diff --git a/.clang-tidy b/.clang-tidy index 1681ed66e..f9b77bce8 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -4,7 +4,9 @@ ExtraArgs: [] FormatStyle: file UseColor: true WarningsAsErrors: '*' +# FIXME: Use `ExcludeHeaderFilterRegex` instead when all maintainers upgraded their `clang-tidy` HeaderFilterRegex: '^(?!.*(?:/|^)(3rdparty|tvm)/).*' +# ExcludeHeaderFilterRegex: '^(3rdparty|tvm)/.*$' # NOTE: there must be no spaces before the '-', so put the comma last. Checks: >- diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 3ba13e0ce..0086358db 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1 +1 @@ -blank_issues_enabled: false +blank_issues_enabled: true diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a475cd513..e939127cb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -40,23 +40,13 @@ jobs: timeout-minutes: 30 steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: recursive - - name: Setup Python 3.8 - id: setup-pylowest - uses: actions/setup-python@v6 - with: - python-version: "3.8" # use lowest supported version for linting - update-environment: false - - - name: Check AST with Python 3.8 - run: | - "${{ steps.setup-pylowest.outputs.python-path }}" -m compileall -q -f tilelang - - name: Setup Python 3.9 + id: setup-pylowest uses: actions/setup-python@v6 with: python-version: "3.9" @@ -67,6 +57,10 @@ jobs: requirements*.txt .pre-commit-config.yaml + - name: Check AST with Python 3.9 + run: | + "${{ steps.setup-pylowest.outputs.python-path }}" -m compileall -q -f tilelang + - name: Pre-commit Lint run: | if ! pipx run pre-commit run --all-files --color=always --show-diff-on-failure; then @@ -93,7 +87,7 @@ jobs: name: self-hosted-amd # Format: [Nightly-]ROCm-.[.]. E.g., "ROCm-6.4" or "Nightly-ROCm-7.0". # Use "Nightly-" prefix to use torch nightly builds. - toolkit: ROCm-6.3 + toolkit: Nightly-ROCm-7.1 - tags: [macos-latest] name: macos-latest toolkit: Metal # or Nightly-Metal @@ -104,7 +98,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: recursive @@ -288,35 +282,59 @@ jobs: echo "Clearing uv cache at ${UV_CACHE_DIR} due to failure." uv cache clean + - name: Enable core dump generation (Linux / GitHub-hosted runners) + if: ${{ runner.os == 'Linux' && !startsWith(matrix.runner.name, 'self-hosted') }} + run: | + sudo sysctl -w kernel.core_pattern="core.${{ matrix.python-version }}.${{ matrix.runner.toolkit }}.%P" + sudo sysctl -w kernel.core_uses_pid=0 + sudo sysctl -w fs.suid_dumpable=1 + sysctl kernel.core_pattern kernel.core_uses_pid fs.suid_dumpable + + - name: Enable core dump generation (macOS / GitHub-hosted runners) + if: ${{ runner.os == 'macOS' && !startsWith(matrix.runner.name, 'self-hosted') }} + run: | + sudo sysctl -w kern.corefile="core.${{ matrix.python-version }}.${{ matrix.runner.toolkit }}.%P" + sudo sysctl -w kern.coredump=1 + sudo sysctl -w kern.sugid_coredump=1 + sysctl kern.corefile kern.coredump kern.sugid_coredump + + - name: Install project (wheel form) + run: | + uv pip install -v . + - name: Run clang-tidy id: clang-tidy if: runner.os == 'Linux' run: | echo "\$ $(command -v clang-tidy) --version" && clang-tidy --version - if [[ -x "$(command -v run-clang-tidy)" ]]; then - echo "Using run-clang-tidy from $(command -v run-clang-tidy)" - CLANG_TIDY=(run-clang-tidy) - else - RCT_URL=https://raw.githubusercontent.com/llvm/llvm-project/refs/heads/release/21.x/clang-tools-extra/clang-tidy/tool/run-clang-tidy.py - echo "Downloading run-clang-tidy script from ${RCT_URL}" - echo "import urllib.request; url = '${RCT_URL}'.rstrip('/'); urllib.request.urlretrieve(url, url.split('/')[-1])" | uv run --no-project --script - - CLANG_TIDY=(uv run --no-project --script -- run-clang-tidy.py) - fi + # Download run-clang-tidy script + RCT_URL=https://raw.githubusercontent.com/llvm/llvm-project/refs/heads/release/21.x/clang-tools-extra/clang-tidy/tool/run-clang-tidy.py + echo "Downloading run-clang-tidy script from ${RCT_URL}" + echo "import urllib.request; url = '${RCT_URL}'.rstrip('/'); urllib.request.urlretrieve(url, url.split('/')[-1])" | uv run --no-project --script - + RUN_CLANG_TIDY=(uv run --no-project --script -- run-clang-tidy.py) + if [[ -x "$(command -v clang-apply-replacements)" ]]; then echo "Using clang-apply-replacements from $(command -v clang-apply-replacements)" - CLANG_TIDY+=(-fix -clang-apply-replacements-binary="$(command -v clang-apply-replacements)") + RUN_CLANG_TIDY+=(-fix -clang-apply-replacements-binary="$(command -v clang-apply-replacements)") else echo "::warning::clang-apply-replacements not found in PATH, automatic fixing disabled." fi # Run cmake to create the build directory with compile_commands.json cmake -S . -B cmake-build --fresh ${CLANG_TIDY_CMAKE_OPTIONS} # no quotes here + echo "::group::compile_commands.json" + ls -alh cmake-build/compile_commands.json + uv run --no-project -m -- json.tool --no-ensure-ascii cmake-build/compile_commands.json + echo "::endgroup::" CXX_FILES=$(find src -type f -iname "*.[ch]pp" -o -iname "*.cc" -o -iname "*.c" -o -iname "*.h") rc=0 - "${CLANG_TIDY[@]}" -clang-tidy-binary="$(command -v clang-tidy)" \ + echo "::group::run-clang-tidy" + "${RUN_CLANG_TIDY[@]}" -clang-tidy-binary="$(command -v clang-tidy)" \ + -exclude-header-filter='^(3rdparty|tvm)/.*$' \ -p="cmake-build" ${CXX_FILES} || rc="$?" + echo "::endgroup::" rm -rf cmake-build run-clang-tidy.py if (( rc != 0 )); then echo "::error::clang-tidy found issues (exit code: ${rc}). Please run 'clang-tidy --fix' locally to fix them." @@ -324,26 +342,6 @@ jobs: exit "${rc}" fi - - name: Enable core dump generation (Linux / GitHub-hosted runners) - if: ${{ runner.os == 'Linux' && !startsWith(matrix.runner.name, 'self-hosted') }} - run: | - sudo sysctl -w kernel.core_pattern="core.${{ matrix.python-version }}.${{ matrix.runner.toolkit }}.%P" - sudo sysctl -w kernel.core_uses_pid=0 - sudo sysctl -w fs.suid_dumpable=1 - sysctl kernel.core_pattern kernel.core_uses_pid fs.suid_dumpable - - - name: Enable core dump generation (macOS / GitHub-hosted runners) - if: ${{ runner.os == 'macOS' && !startsWith(matrix.runner.name, 'self-hosted') }} - run: | - sudo sysctl -w kern.corefile="core.${{ matrix.python-version }}.${{ matrix.runner.toolkit }}.%P" - sudo sysctl -w kern.coredump=1 - sudo sysctl -w kern.sugid_coredump=1 - sysctl kern.corefile kern.coredump kern.sugid_coredump - - - name: Install project (wheel form) - run: | - uv pip install -v . - - name: Run examples with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }}) if: contains(matrix.runner.toolkit, 'CUDA') run: | @@ -369,6 +367,7 @@ jobs: ./python # AMD ROCm tests + # runtime and transform tests needs to repair, then rm it from ignore list - name: Run ROCm tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }}) id: rocm-tests if: contains(matrix.runner.toolkit, 'ROCm') @@ -379,7 +378,8 @@ jobs: pytest --verbose --color=yes --durations=0 --showlocals --cache-clear ) "${PYTEST[@]}" --maxfail=3 --numprocesses=4 \ - ./python/amd/test_tilelang_test_amd.py + --ignore=./python/runtime --ignore=./python/transform \ + ./python # Apple Metal tests - name: Run Metal tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }}) diff --git a/.github/workflows/dist.yml b/.github/workflows/dist.yml index 0ba3fbc30..dcfdcff14 100644 --- a/.github/workflows/dist.yml +++ b/.github/workflows/dist.yml @@ -1,5 +1,6 @@ name: Dist on: + workflow_dispatch: schedule: # gemini said this is 6:00 china time - cron: "0 22 * * *" @@ -17,6 +18,9 @@ on: - CMakeLists.txt - version_provider.py - .github/workflows/dist.yml + # temporarily add to dist check + # until we have type checking in ci / move to python 3.10 + - tilelang/_typing.py release: types: - published @@ -34,6 +38,11 @@ env: COLUMNS: "100" FORCE_COLOR: "1" CLICOLOR_FORCE: "1" + UV_INDEX_STRATEGY: "unsafe-best-match" + UV_HTTP_TIMEOUT: "600" + XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated + PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated + UV_CACHE_DIR: "${{ github.workspace }}/.cache/uv" # to be updated jobs: build-sdist: @@ -52,7 +61,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 1 submodules: recursive @@ -71,6 +80,7 @@ jobs: - name: Setup ccache uses: hendrikmuhs/ccache-action@v1 with: + max-size: "200MB" create-symlink: true evict-old-files: "7d" append-timestamp: false @@ -91,7 +101,7 @@ jobs: - name: Upload SDist # Not PR to save artifact storage, as SDist is only needed for releases. if: github.event_name != 'pull_request' || contains(github.event.pull_request.title, '[Release]') - uses: actions/upload-artifact@v5 + uses: actions/upload-artifact@v6 with: name: sdist path: dist/*.tar.gz @@ -105,24 +115,25 @@ jobs: strategy: matrix: target: - - { runner: ubuntu-latest, toolkit: "CUDA-12.1" } + - { runner: ubuntu-latest, toolkit: "CUDA-12.8" } - { runner: ubuntu-24.04-arm, toolkit: "CUDA-12.8" } + - { runner: ubuntu-latest, toolkit: "Nightly-CUDA-13.0" } + - { runner: ubuntu-24.04-arm, toolkit: "Nightly-CUDA-13.0" } - { runner: macos-latest, toolkit: "Metal" } python-version: # Wheels are built with Python 3.8 Limited API, they should work with all Python >= 3.8. # Only build wheels against Python 3.8 Limited API to save CI resources. - # FIXME: Here we use Python 3.9 because our dependency `apache-tvm-ffi` claims to support - # Python 3.8 but it depends on a version of `ml-dtypes` that requires Python >= 3.9. - "3.9" fail-fast: false timeout-minutes: 120 runs-on: ${{ matrix.target.runner }} env: - NO_VERSION_LABEL: ${{ github.event_name == 'release' && 'OFF' || 'ON' }} + IS_RELEASE: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.title, '[Release]') }} + NO_VERSION_LABEL: "OFF" steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 1 submodules: recursive @@ -130,16 +141,14 @@ jobs: - name: Setup ccache uses: hendrikmuhs/ccache-action@v1 with: + max-size: "200MB" create-symlink: true evict-old-files: "7d" append-timestamp: false - key: wheel-${{ runner.os }}-${{ runner.arch }}-${{ matrix.target.toolkit }}-${{ hashFiles('**/*.cc') }} + key: wheel-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.cc') }} restore-keys: | - wheel-${{ runner.os }}-${{ runner.arch }}-${{ matrix.target.toolkit }}-${{ hashFiles('**/*.cc') }} - wheel-${{ runner.os }}-${{ runner.arch }}-${{ matrix.target.toolkit }} + wheel-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.cc') }} wheel-${{ runner.os }}-${{ runner.arch }} - ${{ runner.os }}-${{ runner.arch }}-${{ matrix.target.toolkit }} - ${{ runner.os }}-${{ runner.arch }} - name: Set CIBW_BUILD run: | @@ -150,26 +159,77 @@ jobs: if [[ "${{ matrix.target.toolkit }}" == *"CUDA"* ]]; then CUDA_VERSION="${{ matrix.target.toolkit }}" - CUDA_VERSION="${CUDA_VERSION#CUDA-}" + CUDA_VERSION="${CUDA_VERSION##*-}" + CUDA_VERSION_MAJMIN="$(echo ${CUDA_VERSION} | cut -d '.' -f-2)" + CUDA_VERSION_MAJMIN_NODOT="${CUDA_VERSION_MAJMIN//./}" echo "CUDA_VERSION=${CUDA_VERSION}" | tee -a "${GITHUB_ENV}" + if [[ "${{ matrix.target.toolkit }}" == "Nightly-"* ]]; then + # Use torch nightly builds + export UV_INDEX="https://download.pytorch.org/whl/nightly/cu${CUDA_VERSION_MAJMIN_NODOT}" + else + export UV_INDEX="https://download.pytorch.org/whl/cu${CUDA_VERSION_MAJMIN_NODOT}" + echo "UV_TORCH_BACKEND=cu${CUDA_VERSION_MAJMIN_NODOT}" | tee -a "${GITHUB_ENV}" + fi + echo "UV_INDEX=${UV_INDEX}" | tee -a "${GITHUB_ENV}" + fi + + if [[ "${{ env.IS_RELEASE }}" == "true" ]]; then + if [[ "${{ matrix.target.toolkit }}" == "Nightly-"* ]]; then + # Avoid using same file name for different toolkit. + echo "NO_GIT_VERSION=ON" | tee -a "${GITHUB_ENV}" + else + echo "NO_VERSION_LABEL=ON" | tee -a "${GITHUB_ENV}" + fi fi if [[ "${{ runner.os }}" == "Linux" ]]; then HOST_CCACHE_DIR="$(ccache --get-config cache_dir)" - echo "CIBW_BEFORE_BUILD_LINUX=yum install -y ccache && ccache -o cache_dir=/host${HOST_CCACHE_DIR}" | tee -a "${GITHUB_ENV}" + echo "CIBW_BEFORE_BUILD_LINUX=dnf install -y ccache && ccache -o cache_dir=/host${HOST_CCACHE_DIR}" | tee -a "${GITHUB_ENV}" fi - name: Build wheels - uses: pypa/cibuildwheel@v3.2 + uses: pypa/cibuildwheel@v3.3 with: package-dir: . output-dir: wheelhouse config-file: "{package}/pyproject.toml" + - name: Setup Python and uv with caching + id: setup-uv + uses: astral-sh/setup-uv@v7 + with: + python-version: "3.12" + activate-environment: true + + - name: Test built wheels + run: | + for WHEEL in wheelhouse/*.whl; do + echo "Testing wheel: ${WHEEL}" + ( + set -e + uv venv --python=3.12 test-venv + source test-venv/bin/activate + + uv pip install --upgrade pip setuptools wheel + if [[ "${UV_INDEX}" == *"/nightly/"* ]]; then + uv pip install --prerelease=allow -v torch + fi + + uv pip install -v "${WHEEL}" + ( + set -e + cd / + uv run --no-project -- python -c "import tilelang; print(tilelang.__version__)" + ) + deactivate + rm -rf test-venv + ) + done + - name: Upload wheels # Not PR to save artifact storage, as wheels are only needed for releases. if: github.event_name != 'pull_request' || contains(github.event.pull_request.title, '[Release]') - uses: actions/upload-artifact@v5 + uses: actions/upload-artifact@v6 with: name: wheels-${{ matrix.python-version }}-${{ runner.os }}-${{ runner.arch }}-${{ matrix.target.toolkit }} path: wheelhouse/*.whl @@ -184,7 +244,7 @@ jobs: timeout-minutes: 15 steps: - name: Download built SDist - uses: actions/download-artifact@v6 + uses: actions/download-artifact@v7 with: # unpacks default artifact into dist/ # if `name: artifact` is omitted, the action will create extra parent dir @@ -192,7 +252,7 @@ jobs: path: dist - name: Download built wheels - uses: actions/download-artifact@v6 + uses: actions/download-artifact@v7 with: pattern: wheels-* path: dist @@ -202,7 +262,7 @@ jobs: run: ls -lh dist/* - name: Upload artifacts - uses: actions/upload-artifact@v5 + uses: actions/upload-artifact@v6 with: name: artifacts path: dist/* diff --git a/.github/workflows/pr-perfbench-bot.yml b/.github/workflows/pr-perfbench-bot.yml deleted file mode 100644 index 37da4e3c8..000000000 --- a/.github/workflows/pr-perfbench-bot.yml +++ /dev/null @@ -1,88 +0,0 @@ -name: Performance Benchmark Bot - -on: - issue_comment: - types: - - created - -permissions: - contents: read - -concurrency: - group: "${{ github.workflow }}-${{ github.ref }}" - cancel-in-progress: true # always cancel in-progress - -env: - PYTHONDEVMODE: "1" - PYTHONUNBUFFERED: "1" - PYTHONPATH: "" # explicit cleanup - PIP_USER: "" # explicit cleanup - COLUMNS: "100" - FORCE_COLOR: "1" - CLICOLOR_FORCE: "1" - XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated - PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated - -jobs: - perfbench: - name: Benchmark between PR and main - if: | - github.repository_owner == 'tile-ai' && - github.event.issue.pull_request && - (contains(github.event.comment.body, '/performance-report') || contains(github.event.comment.body, '/perf')) - runs-on: [self-hosted, nvidia] - steps: - - name: Checkout repository - uses: actions/checkout@v5 - with: - ref: refs/pull/${{ github.event.issue.number }}/merge - fetch-depth: 0 - submodules: recursive - - - name: Setup Python - uses: actions/setup-python@v6 - with: - python-version: "3.12" - update-environment: true - cache: pip - cache-dependency-path: | - pyproject.toml - requirements*.txt - - - name: Install merged version - run: | - python -m venv tll - source tll/bin/activate - pip install -r requirements-test.txt - pip install . - - - name: Install original version - run: | - echo "Check files to be deleted!" - git clean -dxf -e tll/ - echo "Delete files completed!" - git checkout main - python -m venv tl - source tl/bin/activate - pip install -r requirements-test.txt - pip install . - - - name: Run performance test - id: perfbench - run: | - source tl/bin/activate - python maint/scripts/ci_performance.py - - - name: Post test results as PR comment - uses: actions/github-script@v8 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - script: | - github.rest.issues.createComment({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: context.issue.number, - body: '📊 ​**Performance Test Results** (triggered by @' + context.payload.comment.user.login + '):\n\n' + - 'Run listed here: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}\n\n' + - "${{ steps.perfbench.outputs.stdout }}" - }) diff --git a/.github/workflows/pr-regression-test-bot.yml b/.github/workflows/pr-regression-test-bot.yml new file mode 100644 index 000000000..568ce8555 --- /dev/null +++ b/.github/workflows/pr-regression-test-bot.yml @@ -0,0 +1,273 @@ +name: Performance Regression Bot + +on: + issue_comment: + types: + - created + +permissions: + contents: read + issues: write + pull-requests: write + +concurrency: + # Use the issue/PR number to differentiate between different PRs + group: "${{ github.workflow }}-${{ github.event.issue.number }}" + cancel-in-progress: true + +env: + PYTHONDEVMODE: "1" + PYTHONUNBUFFERED: "1" + PYTHONPATH: "" # explicit cleanup + PIP_USER: "" # explicit cleanup + COLUMNS: "100" + FORCE_COLOR: "1" + CLICOLOR_FORCE: "1" + UV_INDEX_STRATEGY: "unsafe-best-match" + UV_HTTP_TIMEOUT: "600" + XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated + PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated + UV_CACHE_DIR: "${{ github.workspace }}/.cache/uv" # to be updated + PRE_COMMIT_HOME: "${{ github.workspace }}/.cache/pip/.pre-commit" # to be updated + +jobs: + permissions-check: + name: Check bot permissions + if: | + github.repository_owner == 'tile-ai' && + github.event.issue.pull_request && + (contains(github.event.comment.body, '@regression-perf')) + runs-on: ubuntu-latest + steps: + - name: Get commenter permission + id: perm + uses: actions/github-script@v8 + with: + script: | + const username = context.payload.comment.user.login + const { owner, repo } = context.repo + const { data } = await github.rest.repos.getCollaboratorPermissionLevel({ owner, repo, username }) + core.setOutput('permission', data.permission) // admin|maintain|write|triage|read|none + + - name: Reject if not allowed + if: ${{ steps.perm.outputs.permission != 'admin' && steps.perm.outputs.permission != 'maintain' && steps.perm.outputs.permission != 'write' }} + run: | + echo "Not authorized: permission=${{ steps.perm.outputs.permission }}" + exit 1 + + pr-regression: + name: Performance regression test between PR and main + needs: [permissions-check] + runs-on: ${{ matrix.runner.tags }} + strategy: + matrix: + runner: + - tags: [self-hosted, nvidia] + name: self-hosted-nvidia + toolkit: CUDA-12.8 + python-version: + - "3.12" + fail-fast: false + timeout-minutes: 120 + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + ref: refs/pull/${{ github.event.issue.number }}/merge + fetch-depth: 0 + submodules: recursive + + - name: Set environment (self-hosted runners) + if: startsWith(matrix.runner.name, 'self-hosted') + run: | + # Hide sensitive data in logs for self-hosted runners + if [[ -n "${{ secrets.SECRET_PATH_PREFIXES }}" ]]; then + echo "::add-mask::${{ secrets.SECRET_PATH_PREFIXES }}" + # Colon separated list of secrets to mask + for secret in $(echo "${{ secrets.SECRET_PATH_PREFIXES }}" | tr ':' '\n'); do + echo "::add-mask::${secret}" + done + fi + + # Use runner tool_cache as cache root for self-hosted runners to avoid internet connection + # issues and to share cache between jobs. + export XDG_CACHE_HOME="${{ runner.tool_cache }}/.ci-cache-${{ github.workflow }}" + echo "XDG_CACHE_HOME=${XDG_CACHE_HOME}" | tee -a "${GITHUB_ENV}" + echo "PIP_CACHE_DIR=${XDG_CACHE_HOME}/pip" | tee -a "${GITHUB_ENV}" + echo "UV_CACHE_DIR=${XDG_CACHE_HOME}/uv" | tee -a "${GITHUB_ENV}" + echo "PRE_COMMIT_HOME=${XDG_CACHE_HOME}/pip/.pre-commit" | tee -a "${GITHUB_ENV}" + + # Do not use ccache on self-hosted runners, as it will download/upload caches which is slow. + # Self-hosted runners usually have more CPU power to compile without ccache. + - name: Setup ccache (GitHub-hosted runners) + id: setup-ccache + if: ${{ !startsWith(matrix.runner.name, 'self-hosted') }} + uses: hendrikmuhs/ccache-action@v1 + with: + create-symlink: true + evict-old-files: "7d" + append-timestamp: false + key: ${{ runner.os }}-${{ runner.arch }}-${{ matrix.runner.toolkit }}-${{ hashFiles('**/*.cc') }} + restore-keys: | + ${{ runner.os }}-${{ runner.arch }}-${{ matrix.runner.toolkit }}-${{ hashFiles('**/*.cc') }} + ${{ runner.os }}-${{ runner.arch }}-${{ matrix.runner.toolkit }} + ${{ runner.os }}-${{ runner.arch }} + + - name: Set environment (CUDA) + if: contains(matrix.runner.toolkit, 'CUDA') + run: | + TOOLKIT="${{ matrix.runner.toolkit }}" + CUDA_VERSION="${TOOLKIT##*-}" + CUDA_VERSION_MAJMIN="$(echo ${CUDA_VERSION} | cut -d '.' -f-2)" + CUDA_VERSION_MAJMIN_NODOT="${CUDA_VERSION_MAJMIN//./}" + if [[ "${TOOLKIT}" == "Nightly-"* ]]; then + # Use torch nightly builds + export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/nightly/cu${CUDA_VERSION_MAJMIN_NODOT}" + else + export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu${CUDA_VERSION_MAJMIN_NODOT}" + fi + export UV_INDEX="${PIP_EXTRA_INDEX_URL}" + export CLANG_TIDY_CMAKE_OPTIONS="${CLANG_TIDY_CMAKE_OPTIONS} -DUSE_CUDA=ON" + + echo "USE_CUDA=ON" | tee -a "${GITHUB_ENV}" + echo "CUDA_VERSION=${CUDA_VERSION}" | tee -a "${GITHUB_ENV}" + echo "CUDA_VERSION_MAJMIN=${CUDA_VERSION_MAJMIN}" | tee -a "${GITHUB_ENV}" + echo "CUDA_VERSION_MAJMIN_NODOT=${CUDA_VERSION_MAJMIN_NODOT}" | tee -a "${GITHUB_ENV}" + echo "PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}" | tee -a "${GITHUB_ENV}" + echo "UV_INDEX=${UV_INDEX}" | tee -a "${GITHUB_ENV}" + echo "CLANG_TIDY_CMAKE_OPTIONS=${CLANG_TIDY_CMAKE_OPTIONS}" | tee -a "${GITHUB_ENV}" + + if [[ ! -x "$(command -v nvcc)" ]]; then + export PATH="/usr/local/cuda/bin:${PATH}" + export LD_LIBRARY_PATH="/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" + echo "PATH=${PATH}" | tee -a "${GITHUB_ENV}" + echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" | tee -a "${GITHUB_ENV}" + fi + if [[ -x "$(command -v nvcc)" ]]; then + echo "\$ $(command -v nvcc) --version" && nvcc --version + else + echo "::warning::nvcc not found in PATH!" + fi + + - name: Setup Python and uv with caching + id: setup-uv + uses: astral-sh/setup-uv@v7 + with: + python-version: ${{ matrix.python-version }} + activate-environment: true + # Do not use cache for self-hosted runners, as it will download/upload caches which is slow. + enable-cache: ${{ !startsWith(matrix.runner.name, 'self-hosted') }} + prune-cache: ${{ !startsWith(matrix.runner.name, 'self-hosted') }} + # Use runner tool_cache for self-hosted runners + cache-local-path: ${{ env.UV_CACHE_DIR }} + ignore-nothing-to-cache: true + # Extra cache key to upload/download caches on GitHub-hosted runners + cache-suffix: uv-${{ runner.os }}-${{ runner.arch }}-${{ matrix.python-version }}-${{ matrix.runner.name }}-${{ matrix.runner.toolkit }} + cache-dependency-glob: | + pyproject.toml + requirements*.txt + + - name: Setup environments + id: setup-venv + run: | + set -e + + uv venv --python "${{ matrix.python-version }}" new + + source new/bin/activate + uv pip install -v -r requirements-test.txt + uv pip install -v . + + - name: Install Main version (Baseline) + run: | + set -e + git clean -dxf -e new/ -e .cache/ + git checkout main + git submodule update --init --recursive + uv venv --python "${{ matrix.python-version }}" old + source old/bin/activate + + uv pip install -v -r requirements-test.txt + uv pip install -v . + rm -rf tilelang build + + uv venv --python "${{ matrix.python-version }}" test_regression + source test_regression/bin/activate + uv pip install -v -r requirements-test.txt + + - name: Clear uv cache for self-hosted runners (if setup failed) + if: >- + ${{ + failure() && + startsWith(matrix.runner.name, 'self-hosted') && + (steps.setup-uv.conclusion == 'failure' || steps.setup-venv.conclusion == 'failure') + }} + run: | + echo "Clearing uv cache at ${UV_CACHE_DIR} due to failure." + uv cache clean + + - name: Enable core dump generation (Linux / GitHub-hosted runners) + if: ${{ runner.os == 'Linux' && !startsWith(matrix.runner.name, 'self-hosted') }} + run: | + sudo sysctl -w kernel.core_pattern="core.${{ matrix.python-version }}.${{ matrix.runner.toolkit }}.%P" + sudo sysctl -w kernel.core_uses_pid=0 + sudo sysctl -w fs.suid_dumpable=1 + sysctl kernel.core_pattern kernel.core_uses_pid fs.suid_dumpable + + - name: Enable core dump generation (macOS / GitHub-hosted runners) + if: ${{ runner.os == 'macOS' && !startsWith(matrix.runner.name, 'self-hosted') }} + run: | + sudo sysctl -w kern.corefile="core.${{ matrix.python-version }}.${{ matrix.runner.toolkit }}.%P" + sudo sysctl -w kern.coredump=1 + sudo sysctl -w kern.sugid_coredump=1 + sysctl kern.corefile kern.coredump kern.sugid_coredump + + - name: Run performance regression test + run: | + source test_regression/bin/activate + OLD_PYTHON=./old/bin/python NEW_PYTHON=./new/bin/python \ + PERF_REGRESSION_MD=regression_result.md PERF_REGRESSION_PNG=regression_result.png \ + python ./maint/scripts/test_perf_regression.py + + - name: Read markdown table + id: read_md + run: | + echo "content<> $GITHUB_OUTPUT + cat regression_result.md >> $GITHUB_OUTPUT + echo "EOF" >> $GITHUB_OUTPUT + + - name: Upload result image as artifact + uses: actions/upload-artifact@v6 + with: + name: perf-regression-${{ github.run_id }} + path: regression_result.png + + - name: Post test results as PR comment + uses: actions/github-script@v8 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + // Read the file directly instead of passing via env/outputs to avoid escaping issues + const md = fs.readFileSync('regression_result.md', 'utf8'); + + const runUrl = `${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}`; + + const body = + 'Performance Regression Test Report\n' + + '============================\n\n' + + `Triggered by: @${context.payload.comment.user.login}\n` + + `Workflow run: ${runUrl}\n\n` + + 'Results\n' + + '-------\n\n' + + md + '\n\n' + + 'Artifacts\n' + + '---------\n\n' + + '- regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.\n'; + + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body + }); diff --git a/.github/workflows/publish-docs.yml b/.github/workflows/publish-docs.yml index 953303102..2197015b6 100644 --- a/.github/workflows/publish-docs.yml +++ b/.github/workflows/publish-docs.yml @@ -25,7 +25,7 @@ jobs: runs-on: [self-hosted, nvidia] steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: recursive diff --git a/.gitignore b/.gitignore index 752f6cb76..727b6a14e 100644 --- a/.gitignore +++ b/.gitignore @@ -108,3 +108,18 @@ cmake-build-*/ # pre-commit cache .pre-commit-cache/* + +# host checks logs +maint/host_checks/logs/* + +# ncu +*.ncu-rep + +# csv +*.csv + +# clang-tidy +/run-clang-tidy.py + +# perf regression test +.perf_regression/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 615f173b9..a99de631d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,19 +9,17 @@ default_stages: [pre-commit, pre-push, manual] exclude: '^(build|3rdparty)/.*$' # exclude build and 3rdparty directories repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v6.0.0 + rev: v6.0.0 # May not sync with requirements-lint.txt, but it's OK for now hooks: - id: check-symlinks - id: destroyed-symlinks - # FIXME: enable these hooks - # - id: trailing-whitespace - # - id: end-of-file-fixer + - id: trailing-whitespace + - id: end-of-file-fixer - id: check-added-large-files - id: check-merge-conflict fail_fast: true - # FIXME: enable these hooks - # - id: check-executables-have-shebangs - # - id: check-shebang-scripts-are-executable + - id: check-executables-have-shebangs + - id: check-shebang-scripts-are-executable - id: detect-private-key - id: check-yaml - id: check-toml @@ -32,30 +30,17 @@ repos: args: [--ignore-case] files: ^docs/spelling_wordlist\.txt$ - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v21.1.2 # sync with requirements-lint.txt + rev: v21.1.8 # sync with requirements-lint.txt hooks: - id: clang-format - exclude: | - (?ix)( - ^.+\.(cu|cuh)$| - ^.+\.json$ - ) + types_or: [c++, c] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.3 # sync with requirements-lint.txt + rev: v0.14.14 # sync with requirements-lint.txt hooks: - id: ruff-check args: [--fix, --exit-non-zero-on-fix] - - repo: https://github.com/google/yapf - rev: v0.43.0 # sync with requirements-lint.txt - hooks: - - id: yapf - name: yapf-multiproc-bugfix - # yapf is not multiprocess safe, so we run a dummy yapf first. - args: [--in-place, docs/conf.py] - always_run: true - pass_filenames: false - - id: yapf - args: [--recursive, --in-place] + - id: ruff-format + args: [--exit-non-zero-on-format] - repo: https://github.com/codespell-project/codespell rev: v2.4.1 # sync with requirements-lint.txt hooks: @@ -67,3 +52,8 @@ repos: ^.+\.svg$| ^.*\brequirements\b.*\.txt$ ) + - repo: https://github.com/jackdewinter/pymarkdown + rev: v0.9.35 + hooks: + - id: pymarkdown + args: ["--config", ".pymarkdown", "fix"] diff --git a/.pymarkdown b/.pymarkdown new file mode 100644 index 000000000..5394265ed --- /dev/null +++ b/.pymarkdown @@ -0,0 +1,37 @@ +{ + "plugins": { + "md003": { + "style": "atx" + }, + "md004": { + "style": "dash" + }, + "md013": { + "enabled": false + }, + "md026": { + "enabled": false + }, + "md029": { + "enabled": false + }, + "md031": { + "enabled": false + }, + "md032": { + "enabled": false + }, + "md033": { + "enabled": false + }, + "md034": { + "enabled": false + }, + "md040": { + "enabled": false + }, + "md041": { + "enabled": false + } + } +} diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel index 1c45ca35d..b38bb492a 160000 --- a/3rdparty/composable_kernel +++ b/3rdparty/composable_kernel @@ -1 +1 @@ -Subproject commit 1c45ca35dd5c215e0c1db1f40f01556f467f52a8 +Subproject commit b38bb492a1a55b5abb0c345962143c0f9c482cfb diff --git a/3rdparty/tvm b/3rdparty/tvm index 1815c3e0b..8d494caca 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 1815c3e0b6ec4ead36370bbd1562025d8529017c +Subproject commit 8d494cacae52b2ec73f2717431190b1ecd5df6ce diff --git a/CMakeLists.txt b/CMakeLists.txt index 72e1d9795..4e520dbcb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -136,14 +136,21 @@ file(GLOB TILE_LANG_SRCS src/*.cc src/layout/*.cc src/transform/*.cc + src/transform/common/*.cc src/op/*.cc src/target/utils.cc + src/target/codegen_c_host.cc src/target/codegen_cpp.cc src/target/rt_mod_cpp.cc # intrin_rule doesn't have system dependency src/target/intrin_rule*.cc ) +# Always include CPU-safe runtime helpers +list(APPEND TILE_LANG_SRCS + src/runtime/error_helpers.cc +) + # Track if the user explicitly selected a backend via cache options. set(TILELANG_BACKEND_USER_SELECTED OFF) foreach(BACKEND IN LISTS TILELANG_BACKENDS) @@ -204,17 +211,55 @@ elseif(USE_CUDA) # Set `USE_CUDA=/usr/local/cuda-x.y` cmake_path(GET CUDAToolkit_BIN_DIR PARENT_PATH USE_CUDA) + # ============================================================================ + # CUDA Driver Stub Library (libcuda_stub.so) + # ============================================================================ + # This library provides drop-in replacements for CUDA driver API functions. + # Instead of linking directly against libcuda.so (which would fail on + # CPU-only machines), we link against this stub which loads libcuda.so + # lazily at runtime on first API call. + # + # The stub exports global C functions matching the CUDA driver API: + # - cuModuleLoadData, cuLaunchKernel, cuMemsetD32_v2, etc. + # These can be called directly without any wrapper macros. + # ============================================================================ + add_library(cuda_stub SHARED src/target/stubs/cuda.cc) + target_include_directories(cuda_stub PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) + # Export symbols with visibility="default" when building + target_compile_definitions(cuda_stub PRIVATE TILELANG_CUDA_STUB_EXPORTS) + # Use dlopen/dlsym for runtime library loading + target_link_libraries(cuda_stub PRIVATE ${CMAKE_DL_LIBS}) + set_target_properties(cuda_stub PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + # Use consistent naming + OUTPUT_NAME "cuda_stub" + ) + file(GLOB TILE_LANG_CUDA_SRCS - src/runtime/*.cc + src/runtime/runtime.cc src/target/ptx.cc src/target/codegen_cuda.cc + src/target/codegen_py.cc + src/target/codegen_utils.cc + src/target/codegen_cutedsl.cc src/target/rt_mod_cuda.cc + src/target/rt_mod_cutedsl.cc ) list(APPEND TILE_LANG_SRCS ${TILE_LANG_CUDA_SRCS}) list(APPEND TILE_LANG_INCLUDES ${CUDAToolkit_INCLUDE_DIRS}) endif() +set(USE_Z3 ON CACHE STRING "Use Z3 SMT solver for TileLang optimizations") +set(USE_PYPI_Z3 ON CACHE BOOL "Use Z3 provided by PyPI z3-solver package") + +if(USE_Z3 AND USE_PYPI_Z3) + list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake/pypi-z3") + find_package(Z3 REQUIRED) +endif() + # Include tvm after configs have been populated add_subdirectory(${TVM_SOURCE} tvm EXCLUDE_FROM_ALL) @@ -222,7 +267,11 @@ add_subdirectory(${TVM_SOURCE} tvm EXCLUDE_FROM_ALL) add_compile_definitions(DMLC_USE_LOGGING_LIBRARY=) add_library(tilelang_objs OBJECT ${TILE_LANG_SRCS}) + +# Set debug mode compile definitions +# Enable the TVM debug option, i.e., TVM_LOG_DEBUG if(CMAKE_BUILD_TYPE STREQUAL "Debug") + message(STATUS "Building TileLang with DEBUG mode") target_compile_definitions(tilelang_objs PRIVATE "TVM_LOG_DEBUG") endif() @@ -232,6 +281,18 @@ add_library(tilelang SHARED $) add_library(tilelang_module SHARED $) target_link_libraries(tilelang PUBLIC tvm_runtime tvm) target_link_libraries(tilelang_module PUBLIC tvm) + +# Place dev build outputs under build/lib for consistency +set_target_properties(tilelang PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" +) +set_target_properties(tilelang_module PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" +) # Build cython extension find_package(Python REQUIRED COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT}) @@ -251,26 +312,105 @@ if(NOT "${SKBUILD_SABI_VERSION}" STREQUAL "") endif() python_add_library(tilelang_cython_wrapper MODULE "${CMAKE_BINARY_DIR}/tilelang_cython_wrapper.cpp" ${USE_SABI} WITH_SOABI) -# Install extension into the tilelang package directory + +# Ensure dev builds drop the extension into build/lib alongside other shared libs +set_target_properties(tilelang_cython_wrapper PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" +) + +# Install the extension into tilelang/lib inside the wheel install(TARGETS tilelang_cython_wrapper - LIBRARY DESTINATION tilelang - RUNTIME DESTINATION tilelang - ARCHIVE DESTINATION tilelang) + LIBRARY DESTINATION tilelang/lib + RUNTIME DESTINATION tilelang/lib + ARCHIVE DESTINATION tilelang/lib) + +# Copy libz3.so to build folder to workaround isolated build env issue +if(USE_Z3 AND USE_PYPI_Z3) + get_target_property(Z3_LIBRARY_PATH z3::libz3 IMPORTED_LOCATION) + install(FILES "${Z3_LIBRARY_PATH}" DESTINATION "${CMAKE_BINARY_DIR}/lib") + if(APPLE) + set_target_properties(tvm PROPERTIES BUILD_RPATH "@loader_path") + else() + set_target_properties(tvm PROPERTIES BUILD_RPATH "\$ORIGIN") + endif() +endif() + +set(TILELANG_OUTPUT_TARGETS + tilelang + tilelang_module + tvm + tvm_runtime +) + +if(USE_CUDA) + # Link against CUDA stub library instead of libcuda.so + # This enables lazy loading of libcuda.so at runtime, allowing + # `import tilelang` to succeed on CPU-only machines. + foreach(target IN LISTS TILELANG_OUTPUT_TARGETS) + target_link_libraries(${target} PUBLIC cuda_stub) + endforeach() + # Include CUDA stub in output targets for RPATH configuration + list(APPEND TILELANG_OUTPUT_TARGETS cuda_stub) +endif() + +unset(PATCHELF_EXECUTABLE CACHE) -# let libtilelang to search tvm/tvm_runtime in same dir if(APPLE) - set_target_properties(tilelang PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib") - set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib") - set_target_properties(tvm PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib") - set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib") + set(TILELANG_INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib") + if(USE_Z3 AND USE_PYPI_Z3) + # Some z3 is placed in lib/ and some in bin/, we add both in rpath + string(APPEND TILELANG_INSTALL_RPATH ";@loader_path/../../z3/lib;@loader_path/../../z3/bin") + endif() elseif(UNIX) - set_target_properties(tilelang PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib") - set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib") - set_target_properties(tvm PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib") - set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib") + set(TILELANG_INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib") + if(USE_Z3 AND USE_PYPI_Z3) + string(APPEND TILELANG_INSTALL_RPATH ":\$ORIGIN/../../z3/lib") + endif() + if(USE_CUDA) + string(APPEND TILELANG_INSTALL_RPATH ":\$ORIGIN/../../nvidia/cu${CUDAToolkit_VERSION_MAJOR}/lib") + endif() + find_program(PATCHELF_EXECUTABLE patchelf) + if (NOT PATCHELF_EXECUTABLE) + message(STATUS "`patchelf` not found.") + endif() +endif() + +# Let libtilelang search for tvm/tvm_runtime in the same directory +foreach(target IN LISTS TILELANG_OUTPUT_TARGETS) + set_target_properties(${target} PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}") + set_target_properties(${target} PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + ) +endforeach() + +# Exclude libcuda.so to allow importing on a CPU-only machine +if(USE_CUDA AND PATCHELF_EXECUTABLE) + # Run `patchelf` on built libraries to remove libcuda.so dependency. + # Use `install(CODE ...)` instead of `add_custom_command(... POST_BUILD ...)` + # to avoid race conditions during linking. + foreach(target IN LISTS TILELANG_OUTPUT_TARGETS) + install(CODE " + execute_process( + COMMAND ${PATCHELF_EXECUTABLE} --remove-needed libcuda.so.1 --remove-needed libcuda.so \"$\" + WORKING_DIRECTORY \"${CMAKE_INSTALL_PREFIX}\" + RESULT_VARIABLE patchelf_result + ) + if(patchelf_result EQUAL 0) + message(STATUS \"`patchelf` successfully removed dependency `libcuda.so` from $\") + else() + message(WARNING \"`patchelf` failed to remove dependency `libcuda.so` from $\") + endif() + ") + endforeach() endif() install( - TARGETS tvm tvm_runtime tilelang_module tilelang + TARGETS ${TILELANG_OUTPUT_TARGETS} LIBRARY DESTINATION tilelang/lib + RUNTIME DESTINATION tilelang/lib + ARCHIVE DESTINATION tilelang/lib ) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 9e380d831..5eba9044a 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -17,23 +17,23 @@ diverse, inclusive, and healthy community. Examples of behavior that contributes to a positive environment for our community include: -* Demonstrating empathy and kindness toward other people -* Being respectful of differing opinions, viewpoints, and experiences -* Giving and gracefully accepting constructive feedback -* Accepting responsibility and apologizing to those affected by our mistakes, +- Demonstrating empathy and kindness toward other people +- Being respectful of differing opinions, viewpoints, and experiences +- Giving and gracefully accepting constructive feedback +- Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience -* Focusing on what is best not just for us as individuals, but for the overall +- Focusing on what is best not just for us as individuals, but for the overall community Examples of unacceptable behavior include: -* The use of sexualized language or imagery, and sexual attention or advances of +- The use of sexualized language or imagery, and sexual attention or advances of any kind -* Trolling, insulting or derogatory comments, and personal or political attacks -* Public or private harassment -* Publishing others' private information, such as a physical or email address, +- Trolling, insulting or derogatory comments, and personal or political attacks +- Public or private harassment +- Publishing others' private information, such as a physical or email address, without their explicit permission -* Other conduct which could reasonably be considered inappropriate in a +- Other conduct which could reasonably be considered inappropriate in a professional setting ## Enforcement Responsibilities diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e4b45e24b..45284e980 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,7 +2,7 @@ That would be awesome if you want to contribute something to TileLang! -### Table of Contents +## Table of Contents - [Report Bugs](#report-bugs) - [Ask Questions](#ask-questions) @@ -81,6 +81,8 @@ in the main directory. This installation is removable by: python3 -m pip uninstall tilelang ``` +We also recommend installing TileLang in a more manual way for better control over the build process, by compiling the C++ extensions first and set the `PYTHONPATH`. See [Working from Source via `PYTHONPATH`](https://tilelang.com/get_started/Installation.html#working-from-source-via-pythonpath) for detailed instructions. + ## Lint Check To check the linting, run: diff --git a/LICENSE b/LICENSE index 2122252e9..09dd51c8c 100644 --- a/LICENSE +++ b/LICENSE @@ -1,7 +1,7 @@ MIT License Copyright (c) Tile-AI. - **During the period from December 1, 2024, to Mar 14, 2025, this project is + **During the period from December 1, 2024, to Mar 14, 2025, this project is subject to additional collaboration terms with Microsoft Corporation.** Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/README.md b/README.md index d7cdabee5..30c518e05 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,9 @@ # Tile Language [![PyPI version](https://badge.fury.io/py/tilelang.svg)](https://badge.fury.io/py/tilelang) -[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/tile-ai/tilelang) [![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?logo=discord&logoColor=white)](https://discord.gg/TUrHyJnKPG) - +[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/tile-ai/tilelang) +[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?logo=discord&logoColor=white)](https://discord.gg/TUrHyJnKPG) +[![Puzzles](https://img.shields.io/badge/🧩_Learn-TileLang_Puzzles-blueviolet)](https://github.com/tile-ai/tilelang-puzzles) Tile Language (**tile-lang**) is a concise domain-specific language designed to streamline the development of high-performance GPU/CPU kernels (e.g., GEMM, Dequant GEMM, FlashAttention, LinearAttention). By employing a Pythonic syntax with an underlying compiler infrastructure on top of [TVM](https://tvm.apache.org/), tile-lang allows developers to focus on productivity without sacrificing the low-level optimizations necessary for state-of-the-art performance. @@ -13,6 +14,10 @@ Tile Language (**tile-lang**) is a concise domain-specific language designed to ## Latest News +- 02/02/2026 🧩: Check out [TileLang Puzzles](https://github.com/tile-ai/tilelang-puzzles), a fun and interactive way to learn TileLang programming with 10 progressively harder puzzles! +- 12/18/2025 🚀: Added [CuTeDSL backend](https://github.com/tile-ai/tilelang/pull/1421) support, enabling compilation to NVIDIA CUTLASS CuTe DSL! Join us in building and optimizing this exciting new backend: [Issue #1454](https://github.com/tile-ai/tilelang/issues/1454). +- 12/17/2025 🔬: Integrated [Z3 theorem prover](https://github.com/tile-ai/tilelang/pull/1367) into TVM Arith Analyzer, bringing SMT-based symbolic reasoning for enhanced optimizations and automatic correctness verification! +- 10/31/2025 🔧: Migrated to [apache-tvm-ffi](https://github.com/tile-ai/tilelang/pull/1108), significantly reducing CPU overhead! - 10/30/2025 📦: We have released v0.1.6.post2, which is the last version compatible with Python 3.8. - 10/07/2025 🍎: Added Apple Metal Device support, check out [Pull Request #799](https://github.com/tile-ai/tilelang/pull/799) for details. - 09/29/2025 🎉: Thrilled to announce that ​​AscendC​​ and ​Ascend​NPU IR​​ backends targeting Huawei Ascend chips are now supported! @@ -21,7 +26,7 @@ Check out the preview here: This includes implementations across two branches: [ascendc_pto](https://github.com/tile-ai/tilelang-ascend) and [npuir](https://github.com/tile-ai/tilelang-ascend/tree/npuir). -Feel free to explore and share your feedback! +Feel free to explore and share your feedback! - 07/04/2025 🚀: Introduced `T.gemm_sp` for 2:4 sparse tensor core support, check out [Pull Request #526](https://github.com/tile-ai/tilelang/pull/526) for details. - 06/05/2025 ✨: Added [NVRTC Backend](https://github.com/tile-ai/tilelang/pull/461) to significantly reduce compilation time for cute templates! - 04/14/2025 🚀: Added high-performance FlashMLA implementation for AMD MI300X, achieving performance parity with hand-optimized assembly kernels of Aiter! See [example_mla_amd](./examples/deepseek_mla/amd/README.md) for details. @@ -46,7 +51,6 @@ Although tile-lang aims to be portable across a range of Devices, it has been sp Within the `examples` directory, you will also find additional complex kernels—such as convolutions, forward/backward passes for FlashAttention, more operators will continuously be added. - ## Benchmark Summary TileLang achieves exceptional performance across a variety of computational patterns. Comprehensive benchmark scripts and settings are available at [tilelang-benchmark](https://github.com/tile-ai/tilelang-benchmark). Below are selected results showcasing its capabilities: @@ -61,7 +65,7 @@ TileLang achieves exceptional performance across a variety of computational patt mla decode performance bs128 on H100 - + - Flash Attention Performance on H100
operator performance on H100 @@ -106,9 +110,9 @@ pip install -e . -v # remove -e option if you don't want to install in editable ### Method 2: Build from Source We currently provide three ways to install **tile-lang** from source: - - [Install from Source (using your own TVM installation)](./docs/get_started/Installation.md#method-1-install-from-source-using-your-own-tvm-installation) - - [Install from Source (using the bundled TVM submodule)](./docs/get_started/Installation.md#method-2-install-from-source-using-the-bundled-tvm-submodule) - - [Install Using the Provided Script](./docs/get_started/Installation.md#method-3-install-using-the-provided-script) +- [Install from Source (using your own TVM installation)](./docs/get_started/Installation.md#method-1-install-from-source-using-your-own-tvm-installation) +- [Install from Source (using the bundled TVM submodule)](./docs/get_started/Installation.md#method-2-install-from-source-using-the-bundled-tvm-submodule) +- [Install Using the Provided Script](./docs/get_started/Installation.md#method-3-install-using-the-provided-script) ### Method 3: Install with Nightly Version @@ -130,93 +134,95 @@ In this section, you'll learn how to write and execute a straightforward GEMM (m Below is an example that demonstrates more advanced features: layout annotation, parallelized copy, and swizzle for improved L2 cache locality. This snippet shows how to adapt your kernel to maximize performance on complex hardware. ```python -import tilelang -import tilelang.language as T - # @tilelang.jit(target="cuda") # target currently can be "cuda" or "hip" or "cpu". # if not specified, it will be inferred from the input tensors during compile time @tilelang.jit -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - - @T.prim_func - def matmul_relu_kernel( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), - ): - # Initialize Kernel Context - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_K, block_N), dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - - # Enable rasterization for better L2 cache locality (Optional) - # T.use_swizzle(panel_size=10, enable=True) - - # Clear local accumulation - T.clear(C_local) - - for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): - # Copy tile of A - # This is a sugar syntax for parallelized copy - T.copy(A[by * block_M, ko * block_K], A_shared) - - # Copy tile of B - T.copy(B[ko * block_K, bx * block_N], B_shared) - - # Perform a tile-level GEMM on the shared buffers - # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs - T.gemm(A_shared, B_shared, C_local) - - # relu - for i, j in T.Parallel(block_M, block_N): - C_local[i, j] = T.max(C_local[i, j], 0) - - # Copy result back to global memory - T.copy(C_local, C[by * block_M, bx * block_N]) - - return matmul_relu_kernel - - -M = 1024 # M = T.dynamic("m") if you want to use dynamic shape -N = 1024 -K = 1024 -block_M = 128 -block_N = 128 -block_K = 32 - -# 1. Define the kernel (matmul) and compile/lower it into an executable module -matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K) - -# 3. Test the kernel in Python with PyTorch data -import torch - -# Create random input tensors on the GPU -a = torch.randn(M, K, device="cuda", dtype=torch.float16) -b = torch.randn(K, N, device="cuda", dtype=torch.float16) -c = torch.empty(M, N, device="cuda", dtype=torch.float16) +def matmul_relu( + A, B, + block_M: int = 64, + block_N: int = 64, + block_K: int = 64, + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float32, +): + # declare compilation shape constant + M, N, K = T.const('M, N, K') -# Run the kernel through the Profiler -matmul_relu_kernel(a, b, c) + # annotate input tensor shape + A: T.Tensor[[M, K], dtype] + B: T.Tensor[[K, N], dtype] -print(c) -# Reference multiplication using PyTorch -ref_c = torch.relu(a @ b) + # allocate output tensor + C = T.empty([M, N], dtype) -# Validate correctness -torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) -print("Kernel output matches PyTorch reference.") + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) -# 4. Retrieve and inspect the generated CUDA source (optional) -# cuda_source = jit_kernel.get_kernel_source() -# print("Generated CUDA kernel:\n", cuda_source) + # Enable rasterization for better L2 cache locality (Optional) + # T.use_swizzle(panel_size=10, enable=True) -# 5.Profile latency with kernel -profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + # Clear local accumulation + T.clear(C_local) -latency = profiler.do_bench() + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy tile of A + # This is a sugar syntax for parallelized copy + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Copy tile of B + T.copy(B[ko * block_K, bx * block_N], B_shared) + + # Perform a tile-level GEMM on the shared buffers + # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs + T.gemm(A_shared, B_shared, C_local) + + # relu + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = T.max(C_local[i, j], 0) + # Copy result back to global memory + T.copy(C_local, C[by * block_M, bx * block_N]) + + # You can write multiple cuda kernel in one function, they execute sequentially + # with T.Kernel(...) as ... + + # Return the tensor, you can also return multiple tensors + return C + + +M, N, K = 1024, 1024, 1024 + +a = torch.randn(M, K, device="cuda", dtype=torch.float16) +b = torch.randn(K, N, device="cuda", dtype=torch.float16) +c_ref = torch.relu(a @ b) + +# Call the kernel +c = matmul_relu(a, b) +torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=1e-2) + +# Call the kernel with overwritten compilation constants +c = matmul_relu(a, b, block_M=128, block_N=128, block_K=64) +torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=1e-2) + +# Retrieve the compiled kernel +kernel = matmul_relu.compile(a, b) # use torch.Tensor +kernel = matmul_relu.compile( # use T.Tensor as placeholder + T.Tensor((M, K), T.float16), + T.Tensor((K, N), T.float16) +) +kernel = matmul_relu.compile( # directly specify the shape constants + M=M, N=N, K=K, + block_M=128, block_N=128, block_K=64 +) +print(kernel.get_kernel_source()) +c = kernel(a, b) + +# 5.Profile latency with kernel +profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) +latency = profiler.do_bench() print(f"Latency: {latency} ms") ``` diff --git a/THIRDPARTYNOTICES.txt b/THIRDPARTYNOTICES.txt index b7c481841..3558662a8 100644 --- a/THIRDPARTYNOTICES.txt +++ b/THIRDPARTYNOTICES.txt @@ -1,5 +1,5 @@ -BitBLAS uses third-party material as listed below. The attached notices are -provided for informational purposes only. +BitBLAS uses third-party material as listed below. The attached notices are +provided for informational purposes only. Notice for apache/tvm ------------------------------- diff --git a/VERSION b/VERSION index 5ed6219f4..c5578b40d 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.6.post2 +0.1.7.post3 diff --git a/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py index 6401276ac..3dd82aa5e 100644 --- a/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py @@ -7,10 +7,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -28,15 +25,15 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F def benchmark_topk_sparse_attention(): from benchmark_configs import configs + torch.manual_seed(0) # Config for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: - # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) import flash_attn diff --git a/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py index aefe4d420..0018e9c93 100644 --- a/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py @@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -39,16 +36,15 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) block_N = 64 num_stages = 2 threads = 128 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] block_mask_shape = [batch, heads, downsample_len, downsample_len] - dtype = "float16" - accum_dtype = "float" - block_mask_dtype = "bool" + dtype = T.float16 + accum_dtype = T.float32 + block_mask_dtype = T.bool def kernel_func(block_M, block_N, num_stages, threads): - @T.macro def MMA0( K: T.Tensor(shape, dtype), @@ -60,11 +56,10 @@ def MMA0( by: T.int32, bz: T.int32, ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -79,22 +74,24 @@ def MMA1( by: T.int32, bz: T.int32, ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # in the first ceil_div(kBlockM, kBlockN) steps. @@ -114,22 +111,21 @@ def Softmax( @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), - Output: T.Tensor(shape, dtype), + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), + Output: T.Tensor(shape, dtype), ): - with T.Kernel( - T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -142,31 +138,29 @@ def main( scores_scale = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - block_mask = T.alloc_local([downsample_len], block_mask_dtype) + block_mask = T.alloc_fragment([downsample_len], block_mask_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - for vj in T.serial(downsample_len): - block_mask[vj] = BlockSparseMask[bz, by, bx, vj] + T.copy(BlockSparseMask[bz, by, bx, :], block_mask) loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): - if block_mask[k]: + if block_mask[k] != 0: MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, - scores_sum, logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main @@ -175,26 +169,23 @@ def main( def benchmark_topk_sparse_attention(): from benchmark_configs import configs + torch.manual_seed(0) # Config for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: - # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) - program = blocksparse_flashattn( - BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) kernel = tilelang.compile(program, out_idx=4) def benchmark_fn(): diff --git a/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py index e4828ce5f..85d754ae3 100644 --- a/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py @@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -31,39 +28,37 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F def benchmark_topk_sparse_attention(): from benchmark_configs import configs + torch.manual_seed(0) # Config for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: - # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) def benchmark_fn(): # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal # PyTorch reference implementation - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale - attn = attn.masked_fill(~full_mask, float('-inf')) + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) return ref_output ref_latency = do_bench( diff --git a/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py index 86ac894bc..7ebca93a6 100644 --- a/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py @@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -56,7 +53,6 @@ def _fwd_kernel_inner( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) if mask_val == True: @@ -72,8 +68,7 @@ def _fwd_kernel_inner( # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N if LAST_K_BLOCK: - qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, - float('-inf')) + qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf")) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_ij[:, None] @@ -153,7 +148,7 @@ def _fwd_kernel( v_ptrs = V + off_v mask_ptrs = block_mask_ptr + start_m * stride_bmm - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) @@ -191,24 +186,12 @@ def _fwd_kernel( acc = acc * l_recip acc = acc.to(Out.dtype.element_ty) - off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[ - None, :] * stride_od + off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) -def _forward(ctx, - q, - k, - v, - block_sparse_mask, - sm_scale, - BLOCK_M=64, - BLOCK_N=64, - num_warps=None, - num_stages=1, - out=None): - +def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None): assert q.shape[-1] == k.shape[-1] == v.shape[-1] assert k.shape[2] == v.shape[2] o = out if out is not None else torch.empty_like(q).contiguous() @@ -253,7 +236,6 @@ def _forward(ctx, class _sparse_attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, block_sparse_dense, sm_scale): # shape constraints @@ -271,24 +253,22 @@ def backward(ctx, do): def benchmark_topk_sparse_attention(): from benchmark_configs import configs + torch.manual_seed(0) # Config for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: - # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) diff --git a/benchmark/mamba2/README.md b/benchmark/mamba2/README.md index 0b6de19b1..f0b4b7e80 100644 --- a/benchmark/mamba2/README.md +++ b/benchmark/mamba2/README.md @@ -45,7 +45,6 @@ PY | 16384 | 2.531 | 135.711 | | 32768 | 5.076 | 135.379 | - ## Compare with Baselines - Triton: v3.5.0, mamba-ssm: v2.2.6.post3 @@ -56,4 +55,4 @@ PY Mamba2_chunk_scan Performance Comparison on H100
Performance comparison across compilers on NVIDIA H100
- \ No newline at end of file + diff --git a/benchmark/mamba2/benchmark_mamba_chunk_scan.py b/benchmark/mamba2/benchmark_mamba_chunk_scan.py index aff810f66..55f802b4f 100644 --- a/benchmark/mamba2/benchmark_mamba_chunk_scan.py +++ b/benchmark/mamba2/benchmark_mamba_chunk_scan.py @@ -51,14 +51,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] decay = torch.exp(dt_segment_sum) scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s") - causal_mask = torch.tril( - torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) + causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) scores_decay = scores_decay.masked_fill(~causal_mask, 0) - out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype), - rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) + out = torch.einsum( + "bchls,bhcs,bcshp->bclhp", scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks) + ) state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) - out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange( - C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out + out_prev = ( + torch.einsum("bclhn,bchpn->bclhp", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out + ) out = out + out_prev out = rearrange(out, "b c l h p -> b (c l) h p") if D is not None: @@ -74,7 +75,6 @@ def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D): def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): - @helion.kernel() def helion_mamba2_chunk_scan_kernel( cb: torch.Tensor, @@ -118,8 +118,7 @@ def helion_mamba2_chunk_scan_kernel( dtype = cb.dtype accum_dtype = torch.float32 - assert (x.dtype == dt.dtype == dA_cumsum.dtype == C.dtype == prev_states.dtype == D.dtype == - dtype) + assert x.dtype == dt.dtype == dA_cumsum.dtype == C.dtype == prev_states.dtype == D.dtype == dtype out = torch.empty_like(x) @@ -127,11 +126,10 @@ def helion_mamba2_chunk_scan_kernel( for tile_h, tile_m, tile_n, tile_b, tile_c in hl.tile( [nheads, chunk_size, headdim, batch, nchunks], - block_size=[1, block_m, block_n, 1, 1], + block_size=[1, block_m, block_n, 1, 1], ): acc_o = hl.zeros([tile_m, tile_n], dtype=accum_dtype) - dA_cumsum_local_m = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, - tile_m].to(torch.float32) + dA_cumsum_local_m = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, tile_m].to(torch.float32) scale_m_local = torch.exp2(dA_cumsum_local_m * p) C_local = C[ @@ -152,10 +150,8 @@ def helion_mamba2_chunk_scan_kernel( tile_m, tile_k, ] - dA_cumsum_local_k = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, - tile_k].to(torch.float32) - cb_local *= torch.exp2(dA_cumsum_local_m[:, None] * p - - dA_cumsum_local_k[None, :] * p) + dA_cumsum_local_k = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32) + cb_local *= torch.exp2(dA_cumsum_local_m[:, None] * p - dA_cumsum_local_k[None, :] * p) dt_local = dt[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32) cb_local = (cb_local * dt_local[None, :]).to(dtype) pred = (tile_m.index + 0)[:, None] >= (tile_k.index + 0)[None, :] @@ -169,11 +165,9 @@ def helion_mamba2_chunk_scan_kernel( acc_o = hl.dot(cb_local, x_local, acc=acc_o) D_local = D[tile_h.begin].to(torch.float32) - x_residual = x[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, - tile_n].to(torch.float32) + x_residual = x[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n].to(torch.float32) acc_o += x_residual * D_local - out[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, - tile_n] = acc_o.to(dtype=dtype) + out[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n] = acc_o.to(dtype=dtype) return out @@ -182,12 +176,7 @@ def helion_mamba2_chunk_scan_kernel( def get_configs(): - iter_params = dict( - block_M=[64, 128, 256], - block_N=[32, 64], - block_K=[64, 128, 256], - block_Dstate=[128], - num_stages=[1, 2, 3, 4, 5]) + iter_params = dict(block_M=[64, 128, 256], block_N=[32, 64], block_K=[64, 128, 256], block_Dstate=[128], num_stages=[1, 2, 3, 4, 5]) return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] @@ -198,56 +187,58 @@ def get_configs(): tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }, ) -def chunk_scan_fwd(batch, - seqlen, - chunk_size, - ngroups, - nheads, - headdim, - dstate, - block_M=64, - block_N=64, - block_K=64, - block_Dstate=128, - num_stages=2, - threads=128): - dtype = "float16" - accum_dtype = "float" +def chunk_scan_fwd( + batch, + seqlen, + chunk_size, + ngroups, + nheads, + headdim, + dstate, + block_M=64, + block_N=64, + block_K=64, + block_Dstate=128, + num_stages=2, + threads=128, +): + dtype = T.float16 + accum_dtype = T.float32 nchunks = T.ceildiv(seqlen, chunk_size) p = 1.44269504 @T.prim_func def main( - cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore - x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore - dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore - dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore - C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore - prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore - D: T.Tensor((nheads), dtype), # type: ignore - Output: T.Tensor((batch, seqlen, nheads, headdim), dtype) # type: ignore + cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore + x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore + dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore + dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore + C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore + prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore + D: T.Tensor((nheads), dtype), # type: ignore + Output: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore ): - with T.Kernel( - nheads, - T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), - batch * nchunks, - threads=threads) as (bz, bx, by): + with T.Kernel(nheads, T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), batch * nchunks, threads=threads) as ( + bz, + bx, + by, + ): acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) acc_o_shared = T.alloc_shared((block_M, block_N), dtype) - cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn") + cb_shared = T.alloc_shared((block_M, block_K), dtype) cb_local = T.alloc_fragment((block_M, block_K), dtype) - dA_cs_k_shared = T.alloc_shared((block_K), dtype, scope="shared") + dA_cs_k_shared = T.alloc_shared((block_K), dtype) dA_cs_k_local = T.alloc_fragment((block_K), accum_dtype) dA_cs_m_local = T.alloc_fragment((block_M), accum_dtype) - dt_shared = T.alloc_shared((block_K), dtype, scope="shared") + dt_shared = T.alloc_shared((block_K), dtype) dt_local = T.alloc_fragment((block_K), accum_dtype) - x_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.dyn") - dA_cs_m_shared = T.alloc_shared((block_M), dtype, scope="shared") + x_shared = T.alloc_shared((block_K, block_N), dtype) + dA_cs_m_shared = T.alloc_shared((block_M), dtype) scale_m_local = T.alloc_fragment((block_M), accum_dtype) C_shared = T.alloc_shared((block_M, block_Dstate), dtype) prev_state_shared = T.alloc_shared((block_N, block_Dstate), dtype) D_local = T.alloc_fragment((1), accum_dtype) - x_residual_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.dyn") + x_residual_shared = T.alloc_shared((block_M, block_N), dtype) x_residual_local = T.alloc_fragment((block_M, block_N), accum_dtype) batch_idx = by % batch @@ -257,27 +248,31 @@ def main( m_idx = bx // T.ceildiv(headdim, block_N) n_idx = bx % T.ceildiv(headdim, block_N) - T.annotate_layout({ - acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared), - cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), - x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared) - }) + T.annotate_layout( + { + cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), + x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared), + } + ) T.no_set_max_nreg() - T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M], - dA_cs_m_shared) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M : (m_idx + 1) * block_M], dA_cs_m_shared) T.copy(dA_cs_m_shared, dA_cs_m_local) T.clear(acc_o) for i in T.Parallel(block_M): scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p) T.copy( - C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared) - T.copy( - prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N, - 0:block_Dstate], prev_state_shared) + C[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz // (nheads // ngroups), + 0:block_Dstate, + ], + C_shared, + ) + T.copy(prev_states[batch_idx, chunk_idx, bz, n_idx * block_N : (n_idx + 1) * block_N, 0:block_Dstate], prev_state_shared) T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True) for i, j in T.Parallel(block_M, block_N): acc_o[i, j] *= scale_m_local[i] @@ -286,34 +281,47 @@ def main( for k in T.Pipelined(loop_range, num_stages=num_stages): T.copy( - cb[batch_idx, chunk_idx, bz // (nheads // ngroups), - m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K], - cb_shared) + cb[ + batch_idx, + chunk_idx, + bz // (nheads // ngroups), + m_idx * block_M : (m_idx + 1) * block_M, + k * block_K : (k + 1) * block_K, + ], + cb_shared, + ) T.copy(cb_shared, cb_local) - T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], - dA_cs_k_shared) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cs_k_shared) T.copy(dA_cs_k_shared, dA_cs_k_local) for i, j in T.Parallel(block_M, block_K): - cb_local[i, - j] = cb_local[i, - j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) - T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared) + cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) + T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared) T.copy(dt_shared, dt_local) for i, j in T.Parallel(block_M, block_K): cb_local[i, j] *= dt_local[j] for i, j in T.Parallel(block_M, block_K): - cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, - cb_local[i, j], 0) + cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, cb_local[i, j], 0) T.copy( - x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + - (k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared) + x[ + batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + x_shared, + ) T.gemm(cb_local, x_shared, acc_o) D_local[0] = D[bz] T.copy( - x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N], - x_residual_shared) + x[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + x_residual_shared, + ) T.copy(x_residual_shared, x_residual_local) for i, j in T.Parallel(block_M, block_N): acc_o[i, j] += x_residual_local[i, j] * D_local[0] @@ -321,24 +329,37 @@ def main( T.copy(acc_o, acc_o_shared) T.copy( acc_o_shared, - Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N]) + Output[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + ) return main if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=80, help='heads') - parser.add_argument('--groups', type=int, default=1, help='groups') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--chunk_size', type=int, default=256, help='chunk size') - parser.add_argument('--dim', type=int, default=64, help='dim') - parser.add_argument('--dstate', type=int, default=128, help='dstate') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=80, help="heads") + parser.add_argument("--groups", type=int, default=1, help="groups") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--chunk_size", type=int, default=256, help="chunk size") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--dstate", type=int, default=128, help="dstate") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate + batch, heads, groups, seq_len, chunk_size, dim, dstate = ( + args.batch, + args.heads, + args.groups, + args.seq_len, + args.chunk_size, + args.dim, + args.dstate, + ) nchunks = math.ceil(seq_len / chunk_size) total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate @@ -360,8 +381,7 @@ def main( D = torch.randn(heads).half().cuda() print("Benchmarking Triton...") - triton_latency = do_bench( - lambda: chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D), _n_warmup=10, _n_repeat=10) + triton_latency = do_bench(lambda: chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D), _n_warmup=10, _n_repeat=10) print(f"Triton TFlops: {total_flops / triton_latency * 1e-9}") print("Benchmarking Helion...") diff --git a/benchmark/matmul/benchmark_matmul.py b/benchmark/matmul/benchmark_matmul.py index c64f4fabf..dca98a676 100644 --- a/benchmark/matmul/benchmark_matmul.py +++ b/benchmark/matmul/benchmark_matmul.py @@ -2,10 +2,10 @@ import itertools import logging -import tilelang import tilelang.language as T from tilelang.autotuner import autotune from tilelang import jit + # Configure logger logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -61,9 +61,9 @@ def get_configs(args, kwargs): M=M, N=N, K=K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float", + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float32, ).with_arch(arch) func = carve_template.equivalent_function() @@ -101,9 +101,7 @@ def get_configs(args, kwargs): policy=[T.GemmWarpPolicy.Square], enable_rasteration=[True, False], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] return configs @@ -112,7 +110,9 @@ def get_configs(args, kwargs): warmup=3, rep=20, ) -@jit(out_idx=[2],) +@jit( + out_idx=[2], +) def matmul( M, N, @@ -154,14 +154,14 @@ def matmul( # Use half-precision for input data to reduce memory bandwidth, # accumulate in float for better numerical accuracy - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): """ The compiled TVM function for block-level matrix multiplication. @@ -176,7 +176,6 @@ def main( # Bind x-dimension to block index in N, # y-dimension to block index in M. with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - # Allocate shared memory for A sub-block of shape (block_M, block_K) A_shared = T.alloc_shared((block_M, block_K), dtype) # Allocate shared memory for B sub-block of shape (block_N, block_K) @@ -188,8 +187,6 @@ def main( # Enable (or disable) swizzling optimization T.use_swizzle(panel_size=10, enable=enable_rasteration) - # to utilize swizzle tma layout - T.annotate_layout({C_shared: tilelang.layout.make_swizzled_layout(C_shared)}) # Clear out the accumulation buffer T.clear(C_local) diff --git a/benchmark/matmul/benchmark_matmul_intrinsic.py b/benchmark/matmul/benchmark_matmul_intrinsic.py index 94e36b385..4ef860c21 100644 --- a/benchmark/matmul/benchmark_matmul_intrinsic.py +++ b/benchmark/matmul/benchmark_matmul_intrinsic.py @@ -6,7 +6,8 @@ import tilelang.language as T from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func from tilelang.autotuner import autotune import itertools @@ -48,22 +49,22 @@ def tl_matmul( enable_rasteration=False, ): assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 - if out_dtype == "int32": + if out_dtype == T.int32: micro_size_k = 32 # This is a debug config - # chunk = 32 if in_dtype == "float16" else 64 + # chunk = 32 if in_dtype == T.float16 else 64 shared_scope = "shared.dyn" block_M = block_row_warps * warp_row_tiles @@ -103,12 +104,11 @@ def tl_matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -116,10 +116,12 @@ def main( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10, enable=enable_rasteration) @@ -127,7 +129,6 @@ def main( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -137,7 +138,6 @@ def main( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a(A_local, A_shared, ki) @@ -194,9 +194,9 @@ def get_configs(args, kwargs): M=M, N=N, K=K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float16, ).with_arch(arch) func = carve_template.equivalent_function() @@ -223,7 +223,6 @@ def get_configs(args, kwargs): for config in configs: print(config) else: - iter_params = dict( block_row_warps=[1, 2, 4], block_col_warps=[1, 2, 4], @@ -233,9 +232,7 @@ def get_configs(args, kwargs): stage=[0, 2], enable_rasteration=[True, False], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] return configs @@ -247,14 +244,16 @@ def get_configs(args, kwargs): ref_prog=ref_program, skip_check=True, ) -@tl.jit(out_idx=[2],) +@tl.jit( + out_idx=[2], +) def matmul( M, N, K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float16, with_roller=False, block_row_warps=None, block_col_warps=None, @@ -291,19 +290,14 @@ def kernel(): parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") - parser.add_argument( - "--with_roller", - type=bool, - default=False, - help="Whether to use roller to deduce search spaces") - parser.add_argument( - "--dtype", type=str, default="float16", choices=["float16", "int8"], help="Input data type") + parser.add_argument("--with_roller", type=bool, default=False, help="Whether to use roller to deduce search spaces") + parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "int8"], help="Input data type") args = parser.parse_args() M, N, K = args.m, args.n, args.k - in_dtype = args.dtype - out_dtype = "float32" if in_dtype == "int8" else "float16" - accum_dtype = "float32" if in_dtype == "int8" else "float16" + in_dtype = T.dtype(args.dtype) + out_dtype = T.float32 if in_dtype == T.int8 else T.float16 + accum_dtype = T.float32 if in_dtype == T.int8 else T.float16 with_roller = args.with_roller with_roller = True # Compute total floating-point operations diff --git a/benchmark/matmul/benchmark_matmul_sp.py b/benchmark/matmul/benchmark_matmul_sp.py index 4e4ed6128..7ecffc26a 100644 --- a/benchmark/matmul/benchmark_matmul_sp.py +++ b/benchmark/matmul/benchmark_matmul_sp.py @@ -9,7 +9,7 @@ from tilelang.autotuner import autotune from tilelang import jit from tilelang.contrib import nvcc -from tilelang.layout import make_metadata_layout +from tilelang.layout import make_cutlass_metadata_layout # Configure logger logger = logging.getLogger(__name__) @@ -70,7 +70,8 @@ def get_configs(M, N, K): thread_num, policy, enable_rasterization, - )) + ) + ) configs = [ { @@ -81,12 +82,13 @@ def get_configs(M, N, K): "thread_num": c[4], "policy": c[5], "enable_rasterization": c[6], # keep param name for backward-compat - } for c in _configs + } + for c in _configs ] return configs -def matmul_sp(M, N, K, accum_dtype): +def matmul_sp(M, N, K, in_dtype, accum_dtype): """ Create an autotuned matrix multiplication kernel for matrices of shape: - A: (M, K) @@ -126,7 +128,9 @@ def matmul_sp(M, N, K, accum_dtype): warmup=3, rep=20, ) - @jit(out_idx=[2],) + @jit( + out_idx=[2], + ) def kernel( block_M=None, block_N=None, @@ -161,15 +165,14 @@ def kernel( """ # Use half-precision for input data to reduce memory bandwidth, # accumulate in float for better numerical accuracy - dtype = "float16" e_factor, e_dtype = ARCH_INFO[arch] @T.prim_func def main( - A_sparse: T.Tensor((M, K // 2), dtype), - E: T.Tensor((M, K // e_factor), e_dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), accum_dtype), + A_sparse: T.Tensor((M, K // 2), in_dtype), + E: T.Tensor((M, K // e_factor), e_dtype), + B: T.Tensor((K, N), in_dtype), + C: T.Tensor((M, N), accum_dtype), ): """ The compiled TVM function for block-level matrix multiplication. @@ -183,13 +186,11 @@ def main( """ # Bind x-dimension to block index in N, # y-dimension to block index in M. - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): # Allocate shared memory for A sub-block of shape (block_M, block_K) - A_shared = T.alloc_shared((block_M, block_K // 2), dtype) + A_shared = T.alloc_shared((block_M, block_K // 2), in_dtype) # Allocate shared memory for B sub-block of shape (block_N, block_K) - B_shared = T.alloc_shared((block_K, block_N), dtype) + B_shared = T.alloc_shared((block_K, block_N), in_dtype) # Allocate shared memory for E sub-block of shape (block_M, block_K // E_factor) E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) # Allocate a local fragment for intermediate accumulation @@ -202,14 +203,12 @@ def main( T.disable_warp_group_reg_alloc() T.use_swizzle(panel_size=10, enable=enable_rasterization) - T.annotate_layout({ - E: - make_metadata_layout( - E, mma_dtype="float16", backend="cutlass", block_k=block_K), - E_shared: - make_metadata_layout( - E_shared, mma_dtype="float16", backend="cutlass", block_k=block_K), - }) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, block_k=block_K), + } + ) # Loop over sub-blocks in K dimension, pipelined by num_stages for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): # Load a sub-block of A from global memory into A_shared @@ -220,7 +219,7 @@ def main( T.copy(B[k * block_K, bx * block_N], B_shared) # Perform a partial matrix multiplication: # C_local += A_shared @ B_shared - T.gemm_sp( + T.gemm_sp_v2( A_shared, E_shared, B_shared, @@ -244,18 +243,13 @@ def main( parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") parser.add_argument("--disable_cache", action="store_true") - parser.add_argument( - "--accum_dtype", - type=str, - default="float", - choices=["float", "float16"], - help="Accumulation datatype") + parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype") parser.add_argument( "--bench_torch_sparse", type=str, - choices=['cutlass', 'cusparselt'], + choices=["cutlass", "cusparselt"], default=None, - help="Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported" + help="Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported", ) args = parser.parse_args() @@ -268,7 +262,7 @@ def main( total_flops = 2 * M * N * K # matmul(...) returns (best_latency, best_config, ref_latency) - best_result = matmul_sp(M, N, K, args.accum_dtype) + best_result = matmul_sp(M, N, K, T.float16, args.accum_dtype) best_latency = best_result.latency best_config = best_result.config A = torch.randn(M, K, dtype=torch.float16, device="cuda") @@ -277,7 +271,8 @@ def main( if args.bench_torch_sparse is not None: from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor - if args.bench_torch_sparse == 'cutlass': + + if args.bench_torch_sparse == "cutlass": SparseSemiStructuredTensor._FORCE_CUTLASS = True A_sp = to_sparse_semi_structured(A, transposed=False) torch_sparse_latency = do_bench(lambda: A_sp @ B) @@ -288,8 +283,6 @@ def main( print(f"Best config: {best_config}") if args.bench_torch_sparse is not None: - print( - f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}" - ) + print(f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}") print(f"Reference Dense TFlops: {total_flops / ref_latency * 1e-9:.3f}") diff --git a/benchmark/matmul_fp8/benchmark_matmul.py b/benchmark/matmul_fp8/benchmark_matmul.py index 36b910355..64714b649 100644 --- a/benchmark/matmul_fp8/benchmark_matmul.py +++ b/benchmark/matmul_fp8/benchmark_matmul.py @@ -1,7 +1,7 @@ import argparse import itertools +import torch import logging -import tilelang import tilelang.language as T from tilelang.autotuner import autotune from tilelang import jit @@ -62,9 +62,9 @@ def get_configs(args, kwargs): M=M, N=N, K=K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float", + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float32, ).with_arch(arch) func = carve_template.equivalent_function() @@ -99,12 +99,11 @@ def get_configs(args, kwargs): block_K=[64, 128], num_stages=[0, 1, 2, 3], thread_num=[128, 256], + k_pack=[1, 2], policy=[T.GemmWarpPolicy.Square], enable_rasteration=[True, False], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] return configs @@ -114,7 +113,9 @@ def get_configs(args, kwargs): warmup=3, rep=20, ) -@jit(out_idx=[2],) +@jit( + out_idx=[2], +) def matmul( M, N, @@ -125,6 +126,7 @@ def matmul( block_K=None, num_stages=None, thread_num=None, + k_pack=None, policy=None, enable_rasteration=None, ): @@ -156,14 +158,14 @@ def matmul( # Use half-precision for input data to reduce memory bandwidth, # accumulate in float for better numerical accuracy - dtype = "float8_e4m3" - accum_dtype = "float" + dtype = T.float8_e4m3fnuz if torch.version.hip is not None else T.float8_e4m3fn + accum_dtype = T.float32 @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): """ The compiled TVM function for block-level matrix multiplication. @@ -178,7 +180,6 @@ def main( # Bind x-dimension to block index in N, # y-dimension to block index in M. with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - # Allocate shared memory for A sub-block of shape (block_M, block_K) A_shared = T.alloc_shared((block_M, block_K), dtype) # Allocate shared memory for B sub-block of shape (block_N, block_K) @@ -190,8 +191,6 @@ def main( # Enable (or disable) swizzling optimization T.use_swizzle(panel_size=10, enable=enable_rasteration) - # to utilize swizzle tma layout - T.annotate_layout({C_shared: tilelang.layout.make_swizzled_layout(C_shared)}) # Clear out the accumulation buffer T.clear(C_local) @@ -210,6 +209,7 @@ def main( C_local, transpose_B=True, policy=policy, + k_pack=k_pack, ) # Write back the results from C_local to the global memory C T.copy(C_local, C_shared) diff --git a/cmake/load_tvm.cmake b/cmake/load_tvm.cmake index f013c3ba6..cb21be95f 100644 --- a/cmake/load_tvm.cmake +++ b/cmake/load_tvm.cmake @@ -3,12 +3,15 @@ set(TVM_BUILD_FROM_SOURCE TRUE) set(TVM_SOURCE ${CMAKE_SOURCE_DIR}/3rdparty/tvm) -if(DEFINED $ENV{TVM_ROOT}) +if(DEFINED ENV{TVM_ROOT}) if(EXISTS $ENV{TVM_ROOT}/cmake/config.cmake) set(TVM_SOURCE $ENV{TVM_ROOT}) + message(STATUS "Using TVM_ROOT from environment variable: ${TVM_SOURCE}") endif() endif() +message(STATUS "Using TVM source: ${TVM_SOURCE}") + set(TVM_INCLUDES ${TVM_SOURCE}/include ${TVM_SOURCE}/src diff --git a/cmake/pypi-z3/FindZ3.cmake b/cmake/pypi-z3/FindZ3.cmake new file mode 100644 index 000000000..d7920f8f9 --- /dev/null +++ b/cmake/pypi-z3/FindZ3.cmake @@ -0,0 +1,30 @@ +if(Z3_FOUND) + return() +endif() +find_package(Python3 COMPONENTS Interpreter REQUIRED) +execute_process( + COMMAND "${Python3_EXECUTABLE}" -c "import z3; print(z3.__path__[0])" + OUTPUT_VARIABLE Z3_PATH + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE Z3_PYTHON_RESULT +) +if(NOT Z3_PYTHON_RESULT EQUAL 0 OR Z3_PATH STREQUAL "") + message(FATAL_ERROR "Failed to locate z3 Python package. Ensure z3-solver>=4.13.0 is installed.") +endif() +message("-- Find Z3 in path: ${Z3_PATH}") +find_path(Z3_INCLUDE_DIR NO_DEFAULT_PATH NAMES z3++.h PATHS ${Z3_PATH}/include) +find_library(Z3_LIBRARY NO_DEFAULT_PATH NAMES z3 libz3 PATHS ${Z3_PATH}/bin ${Z3_PATH}/lib ${Z3_PATH}/lib64) +message("-- Found Z3 include dir: ${Z3_INCLUDE_DIR}") +message("-- Found Z3 library: ${Z3_LIBRARY}") +add_library(z3::libz3 SHARED IMPORTED GLOBAL) +set_target_properties(z3::libz3 + PROPERTIES + IMPORTED_LOCATION ${Z3_LIBRARY} + INTERFACE_INCLUDE_DIRECTORIES ${Z3_INCLUDE_DIR} +) +if(NOT Z3_INCLUDE_DIR OR NOT Z3_LIBRARY) + message(FATAL_ERROR "Could not find Z3 library or include directory") +endif() +set(Z3_CXX_INCLUDE_DIRS ${Z3_INCLUDE_DIR}) +set(Z3_C_INCLUDE_DIRS ${Z3_INCLUDE_DIR}) +set(Z3_FOUND TRUE) diff --git a/docker/Dockerfile.cu118 b/docker/Dockerfile.cu118 index 9256fc09b..969b0e43c 100644 --- a/docker/Dockerfile.cu118 +++ b/docker/Dockerfile.cu118 @@ -1,4 +1,4 @@ -FROM nvcr.io/nvidia/pytorch:22.12-py3 +FROM nvcr.io/nvidia/pytorch:22.12-py3 WORKDIR /root @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.cu120 b/docker/Dockerfile.cu120 index c89ce82ef..341fe40c0 100644 --- a/docker/Dockerfile.cu120 +++ b/docker/Dockerfile.cu120 @@ -1,4 +1,4 @@ -FROM nvcr.io/nvidia/pytorch:23.01-py3 +FROM nvcr.io/nvidia/pytorch:23.01-py3 WORKDIR /root @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.cu121 b/docker/Dockerfile.cu121 index 5b092773d..f91029d75 100644 --- a/docker/Dockerfile.cu121 +++ b/docker/Dockerfile.cu121 @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.cu123 b/docker/Dockerfile.cu123 index 2715536a8..b3d1217fd 100644 --- a/docker/Dockerfile.cu123 +++ b/docker/Dockerfile.cu123 @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.cu124 b/docker/Dockerfile.cu124 index fb9654f48..335f52565 100644 --- a/docker/Dockerfile.cu124 +++ b/docker/Dockerfile.cu124 @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.cu125 b/docker/Dockerfile.cu125 index c409667cb..148e44b41 100644 --- a/docker/Dockerfile.cu125 +++ b/docker/Dockerfile.cu125 @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.cu126 b/docker/Dockerfile.cu126 index 93593b5df..c031c2bc9 100644 --- a/docker/Dockerfile.cu126 +++ b/docker/Dockerfile.cu126 @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.cu128 b/docker/Dockerfile.cu128 index db5e1cb57..2b895ecd8 100644 --- a/docker/Dockerfile.cu128 +++ b/docker/Dockerfile.cu128 @@ -26,6 +26,6 @@ RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev z RUN pip install cython RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && cmake -S . -B build -DUSE_CUDA=ON && cmake --build build -j + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 1fb23a9f3..5f61f0e2e 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -9,23 +9,43 @@ ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update && apt-get install -y --no-install-recommends \ build-essential git wget \ libgtest-dev libprotobuf-dev protobuf-compiler libgflags-dev libsqlite3-dev llvm-dev \ + rocm-dev rocm-libs hip-dev hipblas-dev rocblas-dev \ && apt-get clean autoclean && rm -rf /var/lib/apt/lists/{cache,log} /tmp/* /var/tmp/* ENV PATH="/opt/conda/bin:${PATH}" ENV LIBGL_ALWAYS_INDIRECT=1 +ENV USE_ROCM=1 +ENV USE_CUDA=0 +ENV ROCM_HOME=/opt/rocm +ENV HIP_PLATFORM=amd +ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942" RUN conda run -n py_3.10 conda install pip cmake -y && \ conda run -n py_3.10 conda install -c conda-forge libstdcxx-ng=12 -y && \ conda clean --all -RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev +RUN apt-get update && apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev && \ + apt-get clean autoclean && rm -rf /var/lib/apt/lists/{cache,log} /tmp/* /var/tmp/* -RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main tilelang && \ - conda run -n py_3.10 bash -c "cd tilelang && ./install_rocm.sh" +# Copy local tilelang directory instead of cloning from git +# Build from tilelang root: docker build -f docker/Dockerfile.rocm -t mi300:latest . +COPY . /root/tilelang -RUN conda init bash +RUN mv /opt/conda/envs/py_3.10/compiler_compat /opt/conda/envs/py_3.10/compiler_compat.bak || true && \ + conda run -n py_3.10 bash -c "export USE_ROCM=1 USE_CUDA=0 && pip install 'numpy<2.0' --force-reinstall" && \ + conda run -n py_3.10 bash -c "cd /root/tilelang && \ + # Backup and modify pyproject.toml to remove torch from dependencies \ + cp pyproject.toml pyproject.toml.bak && \ + sed -i '/^[[:space:]]*\"torch/d' pyproject.toml && \ + # Install tilelang with all dependencies except torch \ + USE_ROCM=1 USE_CUDA=0 pip install -e . -v && \ + # Restore original pyproject.toml \ + mv pyproject.toml.bak pyproject.toml" + +RUN conda init bash && \ + echo "conda activate py_3.10" >> /root/.bashrc SHELL ["/bin/bash", "-l", "-c"] -CMD ["bash", "-c", "source ~/.bashrc && conda activate py_3.10 && exec bash"] \ No newline at end of file +ENTRYPOINT ["/bin/bash", "--login", "-i"] diff --git a/docs/.gitignore b/docs/.gitignore index 4d8eb4049..79ba97163 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -1,2 +1,2 @@ _build/ -autoapi/ \ No newline at end of file +autoapi/ diff --git a/docs/CNAME b/docs/CNAME index ca903c694..6862cd2e9 100644 --- a/docs/CNAME +++ b/docs/CNAME @@ -1 +1 @@ -tilelang.com \ No newline at end of file +tilelang.com diff --git a/docs/README.md b/docs/README.md index 349c0eccc..896d778d2 100644 --- a/docs/README.md +++ b/docs/README.md @@ -27,4 +27,4 @@ cd _build/html python3 -m http.server ``` -Then you can view the documentation in your browser at `http://localhost:8000` (the port can be customized by appending ` -p PORT_NUMBER` in the python command above). +Then you can view the documentation in your browser at `http://localhost:8000` (the port can be customized by appending `-p PORT_NUMBER` in the python command above). diff --git a/docs/_static/custom.css b/docs/_static/custom.css new file mode 100644 index 000000000..a1fee9c3d --- /dev/null +++ b/docs/_static/custom.css @@ -0,0 +1,10 @@ +/* Reduce the displayed size of the sidebar logo in Furo */ +.sidebar-logo { + max-height: 125px; + width: auto; +} + +/* Optional: keep container from growing too tall due to spacing */ +.sidebar-logo-container { + line-height: 0; +} diff --git a/docs/_static/img/logo-row.svg b/docs/_static/img/logo-row.svg index 633243f3a..e73244b74 100644 --- a/docs/_static/img/logo-row.svg +++ b/docs/_static/img/logo-row.svg @@ -1 +1 @@ - \ No newline at end of file + diff --git a/docs/_static/img/logo-v2.png b/docs/_static/img/logo-v2.png new file mode 100644 index 000000000..410773f60 Binary files /dev/null and b/docs/_static/img/logo-v2.png differ diff --git a/docs/_static/img/logo.png b/docs/_static/img/logo.png new file mode 100644 index 000000000..5d04697ce Binary files /dev/null and b/docs/_static/img/logo.png differ diff --git a/docs/_static/img/sparse_mma_storage_example.png b/docs/_static/img/sparse_mma_storage_example.png new file mode 100644 index 000000000..0b1639819 Binary files /dev/null and b/docs/_static/img/sparse_mma_storage_example.png differ diff --git a/docs/compiler_internals/tensor_checks.md b/docs/compiler_internals/tensor_checks.md new file mode 100644 index 000000000..ed5a9e691 --- /dev/null +++ b/docs/compiler_internals/tensor_checks.md @@ -0,0 +1,386 @@ +# Tensor Checks (Host-Side Auto-Validation) + +This page explains the host-side checks that TileLang automatically inserts into the generated host stub for kernels. When you pass `torch.Tensor` or any DLPack-compatible object to a TileLang kernel, the host stub validates argument count, pointer kinds, dtype, shape, strides, device, and more — so you don’t need to handwrite Python checks. This keeps the ABI stable and significantly reduces Python overhead compared to doing equivalent checks in Python or via pybind. + +## Why Host-Side Checks +- ABI stability: the entry is based on TVM FFI + DLPack, consistently accepting tensors and scalars. +- Lower overhead: shifting checks from Python into C reduces interpreter/property-access costs; the call overhead is lower than pybind-based approaches. +- Focused error reporting: assertions are raised close to the call site with precise “which field failed” messages. + +## How To Inspect Host Source +You can inspect the auto-generated host source (with all checks and the final device-kernel call) for debugging: + +```python +print(matmul_relu_kernel.get_host_source()) +``` + +--- + +## What The Host Checks + +### 1) Argument count and pointer kind +- `num_args` must match the number of formal parameters; otherwise the kernel returns `-1` with an error message. +- Each argument’s FFI type must be a pointer kind (for DLTensor/handle) or a valid scalar type; otherwise you’ll see errors like `Expect arg[i] to be pointer` or a scalar type error. + +### 2) Tensor checks (per tensor, after nullability decision) +- Nullability + - If the tensor is “statically reachable/used” by the function body, the handle must be non-NULL; otherwise: `xxx is expected to have non-NULL pointer`. + - If an input tensor is not used by the function (statically unreachable), NULL is allowed; other field checks are executed only when `handle != NULL`. +- Rank (`ndim`) + - Runtime `ndim` must equal the compile-time rank. +- Data type (`dtype`) + - Match the triple `(code, bits, lanes)` with tolerance: + - `float8_e4m3`: accept `e4m3`, `e4m3fn`, `e4m3fnuz`. + - `float8_e5m2`: accept `e5m2`, `e5m2fnuz`. + - `bool`: accept `int8/uint8` with `bits=8` (same lanes), `kDLBool(code=6, bits=1 or 8)`, and any `bitwidth=1` (lanes must match). + - For packed-bit dtypes (e.g., `Int(1)`, `Int(4)`, `UInt(4)`), strict dtype checking is skipped. +- Shape + - Each runtime dimension is bound to the compile-time shape (constants or symbols) and checked for consistency. + - Linear equations among symbolic dims can be solved on the fly (when there’s only one unknown at a given check point), enabling cross-tensor constraints. +- Strides + - If `buffer_type = AutoBroadcast`: allow `strides == NULL` and derive strides from `shape`. If explicit `strides` is present, bind to compile-time constraints and check for equality. + - Otherwise: check per-dimension; if `strides == NULL`, derive from `shape` and compare (e.g., contiguous: `strides[-1] == 1`, `strides[-2] == shape[-1]`). +- `byte_offset` + - Must be 0 (non-zero raises an error) to keep addressing simple and aligned. +- Device info + - Assert `device_type == target backend` (CUDA/ROCM/Metal/OneAPI/WebGPU/CPU, etc.). Error messages include a DLPack code legend. + - When multiple tensors participate, assert that `device_id` matches across them. +- Data pointer + - Must be non-NULL when the tensor is required to be non-null by the nullability rule. + +### 3) Scalar checks +- `T.int*` family: require integer; error: `Expect arg[i] to be int`. +- `T.bool`: require boolean; error: `Expect arg[i] to be boolean`. + +--- + +## Shapes and Symbolic Equations: Linear Solving +When shapes are symbolic, the host binds and (when possible) solves linear relations at runtime (only one unknown per check point). Example: + +```python +@T.prim_func +def main( + A: T.Tensor((m,), dtype), + B: T.Tensor((m + n,), dtype), + C: T.Tensor((n * k,), dtype), +): + ... +``` + +This enables enforcing cross-tensor relationships like `len(B) == m + n` and `len(C) == n * k` at runtime. + +--- + +## Nullability Rules and Examples +Which tensors may be NULL? + +- Rule: If an input tensor is not used by the function under static analysis (i.e., the access is statically unreachable), it is considered Nullable; otherwise it must be non-NULL. +- Examples: + +1) Must be non-NULL (used) +```python +@T.prim_func +def main(A: T.Tensor((M, K), dtype)): + A[0] = 1 +``` +Passing `None` raises: `main.A_handle is expected to have non-NULL pointer`. + +2) Still must be non-NULL (constant-true branch) +```python +some_cond: bool = True +@T.prim_func +def main(A: T.Tensor((M, K), dtype)): + if some_cond: + A[0] = 1 +``` + +3) Nullable (constant-false branch, statically unreachable) +```python +some_cond: bool = False +@T.prim_func +def main(A: T.Tensor((M, K), dtype)): + if some_cond: + A[0] = 1 +``` + +4) Must be non-NULL (runtime condition) +```python +@T.prim_func +def main(A: T.Tensor((M, K), dtype), some_cond: T.bool): + if some_cond: + A[0] = 1 +``` +Since `some_cond` is only known at runtime, static analysis cannot prove `A` is unused; `A` is thus non-nullable. + +--- + +## Device Type Codes (DLPack) +Supported and referenced device codes in error messages: `1=CPU, 2=CUDA, 7=Vulkan, 8=Metal, 10=ROCM, 14=OneAPI, 15=WebGPU`. +Kernels assert that `device_type` matches the target backend, and require `device_id` consistency across tensors. + +--- + +## Common Error Examples (What you’ll see) +- Argument count mismatch (num_args) + - Trigger: missing/extra argument + - Error: `: num_args should be N; expected: , got: N` + +- Pointer-typed argument expected + - Trigger: scalar passed where a tensor is expected + - Error: `: Expect arg[i] to be pointer` + +- Rank (ndim) mismatch + - Trigger: runtime rank differs from compile-time rank + - Error: `..ndim is expected to equal R, but got mismatched ndim` + +- Dtype mismatch + - Trigger: dtype not equal to the compiled dtype and not within the tolerance set + - Error: `..dtype is expected to be , but got incompatible dtype` + +- Shape constraint violation + - Trigger: a dimension doesn’t match a constant/symbol binding + - Error: `Argument ..shape[i] has an unsatisfied constraint: ... == ` + +- Strides check failed (e.g., non-contiguous layout) + - Trigger: transposed/sliced tensors that violate expected strides + - Error: `Argument ..strides[j] has an unsatisfied constraint: ... == ` + +- Device type mismatch + - Trigger: calling a CUDA kernel with CPU tensors, etc. + - Error: `..device_type mismatch [expected: ()] ...` + +- Device id mismatch + - Trigger: mixing tensors from different GPUs + - Error: `Argument ..device_id has an unsatisfied constraint: ... == ...` + +- NULL data pointer + - Trigger: tensor required to be non-null has a NULL data pointer + - Error: `. is expected to have non-NULL data pointer, but got NULL` + +- Scalar type mismatch + - Trigger: passing float to `T.int32`, or non-boolean to `T.bool` + - Error: `: Expect arg[i] to be int/boolean` + +--- + +## Troubleshooting Tips +- Print the host source: `print(fn.get_host_source())` to see the exact assertion and expected vs. actual fields. +- Fix strides: call `.contiguous()` for non-contiguous tensors, or avoid generating transposed/sliced layouts that break assumptions. +- Align devices: ensure all participating tensors share the same `device_type` and `device_id`. +- Align dtype: use `.to()` or construct tensors with the correct dtype; pay attention to `float8` and `bool` tolerance. +- Dynamic shapes: ensure cross-tensor linear relations can be uniquely determined at the check point (only one unknown at a time). + +--- + +## FAQ +- Can I disable the checks? + - Not recommended and usually not supported. Checks are done on the host to preserve ABI stability and fail early close to the device call. +- Is the overhead noticeable? + - The checks are lightweight (branches and field reads). Compared to Python-side checks, it’s faster; the dominating cost remains the Python→C boundary. Overall it’s cheaper than equivalent checks in Python. + +--- + +## Reference Example (Matmul + ReLU) + +```python +@T.prim_func +def matmul_relu_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), +): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + T.copy(C_local, C[by * block_M, bx * block_N]) + +# For debugging, print the host source +print(matmul_relu_kernel.get_host_source()) +``` + +The host will insert all checks described above for this example. + +--- + +## Quick Error Reference (Short List) +- Argument count + - Trigger: missing/extra args; Error: `num_args should be N; expected: , got: N`. +- Pointer kind + - Trigger: scalar passed to tensor arg; Error: `Expect arg[i] to be pointer`. +- Rank (ndim) + - Trigger: runtime rank != compile-time; Error: `ndim ... expected to equal R`. +- Dtype + - Trigger: mismatch and not tolerated; Error: `dtype ... expected to be `. +- Shape + - Trigger: constant/symbol binding violated; Error: `shape[i] ... == `. +- Strides + - Trigger: layout mismatch; Error: `strides[j] ... == `. +- Device type + - Trigger: wrong backend device; Error: `device_type mismatch [expected: ...]`. +- Device id + - Trigger: tensors on different GPUs; Error: `device_id ... == ...`. +- Data pointer + - Trigger: required non-NULL but NULL; Error: `non-NULL data pointer`. +- Scalar types + - Trigger: wrong scalar type; Error: `Expect arg[i] to be int/boolean`. + +--- + +## Host Error Troubleshooting (Minimal Repros) + +Below are minimal repro snippets for common host-side errors, assuming a CUDA-targeted kernel like `matmul_relu_kernel` with: + +```python +# Convention: +# A: float16 [M, K] +# B: float16 [K, N] +# C: float16 [M, N] +# Target: CUDA (device_type=2) +fn = matmul_relu_kernel # your compiled function +M = N = K = 1024 +``` + +Adjust dtype/device if your kernel differs. + +### 0. Tip: print the host source +```python +print(fn.get_host_source()) +``` + +### 1. num_args mismatch +```python +import torch + +A = torch.empty((M, K), device='cuda', dtype=torch.float16) +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +# Missing C +fn(A, B) +``` +Expected: `: num_args should be 3; expected: , got: 3`. + +Fix: pass all arguments per the signature. + +### 2. Expect pointer (tensor) but got scalar +```python +import torch + +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(1, B, C) +``` +Expected: `: Expect arg[0] to be pointer`. + +Fix: pass a DLPack-compatible tensor (e.g., torch.Tensor). + +### 3. ndim mismatch +```python +import torch + +A = torch.empty((M, K, 1), device='cuda', dtype=torch.float16) # rank=3 +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `.A_handle.ndim is expected to equal 2, but got mismatched ndim`. + +Fix: ensure runtime rank equals compiled rank. + +### 4. dtype mismatch +```python +import torch + +A = torch.empty((M, K), device='cuda', dtype=torch.float32) # should be float16 +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `.A_handle.dtype is expected to be float16, but got incompatible dtype`. + +Fix: `A = A.to(torch.float16)` or create with the correct dtype. + +### 5. Shape constant/symbol mismatch +```python +import torch + +A = torch.empty((M, K + 1), device='cuda', dtype=torch.float16) # K mismatched +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `Argument .A_handle.shape[i] has an unsatisfied constraint: ... == `. + +Fix: satisfy linear constraints and constants across tensors. + +### 6. Strides check failure (non-contiguous) +```python +import torch + +A = torch.empty((M, K), device='cuda', dtype=torch.float16) +A_nc = A.t() # transpose -> non-contiguous +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A_nc, B, C) +``` +Expected: `Argument .A_handle.strides[1] has an unsatisfied constraint: ... == 1`. + +Fix: pass `A_nc.contiguous()` or align the layout expectation in the kernel. + +### 7. device_type mismatch +```python +import torch + +A = torch.empty((M, K), device='cpu', dtype=torch.float16) +B = torch.empty((K, N), device='cpu', dtype=torch.float16) +C = torch.empty((M, N), device='cpu', dtype=torch.float16) +fn(A, B, C) # CUDA-targeted kernel +``` +Expected: `.A_handle.device_type mismatch [expected: 2 (cuda)] ...`. + +Fix: move tensors to the CUDA device. + +### 8. device_id mismatch (multi-GPU) +```python +import torch + +A = torch.empty((M, K), device='cuda:0', dtype=torch.float16) +B = torch.empty((K, N), device='cuda:1', dtype=torch.float16) +C = torch.empty((M, N), device='cuda:0', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `Argument .B_handle.device_id has an unsatisfied constraint: ... == ...`. + +Fix: place all tensors on the same GPU (e.g., `cuda:0`). + +### 9. NULL data pointer (advanced) +This usually comes from hand-constructed DLTensor/NDArray, or external frameworks passing unallocated/freed storage. Regular `torch.Tensor` allocations rarely hit this. + +Expected: `. is expected to have non-NULL data pointer, but got NULL`. + +Fix: ensure valid underlying storage; in PyTorch scenarios, avoid constructing tensors from invalid external handles. + +### 10. Scalar type mismatch (int / bool) +```python +import tilelang.language as T + +@T.prim_func +def scalar_check(x: T.int32, flag: T.bool()): + T.evaluate(0) + +scalar_check(1.0, True) # x is float -> Expect arg[0] to be int +scalar_check(1, 2.5) # flag is float -> Expect arg[1] to be boolean +``` + +Fix: pass correct scalar types, e.g., `scalar_check(1, True)`. + +--- + +## Closing Notes +- Cross-check “shape / strides / device / dtype” against the kernel signature to localize issues efficiently. +- For complex symbolic relations, print the host source to confirm binding/solving order, then adjust runtime shapes/layouts accordingly. diff --git a/docs/conf.py b/docs/conf.py index 1b1289038..877b5582e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,5 +1,5 @@ # General information about the project. -project = "Tile Language
" +project = "TileLang
" author = "Tile Lang Contributors" copyright = f"2025-2025, {author}" @@ -20,33 +20,27 @@ "autoapi.extension", ] -autoapi_type = 'python' -autoapi_dirs = ['../tilelang'] +autoapi_type = "python" +autoapi_dirs = ["../tilelang"] autoapi_options = [ - 'members', - 'undoc-members', - 'show-inheritance', - 'show-module-summary', - 'special-members', + "members", + "undoc-members", + "show-inheritance", + "show-module-summary", + "special-members", ] autoapi_keep_files = False # Useful for debugging the generated rst files autoapi_generate_api_docs = True -autodoc_typehints = 'description' +autodoc_typehints = "description" autoapi_ignore = ["*language/ast*", "*version*", "*libinfo*", "*parser*"] -source_suffix = { - '.rst': 'restructuredtext', - '.md': 'markdown', -} +source_suffix = {".rst": "restructuredtext", ".md": "markdown"} -myst_enable_extensions = [ - "colon_fence", - "deflist", -] +myst_enable_extensions = ["colon_fence", "deflist"] redirects = {"get_started/try_out": "../index.html#getting-started"} @@ -62,13 +56,11 @@ html_theme = "furo" templates_path = [] html_static_path = ["_static"] -footer_copyright = "© 2025-2025 Tile Language" +html_css_files = ["custom.css"] +footer_copyright = "© 2025-2026 TileLang" footer_note = " " -html_theme_options = { - "light_logo": "img/logo-row.svg", - "dark_logo": "img/logo-row.svg", -} +html_theme_options = {"light_logo": "img/logo-v2.png", "dark_logo": "img/logo-v2.png"} header_links = [ ("Home", "https://github.com/tile-ai/tilelang"), diff --git a/docs/deeplearning_operators/deepseek_mla.md b/docs/deeplearning_operators/deepseek_mla.md index 08175778f..ed02b58b1 100644 --- a/docs/deeplearning_operators/deepseek_mla.md +++ b/docs/deeplearning_operators/deepseek_mla.md @@ -1,8 +1,7 @@ # 🚀 Write High Performance FlashMLA with TileLang on Hopper -
- Author: Yu Cheng + Author: Yu Cheng Author: Lei Wang
@@ -32,14 +31,14 @@ Figure 1: Performance under batch size=64 Figure 2: Performance under batch size=128 ``` -As shown in the results, TileLang achieves performance comparable to FlashMLA in most cases, significantly outperforming both FlashInfer and Triton. +As shown in the results, TileLang achieves performance comparable to FlashMLA in most cases, significantly outperforming both FlashInfer and Triton. Notably, **TileLang accomplishes this with just around 80 lines of Python code**, demonstrating its exceptional ease of use and efficiency. Let's dive in and see how TileLang achieves this. ## Implementation First, let's review the core computation logic of traditional FlashAttention: -```python +```python # acc_s: [block_M, block_N] # scores_max: [block_M] # scores_scale: [block_M] @@ -62,7 +61,7 @@ Compared to traditional attention operators like MHA (Multi-Headed Attention) or This raises the question of how to partition the matrix multiplication operation. On the Hopper architecture, most computation kernels use [`wgmma.mma_async`](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) instructions for optimal performance. The `wgmma.mma_async` instruction organizes 4 warps (128 threads) into a warpgroup for collective MMA operations. However, `wgmma.mma_async` instructions require a minimum M dimension of 64. This means each warpgroup's minimum M dimension can only be reduced to 64, but a tile size of 64*512 is too large for a single warpgroup, leading to register spilling. -Therefore, our only option is to partition `acc_o` along the `dim` dimension, with two warpgroups computing the left and right part of `acc_o` respectively. However, this introduces another challenge: both warpgroups require the complete `acc_s` result as input. +Therefore, our only option is to partition `acc_o` along the `dim` dimension, with two warpgroups computing the left and right part of `acc_o` respectively. However, this introduces another challenge: both warpgroups require the complete `acc_s` result as input. Our solution is to have each warpgroup compute half of `acc_s` during `Q @ K` computation, then obtain the other half computed by the other warpgroup through shared memory. @@ -106,7 +105,6 @@ T.use_swizzle(panel_size: int, order: str = "row") Here, `panel_size` specifies the width of the swizzled threadblock group, and `order` determines the swizzling pattern, which can be either "row" or "col". - ### Shared Memory Swizzling In CUDA programming, shared memory is divided into multiple memory banks, with each bank capable of servicing one thread request per clock cycle in parallel. Bank conflicts occur when multiple threads simultaneously access different addresses mapped to the same bank, forcing these accesses to be serialized and degrading performance. @@ -123,17 +121,14 @@ T.annotate_layout({ Here, `T.annotate_layout` allows users to specify any desired layout for a buffer. For convenience, TileLang provides the `make_swizzled_layout` primitive to automatically generate a swizzled layout. - ### Warp-Specialization The Hopper architecture commonly employs warp specialization for performance optimization. A typical approach is to designate one warpgroup as a producer that handles data movement using TMA (Tensor Memory Accelerator), while the remaining warpgroups serve as consumers performing computations. However, this programming pattern is complex, requiring developers to manually manage the execution logic for producers and consumers, including synchronization through the `mbarrier` objects. In TileLang, users are completely shielded from these implementation details. The frontend script is automatically transformed into a warp-specialized form, where TileLang handles all producer-consumer synchronization automatically, enabling efficient computation. - ### Pipeline - Pipeline is a technique used to improve memory access efficiency by overlapping memory access and computation. In TileLang, pipeline can be implemented through the `T.pipelined` annotation: ```python @@ -142,14 +137,12 @@ T.pipelined(range: int, stage: int) Here, `range` specifies the range of the pipeline, and `stage` specifies the stage of the pipeline. Multi-stage pipelining enables overlapping of computation and memory access, which can significantly improve performance for memory-intensive operators. However, setting a higher number of stages consumes more shared memory resources, so the optimal configuration needs to be determined based on specific use cases. - ### Split-KV We have also implemented Split-KV optimization similar to [FlashDecoding](https://pytorch.org/blog/flash-decoding/). Specifically, when the batch size is small, parallel SM resources cannot be fully utilized due to low parallelism. In such cases, we can split the kv_ctx dimension across multiple SMs for parallel computation and then merge the results. In our implementation, we have developed both split and combine kernels, allowing users to control the split size through a `num_split` parameter. - ## 🚀 On AMD MI300X Accelerators Following our previous demonstration of [high-performance FlashMLA implementation on NVIDIA Hopper architectures using TileLang](https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_mla/README.md), this work presents an optimized implementation for AMD MI300X accelerators. We examine architectural differences and corresponding optimization strategies between these platforms. @@ -167,7 +160,7 @@ Key implementation differences between Hopper and MI300X architectures include: # Original shared memory allocation Q_shared = T.alloc_shared([block_H, dim], dtype) Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) - + # Optimized register allocation Q_local = T.alloc_fragment([block_H, dim], dtype) Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype) diff --git a/docs/deeplearning_operators/elementwise.md b/docs/deeplearning_operators/elementwise.md index 5e1243c26..6aa8e4085 100644 --- a/docs/deeplearning_operators/elementwise.md +++ b/docs/deeplearning_operators/elementwise.md @@ -8,7 +8,7 @@ :class: myclass1 myclass2 :name: a-tip-reference - This document is still **experimental** and may be incomplete. + This document is still **experimental** and may be incomplete. Suggestions and improvements are highly encouraged—please submit a PR! ::: @@ -24,7 +24,7 @@ Please note that this tutorial does not delve deeply into the design principles ## Elementwise add in TileLang ```python -def elementwise_add(N, threads=256, dtype="bfloat16"): +def elementwise_add(N, threads=256, dtype=T.bfloat16): @T.prim_func def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): @@ -43,7 +43,7 @@ Those familiar with CUDA programming might wonder where `threadIdx` fits into th The program can be compiled using the following code: ```python -program = elementwise_add(1024, threads=256, dtype="bfloat16") +program = elementwise_add(1024, threads=256, dtype=T.bfloat16) kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython") ``` Launching the kernel is straightforward, just call it directly like a function: @@ -89,7 +89,7 @@ def elementwise_add( In the compilation process above, a fixed shape was used. However, in practical usage, we often want the kernel to support dynamic shapes. So, how can we compile a kernel in TileLang to handle dynamic shapes? In TileLang, we can replace the target size with a dynamic symbolic value, making the dimension dynamic. The following example illustrates this: ```python -program = elementwise_add(T.dynamic("N"), threads=256, dtype="bfloat16") +program = elementwise_add(T.dynamic("N"), threads=256, dtype=T.bfloat16) kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython") ``` @@ -102,7 +102,7 @@ TileLang automatically incorporates boundary-checking conditions; however, this When compiling the example below, let's set `N` to 2047: ```python -def elementwise_add(N, num_per_thread=8, threads=256, dtype="bfloat16"): +def elementwise_add(N, num_per_thread=8, threads=256, dtype=T.bfloat16): @T.prim_func def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): @@ -176,7 +176,7 @@ While TileLang incorporates various optimizations for the aforementioned case, i In such scenarios, explicitly specifying the number of elements computed per thread can help "guide" TileLang's code generation process, leading to implementations that are more closely aligned with the intended design. ```python -def elementwise_add(N, num_per_thread=8, threads=256, dtype="bfloat16"): +def elementwise_add(N, num_per_thread=8, threads=256, dtype=T.bfloat16): @T.prim_func def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): @@ -212,7 +212,7 @@ Aha, this CUDA code aligns closely with conventional programming practices, maki But what happens if we provide additional hints to TileLang? For instance, by explicitly specifying register copies using the `T.copy(...)` operation. The example below demonstrates a vector addition implementation. Unlike the previous examples, this code explicitly loads data into registers before performing computations. ```python -def elementwise_add(N, NUM_ELE_PER_THREAD=8, threads=256, dtype="bfloat16"): +def elementwise_add(N, NUM_ELE_PER_THREAD=8, threads=256, dtype=T.bfloat16): @T.prim_func def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): @@ -280,8 +280,8 @@ To evaluate complexity, one could implement the same elementwise addition operat ```c++ template -__global__ void elementwise_add(nv_bfloat16* C, - const nv_bfloat16* A, +__global__ void elementwise_add(nv_bfloat16* C, + const nv_bfloat16* A, const nv_bfloat16* B, int N) { using namespace cute; diff --git a/docs/deeplearning_operators/gemv.md b/docs/deeplearning_operators/gemv.md index c75a961b8..c2dddf47f 100644 --- a/docs/deeplearning_operators/gemv.md +++ b/docs/deeplearning_operators/gemv.md @@ -6,7 +6,7 @@
:::{warning} - This document is still **experimental** and may be incomplete. + This document is still **experimental** and may be incomplete. Suggestions and improvements are highly encouraged—please submit a PR! ::: @@ -206,7 +206,6 @@ def splitk_gemv( return main ``` - ## Vectorized Reads GEMV is less computation intensive than GEMM as the computation intensity and memory throughput will be the optimization bottleneck. One effective strategy is to use vectorized load/store operations (e.g., `float2`, `float4`). In `TileLang`, you can specify vectorized operations via `T.vectorized`: @@ -254,7 +253,6 @@ def splitk_gemv_vectorized( With vectorized read, now the kernel finishes in **~0.0084 ms**, which is getting close to cuBLAS performance. - ## `tvm_thread_allreduce` Instead of `atomicAdd` [`tvm_thread_allreduce`](https://tvm.apache.org/docs/reference/api/python/tir/tir.html#tvm.tir.tvm_thread_allreduce) has implemented optimization when making an all-reduce across a number of threads, which should outperfrom out plain smem + `atomidAdd`: @@ -294,7 +292,7 @@ def splitk_gemv_vectorized_tvm( C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) C_reduced = T.alloc_local((1,), accum_dtype) with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + T.comm_reducer(lambda x, y: x + y, [T.cast(0, accum_dtype)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ): @@ -379,7 +377,7 @@ def get_best_config(N, K): C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) C_reduced = T.alloc_local((1,), accum_dtype) with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + T.comm_reducer(lambda x, y: x + y, [T.cast(0, accum_dtype)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ): @@ -459,6 +457,5 @@ This corresponds closely to our `TileLang` program, with necessary synchronizati | splitk_gemv_vectorized | 0.00809 ms | | splitk_gemv_vectorized_tvm | 0.00675 ms | - Triton Time: 0.0077344514429569244 -In this tutorial, we implemented a simple GEMV kernel and learn that `TileLang` exposes low level control to user such as thread-level programming and CUDA primitives. \ No newline at end of file +In this tutorial, we implemented a simple GEMV kernel and learn that `TileLang` exposes low level control to user such as thread-level programming and CUDA primitives. diff --git a/docs/deeplearning_operators/matmul.md b/docs/deeplearning_operators/matmul.md index fea036ebe..12189eb8f 100644 --- a/docs/deeplearning_operators/matmul.md +++ b/docs/deeplearning_operators/matmul.md @@ -14,11 +14,11 @@ TileLang is a domain-specific language (DSL) designed for writing high-performance GPU kernels. It provides three main levels of abstraction: -* **Level 1:** A user writes pure compute logic without knowledge of or concern for hardware details (e.g., GPU caches, tiling, etc.). The compiler or runtime performs automatic scheduling and optimization. This level is conceptually similar to the idea behind TVM. +- **Level 1:** A user writes pure compute logic without knowledge of or concern for hardware details (e.g., GPU caches, tiling, etc.). The compiler or runtime performs automatic scheduling and optimization. This level is conceptually similar to the idea behind TVM. -* **Level 2:** A user is aware of GPU architecture concepts—such as shared memory, tiling, and thread blocks—but does not necessarily want to drop down to the lowest level of explicit thread control. This mode is somewhat comparable to Triton's programming model, where you can write tile-level operations and let the compiler do layout inference, pipelining, etc. +- **Level 2:** A user is aware of GPU architecture concepts—such as shared memory, tiling, and thread blocks—but does not necessarily want to drop down to the lowest level of explicit thread control. This mode is somewhat comparable to Triton's programming model, where you can write tile-level operations and let the compiler do layout inference, pipelining, etc. -* **Level 3:** A user takes full control of thread-level primitives and can write code that is almost as explicit as a hand-written CUDA/HIP kernel. This is useful for performance experts who need to manage every detail, such as PTX inline assembly, explicit thread behavior, etc. +- **Level 3:** A user takes full control of thread-level primitives and can write code that is almost as explicit as a hand-written CUDA/HIP kernel. This is useful for performance experts who need to manage every detail, such as PTX inline assembly, explicit thread behavior, etc. ```{figure} ../_static/img/overview.png :width: 50% @@ -52,12 +52,12 @@ While Level 1 in TileLang can be very comfortable for general users—since it r Below is a simplified code snippet for a 1024 x 1024 x 1024 matrix multiplication. It uses: -* **`T.Kernel(...)`** to initialize the thread block configuration (grid dimensions, block size, etc.). -* **`T.alloc_shared(...)`** to allocate GPU shared memory. -* **`T.alloc_fragment(...)`** to allocate a register fragment for accumulation. -* **`T.Pipelined(...)`** to express software pipelining across the K dimension. -* **`T.Parallel(...)`** to parallelize data copy loops. -* **`T.gemm(...)`** to perform tile-level GEMM operations (which map to the appropriate backends, such as MMA instructions on NVIDIA GPUs). +- **`T.Kernel(...)`** to initialize the thread block configuration (grid dimensions, block size, etc.). +- **`T.alloc_shared(...)`** to allocate GPU shared memory. +- **`T.alloc_fragment(...)`** to allocate a register fragment for accumulation. +- **`T.Pipelined(...)`** to express software pipelining across the K dimension. +- **`T.Parallel(...)`** to parallelize data copy loops. +- **`T.gemm(...)`** to perform tile-level GEMM operations (which map to the appropriate backends, such as MMA instructions on NVIDIA GPUs). ```python import tilelang @@ -147,14 +147,12 @@ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, - This sets up the block grid dimensions based on N/block_N and M/block_M. - `threads=128` specifies that each thread block uses 128 threads. The compiler will infer how loops map to these threads. - ```{figure} ../_static/img/Parallel.png :alt: Parallel :align: center ``` - 2. **Shared & Fragment Memory**: ```python @@ -182,7 +180,6 @@ for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): ``` - 4. **Parallel Copy**: ```python @@ -252,8 +249,8 @@ For more advanced usage—including partial lowering, explicitly controlling thr ## Further Resources -* [TileLang GitHub](https://github.com/tile-ai/tilelang) -* [BitBLAS](https://github.com/tile-ai/bitblas) -* [Triton](https://github.com/openai/triton) -* [Cutlass](https://github.com/NVIDIA/cutlass) -* [PyCUDA](https://documen.tician.de/pycuda/) +- [TileLang GitHub](https://github.com/tile-ai/tilelang) +- [BitBLAS](https://github.com/tile-ai/bitblas) +- [Triton](https://github.com/openai/triton) +- [Cutlass](https://github.com/NVIDIA/cutlass) +- [PyCUDA](https://documen.tician.de/pycuda/) diff --git a/docs/deeplearning_operators/matmul_sparse.md b/docs/deeplearning_operators/matmul_sparse.md new file mode 100644 index 000000000..8caa6182f --- /dev/null +++ b/docs/deeplearning_operators/matmul_sparse.md @@ -0,0 +1,261 @@ +# Sparse Matrix-Matrix Multiplication with Tile Library + +
+ Author: botbw +
+ +:::{warning} + This document is still **experimental** and may be incomplete. + + This feature is still **experimental** and need further optimization. + + Suggestions and improvements are highly encouraged—please submit a PR! +::: + +:::{tip} +It's suggested to go through `docs/deeplearning_operators/matmul.md` first. + +Example code can be found at `examples/gemm_sp`. +::: + +## Structured sparsity in the NVIDIA Ampere architecture + +Since the Ampere architecture (sm80 and above), sparsity support has been integrated into Tensor Cores. This allows a 2:4 (or 1:2 for 32-bit data types) semi-structured matrix to be compressed into its non-zero values along with associated metadata, which can then be fed into the Tensor Core. This enables up to **2x throughput** compared to the equivalent dense computation. + +:::{warning} + This tutorial primarily focuses on CUDA, as this feature is not yet supported on ROCm. However, AMD provides a similar capability in the matrix cores of GPUs such as the MI300X. +::: + +```{figure} ../_static/img/sparse_mma_storage_example.png +:align: center + +Figure: Sparse MMA storage example (from PTX doc) +``` + +## Compress a dense tensor + +To utilize sparse Tensor Cores, a dense tensor must first be **compressed** into its non-zero values along with the corresponding metadata. + +Both `PyTorch` and `vLLM` use `CUTLASS` as their computation backend (see references [here](https://github.com/pytorch/pytorch/blob/a8d6afb511a69687bbb2b7e88a3cf67917e1697e/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu#L47) and [here](https://github.com/vllm-project/vllm/blob/a5dd03c1ebc5e4f56f3c9d3dc0436e9c582c978f/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh#L116)), leveraging `CUTLASS`’s built-in compressor (or reimplementing it in `PyTorch`). + +A set of **CUTLASS-compatible** compressors is provided in `tilelang.utils.sparse`, where a dense tensor—along with other required arguments (e.g., block_K for sm90, transpose options)—can be passed in to perform the compression. + +```python +from tilelang.utils.sparse import compress +A_sparse, E = compress(A, transposed=trans_A, block_k=block_K) +``` + +Here, `A_sparse` contains all the non-zero elements of `A`, while `E` stores the corresponding metadata (indexing information) required to reconstruct the original sparse pattern. + +> NOTE: When using CUTLASS compressor, there is no naive position correspondence between the positions in `A_sparse`/`A` and `E`. (i.e. the 4-element group at [n, k] doesn't match the 4-bit metadata at [n, k] if you consider metadata as int4 tensor) +The metadata is reordered internally to optimize memory access patterns (e.g., for ldsm instructions and vectorized loads). +For more information, see **A note on `gemm_sp` and `gemm_sp_v2`**. + +## `T.gemm_sp` with CUTLASS's compressor + +:::{warning} + +It is strongly recommended to use T.gemm_sp_v2 due to its greater flexibility and faster compilation time. + +::: + +A 2:4 sparse GEMM kernel is similar to its dense counterpart, except that it also requires handling the associated metadata. + +Check comments in below kernel code for required modification. + +```python +def matmul_sp_sm80( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, + trans_A, + trans_B, +): + is_8_bit = "8" in in_dtype + metadata_dtype = 'int32' if is_8_bit else 'int16' + E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype] # Calculate shape for given datatypes + A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) + B_shape = (K, N) if not trans_B else (N, K) + A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) + B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K) + + import tilelang.language as T + + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) # Allocate smem for metadata + C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) + T.annotate_layout({ # Annotate reordered cutlass metadata layout + E: + make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: + make_cutlass_metadata_layout( + E_shared, mma_dtype=in_dtype, arch="8.0"), + }) + T.clear(C_frag) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(E[by * block_M, k * block_K // E_factor], E_shared) + if trans_A: + T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) + else: + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B) # Call gemm_sp with non-zero values and metadata + T.copy(C_frag, C[by * block_M, bx * block_N]) + + return main +``` + +Under the hood, `gemm_sp` invokes templates adapted from `CUTLASS`, and a compatible metadata layout must be specified using `T.annotate_layout`. + +## `T.gemm_sp_v2` with a custom compressor + +To migrate to `gemm_sp_v2`, simply replace occurrences of `gemm_sp`. + +Unlike `gemm_sp`, `gemm_sp_v2` can operate without `T.annotate_layout`, and it also supports user-defined layouts and compressors. + +The metadata is stored in a `(u)int8`/`(u)int16`/`(u)int32` tensor, where **each 4-bit chunk represents two 2-bit indices** of non-zero elements within four consecutive elements. Here, we start with an `int16` example, which is the **default dtype** for `bf16` and `fp16` on Ampere GPUs. + +Suppose we have the following row vector: +```python +t = tensor([[0, 7, 0, 3], [1, 5, 0, 0], [0, 0, 2, 4], [9, 0, 9, 0]], dtype=torch.float16).flatten() +``` + +The non-zero elements and their corresponding indices are: + +```python +t_sp = tensor([[7, 3], [1, 5], [2, 4], [9, 9]], dtype=torch.float16).flatten() +indices = tensor([[1, 3], [0, 1], [2, 3], [0, 2]], dtype=torch.float16).flatten() +``` + +The corresponding uint16 metadata is: +```python +# metadata_bits = tensor([0b1101, 0b0100, 0b1110, 0b1000]) +# Note: storage uses little-endian order: tensor(0b1000111001001101, dtype=torch.int16) +# Note: the above code is not runnable in python as the interpreter won't take the binary +# as 2's complement +metadata_int16 = tensor(-29107) +``` + +You can decode an int16 metadata tensor using the following utility: +```python +def decode_metadata(meta: torch.Tensor) -> torch.Tensor: + assert meta.dtype is torch.int16 + groups_per_meta = 16 // 4 + out = [] + for g in range(groups_per_meta): + group_bits = (meta >> (g * 4)) & 0xF + idx0 = group_bits & 0x3 + idx1 = (group_bits >> 2) & 0x3 + out.append(torch.stack([idx0, idx1], dim=-1)) + return torch.concat(out, dim=-1).view(meta.shape[0], -1) +``` + +The compressor can be implement at either `PyTorch`/`NumPy` level or kernel level. + +For example, `PyTorch` provides an Ampere compressor [here](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L47-L179). Note that in this implementation, a [permutation](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L173-L175) is applied to match CUTLASS’s metadata layout. If you do not annotate a metadata layout when using `gemm_sp_v2`, your compressor should replicate the same behavior as the PyTorch example—but without using the `_calculate_meta_reordering_scatter_offsets` function. + +If you want to use a custom metadata layout in your kernel, one approach is to define the layout in `TileLang` and then apply the same layout to both your compressor kernel and the matmul_sp kernel. + +```python + +@tilelang.jit(out_idx=[1, 2], pass_configs={ + tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True, +}) +def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): + e_factor, e_dtype = ARCH_INFO["8.0"] + e_K = K // e_factor + elem, group = 2, 4 + + assert M % block_M == 0, "M must be divisible by block_M" + assert K % block_K == 0, "K must be divisible by block_K" + assert K % e_factor == 0, "K must be divisible by e_factor" + assert block_K % e_factor == 0, "block_K must be divisible by e_factor" + + @T.prim_func + def kernel( + A: T.Tensor((M, K), dtype), + A_sp: T.Tensor((M, K // 2), dtype), + E: T.Tensor((M, e_K), e_dtype), + ): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype) + E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) + if use_cutlass_layout: # NOTE: Make sure compressor metadata layout + T.annotate_layout({ # is same with your computation kernel + E: + make_cutlass_metadata_layout( + E, mma_dtype="float16", arch="8.0", block_k=block_K), + E_shared: + make_cutlass_metadata_layout( + E_shared, + mma_dtype="float16", + arch="8.0", + block_k=block_K), + }) + T.clear(A_sp_shared) + T.clear(E_shared) + non_zero_cnt = T.alloc_local((1, ), dtype="uint8") + non_zero_elt_log_idx = T.alloc_local((elem, ), dtype="uint8") + T.copy(A[bx * block_M, by * block_K], A_shared) + for tm in T.Parallel(block_M): + for g_i in range(0, block_K // group): + a_k = g_i * group + T.clear(non_zero_cnt) + T.clear(non_zero_elt_log_idx) + for i in range(group): + val = A_shared[tm, a_k + i] + if val != 0.0: + non_zero_elt_log_idx[non_zero_cnt[0]] = i + A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val + non_zero_cnt[0] += 1 + if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3: + non_zero_elt_log_idx[0] = 0 + non_zero_elt_log_idx[1] = 3 + A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2] + A_sp_shared[tm, a_k // 2] = 0.0 + elif non_zero_cnt[0] == 1: + A_sp_shared[tm, a_k // 2 + 1] = 0 + non_zero_elt_log_idx[1] = 3 + for i in T.serial(elem): + val = non_zero_elt_log_idx[i] + E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i) + T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2]) + T.copy(E_shared, E[bx * block_M, by * block_K // e_factor]) + + return kernel +``` + +## A note on `gemm_sp` and `gemm_sp_v2` + +Initially, `T.gemm_sp` followed the same design as `T.gemm`, lowering to a `CUTLASS` template. This inherently requires metadata to be reordered offline following a predetermined layout. + +However, fixing a specific layout introduces several potential issues: + +1. Painful debugging experience: Debugging a failed kernel becomes difficult due to the reordered indexing, including permutations and swizzling. + +2. Limited flexibility: For example, concatenating two compressed tensors, such as `A_sparse_0` and `A_sparse_1`, into a new `A_sparse` makes sense. However, concatenating their metadata `E_0` and `E_1` may not be valid unless the layout allows it mathematically. + +3. Alignment requirements: `CUTLASS` enforces strict alignment checks, and many hyperparameter configurations can lead to compilation errors. (For reference, sm8x was implemented in `CUTLASS 2`.) + +`T.gemm_sp_v2` was designed to address these limitations, following the approach of `T.gemm_v2`. It lowers directly to PTX, removing the need for a fixed metadata layout. diff --git a/docs/get_started/Installation.md b/docs/get_started/Installation.md index 3d5c6db9d..b23026d9b 100644 --- a/docs/get_started/Installation.md +++ b/docs/get_started/Installation.md @@ -8,25 +8,25 @@ - **Python Version**: >= 3.8 - **CUDA Version**: 12.0 <= CUDA < 13 -The easiest way to install **tile-lang** is directly from PyPI using pip. To install the latest version, run the following command in your terminal: +The easiest way to install tilelang is directly from PyPI using pip. To install the latest version, run the following command in your terminal: ```bash pip install tilelang ``` -Alternatively, you may choose to install **tile-lang** using prebuilt packages available on the Release Page: +Alternatively, you may choose to install tilelang using prebuilt packages available on the Release Page: ```bash pip install tilelang-0.0.0.dev0+ubuntu.20.4.cu120-py3-none-any.whl ``` -To install the latest version of **tile-lang** from the GitHub repository, you can run the following command: +To install the latest version of tilelang from the GitHub repository, you can run the following command: ```bash pip install git+https://github.com/tile-ai/tilelang.git ``` -After installing **tile-lang**, you can verify the installation by running: +After installing tilelang, you can verify the installation by running: ```bash python -c "import tilelang; print(tilelang.__version__)" @@ -40,18 +40,18 @@ python -c "import tilelang; print(tilelang.__version__)" - **Python Version**: >= 3.8 - **CUDA Version**: >= 10.0 -```bash -docker run -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.01-py3 -``` +If you prefer Docker, please skip to the [Install Using Docker](#install-using-docker) section. This section focuses on building from source on a native Linux environment. -To build and install **tile-lang** directly from source, follow these steps. This process requires certain pre-requisites from Apache TVM, which can be installed on Ubuntu/Debian-based systems using the following commands: +First, install the OS-level prerequisites on Ubuntu/Debian-based systems using the following commands: ```bash apt-get update apt-get install -y python3 python3-dev python3-setuptools gcc zlib1g-dev build-essential cmake libedit-dev ``` -After installing the prerequisites, you can clone the **tile-lang** repository and install it using pip: +Then, clone the tilelang repository and install it using pip. The `-v` flag enables verbose output during the build process. + +> **Note**: Use the `--recursive` flag to include necessary submodules. Tilelang currently depends on a customized version of TVM, which is included as a submodule. If you prefer [Building with Existing TVM Installation](#using-existing-tvm), you can skip cloning the TVM submodule (but still need other dependencies). ```bash git clone --recursive https://github.com/tile-ai/tilelang.git @@ -59,13 +59,19 @@ cd tilelang pip install . -v ``` -If you want to install **tile-lang** in development mode, you can run the following command: +If you want to install tilelang in development mode, you can use the `-e` flag so that any changes to the Python files will be reflected immediately without reinstallation. ```bash pip install -e . -v ``` -If you prefer to work directly from the source tree via `PYTHONPATH`, make sure the native extension is built first: +> **Note**: changes to C++ files require rebuilding the tilelang C++ library. See [Faster Rebuild for Developers](#faster-rebuild-for-developers) below. A default `build` directory will be created if you use `pip install`, so you can also directly run `make` in the `build` directory to rebuild it as [Working from Source via PYTHONPATH](#working-from-source-via-pythonpath) suggested below. + +(working-from-source-via-pythonpath)= + +### Working from Source via `PYTHONPATH` (Recommended for Developers) + +If you prefer to work directly from the source tree via `PYTHONPATH` instead of using pip, make sure the native extension (`libtilelang.so`) is built first: ```bash mkdir -p build @@ -73,6 +79,14 @@ cd build cmake .. -DUSE_CUDA=ON make -j ``` + +We also recommend using `ninja` to speed up compilation: + +```bash +cmake .. -DUSE_CUDA=ON -G Ninja +ninja +``` + Then add the repository root to `PYTHONPATH` before importing `tilelang`, for example: ```bash @@ -85,17 +99,23 @@ Some useful CMake options you can toggle while configuring: - `-DUSE_ROCM=ON` selects ROCm support when building on AMD GPUs. - `-DNO_VERSION_LABEL=ON` disables the backend/git suffix in `tilelang.__version__`. -We currently provide four methods to install **tile-lang**: +(using-existing-tvm)= + +### Building with Customized TVM Path -1. [Install Using Docker](#install-method-1) (Recommended) -2. [Install from Source (using the bundled TVM submodule)](#install-method-2) -3. [Install from Source (using your own TVM installation)](#install-method-3) +If you already have a TVM codebase, use the `TVM_ROOT` environment variable to specify the location of your existing TVM repository when building tilelang: + +```bash +TVM_ROOT= pip install . -v +``` -(install-method-1)= +> **Note**: This will still rebuild the TVM-related libraries (stored in `TL_LIBS`). And this method often leads to some path issues. Check `env.py` to see some environment variables which are not set properly. -### Method 1: Install Using Docker (Recommended) +(install-using-docker)= -For users who prefer a containerized environment with all dependencies pre-configured, **tile-lang** provides Docker images for different CUDA versions. This method is particularly useful for ensuring consistent environments across different systems and is the **recommended approach** for most users. +## Install Using Docker + +For users who prefer a containerized environment with all dependencies pre-configured, tilelang provides Docker images for different CUDA versions. This method is particularly useful for ensuring consistent environments across different systems. **Prerequisites:** - Docker installed on your system @@ -142,82 +162,106 @@ docker run -itd \ - `--name tilelang_b200`: Assigns a name to the container for easy management - `/bin/zsh`: Uses zsh as the default shell -4. **Access the Container**: +4. **Access the Container and Verify Installation**: ```bash docker exec -it tilelang_b200 /bin/zsh -``` - -5. **Verify Installation**: - -Once inside the container, verify that **tile-lang** is working correctly: - -```bash +# Inside the container: python -c "import tilelang; print(tilelang.__version__)" ``` -You can now run TileLang examples and develop your applications within the containerized environment. The Docker image comes with all necessary dependencies pre-installed, including CUDA toolkit, TVM, and TileLang itself. +### ROCm container build (gfx942/gfx950) -**Example Usage:** +If you want a ready-to-use ROCm image that builds TileLang from source, use +`docker/Dockerfile.rocm`. This is the recommended path for a clean, reproducible +environment. -After accessing the container, you can run TileLang examples: +If you are already inside another ROCm container (for example, the `sglang` +image) and just need to rebuild TileLang in-place, follow the steps below. -```bash -cd /home/tilelang/examples -python elementwise/test_example_elementwise.py -``` - -This Docker-based installation method provides a complete, isolated environment that works seamlessly on systems with compatible NVIDIA GPUs like the B200, ensuring optimal performance for TileLang applications. - -(install-method-2)= - -### Method 2: Install from Source (Using the Bundled TVM Submodule) - -If you already have a compatible TVM installation, follow these steps: - -1. **Clone the Repository**: +If you are using the `sglang` ROCm container and need to build TileLang in it (for example on MI300 `gfx942` or MI355 `gfx950`), the build requires extra system libraries, Cython, and a valid `llvm-config`. The following steps match the build flow used in `sglang/docker/rocm.Dockerfile`: ```bash -git clone --recursive https://github.com/tile-ai/tilelang -cd tilelang -``` - -**Note**: Use the `--recursive` flag to include necessary submodules. - -2. **Configure Build Options**: - -Create a build directory and specify your existing TVM path: - -```bash -pip install . -v -``` - -(install-method-3)= - -### Method 3: Install from Source (Using Your Own TVM Installation) - -If you prefer to use the built-in TVM version, follow these instructions: - -1. **Clone the Repository**: - -```bash -git clone --recursive https://github.com/tile-ai/tilelang -cd tilelang +# Inside the container (as root) +apt-get update && apt-get install -y --no-install-recommends \ + build-essential git wget curl ca-certificates gnupg \ + libgtest-dev libgmock-dev \ + libprotobuf-dev protobuf-compiler libgflags-dev libsqlite3-dev \ + python3 python3-dev python3-setuptools python3-pip \ + gcc libtinfo-dev zlib1g-dev libedit-dev libxml2-dev \ + cmake ninja-build pkg-config libstdc++6 \ + && rm -rf /var/lib/apt/lists/* + +# Prefer the container venv (avoid system pip) +export PATH="/opt/venv/bin:${PATH}" + +# Build GoogleTest static libs (Ubuntu package ships sources only) +cmake -S /usr/src/googletest -B /tmp/build-gtest -DBUILD_GTEST=ON -DBUILD_GMOCK=ON -DCMAKE_BUILD_TYPE=Release +cmake --build /tmp/build-gtest -j"$(nproc)" +cp -v /tmp/build-gtest/lib/*.a /usr/lib/x86_64-linux-gnu/ +rm -rf /tmp/build-gtest + +# Keep setuptools < 80 (compat with some base images) +pip install --upgrade "setuptools>=77.0.3,<80" wheel cmake ninja scikit-build-core + +# Locate ROCm llvm-config (install LLVM 18 if missing) +LLVM_CONFIG_PATH="" +for p in /opt/rocm/llvm/bin/llvm-config /opt/rocm/llvm-*/bin/llvm-config /opt/rocm-*/llvm*/bin/llvm-config; do + if [ -x "$p" ]; then LLVM_CONFIG_PATH="$p"; break; fi +done +if [ -z "$LLVM_CONFIG_PATH" ]; then + echo "ROCm llvm-config not found; installing LLVM 18..." + curl -fsSL https://apt.llvm.org/llvm.sh -o /tmp/llvm.sh + chmod +x /tmp/llvm.sh + /tmp/llvm.sh 18 + LLVM_CONFIG_PATH="$(command -v llvm-config-18)" + if [ -z "$LLVM_CONFIG_PATH" ]; then + echo "ERROR: llvm-config-18 not found after install" + exit 1 + fi +fi +export LLVM_CONFIG="$LLVM_CONFIG_PATH" +export PATH="$(dirname "$LLVM_CONFIG"):/usr/local/bin:${PATH}" + +# Optional shim for tools that expect llvm-config-16 +mkdir -p /usr/local/bin +printf "#!/usr/bin/env bash\nexec \"%s\" \"\$@\"\n" "$LLVM_CONFIG_PATH" > /usr/local/bin/llvm-config-16 +chmod +x /usr/local/bin/llvm-config-16 + +# TVM Python bits need Cython (for system Python used by the build) +pip install --no-cache-dir "cython>=0.29.36,<3.0" + +# Clone + build TileLang (ROCm) +# Default location: /opt/tilelang (adjust if you prefer a different path). +git clone --recursive https://github.com/tile-ai/tilelang.git /opt/tilelang +cd /opt/tilelang +git submodule update --init --recursive +export CMAKE_ARGS="-DUSE_CUDA=OFF -DUSE_ROCM=ON -DROCM_PATH=/opt/rocm -DLLVM_CONFIG=${LLVM_CONFIG}" + +# Avoid pulling CUDA wheels / reinstalling torch by skipping dependency resolution. +# Assume torch is already installed in the container. +pip install -e . -v --no-build-isolation --no-deps + +# Manually install required runtime deps when using --no-deps. +# Note: skip torch-c-dlpack-ext on ROCm (its wheel expects CUDA libs). +pip install "apache-tvm-ffi>=0.1.6" "z3-solver>=4.13.0" +# If you already installed torch-c-dlpack-ext and hit `libtorch_cuda.so` errors: +# pip uninstall -y torch-c-dlpack-ext + +# If you hit Cython compile errors like `PyLong_SHIFT`/`digit` not declared, +# disable the stable ABI (abi3) for editable builds: +# export CMAKE_ARGS="-DUSE_CUDA=OFF -DUSE_ROCM=ON -DROCM_PATH=/opt/rocm -DLLVM_CONFIG=${LLVM_CONFIG} -DSKBUILD_SABI_VERSION=" +# pip install -e . -v --no-build-isolation --no-deps + +# Verify +python -c "import tilelang; print(tilelang.__version__)" ``` -**Note**: Ensure the `--recursive` flag is included to fetch submodules. - -2. **Configure Build Options**: - -Copy the configuration file and enable the desired backends (e.g., LLVM and CUDA): - -```bash -TVM_ROOT= pip install . -v -``` +If you still want to use `pip install -e . -v --no-build-isolation` without `--no-deps`, pip will try to resolve TileLang dependencies and may download CUDA wheels (e.g., `nvidia_cudnn`, `nvidia_nvshmem`) and reinstall `torch`. To avoid that in ROCm containers, keep `--no-deps` and ensure required packages are already installed. ## Install with Nightly Version -For users who want access to the latest features and improvements before official releases, we provide nightly builds of **tile-lang**. +For users who want access to the latest features and improvements before official releases, we provide nightly builds of tilelang. ```bash pip install tilelang -f https://tile-ai.github.io/whl/nightly/cu121/ @@ -252,24 +296,28 @@ Set `NO_TOOLCHAIN_VERSION=ON` to disable this. ### Run-time environment variables - +Please refer to the `env.py` file for a full list of supported run-time environment variables. -## IDE Configs +## Other Tips -Building tilelang locally will automatically `compile_commands.json` file in `build` dir. +### IDE Configs + +Building tilelang locally will automatically generate a `compile_commands.json` file in `build` dir. VSCode with clangd and [clangd extension](https://marketplace.visualstudio.com/items?itemName=llvm-vs-code-extensions.vscode-clangd) should be able to index that without extra configuration. -## Compile cache +### Compile Cache -`ccache` will be automatically used if found. +The default path of the compile cache is `~/.tilelang/cache`. `ccache` will be automatically used if found. -## Repairing wheels +### Repairing Wheels If you plan to use your wheel in other environment, -it's recommend to use auditwheel (on Linux) or delocate (on Darwin) +it's recommended to use auditwheel (on Linux) or delocate (on Darwin) to repair them. -## Faster rebuild for developers +(faster-rebuild-for-developers)= + +### Faster Rebuild for Developers `pip install` introduces extra [un]packaging and takes ~30 sec to complete, even if no source change. @@ -278,8 +326,17 @@ Developers who needs to recompile frequently could use: ```bash pip install -r requirements-dev.txt + +# For first time compilation pip install -e . -v --no-build-isolation +# Or manually compile with cmake/ninja. Remember to set PYTHONPATH properly. +mkdir build +cd build +cmake .. -G Ninja +ninja + +# Rebuild when you change the cpp code cd build; ninja ``` diff --git a/docs/get_started/overview.md b/docs/get_started/overview.md index 18fa9f193..a7c154f31 100644 --- a/docs/get_started/overview.md +++ b/docs/get_started/overview.md @@ -15,49 +15,49 @@ Figure 1: High-level overview of the TileLang compilation flow. ## Programming Interfaces 1. **Beginner Level (Hardware-Unaware)** - - Intended for users who need to write code that is independent of specific hardware details. - - The goal is to let developers focus on the basic logic without worrying about memory hierarchies or hardware-specific optimizations. + - Intended for users who need to write code that is independent of specific hardware details. + - The goal is to let developers focus on the basic logic without worrying about memory hierarchies or hardware-specific optimizations. - *Note:* This interface is not yet fully implemented. 2. **Developer Level (Hardware-Aware with Tile Library)** - - Designed for developers who have a basic understanding of GPU memory hierarchies and performance considerations. - - Provides a **Tile Library**, containing predefined operations and patterns optimized for various hardware architectures. + - Designed for developers who have a basic understanding of GPU memory hierarchies and performance considerations. + - Provides a **Tile Library**, containing predefined operations and patterns optimized for various hardware architectures. - Users at this level can leverage these ready-made primitives without diving into low-level threading details. 3. **Expert Level (Hardware-Aware with Thread Primitives)** - - For highly experienced users who have an in-depth understanding of low-level hardware characteristics (e.g., threading models, memory coalescing). - - Offers direct access to **thread primitives** and other low-level constructs, allowing for fine-grained control of performance-critical kernels. + - For highly experienced users who have an in-depth understanding of low-level hardware characteristics (e.g., threading models, memory coalescing). + - Offers direct access to **thread primitives** and other low-level constructs, allowing for fine-grained control of performance-critical kernels. - This level grants maximum flexibility for specialized optimizations tailored to specific GPU or multi-core architectures. ## Compilation Flow -1. **Tile Program** +1. **Tile Program** A high-level specification of the computation. Depending on the user’s expertise, they may write a purely hardware-unaware tile program or incorporate constructs from the Tile Library or thread primitives. -2. **Tile Program with Tile Library** +2. **Tile Program with Tile Library** When developers choose from the Tile Library, the original Tile Program is expanded with specialized library calls. These calls encapsulate efficient implementation patterns for different operations. -3. **Tile Program with Thread Primitives** +3. **Tile Program with Thread Primitives** Expert-level developers can explicitly use low-level threading constructs to hand-optimize data layout, synchronization, and memory usage. -4. **IRModule** +4. **IRModule** After the program is composed with libraries or thread primitives, it is lowered to an intermediate representation (IR) that captures the necessary hardware details. -5. **Source Code Generation (C/CUDA/HIP/LLVM/…)** +5. **Source Code Generation (C/CUDA/HIP/LLVM/…)** From the IR, the system generates target-specific source code. This source code is tuned for the desired backends or GPU architectures (e.g., NVIDIA, AMD). -6. **Hardware-Specific Executable/Runtime** +6. **Hardware-Specific Executable/Runtime** Finally, the generated source is compiled into hardware-specific executables, ready to run on the corresponding devices. The pipeline supports multiple GPU backends and can be extended to additional architectures. ## Tile-based Programming Model -[Figure 2](#fig-overview-gemm) provides a concise matrix multiplication (GEMM) example in ``TileLang``, -illustrating how developers can employ high-level constructs such as tiles, memory placement, pipelining, +[Figure 2](#fig-overview-gemm) provides a concise matrix multiplication (GEMM) example in ``TileLang``, +illustrating how developers can employ high-level constructs such as tiles, memory placement, pipelining, and operator calls to manage data movement and computation with fine-grained control. -In particular, this snippet ([Figure 2](#fig-overview-gemm) (a)) demonstrates how multi-level tiling -leverages different memory hierarchies (global, shared, and registers) to optimize bandwidth utilization +In particular, this snippet ([Figure 2](#fig-overview-gemm) (a)) demonstrates how multi-level tiling +leverages different memory hierarchies (global, shared, and registers) to optimize bandwidth utilization and reduce latency. -Overall, [Figure 2](#fig-overview-gemm) (b) showcases how the Python-like syntax of ``TileLang`` +Overall, [Figure 2](#fig-overview-gemm) (b) showcases how the Python-like syntax of ``TileLang`` allows developers to reason about performance-critical optimizations within a user-friendly programming model. ```{figure} ../_static/img/MatmulExample.png diff --git a/docs/get_started/targets.md b/docs/get_started/targets.md index c2b3f2fb5..3a464bd66 100644 --- a/docs/get_started/targets.md +++ b/docs/get_started/targets.md @@ -14,6 +14,7 @@ the generated code. The most frequent choices are listed below: | --------- | ----------- | | `auto` | Detects CUDA → HIP → Metal in that order. Useful when running the same script across machines. | | `cuda` | NVIDIA GPUs. Supports options such as `-arch=sm_80`, `-max_num_threads=1024`, etc. | +| `cutedsl` | NVIDIA CUTLASS/CuTe DSL backend. Requires `nvidia-cutlass-dsl`. `cuda` options can also be applied to this target. | | `hip` | AMD GPUs via ROCm. Options like `-mcpu=gfx90a` can be appended. | | `metal` | Apple Silicon GPUs (arm64 Macs). | | `llvm` | CPU execution; accepts the standard TVM LLVM switches. | diff --git a/docs/index.md b/docs/index.md index 5d9a158f8..1c78ea2f6 100644 --- a/docs/index.md +++ b/docs/index.md @@ -2,10 +2,10 @@ [GitHub](https://github.com/tile-ai/tilelang) -Tile Language (tile-lang) is a concise domain-specific language designed to streamline -the development of high-performance GPU/CPU kernels (e.g., GEMM, Dequant GEMM, FlashAttention, LinearAttention). -By employing a Pythonic syntax with an underlying compiler infrastructure on top of TVM, -tile-lang allows developers to focus on productivity without sacrificing the +Tile Language (tile-lang) is a concise domain-specific language designed to streamline +the development of high-performance GPU/CPU kernels (e.g., GEMM, Dequant GEMM, FlashAttention, LinearAttention). +By employing a Pythonic syntax with an underlying compiler infrastructure on top of TVM, +tile-lang allows developers to focus on productivity without sacrificing the low-level optimizations necessary for state-of-the-art performance. :::{toctree} @@ -17,13 +17,26 @@ get_started/overview get_started/targets ::: - :::{toctree} :maxdepth: 1 :caption: TUTORIALS tutorials/debug_tools_for_tilelang tutorials/auto_tuning +tutorials/logging +::: + +:::{toctree} +:maxdepth: 1 +:caption: PROGRAMMING GUIDES + +programming_guides/overview +programming_guides/language_basics +programming_guides/instructions +programming_guides/control_flow +programming_guides/python_compatibility +programming_guides/autotuning +programming_guides/type_system ::: :::{toctree} @@ -33,6 +46,7 @@ tutorials/auto_tuning deeplearning_operators/elementwise deeplearning_operators/gemv deeplearning_operators/matmul +deeplearning_operators/matmul_sparse deeplearning_operators/deepseek_mla ::: @@ -42,6 +56,7 @@ deeplearning_operators/deepseek_mla compiler_internals/letstmt_inline compiler_internals/inject_fence_proxy +compiler_internals/tensor_checks ::: :::{toctree} diff --git a/docs/programming_guides/autotuning.md b/docs/programming_guides/autotuning.md new file mode 100644 index 000000000..9cc5a2d94 --- /dev/null +++ b/docs/programming_guides/autotuning.md @@ -0,0 +1,308 @@ +# Autotuning + +TileLang includes a built‑in autotuner that searches configuration spaces +for the best performing kernel, compiles candidates in parallel, validates +correctness, benchmarks them, and caches the best result for reuse. + +This guide covers two workflows: +- Decorator‑based: `@tilelang.autotune(configs=...)` stacked on `@tilelang.jit` +- Programmatic: `AutoTuner.from_kernel(...).set_*().run()` + +It also explains input tensor supply, validation, caching, and environment +variables that affect parallelism and cache behavior. + +## 1) Decorator‑based Autotune + +Use `@tilelang.autotune` above `@tilelang.jit` and expose tunable parameters as +function arguments with defaults. The autotuner overrides these parameters with +values from your config space. + +```python +import tilelang +import tilelang.language as T + +def matmul_configs(M, N, K): + # Example space — tailor to your target + tiles = [64, 128] + stages = [2, 3] + threads = [128, 256] + return [ + dict(block_M=BM, block_N=BN, block_K=BK, num_stages=S, threads=TH) + for BM in tiles + for BN in tiles + for BK in [32, 64] + for S in stages + for TH in threads + ] + +@tilelang.autotune(configs=matmul_configs, warmup=25, rep=100, timeout=60) +@tilelang.jit(out_idx=[-1]) +def matmul(M: int, N: int, K: int, + block_M: int = 128, block_N: int = 128, block_K: int = 32, + threads: int = 128, num_stages: int = 3, + dtype: str = 'float16', accum_dtype: str = 'float32'): + + @T.prim_func + def kernel(A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_s = T.alloc_shared((block_M, block_K), dtype) + B_s = T.alloc_shared((block_K, block_N), dtype) + C_f = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_f) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, ko * block_K], A_s) + T.copy(B[ko * block_K, bx * block_N], B_s) + T.gemm(A_s, B_s, C_f) + + T.copy(C_f, C[by * block_M, bx * block_N]) + + return kernel + +# Usage +# Provide inputs via context (recommended for reproducibility across configs) +import torch +M = N = K = 1024 +A = torch.randn(M, K, device='cuda', dtype=torch.float16) +B = torch.randn(K, N, device='cuda', dtype=torch.float16) +C = torch.empty(M, N, device='cuda', dtype=torch.float16) + +from tilelang.autotuner import set_autotune_inputs +with set_autotune_inputs(A, B, C): + tuned_kernel = matmul(M, N, K) # compiles, tunes, returns best kernel + tuned_kernel(A, B, C) # run best kernel +``` + +Notes +- `configs` can be a list of dicts or a callable `(args...) -> list[dict]`. Each + dict’s keys must match the tunable function arguments (e.g., `block_M`). +- The decorator returns a callable that runs autotune once per argument tuple + and caches the resulting best kernel in‑process. +- For explicit input control during tuning, wrap the call with + `set_autotune_inputs(...)`. Otherwise, `supply_type` (below) is used. + +## 2) Programmatic Autotune + +Use the `AutoTuner` class to manage configs and arguments more explicitly. + +```python +from tilelang.autotuner import AutoTuner + +kernel_factory = matmul # the function above (already @tilelang.jit) +tuner = AutoTuner.from_kernel(kernel_factory(M, N, K), configs=matmul_configs(M, N, K)) + +tuner.set_profile_args( + warmup=25, rep=100, timeout=60, + supply_type=tilelang.TensorSupplyType.Auto, # or provide supply_prog/ref_prog + ref_prog=lambda A, B, C: torch.allclose(C, (A @ B).to(C.dtype), rtol=1e-2, atol=1e-2), +) + +tuner.set_compile_args( + target='auto', # or 'cuda'/'hip'/'metal' + execution_backend='auto', # resolves per-target + out_idx=[-1], # which outputs to return if multiple + pass_configs={ # optional TVM passes/flags + # tilelang.PassConfigKey.EXAMPLE_KEY: value, + }, +) + +artifact = tuner.run() # compiles + runs + validates all configs +best_kernel = artifact.kernel # JITKernel +best_latency = artifact.latency +best_config = artifact.config + +# Reuse best kernel +best_kernel(A, B, C) +``` + +### Example Gallery (in repo) +- examples/gdn/example_chunk_delta_h.py:101 — uses `@autotune` to sweep configs +- examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py:451 — uses `@tilelang.autotune` +- examples/quickstart.py:84 — profiles a tuned kernel with `get_profiler` +- examples/hadamard_transform/example_hadamard.py:152 — profiler with custom warmup +- examples/dynamic_shape/example_dynamic.py:94 — profiler for dynamic shapes +- examples/gemm/example_gemm_persistent.py:135 — compare persistent vs non‑persistent + +Click any path to open the code and compare patterns. + +## Input Tensor Supply + +The tuner needs inputs to compile and benchmark kernels. Provide them in one of +three ways (priority order): + +1) Context manager (fixed inputs across configs) +```python +with set_autotune_inputs(A, B, C): + tuned = matmul(M, N, K) +``` + +2) Custom supplier program +```python +def supply_prog(signature): + # signature holds KernelParam objects describing shapes/dtypes + # Return a list of torch tensors matching the kernel’s arguments + return [A, B, C] + +tuner.set_profile_args(supply_prog=supply_prog) +``` + +3) Built‑in generators via `supply_type` +- `TensorSupplyType.Auto` (default): heuristic per dtype (uniform ints / fp ranges) +- `Integer`, `Uniform`, `Normal`, `Randn`, `Zero`, `One` + +Important +- Built‑in generators require static shapes; if your PrimFunc uses symbolic + dimensions (T.dyn), supply concrete inputs via (1) or (2). +- Float8 dtypes require PyTorch 2.1+ for `torch.float8_*` support. + +## Correctness Checking and Tolerances + +Use one of the following validation methods: +- `ref_prog`: Provide a reference program that receives the same inputs and + checks results. You can return a boolean or raise on mismatch. +- `manual_check_prog`: A callable that inspects outputs and raises on mismatch. +- `skip_check=True`: Skip correctness checks (faster, use with caution). + +Control numeric drift via: +- `rtol` and `atol` (defaults 1e‑2) +- `max_mismatched_ratio` (default 1%) + +## Configuration Spaces and Best Practices + +What to tune +- Tile sizes: `block_M`, `block_N`, `block_K` +- Software pipelining: `num_stages` +- Threads per block: `threads` (or (x, y) tuple) +- Optional: dtype variants, epilogues, small scheduling knobs + +Tips +- Start from a working baseline. Tune a small, meaningful space first. +- Respect hardware limits (shared memory bytes, registers per thread/block, + max threads per block). Eliminate impossible configs up‑front. +- Keep block sizes multiples of vector widths and warp sizes when relevant. +- Use `set_autotune_inputs` to ensure each config is measured on identical data. +- Record your best configs and bake them as defaults when stable. + +## Parallel Compilation/Benchmarking and Timeouts + +The tuner compiles configurations in parallel using a thread pool and benchmarks +them with a per‑config timeout. On CUDA, each worker sets the current device to +avoid context issues. + +Notes +- `timeout` uses POSIX signals; on non‑Unix systems, it may not take effect. +- Logs are written to `autotuner.log` in the working directory. + +## Caching + +The autotuner caches best artifacts both in‑memory (per process) and on disk under +`$TILELANG_CACHE_DIR/autotuner`. The cache key includes: +- TileLang version, function source, closure free‑vars +- Config list, compile args, profile args + +Disk cache contents (per key) +- Best config and latency: `best_config.json`, `latency.json` +- Kernel sources and library: `device_kernel.cu`, `host_kernel.cu`, `kernel_lib.so` (or `kernel.cubin`/`executable.so` depending on backend) +- Function and params: `function.pkl`, `params.pkl` + +Control via env vars (tilelang.env) +- `TILELANG_CACHE_DIR` (default `~/.tilelang/cache`) +- `TILELANG_TMP_DIR` (default `$TILELANG_CACHE_DIR/tmp`) +- Disable all kernel caches: `TILELANG_DISABLE_CACHE=1` +- Disable autotune disk cache only: `TILELANG_AUTO_TUNING_DISABLE_CACHE=1` + +CPU worker control +- `TILELANG_AUTO_TUNING_CPU_UTILITIES` (fraction, default 0.9) +- `TILELANG_AUTO_TUNING_CPU_COUNTS` (int, `-1` auto) +- `TILELANG_AUTO_TUNING_MAX_CPU_COUNT` (int, `-1` unlimited) + +Backend notes +- NVRTC backend persists `.cubin` and a Python launcher. +- Torch/DLPack backend may not save artifacts to disk; in this case, only + in‑memory caching applies and a warning is logged. + +## Alternative: Manual Sweeps with par_compile + +If you prefer manual control, use `JITImpl.par_compile` to compile a batch of +configs and drive your own benchmarking: + +```python +@tilelang.jit +def factory(M, N, K, block_M=128, block_N=128, block_K=32): + @T.prim_func + def k(A: T.Tensor((M, K), 'float16'), + B: T.Tensor((K, N), 'float16'), + C: T.Tensor((M, N), 'float16')): + ... + return k + +impl = factory # JITImpl +cfgs = [ + dict(block_M=64, block_N=128, block_K=32), + dict(block_M=128, block_N=128, block_K=64), +] +kernels = impl.par_compile(cfgs, num_workers=4) +# Now benchmark kernels[i](A, B, C) yourself +``` + +## Recording and Reusing Best Configs + +The programmatic path returns an `AutotuneResult` that can be saved and later +reloaded. This is useful for CI, multi‑host workflows, or shipping tuned configs. + +```python +artifact = tuner.run() # AutotuneResult + +# Save to disk +from pathlib import Path +save_dir = Path('out/best/matmul_1024') +artifact.save_to_disk(save_dir, verbose=True) + +# Reload later +from tilelang.autotuner.param import AutotuneResult, CompileArgs +restored = AutotuneResult.load_from_disk(save_dir, CompileArgs()) +best = restored.kernel +best(A, B, C) +``` + +Notes +- DLPack/Torch execution backend may not persist compiled binaries; in that + case, re‑compilation is needed on load or use a different backend. +- The directory contains human‑readable JSONs (best config/latency) and sources. + +## Advanced: Config Space Callables + +Derive config spaces from problem sizes to keep searches targeted and legal: + +```python +def matmul_configs(M, N, K): + large = min(M, N, K) >= 1024 + tiles = [128] if large else [64, 128] + for BM in tiles: + for BN in tiles: + for BK in [32, 64]: + for S in [2, 3]: + for TH in [128, 256]: + yield dict(block_M=BM, block_N=BN, block_K=BK, + num_stages=S, threads=TH) +``` + +## Device and Backend Selection + +Tune compile‑time options explicitly: +- `target='auto'|'cuda'|'hip'|'metal'` (normalized to a TVM Target) +- `execution_backend='auto'|'tvm_ffi'|'cython'|'nvrtc'|'torch'` +- `pass_configs={...}` to toggle TileLang/TVM passes for experiments + +On CUDA with multiple GPUs, the tuner sets the current device per worker thread +to avoid context mixups. + +## Troubleshooting +- “No configurations to tune”: Ensure `configs` is a non‑empty list or callable. +- Timeouts: Increase `timeout`; ensure inputs fit device memory; verify that + your reference check isn’t the bottleneck. +- Dynamic shapes: Provide concrete inputs via `set_autotune_inputs` or a custom + `supply_prog`. +- Disk cache disabled: Check `TILELANG_AUTO_TUNING_DISABLE_CACHE` and backend. diff --git a/docs/programming_guides/control_flow.md b/docs/programming_guides/control_flow.md new file mode 100644 index 000000000..259441349 --- /dev/null +++ b/docs/programming_guides/control_flow.md @@ -0,0 +1,149 @@ +# Control Flow + +This guide covers the control‑flow primitives in TileLang and how they lower to +efficient GPU code. You will use these to structure loops, handle boundaries, +and express pipelined compute. + +## Overview +- Conditionals: `if` / `elif` / `else`, ternary (`x if c else y`) +- Loops: `T.serial`, `T.unroll`, `T.Parallel`, `T.Pipelined` +- While loops: `while` with a TIR condition +- Flow control: Python `break` / `continue` +- Safety: automatic OOB guards via the LegalizeSafeMemoryAccess pass + +The examples assume `import tilelang.language as T`. + +## Conditionals + +Standard Python `if`/`elif`/`else` is supported inside `@T.prim_func` kernels. +Conditions should be TIR expressions (e.g., `i < N`). Python plain booleans are +treated as compile‑time constants and will be folded. + +```python +for i in T.serial(N): + if i < N: # TIR condition + C[i] = A[i] + B[i] + else: + pass + +# Ternary +x = (A[i] if i < N else 0) +``` + +Short‑circuit boolean ops are supported. For multi‑dimensional bounds, use +`T.any_of` / `T.all_of` for clarity: + +```python +if T.all_of(i < M, j < N): + C[i, j] = A[i, j] + B[i, j] +``` + +Boundary handling note +- The LegalizeSafeMemoryAccess pass automatically inserts guards when an access + may be out‑of‑bounds, and elides them when proven safe. You can often omit + explicit `if` checks for simple edge handling, but keep them when you need + custom logic or clarity. + +## Loops + +### Serial + +`T.serial` creates a plain for‑loop. Common forms: + +```python +for i in T.serial(N): + ... # 0..N-1 + +for i in T.serial(0, N, 2): + ... # 0, 2, 4, ... +``` + +### Unroll + +`T.unroll` requests loop unrolling for small trip counts. + +```python +for k in T.unroll(K_TILE): + acc += a[k] * b[k] +``` + +Advanced: TileLang forwards unroll hints to TIR; factor/explicit knobs are +available for expert tuning. + +### Parallel (elementwise) + +`T.Parallel(ext0, ext1, ...)` builds nested loops that map well to elementwise +operations. The body receives all indices in one `for` header: + +```python +for i, j in T.Parallel(M, N): + C[i, j] = A[i, j] + B[i, j] +``` + +Optional hints: +- `coalesced_width=` controls memory coalescing width (used for vectorization checks). +- `loop_layout=` accepts a `T.Fragment` to annotate the layout of the entire + nested parallel loop. The annotation is attached to the outermost loop only + and must have `InputDim == number of nested parallel extents`. + +### Pipelined (software pipelining) + +`T.Pipelined(iters, num_stages=...)` overlaps producer/consumer stages (e.g., +Global→Shared copies with compute). This is the backbone of GEMM/attention +pipelines. + +```python +for ko in T.Pipelined(T.ceildiv(K, BK), num_stages=3): + T.copy(A[by * BM, ko * BK], A_s) # stage: copy A tile + T.copy(B[ko * BK, bx * BN], B_s) # stage: copy B tile + T.gemm(A_s, B_s, C_f) # stage: compute +``` + +### Persistent (advanced) + +`T.Persistent(domain, wave_size, index, group_size=...)` exposes persistent +thread‑block style looping. It is an advanced construct that TileLang lowers in +later passes and is typically used by specialized templates. + +## While Loops + +`while` is supported when the condition is a TIR expression. Avoid infinite +loops; TileLang will error if it detects a constant‑true condition. + +```python +i = 0 +while i < N: + ... + if done: + break + i += 1 +``` + +## Break and Continue + +Use Python `break`/`continue` to exit or skip within `T.serial`/`T.unroll`/ +`T.Parallel`/`while` loops. Keep the body clean after a `break`/`continue` for +readability; the compiler will ignore the dead path. + +## Putting It Together: Residual Tile Handling + +Below is a typical edge pattern for a 2D kernel. With LegalizeSafeMemoryAccess, +the explicit guard can be omitted when you don’t need a custom edge path. + +```python +for i, j in T.Parallel(M, N): + gi = by * BM + i + gj = bx * BN + j + if T.all_of(gi < M, gj < N): # optional in many cases + C[gi, gj] = A[gi, gj] + B[gi, gj] +``` + +## Debugging Conditions + +Use `T.print` to inspect values under predicates. For buffers, TileLang prints +from a single thread to avoid duplicate outputs. + +```python +if i == 0: + T.print(C, msg='C tile:') +``` diff --git a/docs/programming_guides/instructions.md b/docs/programming_guides/instructions.md new file mode 100644 index 000000000..20beb8325 --- /dev/null +++ b/docs/programming_guides/instructions.md @@ -0,0 +1,184 @@ +# Instructions + +This page summarizes the core TileLang “instructions” available at the DSL +level, how they map to hardware concepts, and how to use them correctly. + +## Quick Categories +- Data movement: `T.copy`, `T.c2d_im2col`, staging Global ↔ Shared ↔ Fragment +- Compute primitives: `T.gemm`/`T.gemm_sp`, elementwise math (`T.exp`, `T.max`), + reductions (`T.reduce_sum`, `T.cumsum`, warp reducers) +- Control helpers: `T.clear`/`T.fill`, `T.reshape`/`T.view` +- Diagnostics: `T.print`, `T.device_assert` +- Advanced: atomics, memory barriers, warp‑group ops + +## Data Movement + +Use `T.copy(src, dst, *, coalesced_width=None, disable_tma=False, eviction_policy=None, loop_layout=None)` +to move tiles between memory scopes. It accepts `tir.Buffer`, `BufferLoad`, or +`BufferRegion`; extents are inferred or broadcast when possible. + +```python +# Global → Shared tiles (extents inferred from dst) +T.copy(A[by * BM, ko * BK], A_s) +T.copy(B[ko * BK, bx * BN], B_s) + +# Fragment/Register → Global (store result) +T.copy(C_f, C[by * BM, bx * BN]) +``` + +Semantics +- Extents are deduced from arguments; missing sides broadcast to the other’s rank. +- Access patterns are legalized and coalesced during lowering. Explicit + vectorization is not required in HL mode. +- Safety: the LegalizeSafeMemoryAccess pass inserts boundary guards when an + access may be out‑of‑bounds and drops them when proven safe. + +Other helpers +- `T.c2d_im2col(img, col, ...)`: convenience for conv‑style transforms. + +## Compute Primitives + +GEMM and sparse GEMM +- `T.gemm(A_shared, B_shared, C_fragment)`: computes a tile GEMM using shared + inputs and a fragment accumulator; lowered to target‑specific tensor cores. +- `T.gemm_sp(...)`: 2:4 sparse tensor core variant (see examples and README). + +Reductions and scans +- `T.reduce_sum`, `T.reduce_max`, `T.reduce_min`, `T.cumsum`, plus warp + reducers (`T.warp_reduce_sum`, etc.). +- Allocate and initialize accumulators via `T.alloc_fragment` + `T.clear` or + `T.fill`. + +Elementwise math +- Most math ops mirror TVM TIR: `T.exp`, `T.log`, `T.max`, `T.min`, `T.rsqrt`, + `T.sigmoid`, etc. Compose freely inside loops. + +Reshape/view (no copy) +- `T.reshape(buf, new_shape)` and `T.view(buf, shape=None, dtype=None)` create + new views that share storage, with shape/dtype checks enforced. + +## Synchronization (HL usage) + +In HL pipelines, you usually don’t need to write explicit barriers. Passes such +as PipelinePlanning/InjectSoftwarePipeline/InjectTmaBarrier orchestrate +producer/consumer ordering and thread synchronization behind the scenes. + +If you need debugging or explicit checks: +- `T.device_assert(cond, msg='')` emits device‑side asserts on CUDA targets. +- `T.print(obj, msg='...')` prints scalars or buffers safely from one thread. + +## Putting It Together: GEMM Tile + +```python +@T.prim_func +def gemm( + A: T.Tensor((M, K), 'float16'), + B: T.Tensor((K, N), 'float16'), + C: T.Tensor((M, N), 'float16'), +): + with T.Kernel(T.ceildiv(N, BN), T.ceildiv(M, BM), threads=128) as (bx, by): + A_s = T.alloc_shared((BM, BK), 'float16') + B_s = T.alloc_shared((BK, BN), 'float16') + C_f = T.alloc_fragment((BM, BN), 'float32') + T.clear(C_f) + + for ko in T.Pipelined(T.ceildiv(K, BK), num_stages=3): + T.copy(A[by * BM, ko * BK], A_s) # Global → Shared + T.copy(B[ko * BK, bx * BN], B_s) + T.gemm(A_s, B_s, C_f) # compute into fragment + + T.copy(C_f, C[by * BM, bx * BN]) # store back +``` + +## Instruction Reference (Concise) + +Below is a concise list of TileLang instructions grouped by category. For full +signatures, behaviors, constraints, and examples, refer to API Reference +(`autoapi/tilelang/index`). + +Data movement +- `T.copy(src, dst, ...)`: Move tiles between Global/Shared/Fragment. +- `T.c2d_im2col(img, col, ...)`: 2D im2col transform for conv. + +Memory allocation and descriptors +- `T.alloc_shared(shape, dtype, scope='shared.dyn')`: Allocate shared buffer. +- `T.alloc_fragment(shape, dtype, scope='local.fragment')`: Allocate fragment. +- `T.alloc_var(dtype, [init], scope='local.var')`: Scalar var buffer (1 elem). +- `T.alloc_barrier(arrive_count)`: Allocate and initialize one or more mbarriers. +- `T.alloc_tmem(shape, dtype)`: Tensor memory (TMEM) buffer (Hopper+). +- `T.alloc_reducer(shape, dtype, op='sum', replication=None)`: Reducer buf. +- `T.alloc_descriptor(kind, dtype)`: Generic descriptor allocator. + - `T.alloc_wgmma_desc(dtype='uint64')` + - `T.alloc_tcgen05_smem_desc(dtype='uint64')` + - `T.alloc_tcgen05_instr_desc(dtype='uint32')` +- `T.empty(shape, dtype='float32')`: Declare function output tensors. + +Compute primitives +- `T.gemm(A_s, B_s, C_f)`: Tile GEMM into fragment accumulator. +- `T.gemm_sp(...)`: Sparse (2:4) tensor core GEMM. +- Reductions: `T.reduce_sum/max/min/abssum/absmax`, bitwise `and/or/xor`. +- Scans: `T.cumsum`, finalize: `T.finalize_reducer`. +- Warp reducers: `T.warp_reduce_sum/max/min/bitand/bitor`. +- Elementwise math: TIR ops (`T.exp`, `T.log`, `T.max`, `T.min`, `T.rsqrt`, ...). +- Fast math: `T.__log/__log2/__log10/__exp/__exp2/__exp10/__sin/__cos/__tan`. +- IEEE math: `T.ieee_add/sub/mul/fmaf` (configurable rounding). +- Helpers: `T.clear(buf)`, `T.fill(buf, value)`. +- Views: `T.reshape(buf, shape)`, `T.view(buf, shape=None, dtype=None)`. + +Diagnostics +- `T.print(obj, msg='')`: Print scalar/buffer from one thread. +- `T.device_assert(cond, msg='')`: Device-side assert (CUDA). + +Logical helpers +- `T.any_of(a, b, ...)`, `T.all_of(a, b, ...)`: Multi-term predicates. + +Annotation helpers +- `T.use_swizzle(panel_size=..., enable=True)`: Rasterization hint. +- `T.annotate_layout({...})`: Attach explicit layouts to buffers. +- `T.annotate_safe_value(var, ...)`: Safety/const hints. +- `T.annotate_l2_hit_ratio(buf, ratio)`: Cache behavior hint. + +Synchronization helpers +- `T.pdl_trigger()`: Signal programmatic launch completion for the current kernel. +- `T.pdl_sync()`: Wait until kernel dependencies are satisfied. + +Atomics +- `T.atomic_add(dst, value, memory_order=None, return_prev=False, use_tma=False)`. +- `T.atomic_addx2(dst, value, return_prev=False)`; `T.atomic_addx4(...)`. +- `T.atomic_max(dst, value, memory_order=None, return_prev=False)`. +- `T.atomic_min(dst, value, memory_order=None, return_prev=False)`. +- `T.atomic_load(dst)`, `T.atomic_store(dst, value)`. + +Custom intrinsics +- `T.dp4a(A, B, C)`: 4‑element dot‑product accumulate. +- `T.clamp(x, lo, hi)`: Clamp to [lo, hi]. +- `T.loop_break()`: Break from current loop via intrinsic. + +Barriers, TMA, warp‑group +- Barriers: `T.alloc_barrier(arrive_count)`. +- Parity ops: `T.mbarrier_wait_parity(barrier, parity)`, `T.mbarrier_arrive(barrier)`. +- Expect tx: `T.mbarrier_expect_tx(...)`; sugar: `T.barrier_wait(id, parity=None)`. +- TMA: `T.create_tma_descriptor(...)`, `T.tma_load(...)`, + `T.tma_store_arrive(...)`, `T.tma_store_wait(...)`. +- Proxy/fences: `T.fence_proxy_async(...)`, `T.warpgroup_fence_operand(...)`. +- Warp‑group: `T.warpgroup_arrive()`, `T.warpgroup_commit_batch()`, + `T.warpgroup_wait(num_mma)`, `T.wait_wgmma(id)`. + +Lane/warp index +- `T.get_lane_idx(warp_size=None)`: Lane id in warp. +- `T.get_warp_idx_sync(warp_size=None)`: Canonical warp id (sync). +- `T.get_warp_idx(warp_size=None)`: Canonical warp id (no sync). +- `T.get_warp_group_idx(warp_size=None, warps_per_group=None)`: Group id. + +Register control +- `T.set_max_nreg(reg_count, is_inc)`, `T.inc_max_nreg(n)`, `T.dec_max_nreg(n)`. +- `T.annotate_producer_reg_dealloc(n=24)`, `T.annotate_consumer_reg_alloc(n=240)`. +- `T.no_set_max_nreg()`, `T.disable_warp_group_reg_alloc()`. + +## Notes on Dtypes + +Dtypes accept three equivalent forms: +- String: `'float32'` +- TileLang dtype: `T.float32` +- Framework dtype: `torch.float32` +All are normalized internally. See Type System for details. diff --git a/docs/programming_guides/language_basics.md b/docs/programming_guides/language_basics.md new file mode 100644 index 000000000..1152680c9 --- /dev/null +++ b/docs/programming_guides/language_basics.md @@ -0,0 +1,234 @@ +# Language Basics + +This page introduces the core TileLang (tile‑lang) DSL that you’ll use to write +high‑performance kernels. It focuses on how to define a kernel, express +iteration, move data across memory scopes, and run it with JIT. + +The examples use the conventional aliases: + +```python +import tilelang +import tilelang.language as T +from tilelang import jit +``` + +## 1. Defining a Kernel with `@T.prim_func` + +TileLang kernels are TIR (TVM IR) functions produced by the `@T.prim_func` +decorator. Arguments are annotated with shapes and dtypes via `T.Tensor` or +`T.Buffer`. + +Note on dtypes +- You can pass dtypes as a string (e.g., 'float32'), a TileLang dtype (e.g., `T.float32`), + or a framework dtype (e.g., `torch.float32`). TileLang normalizes all of these. + See Type System for details. + +```python +@T.prim_func +def add_kernel( + A: T.Tensor((N,), dtype), # dtype could be 'float32' | T.float32 | torch.float32 + B: T.Tensor((N,), dtype), + C: T.Tensor((N,), dtype), +): + ... # kernel body +``` + +- Shapes may be concrete integers or symbolic. For symbolic, you can pass + Python ints through the outer `@jit` wrapper (shown below), or annotate with + `T.dyn` when you want a named symbolic dimension. + +```python +# Named symbolic dimension (optional) +K = T.dyn['K'] +@T.prim_func +def uses_dyn(A: T.Tensor((K,), 'float32')): + ... +``` + +### Dynamic symbolic dimensions: two ways + +TileLang supports two complementary ways to introduce symbolic (dynamic) dims: + +- Type-level annotations via `T.dyn[...]` (recommended for function signatures) + - Use in `T.Tensor((T.dyn['K'], ...), dtype)` or bind once then reuse (as above). + - Inside the kernel body, prefer reading from the buffer’s shape, e.g. `M = A.shape[0]`. + +- Term-level variables via `T.dynamic(name, dtype)` + - Creates a TIR `tir.Var` you can use directly in expressions/loops. + - Handy when you need to reference the dimension symbol in the body. + +```python +# 1) Annotation-only symbol; read the bound size via shape +K = T.dyn['K'] # dtype defaults to int32 +@T.prim_func +def foo(A: T.Tensor((K,), 'float32')): + N = A.shape[0] + for i in T.serial(N): + ... + +# 2) Explicit Var symbol usable in the body +K = T.dynamic('K', 'int32') # or T.dynamic('K') defaults to int32 +@T.prim_func +def bar(A: T.Tensor((K,), 'float32')): + for i in T.serial(K): + ... +``` + +Notes +- `T.symbolic(name, dtype)` is a deprecated alias of `T.dynamic`; prefer `T.dynamic`. +- Under `@jit`, concrete sizes come from the actual tensor arguments at the first call. +- Symbols in annotations do not need to be separate kernel arguments; TileLang binds them from argument shapes. + +## 2. Launching Work with `T.Kernel` + +`with T.Kernel(...)` declares a launch context and creates block/thread +bindings. For GPU backends, specify a grid and threads per block. + +```python +with T.Kernel(grid_x, grid_y, threads=128) as (bx, by): + ... # bx/by are blockIdx.x/y +``` + +You rarely need raw thread indices; most kernels use structured loops +(`T.serial`, `T.unroll`, `T.Parallel`, `T.Pipelined`) inside a `T.Kernel`. + +## 3. Loops and Control Flow + +Core loop constructs map to familiar hardware patterns: + +- `T.serial(start, stop[, step])`: plain for‑loop +- `T.unroll(start, stop[, step])`: unrolled loop +- `T.Parallel(ext0, ext1, ...)`: nested parallel loops (elementwise‑friendly) +- `T.Pipelined(iters, num_stages=N)`: software pipelining for producer/consumer + +```python +for i in T.serial(N): + ... + +for i, j in T.Parallel(M, N): + C[i, j] = A[i, j] + B[i, j] + +for k in T.Pipelined(T.ceildiv(K, BK), num_stages=3): + # overlap copy/compute across stages + ... +``` + +Conditionals use standard Python `if`/`else`. Guard edges with predicates when +tile sizes do not divide problem sizes evenly. + +## 4. Memory Scopes and Allocation + +TileLang exposes key software‑managed scopes: + +- Global: device memory (default for `T.Tensor` arguments) +- Shared: on‑chip, block‑visible (`T.alloc_shared(shape, dtype)`) +- Fragment and scalars: per‑thread fragments and scalar vars but in Shared View + (`T.alloc_fragment`, `T.alloc_var`) + +```python +A_shared = T.alloc_shared((BM, BK), 'float16') +B_shared = T.alloc_shared((BK, BN), 'float16') +C_local = T.alloc_fragment((BM, BN), 'float32') +T.clear(C_local) # zero accumulators +``` + +## 5. Moving Data: `T.copy` + +Use `T.copy(src, dst)` to move tiles between scopes. It accepts buffers, +buffer regions, or buffer loads; extents are inferred or can be broadcast. + +```python +# Global -> Shared (tile copy), extents inferred from dst +T.copy(A[by * BM, ko * BK], A_shared) +T.copy(B[ko * BK, bx * BN], B_shared) + +# Fragment -> Global (store back) +T.copy(C_local, C[by * BM, bx * BN]) +``` + +`T.copy` performs coalescing and scope‑specific lowering during compilation. + +## 6. A Minimal End‑to‑End Example (Vector Add) + +```python +import tilelang +import tilelang.language as T +from tilelang import jit + +@jit # infers target from tensors at first call +def add(N: int, block: int = 256, dtype: str = 'float32'): + + @T.prim_func + def add_kernel( + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), + C: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, block), threads=block) as bx: + for i in T.Parallel(block): + gi = bx * block + i + # Optional — LegalizeSafeMemoryAccess inserts a guard when an access may be OOB + C[gi] = A[gi] + B[gi] + + return add_kernel + +# Host side (PyTorch shown; NumPy/DLPack also supported) +import torch +N = 1 << 20 +A = torch.randn(N, device='cuda', dtype=torch.float32) +B = torch.randn(N, device='cuda', dtype=torch.float32) +C = torch.empty(N, device='cuda', dtype=torch.float32) + +kernel = add(N) +kernel(A, B, C) # runs on GPU +torch.testing.assert_close(C, A + B) +``` + +Notes +- The `@jit` wrapper returns a callable kernel after the first compilation. +- You can pass compile‑time tunables (tile sizes, dtypes) through the outer + Python function and bake them into the generated TIR. + +## 7. Tiled GEMM Skeleton + +Below is a minimal pattern for a tiled GEMM using shared memory staging and a +fragment accumulator. It mirrors the quickstart style found in the repository. + +```python +@T.prim_func +def gemm( + A: T.Tensor((M, K), 'float16'), + B: T.Tensor((K, N), 'float16'), + C: T.Tensor((M, N), 'float16'), +): + with T.Kernel(T.ceildiv(N, BN), T.ceildiv(M, BM), threads=128) as (bx, by): + A_s = T.alloc_shared((BM, BK), 'float16') + B_s = T.alloc_shared((BK, BN), 'float16') + C_f = T.alloc_fragment((BM, BN), 'float32') + T.clear(C_f) + + for ko in T.Pipelined(T.ceildiv(K, BK), num_stages=3): + T.copy(A[by * BM, ko * BK], A_s) + T.copy(B[ko * BK, bx * BN], B_s) + T.gemm(A_s, B_s, C_f) # lowered to tensor‑core/ISA specific kernels + + T.copy(C_f, C[by * BM, bx * BN]) +``` + +## 8. Debugging and Printing + +Use `T.print` inside a kernel for quick introspection. TileLang emits printing +from a single thread for shared/fragment scopes to avoid floods. + +```python +T.print(C_f, msg='accumulator:') +T.print(A_s, msg='A tile:') +T.print(C[0], msg='C[0] = ') +``` + +## 9. Where to Go Next + +- Control flow details: see Programming Guides → Control Flow +- Memory topics: see Programming Guides → (removed cache/layout); basics are covered inline +- Autotuning tile sizes and mappings: Programming Guides → Autotuning +- Operator examples (GEMM, GEMV, attention): see Deep Learning Operators diff --git a/docs/programming_guides/overview.md b/docs/programming_guides/overview.md new file mode 100644 index 000000000..64b6d2039 --- /dev/null +++ b/docs/programming_guides/overview.md @@ -0,0 +1,27 @@ +# Programming Guides Overview + +This section provides a practical guide to writing high‑performance kernels with Tile Language (tile‑lang). +It mirrors the structure of a similar guide in another project and adapts it to tile‑lang concepts and APIs. + +- Audience: Developers implementing custom GPU/CPU kernels with tile‑lang +- Prereqs: Basic Python, NumPy/Tensor concepts, and familiarity with GPU programming notions +- Scope: Language basics, control flow, instructions, autotuning, and type system + +## What You’ll Learn +- How to structure kernels with TileLang’s core DSL constructs +- How to move data across global/shared/fragment and pipeline compute +- How to apply autotuning to tile sizes and schedules +- How to specify and work with dtypes in kernels + +## Suggested Reading Order +1. Language Basics +2. Control Flow +3. Instructions +4. Autotuning +5. Type System + +## Related Docs +- Tutorials: see existing guides in `tutorials/` +- Operators: examples in `deeplearning_operators/` + +> NOTE: This is a draft scaffold. Fill in code snippets and benchmarks as APIs evolve. diff --git a/docs/programming_guides/python_compatibility.md b/docs/programming_guides/python_compatibility.md new file mode 100644 index 000000000..b858e392a --- /dev/null +++ b/docs/programming_guides/python_compatibility.md @@ -0,0 +1,59 @@ +# Python Compatibility + +TileLang is a Python-embedded DSL, but not all Python syntax is supported inside +TileLang DSL. This guide clarifies what works, what doesn't, and how +to translate common Python patterns into TileLang equivalents. Specially, we focus on +the kernel part (scripts inside `with T.Kernel`) semantics. For host-side semantics when +using eager-style JIT, please stay tuned for our upcoming documentation. + +The following codes use the conventional aliases: + +```python +import tilelang +import tilelang.language as T +from tilelang import jit +``` + +## Control Flow & Loops + +| Python Feature | Supported | Notes / Alternative | +|-------------------------|:---------:|------------------------------------------| +| `for i in range(n)` | ✅ | Maps to `T.serial(n)` | +| `for i in range(a,b,s)` | ✅ | Maps to `T.serial(a, b, s)` | +| `for x in list` | ❌ | Use index-based loop | +| `while condition` | ✅ | | +| `if` / `elif` / `else` | ✅ | | +| `x if cond else y` | ✅ | Ternary expression | +| `break` / `continue` | ✅ | | +| `enumerate()` / `zip()` | ❌ | | + +## Data Access + +| Python Feature | Supported | Notes / Alternative | +|-------------------------|:---------:|------------------------------------------| +| `a[i]` indexing | ✅ | Multi-dim indexing supported: `a[i, j, k]` | +| `a[i:j]` slicing | ✅ | Creates `BufferRegion` | +| `a[-1]` negative index | ✅ | | + +## Assignment & Arithmetic Operations + +| Python Feature | Supported | Notes / Alternative | +|-------------------------|:---------:|------------------------------------------| +| `x = expr` | ✅ | | +| `+`, `-`, `*`, `/`, `%` | ✅ | Maps to device-side arithmetic operations | +| `+=`, `-=`, `*=`, etc. | ✅ | Augmented assignment | +| `a = b = c` | ❌ | Use separate assignments | + +## Functions & Classes + +As a kernel script language, TileLang doesn't support functions or classes. You can use `@T.macro` to define reusable code blocks, which will be inlined at compile time like `__device__` function. + +## Statements & Built-in Functions + +| Python Feature | Supported | Notes / Alternative | +|-------------------------|:---------:|------------------------------------------| +| `with` | ⚠️ | Only `T.Kernel`, `T.ws` | +| `assert` | ⚠️ | Use `T.device_assert` or `T.assert` | +| `print()` | ⚠️ | Use `T.print()`; `print` works for Python expressions | +| `len()` | ❌ | Use `buffer.shape[dim]` | +| `type()`, `isinstance()`| ❌ | | diff --git a/docs/programming_guides/type_system.md b/docs/programming_guides/type_system.md new file mode 100644 index 000000000..60061df3f --- /dev/null +++ b/docs/programming_guides/type_system.md @@ -0,0 +1,41 @@ +# Type System + +This page lists the data types supported by TileLang and how to specify them in +kernels. For full details and the authoritative list, see the API Reference +(`autoapi/tilelang/index`) and `tilelang.language.v2.dtypes`. + +How to specify dtypes +- Use any of the following forms; TileLang normalizes them internally: + - String: `'float32'`, `'int8'`, `'bfloat16'`, ... + - TileLang dtype object: `T.float32`, `T.int8`, `T.bfloat16`, ... + - Framework dtype: `torch.float32`, `torch.int8`, `torch.bfloat16`, ... + +Common scalar types +- Boolean: `bool` +- Signed integers: `int8`, `int16`, `int32`, `int64` +- Unsigned integers: `uint8`, `uint16`, `uint32`, `uint64` +- Floating‑point: `float16` (half), `bfloat16`, `float32`, `float64` + +Float8 and low‑precision families +- Float8: `float8_e3m4`, `float8_e4m3`, `float8_e4m3b11fnuz`, `float8_e4m3fn`, + `float8_e4m3fnuz`, `float8_e5m2`, `float8_e5m2fnuz`, `float8_e8m0fnu` +- Float6: `float6_e2m3fn`, `float6_e3m2fn` +- Float4: `float4_e2m1fn` + +Vectorized element types (SIMD packs) +- For many base types, vector‑packed variants are available by lane count: + `x2`, `x4`, `x8`, `x16`, `x32`, `x64`. +- Examples: + - Integers: `int8x2`, `int8x4`, ..., `int32x2`, `int32x4`, ... + - Unsigned: `uint8x2`, `uint8x4`, ... + - Floats: `float16x2`, `float16x4`, `float32x2`, `float32x4`, ... + - Float8/6/4 families also provide `x2/x4/x8/x16/x32/x64` where applicable, + e.g., `float8_e4m3x2`, `float8_e4m3x4`, `float6_e2m3fnx8`, `float4_e2m1fnx16`. + +Notes +- Availability of certain low‑precision formats (float8/6/4) depends on target + architecture and backend support. +- Choose accumulation dtypes explicitly for mixed‑precision compute (e.g., + GEMM with `float16` inputs and `float32` accumulators). +- The complete, up‑to‑date list is exposed in + `tilelang.language.v2.dtypes` and rendered in the API Reference. diff --git a/docs/tutorials/auto_tuning.md b/docs/tutorials/auto_tuning.md index 3f3cad832..33368a2f0 100644 --- a/docs/tutorials/auto_tuning.md +++ b/docs/tutorials/auto_tuning.md @@ -14,7 +14,7 @@ Auto-tuning a Tile Language program involves three main steps: ## Matrix Multiplication Example -The following example demonstrates auto-tuning matrix multiplication. Code has been simplified for readability - see `examples/gemm/example_gemm.py` for complete implementation. +The following example demonstrates auto-tuning matrix multiplication. Code has been simplified for readability - see `examples/gemm/example_gemm.py` for complete implementation. ### Step 1: Implement with Reserved Parameters Users can implement matrix multiplication in Tile Language while reserving parameters for optimization: @@ -145,4 +145,4 @@ for hint in roller_hints: config["thread_num"] = block_rows * block_cols * 32 config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization -``` \ No newline at end of file +``` diff --git a/docs/tutorials/debug_tools_for_tilelang.md b/docs/tutorials/debug_tools_for_tilelang.md index e18b13279..078440f34 100644 --- a/docs/tutorials/debug_tools_for_tilelang.md +++ b/docs/tutorials/debug_tools_for_tilelang.md @@ -12,7 +12,6 @@ A Tile Language program (hereafter referred to as a *program*) is transformed in 2. The program undergoes multiple *Passes* for transformation and optimization (the *lower* stage, see `tilelang/engine/lower.py`), finally producing an intermediate representation (e.g., LLVM or C for CPU, CUDA for NVIDIA GPUs, etc.). 3. The generated code is compiled by the respective compiler (e.g., nvcc) into a hardware-executable file. - ```{figure} ../_static/img/overview.png :width: 300 :alt: Overview of the compilation process @@ -22,9 +21,9 @@ A Tile Language program (hereafter referred to as a *program*) is transformed in During this process, users may encounter roughly three categories of issues: -* **Generation issues**: The Tile Language program fails to generate a valid hardware-executable file (i.e., errors during the lowering process). -* **Correctness issues**: The resulting executable runs, but produces incorrect results. -* **Performance issues**: The executable runs with performance significantly below the expected theoretical hardware limits. +- **Generation issues**: The Tile Language program fails to generate a valid hardware-executable file (i.e., errors during the lowering process). +- **Correctness issues**: The resulting executable runs, but produces incorrect results. +- **Performance issues**: The executable runs with performance significantly below the expected theoretical hardware limits. This tutorial focuses on the first two issues—how to debug generation and correctness problems. Performance tuning often requires using vendor-provided profiling tools (e.g., **Nsight Compute**, **rocProf**, etc.) for further hardware-level analysis, which we will address in future materials. @@ -52,7 +51,6 @@ func = matmul(1024, 1024, 1024, 128, 128, 32) TileLang essentially performs *progressive lowering*. For example, a `T.copy` may first be expanded into `T.Parallel` (see the pass `LowerTileOP`), which is then expanded again, eventually resulting in lower-level statements that can be translated to CUDA C code. - ```{figure} ../_static/img/ir_transform_diagram.png :width: 400 :alt: IR transformation diagram @@ -171,8 +169,138 @@ The output messages will include something like: msg='hello world' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): 0 ``` +### Visual Layout Inference For TileLang + The **Visual Layout Inference** tool automatically generates visual diagrams that illustrate the mapping between logical indices, thread IDs, and register file locations. + +When TileLang performs layout inference, it determines how fragment buffers are distributed across threads. The visual layout tool captures this information and generates: +1. **Textual output**: A human-readable description of the layout mapping +2. **Visual diagrams**: Color-coded plots showing the thread-to-data mapping + +The visual layout inference tool is controlled through the `TL_LAYOUT_VISUALIZATION_ENABLE` and `TL_LAYOUT_VISUALIZATION_FORMATS` pass configuration. By default, `TL_LAYOUT_VISUALIZATION_ENABLE` is **disabled** to avoid performance overhead during compilation. + +When enabled, `TL_LAYOUT_VISUALIZATION_FORMATS` accepts string values to control output formats: +- "txt": Text output only (same as default) +- "all": Generates all formats (TXT, PDF, PNG, SVG) +- "png": Generate PNG format only +- "pdf": Generate PDF format only +- "svg": Generate SVG format only +- "txt,svg": Generate multiple formats (comma-separated) in addition to text output + +The output messages of "txt" will include something like: +``` +C_local inferenced layout: + Shape: [32, 32] -> [8] + Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2 + Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2] +``` + +## AutoDD: Automatic Delta Debugging + +When dealing with complex TileLang programs that produce errors, manually isolating the bug can be tedious. **AutoDD** (Automatic Delta Debugging) is a built-in tool that automatically simplifies your program to the minimal code needed to reproduce a specific error. + +### What is Delta Debugging? + +Delta Debugging is an automated debugging technique that: +1. Takes a program that triggers a bug +2. Systematically removes code fragments +3. Checks if the simplified program still triggers the same bug +4. Produces the minimal code that reproduces the bug + +AutoDD uses a Probability Distribution Driven Delta Debugging (PDD) algorithm for efficient minimization. + +### Why Use AutoDD? + +- **Large codebases**: Real projects often have hundreds of lines of configuration, helper functions, and logging +- **Hard-to-locate errors**: Error messages may point to TVM/CUDA internals rather than your TileLang code +- **Time-saving**: Manually deleting code to isolate bugs is very time-consuming + +AutoDD can reduce a 200+ line program to just 30 lines, directly exposing the root cause. + +### Basic Usage + +```bash +python -m tilelang.autodd --err-msg "" -o +``` + +### Parameters + +| Parameter | Description | +|-----------|-------------| +| `source` | Path to the input Python source file | +| `--err-msg` | Error message to match (searched in stdout or stderr) | +| `-o, --output` | Path to the minimized output file | +| `--backend` | Execution backend: `runner` (faster) or `subproc` (more stable), default `runner` | +| `--timeout` | Timeout for each task in seconds, default 60 | +| `-j, --jobs` | Number of parallel jobs, default 1 | + +### Example + +Suppose you have a complex TileLang program with a GEMM shape mismatch bug: + +```python +# buggy_matmul.py (200+ lines) +@tilelang.jit +def buggy_matmul(M, N, K, block_M, block_N, block_K, ...): + @T.prim_func + def matmul_kernel(...): + with T.Kernel(...) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_M, block_N), dtype) # Bug: should be (block_K, block_N) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + # ... lots of other code ... + T.gemm(A_shared, B_shared, C_local) # Error here + return matmul_kernel +``` + +Run AutoDD to minimize: + +```bash +python -m tilelang.autodd buggy_matmul.py --err-msg "Dimension mismatch" -o minimized.py -j 4 +``` + +AutoDD will produce a minimal reproduction: + +```python +# minimized.py (~30 lines) +import tilelang.language as T + +def buggy_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32, *args, **kwargs): + @T.prim_func + def matmul_kernel(): + with T.Kernel(): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_M, block_N), dtype) # Bug exposed! + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.gemm(A_shared, B_shared, C_local) +``` + +### How AutoDD Works + +AutoDD uses AST (Abstract Syntax Tree) analysis with multiple rewrite rules: + +1. **Fast Reducers**: Remove statements, simplify if/for constructs +2. **Canonicalizers**: Expand with statements, add `*args, **kwargs` for compatibility +3. **Simplifiers**: Replace expressions with constants, simplify function calls +4. **Slow Reducers**: Remove arbitrary expressions, reduce integer constants + +### Tips + +- **Error message matching**: Use a unique substring from the error output +- **Timeout**: Increase `--timeout` for programs with long compilation times +- **Parallel jobs**: Use `-j 4` or higher to speed up minimization +- **Backend**: Try `--backend subproc` if `runner` is unstable + +### Complete Example + +A complete example is available in `examples/autodd/`: +- `tilelang_buggy.py`: A complex program with a bug (~200 lines) +- `tilelang_minimized_expected.py`: Expected output after AutoDD (~30 lines) +- `README.md`: Detailed documentation + ## Conclusion By carefully examining intermediate representations (IR) before final code generation—and by leveraging runtime printing through `T.print`—one can quickly diagnose where index calculations, copy logic, or other kernel operations deviate from the intended behavior. This two-pronged approach (inspecting IR transformations and using runtime prints) is often sufficient for resolving generation and correctness issues in TileLang programs. +For complex programs where manual debugging is tedious, **AutoDD** provides automated delta debugging to quickly isolate the minimal code that reproduces a bug. + For advanced performance tuning (e.g., analyzing memory bandwidth or occupancy), more specialized profiling tools such as **Nsight Compute**, **rocProf**, or vendor-specific profilers may be required. Those aspects will be covered in future documents. diff --git a/docs/tutorials/logging.md b/docs/tutorials/logging.md new file mode 100644 index 000000000..1a015432d --- /dev/null +++ b/docs/tutorials/logging.md @@ -0,0 +1,116 @@ +Logging in Tilelang/TVM +=================================================== +
+Author: SiriusNEO +
+ +## TVM Logging Overview + +Tilelang currently utilizes the logging system from TVM. The implementation can be found in: + +- [include/tvm/runtime/logging.h](https://github.com/apache/tvm/blob/main/include/tvm/runtime/logging.h): Macro definitions +- [src/runtime/logging.cc](https://github.com/apache/tvm/blob/main/src/runtime/logging.cc): Logging logic implementation + +The design style is inspired by [Google's glog](https://google.github.io/glog/stable/). + +## Logging Categories + +There are three primary macro types: + +```c++ +LOG(INFO) << "aaa"; +DLOG(INFO) << "aaa"; +VLOG(1) << "aaa"; +``` + +- **LOG**: Standard logging preserved in code for displaying necessary information at different levels during runtime. Most Tilelang C++ error reporting is implemented via `LOG(FATAL) << "error msg"`. +- **DLOG**: Debug logging for developer debugging output. DLOG is controlled at build time by the TVM_LOG_DEBUG environment variable and is **eliminated in Release builds through dead code elimination**. + - The key difference between LOG(DEBUG) and DLOG is this build-time elimination. We recommend using DLOG over LOG(DEBUG), as the latter has overlapping functionality and gets compiled into the release runtime. +- **VLOG**: [Verbose logging](https://google.github.io/glog/stable/logging/#verbose-logging), primarily for debugging. Its main feature is customizable verbosity levels. For example, VLOG(n) where n can be 1, 2, 3, 4, 5, or 6, enabling complex tracing requirements. In contrast, LOG and DLOG typically use predefined verbose levels like INFO and DEBUG. + - In practical Tilelang development, VLOG is used less frequently. + - TVM's VLOG is implemented using DLOG, thus inheriting DLOG's characteristics. + +Additional useful macros include various **CHECK** variants: + +```c++ +CHECK(cond) << "error msg"; +DCHECK(cond) << "error msg"; +ICHECK(cond) << "error msg"; +``` + +The implementation routes errors to LogFatal: + +```c++ +#define CHECK(x) \ + if (!(x)) \ + ::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \ + << "Check failed: (" #x << ") is false: " +``` +- **DCHECK**: Debug mode CHECK, only compiled in debug builds +- **ICHECK**: Internal Check that should exist in Release builds. When ICHECK fails, the entire system should report an error. + +## Logging Verbose Levels + +TVM defines 5 levels for LOG and DLOG (adding DEBUG compared to glog): + +```c++ +#define TVM_LOG_LEVEL_DEBUG 0 +#define TVM_LOG_LEVEL_INFO 1 +#define TVM_LOG_LEVEL_WARNING 2 +#define TVM_LOG_LEVEL_ERROR 3 +#define TVM_LOG_LEVEL_FATAL 4 +``` + +## Using Logging in TileLang Development + +### Guidelines + +For temporary debugging output in your code, there are no restrictions (you can even use std::cout). Just remember to remove it before submitting a PR. + +For meaningful logging that should remain in the Tilelang codebase: + +- Critical correctness checks: Use ICHECK with sufficient error messages to facilitate debugging when issues arise. +- Complex Pass debugging: For passes requiring intermediate output that may need future review (e.g., LayoutInference), use DLOG. +- General INFO/WARNING messages: Use standard LOG. + +### Enabling Log Output in Tilelang + +To specify current log level at runtime, we need to set the environment variable `TVM_LOG_LEVEL`. An example usage is: + +```c++ +TVM_LOG_DEBUG=1 python3 code.py +``` + +which enables all DEBUG/INFO (level <= 1) logs for all files. + +#### Detailed Rules for TVM_LOG_DEBUG Specification + +The parsing logic is in `logging.cc`. Reference: [HyperAI Zhihu Article](https://zhuanlan.zhihu.com/p/1933106843468665163). + +Launch Python with `TVM_LOG_DEBUG=`, where `` is a comma-separated list of level assignments in the form `=`. Important notes: + +- The special filename DEFAULT sets the LOG level for all files. +- `` can be set to -1 to disable LOG for that file. +- `` is the C++ source filename (e.g., .cc, not .h) relative to the `src/` directory in the TVM repository. The `src/` prefix is optional when specifying file paths. + +### Enabling Debug Mode + +To enable DLOG/DCHECK, developers need to first build Tilelang in Debug mode: + +```bash +cmake .. -DCMAKE_BUILD_TYPE=Debug -DUSE_CUDA=ON +``` + +Tilelang's CMake logic automatically adds the `TVM_LOG_DEBUG` macro, compiling all DLOG statements: + +```cmake +target_compile_definitions(tilelang_objs PRIVATE "TVM_LOG_DEBUG") +``` + +Then you also need to specify the runtime environment variables. For example, to use `DLOG(INFO) << "xxx"` for debugging, run your code with INFO level (1): `TVM_LOG_DEBUG=1`. + +:::{note} + **Important**: There are two TVM_LOG_DEBUG variables. (1) Compile-time macro: Determines whether debug content (like DLOG) is compiled into the .so file. Referenced in C++ source via #ifdef TVM_LOG_DEBUG. This is automatically enabled when using Debug build mode in CMake. (2) Runtime environment variable: Controls logging level at runtime. TVM provides a specification for this variable, allowing control over per-file logging levels. + + These two should ideally have different names, but TVM uses the same name for both, which can cause confusion. +::: diff --git a/examples/amd/example_amd_flash_attn_bwd.py b/examples/amd/example_amd_flash_attn_bwd.py index d47866e1e..27986ce78 100644 --- a/examples/amd/example_amd_flash_attn_bwd.py +++ b/examples/amd/example_amd_flash_attn_bwd.py @@ -2,7 +2,7 @@ import torch.nn.functional as F import tilelang import tilelang.language as T -from tilelang.primitives.gemm.base import GemmWarpPolicy +from tilelang.tileop.base import GemmWarpPolicy import itertools import argparse from functools import partial @@ -11,22 +11,20 @@ def ref_program(Q, K, V, is_causal, groups=1): - assert Q.size( - 2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" - assert Q.size( - 2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" dim = Q.size(-1) K_ref = K.repeat_interleave(groups, dim=2) V_ref = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K_ref) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K_ref) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V_ref) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V_ref) lse = torch.logsumexp(scores, dim=-1).float() return output, lse @@ -45,23 +43,23 @@ def get_fwd_configs(): valid_configs = [] - for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(block_M, block_N, num_split_q, - threads, num_stages, - enable_rasterization, k_pack, - panel_size, qk_coalesced_width, - v_coalesced_width): - valid_configs.append({ - "block_M": m, - "block_N": n, - "num_split_q": s, - "threads": t, - "num_stages": stages, - "enable_rasterization": r, - "k_pack": k, - "panel_size": p, - "qk_coalesced_width": qkw, - "v_coalesced_width": vw, - }) + for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product( + block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width + ): + valid_configs.append( + { + "block_M": m, + "block_N": n, + "num_split_q": s, + "threads": t, + "num_stages": stages, + "enable_rasterization": r, + "k_pack": k, + "panel_size": p, + "qk_coalesced_width": qkw, + "v_coalesced_width": vw, + } + ) return valid_configs @@ -85,23 +83,23 @@ def fast_flashattn( qk_coalesced_width: int, v_coalesced_width: int, ): - scale = (1.0 / dim)**0.5 + scale = (1.0 / dim) ** 0.5 head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 vec_size = qk_coalesced_width v_vec_size = v_coalesced_width @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), - LSE: T.Tensor([batch, heads, seq_len], accum_dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + LSE: T.Tensor([batch, heads, seq_len], accum_dtype), ): with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined): T.use_swizzle(panel_size, enable=enable_rasterization) @@ -111,10 +109,10 @@ def main( num_q_blocks = T.ceildiv(seq_len, block_M) - bx_loop_var = T.alloc_var("int32") + bx_loop_var = T.alloc_var(T.int32) bx_loop_var = b_split - with T.While(bx_loop_var < num_q_blocks): + while bx_loop_var < num_q_blocks: acc_o = T.alloc_fragment([block_M, dim], accum_dtype) m_i = T.alloc_fragment([block_M], accum_dtype) l_i = T.alloc_fragment([block_M], accum_dtype) @@ -135,33 +133,21 @@ def main( m_prev = T.alloc_fragment([block_M], accum_dtype) scale_factor = T.alloc_fragment([block_M], accum_dtype) - T.copy( - Q[bz, q_block_offset:q_block_offset + block_M, by, :], - Q_shared, - coalesced_width=vec_size) + T.copy(Q[bz, q_block_offset : q_block_offset + block_M, by, :], Q_shared, coalesced_width=vec_size) - loop_end_k = ( - T.ceildiv(q_block_offset + - block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_end_k = T.ceildiv(q_block_offset + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) row_sum = T.alloc_fragment([block_M], accum_dtype) for k in T.Pipelined(loop_end_k, num_stages=num_stages): kv_idx = k * block_N - T.copy( - K[bz, kv_idx:kv_idx + block_N, by // groups, :], - K_shared, - coalesced_width=vec_size) - T.copy( - V[bz, kv_idx:kv_idx + block_N, by // groups, :], - V_shared, - coalesced_width=v_vec_size) + T.copy(K[bz, kv_idx : kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size) + T.copy(V[bz, kv_idx : kv_idx + block_N, by // groups, :], V_shared, coalesced_width=v_vec_size) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) T.gemm( @@ -178,6 +164,8 @@ def main( T.copy(m_i, m_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for i in T.Parallel(block_M): + m_i[i] = T.max(m_i[i], m_prev[i]) for i in T.Parallel(block_M): if m_prev[i] == -T.infinity(accum_dtype): @@ -214,8 +202,7 @@ def main( for i in T.Parallel(block_M): if q_block_offset + i < seq_len: - lse_val = T.if_then_else(l_i[i] > 0, - T.log(l_i[i]) + m_i[i], -T.infinity(accum_dtype)) + lse_val = T.if_then_else(l_i[i] > 0, T.log(l_i[i]) + m_i[i], -T.infinity(accum_dtype)) LSE[bz, by, q_block_offset + i] = lse_val bx_loop_var = current_bx + num_split_q @@ -232,30 +219,30 @@ def get_bwd_configs(): panel_size = [7, 8, 9, 10] configs = [] - for m, n, stages, t, r, p in itertools.product(block_M, block_N, num_stages, threads, - enable_rasterization, panel_size): - configs.append({ - "block_M": m, - "block_N": n, - "num_stages": stages, - "threads": t, - "enable_rasterization": r, - "panel_size": p, - }) + for m, n, stages, t, r, p in itertools.product(block_M, block_N, num_stages, threads, enable_rasterization, panel_size): + configs.append( + { + "block_M": m, + "block_N": n, + "num_stages": stages, + "threads": t, + "enable_rasterization": r, + "panel_size": p, + } + ) return configs @tilelang.jit(out_idx=[2]) def flashattn_bwd_preprocess(batch, heads, seq_len, dim): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim] blk = 32 @T.prim_func - def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), - Delta: T.Tensor([batch, heads, seq_len], accum_dtype)): + def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), Delta: T.Tensor([batch, heads, seq_len], accum_dtype)): with T.Kernel(batch, heads, T.ceildiv(seq_len, blk)) as (bz, bx, by): o = T.alloc_fragment([blk, blk], dtype) do = T.alloc_fragment([blk, blk], dtype) @@ -263,36 +250,51 @@ def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep @tilelang.autotune(configs=get_bwd_configs(), cache_input_tensors=True) @tilelang.jit -def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, block_N: int, - num_stages: int, threads: int, enable_rasterization: bool, panel_size: int): - sm_scale = (1.0 / dim)**0.5 +def flashattn_bwd( + batch, + heads, + seq_len, + dim, + is_causal, + groups, + block_M: int, + block_N: int, + num_stages: int, + threads: int, + enable_rasterization: bool, + panel_size: int, +): + sm_scale = (1.0 / dim) ** 0.5 head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func - def flash_bwd_kernel(Q: T.Tensor(q_shape, - dtype), K: T.Tensor(kv_shape, - dtype), V: T.Tensor(kv_shape, dtype), - dO: T.Tensor(q_shape, dtype), lse: T.Tensor([batch, heads, seq_len], - accum_dtype), - Delta: T.Tensor([batch, heads, seq_len], - accum_dtype), dQ: T.Tensor(q_shape, accum_dtype), - dK: T.Tensor(kv_shape, accum_dtype), dV: T.Tensor(kv_shape, accum_dtype)): + def flash_bwd_kernel( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + dO: T.Tensor(q_shape, dtype), + lse: T.Tensor([batch, heads, seq_len], accum_dtype), + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), + dQ: T.Tensor(q_shape, accum_dtype), + dK: T.Tensor(kv_shape, accum_dtype), + dV: T.Tensor(kv_shape, accum_dtype), + ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): T.use_swizzle(panel_size, enable=enable_rasterization) @@ -313,8 +315,8 @@ def flash_bwd_kernel(Q: T.Tensor(q_shape, dk = T.alloc_fragment([block_M, dim], accum_dtype) dq = T.alloc_fragment([block_N, dim], accum_dtype) - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) @@ -322,22 +324,21 @@ def flash_bwd_kernel(Q: T.Tensor(q_shape, loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q_shared) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q_shared) T.clear(qkT) T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): P_acc[i, j] = T.exp(qkT[i, j] * sm_scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - P_acc[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, - P_acc[i, j], 0.0) + P_acc[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, P_acc[i, j], 0.0) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do_shared) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do_shared) T.clear(dP) T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -345,7 +346,7 @@ def flash_bwd_kernel(Q: T.Tensor(q_shape, T.copy(P_acc, p_cast) T.gemm(p_cast, do_shared, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta_shared) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta_shared) for i, j in T.Parallel(block_M, block_N): p_cast[i, j] = P_acc[i, j] * (dP[i, j] - delta_shared[j]) * sm_scale @@ -367,8 +368,8 @@ def flash_bwd_kernel(Q: T.Tensor(q_shape, @tilelang.jit(out_idx=[1]) def flashattn_bwd_postprocess(batch, heads, seq_len, dim): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim] blk = 64 @@ -376,8 +377,8 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): def flash_bwd_post(dQ_in: T.Tensor(shape, accum_dtype), dQ_out: T.Tensor(shape, dtype)): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.copy( - dQ_in[bz, bx * blk:(bx + 1) * blk, by, :], - dQ_out[bz, bx * blk:(bx + 1) * blk, by, :], + dQ_in[bz, bx * blk : (bx + 1) * blk, by, :], + dQ_out[bz, bx * blk : (bx + 1) * blk, by, :], ) return flash_bwd_post @@ -444,22 +445,14 @@ def benchmark_function(func, *args, warmup=10, repeat=100): return np.median(times) -def main(batch: int = 1, - heads: int = 8, - seq_len: int = 4096, - dim: int = 128, - is_causal: bool = False, - groups: int = 1): - +def main(batch: int = 1, heads: int = 8, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1): device = "cuda" dtype = torch.float16 torch.manual_seed(42) torch.cuda.manual_seed(42) - print( - f"Test configuration: batch={batch}, heads={heads}, seq_len={seq_len}, dim={dim}, is_causal={is_causal}, groups={groups}" - ) + print(f"Test configuration: batch={batch}, heads={heads}, seq_len={seq_len}, dim={dim}, is_causal={is_causal}, groups={groups}") flops_per_gemm = 2.0 * batch * heads * seq_len * seq_len * dim total_flops = 5 * flops_per_gemm @@ -515,22 +508,19 @@ def main(batch: int = 1, o_ref.backward(dO) print("Verifying backward pass correctness...") - dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison( - dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05) + dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison(dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05) if dq_close: print("dQ is correct.") else: print("dQ mismatch detected.") - dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison( - dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05) + dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison(dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05) if dk_close: print("dK is correct.") else: print("dK mismatch detected.") - dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison( - dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05) + dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison(dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05) if dv_close: print("dV is correct.") else: @@ -551,9 +541,7 @@ def run_reference_fwd_bwd(): torch.cuda.synchronize() ref_latency = benchmark_function(run_reference_fwd_bwd, warmup=10, repeat=100) - print( - f"Reference PyTorch Forward+Backward: {ref_latency:.2f} ms | {total_flops / ref_latency * 1e-9:.2f} TFlops" - ) + print(f"Reference PyTorch Forward+Backward: {ref_latency:.2f} ms | {total_flops / ref_latency * 1e-9:.2f} TFlops") def run_complete_fwd_bwd(): o_tl_bench, lse_tl_bench = fwd_kernel(q, k, v) @@ -591,12 +579,12 @@ def run_complete_fwd_bwd(): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=8, help='heads') - parser.add_argument('--seq_len', type=int, default=1024, help='sequence length') - parser.add_argument('--dim', type=int, default=64, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--groups', type=int, default=1, help='groups') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=8, help="heads") + parser.add_argument("--seq_len", type=int, default=1024, help="sequence length") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--groups", type=int, default=1, help="groups") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups) diff --git a/examples/amd/example_amd_flash_attn_fwd.py b/examples/amd/example_amd_flash_attn_fwd.py index 6ec5db1e5..581619220 100644 --- a/examples/amd/example_amd_flash_attn_fwd.py +++ b/examples/amd/example_amd_flash_attn_fwd.py @@ -2,29 +2,42 @@ import torch.nn.functional as F import tilelang import tilelang.language as T -from tilelang.primitives.gemm.base import GemmWarpPolicy +from tilelang.tileop.base import GemmWarpPolicy import itertools import argparse from functools import partial +# Custom supply function to ensure tensors are created on GPU +def supply_tensors_gpu(params): + """Supply function that creates tensors on GPU for ROCm/HIP.""" + tensors = [] + for param in params: + if hasattr(param, "shape") and hasattr(param, "dtype"): + # Force creation on GPU device + shape = [int(s) for s in param.shape] + tensor = torch.randn(shape, dtype=param.dtype, device="cuda") + tensors.append(tensor) + else: + tensors.append(param) + return tensors + + def ref_program(Q, K, V, is_causal, groups=1): - assert Q.size( - 2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" - assert Q.size( - 2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" dim = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -43,27 +56,27 @@ def get_configs(): valid_configs = [] - for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(block_M, block_N, num_split_q, - threads, num_stages, - enable_rasterization, k_pack, - panel_size, qk_coalesced_width, - v_coalesced_width): - valid_configs.append({ - "block_M": m, - "block_N": n, - "num_split_q": s, - "threads": t, - "num_stages": stages, - "enable_rasterization": r, - "k_pack": k, - "panel_size": p, - "qk_coalesced_width": qkw, - "v_coalesced_width": vw, - }) + for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product( + block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width + ): + valid_configs.append( + { + "block_M": m, + "block_N": n, + "num_split_q": s, + "threads": t, + "num_stages": stages, + "enable_rasterization": r, + "k_pack": k, + "panel_size": p, + "qk_coalesced_width": qkw, + "v_coalesced_width": vw, + } + ) return valid_configs -@tilelang.autotune(configs=get_configs(), cache_input_tensors=True) +@tilelang.autotune(configs=get_configs(), cache_input_tensors=True, supply_prog=supply_tensors_gpu) @tilelang.jit(out_idx=[3]) def fast_flashattn( batch, @@ -83,22 +96,22 @@ def fast_flashattn( qk_coalesced_width: int, v_coalesced_width: int, ): - scale = (1.0 / dim)**0.5 + scale = (1.0 / dim) ** 0.5 head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 vec_size = qk_coalesced_width v_vec_size = v_coalesced_width @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined): T.use_swizzle(panel_size, enable=enable_rasterization) @@ -108,10 +121,10 @@ def main( num_q_blocks = T.ceildiv(seq_len, block_M) - bx = T.alloc_var("int32") + bx = T.alloc_var(T.int32) bx = b_split - with T.While(bx < num_q_blocks): + while bx < num_q_blocks: acc_o = T.alloc_fragment([block_M, dim], accum_dtype) m_i = T.alloc_fragment([block_M], accum_dtype) l_i = T.alloc_fragment([block_M], accum_dtype) @@ -132,32 +145,21 @@ def main( m_prev = T.alloc_fragment([block_M], accum_dtype) scale_factor = T.alloc_fragment([block_M], accum_dtype) - T.copy( - Q[bz, q_block_offset:q_block_offset + block_M, by, :], - Q_shared, - coalesced_width=vec_size) + T.copy(Q[bz, q_block_offset : q_block_offset + block_M, by, :], Q_shared, coalesced_width=vec_size) - loop_end_k = T.ceildiv(q_block_offset + block_M, - block_N) if is_causal else T.ceildiv(seq_len, block_N) + loop_end_k = T.ceildiv(q_block_offset + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) row_sum = T.alloc_fragment([block_M], accum_dtype) for k in T.Pipelined(loop_end_k, num_stages=num_stages): kv_idx = k * block_N - T.copy( - K[bz, kv_idx:kv_idx + block_N, by // groups, :], - K_shared, - coalesced_width=vec_size) - T.copy( - V[bz, kv_idx:kv_idx + block_N, by // groups, :], - V_shared, - coalesced_width=v_vec_size) + T.copy(K[bz, kv_idx : kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size) + T.copy(V[bz, kv_idx : kv_idx + block_N, by // groups, :], V_shared, coalesced_width=v_vec_size) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) T.gemm( @@ -171,6 +173,8 @@ def main( T.copy(m_i, m_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for i in T.Parallel(block_M): + m_i[i] = T.max(m_i[i], m_prev[i]) for i in T.Parallel(block_M): sf = T.exp(m_prev[i] * scale - m_i[i] * scale) @@ -205,13 +209,7 @@ def main( return main -def main(batch: int = 1, - heads: int = 8, - seq_len: int = 4096, - dim: int = 128, - is_causal: bool = False, - groups: int = 1): - +def main(batch: int = 1, heads: int = 8, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1): flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim total_flops = 2 * flops_per_matmul if is_causal: @@ -233,18 +231,16 @@ def main(batch: int = 1, print(f"Reference (PyTorch): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops") latency = profiler.do_bench(warmup=100) - print( - f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops" - ) + print(f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops") if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=8, help='heads') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--groups', type=int, default=1, help='groups') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=8, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--groups", type=int, default=1, help="groups") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups) diff --git a/examples/analyze/README.md b/examples/analyze/README.md index 8171d8826..1c2788b0b 100644 --- a/examples/analyze/README.md +++ b/examples/analyze/README.md @@ -21,9 +21,9 @@ M = N = K = 1024 def kernel(block_M=128, block_N=128, block_K=32, num_stages=3, thread_num=128): @T.prim_func - def main(A: T.Tensor((M, K), "float16"), - B: T.Tensor((N, K), "float16"), - C: T.Tensor((M, N), "float")): + def main(A: T.Tensor((M, K), T.float16), + B: T.Tensor((N, K), T.float16), + C: T.Tensor((M, N), T.float)): # ... (kernel definition) return main @@ -40,9 +40,9 @@ from tilelang.carver.arch import CUDA def kernel(N=64, C=256, H=512, W=512, F=512, K=3, block_M=64, block_N=128): @T.prim_func - def main(data: T.Tensor((N, H, W, C), "float16"), - kernel: T.Tensor((K, K, C, F), "float16"), - out: T.Tensor((N, (H-K+1), (W-K+1), F), "float")): + def main(data: T.Tensor((N, H, W, C), T.float16), + kernel: T.Tensor((K, K, C, F), T.float16), + out: T.Tensor((N, (H-K+1), (W-K+1), F), T.float)): # ... (convolution kernel definition) return main @@ -64,10 +64,10 @@ class AnalysisResult: ``` ### `Analyzer` Class Methods #### `analysis(fn, device)` -* ​Parameters: - * fn: TVM IRModule or PrimFunc - * device: Device configuration object -* Returns: AnalysisResult +- ​Parameters: + - fn: TVM IRModule or PrimFunc + - device: Device configuration object +- Returns: AnalysisResult #### Supported Architectures ```python # Extendable to custom hardware via: "compute_capability": (cores_per_SM, clock_GHz, flops_per_cycle, max_SM_count) diff --git a/examples/analyze/example_conv_analyze.py b/examples/analyze/example_conv_analyze.py index 540fcf4b7..06e5a86e9 100644 --- a/examples/analyze/example_conv_analyze.py +++ b/examples/analyze/example_conv_analyze.py @@ -2,7 +2,6 @@ from tilelang.tools import Analyzer from tilelang.carver.arch import CUDA from tilelang.carver.arch import CDNA -from tilelang.layout import make_swizzled_layout import torch N = 64 @@ -25,38 +24,21 @@ def check_hopper(): return False -def kernel(N, - C, - H, - W, - F, - K, - S, - D, - P, - block_M, - block_N, - block_K, - num_stages, - threads, - dtype="float16", - accum_dtype="float"): +def kernel(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 is_hopper = check_hopper() @T.prim_func def conv( - data: T.Tensor((N, H, W, C), dtype), - kernel: T.Tensor((KH, KW, C, F), dtype), - out: T.Tensor((N, OH, OW, F), dtype), + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), ): - with T.Kernel( - T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by): data_shared = T.alloc_shared((block_M, block_K), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -65,12 +47,6 @@ def conv( kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) - T.annotate_layout({ - out_shared: make_swizzled_layout(out_shared), - data_shared: make_swizzled_layout(data_shared), - kernel_shared: make_swizzled_layout(kernel_shared), - }) - T.clear(out_local) for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): if is_hopper: @@ -81,10 +57,8 @@ def conv( m = by * block_M + i access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P access_w = m % OW * S + k // C % KW * D - P - in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and - (access_w < W)) - data_shared[i, j] = T.if_then_else( - in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W) + data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) T.gemm(data_shared, kernel_shared, out_local) diff --git a/examples/analyze/example_gemm_analyze.py b/examples/analyze/example_gemm_analyze.py index bfd934f6a..0367af126 100644 --- a/examples/analyze/example_gemm_analyze.py +++ b/examples/analyze/example_gemm_analyze.py @@ -15,14 +15,14 @@ def kernel( thread_num=None, enable_rasteration=None, ): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def matmul( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) diff --git a/examples/attention_sink/README.md b/examples/attention_sink/README.md index ed4b7004e..2cba8f0cc 100644 --- a/examples/attention_sink/README.md +++ b/examples/attention_sink/README.md @@ -2,7 +2,6 @@ We compare with an optimized version of the official Triton implementation [here](https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py). - ## Algorithm ### Forward The only change from vanilla FlashAttention is that `sinks` should be taken into consideration in the softmax, which requires an extra rescaling at the epilogue stage. @@ -43,4 +42,4 @@ where $P_{b, h, q}$ is the proportion of $sink_h$ in the softmax in the $b$-th b | 16384 | 64 | 309.46 | **400.62** | 1.29x | | 16384 | 128 | 418.99 | **549.11** | 1.31x | -> The backward performance will be further optimized in the future. \ No newline at end of file +> The backward performance will be further optimized in the future. diff --git a/examples/attention_sink/benchmark_gqa_sink_fwd.py b/examples/attention_sink/benchmark_gqa_sink_fwd.py index 1b7de6b6f..211ef1d18 100644 --- a/examples/attention_sink/benchmark_gqa_sink_fwd.py +++ b/examples/attention_sink/benchmark_gqa_sink_fwd.py @@ -1,6 +1,7 @@ import torch import argparse from tilelang.profiler import do_bench +from tilelang import language as T import triton import triton.language as tl from triton.tools.tensor_descriptor import TensorDescriptor @@ -51,8 +52,7 @@ def triton_kernel( q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM]) if BANDWIDTH: - lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M + lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M else: lo, hi = 0, start_q + (start_m + 1) * BLOCK_M @@ -120,7 +120,8 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T BANDWIDTH=window_size, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - start_q=seq_kv - seq_q) + start_q=seq_kv - seq_q, + ) return o @@ -135,14 +136,14 @@ def main( dtype: str = "float16", tune: bool = False, ): - torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= seq_q - flops_per_matmul = 2.0 * batch * heads * min( - window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 total_flops = 2 * flops_per_matmul @@ -170,15 +171,14 @@ def main( block_N=block_N, num_stages=num_stages, threads=threads, - dtype=dtype) + dtype=dtype, + ) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype) if torch.allclose( - triton_program(Q, K, V, sinks, window_size), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2): + triton_program(Q, K, V, sinks, window_size), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ): print("Checks for triton passed.✅") else: print("Checks for triton failed.❌") @@ -198,20 +198,14 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='heads') - parser.add_argument('--seq_q', type=int, default=2048, help='sequence length of query') - parser.add_argument('--seq_kv', type=int, default=2048, help='sequence length of key/value') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--groups', type=int, default=8, help='groups') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_q", type=int, default=2048, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=2048, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--groups", type=int, default=8, help="groups") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, - args.dtype, args.tune) + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/benchmark_mha_sink_fwd.py b/examples/attention_sink/benchmark_mha_sink_fwd.py index f50b94535..50747e6b0 100644 --- a/examples/attention_sink/benchmark_mha_sink_fwd.py +++ b/examples/attention_sink/benchmark_mha_sink_fwd.py @@ -1,6 +1,7 @@ import torch import argparse from tilelang.profiler import do_bench +from tilelang import language as T import triton import triton.language as tl from triton.tools.tensor_descriptor import TensorDescriptor @@ -50,8 +51,7 @@ def triton_kernel( q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM]) if BANDWIDTH: - lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M + lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M else: lo, hi = 0, start_q + (start_m + 1) * BLOCK_M @@ -117,26 +117,29 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T BANDWIDTH=window_size, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - start_q=seq_kv - seq_q) + start_q=seq_kv - seq_q, + ) return o -def main(batch: int = 1, - heads: int = 32, - seq_q: int = 256, - seq_kv: int = 256, - dim: int = 128, - window_size: Optional[int] = None, - dtype: str = "float16", - tune: bool = False): - torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] +def main( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False, +): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= seq_q - flops_per_matmul = 2.0 * batch * heads * min( - window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 total_flops = 2 * flops_per_matmul @@ -163,15 +166,14 @@ def main(batch: int = 1, block_N=block_N, num_stages=num_stages, threads=threads, - dtype=dtype) + dtype=dtype, + ) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) torch.testing.assert_close( - kernel(Q, K, V, sinks), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2) + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) print("All checks passed.✅") latency = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500) @@ -184,19 +186,13 @@ def main(batch: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query') - parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") - parser.add_argument('--tune', action='store_true', help='tune') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune") args = parser.parse_args() - main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, - args.tune) + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py index eec43db99..cfdcd21b5 100644 --- a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py @@ -13,50 +13,50 @@ def get_bwd_configs(): sm_version = sm_major * 10 + sm_minor if sm_version == 80: return 64, 32, 1, 128 - elif sm_version == 90: - return 128, 32, 2, 256 else: - raise ValueError(f"Unsupported SM version: {sm_version}") + return 128, 32, 2, 256 @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd( - batch, - heads, - seq_len, - dim, - groups=1, - window_size=None, # None for full attention - sm_scale=None, - block_M=64, - block_N=64, - num_stages=1, - threads=128, - dtype: str = "float16"): - + batch, + heads, + seq_len, + dim, + groups=1, + window_size=None, # None for full attention + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype: T.dtype = T.float16, +): if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, heads, seq_len, dim] kv_shape = [batch, head_kv, seq_len, dim] - accum_dtype = "float" + accum_dtype = T.float32 @T.prim_func def flash_fwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(kv_shape, dtype), # type: ignore - V: T.Tensor(kv_shape, dtype), # type: ignore - Output: T.Tensor(q_shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Sinks: T.Tensor([heads], dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(kv_shape, dtype), # type: ignore + V: T.Tensor(kv_shape, dtype), # type: ignore + Output: T.Tensor(q_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Sinks: T.Tensor([heads], dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -72,8 +72,7 @@ def flash_fwd( logsum = T.alloc_fragment([block_M], accum_dtype) sinks = T.alloc_fragment([heads], dtype) - T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -81,31 +80,30 @@ def flash_fwd( sinks[i] = Sinks[by] end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) - start = T.max(0, - (bx * block_M - window_size) // block_N) if window_size is not None else 0 + start = T.max(0, (bx * block_M - window_size) // block_N) if window_size is not None else 0 for k in T.Pipelined(start, end, num_stages=num_stages): - T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by // groups, k * block_N : (k + 1) * block_N, :], K_shared) for i, j in T.Parallel(block_M, block_N): q_idx = bx * block_M + i k_idx = k * block_N + j if window_size is not None: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, - 0, -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) else: acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, by // groups, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by // groups, k * block_N : (k + 1) * block_N, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # NOTE(wt): check_inf is necessary for sliding window attention. for i in T.Parallel(block_M): if window_size is not None: - scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, - scores_max[i]) + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): @@ -122,32 +120,33 @@ def flash_fwd( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i in T.Parallel(block_M): - logsum[i] += T.exp2(sinks[i] * 1.44269504 - - scores_max[i] * scale) # The only change for attention sink + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"): - accum_dtype = "float" + }, +) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 shape = [batch, heads, seq_len, dim] blk = 32 @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -156,65 +155,61 @@ def flash_bwd_prep( delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], do) + T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment - return T.Layout(dQ.shape, - lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) @tilelang.jit( - out_idx=[1], pass_configs={ + out_idx=[1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"): - accum_dtype = "float" + }, +) +def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 shape = [batch, heads, seq_len, dim] blk = 64 @T.prim_func def flash_bwd_post( - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(shape, dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) T.copy( - dQ[bz, by, bx * blk:(bx + 1) * blk, :], - dQ_out[bz, by, bx * blk:(bx + 1) * blk, :], + dQ[bz, by, bx * blk : (bx + 1) * blk, :], + dQ_out[bz, by, bx * blk : (bx + 1) * blk, :], ) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd(batch, - heads, - seq_len, - dim, - groups, - window_size=None, - sm_scale=None, - dtype="float16"): # None for full attention +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale=None, dtype=T.float16): # None for full attention if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, heads, seq_len, dim] kv_shape = [batch, head_kv, seq_len, dim] - accum_dtype = "float" + accum_dtype = T.float32 block_M, block_N, num_stages, threads = get_bwd_configs() @@ -223,15 +218,15 @@ def flashattn_bwd(batch, @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(kv_shape, dtype), # type: ignore - V: T.Tensor(kv_shape, dtype), # type: ignore - dO: T.Tensor(q_shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(kv_shape, accum_dtype), # type: ignore - dV: T.Tensor(kv_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(kv_shape, dtype), # type: ignore + V: T.Tensor(kv_shape, dtype), # type: ignore + dO: T.Tensor(q_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(kv_shape, accum_dtype), # type: ignore + dV: T.Tensor(kv_shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) @@ -251,44 +246,44 @@ def flash_bwd( dv_shared = T.alloc_shared([block_M, dim], accum_dtype) dk_shared = T.alloc_shared([block_M, dim], accum_dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) - T.copy(K[bz, bx // groups, by * block_M:(by + 1) * block_M, :], K_shared) - T.copy(V[bz, bx // groups, by * block_M:(by + 1) * block_M, :], V_shared) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) + T.copy(K[bz, bx // groups, by * block_M : (by + 1) * block_M, :], K_shared) + T.copy(V[bz, bx // groups, by * block_M : (by + 1) * block_M, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) - loop_ed = T.min( - T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv( - seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N) + loop_ed = ( + T.min(T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(seq_len, block_N)) + if window_size is not None + else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q) + T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) for i, j in T.Parallel(block_M, block_N): if window_size is not None: qkT[i, j] = T.if_then_else( - by * block_M + i <= k * block_N + j and - by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0) + by * block_M + i <= k * block_N + j and by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0 + ) else: - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], dst=do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], dst=do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -297,50 +292,46 @@ def flash_bwd( T.copy(dsT_cast, dsT_shared) T.clear(dq) T.gemm(dsT_shared, K_shared, dq, transpose_A=True) - T.atomic_add(dQ[bz, bx, k * block_N:(k + 1) * block_N, :], dq) + T.atomic_add(dQ[bz, bx, k * block_N : (k + 1) * block_N, :], dq) T.copy(dv, dv_shared) - T.atomic_add(dV[bz, bx // groups, by * block_M:(by + 1) * block_M, :], dv_shared) + T.atomic_add(dV[bz, bx // groups, by * block_M : (by + 1) * block_M, :], dv_shared) T.copy(dk, dk_shared) - T.atomic_add(dK[bz, bx // groups, by * block_M:(by + 1) * block_M, :], dk_shared) + T.atomic_add(dK[bz, bx // groups, by * block_M : (by + 1) * block_M, :], dk_shared) return flash_bwd @tilelang.jit(out_idx=-1) -def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: str = "float16"): - accum_dtype = "float" +def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: T.dtype = T.float16): + accum_dtype = T.float32 shape = [batch, heads, seq_len] @T.prim_func def flash_bwd_dsink( - Sinks: T.Tensor([heads], dtype), # type: ignore - Delta: T.Tensor(shape, accum_dtype), # type: ignore - lse: T.Tensor(shape, accum_dtype), # type: ignore - dsinks: T.Tensor(shape, dtype), # type: ignore + Sinks: T.Tensor([heads], dtype), # type: ignore + Delta: T.Tensor(shape, accum_dtype), # type: ignore + lse: T.Tensor(shape, accum_dtype), # type: ignore + dsinks: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=256) as (bx, by, bz): - sink = T.alloc_local([1], dtype) lse_fragment = T.alloc_fragment([block], accum_dtype) delta_fragment = T.alloc_fragment([block], accum_dtype) dsink_fragment = T.alloc_fragment([block], dtype) - sink[0] = Sinks[bx] - T.copy(lse[bz, bx, by * block:(by + 1) * block], lse_fragment) - T.copy(Delta[bz, bx, by * block:(by + 1) * block], delta_fragment) + sink = Sinks[bx] + T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment) + T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment) for i in T.Parallel(block): - dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - - lse_fragment[i]) * delta_fragment[i] - T.copy(dsink_fragment, dsinks[bz, bx, by * block:(by + 1) * block]) + dsink_fragment[i] = -T.exp2(sink * 1.44269504 - lse_fragment[i]) * delta_fragment[i] + T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block]) return flash_bwd_dsink class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, sinks, window_size, groups): - def maybe_contiguous(x): if x.stride(-1) != 1: return x.contiguous() @@ -348,7 +339,7 @@ def maybe_contiguous(x): q, k, v, sinks = [maybe_contiguous(x) for x in (q, k, v, sinks)] BATCH, H, N_CTX, D_HEAD = q.shape - dtype = "float16" if q.dtype == torch.float16 else "bfloat16" + dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16 kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, groups, window_size, dtype=dtype) o, lse = kernel(q, k, v, sinks) ctx.save_for_backward(q, k, v, sinks, o, lse) @@ -361,7 +352,7 @@ def backward(ctx, do): q, k, v, sinks, o, lse = ctx.saved_tensors BATCH, H, N_CTX, D_HEAD = q.shape groups = ctx.groups - dtype = "float16" if q.dtype == torch.float16 else "bfloat16" + dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16 kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) @@ -386,13 +377,14 @@ def backward(ctx, do): # Adapted and optimized from # https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py -def ref_program(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sinks: torch.Tensor, - sliding_window: Optional[int] = None, - dtype: torch.dtype = torch.float16) -> torch.Tensor: - +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() batch_size, num_keys, num_key_value_heads, head_dim = key.shape @@ -428,32 +420,32 @@ def ref_program(query: torch.Tensor, output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) - output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, - head_dim).to(dtype) + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) return output.transpose(1, 2).contiguous() -def main(BATCH: int = 1, - H: int = 8, - N_CTX: int = 512, - D_HEAD: int = 64, - groups: int = 2, - window_size: Optional[int] = None, - dtype: str = "float16"): - torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] +def main( + BATCH: int = 1, + H: int = 8, + N_CTX: int = 512, + D_HEAD: int = 64, + groups: int = 2, + window_size: Optional[int] = None, + dtype: str = "float16", +): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= N_CTX - flops_per_matmul = 2.0 * BATCH * H * min( - window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation + flops_per_matmul = 2.0 * BATCH * H * min(window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5 total_flops = 5 * flops_per_matmul - Q = (torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()) - K = torch.randn( - BATCH, H // groups, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() + K = torch.randn(BATCH, H // groups, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() V = torch.randn_like(K).requires_grad_() sinks = torch.randn(H, dtype=torch_dtype, device="cuda").requires_grad_() dO = torch.randn_like(Q) @@ -474,19 +466,14 @@ def main(BATCH: int = 1, # Checks rtol, atol = { - "float16": (1e-2, 1e-2), - "bfloat16": (2e-2, 2e-2), + T.float16: (1e-2, 1e-2), + T.bfloat16: (2e-2, 2e-2), }[dtype] - assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f'O max err: {(O-O_ref).abs().max()}' - assert torch.allclose( - dV, dV_ref, rtol=rtol, atol=atol), f'dV max err: {(dV-dV_ref).abs().max()}' - assert torch.allclose( - dK, dK_ref, rtol=rtol, atol=atol), f'dK max err: {(dK-dK_ref).abs().max()}' - assert torch.allclose( - dQ, dQ_ref, rtol=rtol, atol=atol), f'dq max err: {(dQ-dQ_ref).abs().max()}' - assert torch.allclose( - dsinks, dsinks_ref, rtol=rtol, - atol=atol), f'dsinks max err: {(dsinks-dsinks_ref).abs().max()}' + assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}" + assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}" + assert torch.allclose(dK, dK_ref, rtol=rtol, atol=atol), f"dK max err: {(dK - dK_ref).abs().max()}" + assert torch.allclose(dQ, dQ_ref, rtol=rtol, atol=atol), f"dq max err: {(dQ - dQ_ref).abs().max()}" + assert torch.allclose(dsinks, dsinks_ref, rtol=rtol, atol=atol), f"dsinks max err: {(dsinks - dsinks_ref).abs().max()}" print("All checks passed for tilelang kernels.✅") @@ -505,19 +492,57 @@ def tl_bwd(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf( + BATCH: int = 1, + H: int = 8, + N_CTX: int = 512, + D_HEAD: int = 64, + groups: int = 2, + window_size: Optional[int] = None, + dtype: str = "float16", +): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + with torch.no_grad(): + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda") + K = torch.randn(BATCH, H // groups, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda") + V = torch.randn_like(K) + sinks = torch.randn(H, dtype=torch_dtype, device="cuda") + dO = torch.randn_like(Q) + fwd = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, groups, window_size, dtype=dtype) + O, lse = fwd(Q, K, V, sinks) + + def maybe_contiguous(x): + return x if x.stride(-1) == 1 else x.contiguous() + + do, q, k, v, sinks_c, o = [maybe_contiguous(x) for x in (dO, Q, K, V, sinks, O)] + k_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) + Delta = k_prep(o, do) + k_bwd = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, groups, window_size, dtype=dtype) + k_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype) + q_shape = (BATCH, H, N_CTX, D_HEAD) + head_kv = H // groups + kv_shape = (BATCH, head_kv, N_CTX, D_HEAD) + dq = torch.zeros(q_shape, dtype=torch.float32, device="cuda") + dk = torch.zeros(kv_shape, dtype=torch.float32, device="cuda") + dv = torch.zeros(kv_shape, dtype=torch.float32, device="cuda") + k_bwd(q, k, v, do, lse, Delta, dq, dk, dv) + _ = k_dsink(sinks_c, Delta, lse).sum(0).sum(1) + + def run_kernel_only(): + k_bwd(q, k, v, do, lse, Delta, dq, dk, dv) + + latency_ms = do_bench(run_kernel_only, backend="cupti") + return latency_ms + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='Batch size') - parser.add_argument('--h', type=int, default=64, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=4096, help='Context size') - parser.add_argument('--d_head', type=int, default=128, help='Head dimension') - parser.add_argument('--groups', type=int, default=8, help='Groups') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--batch", type=int, default=1, help="Batch size") + parser.add_argument("--h", type=int, default=64, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=4096, help="Context size") + parser.add_argument("--d_head", type=int, default=128, help="Head dimension") + parser.add_argument("--groups", type=int, default=8, help="Groups") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head, args.groups, args.window_size, args.dtype) diff --git a/examples/attention_sink/example_gqa_sink_bwd_varlen.py b/examples/attention_sink/example_gqa_sink_bwd_varlen.py new file mode 100644 index 000000000..64a5a39a8 --- /dev/null +++ b/examples/attention_sink/example_gqa_sink_bwd_varlen.py @@ -0,0 +1,798 @@ +import torch +import tilelang +from tilelang.profiler import do_bench +import tilelang.language as T +import argparse +from typing import Optional +import sys +import os + +sys.path.append(os.path.join(os.path.dirname(__file__), "../flash_attention")) +from varlen_utils import generate_random_padding_mask, generate_qkv + + +def get_bwd_configs(): + sm_major, sm_minor = torch.cuda.get_device_capability() + sm_version = sm_major * 10 + sm_minor + if sm_version == 80: + return 64, 32, 1, 128 + else: + return 128, 32, 2, 256 + + +@tilelang.jit( + out_idx=[6, 7], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_fwd( + batch_size, + groups, + UQ, + UKV, + N_CTX, + heads, + max_seq_len, + dim, + is_causal, + window_size=None, # None for full causal attention + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype=T.float16, +): + if window_size is not None: + assert window_size % block_N == 0, "window_size must be divisible by block_N" + + if sm_scale is None: + sm_scale = (1.0 / dim) ** 0.5 + scale = sm_scale * 1.44269504 # log2(e) + + head_kv = heads // groups + q_shape = [UQ, heads, dim] + kv_shape = [UKV, head_kv, dim] + o_shape = [UQ, heads, dim] + accum_dtype = T.float32 + + @T.prim_func + def main( + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(kv_shape, dtype), + V_unpad: T.Tensor(kv_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), + cu_seqlens_k: T.Tensor([batch_size + 1], T.int32), + Sinks: T.Tensor([heads], dtype), + Output_unpad: T.Tensor(o_shape, dtype), + lse: T.Tensor([batch_size, heads, N_CTX], accum_dtype), + ): + with T.Kernel(T.ceildiv(max_seq_len, block_M), heads, batch_size, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + sinks = T.alloc_fragment([block_M], dtype) + + batch_idx = bz + head_idx = by + kv_head_idx = head_idx // groups + + q_start_idx = cu_seqlens_q[batch_idx] + kv_start_idx = cu_seqlens_k[batch_idx] + q_end_idx = cu_seqlens_q[batch_idx + 1] + k_end_idx = cu_seqlens_k[batch_idx + 1] + + q_current_seqlen = q_end_idx - q_start_idx + kv_current_seqlen = k_end_idx - kv_start_idx + + T.copy(Q_unpad[q_start_idx + bx * block_M : q_start_idx + (bx + 1) * block_M, head_idx, :], Q_shared) + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + for i in T.Parallel(block_M): + sinks[i] = Sinks[head_idx] + + offset = kv_current_seqlen - q_current_seqlen # always align on the right + max_visible_k_idx = offset + (bx + 1) * block_M + + # Determine loop range based on causal mask and sliding window + if is_causal: + if window_size is not None: + start = T.max(0, (offset + bx * block_M - window_size + 1) // block_N) + end = T.min(T.ceildiv(max_visible_k_idx, block_N), T.ceildiv(kv_current_seqlen, block_N)) + else: + start = 0 + end = T.min(T.ceildiv(max_visible_k_idx, block_N), T.ceildiv(kv_current_seqlen, block_N)) + else: + if window_size is not None: + start = T.max(0, (offset + bx * block_M - window_size + 1) // block_N) + end = T.ceildiv(kv_current_seqlen, block_N) + else: + start = 0 + end = T.ceildiv(kv_current_seqlen, block_N) + + loop_range = end - start + + for k in T.Pipelined(loop_range, num_stages=num_stages): + actual_k = k + start + T.copy(K_unpad[kv_start_idx + actual_k * block_N : kv_start_idx + (actual_k + 1) * block_N, kv_head_idx, :], K_shared) + + # Build mask considering causal, sliding window, and padding + if is_causal: + if window_size is not None: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + offset + k_idx = actual_k * block_N + j + acc_s[i, j] = T.if_then_else( + (q_idx < k_idx) + or (q_idx >= k_idx + window_size) + or (bx * block_M + i >= q_current_seqlen or actual_k * block_N + j >= kv_current_seqlen), + -T.infinity(acc_s.dtype), + 0, + ) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (bx * block_M + i + offset < actual_k * block_N + j) + or (bx * block_M + i >= q_current_seqlen or actual_k * block_N + j >= kv_current_seqlen), + -T.infinity(acc_s.dtype), + 0, + ) + else: + if window_size is not None: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + offset + k_idx = actual_k * block_N + j + acc_s[i, j] = T.if_then_else( + (q_idx >= k_idx + window_size) + or (bx * block_M + i >= q_current_seqlen or actual_k * block_N + j >= kv_current_seqlen), + -T.infinity(acc_s.dtype), + 0, + ) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= q_current_seqlen or actual_k * block_N + j >= kv_current_seqlen), + -T.infinity(acc_s.dtype), + 0, + ) + + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(V_unpad[kv_start_idx + actual_k * block_N : kv_start_idx + (actual_k + 1) * block_N, kv_head_idx, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + # Handle case where scores_max is -inf (query sees no keys due to causal mask or sliding window) + # This can happen when q_len > k_len (offset < 0) in causal attention, or with sliding window + for i in T.Parallel(block_M): + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + + # Attention sink: add sink contribution to logsum + for i in T.Parallel(block_M): + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] = 0 if is_causal and bx * block_M + i + offset < 0 else acc_o[i, j] / logsum[i] + + for i, d in T.Parallel(block_M, dim): + if bx * block_M + i < q_current_seqlen: + Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = acc_o[i, d] + + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + for i in T.Parallel(block_M): + if bx * block_M + i < q_current_seqlen: + lse[bz, head_idx, bx * block_M + i] = logsum[i] + + return main + + +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_preprocess(batch_size, heads, UQ, N_CTX, max_seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 + shape = [UQ, heads, dim] + blk = 32 + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), + dO: T.Tensor(shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), + Delta: T.Tensor([batch_size, heads, N_CTX], accum_dtype), + ): + with T.Kernel(heads, T.ceildiv(max_seq_len, blk), batch_size) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + + q_start_idx = cu_seqlens_q[bz] + q_end_idx = cu_seqlens_q[bz + 1] + q_current_seqlen = q_end_idx - q_start_idx + + T.clear(acc) + for k in range(T.ceildiv(dim, blk)): + for i, j in T.Parallel(blk, blk): + if by * blk + i < q_current_seqlen and k * blk + j < dim: + o[i, j] = O[q_start_idx + by * blk + i, bx, k * blk + j] + do[i, j] = dO[q_start_idx + by * blk + i, bx, k * blk + j] + else: + o[i, j] = 0.0 + do[i, j] = 0.0 + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + + for i in T.Parallel(blk): + if by * blk + i < q_current_seqlen: + Delta[bz, bx, by * blk + i] = delta[i] + + return flash_bwd_prep + + +def make_dq_layout(dQ): + # Reorder dq for atomic add: [seq, head, dim] -> permuted layout + return T.Layout(dQ.shape, lambda l, h, d: [h, l, d]) + + +@tilelang.jit( + out_idx=[1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_postprocess(UQ, heads, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 + shape = [UQ, heads, dim] + blk = 64 + + @T.prim_func + def flash_bwd_post( + dQ: T.Tensor(shape, accum_dtype), + dQ_out: T.Tensor(shape, dtype), + ): + with T.Kernel(T.ceildiv(UQ, blk), heads, threads=128) as (bx, by): + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy( + dQ[bx * blk : (bx + 1) * blk, by, :], + dQ_out[bx * blk : (bx + 1) * blk, by, :], + ) + + return flash_bwd_post + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd( + batch_size, + groups, + UQ, + UKV, + N_CTX, + heads, + max_seq_len, + dim, + is_causal, + window_size=None, + sm_scale=None, + dtype=T.float16, +): + if sm_scale is None: + sm_scale = (1.0 / dim) ** 0.5 + scale = sm_scale * 1.44269504 # log2(e) + + head_kv = heads // groups + q_shape = [UQ, heads, dim] + kv_shape = [UKV, head_kv, dim] + accum_dtype = T.float32 + + block_M, block_N, num_stages, threads = get_bwd_configs() + + if window_size is not None: + assert window_size % block_N == 0, "window_size must be divisible by block_N" + + @T.prim_func + def flash_bwd( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + dO: T.Tensor(q_shape, dtype), + lse: T.Tensor([batch_size, heads, N_CTX], accum_dtype), + Delta: T.Tensor([batch_size, heads, N_CTX], accum_dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), + cu_seqlens_k: T.Tensor([batch_size + 1], T.int32), + dQ: T.Tensor(q_shape, accum_dtype), + dK: T.Tensor(kv_shape, accum_dtype), + dV: T.Tensor(kv_shape, accum_dtype), + ): + with T.Kernel(heads, T.ceildiv(max_seq_len, block_M), batch_size, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_M, dim], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim], dtype) + dv = T.alloc_fragment([block_M, dim], accum_dtype) + dk = T.alloc_fragment([block_M, dim], accum_dtype) + dq = T.alloc_fragment([block_N, dim], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim], accum_dtype) + dk_shared = T.alloc_shared([block_M, dim], accum_dtype) + + q_start_idx = cu_seqlens_q[bz] + kv_start_idx = cu_seqlens_k[bz] + q_end_idx = cu_seqlens_q[bz + 1] + k_end_idx = cu_seqlens_k[bz + 1] + q_current_seqlen = q_end_idx - q_start_idx + kv_current_seqlen = k_end_idx - kv_start_idx + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) + T.copy(K[kv_start_idx + by * block_M : kv_start_idx + (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[kv_start_idx + by * block_M : kv_start_idx + (by + 1) * block_M, bx // groups, :], V_shared) + T.clear(dv) + T.clear(dk) + + # For varlen causal attention, we need to account for offset between q and kv lengths + # In forward: Q at pos q can see KV at pos k if q + offset >= k (where offset = kv_len - q_len) + # In backward: KV at pos kv_pos is seen by Q at pos q_pos if kv_pos <= q_pos + offset + offset = kv_current_seqlen - q_current_seqlen + + # loop_st: first Q block that can see this KV block + # kv_pos <= q_pos + offset => by * block_M <= k * block_N + offset + # => k >= (by * block_M - offset) / block_N + loop_st = T.max(0, T.floordiv(by * block_M - offset, block_N)) if is_causal else 0 + loop_ed = ( + T.min(T.ceildiv((by + 1) * block_M - offset + window_size, block_N), T.ceildiv(q_current_seqlen, block_N)) + if window_size is not None + else T.ceildiv(q_current_seqlen, block_N) + ) + + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[q_start_idx + k * block_N : q_start_idx + (k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + if window_size is not None: + for i, j in T.Parallel(block_M, block_N): + # Causal: kv_pos <= q_pos + offset + # Sliding window: kv_pos > q_pos + offset - window_size + qkT[i, j] = T.if_then_else( + (by * block_M + i <= k * block_N + j + offset) + and (by * block_M + i > k * block_N + j + offset - window_size) + and (by * block_M + i < kv_current_seqlen and k * block_N + j < q_current_seqlen), + qkT[i, j], + 0, + ) + else: + for i, j in T.Parallel(block_M, block_N): + # Causal: kv_pos <= q_pos + offset + qkT[i, j] = T.if_then_else( + (by * block_M + i <= k * block_N + j + offset) + and (by * block_M + i < kv_current_seqlen and k * block_N + j < q_current_seqlen), + qkT[i, j], + 0, + ) + else: + if window_size is not None: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else( + (by * block_M + i > k * block_N + j + offset - window_size) + and (by * block_M + i < kv_current_seqlen and k * block_N + j < q_current_seqlen), + qkT[i, j], + 0, + ) + else: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else( + by * block_M + i < kv_current_seqlen and k * block_N + j < q_current_seqlen, + qkT[i, j], + 0, + ) + + T.copy(dO[q_start_idx + k * block_N : q_start_idx + (k + 1) * block_N, bx, :], dst=do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + T.atomic_add(dQ[q_start_idx + k * block_N : q_start_idx + (k + 1) * block_N, bx, :], dq) + + T.copy(dv, dv_shared) + T.atomic_add(dV[kv_start_idx + by * block_M : kv_start_idx + (by + 1) * block_M, bx // groups, :], dv_shared) + T.copy(dk, dk_shared) + T.atomic_add(dK[kv_start_idx + by * block_M : kv_start_idx + (by + 1) * block_M, bx // groups, :], dk_shared) + + return flash_bwd + + +@tilelang.jit(out_idx=-1) +def flashattn_bwd_dsink(batch_size, heads, N_CTX, max_seq_len, block=256, dtype: T.dtype = T.float16): + accum_dtype = T.float32 + shape = [batch_size, heads, N_CTX] + + @T.prim_func + def flash_bwd_dsink( + Sinks: T.Tensor([heads], dtype), + Delta: T.Tensor(shape, accum_dtype), + lse: T.Tensor(shape, accum_dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), + dsinks: T.Tensor(shape, dtype), + ): + with T.Kernel(heads, T.ceildiv(max_seq_len, block), batch_size, threads=256) as (bx, by, bz): + lse_fragment = T.alloc_fragment([block], accum_dtype) + delta_fragment = T.alloc_fragment([block], accum_dtype) + dsink_fragment = T.alloc_fragment([block], dtype) + + # Get actual sequence length for this batch item + q_start_idx = cu_seqlens_q[bz] + q_end_idx = cu_seqlens_q[bz + 1] + q_current_seqlen = q_end_idx - q_start_idx + + sink = Sinks[bx] + T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment) + T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment) + for i in T.Parallel(block): + # Only compute for valid positions, set 0 for positions beyond sequence length + dsink_fragment[i] = T.if_then_else( + by * block + i < q_current_seqlen, + -T.exp2(sink * 1.44269504 - lse_fragment[i]) * delta_fragment[i], + 0, + ) + T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block]) + + return flash_bwd_dsink + + +class _attention(torch.autograd.Function): + @staticmethod + def forward( + ctx, q_unpad, k_unpad, v_unpad, sinks, cu_seqlens_q, cu_seqlens_k, N_CTX, max_seqlen_q, max_seqlen_k, window_size, groups, is_causal + ): + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + q_unpad, k_unpad, v_unpad, sinks = [maybe_contiguous(x) for x in (q_unpad, k_unpad, v_unpad, sinks)] + UQ, H, D_HEAD = q_unpad.shape + UKV = k_unpad.shape[0] + batch_size = cu_seqlens_q.shape[0] - 1 + dtype = T.float16 if q_unpad.dtype == torch.float16 else T.bfloat16 + + kernel = flashattn_fwd( + batch_size, + groups, + UQ, + UKV, + N_CTX, + H, + max_seqlen_q, + D_HEAD, + is_causal, + window_size=window_size, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype=dtype, + ) + o_unpad, lse = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, sinks) + + ctx.save_for_backward(q_unpad, k_unpad, v_unpad, sinks, o_unpad, lse, cu_seqlens_q, cu_seqlens_k) + ctx.window_size = window_size + ctx.groups = groups + ctx.is_causal = is_causal + ctx.N_CTX = N_CTX + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.batch_size = batch_size + return o_unpad + + @staticmethod + def backward(ctx, do): + q_unpad, k_unpad, v_unpad, sinks, o_unpad, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + UQ, H, D_HEAD = q_unpad.shape + UKV = k_unpad.shape[0] + groups = ctx.groups + batch_size = ctx.batch_size + dtype = T.float16 if q_unpad.dtype == torch.float16 else T.bfloat16 + + kernel_prep = flashattn_bwd_preprocess(batch_size, H, UQ, ctx.N_CTX, ctx.max_seqlen_q, D_HEAD, dtype=dtype) + kernel_post = flashattn_bwd_postprocess(UQ, H, D_HEAD, dtype=dtype) + delta = kernel_prep(o_unpad, do, cu_seqlens_q) + + kernel = flashattn_bwd( + batch_size, + groups, + UQ, + UKV, + ctx.N_CTX, + H, + ctx.max_seqlen_q, + D_HEAD, + ctx.is_causal, + window_size=ctx.window_size, + dtype=dtype, + ) + + head_kv = H // groups + dq = torch.zeros_like(q_unpad, dtype=torch.float32) + dk = torch.zeros([UKV, head_kv, D_HEAD], dtype=torch.float32, device=q_unpad.device) + dv = torch.zeros([UKV, head_kv, D_HEAD], dtype=torch.float32, device=q_unpad.device) + + kernel(q_unpad, k_unpad, v_unpad, do, lse, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv) + dq = kernel_post(dq) + dk = dk.to(q_unpad.dtype) + dv = dv.to(q_unpad.dtype) + + kernel_dsink = flashattn_bwd_dsink(batch_size, H, ctx.N_CTX, ctx.max_seqlen_q, dtype=dtype) + dsinks = kernel_dsink(sinks, delta, lse, cu_seqlens_q).sum(0).sum(1) + + return dq, dk, dv, dsinks, None, None, None, None, None, None, None, None + + +attention = _attention.apply + + +def ref_program( + q_unpad: torch.Tensor, + k_unpad: torch.Tensor, + v_unpad: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sinks: torch.Tensor, + batch_size: int, + is_causal: bool, + sliding_window: Optional[int] = None, + groups: int = 1, +) -> torch.Tensor: + """Reference implementation for varlen attention with sinks.""" + total_q, num_heads, head_dim = q_unpad.shape + _, num_key_value_heads, _ = k_unpad.shape + + sm_scale = 1.0 / head_dim**0.5 + + output = torch.zeros_like(q_unpad) + + for b in range(batch_size): + q_start = cu_seqlens_q[b].item() + q_end = cu_seqlens_q[b + 1].item() + k_start = cu_seqlens_k[b].item() + k_end = cu_seqlens_k[b + 1].item() + + q_len = q_end - q_start + k_len = k_end - k_start + + if q_len == 0: + continue + + q_seq = q_unpad[q_start:q_end] # [q_len, heads, dim] + k_seq = k_unpad[k_start:k_end] # [k_len, head_kv, dim] + v_seq = v_unpad[k_start:k_end] # [k_len, head_kv, dim] + + # Reshape for GQA + q_seq = q_seq.view(q_len, num_key_value_heads, groups, head_dim) + sinks_expanded = sinks.view(num_key_value_heads, groups, 1, 1).float() + + k_seq = k_seq.unsqueeze(2) # [k_len, head_kv, 1, dim] + v_seq = v_seq.unsqueeze(2) # [k_len, head_kv, 1, dim] + + logits = torch.einsum("qhgd,khgd->hgqk", q_seq.float(), k_seq.float()) * sm_scale + + start_q = k_len - q_len + pos_keys = torch.arange(k_len, device=q_unpad.device) + pos_queries = torch.arange(q_len, device=q_unpad.device) + start_q + + if is_causal: + mask = pos_keys[None, :] > pos_queries[:, None] + mask = mask.float().masked_fill(mask, float("-inf")) + else: + mask = torch.zeros(q_len, k_len, device=q_unpad.device) + + if sliding_window is not None: + too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1) + mask.masked_fill_(too_old, float("-inf")) + + logits = logits + mask[None, None, :, :] + + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(sinks_expanded, logits_max) + sinks_exp = torch.exp(sinks_expanded - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks_exp + scores = unnormalized_scores / normalizer + + out = torch.einsum("hgqk,khgd->qhgd", scores, v_seq.float()) + out = out.reshape(q_len, num_heads, head_dim).to(q_unpad.dtype) + + output[q_start:q_end] = out + + return output + + +def main( + batch: int = 1, + heads: int = 64, + q_seqlen: int = 2048, + k_seqlen: int = 2048, + dim: int = 128, + groups: int = 16, + is_causal: bool = True, + window_size: Optional[int] = None, +): + assert heads % groups == 0, "heads must be divisible by groups" + + flops_per_matmul = 2.0 * batch * heads * q_seqlen * k_seqlen * dim + total_flops = 5 * flops_per_matmul # fwd + bwd + + if is_causal: + total_flops *= 0.5 + + if window_size is not None: + print(f"Using sliding window attention with window_size={window_size}") + flops_per_matmul = 2.0 * batch * heads * min(window_size, k_seqlen // 2) * q_seqlen * dim + total_flops = 5 * flops_per_matmul + + dtype = torch.float16 + device = torch.device("cuda") + + head_kv = heads // groups + q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device) + k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device) + v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device) + sinks = torch.randn(heads, dtype=dtype, device=device) + + query_padding_mask = generate_random_padding_mask(q_seqlen, batch, device, mode="random") + key_padding_mask = generate_random_padding_mask(k_seqlen, batch, device, mode="random") + + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + _, + _, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + + q_unpad = q_unpad.requires_grad_(True) + k_unpad = k_unpad.requires_grad_(True) + v_unpad = v_unpad.requires_grad_(True) + sinks = sinks.requires_grad_(True) + + dO_unpad = torch.randn_like(q_unpad) + + # TileLang forward + backward + # N_CTX is the padded sequence length used for tensor allocation + N_CTX = q_seqlen + O_unpad = attention( + q_unpad, k_unpad, v_unpad, sinks, cu_seqlens_q, cu_seqlens_k, N_CTX, max_seqlen_q, max_seqlen_k, window_size, groups, is_causal + ) + O_unpad.backward(dO_unpad, retain_graph=True) + dQ, q_unpad.grad = q_unpad.grad.clone(), None + dK, k_unpad.grad = k_unpad.grad.clone(), None + dV, v_unpad.grad = v_unpad.grad.clone(), None + dsinks, sinks.grad = sinks.grad.clone(), None + + # Reference forward + backward + O_ref_unpad = ref_program( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sinks, + batch, + is_causal, + sliding_window=window_size, + groups=groups, + ) + O_ref_unpad.backward(dO_unpad, retain_graph=True) + dQ_ref, q_unpad.grad = q_unpad.grad.clone(), None + dK_ref, k_unpad.grad = k_unpad.grad.clone(), None + dV_ref, v_unpad.grad = v_unpad.grad.clone(), None + dsinks_ref, sinks.grad = sinks.grad.clone(), None + + # Checks + # Sliding window attention has slightly higher numerical error due to more complex masking + rtol, atol = (2e-2, 2e-2) if window_size is not None else (1e-2, 1e-2) + assert torch.allclose(O_unpad, O_ref_unpad, rtol=rtol, atol=atol), f"O max err: {(O_unpad - O_ref_unpad).abs().max()}" + assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}" + assert torch.allclose(dK, dK_ref, rtol=rtol, atol=atol), f"dK max err: {(dK - dK_ref).abs().max()}" + assert torch.allclose(dQ, dQ_ref, rtol=rtol, atol=atol), f"dQ max err: {(dQ - dQ_ref).abs().max()}" + assert torch.allclose(dsinks, dsinks_ref, rtol=rtol, atol=atol), f"dsinks max err: {(dsinks - dsinks_ref).abs().max()}" + + print("All checks passed for tilelang kernels.✅") + + # Benchmark backward + def torch_bwd(): + O_ref_unpad.backward(dO_unpad, retain_graph=True) + + def tl_bwd(): + O_unpad.backward(dO_unpad, retain_graph=True) + + latency = do_bench(torch_bwd, warmup=500) + print("torch: {:.2f} ms".format(latency)) + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(tl_bwd, warmup=500) + print("tilelang: {:.2f} ms".format(latency)) + print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="query heads") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--q_seqlen", type=int, default=2048, help="query sequence length") + parser.add_argument("--k_seqlen", type=int, default=2048, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=128, help="head dim") + parser.add_argument("--is_causal", action="store_true", help="causal attention") + parser.add_argument("--window_size", type=int, default=None, help="sliding window size (default: None for full attention)") + args = parser.parse_args() + main(args.batch, args.heads, args.q_seqlen, args.k_seqlen, args.dim, args.groups, args.is_causal, args.window_size) diff --git a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py index 7765603af..fa73df0af 100644 --- a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py +++ b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py @@ -6,7 +6,6 @@ from tilelang.autotuner import autotune from tilelang.profiler import do_bench import tilelang.language as T -from tilelang.layout import make_swizzled_layout import itertools import argparse from typing import Optional @@ -23,9 +22,11 @@ def get_configs(): rep=100, ) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn( batch, heads, @@ -39,106 +40,30 @@ def flashattn( block_N=128, num_stages=2, threads=256, - dtype: str = "float16", + dtype: T.dtype = T.float16, ): - if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, head_kv, seq_kv, dim] - accum_dtype = "float" + accum_dtype = T.float32 past_len = seq_kv - seq_q assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" - @T.macro - def MMA0( - K: T.Tensor(kv_shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared) - for i, j in T.Parallel(block_M, block_N): - q_idx = bx * block_M + i + past_len - k_idx = k * block_N + j - if window_size is not None: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, - -T.infinity(acc_s.dtype)) - else: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(kv_shape, dtype), - V_shared: T.SharedBuffer([block_M, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, by // groups, k * block_N:(k + 1) * block_N, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # NOTE(wt): check_inf is necessary for sliding window attention. - for i in T.Parallel(block_M): - if window_size is not None: - scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, - scores_max[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), - Sinks: T.Tensor([heads], dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + Sinks: T.Tensor([heads], dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -155,58 +80,83 @@ def main( logsum = T.alloc_fragment([block_M], accum_dtype) sinks = T.alloc_fragment([block_M], dtype) - T.annotate_layout({ - Q_shared: make_swizzled_layout(Q_shared), - K_shared: make_swizzled_layout(K_shared), - V_shared: make_swizzled_layout(V_shared), - O_shared: make_swizzled_layout(O_shared), - }) - - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) for i in T.Parallel(block_M): sinks[i] = Sinks[by] - end = T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) - start = T.max(0, (bx * block_M + past_len - window_size) // - block_N) if window_size is not None else 0 + start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0 for k in T.Pipelined( - start, - end, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + start, + end, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): + T.copy(K[bz, by // groups, k * block_N : (k + 1) * block_N, :], K_shared) + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + if window_size is not None: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) + else: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # NOTE(wt): check_inf is necessary for sliding window attention. + for i in T.Parallel(block_M): + if window_size is not None: + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, by // groups, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i in T.Parallel(block_M): - logsum[i] += T.exp2(sinks[i] * 1.44269504 - - scores_max[i] * scale) # The only change for attention sink + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main # Following functions are adapted and optimized from # https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py -def ref_program(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sinks: torch.Tensor, - sliding_window: Optional[int] = None, - dtype: torch.dtype = torch.float16) -> torch.Tensor: - +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() batch_size, num_keys, num_key_value_heads, head_dim = key.shape @@ -242,23 +192,15 @@ def ref_program(query: torch.Tensor, output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) - output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, - head_dim).to(dtype) + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) return output.transpose(1, 2).contiguous() -def gen_inputs( - B, - H, - Sq, - Skv, - D, - groups, - dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda') - key = torch.randn([B, H // groups, Skv, D], dtype=dtype, device='cuda') - value = torch.randn([B, H // groups, Skv, D], dtype=dtype, device='cuda') - sinks = torch.randn([H], dtype=dtype, device='cuda') +def gen_inputs(B, H, Sq, Skv, D, groups, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda") + key = torch.randn([B, H // groups, Skv, D], dtype=dtype, device="cuda") + value = torch.randn([B, H // groups, Skv, D], dtype=dtype, device="cuda") + sinks = torch.randn([H], dtype=dtype, device="cuda") return query, key, value, sinks @@ -270,17 +212,17 @@ def main( dim: int = 128, groups: int = 8, window_size: Optional[int] = None, - dtype: str = "float16", + dtype: T.dtype = T.float16, tune: bool = False, ): - torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= seq_q - flops_per_matmul = 2.0 * batch * heads * min( - window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 total_flops = 2 * flops_per_matmul @@ -308,15 +250,14 @@ def main( block_N=block_N, num_stages=num_stages, threads=threads, - dtype=dtype) + dtype=dtype, + ) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype) torch.testing.assert_close( - kernel(Q, K, V, sinks), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2) + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) print("All checks passed.✅") # Benchmark tilelang @@ -325,22 +266,51 @@ def main( print("Tilelang: {:.2f} TFlops".format(total_flops / latency_tilelang * 1e-9)) +def run_regression_perf( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + groups: int = 8, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False, +): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + kernel = flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + groups, + window_size, + block_M=block_M, + block_N=block_N, + num_stages=num_stages, + threads=threads, + dtype=dtype, + ) + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype) + latency = do_bench(lambda: kernel(Q, K, V, sinks), backend="cupti") + return latency + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='heads') - parser.add_argument('--seq_q', type=int, default=2048, help='sequence length of query') - parser.add_argument('--seq_kv', type=int, default=2048, help='sequence length of key/value') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--groups', type=int, default=8, help='groups') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_q", type=int, default=2048, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=2048, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--groups", type=int, default=8, help="groups") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, - args.dtype, args.tune) + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/example_gqa_sink_fwd_varlen.py b/examples/attention_sink/example_gqa_sink_fwd_varlen.py new file mode 100644 index 000000000..16838dd86 --- /dev/null +++ b/examples/attention_sink/example_gqa_sink_fwd_varlen.py @@ -0,0 +1,401 @@ +# ruff: noqa +# Using varlen (variable length) format with attention sink + +import argparse +import torch +import tilelang +import tilelang.language as T +import tilelang.testing +from tilelang.profiler import do_bench +from typing import Optional +import sys +import os + +sys.path.append(os.path.join(os.path.dirname(__file__), "../flash_attention")) +from varlen_utils import generate_random_padding_mask, generate_qkv + + +@tilelang.jit( + out_idx=[7], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_sink( + batch_size, + groups, + UQ, + UKV, + heads, + dim, + is_causal, + window_size=None, # None for full causal attention + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, +): + if window_size is not None: + assert window_size % block_N == 0, "window_size must be divisible by block_N" + + if sm_scale is None: + sm_scale = (1.0 / dim) ** 0.5 + scale = sm_scale * 1.44269504 # log2(e) + + head_kv = heads // groups + q_shape = [UQ, heads, dim] + kv_shape = [UKV, head_kv, dim] + o_shape = [UQ, heads, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def main( + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(kv_shape, dtype), + V_unpad: T.Tensor(kv_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), + cu_seqlens_k: T.Tensor([batch_size + 1], T.int32), + max_seqlen_q: T.int32, + Sinks: T.Tensor([heads], dtype), + Output_unpad: T.Tensor(o_shape, dtype), + ): + with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + sinks = T.alloc_fragment([block_M], dtype) + + batch_idx = bz + head_idx = by + kv_head_idx = head_idx // groups + + q_start_idx = cu_seqlens_q[batch_idx] + kv_start_idx = cu_seqlens_k[batch_idx] + q_end_idx = cu_seqlens_q[batch_idx + 1] + k_end_idx = cu_seqlens_k[batch_idx + 1] + + q_current_seqlen = q_end_idx - q_start_idx + kv_current_seqlen = k_end_idx - kv_start_idx + + T.copy(Q_unpad[q_start_idx + bx * block_M : q_start_idx + (bx + 1) * block_M, head_idx, :], Q_shared) + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + for i in T.Parallel(block_M): + sinks[i] = Sinks[head_idx] + + offset = kv_current_seqlen - q_current_seqlen # always align on the right + max_visible_k_idx = offset + (bx + 1) * block_M + + # Determine loop range based on causal mask and sliding window + if is_causal: + if window_size is not None: + # Sliding window + causal: start from window boundary + start = T.max(0, (offset + bx * block_M - window_size + 1) // block_N) + end = T.min(T.ceildiv(max_visible_k_idx, block_N), T.ceildiv(kv_current_seqlen, block_N)) + else: + # Full causal attention + start = 0 + end = T.min(T.ceildiv(max_visible_k_idx, block_N), T.ceildiv(kv_current_seqlen, block_N)) + else: + if window_size is not None: + start = T.max(0, (offset + bx * block_M - window_size + 1) // block_N) + end = T.ceildiv(kv_current_seqlen, block_N) + else: + start = 0 + end = T.ceildiv(kv_current_seqlen, block_N) + + loop_range = end - start + + for k in T.Pipelined(loop_range, num_stages=num_stages): + actual_k = k + start + T.copy(K_unpad[kv_start_idx + actual_k * block_N : kv_start_idx + (actual_k + 1) * block_N, kv_head_idx, :], K_shared) + + # Build mask considering causal, sliding window, and padding + if is_causal: + if window_size is not None: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + offset + k_idx = actual_k * block_N + j + # Causal + sliding window mask + acc_s[i, j] = T.if_then_else( + (q_idx < k_idx) # causal: can't see future + or (q_idx >= k_idx + window_size) # sliding window: too old + or (bx * block_M + i >= q_current_seqlen or actual_k * block_N + j >= kv_current_seqlen), + -T.infinity(acc_s.dtype), + 0, + ) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (bx * block_M + i + offset < actual_k * block_N + j) + or (bx * block_M + i >= q_current_seqlen or actual_k * block_N + j >= kv_current_seqlen), + -T.infinity(acc_s.dtype), + 0, + ) + else: + if window_size is not None: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + offset + k_idx = actual_k * block_N + j + acc_s[i, j] = T.if_then_else( + (q_idx >= k_idx + window_size) # sliding window: too old + or (bx * block_M + i >= q_current_seqlen or actual_k * block_N + j >= kv_current_seqlen), + -T.infinity(acc_s.dtype), + 0, + ) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= q_current_seqlen or actual_k * block_N + j >= kv_current_seqlen), + -T.infinity(acc_s.dtype), + 0, + ) + + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + # Check_inf for sliding window attention + if window_size is not None: + for i in T.Parallel(block_M): + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V_unpad[kv_start_idx + actual_k * block_N : kv_start_idx + (actual_k + 1) * block_N, kv_head_idx, :], V_shared) + + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + # Attention sink: add sink contribution to logsum + for i in T.Parallel(block_M): + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, dim): + # When sq > skv, some tokens can see nothing (for causal) + acc_o[i, j] = 0 if is_causal and bx * block_M + i + offset < 0 else acc_o[i, j] / logsum[i] + + T.copy(acc_o, O_shared) + for i, d in T.Parallel(block_M, dim): + if bx * block_M + i < q_current_seqlen: + Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d] + + return main + + +def ref_program( + q_unpad: torch.Tensor, + k_unpad: torch.Tensor, + v_unpad: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sinks: torch.Tensor, + batch_size: int, + is_causal: bool, + sliding_window: Optional[int] = None, + groups: int = 1, +) -> torch.Tensor: + """Reference implementation for varlen attention with sinks.""" + # q_unpad: [total_q, heads, dim] + # k_unpad: [total_kv, head_kv, dim] + # v_unpad: [total_kv, head_kv, dim] + total_q, num_heads, head_dim = q_unpad.shape + _, num_key_value_heads, _ = k_unpad.shape + + sm_scale = 1.0 / head_dim**0.5 + + output = torch.zeros_like(q_unpad) + + for b in range(batch_size): + q_start = cu_seqlens_q[b].item() + q_end = cu_seqlens_q[b + 1].item() + k_start = cu_seqlens_k[b].item() + k_end = cu_seqlens_k[b + 1].item() + + q_len = q_end - q_start + k_len = k_end - k_start + + if q_len == 0: + continue + + # Extract sequences for this batch + q_seq = q_unpad[q_start:q_end] # [q_len, heads, dim] + k_seq = k_unpad[k_start:k_end] # [k_len, head_kv, dim] + v_seq = v_unpad[k_start:k_end] # [k_len, head_kv, dim] + + # Reshape for GQA + q_seq = q_seq.view(q_len, num_key_value_heads, groups, head_dim) # [q_len, head_kv, groups, dim] + sinks_expanded = sinks.view(num_key_value_heads, groups, 1, 1).float() # [head_kv, groups, 1, 1] + + k_seq = k_seq.unsqueeze(2) # [k_len, head_kv, 1, dim] + v_seq = v_seq.unsqueeze(2) # [k_len, head_kv, 1, dim] + + # Compute attention + # q_seq: [q_len, head_kv, groups, dim], k_seq: [k_len, head_kv, 1, dim] + logits = torch.einsum("qhgd,khgd->hgqk", q_seq.float(), k_seq.float()) * sm_scale + + # Build mask + start_q = k_len - q_len # offset for causal alignment + pos_keys = torch.arange(k_len, device=q_unpad.device) + pos_queries = torch.arange(q_len, device=q_unpad.device) + start_q + + if is_causal: + mask = pos_keys[None, :] > pos_queries[:, None] + mask = mask.float().masked_fill(mask, float("-inf")) + else: + mask = torch.zeros(q_len, k_len, device=q_unpad.device) + + if sliding_window is not None: + too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1) + mask.masked_fill_(too_old, float("-inf")) + + logits = logits + mask[None, None, :, :] # [head_kv, groups, q_len, k_len] + + # Apply sink-adjusted softmax + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(sinks_expanded, logits_max) + sinks_exp = torch.exp(sinks_expanded - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks_exp + scores = unnormalized_scores / normalizer + + # Compute output + out = torch.einsum("hgqk,khgd->qhgd", scores, v_seq.float()) + out = out.reshape(q_len, num_heads, head_dim).to(q_unpad.dtype) + + output[q_start:q_end] = out + + return output + + +def main( + batch: int = 1, + heads: int = 64, + q_seqlen: int = 2048, + k_seqlen: int = 2048, + dim: int = 128, + groups: int = 16, + is_causal: bool = True, + window_size: Optional[int] = None, +): + assert heads % groups == 0, "heads must be divisible by groups" + + flops_per_matmul = 2.0 * batch * heads * q_seqlen * k_seqlen * dim + total_flops = 2 * flops_per_matmul + + tilelang.testing.set_random_seed(0) + + if is_causal: + total_flops *= 0.5 + + if window_size is not None: + print(f"Using sliding window attention with window_size={window_size}") + flops_per_matmul = 2.0 * batch * heads * min(window_size, k_seqlen // 2) * q_seqlen * dim + total_flops = 2 * flops_per_matmul + + dtype = torch.float16 + device = torch.device("cuda") + + head_kv = heads // groups + q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device) + k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device) + v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device) + sinks = torch.randn(heads, dtype=dtype, device=device) + + query_padding_mask = generate_random_padding_mask(q_seqlen, batch, device, mode="random") + key_padding_mask = generate_random_padding_mask(k_seqlen, batch, device, mode="random") + + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + _, + _, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + + UQ = q_unpad.shape[0] + UKV = k_unpad.shape[0] + + kernel = flashattn_sink( + batch, groups, UQ, UKV, heads, dim, is_causal, window_size=window_size, block_M=128, block_N=128, num_stages=2, threads=256 + ) + + out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, sinks) + out = output_pad_fn(out_unpad) + + # Reference implementation + ref_out_unpad = ref_program( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sinks, + batch, + is_causal, + sliding_window=window_size, + groups=groups, + ) + ref_out = output_pad_fn(ref_out_unpad) + + torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=1e-2) + + print("All checks passed.✅") + latency = do_bench( + lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, sinks), + warmup=500, + ) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="query heads") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--q_seqlen", type=int, default=2048, help="query sequence length") + parser.add_argument("--k_seqlen", type=int, default=2048, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=128, help="head dim") + parser.add_argument("--is_causal", action="store_true", help="causal attention") + parser.add_argument("--window_size", type=int, default=None, help="sliding window size (default: None for full attention)") + args = parser.parse_args() + main(args.batch, args.heads, args.q_seqlen, args.k_seqlen, args.dim, args.groups, args.is_causal, args.window_size) diff --git a/examples/attention_sink/example_mha_sink_bwd_bhsd.py b/examples/attention_sink/example_mha_sink_bwd_bhsd.py index 866668e41..66905f55d 100644 --- a/examples/attention_sink/example_mha_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_bwd_bhsd.py @@ -20,40 +20,42 @@ def get_bwd_configs(): @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd( - batch, - heads, - seq_len, - dim, - window_size=None, # None for full attention, - sm_scale=None, - block_M=64, - block_N=64, - num_stages=1, - threads=128, - dtype: str = "float16"): - + batch, + heads, + seq_len, + dim, + window_size=None, # None for full attention, + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype: T.dtype = T.float16, +): if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] - accum_dtype = "float" + accum_dtype = T.float32 @T.prim_func def flash_fwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - Output: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Sinks: T.Tensor([heads], dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Sinks: T.Tensor([heads], dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -69,8 +71,7 @@ def flash_fwd( logsum = T.alloc_fragment([block_M], accum_dtype) sinks = T.alloc_fragment([heads], dtype) - T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -78,31 +79,30 @@ def flash_fwd( sinks[i] = Sinks[by] end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) - start = T.max(0, - (bx * block_M - window_size) // block_N) if window_size is not None else 0 + start = T.max(0, (bx * block_M - window_size) // block_N) if window_size is not None else 0 for k in T.Pipelined(start, end, num_stages=num_stages): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) for i, j in T.Parallel(block_M, block_N): q_idx = bx * block_M + i k_idx = k * block_N + j if window_size is not None: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, - 0, -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) else: acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # NOTE(wt): check_inf is necessary for sliding window attention. for i in T.Parallel(block_M): if window_size is not None: - scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, - scores_max[i]) + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): @@ -119,32 +119,33 @@ def flash_fwd( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i in T.Parallel(block_M): - logsum[i] += T.exp2(sinks[i] * 1.44269504 - - scores_max[i] * scale) # The only change for attention sink + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"): - accum_dtype = "float" + }, +) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 shape = [batch, heads, seq_len, dim] blk = 32 @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -153,49 +154,52 @@ def flash_bwd_prep( delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], do) + T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment - return T.Layout(dQ.shape, - lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) @tilelang.jit( - out_idx=[1], pass_configs={ + out_idx=[1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"): - accum_dtype = "float" + }, +) +def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 shape = [batch, heads, seq_len, dim] blk = 64 @T.prim_func def flash_bwd_post( - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(shape, dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) T.copy( - dQ[bz, by, bx * blk:(bx + 1) * blk, :], - dQ_out[bz, by, bx * blk:(bx + 1) * blk, :], + dQ[bz, by, bx * blk : (bx + 1) * blk, :], + dQ_out[bz, by, bx * blk : (bx + 1) * blk, :], ) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def flashattn_bwd( batch, heads, @@ -203,32 +207,31 @@ def flashattn_bwd( dim, window_size=None, # None for full attention sm_scale=None, - dtype: str = "float16", + dtype: T.dtype = T.float16, ): - block_M, block_N, num_stages, threads = get_bwd_configs() if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] - accum_dtype = "float" + accum_dtype = T.float32 if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" @T.prim_func def flash_bwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dK: T.Tensor(shape, dtype), # type: ignore - dV: T.Tensor(shape, dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) @@ -252,43 +255,43 @@ def flash_bwd( dv_shared = T.alloc_shared([block_M, dim], dtype) dk_shared = T.alloc_shared([block_M, dim], dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) - T.copy(K[bz, bx, by * block_M:(by + 1) * block_M, :], K_shared) - T.copy(V[bz, bx, by * block_M:(by + 1) * block_M, :], V_shared) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) + T.copy(K[bz, bx, by * block_M : (by + 1) * block_M, :], K_shared) + T.copy(V[bz, bx, by * block_M : (by + 1) * block_M, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) - loop_ed = T.min( - T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv( - seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N) + loop_ed = ( + T.min(T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(seq_len, block_N)) + if window_size is not None + else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q) + T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) for i, j in T.Parallel(block_M, block_N): if window_size is not None: qkT[i, j] = T.if_then_else( - by * block_M + i <= k * block_N + j and - by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0) + by * block_M + i <= k * block_N + j and by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0 + ) else: - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], dst=do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], dst=do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, B=do, C=dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -297,51 +300,48 @@ def flash_bwd( T.copy(dsT_cast, dsT_shared) T.clear(dq) T.gemm(dsT_shared, K_shared, dq, transpose_A=True) - T.atomic_add(dQ[bz, bx, k * block_N:(k + 1) * block_N, :], dq) + T.atomic_add(dQ[bz, bx, k * block_N : (k + 1) * block_N, :], dq) T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, dV[bz, bx, by * block_M:(by + 1) * block_M, :]) - T.copy(dk_shared, dK[bz, bx, by * block_M:(by + 1) * block_M, :]) + T.copy(dv_shared, dV[bz, bx, by * block_M : (by + 1) * block_M, :]) + T.copy(dk_shared, dK[bz, bx, by * block_M : (by + 1) * block_M, :]) return flash_bwd @tilelang.jit(out_idx=-1) -def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: str = "float16"): - accum_dtype = "float" +def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: T.dtype = T.float16): + accum_dtype = T.float32 shape = [batch, heads, seq_len] @T.prim_func def flash_bwd_dsink( - Sinks: T.Tensor([heads], dtype), # type: ignore - Delta: T.Tensor(shape, accum_dtype), # type: ignore - lse: T.Tensor(shape, accum_dtype), # type: ignore - dsinks: T.Tensor(shape, accum_dtype), # type: ignore + Sinks: T.Tensor([heads], dtype), # type: ignore + Delta: T.Tensor(shape, accum_dtype), # type: ignore + lse: T.Tensor(shape, accum_dtype), # type: ignore + dsinks: T.Tensor(shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=128) as (bx, by, bz): - sink = T.alloc_local([1], dtype) lse_fragment = T.alloc_fragment([block], accum_dtype) delta_fragment = T.alloc_fragment([block], accum_dtype) dsink_fragment = T.alloc_fragment([block], accum_dtype) - sink[0] = Sinks[bx] - T.copy(lse[bz, bx, by * block:(by + 1) * block], lse_fragment) - T.copy(Delta[bz, bx, by * block:(by + 1) * block], delta_fragment) + sink = Sinks[bx] + T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment) + T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment) for i in T.Parallel(block): - dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - - lse_fragment[i]) * delta_fragment[i] - T.copy(dsink_fragment, dsinks[bz, bx, by * block:(by + 1) * block]) + dsink_fragment[i] = -T.exp2(sink * 1.44269504 - lse_fragment[i]) * delta_fragment[i] + T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block]) return flash_bwd_dsink class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, sinks, window_size): BATCH, H, N_CTX, D_HEAD = q.shape - dtype = "float16" if q.dtype == torch.float16 else "bfloat16" + dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16 kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, window_size, dtype=dtype) o, lse = kernel(q, k, v, sinks) ctx.save_for_backward(q, k, v, sinks, o, lse) @@ -359,7 +359,7 @@ def maybe_contiguous(x): return x do, q, k, v, sinks, o = [maybe_contiguous(x) for x in (do, q, k, v, sinks, o)] - dtype = "float16" if q.dtype == torch.float16 else "bfloat16" + dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16 kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) delta = kernel_prep(o, do) @@ -381,15 +381,15 @@ def maybe_contiguous(x): # Adapted and optimized from # https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py -def ref_program(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sinks: torch.Tensor, - sliding_window: Optional[int] = None, - dtype: torch.dtype = torch.float16) -> torch.Tensor: - - query = query.transpose(1, 2).contiguous().unsqueeze( - 3) # align with the original function's interface +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function's interface key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() @@ -424,29 +424,23 @@ def ref_program(query: torch.Tensor, output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) - output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, - head_dim).to(dtype) + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) return output.transpose(1, 2).contiguous() -def main(BATCH: int = 1, - H: int = 1, - N_CTX: int = 512, - D_HEAD: int = 128, - window_size: Optional[int] = None, - dtype: str = "float16"): - torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] +def main(BATCH: int = 1, H: int = 1, N_CTX: int = 512, D_HEAD: int = 128, window_size: Optional[int] = None, dtype: T.dtype = T.float16): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= N_CTX - flops_per_matmul = 2.0 * BATCH * H * min( - window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation + flops_per_matmul = 2.0 * BATCH * H * min(window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5 total_flops = 5 * flops_per_matmul - Q = (torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()) + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() K = torch.randn_like(Q).requires_grad_() V = torch.randn_like(Q).requires_grad_() sinks = torch.randn(H, dtype=torch_dtype, device=Q.device).requires_grad_() @@ -468,19 +462,14 @@ def main(BATCH: int = 1, # Checks rtol, atol = { - "float16": (1e-2, 1e-2), - "bfloat16": (2e-2, 2e-2), + T.float16: (1e-2, 1e-2), + T.bfloat16: (2e-2, 2e-2), }[dtype] - assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f'O max err: {(O-O_ref).abs().max()}' - assert torch.allclose( - dV, dV_ref, rtol=rtol, atol=atol), f'dV max err: {(dV-dV_ref).abs().max()}' - assert torch.allclose( - dK, dK_ref, rtol=rtol, atol=atol), f'dK max err: {(dK-dK_ref).abs().max()}' - assert torch.allclose( - dQ, dQ_ref, rtol=rtol, atol=atol), f'dq max err: {(dQ-dQ_ref).abs().max()}' - assert torch.allclose( - dsinks, dsinks_ref, rtol=rtol, - atol=atol), f'dsinks max err: {(dsinks-dsinks_ref).abs().max()}' + assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}" + assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}" + assert torch.allclose(dK, dK_ref, rtol=rtol, atol=atol), f"dK max err: {(dK - dK_ref).abs().max()}" + assert torch.allclose(dQ, dQ_ref, rtol=rtol, atol=atol), f"dq max err: {(dQ - dQ_ref).abs().max()}" + assert torch.allclose(dsinks, dsinks_ref, rtol=rtol, atol=atol), f"dsinks max err: {(dsinks - dsinks_ref).abs().max()}" print("All checks passed for tilelang kernels.✅") @@ -499,18 +488,53 @@ def tl_bwd(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf( + BATCH: int = 1, + H: int = 32, + N_CTX: int = 512, + D_HEAD: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", +): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + with torch.no_grad(): + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda") + K = torch.randn_like(Q) + V = torch.randn_like(Q) + sinks = torch.randn(H, dtype=torch_dtype, device=Q.device) + dO = torch.randn_like(Q) + fwd = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, window_size=window_size, dtype=dtype) + O, lse = fwd(Q, K, V, sinks) + + def maybe_contiguous(x): + return x if x.stride(-1) == 1 else x.contiguous() + + do, q, k, v, sinks_c, o = [maybe_contiguous(x) for x in (dO, Q, K, V, sinks, O)] + k_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) + Delta = k_prep(o, do) + k_bwd = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, window_size, dtype=dtype) + k_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype) + shape = (BATCH, H, N_CTX, D_HEAD) + dq = torch.zeros(shape, dtype=torch.float32, device=Q.device) + dk = torch.empty(shape, dtype=torch_dtype, device=Q.device) + dv = torch.empty(shape, dtype=torch_dtype, device=Q.device) + k_bwd(q, k, v, do, lse, Delta, dq, dk, dv) + _ = k_dsink(sinks_c, Delta, lse).sum(0).sum(1) + + def run_kernel_only(): + k_bwd(q, k, v, do, lse, Delta, dq, dk, dv) + + latency_ms = do_bench(run_kernel_only, backend="cupti") + return latency_ms + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='Batch size') - parser.add_argument('--h', type=int, default=64, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=4096, help='Context size') - parser.add_argument('--d_head', type=int, default=128, help='Head dimension') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--batch", type=int, default=1, help="Batch size") + parser.add_argument("--h", type=int, default=64, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=4096, help="Context size") + parser.add_argument("--d_head", type=int, default=128, help="Head dimension") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head, args.window_size, args.dtype) diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd.py b/examples/attention_sink/example_mha_sink_fwd_bhsd.py index 2449b090c..f24aa38b7 100644 --- a/examples/attention_sink/example_mha_sink_fwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd.py @@ -5,7 +5,6 @@ from tilelang.autotuner import autotune from tilelang.profiler import do_bench import tilelang.language as T -from tilelang.layout import make_swizzled_layout import itertools import argparse from typing import Optional @@ -18,117 +17,45 @@ def get_configs(): @autotune(configs=get_configs(), warmup=500, rep=100) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn( - batch, - heads, - seq_q, - seq_kv, - dim, - window_size=None, # None for full attention - sm_scale=None, - block_M=64, - block_N=64, - num_stages=1, - threads=128, - dtype: str = "float16"): + batch, + heads, + seq_q, + seq_kv, + dim, + window_size=None, # None for full attention + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype: T.dtype = T.float16, +): if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] - accum_dtype = "float" + accum_dtype = T.float32 past_len = seq_kv - seq_q assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" - @T.macro - def MMA0( - K: T.Tensor(kv_shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) - for i, j in T.Parallel(block_M, block_N): - q_idx = bx * block_M + i + past_len - k_idx = k * block_N + j - if window_size is not None: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, - -T.infinity(acc_s.dtype)) - else: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(kv_shape, dtype), - V_shared: T.SharedBuffer([block_M, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # NOTE(wt): check_inf is necessary for sliding window attention. - for i in T.Parallel(block_M): - if window_size is not None: - scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, - scores_max[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), - Sinks: T.Tensor([heads], dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + Sinks: T.Tensor([heads], dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -145,53 +72,76 @@ def main( logsum = T.alloc_fragment([block_M], accum_dtype) sinks = T.alloc_fragment([block_M], dtype) - T.annotate_layout({ - Q_shared: make_swizzled_layout(Q_shared), - K_shared: make_swizzled_layout(K_shared), - V_shared: make_swizzled_layout(V_shared), - O_shared: make_swizzled_layout(O_shared), - }) - - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) for i in T.Parallel(block_M): sinks[i] = Sinks[by] - end = T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) - start = T.max(0, (bx * block_M + past_len - window_size) // - block_N) if window_size is not None else 0 + start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0 for k in T.Pipelined(start, end, num_stages=num_stages): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + if window_size is not None: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) + else: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # NOTE(wt): check_inf is necessary for sliding window attention. + for i in T.Parallel(block_M): + if window_size is not None: + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i in T.Parallel(block_M): - logsum[i] += T.exp2(sinks[i] * 1.44269504 - - scores_max[i] * scale) # The only change for attention sink + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main # Modified from https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py -def ref_program(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sinks: torch.Tensor, - sliding_window: Optional[int] = None, - dtype: torch.dtype = torch.float16) -> torch.Tensor: - - query = query.transpose(1, 2).contiguous().unsqueeze( - 3) # align with the original function's interface +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function's interface key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() @@ -226,41 +176,36 @@ def ref_program(query: torch.Tensor, output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) - output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, - head_dim).to(dtype) + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) return output.transpose(1, 2).contiguous() -def gen_inputs( - B, - H, - Sq, - Skv, - D, - dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda') - key = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda') - value = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda') - sinks = torch.randn([H], dtype=dtype, device='cuda') +def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda") + key = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + value = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + sinks = torch.randn([H], dtype=dtype, device="cuda") return query, key, value, sinks -def main(batch: int = 1, - heads: int = 1, - seq_q: int = 256, - seq_kv: int = 256, - dim: int = 128, - window_size: Optional[int] = None, - dtype: str = "float16", - tune: bool = False): - torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] +def main( + batch: int = 1, + heads: int = 1, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: T.dtype = T.float16, + tune: bool = False, +): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= seq_q - flops_per_matmul = 2.0 * batch * heads * min( - window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 total_flops = 2 * flops_per_matmul @@ -287,19 +232,17 @@ def main(batch: int = 1, block_N=block_N, num_stages=num_stages, threads=threads, - dtype=dtype) + dtype=dtype, + ) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) torch.testing.assert_close( - kernel(Q, K, V, sinks), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2) + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) print("All checks passed.✅") - latency = do_bench( - lambda: ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), warmup=500) + latency = do_bench(lambda: ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), warmup=500) print("Ref: {:.2f} ms".format(latency)) print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) @@ -307,21 +250,37 @@ def main(batch: int = 1, print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", +): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + kernel = flashattn( + batch, heads, seq_q, seq_kv, dim, window_size, block_M=block_M, block_N=block_N, num_stages=num_stages, threads=threads, dtype=dtype + ) + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) + latency = do_bench(lambda: kernel(Q, K, V, sinks), backend="cupti") + return latency + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query') - parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") - parser.add_argument('--tune', action='store_true', help='tune') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default=T.float16, help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune") args = parser.parse_args() - main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, - args.tune) + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py index 352844075..b47c8175f 100644 --- a/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py @@ -6,7 +6,6 @@ from tilelang.autotuner import autotune from tilelang.profiler import do_bench import tilelang.language as T -from tilelang.layout import make_swizzled_layout import itertools import argparse from typing import Optional @@ -19,119 +18,46 @@ def get_configs(): @autotune(configs=get_configs(), warmup=500, rep=100) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn( - batch, - heads, - seq_q, - seq_kv, - dim, - window_size=None, # None for full attention - sm_scale=None, - block_M=128, - block_N=128, - num_stages=2, - threads=256, - dtype: str = "float16"): - + batch, + heads, + seq_q, + seq_kv, + dim, + window_size=None, # None for full attention + sm_scale=None, + block_M=128, + block_N=128, + num_stages=2, + threads=256, + dtype: T.dtype = T.float16, +): if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] - accum_dtype = "float" + accum_dtype = T.float32 past_len = seq_kv - seq_q assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" - @T.macro - def MMA0( - K: T.Tensor(kv_shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) - for i, j in T.Parallel(block_M, block_N): - q_idx = bx * block_M + i + past_len - k_idx = k * block_N + j - if window_size is not None: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, - -T.infinity(acc_s.dtype)) - else: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(kv_shape, dtype), - V_shared: T.SharedBuffer([block_M, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # NOTE(wt): check_inf is necessary for sliding window attention. - for i in T.Parallel(block_M): - if window_size is not None: - scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, - scores_max[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), - Sinks: T.Tensor([heads], dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + Sinks: T.Tensor([heads], dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -148,60 +74,84 @@ def main( logsum = T.alloc_fragment([block_M], accum_dtype) sinks = T.alloc_fragment([block_M], dtype) - T.annotate_layout({ - Q_shared: make_swizzled_layout(Q_shared), - K_shared: make_swizzled_layout(K_shared), - V_shared: make_swizzled_layout(V_shared), - O_shared: make_swizzled_layout(O_shared), - }) - - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) for i in T.Parallel(block_M): sinks[i] = Sinks[by] - end = T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) - start = T.max(0, (bx * block_M + past_len - window_size) // - block_N) if window_size is not None else 0 + start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0 for k in T.Pipelined( - start, - end, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + start, + end, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + if window_size is not None: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) + else: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # NOTE(wt): check_inf is necessary for sliding window attention. + for i in T.Parallel(block_M): + if window_size is not None: + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i in T.Parallel(block_M): - logsum[i] += T.exp2(sinks[i] * 1.44269504 - - scores_max[i] * scale) # The only change for attention sink + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main # Following functions are adapted and optimized from # https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py -def ref_program(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sinks: torch.Tensor, - sliding_window: Optional[int] = None, - dtype: torch.dtype = torch.float16) -> torch.Tensor: - - query = query.transpose(1, 2).contiguous().unsqueeze( - 3) # align with the original function'sinterface +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function'sinterface key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() @@ -236,41 +186,36 @@ def ref_program(query: torch.Tensor, output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) - output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, - head_dim).to(dtype) + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) return output.transpose(1, 2).contiguous() -def gen_inputs( - B, - H, - Sq, - Skv, - D, - dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda') - key = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda') - value = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda') - sinks = torch.randn([H], dtype=dtype, device='cuda') +def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda") + key = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + value = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + sinks = torch.randn([H], dtype=dtype, device="cuda") return query, key, value, sinks -def main(batch: int = 1, - heads: int = 32, - seq_q: int = 256, - seq_kv: int = 256, - dim: int = 128, - window_size: Optional[int] = None, - dtype: str = "float16", - tune: bool = False): - torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] +def main( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: T.dtype = T.float16, + tune: bool = False, +): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= seq_q - flops_per_matmul = 2.0 * batch * heads * min( - window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 total_flops = 2 * flops_per_matmul @@ -297,15 +242,14 @@ def main(batch: int = 1, block_N=block_N, num_stages=num_stages, threads=threads, - dtype=dtype) + dtype=dtype, + ) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) torch.testing.assert_close( - kernel(Q, K, V, sinks), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2) + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) print("All checks passed.✅") latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) @@ -313,21 +257,38 @@ def main(batch: int = 1, print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False, +): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + kernel = flashattn( + batch, heads, seq_q, seq_kv, dim, window_size, block_M=block_M, block_N=block_N, num_stages=num_stages, threads=threads, dtype=dtype + ) + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) + latency = do_bench(lambda: kernel(Q, K, V, sinks), backend="cupti") + return latency + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query') - parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") - parser.add_argument('--tune', action='store_true', help='tune') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default=T.float16, help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune") args = parser.parse_args() - main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, - args.tune) + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/regression_attention_sink.py b/examples/attention_sink/regression_attention_sink.py new file mode 100644 index 000000000..e2453173c --- /dev/null +++ b/examples/attention_sink/regression_attention_sink.py @@ -0,0 +1,64 @@ +import tilelang.testing +import example_mha_sink_fwd_bhsd +import example_mha_sink_fwd_bhsd_wgmma_pipelined +import example_mha_sink_bwd_bhsd +import example_gqa_sink_bwd_bhsd +import example_gqa_sink_fwd_bhsd_wgmma_pipelined + + +def regression_example_mha_sink_fwd_bhsd(): + tilelang.testing.process_func(example_mha_sink_fwd_bhsd.run_regression_perf) + + +def regression_example_mha_sink_fwd_bhsd_sliding_window(): + tilelang.testing.process_func( + example_mha_sink_fwd_bhsd.run_regression_perf, "regression_example_mha_sink_fwd_bhsd_sliding_window", window_size=128 + ) + + +def regression_example_mha_sink_fwd_bhsd_wgmma_pipelined(): + tilelang.testing.process_func(example_mha_sink_fwd_bhsd_wgmma_pipelined.run_regression_perf) + + +def regression_example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window(): + tilelang.testing.process_func( + example_mha_sink_fwd_bhsd_wgmma_pipelined.run_regression_perf, + "regression_example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window", + window_size=128, + ) + + +def regression_example_gqa_sink_fwd_bhsd_wgmma_pipelined(): + tilelang.testing.process_func(example_gqa_sink_fwd_bhsd_wgmma_pipelined.run_regression_perf) + + +def regression_example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window(): + tilelang.testing.process_func( + example_gqa_sink_fwd_bhsd_wgmma_pipelined.run_regression_perf, + "regression_example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window", + window_size=128, + ) + + +def regression_example_mha_sink_bwd_bhsd(): + tilelang.testing.process_func(example_mha_sink_bwd_bhsd.run_regression_perf) + + +def regression_example_mha_sink_bwd_bhsd_sliding_window(): + tilelang.testing.process_func( + example_mha_sink_bwd_bhsd.run_regression_perf, "regression_example_mha_sink_bwd_bhsd_sliding_window", window_size=128 + ) + + +def regression_example_gqa_sink_bwd_bhsd(): + tilelang.testing.process_func(example_gqa_sink_bwd_bhsd.run_regression_perf) + + +def regression_example_gqa_sink_bwd_bhsd_sliding_window(): + tilelang.testing.process_func( + example_gqa_sink_bwd_bhsd.run_regression_perf, "regression_example_gqa_sink_bwd_bhsd_sliding_window", window_size=128 + ) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/attention_sink/test_example_attention_sink.py b/examples/attention_sink/test_example_attention_sink.py index 57242c199..31a1ff1b3 100644 --- a/examples/attention_sink/test_example_attention_sink.py +++ b/examples/attention_sink/test_example_attention_sink.py @@ -5,6 +5,8 @@ import example_gqa_sink_fwd_bhsd_wgmma_pipelined import example_mha_sink_bwd_bhsd import example_gqa_sink_bwd_bhsd +import example_gqa_sink_fwd_varlen +import example_gqa_sink_bwd_varlen @tilelang.testing.requires_cuda @@ -61,5 +63,12 @@ def test_example_gqa_sink_bwd_bhsd_sliding_window(): example_gqa_sink_bwd_bhsd.main(window_size=128) +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_gqa_sink_varlen(): + example_gqa_sink_fwd_varlen.main() # non-causal + example_gqa_sink_bwd_varlen.main() # causal + + if __name__ == "__main__": tilelang.testing.main() diff --git a/examples/autodd/README.md b/examples/autodd/README.md new file mode 100644 index 000000000..9ae9f9816 --- /dev/null +++ b/examples/autodd/README.md @@ -0,0 +1,126 @@ +# AutoDD - Automatic Delta Debugging for TileLang + +AutoDD (Automatic Delta Debugging) is a built-in debugging tool for TileLang that automatically simplifies complex Python programs to the minimal code needed to reproduce a specific error. This is extremely useful for debugging large, complex TileLang programs. + +## What is Delta Debugging? + +Delta Debugging is an automated debugging technique with the core idea: +1. Given a program that triggers a bug +2. Systematically remove code fragments from the program +3. Check if the simplified program still triggers the same bug +4. Eventually obtain the minimal code that triggers the bug + +AutoDD uses a Probability Distribution Driven Delta Debugging (PDD) algorithm for efficient search of minimized code. + +## Why AutoDD? + +When developing TileLang programs, bugs are often hidden in complex code: + +- **Lots of irrelevant code**: Real projects may have hundreds of lines of configuration, helper functions, logging, etc. +- **Hard to locate**: Error messages may point to underlying TVM/CUDA rather than TileLang code +- **Tedious debugging**: Manually deleting code to locate bugs is very time-consuming + +AutoDD automates this process, reducing hundreds of lines of code to just a few dozen, directly exposing the root cause of the problem. + +## Usage + +### Basic Usage + +```bash +python -m tilelang.autodd --err-msg "" -o +``` + +### Parameters + +| Parameter | Description | +|-----------|-------------| +| `source` | Path to the input Python source file | +| `--err-msg` | Error message to match (searched in stdout or stderr) | +| `-o, --output` | Path to the minimized output file | +| `--backend` | Execution backend: `runner` (faster) or `subproc` (more stable), default `runner` | +| `--timeout` | Timeout for each task in seconds, default 60 | +| `-j, --jobs` | Number of parallel jobs, default 1 | + +### Example + +Run AutoDD on `tilelang_buggy.py` in this directory: + +```bash +# Use 4 parallel jobs, search for "Dimension mismatch" error +python -m tilelang.autodd tilelang_buggy.py --err-msg "Dimension mismatch" -o minimized.py -j 4 + +# Or use subprocess backend (more stable but slower) +python -m tilelang.autodd tilelang_buggy.py --err-msg "Dimension mismatch" -o minimized.py --backend subproc +``` + +## Example Files + +### `tilelang_buggy.py` + +A complex TileLang program with a bug (~200 lines), containing: +- Multiple useless helper functions (`calculate_optimal_block_size`, `get_memory_requirements`, etc.) +- A complex configuration class (`MatmulConfig`) +- Unused benchmark code (`benchmark_pytorch`) +- **A GEMM shape mismatch bug** + +The bug is on line 124: +```python +B_shared = T.alloc_shared((block_M, block_N), dtype) # Wrong! Should be (block_K, block_N) +``` + +### `tilelang_minimized_expected.py` + +The expected output after AutoDD simplification (~30 lines). The simplified code clearly shows the root cause of the bug: + +```python +def buggy_matmul(...): + @T.prim_func + def matmul_kernel(): + with T.Kernel(): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_M, block_N), dtype) # Bug! + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.gemm(A_shared, B_shared, C_local) # Error occurs here +``` + +## How AutoDD Works + +AutoDD uses AST (Abstract Syntax Tree) analysis and multiple rewrite rules to simplify code: + +### 1. Fast Reducers +- **Statement removal**: Directly remove statements that don't affect bug reproduction +- **If statement simplification**: Simplify `if cond: body` to `body` +- **For loop simplification**: Bind loop variables to constants + +### 2. Canonicalizers +- **With statement expansion**: Convert `with expr as var` to explicit assignment +- **Function argument extension**: Add `*args, **kwargs` for compatibility + +### 3. Simplifiers +- **Assignment simplification**: Replace complex expressions with constants +- **Function call simplification**: Simplify `f(x)` to `x` +- **Binary operation simplification**: Simplify `a + b` to `a` or `b` + +### 4. Slow Reducers +- **Expression removal**: Remove arbitrary expressions +- **Argument removal**: Remove function arguments +- **Integer reduction**: Gradually reduce large integers + +## Use Cases + +1. **TileLang kernel debugging**: Simplify complex TileLang programs to locate bugs +2. **Bug report submission**: Generate minimal reproduction code for easier issue tracking +3. **Understanding errors**: Easier to understand the nature of errors after removing irrelevant code +4. **Regression testing**: Simplified code can serve as regression test cases + +## Notes + +1. **Error message matching**: The `--err-msg` parameter needs to exactly match a string in the error output +2. **Timeout setting**: For programs with long compilation times, you may need to increase `--timeout` +3. **Parallel jobs**: Increasing `-j` can speed up the simplification process but consumes more resources +4. **Backend selection**: If the `runner` backend is unstable, try the `subproc` backend + +## References + +- [Delta Debugging Paper](https://www.st.cs.uni-saarland.de/papers/tse2002/) +- [TileLang Documentation](https://github.com/tile-ai/tilelang) diff --git a/examples/autodd/tilelang_buggy.py b/examples/autodd/tilelang_buggy.py new file mode 100644 index 000000000..d2c5469bb --- /dev/null +++ b/examples/autodd/tilelang_buggy.py @@ -0,0 +1,229 @@ +""" +A complex TileLang program with lots of redundant code and a bug that triggers an error. +AutoDD will simplify it to the minimal code needed to reproduce the error. + +This example demonstrates how AutoDD can help developers quickly isolate bugs +in complex TileLang programs by automatically removing irrelevant code. + +To run AutoDD on this file: + python -m tilelang.autodd tilelang_buggy.py --err-msg "Dimension mismatch" -o minimized.py -j 4 + +The bug in this file: B_shared has shape (block_M, block_N) instead of (block_K, block_N), +causing a GEMM dimension mismatch error. +""" + +import tilelang +import tilelang.language as T +import torch + + +# Useless helper function - will be removed by AutoDD +def calculate_optimal_block_size(M, N, K): + """Calculate optimal block size - this function is completely useless""" + options = [32, 64, 128, 256] + best = 128 + for opt in options: + if M % opt == 0 and N % opt == 0: + best = opt + break + return best, best, 32 + + +def get_memory_requirements(M, N, K, block_M, block_N, block_K, dtype_size=2): + """Calculate memory requirements - completely useless""" + shared_mem_a = block_M * block_K * dtype_size + shared_mem_b = block_K * block_N * dtype_size + total_shared = shared_mem_a + shared_mem_b + return total_shared + + +def validate_parameters(M, N, K, block_M, block_N, block_K): + """Validate parameters - redundant check""" + if M <= 0 or N <= 0 or K <= 0: + raise ValueError("Matrix dimensions must be positive") + if block_M <= 0 or block_N <= 0 or block_K <= 0: + raise ValueError("Block sizes must be positive") + if M % block_M != 0: + print(f"Warning: M ({M}) not divisible by block_M ({block_M})") + if N % block_N != 0: + print(f"Warning: N ({N}) not divisible by block_N ({block_N})") + if K % block_K != 0: + print(f"Warning: K ({K}) not divisible by block_K ({block_K})") + return True + + +class MatmulConfig: + """Configuration class - increases code complexity but is actually useless""" + + def __init__(self, M, N, K): + self.M = M + self.N = N + self.K = K + self.block_M = 128 + self.block_N = 128 + self.block_K = 32 + self.num_stages = 3 + self.threads = 128 + self.dtype = "float16" + self.accum_dtype = "float32" + + def get_grid_size(self): + grid_x = (self.N + self.block_N - 1) // self.block_N + grid_y = (self.M + self.block_M - 1) // self.block_M + return grid_x, grid_y + + def get_shared_memory_size(self): + return get_memory_requirements(self.M, self.N, self.K, self.block_M, self.block_N, self.block_K) + + def validate(self): + return validate_parameters(self.M, self.N, self.K, self.block_M, self.block_N, self.block_K) + + +def create_reference_output(a, b, activation="relu"): + """Create reference output - not actually used in verification""" + result = a @ b + if activation == "relu": + result = torch.relu(result) + elif activation == "gelu": + result = torch.nn.functional.gelu(result) + elif activation == "sigmoid": + result = torch.sigmoid(result) + return result + + +def benchmark_pytorch(M, N, K, num_iters=10, warmup=5): + """PyTorch benchmark - not used""" + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + # Warmup + for _ in range(warmup): + _ = a @ b + torch.cuda.synchronize() + + # Benchmark + import time + + start = time.time() + for _ in range(num_iters): + _ = a @ b + torch.cuda.synchronize() + end = time.time() + + return (end - start) / num_iters * 1000 # ms + + +# Main TileLang kernel - contains a BUG: GEMM shape mismatch! +@tilelang.jit +def buggy_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def matmul_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + # Allocate shared memory + A_shared = T.alloc_shared((block_M, block_K), dtype) + # BUG: the first dimension of B_shared should be block_K, but block_M is used here! + B_shared = T.alloc_shared((block_M, block_N), dtype) # Wrong shape! + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Allocate some useless temp variables + temp_buffer = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Zero out + T.clear(C_local) + T.clear(temp_buffer) + + # Main loop + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy a tile of A + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Copy a tile of B - shape can mismatch here too + T.copy(B[ko * block_K, bx * block_N], B_shared) + + # GEMM computation - shape mismatch will cause an error + # A_shared: (block_M, block_K) + # B_shared: (block_M, block_N) <- should be (block_K, block_N) + T.gemm(A_shared, B_shared, C_local) + + # ReLU activation + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = T.max(C_local[i, j], 0) + + # Some useless postprocessing + for i, j in T.Parallel(block_M, block_N): + if temp_buffer[i, j] > 0: + C_local[i, j] = C_local[i, j] + 0.0 + + # Write back result + T.copy(C_local, C[by * block_M, bx * block_N]) + + return matmul_kernel + + +def run_kernel(config): + """Run kernel - includes extra redundant logic""" + # Validate parameters + config.validate() + + # Get config + M, N, K = config.M, config.N, config.K + block_M, block_N, block_K = config.block_M, config.block_N, config.block_K + + # Calculate some useless statistics + grid_size = config.get_grid_size() + shared_mem = config.get_shared_memory_size() + print(f"Grid size: {grid_size}") + print(f"Shared memory: {shared_mem} bytes") + + # Create test data + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + c = torch.empty(M, N, device="cuda", dtype=torch.float16) + + # Compile and run kernel - will trigger the BUG here + kernel = buggy_matmul(M, N, K, block_M, block_N, block_K) + kernel(a, b, c) + + # Validate results (if it can get here) + ref_c = torch.relu(a @ b) + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("Kernel output matches PyTorch reference.") + + return c + + +def main(): + # Useless printing + print("=" * 60) + print("TileLang Matmul Kernel Test") + print("=" * 60) + + # Create config + M, N, K = 512, 512, 512 + config = MatmulConfig(M, N, K) + + # Calculate some useless values + optimal_block = calculate_optimal_block_size(M, N, K) + print(f"Optimal block size: {optimal_block}") + + # Run PyTorch benchmark - result is not used + # pytorch_time = benchmark_pytorch(M, N, K) + # print(f"PyTorch time: {pytorch_time:.3f} ms") + + # Run our kernel - will trigger the error here + try: + result = run_kernel(config) + print(f"Result shape: {result.shape}") + except Exception as e: + print(f"Error: {e}") + raise + + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/examples/autodd/tilelang_minimized_expected.py b/examples/autodd/tilelang_minimized_expected.py new file mode 100644 index 000000000..3dc88f992 --- /dev/null +++ b/examples/autodd/tilelang_minimized_expected.py @@ -0,0 +1,49 @@ +""" +This is the expected output after running AutoDD on tilelang_buggy.py. +AutoDD automatically simplified the 200+ line buggy program to ~30 lines +while preserving the ability to reproduce the error. + +The minimized code clearly shows the root cause of the bug: +- A_shared has shape (block_M, block_K) +- B_shared has shape (block_M, block_N) - should be (block_K, block_N) +- This causes a dimension mismatch in T.gemm() +""" + +import tilelang.language as T + + +class MatmulConfig: + def __init__(self, *args, **kwargs): + self.M = 1 + self.N = 1 + self.K = 1 + self.block_M = 2 + self.block_N = 1 + self.block_K = 1 + + +def buggy_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32, *args, **kwargs): + @T.prim_func + def matmul_kernel(): + with T.Kernel(): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_M, block_N), dtype) # Bug: should be (block_K, block_N) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.gemm(A_shared, B_shared, C_local) + + +def run_kernel(config, *args, **kwargs): + M, N, K = (config.M, config.N, config.K) + block_M, block_N, block_K = (config.block_M, config.block_N, config.block_K) + buggy_matmul(M, N, K, block_M, block_N, block_K) + + +def main(*args, **kwargs): + config = MatmulConfig() + try: + run_kernel(config) + except Exception as e: + print(f"{e}") + + +main() diff --git a/examples/bitnet-1.58b/.gitignore b/examples/bitnet-1.58b/.gitignore index 6ea887496..2bcdfd92b 100644 --- a/examples/bitnet-1.58b/.gitignore +++ b/examples/bitnet-1.58b/.gitignore @@ -1 +1 @@ -models/ \ No newline at end of file +models/ diff --git a/examples/bitnet-1.58b/README.md b/examples/bitnet-1.58b/README.md index 2b587eab4..b9898741b 100644 --- a/examples/bitnet-1.58b/README.md +++ b/examples/bitnet-1.58b/README.md @@ -2,7 +2,6 @@ license: mit --- - This is a Tilelang Implementation for the reproduced 1.58bit model from [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B). We replaced the original simulated Int8x3bit Quantized Inference Kernel with INT8xINT2 Kernel. We also evaluated the model's correctness and performance through `eval_correctness.py` and `benchmark_inference_latency.py`. ## Make Checkpoints for vLLM @@ -43,7 +42,6 @@ python3 inference_with_bitblas_format.py | bitnet-3b-1.58bits | vllm-tilelang | 379.25 | 117.43 | 752.55 | | bitnet-3b-1.58bits | vllm-tilelang-cuda-graph | 2543.58 | 1621.08 | 2731.79 | - ## BitBLAS Results ### Performance @@ -94,4 +92,4 @@ The differences between the reported numbers and the reproduced results are poss journal={arXiv preprint arXiv:2402.17764}, year={2024} } -``` \ No newline at end of file +``` diff --git a/examples/bitnet-1.58b/benchmark.sh b/examples/bitnet-1.58b/benchmark.sh index 6a2550d45..839443dc6 100755 --- a/examples/bitnet-1.58b/benchmark.sh +++ b/examples/bitnet-1.58b/benchmark.sh @@ -1,3 +1,5 @@ +#!/usr/bin/env bash + python benchmark_generate.py --bs 16 --in_seq_len 32 --out_seq_len 128 | tee b16_i32_o128.log python benchmark_generate.py --bs 1 --in_seq_len 512 --out_seq_len 64 | tee b1_i512_o64.log diff --git a/examples/bitnet-1.58b/benchmark_generate.py b/examples/bitnet-1.58b/benchmark_generate.py index d6f21ed50..d678b91a4 100644 --- a/examples/bitnet-1.58b/benchmark_generate.py +++ b/examples/bitnet-1.58b/benchmark_generate.py @@ -12,8 +12,7 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100): # Encode the input prompts as a batch - input_ids = tokenizer( - prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device) + input_ids = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device) # Generate cos and sin values (commented out as not used in generation) seq_length = input_ids.size(1) @@ -37,9 +36,7 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100): end_time = time.time() # Decode the output ids to text - generated_texts = [ - tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids - ] + generated_texts = [tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids] generation_time = end_time - start_time num_tokens = sum(len(output_id) for output_id in output_ids) @@ -52,8 +49,8 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100): def profile(model, input_data): - import numpy as np + model = model.cuda() model.eval() @@ -74,25 +71,29 @@ def get_runtime(num_repeats=1): return np.mean(times) -model_path = '1bitLLM/bitnet_b1_58-3B' +model_path = "1bitLLM/bitnet_b1_58-3B" def main(): parser = argparse.ArgumentParser() - parser.add_argument('--bs', default=16, type=int) - parser.add_argument('--in_seq_len', default=32, type=int) - parser.add_argument('--out_seq_len', default=128, type=int) - parser.add_argument('--bitblas', action='store_true') + parser.add_argument("--bs", default=16, type=int) + parser.add_argument("--in_seq_len", default=32, type=int) + parser.add_argument("--out_seq_len", default=128, type=int) + parser.add_argument("--bitblas", action="store_true") args = parser.parse_args() bs = args.bs in_seq_len = args.in_seq_len out_seq_len = args.out_seq_len is_bitblas = args.bitblas - model = BitnetForCausalLM.from_pretrained( - model_path, - use_flash_attention_2=True, - torch_dtype=torch.float16, - ).cuda().half() + model = ( + BitnetForCausalLM.from_pretrained( + model_path, + use_flash_attention_2=True, + torch_dtype=torch.float16, + ) + .cuda() + .half() + ) if is_bitblas: with torch.no_grad(): model.quantize() @@ -109,5 +110,5 @@ def main(): print(generate_text_batch(model, tokenizer, prompts, max_length=max_length)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/bitnet-1.58b/benchmark_inference_latency.py b/examples/bitnet-1.58b/benchmark_inference_latency.py index 9ce7a3898..788fc5565 100644 --- a/examples/bitnet-1.58b/benchmark_inference_latency.py +++ b/examples/bitnet-1.58b/benchmark_inference_latency.py @@ -6,13 +6,14 @@ torch.set_grad_enabled(False) parser = argparse.ArgumentParser() -parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) +parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) def profile(model, input_data): import time import numpy as np + model = model.cuda() model.eval() @@ -35,8 +36,8 @@ def get_runtime(num_repeats=1): def main(): model = BitnetForCausalLM.from_pretrained( - '1bitLLM/bitnet_b1_58-3B', - device_map='auto', + "1bitLLM/bitnet_b1_58-3B", + device_map="auto", low_cpu_mem_usage=True, use_flash_attention_2=True, torch_dtype=torch.float16, @@ -52,5 +53,5 @@ def main(): print(f"Batch size: {batch_size}, Seq len: {seq_len}, Latency: {latency}") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/bitnet-1.58b/configuration_bitnet.py b/examples/bitnet-1.58b/configuration_bitnet.py index 5f4937b87..63c499db3 100644 --- a/examples/bitnet-1.58b/configuration_bitnet.py +++ b/examples/bitnet-1.58b/configuration_bitnet.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" LLaMA model configuration""" +"""LLaMA model configuration""" from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging @@ -180,16 +180,10 @@ def _rope_scaling_validation(self): return if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " - f"got {self.rope_scaling}") + raise ValueError(f"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, got {self.rope_scaling}") rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, - float) or rope_scaling_factor <= 1.0: - raise ValueError( - f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + raise ValueError(f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}") + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/examples/bitnet-1.58b/eval_correctness.py b/examples/bitnet-1.58b/eval_correctness.py index ac1e34072..11d47004b 100644 --- a/examples/bitnet-1.58b/eval_correctness.py +++ b/examples/bitnet-1.58b/eval_correctness.py @@ -47,8 +47,8 @@ def generate_text(model, tokenizer, prompt, max_length=100): def profile(model, input_data): - import numpy as np + model = model.cuda() model.eval() @@ -69,18 +69,22 @@ def get_runtime(num_repeats=1): return np.mean(times) -model_path = '1bitLLM/bitnet_b1_58-3B' +model_path = "1bitLLM/bitnet_b1_58-3B" def main(): - model = BitnetForCausalLM.from_pretrained( - model_path, - use_flash_attention_2=False, - torch_dtype=torch.float16, - ).cuda().half() + model = ( + BitnetForCausalLM.from_pretrained( + model_path, + use_flash_attention_2=False, + torch_dtype=torch.float16, + ) + .cuda() + .half() + ) tokenizer = BitnetTokenizer.from_pretrained(model_path, use_fast=False) - input_id = tokenizer("Hello")['input_ids'] + input_id = tokenizer("Hello")["input_ids"] input_id = torch.tensor(input_id).unsqueeze(0).cuda() print("original model generated text:") @@ -91,5 +95,5 @@ def main(): print(generate_text(model, tokenizer, "Hello", max_length=100)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/bitnet-1.58b/eval_gpu_memory.py b/examples/bitnet-1.58b/eval_gpu_memory.py index 597cbbfcd..00c914cb3 100644 --- a/examples/bitnet-1.58b/eval_gpu_memory.py +++ b/examples/bitnet-1.58b/eval_gpu_memory.py @@ -6,13 +6,14 @@ torch.set_grad_enabled(False) parser = argparse.ArgumentParser() -parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) +parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) def profile(model, input_data): import time import numpy as np + model = model.cuda() model.eval() @@ -35,17 +36,17 @@ def get_runtime(num_repeats=1): def main(): model = BitnetForCausalLM.from_pretrained( - '1bitLLM/bitnet_b1_58-3B', - device_map='auto', + "1bitLLM/bitnet_b1_58-3B", + device_map="auto", low_cpu_mem_usage=True, use_flash_attention_2=True, torch_dtype=torch.float16, ).half() - print(f"gpu memory: {torch.cuda.memory_allocated() / 1024 ** 3} GB") + print(f"gpu memory: {torch.cuda.memory_allocated() / 1024**3} GB") with torch.no_grad(): model._post_process_weights() - print(f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024 ** 3} GB") + print(f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024**3} GB") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/bitnet-1.58b/eval_ppl.py b/examples/bitnet-1.58b/eval_ppl.py index 61c8488e4..97db2d0f5 100644 --- a/examples/bitnet-1.58b/eval_ppl.py +++ b/examples/bitnet-1.58b/eval_ppl.py @@ -15,9 +15,9 @@ torch.set_grad_enabled(False) parser = argparse.ArgumentParser() -parser.add_argument('--seed', default=0, type=int) -parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) -parser.add_argument('--seqlen', default=2048, type=int) +parser.add_argument("--seed", default=0, type=int) +parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) +parser.add_argument("--seqlen", default=2048, type=int) def calulate_loss(model, input, loss_fct): @@ -29,12 +29,16 @@ def calulate_loss(model, input, loss_fct): def main(args): - datasets = ['c4', 'wikitext2'] - model = BitnetForCausalLM.from_pretrained( - args.hf_path, - use_flash_attention_2=True, - torch_dtype=torch.float16, - ).cuda().half() + datasets = ["c4", "wikitext2"] + model = ( + BitnetForCausalLM.from_pretrained( + args.hf_path, + use_flash_attention_2=True, + torch_dtype=torch.float16, + ) + .cuda() + .half() + ) with torch.no_grad(): model._post_process_weights() tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False) @@ -48,9 +52,9 @@ def main(args): for ii in progress: input = torch.Tensor(testdata[ii]).long().cuda().view(1, -1) loss = calulate_loss(model, input, loss_fct) - count += (input.size(-1) - 1) + count += input.size(-1) - 1 acc_loss += loss.item() - progress.set_description(f"avg_loss = {acc_loss/ count / math.log(2)}") + progress.set_description(f"avg_loss = {acc_loss / count / math.log(2)}") avg_loss = acc_loss / count / math.log(2) ppl.append(2**avg_loss) @@ -60,7 +64,7 @@ def main(args): print("Avg PPL:", sum(ppl) / len(ppl)) -if __name__ == '__main__': +if __name__ == "__main__": torch.set_grad_enabled(False) args = parser.parse_args() random.seed(args.seed) diff --git a/examples/bitnet-1.58b/eval_utils.py b/examples/bitnet-1.58b/eval_utils.py index 46241eedf..72480c392 100644 --- a/examples/bitnet-1.58b/eval_utils.py +++ b/examples/bitnet-1.58b/eval_utils.py @@ -15,21 +15,17 @@ def set_seed(seed): def get_test_dataset(dataset_name, tokenizer, seqlen=2048): if dataset_name == "wikitext2": - testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') - testdata = "".join(testdata['text']).split('\n') + testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + testdata = "".join(testdata["text"]).split("\n") elif dataset_name == "c4": - testdata = load_dataset( - 'allenai/c4', - data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, - split='validation')['text'] + testdata = load_dataset("allenai/c4", data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, split="validation")[ + "text" + ] else: raise NotImplementedError testdata = [item for item in testdata if item != ""] - tokenized_text = [ - tokenizer(item, add_special_tokens=False)['input_ids'] + [tokenizer.eos_token_id] - for item in testdata - ] + tokenized_text = [tokenizer(item, add_special_tokens=False)["input_ids"] + [tokenizer.eos_token_id] for item in testdata] data, doc = [], [tokenizer.bos_token_id] for sen in tokenized_text: @@ -45,7 +41,6 @@ def get_test_dataset(dataset_name, tokenizer, seqlen=2048): class LMEvalAdaptor(BaseLM): - def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1): super().__init__() @@ -137,5 +132,4 @@ def _model_call(self, inps): return out def _model_generate(self, context, max_length, eos_token_id): - return self.model.generate( - context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False) + return self.model.generate(context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False) diff --git a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py index e5af16cc4..a31261d3e 100644 --- a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py +++ b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py @@ -76,13 +76,13 @@ def bitnet_158_int8xint2_decode( reduce_thread=32, ): assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" storage_nbit = 8 num_bits = 2 @@ -94,7 +94,7 @@ def bitnet_158_int8xint2_decode( MAX_TRANSACTION_SIZE_IN_BITS = 128 micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits micro_size_k_compressed = micro_size_k // num_elems_per_byte - storage_dtype = "int8" + storage_dtype = T.int8 block_K = reduce_thread * micro_size_k use_dp4a = True @@ -102,17 +102,17 @@ def bitnet_158_int8xint2_decode( @T.prim_func def kernel( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, storage_dtype), - C: T.Buffer(C_shape, out_dtype), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer(C_shape, out_dtype), ): with T.Kernel( - T.ceildiv(N, n_partition), - M, - threads=(reduce_thread, n_partition), + T.ceildiv(N, n_partition), + M, + threads=(reduce_thread, n_partition), ) as ( - bx, - by, + bx, + by, ): A_local = T.alloc_local((micro_size_k,), in_dtype) B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype) @@ -133,8 +133,7 @@ def kernel( for v in T.vectorized(micro_size_k_compressed): B_quant_local[v] = B[ bx * n_partition + ni, - ko * (reduce_thread * micro_size_k_compressed) + - kr * micro_size_k_compressed + v, + ko * (reduce_thread * micro_size_k_compressed) + kr * micro_size_k_compressed + v, ] T.call_extern( @@ -156,9 +155,9 @@ def kernel( accum_res[0] += A_local[ki] * B_dequantize_local[ki] with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), - "reduce_scope", - T.reinterpret(T.uint64(0), dtype="handle"), + T.comm_reducer(lambda x, y: x + y, [T.cast(0, accum_dtype)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), ): T.evaluate( T.tvm_thread_allreduce( @@ -168,7 +167,8 @@ def kernel( reduced_accum_res[0], kr, dtype="handle", - )) + ) + ) if kr == 0: C[by, bx * n_partition + ni] = reduced_accum_res[0] @@ -194,12 +194,12 @@ def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8): # interleave weight numpy implementation -def interleave_weight(qweight, nbits=4, target_dtype="float16"): - assert target_dtype in ["float16", "int8"] +def interleave_weight(qweight, nbits=4, target_dtype=T.float16): + assert target_dtype in [T.float16, T.int8] # reinterpret the data type of qweight to int32 qweight = qweight.view(np.int32) new_qweight = np.zeros_like(qweight) - bits_stride = 8 if target_dtype == "int8" else 16 + bits_stride = 8 if target_dtype == T.int8 else 16 mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f num_groups = 32 // bits_stride elems_per_group = bits_stride // nbits @@ -209,7 +209,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift - if nbits == 1 and target_dtype == "int8": + if nbits == 1 and target_dtype == T.int8: # special handling for 1b interleave n16_weight = new_qweight & np.int32(0xF0F00F0F) n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16 @@ -217,12 +217,12 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4 n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12 return n16_weight.view(np.int8) - elif nbits == 2 and target_dtype == "float16": + elif nbits == 2 and target_dtype == T.float16: n8_weight = new_qweight & np.int32(0xFF0000FF) n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16 n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8 return n8_weight.view(np.int8) - elif nbits == 1 and target_dtype == "float16": + elif nbits == 1 and target_dtype == T.float16: n8_weight = new_qweight & 0xF000000F n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8 n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16 @@ -234,13 +234,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): return new_qweight.view(np.int8) -def assert_bitnet_158_int8xint2_decode_correctness(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - fast_decoding=True): +def assert_bitnet_158_int8xint2_decode_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True): program = bitnet_158_int8xint2_decode(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding) print(program) kernel = tilelang.compile(program) @@ -265,4 +259,4 @@ def assert_bitnet_158_int8xint2_decode_correctness(M, if __name__ == "__main__": - assert_bitnet_158_int8xint2_decode_correctness(1, 256, 256, "int8", "int32", "int32") + assert_bitnet_158_int8xint2_decode_correctness(1, 256, 256, T.int8, T.int32, T.int32) diff --git a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py index d8b1f6228..f4a60098a 100644 --- a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py +++ b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py @@ -8,11 +8,13 @@ from tilelang import tvm as tvm from tvm import DataType from tilelang.intrinsics.mma_layout import ( - make_mma_swizzle_layout as make_swizzle_layout,) + make_mma_swizzle_layout as make_swizzle_layout, +) import numpy as np from tilelang.intrinsics.mma_macro_generator import ( - INT4TensorCoreIntrinEmitter,) + INT4TensorCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func torch.manual_seed(42) @@ -86,9 +88,9 @@ def bitnet_158_int8xint2_prefill( Create a TVM GPU prim_func implementing a block-tiled matrix multiply that multiplies dense A by compressed/interleaved low‑precision B (2-bit packed into int8 storage), decoding B to int8 on-chip and accumulating into C. The returned prim_func expects: - - A: shape (M, K) with dtype `in_dtype` ("float16" or "int8"). + - A: shape (M, K) with dtype `in_dtype` (T.float16 or T.int8). - B: compressed storage with shape (N, K/4) and int8 storage layout (packing 4 2-bit elements per byte). - - C: output buffer shape (M, N) with dtype `out_dtype` ("float16", "float32", or "int32"). + - C: output buffer shape (M, N) with dtype `out_dtype` (T.float16, T.float32, or T.int32). Details: - Builds a tiled, pipelined kernel using shared memory and warp-level MMA intrinsics (INT4TensorCoreIntrinEmitter). B is loaded from compressed storage, decoded to int8 in threads (via decode_i2u_to_i8s / decode_i2s_to_i8s), and dequantized into a shared buffer used by the MMA emitter. @@ -96,15 +98,15 @@ def bitnet_158_int8xint2_prefill( - block_row_warps, block_col_warps: number of warps per block in row/col. - warp_row_tiles, warp_col_tiles: tiles per warp. - chunk: K-sized chunk per block (block_K). - - micro sizes are fixed (16x16x16, except micro_k=32 when accum_dtype == "int32"). + - micro sizes are fixed (16x16x16, except micro_k=32 when accum_dtype == T.int32). - Uses 2-stage pipelining by default to overlap loads and compute and applies a swizzle layout to improve L2 behavior. - Assertions: raises AssertionError if in_dtype or out_dtype are not among supported values. Parameters: M, N, K (int): Global matrix dimensions. - in_dtype (str): Input and decoded B element dtype; "float16" or "int8". - out_dtype (str): Output C dtype; one of "float16", "float32", "int32". - accum_dtype (str): Accumulator dtype used by MMA (e.g., "int32"). + in_dtype (str): Input and decoded B element dtype; T.float16 or T.int8. + out_dtype (str): Output C dtype; one of T.float16, T.float32, T.int32. + accum_dtype (str): Accumulator dtype used by MMA (e.g., T.int32). fast_decoding (bool): If True, enable the fast decoding path (affects which device decode is used). block_row_warps (int): Warps in block row dimension. block_col_warps (int): Warps in block column dimension. @@ -116,18 +118,18 @@ def bitnet_158_int8xint2_prefill( T.prim_func: A TVM prim_func implementing the described GPU kernel suitable for compilation and execution. """ assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 - if accum_dtype == "int32": + if accum_dtype == T.int32: micro_size_k = 32 num_elems_per_byte = 4 @@ -136,7 +138,7 @@ def bitnet_158_int8xint2_prefill( local_size_compressed = local_size // num_elems_per_byte shared_scope = "shared.dyn" - storage_dtype = "int8" + storage_dtype = T.int8 # Pipeline Stage stage = 2 @@ -181,38 +183,36 @@ def bitnet_158_int8xint2_prefill( @T.prim_func def main( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, storage_dtype), - C: T.Buffer((M, N), out_dtype), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), out_dtype), ): """ - GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C. + GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C. - This kernel: - - Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory. - - Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine. - - Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages. - - Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing. + This kernel: + - Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory. + - Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine. + - Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages. + - Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing. - Parameters: - A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations. - B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel. - C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C). + Parameters: + A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations. + B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel. + C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C). - Side effects: - Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation. + Side effects: + Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation. """ with T.Kernel( - T.ceildiv(N, block_N), - T.ceildiv(M, block_M), - threads=threads, - prelude=decode_i2s_to_i8s, + T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + threads=threads, + prelude=decode_i2s_to_i8s, ) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) - B_dequantize_shared = T.alloc_shared( - B_dequantize_shared_shape, in_dtype, scope=shared_scope) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype) B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype) @@ -221,12 +221,14 @@ def main( B_local = T.alloc_local([local_size_compressed], storage_dtype) B_dequantize_local = T.alloc_local([local_size], in_dtype) - thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + thread_bindings = T.get_thread_binding(0) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_dequantize_shared: make_swizzle_layout(B_dequantize_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_dequantize_shared: make_swizzle_layout(B_dequantize_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -234,7 +236,6 @@ def main( T.clear(C_frag) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -243,12 +244,9 @@ def main( for j, k in T.Parallel(block_N, block_K // num_elems_per_byte): B_shared[j, k] = B[bx * block_N + j, ko * (block_K // num_elems_per_byte) + k] - for i in T.serial(block_N * block_K // num_elems_per_byte // - (threads * local_size_compressed)): + for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)): for v in T.vectorized(0, local_size_compressed): - index = ( - i * threads * local_size_compressed + - thread_bindings * local_size_compressed + v) + index = i * threads * local_size_compressed + thread_bindings * local_size_compressed + v vi, vj = T.index_to_coordinates(index, B_shared_shape) B_local[v] = B_shared[vi, vj] @@ -260,12 +258,11 @@ def main( ) for v in T.vectorized(0, local_size): - index = (i * threads * local_size + thread_bindings * local_size + v) + index = i * threads * local_size + thread_bindings * local_size + v vi, vj = T.index_to_coordinates(index, B_dequantize_shared_shape) B_dequantize_shared[vi, vj] = B_dequantize_local[v] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_frag, @@ -320,12 +317,12 @@ def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8): # interleave weight numpy implementation -def interleave_weight(qweight, nbits=4, target_dtype="float16"): - assert target_dtype in ["float16", "int8"] +def interleave_weight(qweight, nbits=4, target_dtype=T.float16): + assert target_dtype in [T.float16, T.int8] # reinterpret the data type of qweight to int32 qweight = qweight.view(np.int32) new_qweight = np.zeros_like(qweight) - bits_stride = 8 if target_dtype == "int8" else 16 + bits_stride = 8 if target_dtype == T.int8 else 16 mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f num_groups = 32 // bits_stride elems_per_group = bits_stride // nbits @@ -335,7 +332,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift - if nbits == 1 and target_dtype == "int8": + if nbits == 1 and target_dtype == T.int8: # special handling for 1b interleave n16_weight = new_qweight & np.int32(0xF0F00F0F) n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16 @@ -343,12 +340,12 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4 n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12 return n16_weight.view(np.int8) - elif nbits == 2 and target_dtype == "float16": + elif nbits == 2 and target_dtype == T.float16: n8_weight = new_qweight & np.int32(0xFF0000FF) n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16 n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8 return n8_weight.view(np.int8) - elif nbits == 1 and target_dtype == "float16": + elif nbits == 1 and target_dtype == T.float16: n8_weight = new_qweight & 0xF000000F n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8 n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16 @@ -360,13 +357,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): return new_qweight.view(np.int8) -def assert_bitnet_158_int8xint2_prefill_correctness(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - fast_decoding=True): +def assert_bitnet_158_int8xint2_prefill_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True): program = bitnet_158_int8xint2_prefill(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding) print(program) kernel = tilelang.compile(program) @@ -391,4 +382,4 @@ def assert_bitnet_158_int8xint2_prefill_correctness(M, if __name__ == "__main__": - assert_bitnet_158_int8xint2_prefill_correctness(256, 256, 256, "int8", "int32", "int32") + assert_bitnet_158_int8xint2_prefill_correctness(256, 256, 256, T.int8, T.int32, T.int32) diff --git a/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py b/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py index 986463598..e3d35df4b 100644 --- a/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py +++ b/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py @@ -6,7 +6,8 @@ import tvm.tl.language as T from bitblas.tl.utils import get_swizzle_layout from bitblas.tl.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from bitblas.base import simplify_prim_func torch.manual_seed(0) @@ -37,18 +38,18 @@ def tl_matmul( accum_dtype, ): assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 - if out_dtype == "int32": + if out_dtype == T.int32: micro_size_k = 32 # This is a debug config @@ -56,7 +57,7 @@ def tl_matmul( block_col_warps = 2 warp_row_tiles = 64 warp_col_tiles = 64 - chunk = 32 if in_dtype == "float16" else 64 + chunk = 32 if in_dtype == T.float16 else 64 shared_scope = "shared.dyn" # Pipeline Stage @@ -101,12 +102,11 @@ def tl_matmul( @T.prim_func def main( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, in_dtype), - C: T.Buffer((M, N), out_dtype), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -116,10 +116,12 @@ def main( thread_bindings = T.thread_binding(0, threads, "threadIdx.x") - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -127,7 +129,6 @@ def main( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -137,7 +138,6 @@ def main( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, @@ -183,7 +183,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): # src_code is the generated cuda source assert src_code is not None print(src_code) - if in_dtype == "int8": + if in_dtype == T.int8: A = torch.randint(-7, 7, (M, K), device="cuda", dtype=torch.int8) B = torch.randint(-7, 7, (N, K), device="cuda", dtype=torch.int8) else: @@ -209,12 +209,12 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): def test_assert_tl_matmul(): - assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") - assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") + assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16) + assert_tl_matmul_correctness(128, 256, 256, T.float16, T.float32, T.float32) if __name__ == "__main__": # bitblas.testing.main() - # assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") - # assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") - assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") + # assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16) + # assert_tl_matmul_correctness(128, 128, 128, T.int8, T.int32, T.int32) + assert_tl_matmul_correctness(16384, 16384, 16384, T.int8, T.int32, T.int32) diff --git a/examples/bitnet-1.58b/load_from_quantized.py b/examples/bitnet-1.58b/load_from_quantized.py index 26a32f974..8c775aa4c 100644 --- a/examples/bitnet-1.58b/load_from_quantized.py +++ b/examples/bitnet-1.58b/load_from_quantized.py @@ -49,7 +49,13 @@ def generate_text(model, tokenizer, prompt, max_length=100): def main(): # load quantized model - qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half() + qmodel = ( + BitnetForCausalLM.from_quantized( + saved_model_path, + ) + .cuda() + .half() + ) tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False) # print("original model generated text:") # print(generate_text(model, tokenizer, "Hi, ", max_length=100)) diff --git a/examples/bitnet-1.58b/maint/README.md b/examples/bitnet-1.58b/maint/README.md index 63cc3e275..6bccdf93a 100644 --- a/examples/bitnet-1.58b/maint/README.md +++ b/examples/bitnet-1.58b/maint/README.md @@ -2,7 +2,6 @@ license: mit --- - This is a BitBLAS Implementation for the reproduced 1.58bit model from [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B). We replaced the original simulated Int8x3bit Quantized Inference Kernel with BitBLAS INT8xINT2 Kernel. We also evaluated the model's correctness and performance through `eval_correctness.py` and `benchmark_inference_latency.py`. ## Latest News @@ -88,4 +87,4 @@ The differences between the reported numbers and the reproduced results are poss journal={arXiv preprint arXiv:2402.17764}, year={2024} } -``` \ No newline at end of file +``` diff --git a/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py b/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py index 1e29a553a..2604ef387 100644 --- a/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py +++ b/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py @@ -25,9 +25,9 @@ args = parser.parse_args() model_name_or_path = args.model_name_or_path -saved_model_path = os.path.join( - dirpath, "models", - f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path +saved_model_path = ( + os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path +) def generate_text(model, tokenizer, prompt, max_length=100): @@ -67,7 +67,10 @@ def main(): model_name_or_path, use_flash_attention_2=False, torch_dtype=torch.float16, - ).cuda().half()) + ) + .cuda() + .half() + ) tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False) # print("original model generated text:") @@ -112,10 +115,16 @@ def main(): file_path = cached_file(model_name_or_path, file) os.system(f"cp {file_path} {saved_model_path}") # load quantized model - qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half() + qmodel = ( + BitnetForCausalLM.from_quantized( + saved_model_path, + ) + .cuda() + .half() + ) print("quantized model generated text:") print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/bitnet-1.58b/maint/generate_bitnet_model_bitblas_format.sh b/examples/bitnet-1.58b/maint/generate_bitnet_model_bitblas_format.sh index 741c3a124..b0430588a 100755 --- a/examples/bitnet-1.58b/maint/generate_bitnet_model_bitblas_format.sh +++ b/examples/bitnet-1.58b/maint/generate_bitnet_model_bitblas_format.sh @@ -1,3 +1,5 @@ +#!/usr/bin/env bash + # retrieve the native model input and saved model directory MODEL_DIR=$1 SAVED_MODEL_DIR=$2 diff --git a/examples/bitnet-1.58b/maint/generate_bitnet_model_native_format.sh b/examples/bitnet-1.58b/maint/generate_bitnet_model_native_format.sh index a2df0eb8c..66356d3d8 100755 --- a/examples/bitnet-1.58b/maint/generate_bitnet_model_native_format.sh +++ b/examples/bitnet-1.58b/maint/generate_bitnet_model_native_format.sh @@ -1,3 +1,5 @@ +#!/usr/bin/env bash + # require git lfs if ! command -v git-lfs &> /dev/null; then echo "Please install git-lfs first by running 'sudo apt install git-lfs'" diff --git a/examples/bitnet-1.58b/maint/quantize_config.json b/examples/bitnet-1.58b/maint/quantize_config.json index e2b24123a..80fbf02f0 100644 --- a/examples/bitnet-1.58b/maint/quantize_config.json +++ b/examples/bitnet-1.58b/maint/quantize_config.json @@ -7,4 +7,4 @@ "model_name_or_path": "1bitLLM/bitnet_b1_58-3B", "quant_method": "bitnet", "checkpoint_format": "bitnet" -} \ No newline at end of file +} diff --git a/examples/bitnet-1.58b/maint/upload_models.sh b/examples/bitnet-1.58b/maint/upload_models.sh index b764b0da6..7c6d76e32 100755 --- a/examples/bitnet-1.58b/maint/upload_models.sh +++ b/examples/bitnet-1.58b/maint/upload_models.sh @@ -1,3 +1,5 @@ +#!/usr/bin/env bash + MODEL_DIR=$1 REMOTE_DIR=$2 diff --git a/examples/bitnet-1.58b/modeling_bitnet.py b/examples/bitnet-1.58b/modeling_bitnet.py index 6e3c42b6f..1830995ee 100644 --- a/examples/bitnet-1.58b/modeling_bitnet.py +++ b/examples/bitnet-1.58b/modeling_bitnet.py @@ -64,8 +64,7 @@ def find_layers(module, layers=None, name=""): return {name: module} res = {} for name1, child in module.named_children(): - res.update( - find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1)) + res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1)) return res @@ -87,7 +86,6 @@ def _get_unpad_data(attention_mask): class BitnetRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): """ BitnetRMSNorm is equivalent to T5LayerNorm @@ -108,34 +106,23 @@ def forward(self, hidden_states): class BitnetRotaryEmbedding(nn.Module): - - def __init__(self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): super().__init__() self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / ( - self.base - **(torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq) # For BC we register cos and sin cached self.max_seq_len_cached = max_position_embeddings - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) t = t / self.scaling_factor freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - "_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) - self.register_buffer( - "_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) @property def sin_cached(self): @@ -156,14 +143,12 @@ def cos_cached(self): @torch.no_grad() def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] - inv_freq_expanded = self.inv_freq[None, :, - None].float().expand(position_ids.shape[0], -1, 1) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type - device_type = device_type if isinstance(device_type, - str) and device_type != "mps" else "cpu" + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) @@ -174,8 +159,8 @@ def forward(self, x, position_ids): def rotate_half(x): """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) @@ -207,7 +192,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class BitnetMLP(nn.Module): - def __init__(self, config): super().__init__() self.config = config @@ -245,7 +229,6 @@ def forward(self, x): class BitnetMLPFuseGateUp(nn.Module): - def __init__(self, config): super().__init__() self.config = config @@ -272,8 +255,7 @@ def __init__(self, config): def from_bit_mlp(cls, bit_mlp: BitnetMLP): module = cls(bit_mlp.config) # assign the weights - module.gate_up_proj.weight = nn.Parameter( - torch.cat([bit_mlp.gate_proj.weight, bit_mlp.up_proj.weight], dim=0)) + module.gate_up_proj.weight = nn.Parameter(torch.cat([bit_mlp.gate_proj.weight, bit_mlp.up_proj.weight], dim=0)) module.down_proj = bit_mlp.down_proj module.ffn_layernorm = bit_mlp.ffn_layernorm return module @@ -295,8 +277,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, - head_dim) + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @@ -311,7 +292,8 @@ def __init__(self, config: BitnetConfig, layer_idx: Optional[int] = None): logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class.") + "when creating this class." + ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size @@ -325,8 +307,8 @@ def __init__(self, config: BitnetConfig, layer_idx: Optional[int] = None): if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads}).") + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads})." + ) self.q_proj = BitLinear( self.hidden_size, @@ -387,10 +369,8 @@ def forward( value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) past_key_value = getattr(self, "past_key_value", past_key_value) cos, sin = self.rotary_emb(value_states, position_ids) @@ -399,30 +379,24 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( - self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, :key_states.shape[-2]] + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout( - attn_weights, p=self.attention_dropout, training=self.training) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}") + raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is {attn_output.size()}") attn_output = attn_output.transpose(1, 2).contiguous() @@ -448,7 +422,8 @@ def __init__(self, config: BitnetConfig, layer_idx: Optional[int] = None): logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class.") + "when creating this class." + ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size @@ -462,8 +437,8 @@ def __init__(self, config: BitnetConfig, layer_idx: Optional[int] = None): if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads}).") + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads})." + ) self.qkv_proj = BitLinear( self.hidden_size, @@ -497,17 +472,12 @@ def from_bit_attention(cls, bit_attention: BitnetAttention): module = cls(bit_attention.config, bit_attention.layer_idx) # assign the weights module.qkv_proj.weight = nn.Parameter( - torch.cat([ - bit_attention.q_proj.weight, bit_attention.k_proj.weight, - bit_attention.v_proj.weight - ], - dim=0)) + torch.cat([bit_attention.q_proj.weight, bit_attention.k_proj.weight, bit_attention.v_proj.weight], dim=0) + ) if bit_attention.q_proj.bias is not None and bit_attention.k_proj.bias is not None and bit_attention.v_proj.bias is not None: module.qkv_proj.bias = nn.Parameter( - torch.cat([ - bit_attention.q_proj.bias, bit_attention.k_proj.bias, bit_attention.v_proj.bias - ], - dim=0)) + torch.cat([bit_attention.q_proj.bias, bit_attention.k_proj.bias, bit_attention.v_proj.bias], dim=0) + ) module.o_proj = bit_attention.o_proj module.inner_attn_ln = bit_attention.inner_attn_ln if bit_attention.config.rope_scaling is None: @@ -528,16 +498,13 @@ def forward( bsz, q_len, _ = hidden_states.size() qkv_states = self.qkv_proj(hidden_states) query_states, key_states, value_states = torch.split( - qkv_states, [ - self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim, - self.num_key_value_heads * self.head_dim - ], - dim=-1) + qkv_states, + [self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim, self.num_key_value_heads * self.head_dim], + dim=-1, + ) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) past_key_value = getattr(self, "past_key_value", past_key_value) cos, sin = self.rotary_emb(value_states, position_ids) @@ -546,30 +513,24 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( - self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, :key_states.shape[-2]] + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout( - attn_weights, p=self.attention_dropout, training=self.training) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}") + raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is {attn_output.size()}") attn_output = attn_output.transpose(1, 2).contiguous() @@ -622,10 +583,8 @@ def forward( # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -635,8 +594,7 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -665,14 +623,14 @@ def forward( logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}.") + f" {target_dtype}." + ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate) + attn_output = self._flash_attention_forward(query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.inner_attn_ln(attn_output) @@ -683,14 +641,9 @@ def forward( return attn_output, attn_weights, past_key_value - def _flash_attention_forward(self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None): + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. @@ -720,7 +673,8 @@ def _flash_attention_forward(self, if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length) + query_states, key_states, value_states, attention_mask, query_length + ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens @@ -740,13 +694,7 @@ def _flash_attention_forward(self, attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal) + attn_output = flash_attn_func(query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal) return attn_output @@ -754,28 +702,24 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k) + query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, - device=query_layer.device) # There is a memcpy here, that is very bad. + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( - query_layer, attention_mask) + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, @@ -794,13 +738,11 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query class BitnetDecoderLayer(nn.Module): - def __init__(self, config: BitnetConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation]( - config=config, layer_idx=layer_idx) + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.mlp = BitnetMLP(config) self.input_layernorm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -834,7 +776,8 @@ def forward( if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`", - stacklevel=2) + stacklevel=2, + ) residual = hidden_states @@ -925,8 +868,7 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = dtype = self.config._pre_quantization_dtype else: dtype = layer.self_attn.o_proj.weight.dtype - layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=device, dtype=dtype) + layer.self_attn.past_key_value = cache_cls(self.config, max_batch_size, max_cache_len, device=device, dtype=dtype) def _reset_cache(self): for layer in self.model.layers: @@ -1025,9 +967,7 @@ def __init__(self, config: BitnetConfig): self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([ - BitnetDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList([BitnetDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) self.norm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -1055,21 +995,15 @@ def forward( cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" - ) + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one") if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.") use_cache = False if inputs_embeds is None: @@ -1083,10 +1017,7 @@ def forward( if cache_position is None: if isinstance(past_key_values, StaticCache): raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device) + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device) if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -1143,12 +1074,9 @@ def forward( next_cache = None if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() - if isinstance(next_decoder_cache, Cache) else next_decoder_cache) + next_cache = next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None) + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1172,14 +1100,9 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache target_length = self.config.max_position_embeddings else: # dynamic cache - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1) - - causal_mask = torch.full((sequence_length, target_length), - fill_value=min_dtype, - dtype=dtype, - device=device) + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 + + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) @@ -1188,10 +1111,8 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[..., :mask_length].eq( - 0.0) * attention_mask[:, None, None, :].eq(0.0) - causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill( - padding_mask, min_dtype) + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) elif attention_mask.dim() == 4: # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. @@ -1201,8 +1122,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[:mask_shape[0], :mask_shape[1], - offset:mask_shape[2] + offset, :mask_shape[3]] = mask_slice + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = mask_slice return causal_mask @@ -1279,9 +1199,7 @@ def forward( "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) @@ -1327,13 +1245,9 @@ def forward( attentions=outputs.attentions, ) - def prepare_inputs_for_generation(self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - **kwargs): + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs + ): # With static cache, the `past_key_values` is None # TODO joao: standardize interface for the different Cache classes and remove of this if has_static_cache = False @@ -1344,13 +1258,13 @@ def prepare_inputs_for_generation(self, past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - past_length = cache_position[ - 0] if cache_position is not None else past_key_values.get_seq_length() + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() max_cache_length = ( torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None else None) - cache_length = past_length if max_cache_length is None else torch.min( - max_cache_length, past_length) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] @@ -1361,7 +1275,7 @@ def prepare_inputs_for_generation(self, # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):] + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: @@ -1369,8 +1283,7 @@ def prepare_inputs_for_generation(self, # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if (max_cache_length is not None and attention_mask is not None and - cache_length + input_ids.shape[1] > max_cache_length): + if max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length: attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids") @@ -1379,7 +1292,7 @@ def prepare_inputs_for_generation(self, position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1]:] + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: @@ -1392,39 +1305,38 @@ def prepare_inputs_for_generation(self, input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] if cache_position is None: - cache_position = torch.arange( - past_length, past_length + input_length, device=input_ids.device) + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) else: cache_position = cache_position[-input_length:] if has_static_cache: past_key_values = None - model_inputs.update({ - "position_ids": position_ids, - "cache_position": cache_position, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - }) + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past),) + reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),) return reordered_past @staticmethod def recursive_set(model, name, attr): - ''' - set layers.25.mlp.up_proj to attr - ''' + """ + set layers.25.mlp.up_proj to attr + """ - names = name.split('.') + names = name.split(".") obj = model for n in names[:-1]: obj = getattr(obj, n) @@ -1521,6 +1433,7 @@ def from_quantized( fuse_gateup = quant_config.get("fuse_gateup", True) import accelerate + if checkpoint_format == "bitblas": model = cls(config) for name, module in model.named_modules(): @@ -1567,7 +1480,6 @@ def from_quantized( LLAMA_START_DOCSTRING, ) class BitnetForSequenceClassification(BitnetPreTrainedModel): - def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels @@ -1631,8 +1543,7 @@ def forward( else: if input_ids is not None: # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, - self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 sequence_lengths = sequence_lengths % input_ids.shape[-1] sequence_lengths = sequence_lengths.to(logits.device) else: @@ -1646,8 +1557,7 @@ def forward( if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or - labels.dtype == torch.int): + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" diff --git a/examples/bitnet-1.58b/nvidia_measure_memory.sh b/examples/bitnet-1.58b/nvidia_measure_memory.sh index e8998f309..82cf4855f 100755 --- a/examples/bitnet-1.58b/nvidia_measure_memory.sh +++ b/examples/bitnet-1.58b/nvidia_measure_memory.sh @@ -1 +1,3 @@ +#!/usr/bin/env bash + nvidia-smi --query-gpu=memory.used --format=csv -lms 500 diff --git a/examples/bitnet-1.58b/tokenization_bitnet.py b/examples/bitnet-1.58b/tokenization_bitnet.py index 6fea3252a..2adfd6dee 100644 --- a/examples/bitnet-1.58b/tokenization_bitnet.py +++ b/examples/bitnet-1.58b/tokenization_bitnet.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tokenization classes for LLaMA.""" + import os from shutil import copyfile from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple @@ -37,12 +38,10 @@ PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "hf-internal-testing/llama-tokenizer": - "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", }, "tokenizer_file": { - "hf-internal-testing/llama-tokenizer": - "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", }, } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { @@ -159,14 +158,10 @@ def __init__( **kwargs, ): self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs - bos_token = AddedToken( - bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token - eos_token = AddedToken( - eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token - unk_token = AddedToken( - unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token - pad_token = AddedToken( - pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token + bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token if legacy is None: logger.warning_once( @@ -174,7 +169,8 @@ def __init__( " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." " If you want to use the new behavior, set `legacy=False`. This should only be set if you understand what it" " means, and thoroughly read the reason why this was added as explained in" - " https://github.com/huggingface/transformers/pull/24565") + " https://github.com/huggingface/transformers/pull/24565" + ) legacy = True self.legacy = legacy @@ -214,8 +210,7 @@ def get_spm_processor(self, from_slow=False): with open(self.vocab_file, "rb") as f: sp_model = f.read() - model_pb2 = import_protobuf( - f"The new behavior of {self.__class__.__name__} (with `self.legacy = False`)") + model_pb2 = import_protobuf(f"The new behavior of {self.__class__.__name__} (with `self.legacy = False`)") model = model_pb2.ModelProto.FromString(sp_model) normalizer_spec = model_pb2.NormalizerSpec() normalizer_spec.add_dummy_prefix = False @@ -261,8 +256,7 @@ def tokenize(self, text: "TextInput", **kwargs) -> List[str]: tokens = super().tokenize(text, **kwargs) - if len(tokens - ) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: tokens = tokens[1:] return tokens @@ -284,7 +278,7 @@ def _tokenize(self, text, **kwargs): # 1. Encode string + prefix ex: " Hey" tokens = self.sp_model.encode(self.unk_token + text, out_type=str) # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] - return tokens[self.unk_token_length:] if len(tokens) >= self.unk_token_length else tokens + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" @@ -332,12 +326,9 @@ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return - out_vocab_file = os.path.join(save_directory, - (filename_prefix + "-" if filename_prefix else "") + - VOCAB_FILES_NAMES["vocab_file"]) + out_vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile( - self.vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) elif not os.path.isfile(self.vocab_file): with open(out_vocab_file, "wb") as fi: @@ -357,10 +348,9 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): return output - def get_special_tokens_mask(self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None, - already_has_special_tokens: bool = False) -> List[int]: + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: """ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer `prepare_for_model` method. @@ -377,20 +367,16 @@ def get_special_tokens_mask(self, `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. """ if already_has_special_tokens: - return super().get_special_tokens_mask( - token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True) + return super().get_special_tokens_mask(token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True) bos_token_id = [1] if self.add_bos_token else [] eos_token_id = [1] if self.add_eos_token else [] if token_ids_1 is None: return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id - return (bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + bos_token_id + - ([0] * len(token_ids_1)) + eos_token_id) + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + bos_token_id + ([0] * len(token_ids_1)) + eos_token_id - def create_token_type_ids_from_sequences(self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None) -> List[int]: + def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: """ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT sequence pair mask has the following format: @@ -473,9 +459,9 @@ def default_chat_template(self): "{% elif message['role'] == 'assistant' %}" "{{ ' ' + content.strip() + ' ' + eos_token }}" "{% endif %}" - "{% endfor %}") - template = template.replace("USE_DEFAULT_PROMPT", - "true" if self.use_default_system_prompt else "false") + "{% endfor %}" + ) + template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false") default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) diff --git a/examples/bitnet-1.58b/utils_quant.py b/examples/bitnet-1.58b/utils_quant.py index 5f5db5dbc..5a50edb39 100644 --- a/examples/bitnet-1.58b/utils_quant.py +++ b/examples/bitnet-1.58b/utils_quant.py @@ -24,15 +24,14 @@ def weight_quant(weight, num_bits=1): def activation_quant(x, num_bits=8): dtype = x.dtype x = x.float() - Qn = -(2**(num_bits - 1)) - Qp = 2**(num_bits - 1) - 1 + Qn = -(2 ** (num_bits - 1)) + Qp = 2 ** (num_bits - 1) - 1 s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) result = (x * s).round().clamp(Qn, Qp) / s return result.type(dtype) class BitLinearBitBLAS(nn.Module): - def __init__( self, in_features: int, @@ -68,7 +67,7 @@ def __init__( self.bitblas_matmul = self._get_or_create_bitblas_operator(matmul_config, ENABLE_TUNING) self.format = "bitnet" - self.Qp = 2**(self.input_bits - 1) - 1 + self.Qp = 2 ** (self.input_bits - 1) - 1 def _get_or_create_bitblas_operator(self, config, enable_tuning): if global_operator_cache.size() == 0: @@ -99,8 +98,7 @@ def replace_weight_param_with_qweight(self): @classmethod def from_bit_linear(cls, bitlinear, weight_group=1): - bitblas_linear = cls( - bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8) + bitblas_linear = cls(bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8) sw, qweight = bitblas_linear.create_bitblas_weights(bitlinear.weight, weight_group) bitblas_linear.register_buffer("qweight", qweight) bitblas_linear.register_buffer("sw", sw) @@ -158,8 +156,8 @@ def weight_quant(weight): @torch.compile def activation_quant(self, x, num_bits=8): x = x.float() - Qn = -(2**(num_bits - 1)) - Qp = 2**(num_bits - 1) - 1 + Qn = -(2 ** (num_bits - 1)) + Qp = 2 ** (num_bits - 1) - 1 s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) result = (x * s).round().clamp(Qn, Qp) return result.type(torch.int8), s @@ -173,9 +171,8 @@ def post_quant_process(self, input, si, sw): # for the correctness evaluation. def native_forward(self, input): - quant_input = (input + (activation_quant(input, self.input_bits) - input).detach()) - quant_weight = ( - self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach()) + quant_input = input + (activation_quant(input, self.input_bits) - input).detach() + quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach() out = nn.functional.linear(quant_input, quant_weight) if self.bias is not None: @@ -214,7 +211,6 @@ def forward(self, input): # Naive BitLinear from HuggingFace class BitLinear(nn.Linear): - def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs): super(BitLinear, self).__init__(*kargs, **kwargs) """ @@ -224,10 +220,8 @@ def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs): self.input_bits = input_bits def forward(self, input): - quant_input = input + (activation_quant(input, self.input_bits) - input).detach() - quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - - self.weight).detach() + quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach() out = nn.functional.linear(quant_input, quant_weight) if self.bias is not None: diff --git a/examples/bitnet-1.58b/vllm_workspace/conftest.py b/examples/bitnet-1.58b/vllm_workspace/conftest.py index 951f38991..e9e2997ef 100644 --- a/examples/bitnet-1.58b/vllm_workspace/conftest.py +++ b/examples/bitnet-1.58b/vllm_workspace/conftest.py @@ -20,7 +20,7 @@ from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from vllm.config import TokenizerPoolConfig -from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel) +from vllm.distributed import destroy_distributed_environment, destroy_model_parallel from vllm.inputs import TextPrompt from vllm.logger import init_logger from vllm.sequence import SampleLogprobs @@ -56,12 +56,13 @@ class _ImageAssetsBase(UserList[ImageAsset]): class _ImageAssets(_ImageAssetsBase): - def __init__(self) -> None: - super().__init__([ - ImageAsset("stop_sign"), - ImageAsset("cherry_blossom"), - ]) + super().__init__( + [ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), + ] + ) def prompts(self, prompts: _ImageAssetPrompts) -> List[str]: """ @@ -136,7 +137,6 @@ def image_assets() -> _ImageAssets: class HfRunner: - def wrap_device(self, input: _T) -> _T: if not is_cpu(): return input.to("cuda") @@ -166,7 +166,8 @@ def __init__( SentenceTransformer( model_name, device="cpu", - ).to(dtype=torch_dtype)) + ).to(dtype=torch_dtype) + ) else: if is_vision_model: auto_cls = AutoModelForVision2Seq @@ -184,7 +185,8 @@ def __init__( torch_dtype=torch_dtype, trust_remote_code=True, **model_kwargs, - )) + ) + ) self.tokenizer = AutoTokenizer.from_pretrained( model_name, @@ -204,8 +206,7 @@ def __init__( ) except Exception: logger.warning( - "Unable to auto-load processor from HuggingFace for " - "model %s. Using tokenizer instead.", + "Unable to auto-load processor from HuggingFace for model %s. Using tokenizer instead.", model_name, ) self.processor = self.tokenizer @@ -362,7 +363,7 @@ def generate_greedy_logprobs_limit( last_hidden_states, self.model.get_output_embeddings().weight.t(), ) - if (getattr(self.model.get_output_embeddings(), "bias", None) is not None): + if getattr(self.model.get_output_embeddings(), "bias", None) is not None: logits += self.model.get_output_embeddings().bias.unsqueeze(0) logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) seq_logprobs.append(logprobs) @@ -389,8 +390,7 @@ def generate_greedy_logprobs_limit( all_output_strs.append(self.tokenizer.decode(output_ids)) outputs = zip(all_output_ids, all_output_strs, all_logprobs) - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] + return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: return self.model.encode(prompts) @@ -409,7 +409,6 @@ def hf_runner(): class VllmRunner: - def __init__( self, model_name: str, @@ -514,12 +513,10 @@ def generate_greedy_logprobs( num_logprobs: int, images: Optional[List[Image.Image]] = None, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: - greedy_logprobs_params = SamplingParams( - temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs) + greedy_logprobs_params = SamplingParams(temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs) outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params, images=images) - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] + return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] def generate_beam_search( self, diff --git a/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py b/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py index 55a24543e..ea18239cb 100644 --- a/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py +++ b/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py @@ -32,15 +32,14 @@ ckpt_path = args.ckpt_path with VllmRunner( - ckpt_path, - dtype="half", - quantization="bitblas", - # set enforce_eager = False to enable cuda graph - # set enforce_eager = True to disable cuda graph - enforce_eager=False, + ckpt_path, + dtype="half", + quantization="bitblas", + # set enforce_eager = False to enable cuda graph + # set enforce_eager = True to disable cuda graph + enforce_eager=False, ) as bitnet_model: - bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], - max_tokens=1024) + bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=1024) print("bitnet inference:") print(bitbnet_outputs[0][0]) print(bitbnet_outputs[0][1]) diff --git a/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py b/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py index 4f5f87f6f..f631fb306 100644 --- a/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py +++ b/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py @@ -33,13 +33,13 @@ ckpt_path = args.ckpt_path with VllmRunner( - ckpt_path, - dtype="half", - quantization="bitnet_bitblas", - gpu_memory_utilization=0.5, - # set enforce_eager = False to enable cuda graph - # set enforce_eager = True to disable cuda graph - enforce_eager=False, + ckpt_path, + dtype="half", + quantization="bitnet_bitblas", + gpu_memory_utilization=0.5, + # set enforce_eager = False to enable cuda graph + # set enforce_eager = True to disable cuda graph + enforce_eager=False, ) as bitnet_model: bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=128) print("bitnet inference output:") diff --git a/examples/bitnet-1.58b/vllm_workspace/utils.py b/examples/bitnet-1.58b/vllm_workspace/utils.py index daa9d8f52..e96b19e28 100644 --- a/examples/bitnet-1.58b/vllm_workspace/utils.py +++ b/examples/bitnet-1.58b/vllm_workspace/utils.py @@ -3,8 +3,7 @@ TokensText = Tuple[List[int], str] -def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText], - name_0: str, name_1: str): +def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText], name_0: str, name_1: str): """ Compare the two sequences generated by different models, which should be equal. @@ -15,19 +14,14 @@ def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[Tok output_ids_0, output_str_0 = outputs_0 output_ids_1, output_str_1 = outputs_1 - assert output_str_0 == output_str_1, (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") - assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") + assert output_str_0 == output_str_1, f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" + assert output_ids_0 == output_ids_1, f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]] -def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], - outputs_1_lst: List[TokensTextLogprobs], name_0: str, name_1: str): +def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], outputs_1_lst: List[TokensTextLogprobs], name_0: str, name_1: str): """ Compare the logprobs of two sequences generated by different models, which should be similar but not necessarily equal. @@ -41,16 +35,11 @@ def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], # Loop through generated tokens. for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): - # If generated tokens don't match, then if output_id_0 != output_id_1: # Each predicted token must be in top N logprobs of the other - assert output_id_0 in logprobs_1[idx], (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") - assert output_id_1 in logprobs_0[idx], (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") + assert output_id_0 in logprobs_1[idx], f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" + assert output_id_1 in logprobs_0[idx], f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" # Break out since sequences will now diverge. break diff --git a/examples/blocksparse_attention/README.md b/examples/blocksparse_attention/README.md index 89f75b81d..34bf3c637 100644 --- a/examples/blocksparse_attention/README.md +++ b/examples/blocksparse_attention/README.md @@ -1,6 +1,5 @@ # Block-Sparse Flash-Attention -Tilelang implementation of block-sparse flash-attention kernels. - -The kernels have been used in [Rectified Sparse Attention](https://arxiv.org/abs/2506.04108) and [SeerAttention-R](https://arxiv.org/abs/2506.08889). +Tilelang implementation of block-sparse flash-attention kernels. +The kernels have been used in [Rectified Sparse Attention](https://arxiv.org/abs/2506.04108) and [SeerAttention-R](https://arxiv.org/abs/2506.08889). diff --git a/examples/blocksparse_attention/block_sparse_attn_triton.py b/examples/blocksparse_attention/block_sparse_attn_triton.py index 014f0c5fc..b94e602f6 100644 --- a/examples/blocksparse_attention/block_sparse_attn_triton.py +++ b/examples/blocksparse_attention/block_sparse_attn_triton.py @@ -1,7 +1,6 @@ # ruff: noqa: E712 import math import torch - import triton import triton.language as tl import torch.nn.functional as F @@ -15,10 +14,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -56,7 +52,6 @@ def _fwd_kernel_inner( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) # print @@ -73,8 +68,7 @@ def _fwd_kernel_inner( # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N if LAST_K_BLOCK: - qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, - float('-inf')) + qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf")) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_ij[:, None] @@ -154,7 +148,7 @@ def _fwd_kernel( v_ptrs = V + off_v mask_ptrs = block_mask_ptr + start_m * stride_bmm - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) @@ -192,24 +186,12 @@ def _fwd_kernel( acc = acc * l_recip acc = acc.to(Out.dtype.element_ty) - off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[ - None, :] * stride_od + off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) -def _forward(ctx, - q, - k, - v, - block_sparse_mask, - sm_scale, - BLOCK_M=64, - BLOCK_N=64, - num_warps=None, - num_stages=1, - out=None): - +def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None): assert q.shape[-1] == k.shape[-1] == v.shape[-1] assert k.shape[2] == v.shape[2] o = out if out is not None else torch.empty_like(q).contiguous() @@ -254,7 +236,6 @@ def _forward(ctx, class _sparse_attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, block_sparse_dense, sm_scale): # shape constraints @@ -278,9 +259,9 @@ def test_topk_sparse_attention(): torch.manual_seed(0) # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) @@ -288,9 +269,7 @@ def test_topk_sparse_attention(): downsample_len = math.ceil(SEQ_LEN / downsample_factor) print("downsample_len", downsample_len) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 print("x_ds.shape", x_ds.shape) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) @@ -302,22 +281,21 @@ def test_topk_sparse_attention(): # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal # PyTorch reference implementation - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale - attn = attn.masked_fill(~full_mask, float('-inf')) + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) # print("ref_output", ref_output) # print("triton_output", triton_output) # Verify accuracy - assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ - "Triton output doesn't match reference" + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference" print("Pass topk sparse attention test with qlen == klen") @@ -329,9 +307,9 @@ def test_topk_sparse_attention_qlt_kl(): torch.manual_seed(0) # Create inputs. - q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) + q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) # softmax scale sm_scale = 1.0 / (D_HEAD**0.5) @@ -339,8 +317,7 @@ def test_topk_sparse_attention_qlt_kl(): print("downsample_factor", downsample_factor) downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension print("downsample_len", downsample_len) - x_ds = torch.randn( - BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.bfloat16) + x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.bfloat16) # Force the first column to be high so that the first block is always selected. x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) @@ -351,26 +328,25 @@ def test_topk_sparse_attention_qlt_kl(): past_len = K_LEN - Q_LEN - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale - full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool() + full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool() full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) - causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN) + causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN) final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) - attn = attn.masked_fill(~final_mask, float('-inf')) + attn = attn.masked_fill(~final_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) # Verify accuracy. - assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ - "Triton output doesn't match reference when qlen < klen" + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference when qlen < klen" print("Pass topk sparse attention test with qlen < klen") diff --git a/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py index 7e90db7e5..9a394710f 100644 --- a/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py +++ b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py @@ -1,8 +1,8 @@ import math import torch - import tilelang import tilelang.language as T +from tilelang.profiler import do_bench import torch.nn.functional as F @@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -30,105 +27,34 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F @tilelang.jit( - out_idx=[4], pass_configs={ + out_idx=[4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal): block_M = 64 block_N = 64 num_stages = 1 threads = 128 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] block_mask_shape = [batch, heads, downsample_len, downsample_len] - dtype = "float16" - accum_dtype = "float" - block_mask_dtype = "bool" + dtype = T.float16 + accum_dtype = T.float32 + block_mask_dtype = T.bool def kernel_func(block_M, block_N, num_stages, threads): - - @T.macro - def MMA0( - K: T.Tensor(shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(shape, dtype), - V_shared: T.SharedBuffer([block_M, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - @T.prim_func def blocksparse_flashattn( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), - Output: T.Tensor(shape, dtype), + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), + Output: T.Tensor(shape, dtype), ): - with T.Kernel( - T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -141,31 +67,59 @@ def blocksparse_flashattn( scores_scale = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - block_mask = T.alloc_local([downsample_len], block_mask_dtype) + block_mask = T.alloc_fragment([downsample_len], block_mask_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - for vj in T.serial(downsample_len): - block_mask[vj] = BlockSparseMask[bz, by, bx, vj] + T.copy(BlockSparseMask[bz, by, bx, :], block_mask) loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): if block_mask[k] != 0: - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, - scores_sum, logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return blocksparse_flashattn @@ -180,18 +134,16 @@ def test_topk_sparse_attention(): torch.manual_seed(0) # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) @@ -202,15 +154,15 @@ def test_topk_sparse_attention(): # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal # PyTorch reference implementation - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale - attn = attn.masked_fill(~full_mask, float('-inf')) + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) print("ref_output", ref_output) print("tilelang_output", tilelang_output) @@ -224,5 +176,26 @@ def main(): test_topk_sparse_attention() +def run_regression_perf(): + BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 32, 256, 64 + TOPK = 2 + BLOCK = 64 + torch.manual_seed(0) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + + def run_kernel_only(): + kernel(q, k, v, block_mask) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py index e29982162..a93e4de13 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py @@ -2,209 +2,172 @@ import torch import torch.nn.functional as F import tilelang -from tilelang.autotuner import * import tilelang.language as T from einops import rearrange, einsum import argparse import time import math +from tilelang.profiler import do_bench from heuristic import num_splits_heuristic -def flashattn(batch, heads, heads_kv, dim, dim_v): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) - dtype = "float16" - accum_dtype = "float" +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, heads_kv, dim, dim_v, block_N, block_H, page_block_size, num_stages, threads, num_pages): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // heads_kv - @tilelang.jit( - out_idx=[-1], pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) - def kernel_func(block_N, block_H, page_block_size, num_split, num_stages, threads, num_pages, - max_num_blocks_per_seq, max_selected_blocks): - shape_q = [batch, heads, dim] - shape_k = [num_pages, page_block_size, heads_kv, dim] - shape_v = [num_pages, page_block_size, heads_kv, dim_v] - shape_indices = [batch, heads_kv, max_selected_blocks] - shape_block_table = [batch, max_num_blocks_per_seq] - shape_o = [batch, heads, dim_v] - part_shape = [batch, heads, num_split, dim_v] - valid_block_H = min(block_H, kv_group_num) - assert block_N <= page_block_size and page_block_size % block_N == 0 - block_ratio = page_block_size // block_N - - @T.macro - def flash_attn_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - block_table: T.Tensor(shape_block_table, "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - ): - with T.Kernel( - batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_H, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim_v], dtype) - acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) - acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) - - scores_max = T.alloc_fragment([block_H], accum_dtype) - scores_max_prev = T.alloc_fragment([block_H], accum_dtype) - scores_scale = T.alloc_fragment([block_H], accum_dtype) - scores_sum = T.alloc_fragment([block_H], accum_dtype) - logsum = T.alloc_fragment([block_H], accum_dtype) - has_valid_block = T.alloc_var("bool") - - bid = bx - hid = by - sid = bz - cur_kv_head = hid // (kv_group_num // valid_block_H) - - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - - num_blocks = max_selected_blocks - blocks_per_split = T.floordiv(num_blocks, num_split) - remaining_blocks = T.floormod(num_blocks, num_split) - loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)) - start = blocks_per_split * sid + T.min(sid, remaining_blocks) - has_valid_block = False - for k in T.Pipelined(loop_range, num_stages=num_stages): - logical_block_idx = block_indices[bid, cur_kv_head, start + k] - if logical_block_idx >= 0: - has_valid_block = True - block_table_idx = T.floordiv(logical_block_idx, block_ratio) - block_tile_idx = T.floormod(logical_block_idx, block_ratio) - physical_block_idx = block_table[bid, block_table_idx] - T.copy( - K[physical_block_idx, - block_tile_idx * block_N:(block_tile_idx + 1) * block_N, - cur_kv_head, :], K_shared) - T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) - if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else( - logical_block_idx * block_N + j >= cache_seqlens[bid], - -T.infinity(accum_dtype), acc_s[i, j]) - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_H): - scores_max[i] = T.if_then_else(scores_max[i] > scores_max_prev[i], - scores_max[i], scores_max_prev[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - - scores_max[i] * scale) + num_split = T.dynamic("num_split") + max_num_blocks_per_seq = T.dynamic("max_num_blocks_per_seq") + max_selected_blocks = T.dynamic("max_selected_blocks") + + shape_q = [batch, heads, dim] + shape_k = [num_pages, page_block_size, heads_kv, dim] + shape_v = [num_pages, page_block_size, heads_kv, dim_v] + shape_indices = [batch, heads_kv, max_selected_blocks] + shape_block_table = [batch, max_num_blocks_per_seq] + shape_o = [batch, heads, dim_v] + part_shape = [batch, heads, num_split, dim_v] + valid_block_H = min(block_H, kv_group_num) + assert block_N <= page_block_size and page_block_size % block_N == 0 + block_ratio = page_block_size // block_N + + @T.prim_func + def main( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + block_table: T.Tensor(shape_block_table, T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), + ): + # flash_attn_split + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) + + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + has_valid_block = T.alloc_var(T.bool) + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + num_blocks = max_selected_blocks + blocks_per_split = T.floordiv(num_blocks, num_split) + remaining_blocks = T.floormod(num_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) + start = blocks_per_split * sid + T.min(sid, remaining_blocks) + has_valid_block = False + for k in T.Pipelined(loop_range, num_stages=num_stages): + logical_block_idx = block_indices[bid, cur_kv_head, start + k] + if logical_block_idx >= 0: + has_valid_block = True + block_table_idx = T.floordiv(logical_block_idx, block_ratio) + block_tile_idx = T.floormod(logical_block_idx, block_ratio) + physical_block_idx = block_table[bid, block_table_idx] + T.copy(K[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_H): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] *= scores_scale[i] - T.copy( - V[physical_block_idx, - block_tile_idx * block_N:(block_tile_idx + 1) * block_N, - cur_kv_head, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - if has_valid_block: - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] /= logsum[i] - + acc_s[i, j] = T.if_then_else( + logical_block_idx * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j] + ) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - - for i in T.Parallel(block_H): - if i < valid_block_H: - glse[bid, hid * valid_block_H + i, sid] = logsum[i] - + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] *= scores_scale[i] + T.copy(V[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + if has_valid_block: for i, j in T.Parallel(block_H, dim_v): - if i < valid_block_H: - Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] - - @T.macro - def combine( - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), - ): - with T.Kernel(heads, batch, threads=128) as (by, bz): - po_local = T.alloc_fragment([dim_v], accum_dtype) - o_accum_local = T.alloc_fragment([dim_v], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - max_split = T.alloc_local([1], "int32") - - T.annotate_layout({ - lse_logsum_local: - T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) - for k in T.serial(num_split): - lse_local_split[0] = glse[bz, by, k] - if (lse_local_split[0] != 0): - max_split[0] = k - lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) - - for k in T.Pipelined(num_split, num_stages=1): - if k <= max_split[0]: - lse_local_split[0] = glse[bz, by, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] - for k in T.serial(num_split): - if k <= max_split[0]: - for i in T.Parallel(dim_v): - po_local[i] = Output_partial[bz, by, k, i] - lse_local_split[0] = glse[bz, by, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) - for i in T.Parallel(dim_v): - o_accum_local[i] += po_local[i] * scale_local[0] - for i in T.Parallel(dim_v): - Output[bz, by, i] = o_accum_local[i] - - @T.prim_func - def main( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - block_table: T.Tensor(shape_block_table, "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), - ): - flash_attn_split(Q, K, V, block_indices, cache_seqlens, block_table, glse, - Output_partial) - combine(glse, Output_partial, Output) - - return main - - return kernel_func + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + # TODO(lei): Support T.Parallel(valid_block_H) + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] + for i, j in T.Parallel(block_H, dim_v): + if i < valid_block_H: + Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] + + # combine + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim_v], accum_dtype) + o_accum_local = T.alloc_fragment([dim_v], accum_dtype) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + max_split = T.alloc_var(T.int32) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_local_split = glse[bz, by, k] + if lse_local_split != 0: + max_split = k + lse_max_local = T.max(lse_max_local, glse[bz, by, k]) + + for k in T.Pipelined(num_split, num_stages=1): + if k <= max_split: + lse_local_split = glse[bz, by, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local + for k in T.serial(num_split): + if k <= max_split: + for i in T.Parallel(dim_v): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split = glse[bz, by, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) + for i in T.Parallel(dim_v): + o_accum_local[i] += po_local[i] * scale_local + for i in T.Parallel(dim_v): + Output[bz, by, i] = o_accum_local[i] + + print(main) + return main class SparseFlashAttn(torch.nn.Module): - def __init__(self, batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_pages): super(SparseFlashAttn, self).__init__() self.batch = batch @@ -216,19 +179,6 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, self.page_block_size = page_block_size self.num_pages = num_pages self.block_H = 64 - - self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( - block_N=block_N, - block_H=self.block_H, - page_block_size=page_block_size, - num_split=T.dynamic("num_split"), - num_stages=2, - threads=128, - num_pages=num_pages, - max_num_blocks_per_seq=T.dynamic("max_num_blocks_per_seq"), - max_selected_blocks=T.dynamic("max_selected_blocks"), - ) - props = torch.cuda.get_device_properties(torch.device("cuda:0")) self.num_sm = props.multi_processor_count @@ -250,40 +200,35 @@ def forward(self, query, key, value, block_indices, cache_seqlens, block_table): num_sm = self.num_sm num_split = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) - - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') - output_partial = torch.empty((batch, heads, num_split, dim_v), - dtype=torch.float32, - device='cuda') - - output = self.kernel( - query, - key, - value, - block_indices, - cache_seqlens, - block_table, - glse, - output_partial, + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + + output = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, + block_N=block_size, + block_H=self.block_H, + page_block_size=self.page_block_size, + num_stages=2, + threads=128, + num_pages=self.num_pages, + )(query, key, value, block_indices, cache_seqlens, block_table, glse, output_partial) return output -def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_seqlens, - block_table, page_block_size, block_size): +def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_seqlens, block_table, page_block_size, block_size): """ Paged version of sparse attention reference implementation. - + Args: query: [batch, heads, dim] - key_cache: [num_pages, page_block_size, heads_kv, dim] + key_cache: [num_pages, page_block_size, heads_kv, dim] value_cache: [num_pages, page_block_size, heads_kv, dim] block_indices: [batch, heads_kv, max_selected_blocks] - logical block indices cache_seqlens: [batch] - actual sequence lengths @@ -299,12 +244,8 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_ # Reconstruct the full key and value tensors from paged cache max_cache_seqlen = max(cache_seqlens).item() - key_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim), - dtype=key_cache.dtype, - device=key_cache.device) - value_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim_v), - dtype=value_cache.dtype, - device=value_cache.device) + key_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim), dtype=key_cache.dtype, device=key_cache.device) + value_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim_v), dtype=value_cache.dtype, device=value_cache.device) # Reconstruct full tensors from paged cache using block_table for b in range(batch): @@ -320,20 +261,14 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_ actual_block_size = end_token - start_token # Copy from paged cache to full tensors - key_full[b, :, start_token:end_token, :] = key_cache[ - physical_block_idx, :actual_block_size, :, :].transpose(0, 1) - value_full[b, :, start_token:end_token, :] = value_cache[ - physical_block_idx, :actual_block_size, :, :].transpose(0, 1) + key_full[b, :, start_token:end_token, :] = key_cache[physical_block_idx, :actual_block_size, :, :].transpose(0, 1) + value_full[b, :, start_token:end_token, :] = value_cache[physical_block_idx, :actual_block_size, :, :].transpose(0, 1) # Reshape query for grouped attention - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] # Compute attention scores - scores = einsum( - query, key_full, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = einsum(query, key_full, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] # Create sparse mask based on block_indices sparse_mask = torch.zeros_like(scores) @@ -349,24 +284,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_ sparse_mask[b, :, h, start_pos:end_pos] = 1 # Apply sparse mask - scores = scores.masked_fill(sparse_mask == 0, float('-inf')) + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) # Apply causal mask based on actual sequence lengths range_len = torch.arange(scores.shape[-1], device=scores.device).unsqueeze(0) cache_seqlens_expanded = cache_seqlens.unsqueeze(1) pad_mask = range_len >= cache_seqlens_expanded pad_mask = pad_mask[:, None, None, :] - scores = scores.masked_fill(pad_mask, float('-inf')) + scores = scores.masked_fill(pad_mask, float("-inf")) # Compute attention weights attention = F.softmax(scores / scale, dim=-1) # Apply attention to values - out = einsum(attention, value_full, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] + out = einsum(attention, value_full, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] # Reshape output back to original format - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out @@ -374,17 +308,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_ def ref_program_fa(query, kcache, vcache, cache_seqlens, block_table): # latency reference # from flash_attn_interface import flash_attn_with_kvcache # fa3 - from flash_attn import flash_attn_with_kvcache #fa2 + from flash_attn import flash_attn_with_kvcache # fa2 + query = query.unsqueeze(1) - output = flash_attn_with_kvcache( - query, kcache, vcache, cache_seqlens=cache_seqlens, block_table=block_table) + output = flash_attn_with_kvcache(query, kcache, vcache, cache_seqlens=cache_seqlens, block_table=block_table) output = output.squeeze(1) return output def main(args): - - batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = ( + args.batch, + args.heads, + args.heads_kv, + args.max_cache_seqlen, + args.dim, + args.dim_v, + ) sparse_ratio = args.sparse_ratio block_N = args.block_N page_block_size = args.page_block_size @@ -396,35 +336,30 @@ def main(args): dtype = torch.float16 # Generate random inputs - Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') - cache_seqlens = torch.randint( - max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device='cuda') + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device="cuda") print("cache_seqlens: ", cache_seqlens) - K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') - V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") # Create paged KV cache - K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device='cuda') - V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v), - dtype=dtype, - device='cuda') + K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device="cuda") + V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v), dtype=dtype, device="cuda") # Create block table and block indices for dense case (all blocks selected) max_num_blocks_per_seq = int(math.ceil(max_cache_seqlen / page_block_size)) print("max_num_blocks_per_seq: ", max_num_blocks_per_seq) - block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device='cuda') - block_indices = torch.zeros((batch, heads_kv, max_selected_blocks), - dtype=torch.int32, - device='cuda') + block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device="cuda") + block_indices = torch.zeros((batch, heads_kv, max_selected_blocks), dtype=torch.int32, device="cuda") # Fill block table and block indices and cache # Create a pool of available physical blocks - total_blocks_needed = sum( - int(math.ceil(cache_seqlens[seq_idx].item() / page_block_size)) for seq_idx in range(batch)) + total_blocks_needed = sum(int(math.ceil(cache_seqlens[seq_idx].item() / page_block_size)) for seq_idx in range(batch)) available_blocks = list(range(total_blocks_needed)) import random + random.seed(42) # For reproducibility random.shuffle(available_blocks) @@ -459,10 +394,8 @@ def main(args): actual_block_size = end_token - start_token # Copy K and V data to the paged cache - K_cache[physical_block_idx, :actual_block_size, :, :] = K[seq_idx, - start_token:end_token, :, :] - V_cache[physical_block_idx, :actual_block_size, :, :] = V[seq_idx, - start_token:end_token, :, :] + K_cache[physical_block_idx, :actual_block_size, :, :] = K[seq_idx, start_token:end_token, :, :] + V_cache[physical_block_idx, :actual_block_size, :, :] = V[seq_idx, start_token:end_token, :, :] # Fill block_indices for sparse attention # For dense case (verification), we select all blocks in reverse order @@ -497,10 +430,9 @@ def main(args): remaining_blocks = [b for b in all_blocks if b not in selected_blocks] if remaining_blocks: import random + random.seed(42) # For reproducibility - additional_blocks = random.sample( - remaining_blocks, - min(num_selected - recent_blocks, len(remaining_blocks))) + additional_blocks = random.sample(remaining_blocks, min(num_selected - recent_blocks, len(remaining_blocks))) selected_blocks.extend(additional_blocks) # Sort selected blocks in reverse order (most recent first) @@ -513,25 +445,20 @@ def main(args): block_indices[seq_idx, head_idx, i] = -1 # Initialize sparse attention module - sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, - num_blocks) - output_sparse = sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, - block_table) + sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_blocks) + output_sparse = sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table) import flash_attn # noqa: F401 - output_ref_torch = ref_program_torch_paged(Q, K_cache, V_cache, block_indices, cache_seqlens, - block_table, page_block_size, block_N) + output_ref_torch = ref_program_torch_paged(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table, page_block_size, block_N) output_ref_fa = ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table) # Check correctness if sparse_ratio == 0.0: max_diff = torch.max(torch.abs(output_sparse - output_ref_fa)).item() mean_diff = torch.mean(torch.abs(output_sparse - output_ref_fa)).item() - assert torch.allclose( - output_ref_fa, output_ref_torch, atol=1e-2), "Reference outputs do not match!" + assert torch.allclose(output_ref_fa, output_ref_torch, atol=1e-2), "Reference outputs do not match!" else: - max_diff = torch.max(torch.abs(output_sparse - output_ref_torch)).item() mean_diff = torch.mean(torch.abs(output_sparse - output_ref_torch)).item() @@ -573,18 +500,144 @@ def main(args): print(f"Speedup: {kernel_time_fa / kernel_time:.2f}x") +def run_regression_perf(args): + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = ( + args.batch, + args.heads, + args.heads_kv, + args.max_cache_seqlen, + args.dim, + args.dim_v, + ) + sparse_ratio = args.sparse_ratio + block_N = args.block_N + page_block_size = args.page_block_size + num_pages = args.num_pages + max_selected_blocks = int(math.ceil(max_cache_seqlen / block_N)) + dtype = torch.float16 + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + K_cache = torch.zeros((num_pages, page_block_size, heads_kv, dim), dtype=dtype, device="cuda") + V_cache = torch.zeros((num_pages, page_block_size, heads_kv, dim_v), dtype=dtype, device="cuda") + max_num_blocks_per_seq = int(math.ceil(max_cache_seqlen / page_block_size)) + block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device="cuda") + block_indices = torch.zeros((batch, heads_kv, max_selected_blocks), dtype=torch.int32, device="cuda") + total_blocks_needed = sum(int(math.ceil(cache_seqlens[seq_idx].item() / page_block_size)) for seq_idx in range(batch)) + available_blocks = list(range(total_blocks_needed)) + import random + + random.seed(42) + random.shuffle(available_blocks) + block_assignment = {} + block_idx_counter = 0 + for seq_idx in range(batch): + seq_len = cache_seqlens[seq_idx].item() + num_blocks_needed = int(math.ceil(seq_len / page_block_size)) + for block_idx in range(num_blocks_needed): + physical_block_idx = available_blocks[block_idx_counter] + block_table[seq_idx, block_idx] = physical_block_idx + block_assignment[(seq_idx, block_idx)] = physical_block_idx + block_idx_counter += 1 + for seq_idx in range(batch): + seq_len = cache_seqlens[seq_idx].item() + num_blocks_needed = int(math.ceil(seq_len / page_block_size)) + for block_idx in range(num_blocks_needed): + physical_block_idx = block_assignment[(seq_idx, block_idx)] + start_token = block_idx * page_block_size + end_token = min(start_token + page_block_size, seq_len) + actual_block_size = end_token - start_token + K_cache[physical_block_idx, :actual_block_size, :, :] = K[seq_idx, start_token:end_token, :, :] + V_cache[physical_block_idx, :actual_block_size, :, :] = V[seq_idx, start_token:end_token, :, :] + for seq_idx in range(batch): + seq_len = cache_seqlens[seq_idx].item() + num_tile = int(math.ceil(seq_len / block_N)) + if sparse_ratio == 0.0: + selected_blocks = min(num_tile, max_selected_blocks) + for head_idx in range(heads_kv): + for i in range(selected_blocks): + block_indices[seq_idx, head_idx, i] = num_tile - 1 - i + for i in range(selected_blocks, max_selected_blocks): + block_indices[seq_idx, head_idx, i] = -1 + else: + num_selected = int(num_tile * (1.0 - sparse_ratio)) + num_selected = max(1, min(num_selected, max_selected_blocks)) + all_blocks = list(range(num_tile)) + for head_idx in range(heads_kv): + selected_blocks = [] + recent_blocks = 1 + selected_blocks.append(num_tile - 1) + if num_selected > recent_blocks: + remaining_blocks = [b for b in all_blocks if b not in selected_blocks] + if remaining_blocks: + import random + + random.seed(42) + additional_blocks = random.sample(remaining_blocks, min(num_selected - recent_blocks, len(remaining_blocks))) + selected_blocks.extend(additional_blocks) + + selected_blocks.sort(reverse=True) + + for i in range(len(selected_blocks)): + block_indices[seq_idx, head_idx, i] = selected_blocks[i] + for i in range(len(selected_blocks), max_selected_blocks): + block_indices[seq_idx, head_idx, i] = -1 + + sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_pages) + batch = sparse_kernel.batch + heads = sparse_kernel.heads + heads_kv = sparse_kernel.heads_kv + dim_v = sparse_kernel.dim_v + dim = sparse_kernel.dim + block_size = sparse_kernel.block_N + max_selected_blocks = block_indices.shape[-1] + + num_m_blocks = 1 * (heads // heads_kv + sparse_kernel.block_H - 1) // sparse_kernel.block_H + num_n_blocks = max_selected_blocks + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + num_sm = sparse_kernel.num_sm + + num_split = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + kernel = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, + block_N=block_size, + block_H=sparse_kernel.block_H, + page_block_size=sparse_kernel.page_block_size, + num_stages=2, + threads=128, + num_pages=sparse_kernel.num_pages, + ) + + def run_kernel_only(): + kernel(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table, glse, output_partial) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.0, help='sparse ratio') - parser.add_argument('--block_N', type=int, default=64, help='block_N') - parser.add_argument('--page_block_size', type=int, default=256, help='block size of pages') - parser.add_argument('--num_pages', type=int, default=1024, help='total number of pages') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.0, help="sparse ratio") + parser.add_argument("--block_N", type=int, default=64, help="block_N") + parser.add_argument("--page_block_size", type=int, default=256, help="block size of pages") + parser.add_argument("--num_pages", type=int, default=1024, help="total number of pages") args = parser.parse_args() main(args) diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py index ae3004267..54148e69b 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py @@ -7,191 +7,156 @@ import time import math from heuristic import num_splits_heuristic - - -def flashattn(batch, heads, heads_kv, dim, dim_v): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) - dtype = "float16" - accum_dtype = "float" +from tilelang.profiler import do_bench + + +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, heads_kv, dim, dim_v, block_N, block_H, num_stages, threads): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // heads_kv - @tilelang.jit( - out_idx=[-1], pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) - def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, - max_selected_blocks): - shape_q = [batch, heads, dim] - shape_k = [batch, max_cache_seqlen, heads_kv, dim] - shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] - shape_indices = [batch, heads_kv, max_selected_blocks] - shape_o = [batch, heads, dim_v] - part_shape = [batch, heads, num_split, dim_v] - valid_block_H = min(block_H, kv_group_num) - - @T.macro - def flash_attn_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - # actual_num_blocks: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - ): - with T.Kernel( - batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_H, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim_v], dtype) - # O_shared = T.alloc_shared([valid_block_H, dim_v], dtype) - acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) - acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) - - scores_max = T.alloc_fragment([block_H], accum_dtype) - scores_max_prev = T.alloc_fragment([block_H], accum_dtype) - scores_scale = T.alloc_fragment([block_H], accum_dtype) - scores_sum = T.alloc_fragment([block_H], accum_dtype) - logsum = T.alloc_fragment([block_H], accum_dtype) - has_valid_block = T.alloc_var("bool") - - bid = bx - hid = by - sid = bz - cur_kv_head = hid // (kv_group_num // valid_block_H) - - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - - num_blocks = max_selected_blocks - blocks_per_split = T.floordiv(num_blocks, num_split) - remaining_blocks = T.floormod(num_blocks, num_split) - loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)) - start = blocks_per_split * sid + T.min(sid, remaining_blocks) - has_valid_block = False - - for k in T.Pipelined(loop_range, num_stages=num_stages): - i_s = block_indices[bid, cur_kv_head, start + k] - if i_s >= 0: - has_valid_block = True - T.copy(K[bid, i_s * block_N:(i_s + 1) * block_N, cur_kv_head, :], K_shared) - T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) - if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition - for i, j in T.Parallel(block_H, block_N): - acc_s[i, - j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid], - -T.infinity(accum_dtype), acc_s[i, j]) - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_H): - scores_max[i] = T.if_then_else(scores_max[i] > scores_max_prev[i], - scores_max[i], scores_max_prev[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - - scores_max[i] * scale) + num_split = T.dynamic("num_split") + max_cache_seqlen = T.dynamic("max_cache_seqlen") + max_selected_blocks = T.dynamic("max_selected_blocks") + + shape_q = [batch, heads, dim] + shape_k = [batch, max_cache_seqlen, heads_kv, dim] + shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] + shape_indices = [batch, heads_kv, max_selected_blocks] + shape_o = [batch, heads, dim_v] + part_shape = [batch, heads, num_split, dim_v] + valid_block_H = min(block_H, kv_group_num) + + @T.prim_func + def main( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + # actual_num_blocks: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) + + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + has_valid_block = T.alloc_var(T.bool) + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + num_blocks = max_selected_blocks + blocks_per_split = T.floordiv(num_blocks, num_split) + remaining_blocks = T.floormod(num_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) + start = blocks_per_split * sid + T.min(sid, remaining_blocks) + has_valid_block = False + + for k in T.Pipelined(loop_range, num_stages=num_stages): + i_s = block_indices[bid, cur_kv_head, start + k] + if i_s >= 0: + has_valid_block = True + T.copy(K[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_H): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] *= scores_scale[i] - T.copy(V[bid, i_s * block_N:(i_s + 1) * block_N, cur_kv_head, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - if has_valid_block: - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] /= logsum[i] - + acc_s[i, j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j]) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - - for i in T.Parallel(block_H): - if i < valid_block_H: - glse[bid, hid * valid_block_H + i, sid] = logsum[i] - + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] *= scores_scale[i] + T.copy(V[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + if has_valid_block: for i, j in T.Parallel(block_H, dim_v): - if i < valid_block_H: - Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] - - @T.macro - def combine( - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), - ): - with T.Kernel(heads, batch, threads=128) as (by, bz): - po_local = T.alloc_fragment([dim_v], accum_dtype) - o_accum_local = T.alloc_fragment([dim_v], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - max_split = T.alloc_local([1], "int32") - - T.annotate_layout({ - lse_logsum_local: - T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) - for k in T.serial(num_split): - lse_local_split[0] = glse[bz, by, k] - if (lse_local_split[0] != 0): - max_split[0] = k - lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) - - for k in T.Pipelined(num_split, num_stages=1): - if k <= max_split[0]: - lse_local_split[0] = glse[bz, by, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] - for k in T.serial(num_split): - if k <= max_split[0]: - for i in T.Parallel(dim_v): - po_local[i] = Output_partial[bz, by, k, i] - lse_local_split[0] = glse[bz, by, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) - for i in T.Parallel(dim_v): - o_accum_local[i] += po_local[i] * scale_local[0] - for i in T.Parallel(dim_v): - Output[bz, by, i] = o_accum_local[i] - - @T.prim_func - def main( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - # actual_num_blocks: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), - ): - # flash_attn_split(Q, K, V, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial) - flash_attn_split(Q, K, V, block_indices, cache_seqlens, glse, Output_partial) - combine(glse, Output_partial, Output) - - return main - - return kernel_func + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + # TODO(lei): Support T.Parallel(valid_block_H) + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] + for i, j in T.Parallel(block_H, dim_v): + if i < valid_block_H: + Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] + + # combine + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim_v], accum_dtype) + o_accum_local = T.alloc_fragment([dim_v], accum_dtype) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + max_split = T.alloc_var(T.int32) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_local_split = glse[bz, by, k] + if lse_local_split != 0: + max_split = k + lse_max_local = T.max(lse_max_local, glse[bz, by, k]) + + for k in T.Pipelined(num_split, num_stages=1): + if k <= max_split: + lse_local_split = glse[bz, by, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local + for k in T.serial(num_split): + if k <= max_split: + for i in T.Parallel(dim_v): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split = glse[bz, by, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) + for i in T.Parallel(dim_v): + o_accum_local[i] += po_local[i] * scale_local + for i in T.Parallel(dim_v): + Output[bz, by, i] = o_accum_local[i] + + return main class SparseFlashAttn(torch.nn.Module): - def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): super(SparseFlashAttn, self).__init__() self.batch = batch @@ -200,18 +165,7 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): self.dim = dim self.dim_v = dim_v self.block_size = block_size - self.block_H = 64 - - self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( - block_N=block_size, - block_H=self.block_H, - num_split=T.dynamic("num_split"), - num_stages=2, - threads=128, - max_cache_seqlen=T.dynamic("max_cache_seqlen"), - max_selected_blocks=T.dynamic("max_selected_blocks")) - props = torch.cuda.get_device_properties(torch.device("cuda:0")) self.num_sm = props.multi_processor_count @@ -233,25 +187,27 @@ def forward(self, query, key, value, block_indices, cache_seqlens): num_sm = self.num_sm num_split = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) - - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') - output_partial = torch.empty((batch, heads, num_split, dim_v), - dtype=torch.float32, - device='cuda') - - output = self.kernel(query, key, value, block_indices, cache_seqlens, glse, output_partial) + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + + output = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, + block_N=block_size, + block_H=self.block_H, + num_stages=2, + threads=128, + )(query, key, value, block_indices, cache_seqlens, glse, output_partial) return output -def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seqlens, - max_cache_seqlen, block_size): +def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, block_size): """ Args: query: [batch, heads, dim] @@ -273,61 +229,51 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql block_H = 64 actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32) - actual_num_blocks = actual_num_blocks[:, - 0] #[batch], number of valid blocks, assume all groups in the same batch have the same number of blocks + actual_num_blocks = actual_num_blocks[ + :, 0 + ] # [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks # get num_split num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H - num_n_blocks = max_selected_blocks #(kv_seqlen + block_size - 1 ) // block_size + num_n_blocks = max_selected_blocks # (kv_seqlen + block_size - 1 ) // block_size # num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks - size_one_kv_head = max_selected_blocks * block_size * ( - dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2 + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks num_sm = 132 num_split = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) - - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') - Output_partial = torch.empty((batch, heads, num_split, dim_v), - dtype=torch.float32, - device='cuda') - kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + kernel = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, block_N=block_size, block_H=block_H, - num_split=T.dynamic("num_split"), num_stages=2, threads=128, - max_cache_seqlen=T.dynamic("max_cache_seqlen"), - max_selected_blocks=T.dynamic("max_selected_blocks")) + ) output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial) return output -def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): - +def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): batch, heads, dim = query.shape heads_kv = key.shape[2] num_head_groups = query.shape[1] // key.shape[2] scale = dim**0.5 - key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] - value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] sparse_mask = torch.zeros_like(scores) # Assign mask values based on block_indices @@ -336,149 +282,141 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache valid_indices = block_indices[b, h] # Extract indices for this batch and head for idx in valid_indices: if idx >= 0: - sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 - scores = scores.masked_fill(sparse_mask == 0, float('-inf')) + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) - range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) cache_seqlens_expanded = cache_seqlens.unsqueeze(1) pad_mask = range_len >= cache_seqlens_expanded pad_mask = pad_mask[:, None, None, :] - scores = scores.masked_fill(pad_mask, float('-inf')) - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] - out = einsum(attention, value, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out -def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): +def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): # latency reference # from flash_attn_interface import flash_attn_with_kvcache # fa3 - from flash_attn import flash_attn_with_kvcache #fa2 + from flash_attn import flash_attn_with_kvcache # fa2 + query = query.unsqueeze(1) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = output.squeeze(1) return output -def debug(name, expect, actual, atol=1e-3, rtol=1e-3): +def assert_close(name, expect, actual, atol=1e-3, rtol=1e-3): all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol) print(name + " all_close={}".format(all_close)) if not all_close: diff = (expect - actual).abs() - print("all_close={}, max={}, min={}, mean={}".format(all_close, - diff.max().item(), - diff.min().item(), - diff.mean().item())) + print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item())) max_indices = torch.nonzero(diff == diff.max().item()) first_index = tuple(max_indices[0].tolist()) print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}") -def main(batch=8, - heads=32, - heads_kv=8, - max_cache_seqlen=8192, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32): +def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + dtype = torch.float16 sparse_ratio = sparse_ratio block_size = block_size max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) - print("max_selected_blocks: ", max_selected_blocks) - dtype = torch.float16 - - Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') - K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') - V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') - cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') - # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') - # # Ensure at least one element equals cache_seqlen - # random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index - # # cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence - - print("cache_seqlens: ", cache_seqlens) + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() - print("max_valid_num_blocks: ", max_valid_num_blocks) - # Initialize block_indices with -1 (for padding blocks) - block_indices = torch.full((batch, heads_kv, max_selected_blocks), - -1, - dtype=torch.int32, - device='cuda') - # max_num_blocks = int((max_cache_seqlen + block_size - 1)/ block_size) - # block_indices = torch.full((batch, heads_kv, max_num_blocks), -1, dtype=torch.int32, device='cuda') + block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda") # Assign valid indices while ensuring no duplicates within each batch-group for b in range(batch): max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch if max_valid_block > 0: # Ensure there's at least one valid block for h in range(heads_kv): - valid_indices = torch.randperm( - max_valid_block, device='cuda', dtype=torch.int32)[:max_selected_blocks] - # valid_indices = torch.randperm(max_valid_block, device='cuda', dtype=torch.int32)[:max_num_blocks] - block_indices[b, h, :len(valid_indices)] = valid_indices + valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks] + block_indices[b, h, : len(valid_indices)] = valid_indices - # Sort indices within each batch-group for consistency block_indices, _ = block_indices.sort(dim=-1, descending=True) - # print("block_indices: ", block_indices) - actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32)[:, 0] - print("actual_num_blocks: ", actual_num_blocks) - # print(block_indices.shape, actual_num_blocks.shape) - max_num_blocks = torch.max(max_valid_num_blocks).item() - print("max_num_blocks: ", max_num_blocks) # parity reference - ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, - block_size) + ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) - debug("output", ref, out, atol=1e-3, rtol=1e-3) - - import flash_attn # noqa: F401 + assert_close("output", ref, out, atol=1e-3, rtol=1e-3) ## latency reference for _ in range(10): - ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, - max_num_blocks, block_size) + ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) torch.cuda.synchronize() start = time.time() for _ in range(100): - ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, - max_num_blocks, block_size) + ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) torch.cuda.synchronize() print("dense time: ", (time.time() - start) / 100 * 1000) for _ in range(10): - # out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size) out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) torch.cuda.synchronize() start = time.time() for _ in range(100): - # out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size) out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) torch.cuda.synchronize() print("sparse time: ", (time.time() - start) / 100 * 1000) +def run_regression_perf(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + sparse_ratio = sparse_ratio + block_size = block_size + max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) + dtype = torch.float16 + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") + max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() + block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda") + + for b in range(batch): + max_valid_block = max_valid_num_blocks[b].item() + if max_valid_block > 0: + for h in range(heads_kv): + valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks] + block_indices[b, h, : len(valid_indices)] = valid_indices + + block_indices, _ = block_indices.sort(dim=-1, descending=True) + sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) + batch = sparse_kernel.batch + heads = sparse_kernel.heads + heads_kv = sparse_kernel.heads_kv + dim_v = sparse_kernel.dim_v + dim = sparse_kernel.dim + block_size = sparse_kernel.block_size + + def run_kernel_only(): + sparse_kernel(Q, K, V, block_indices, cache_seqlens) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') - parser.add_argument('--block_size', type=int, default=32, help='block_size') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") args = parser.parse_args() - main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, - args.sparse_ratio, args.block_size) + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py index ad62817dd..e588ec54c 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py @@ -1,184 +1,156 @@ import torch import torch.nn.functional as F import tilelang -from tilelang.autotuner import * import tilelang.language as T from einops import rearrange, einsum import argparse - import time import math from heuristic import num_splits_heuristic - - -def flashattn(batch, heads, heads_kv, dim, dim_v): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) - dtype = "float16" - accum_dtype = "float" +from tilelang.profiler import do_bench + + +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, heads_kv, dim, dim_v, block_N, block_H, num_stages, threads): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // heads_kv - @tilelang.jit( - out_idx=[-1], pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) - def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks): - shape_q = [batch, heads, dim] - shape_k = [batch, max_cache_seqlen, heads_kv, dim] - shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] - shape_mask = [batch, heads_kv, num_blocks] - shape_o = [batch, heads, dim_v] - part_shape = [batch, heads, num_split, dim_v] - valid_block_H = min(block_H, kv_group_num) - - @T.macro - def flash_attn_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_mask: T.Tensor(shape_mask, "bool"), - cache_seqlens: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - ): - with T.Kernel( - batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_H, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim_v], dtype) - # O_shared = T.alloc_shared([valid_block_H, dim_v], dtype) - acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) - acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) - - scores_max = T.alloc_fragment([block_H], accum_dtype) - scores_max_prev = T.alloc_fragment([block_H], accum_dtype) - scores_scale = T.alloc_fragment([block_H], accum_dtype) - scores_sum = T.alloc_fragment([block_H], accum_dtype) - logsum = T.alloc_fragment([block_H], accum_dtype) - has_valid_block = T.alloc_var("bool") - - bid = bx - hid = by - sid = bz - cur_kv_head = hid // (kv_group_num // valid_block_H) - - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - blocks_per_split = T.floordiv(num_blocks, num_split) - remaining_blocks = T.floormod(num_blocks, num_split) - loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)) - start = blocks_per_split * sid + T.min(sid, remaining_blocks) - has_valid_block = False - for k in T.Pipelined(loop_range, num_stages=num_stages): - if block_mask[bid, hid, start + k]: - has_valid_block = True - T.copy( - K[bid, (start + k) * block_N:(start + k + 1) * block_N, cur_kv_head, :], - K_shared) - T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else((start + k) * block_N + j - >= cache_seqlens[bx], - -T.infinity(accum_dtype), acc_s[i, j]) - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_H): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - - scores_max[i] * scale) - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_H): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] *= scores_scale[i] - T.copy( - V[bid, (start + k) * block_N:(start + k + 1) * block_N, cur_kv_head, :], - V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - if has_valid_block: - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] /= logsum[i] + num_split = T.dynamic("num_split") + max_cache_seqlen = T.dynamic("max_cache_seqlen") + num_blocks = T.dynamic("num_blocks") + + shape_q = [batch, heads, dim] + shape_k = [batch, max_cache_seqlen, heads_kv, dim] + shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] + shape_mask = [batch, heads_kv, num_blocks] + shape_o = [batch, heads, dim_v] + part_shape = [batch, heads, num_split, dim_v] + valid_block_H = min(block_H, kv_group_num) + + @T.prim_func + def main( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_mask: T.Tensor(shape_mask, T.bool), + cache_seqlens: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) + + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + has_valid_block = T.alloc_var(T.bool) + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + blocks_per_split = T.floordiv(num_blocks, num_split) + remaining_blocks = T.floormod(num_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) + start = blocks_per_split * sid + T.min(sid, remaining_blocks) + has_valid_block = False + for k in T.Pipelined(loop_range, num_stages=num_stages): + if block_mask[bid, hid, start + k]: + has_valid_block = True + T.copy(K[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else((start + k) * block_N + j >= cache_seqlens[bx], -T.infinity(accum_dtype), acc_s[i, j]) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - - for i in T.Parallel(block_H): - if i < valid_block_H: - glse[bid, hid * valid_block_H + i, sid] = logsum[i] - + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] *= scores_scale[i] + T.copy(V[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + if has_valid_block: for i, j in T.Parallel(block_H, dim_v): - if i < valid_block_H: - Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] - - @T.macro - def combine( - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), - ): - with T.Kernel(heads, batch, threads=128) as (by, bz): - po_local = T.alloc_fragment([dim_v], accum_dtype) - o_accum_local = T.alloc_fragment([dim_v], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - - T.annotate_layout({ - lse_logsum_local: - T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) - for k in T.serial(num_split): - lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) - for k in T.Pipelined(num_split, num_stages=1): - lse_local_split[0] = glse[bz, by, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] - for k in T.serial(num_split): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + # TODO(lei): Support T.Parallel(valid_block_H) + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] + for i, j in T.Parallel(block_H, dim_v): + if i < valid_block_H: + Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] + + # combine + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim_v], accum_dtype) + o_accum_local = T.alloc_fragment([dim_v], accum_dtype) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + max_split = T.alloc_var(T.int32) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_local_split = glse[bz, by, k] + if lse_local_split != 0: + max_split = k + lse_max_local = T.max(lse_max_local, glse[bz, by, k]) + + for k in T.Pipelined(num_split, num_stages=1): + if k <= max_split: + lse_local_split = glse[bz, by, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local + for k in T.serial(num_split): + if k <= max_split: for i in T.Parallel(dim_v): po_local[i] = Output_partial[bz, by, k, i] - lse_local_split[0] = glse[bz, by, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + lse_local_split = glse[bz, by, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) for i in T.Parallel(dim_v): - o_accum_local[i] += po_local[i] * scale_local[0] - for i in T.Parallel(dim_v): - Output[bz, by, i] = o_accum_local[i] + o_accum_local[i] += po_local[i] * scale_local + for i in T.Parallel(dim_v): + Output[bz, by, i] = o_accum_local[i] - @T.prim_func - def main( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_mask: T.Tensor(shape_mask, "bool"), - cache_seqlens: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), - ): - flash_attn_split(Q, K, V, block_mask, cache_seqlens, glse, Output_partial) - combine(glse, Output_partial, Output) - - return main - - return kernel_func + return main class SparseFlashAttn(torch.nn.Module): - def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): super(SparseFlashAttn, self).__init__() self.batch = batch @@ -187,18 +159,7 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): self.dim = dim self.dim_v = dim_v self.block_size = block_size - self.block_H = 64 - - self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( - block_N=block_size, - block_H=self.block_H, - num_split=T.dynamic("num_split"), - num_stages=2, - threads=128, - max_cache_seqlen=T.dynamic("max_cache_seqlen"), - num_blocks=T.dynamic("num_blocks")) - props = torch.cuda.get_device_properties(torch.device("cuda:0")) self.num_sm = props.multi_processor_count @@ -209,32 +170,33 @@ def forward(self, query, key, value, block_mask, cache_seqlens): dim_v = self.dim_v dim = self.dim block_size = self.block_size - block_H = self.block_H max_cache_seqlen = key.shape[1] # get num_split max_selected_blocks = (max_cache_seqlen + block_size - 1) // block_size - num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H + num_m_blocks = 1 * (heads // heads_kv + self.block_H - 1) // self.block_H num_n_blocks = max_selected_blocks - size_one_kv_head = max_selected_blocks * block_size * ( - dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2 + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks - # num_sm = 132 num_sm = self.num_sm num_split = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) - # print("num_split: ", num_split) - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') - Output_partial = torch.empty((batch, heads, num_split, dim_v), - dtype=torch.float32, - device='cuda') - output = self.kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + + output = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, + block_N=block_size, + block_H=self.block_H, + num_stages=2, + threads=128, + )(query, key, value, block_mask, cache_seqlens, glse, output_partial) return output @@ -258,64 +220,52 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, block_H = 64 actual_num_blocks = torch.sum(block_mask, dim=-1).to(torch.int32) - actual_num_blocks = actual_num_blocks[:, - 0] #[batch], number of valid blocks, assume all groups in the same batch have the same number of blocks + actual_num_blocks = actual_num_blocks[ + :, 0 + ] # [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks max_selected_blocks = actual_num_blocks.max().item() # get num_split num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H - num_n_blocks = max_selected_blocks #(kv_seqlen + block_size - 1 ) // block_size + num_n_blocks = max_selected_blocks # (kv_seqlen + block_size - 1 ) // block_size # num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks - size_one_kv_head = max_selected_blocks * block_size * ( - dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2 + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks num_sm = 132 num_split = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) - - kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + kernel = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, block_N=block_size, block_H=block_H, - num_split=T.dynamic("num_split"), num_stages=2, threads=128, - max_cache_seqlen=T.dynamic("max_cache_seqlen"), - num_blocks=T.dynamic("num_blocks")) - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') - Output_partial = torch.empty((batch, heads, num_split, dim_v), - dtype=torch.float32, - device='cuda') - # print(kernel.get_kernel_source()) + ) output = kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) - return output -def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): - +def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size): batch, heads, dim = query.shape heads_kv = key.shape[2] num_head_groups = query.shape[1] // key.shape[2] scale = dim**0.5 - key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] - value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] sparse_mask = torch.zeros_like(scores) # Assign mask values @@ -323,59 +273,45 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se for h in range(heads_kv): for idx in range(num_blocks): if block_mask[b, h, idx]: - sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 - scores = scores.masked_fill(sparse_mask == 0, float('-inf')) + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) - range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) cache_seqlens_expanded = cache_seqlens.unsqueeze(1) pad_mask = range_len >= cache_seqlens_expanded pad_mask = pad_mask[:, None, None, :] - scores = scores.masked_fill(pad_mask, float('-inf')) - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] - out = einsum(attention, value, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out -def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): +def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): # latency reference # from flash_attn_interface import flash_attn_with_kvcache # fa3 - from flash_attn import flash_attn_with_kvcache #fa2 + from flash_attn import flash_attn_with_kvcache # fa2 + query = query.unsqueeze(1) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = output.squeeze(1) return output -def debug(name, expect, actual, atol=1e-3, rtol=1e-3): +def assert_close(name, expect, actual, atol=1e-3, rtol=1e-3): all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol) print(name + " all_close={}".format(all_close)) if not all_close: - # print(expect[3, 28]) - # print(actual[3, 28]) diff = (expect - actual).abs() - print("all_close={}, max={}, min={}, mean={}".format(all_close, - diff.max().item(), - diff.min().item(), - diff.mean().item())) + print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item())) max_indices = torch.nonzero(diff == diff.max().item()) first_index = tuple(max_indices[0].tolist()) print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}") -def main(batch=8, - heads=32, - heads_kv=8, - max_cache_seqlen=8192, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32): +def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v sparse_ratio = sparse_ratio block_size = block_size @@ -383,14 +319,13 @@ def main(batch=8, print("max_selected_blocks: ", max_selected_blocks) dtype = torch.float16 - Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') - K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') - V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') - cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") # Ensure at least one element equals cache_seqlen - random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index - cache_seqlens[ - random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence + random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index + cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') print("cache_seqlens: ", cache_seqlens) @@ -402,7 +337,7 @@ def main(batch=8, max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() print("max_valid_num_blocks: ", max_valid_num_blocks) # Initialize block_mask with false (for padding blocks) - block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device='cuda') + block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda") # Assign valid indices while ensuring no duplicates within each batch-group for b in range(batch): @@ -410,29 +345,26 @@ def main(batch=8, valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch if valid_num_block > 0: # Ensure there's at least one valid block for h in range(heads_kv): - perm = torch.randperm(max_valid_block, device='cuda')[:valid_num_block] + perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block] block_mask[b, h, perm] = True # print("block_mask: ", block_mask) # parity reference - ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size) + ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) # out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size) model = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) out = model(Q, K, V, block_mask, cache_seqlens) - debug("output", ref, out, atol=1e-3, rtol=1e-3) + assert_close("output", ref, out, atol=1e-3, rtol=1e-3) import flash_attn # noqa: F401 ## latency reference for _ in range(10): - ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size) + ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) torch.cuda.synchronize() start = time.time() for _ in range(100): - ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size) + ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) torch.cuda.synchronize() print("dense time: ", (time.time() - start) / 100 * 1000) @@ -449,17 +381,83 @@ def main(batch=8, print("sparse time: ", (time.time() - start) / 100 * 1000) +def run_regression_perf(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + sparse_ratio = sparse_ratio + block_size = block_size + max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) + dtype = torch.float16 + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") + random_index = torch.randint(0, batch, (1,), device="cuda").item() + cache_seqlens[random_index] = max_cache_seqlen + + num_blocks = (max_cache_seqlen + block_size - 1) // block_size + + valid_num_blocks = torch.ceil(cache_seqlens * (1 - sparse_ratio) / block_size).int() + max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() + block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda") + + for b in range(batch): + max_valid_block = max_valid_num_blocks[b].item() + valid_num_block = valid_num_blocks[b].item() + if valid_num_block > 0: + for h in range(heads_kv): + perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block] + block_mask[b, h, perm] = True + + sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) + batch = sparse_kernel.batch + heads = sparse_kernel.heads + heads_kv = sparse_kernel.heads_kv + dim_v = sparse_kernel.dim_v + dim = sparse_kernel.dim + block_size = sparse_kernel.block_size + max_selected_blocks = (max_cache_seqlen + block_size - 1) // block_size + + num_m_blocks = 1 * (heads // heads_kv + sparse_kernel.block_H - 1) // sparse_kernel.block_H + num_n_blocks = max_selected_blocks + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + num_sm = sparse_kernel.num_sm + + num_split = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + kernel = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, + block_N=block_size, + block_H=sparse_kernel.block_H, + num_stages=2, + threads=128, + ) + + def run_kernel_only(): + kernel(Q, K, V, block_mask, cache_seqlens, glse, output_partial) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') - parser.add_argument('--block_size', type=int, default=32, help='block_size') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") args = parser.parse_args() - main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, - args.sparse_ratio, args.block_size) + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py index 85b72b775..91d85a1a4 100644 --- a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py @@ -5,19 +5,15 @@ import argparse from einops import rearrange, einsum import torch.nn.functional as F - import math import time from heuristic import num_splits_heuristic +from tilelang.profiler import do_bench @triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4]\ - for num_stages in [1, 2, 3, 4, 7] - ], - key=['BLOCK_H', 'BLOCK_N', 'BLOCK_D'], + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_H", "BLOCK_N", "BLOCK_D"], ) @triton.jit def _split_kernel( @@ -79,16 +75,11 @@ def _split_kernel( loop_range = blocks_per_split q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h - k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[ - None, :] * stride_k_s + offs_d[:, None] * stride_k_d - v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, - None] * stride_v_s + offs_d[ - None, :] * stride_v_d + k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[None, :] * stride_k_s + offs_d[:, None] * stride_k_d + v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, None] * stride_v_s + offs_d[None, :] * stride_v_d mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h - q = tl.load( - q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, - mask=offs_h[:, None] < gqa_group_size) + q = tl.load(q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, mask=offs_h[:, None] < gqa_group_size) start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks) for i in range(loop_range): block_idx = tl.load(mask_ptr + (start + i) * stride_mask_s) @@ -119,23 +110,18 @@ def _split_kernel( acc = acc * l_recip acc = acc.to(o_partial_ptr.dtype.element_ty) - lse_partial_ptr += batch_idx * stride_lse_b + ( - head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split + lse_partial_ptr += batch_idx * stride_lse_b + (head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size) - o_partial_ptr += batch_idx * stride_o_b + ( - head_idx_q + - offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d + o_partial_ptr += ( + batch_idx * stride_o_b + (head_idx_q + offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d + ) tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size) @triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4]\ - for num_stages in [1, 2, 3, 4, 7] - ], - key=['BLOCK_D'], + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_D"], ) @triton.jit def _merge_kernel( @@ -163,18 +149,15 @@ def _merge_kernel( offs_d = tl.arange(0, BLOCK_D) lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h - lse = tl.load( - lse_offsets + offs_splits * lse_partial_stride_split, - mask=offs_splits < num_splits, - other=float("-inf")) + lse = tl.load(lse_offsets + offs_splits * lse_partial_stride_split, mask=offs_splits < num_splits, other=float("-inf")) lse_max = tl.max(lse) o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h o_partial = tl.load( - o_offsets + offs_splits[:, None] * o_partial_stride_split + - offs_d[None, :] * o_partial_stride_d, - mask=offs_splits[:, None] < num_splits) + o_offsets + offs_splits[:, None] * o_partial_stride_split + offs_d[None, :] * o_partial_stride_d, + mask=offs_splits[:, None] < num_splits, + ) sumexp_normalized_splitk = tl.exp(lse - lse_max) sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0) numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0) @@ -209,19 +192,13 @@ def block_sparse_flash_decode_gqa_indice_triton( num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H num_n_blocks = max_selected_blocks - size_one_kv_head = max_selected_blocks * block_size * ( - dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2 + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks num_sm = 64 # num_sm = self.num_sm num_splits = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) # print("num_splits:", num_splits, "num_blocks:", num_n_blocks) @@ -295,24 +272,18 @@ def block_sparse_flash_decode_gqa_indice_triton( return output -def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): - +def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): batch, heads, dim = query.shape heads_kv = key.shape[2] dim_v = value.shape[-1] num_head_groups = query.shape[1] // key.shape[2] scale = dim**0.5 - key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] - value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] sparse_mask = torch.zeros_like(scores) # Assign mask values based on block_indices @@ -321,42 +292,33 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache valid_indices = block_indices[b, h] # Extract indices for this batch and head for idx in valid_indices: if idx >= 0: - sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 - scores = scores.masked_fill(sparse_mask == 0, float('-inf')) + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) - range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) cache_seqlens_expanded = cache_seqlens.unsqueeze(1) pad_mask = range_len >= cache_seqlens_expanded pad_mask = pad_mask[:, None, None, :] - scores = scores.masked_fill(pad_mask, float('-inf')) - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] - out = einsum(attention, value, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out def ref_program_fa(query, key, value, cache_seqlens): # latency reference # from flash_attn_interface import flash_attn_with_kvcache # fa3 - from flash_attn import flash_attn_with_kvcache #fa2 + from flash_attn import flash_attn_with_kvcache # fa2 + query = query.unsqueeze(1) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = output.squeeze(1) return output -def main(batch=64, - heads=32, - heads_kv=8, - max_cache_seqlen=8192, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32): - +def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v sparse_ratio = sparse_ratio block_size = block_size @@ -367,49 +329,31 @@ def main(batch=64, max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) print("max_selected_blocks: ", max_selected_blocks) dtype = torch.float16 - block_H = 64 - Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') - K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') - V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') - cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') - # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") # Ensure at least one element equals cache_seqlen - random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index - cache_seqlens[ - random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence - - print("cache_seqlens: ", cache_seqlens) - + random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index + cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() - print("max_valid_num_blocks: ", max_valid_num_blocks) # Initialize block_indices with -1 (for padding blocks) - block_indices = torch.full((batch, heads_kv, max_selected_blocks), - -1, - dtype=torch.int32, - device='cuda') + block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda") # Assign valid indices while ensuring no duplicates within each batch-group for b in range(batch): max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch if max_valid_block > 0: # Ensure there's at least one valid block for h in range(heads_kv): - valid_indices = torch.randperm( - max_valid_block, device='cuda', dtype=torch.int32)[:max_selected_blocks] - block_indices[b, h, :len(valid_indices)] = valid_indices + valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks] + block_indices[b, h, : len(valid_indices)] = valid_indices # Sort indices within each batch-group for consistency block_indices, _ = block_indices.sort(dim=-1, descending=True) - # print("block_indices: ", block_indices) - actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32)[:, 0] - print("actual_num_blocks: ", actual_num_blocks) - # print(block_indices.shape, actual_num_blocks.shape) - max_num_blocks = torch.max(max_valid_num_blocks).item() - print("max_num_blocks: ", max_num_blocks) - ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, - block_size) + ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) triton_out = block_sparse_flash_decode_gqa_indice_triton( Q, @@ -423,8 +367,7 @@ def main(batch=64, ) print("max difference: ", torch.max(torch.abs(ref - triton_out))) - assert torch.allclose( - ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation" + assert torch.allclose(ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation" print("Passed the ref test!") # Measure performance @@ -447,6 +390,7 @@ def main(batch=64, avg_time = elapsed_time / 1000 avg_flops = total_flops / avg_time print(f"Average time: {avg_time:.6f} seconds") + print(f"Average FLOPS: {avg_flops:.2f} GFLOPS") # Measure performance of reference implementation import flash_attn # noqa: F401 @@ -460,21 +404,19 @@ def main(batch=64, avg_time_ref = elapsed_time_ref / 1000 avg_flops_ref = total_flops / avg_time_ref print(f"Average time of ref: {avg_time_ref:.6f} seconds") - + print(f"Average FLOPS of ref: {avg_flops_ref:.2f} GFLOPS") print(f"Speedup: {avg_time_ref / avg_time:.2f}x") if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=64, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') - parser.add_argument('--block_size', type=int, default=32, help='block_size') + parser.add_argument("--batch", type=int, default=64, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") args = parser.parse_args() - main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, - args.sparse_ratio, args.block_size) + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py index 348572526..232bcacaf 100644 --- a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py @@ -4,19 +4,14 @@ import argparse from einops import rearrange, einsum import torch.nn.functional as F - import math import time from heuristic import num_splits_heuristic @triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4]\ - for num_stages in [1, 2, 3, 4, 7] - ], - key=['BLOCK_H', 'BLOCK_N', 'BLOCK_D'], + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_H", "BLOCK_N", "BLOCK_D"], ) @triton.jit def _split_kernel( @@ -77,16 +72,11 @@ def _split_kernel( loop_range = blocks_per_split q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h - k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[ - None, :] * stride_k_s + offs_d[:, None] * stride_k_d - v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, - None] * stride_v_s + offs_d[ - None, :] * stride_v_d + k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[None, :] * stride_k_s + offs_d[:, None] * stride_k_d + v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, None] * stride_v_s + offs_d[None, :] * stride_v_d mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h - q = tl.load( - q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, - mask=offs_h[:, None] < gqa_group_size) + q = tl.load(q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, mask=offs_h[:, None] < gqa_group_size) start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks) for block_idx in range(loop_range): start_n = (start + block_idx) * BLOCK_N @@ -117,23 +107,18 @@ def _split_kernel( acc = acc * l_recip acc = acc.to(o_partial_ptr.dtype.element_ty) - lse_partial_ptr += batch_idx * stride_lse_b + ( - head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split + lse_partial_ptr += batch_idx * stride_lse_b + (head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size) - o_partial_ptr += batch_idx * stride_o_b + ( - head_idx_q + - offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d + o_partial_ptr += ( + batch_idx * stride_o_b + (head_idx_q + offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d + ) tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size) @triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4]\ - for num_stages in [1, 2, 3, 4, 7] - ], - key=['BLOCK_D'], + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_D"], ) @triton.jit def _merge_kernel( @@ -161,18 +146,15 @@ def _merge_kernel( offs_d = tl.arange(0, BLOCK_D) lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h - lse = tl.load( - lse_offsets + offs_splits * lse_partial_stride_split, - mask=offs_splits < num_splits, - other=float("-inf")) + lse = tl.load(lse_offsets + offs_splits * lse_partial_stride_split, mask=offs_splits < num_splits, other=float("-inf")) lse_max = tl.max(lse) o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h o_partial = tl.load( - o_offsets + offs_splits[:, None] * o_partial_stride_split + - offs_d[None, :] * o_partial_stride_d, - mask=offs_splits[:, None] < num_splits) + o_offsets + offs_splits[:, None] * o_partial_stride_split + offs_d[None, :] * o_partial_stride_d, + mask=offs_splits[:, None] < num_splits, + ) sumexp_normalized_splitk = tl.exp(lse - lse_max) sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0) numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0) @@ -207,19 +189,13 @@ def block_sparse_flash_decode_gqa_mask_triton( num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H num_n_blocks = max_selected_blocks - size_one_kv_head = max_selected_blocks * block_size * ( - dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2 + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks num_sm = 64 # num_sm = self.num_sm num_splits = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) # print("num_splits:", num_splits, "num_blocks:", num_n_blocks) @@ -292,24 +268,18 @@ def block_sparse_flash_decode_gqa_mask_triton( return output -def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): - +def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size): batch, heads, dim = query.shape heads_kv = key.shape[2] num_head_groups = query.shape[1] // key.shape[2] scale = dim**0.5 - key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] - value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] sparse_mask = torch.zeros_like(scores) # Assign mask values @@ -317,43 +287,34 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se for h in range(heads_kv): for idx in range(num_blocks): if block_mask[b, h, idx]: - sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 - scores = scores.masked_fill(sparse_mask == 0, float('-inf')) + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) - range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) cache_seqlens_expanded = cache_seqlens.unsqueeze(1) pad_mask = range_len >= cache_seqlens_expanded pad_mask = pad_mask[:, None, None, :] - scores = scores.masked_fill(pad_mask, float('-inf')) - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] - out = einsum(attention, value, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out def ref_program_fa(query, key, value, cache_seqlens): # latency reference # from flash_attn_interface import flash_attn_with_kvcache # fa3 - from flash_attn import flash_attn_with_kvcache #fa2 + from flash_attn import flash_attn_with_kvcache # fa2 + query = query.unsqueeze(1) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = output.squeeze(1) return output -def main(batch=64, - heads=32, - heads_kv=8, - max_cache_seqlen=8192, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32): - +def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v block_size = block_size sparse_ratio = sparse_ratio @@ -363,14 +324,13 @@ def main(batch=64, dtype = torch.float16 - Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') - K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') - V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') - cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") # Ensure at least one element equals cache_seqlen - random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index - cache_seqlens[ - random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence + random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index + cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence num_blocks = (max_cache_seqlen + block_size - 1) // block_size @@ -379,7 +339,7 @@ def main(batch=64, max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() print("max_valid_num_blocks: ", max_valid_num_blocks) # Initialize block_mask with false (for padding blocks) - block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device='cuda') + block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda") # Assign valid indices while ensuring no duplicates within each batch-group for b in range(batch): @@ -387,11 +347,10 @@ def main(batch=64, valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch if valid_num_block > 0: # Ensure there's at least one valid block for h in range(heads_kv): - perm = torch.randperm(max_valid_block, device='cuda')[:valid_num_block] + perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block] block_mask[b, h, perm] = True - ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size) + ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) triton_out = block_sparse_flash_decode_gqa_mask_triton( Q, @@ -404,8 +363,7 @@ def main(batch=64, ) # print("max difference: ", torch.max(torch.abs(ref - triton_out))) - assert torch.allclose( - ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation" + assert torch.allclose(ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation" print("Passed the ref test!") # Measure performance @@ -448,15 +406,13 @@ def main(batch=64, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=64, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') - parser.add_argument('--block_size', type=int, default=32, help='block_size') + parser.add_argument("--batch", type=int, default=64, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") args = parser.parse_args() - main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, - args.sparse_ratio, args.block_size) + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/heuristic.py b/examples/blocksparse_attention/heuristic.py index b60a81dc3..0e6fc5281 100644 --- a/examples/blocksparse_attention/heuristic.py +++ b/examples/blocksparse_attention/heuristic.py @@ -1,8 +1,7 @@ import math -def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head, - is_causal_or_local, max_splits): +def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local, max_splits): """ Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency. diff --git a/examples/blocksparse_attention/regression_example_blocksparse_attention.py b/examples/blocksparse_attention/regression_example_blocksparse_attention.py new file mode 100644 index 000000000..26fa60df5 --- /dev/null +++ b/examples/blocksparse_attention/regression_example_blocksparse_attention.py @@ -0,0 +1,20 @@ +import tilelang.testing +import example_tilelang_block_sparse_attn +import example_tilelang_sparse_gqa_decode_varlen_indice +import example_tilelang_sparse_gqa_decode_varlen_mask + + +def regression_example_tilelang_block_sparse_attn(): + tilelang.testing.process_func(example_tilelang_block_sparse_attn.run_regression_perf) + + +def regression_example_tilelang_sparse_gqa_decode_varlen_indice(): + tilelang.testing.process_func(example_tilelang_sparse_gqa_decode_varlen_indice.run_regression_perf, batch=1, max_cache_seqlen=2048) + + +def regression_example_tilelang_sparse_gqa_decode_varlen_mask(): + tilelang.testing.process_func(example_tilelang_sparse_gqa_decode_varlen_mask.run_regression_perf, batch=1, max_cache_seqlen=2048) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/blocksparse_attention/test_example_blocksparse_attention.py b/examples/blocksparse_attention/test_example_blocksparse_attention.py index adda1f0f1..dd33f46c4 100644 --- a/examples/blocksparse_attention/test_example_blocksparse_attention.py +++ b/examples/blocksparse_attention/test_example_blocksparse_attention.py @@ -25,26 +25,14 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask(): def test_example_triton_sparse_gqa_decode_varlen_indice(): example_triton_sparse_gqa_decode_varlen_indice.main( - batch=8, - heads=8, - heads_kv=4, - max_cache_seqlen=2048, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32) + batch=8, heads=8, heads_kv=4, max_cache_seqlen=2048, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32 + ) def test_example_triton_sparse_gqa_decode_varlen_mask(): example_triton_sparse_gqa_decode_varlen_mask.main( - batch=16, - heads=16, - heads_kv=8, - max_cache_seqlen=1024, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32) + batch=16, heads=16, heads_kv=8, max_cache_seqlen=1024, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32 + ) if __name__ == "__main__": diff --git a/examples/blocksparse_gemm/example_blocksparse_gemm.py b/examples/blocksparse_gemm/example_blocksparse_gemm.py index 7b9cff7c1..289421548 100644 --- a/examples/blocksparse_gemm/example_blocksparse_gemm.py +++ b/examples/blocksparse_gemm/example_blocksparse_gemm.py @@ -2,10 +2,9 @@ import itertools import tilelang import tilelang.language as T -from tilelang.engine.param import KernelParam from tilelang.utils.tensor import get_tensor_supply, TensorSupplyType import torch -from typing import List +from tilelang.profiler import do_bench DEFAULT_BLOCK_M = 128 DEFAULT_BLOCK_N = 128 @@ -13,25 +12,8 @@ DEFAULT_NUM_STAGES = 2 DEFAULT_THREAD_NUM = 128 DEFAULT_ENABLE_RASTERIZATION = True - -parser = argparse.ArgumentParser(description="Autotuned BlockSparse MatMul Benchmark") -parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M") -parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N") -parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K") -parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio (0-1)") -parser.add_argument( - "--use_autotune", action="store_true", default=False, help="Whether to use autotune") - -args, _ = parser.parse_known_args() -M, N, K = args.m, args.n, args.k -sparsity = args.sparsity -use_autotune = args.use_autotune default_tensor_supply = get_tensor_supply(TensorSupplyType.Auto) -print(f"Running BlockSparse MatMul Benchmark for M={M}, N={N}, K={K}") -print(f"Target Block Sparsity: {sparsity}") -print(f"Using Autotuner: {use_autotune}\n") - def get_configs(): block_M = [64, 128, 256] @@ -41,76 +23,52 @@ def get_configs(): thread_num = [128, 256] enable_rasterization = [True, False] - _configs = list( - itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasterization)) + _configs = list(itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasterization)) - return [{ - "block_M": c[0], - "block_N": c[1], - "block_K": c[2], - "num_stages": c[3], - "thread_num": c[4], - "enable_rasteration": c[5], - } for c in _configs] + return [ + { + "block_M": c[0], + "block_N": c[1], + "block_K": c[2], + "num_stages": c[3], + "thread_num": c[4], + "enable_rasteration": c[5], + } + for c in _configs + ] def ref_program(A, B, BlockMask, block_M, block_N, block_K): + M, K = A.shape + _, N = B.shape ref_c = torch.zeros((M, N), dtype=torch.float16, device=A.device) for i in range(M // block_M): for j in range(N // block_N): accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device) for k in range(K // block_K): if BlockMask[i, j, k]: - accu += ( - A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to( - torch.float32) @ B[k * block_K:(k + 1) * block_K, - j * block_N:(j + 1) * block_N].to(torch.float32)) - ref_c[i * block_M:(i + 1) * block_M, - j * block_N:(j + 1) * block_N] = accu.to(torch.float16) + accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[ + k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N + ].to(torch.float32) + ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16) return ref_c -def supply_program(params: List[KernelParam]): - input_tensors = [] - - for p in params: - # Check if the kernel parameter is BlockMask tensor. - # Here, BlockMask is uniquely identified by having 3 dimensions. - if len(p.shape) != 3: - # For non-BlockMask tensors, use the default tensor generation logic. - input_tensors.append(default_tensor_supply(p)) - else: - # For BlockMask tensor, randomly set elements to True based on desired - # sparsity level. - block_mask = torch.zeros(p.shape, dtype=torch.bool, device=torch.cuda.current_device()) - block_mask[:, :, :] = torch.rand(p.shape) > sparsity - input_tensors.append(block_mask) - - return input_tensors - - -@tilelang.autotune(configs=get_configs(),) +@tilelang.autotune( + configs=get_configs(), +) @tilelang.jit(out_idx=[-1]) -def blocksparse_matmul(M, - N, - K, - block_M, - block_N, - block_K, - num_stages, - thread_num, - enable_rasteration, - dtype="float16", - accum_dtype="float"): - +def blocksparse_matmul( + M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32 +): block_mask_shape = (M // block_M, N // block_N, K // block_K) @T.prim_func def block_sparse_matmul( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - BlockMask: T.Tensor(block_mask_shape, "bool"), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -134,7 +92,20 @@ def block_sparse_matmul( def main(): - + parser = argparse.ArgumentParser(description="Autotuned BlockSparse MatMul Benchmark") + parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K") + parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio (0-1)") + parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune") + + args, _ = parser.parse_known_args() + M, N, K = args.m, args.n, args.k + sparsity = args.sparsity + use_autotune = args.use_autotune + print(f"Running BlockSparse MatMul Benchmark for M={M}, N={N}, K={K}") + print(f"Target Block Sparsity: {sparsity}") + print(f"Using Autotuner: {use_autotune}\n") # Initialize input matrices A and B on the GPU with half precision a = torch.randn(M, K).cuda().half() b = torch.randn(K, N).cuda().half() @@ -147,8 +118,7 @@ def main(): best_config = kernel.config best_latency = kernel.latency - block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config[ - "block_K"] + block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config["block_K"] print(f"Best Config: {best_config}") print(f"Sparsity Ratio: {sparsity}") @@ -163,10 +133,10 @@ def main(): block_K=DEFAULT_BLOCK_K, num_stages=DEFAULT_NUM_STAGES, thread_num=DEFAULT_THREAD_NUM, - enable_rasteration=DEFAULT_ENABLE_RASTERIZATION) + enable_rasteration=DEFAULT_ENABLE_RASTERIZATION, + ) block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})") - # Create block mask with desired sparsity mask_shape = (M // block_M, N // block_N, K // block_K) block_mask = torch.rand(mask_shape).cuda() > sparsity @@ -185,5 +155,34 @@ def main(): print(e) +def run_regression_perf(): + M = N = K = 1024 + sparsity = 0.5 + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + + kernel = blocksparse_matmul( + M, + N, + K, + block_M=DEFAULT_BLOCK_M, + block_N=DEFAULT_BLOCK_N, + block_K=DEFAULT_BLOCK_K, + num_stages=DEFAULT_NUM_STAGES, + thread_num=DEFAULT_THREAD_NUM, + enable_rasteration=DEFAULT_ENABLE_RASTERIZATION, + ) + block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K + mask_shape = (M // block_M, N // block_N, K // block_K) + block_mask = torch.rand(mask_shape).cuda() > sparsity + + def run_kernel_only(): + kernel(a, b, block_mask) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/blocksparse_gemm/regression_example_blocksparse_gemm.py b/examples/blocksparse_gemm/regression_example_blocksparse_gemm.py new file mode 100644 index 000000000..81900a00c --- /dev/null +++ b/examples/blocksparse_gemm/regression_example_blocksparse_gemm.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_blocksparse_gemm + + +def regression_example_blocksparse_gemm(): + tilelang.testing.process_func(example_blocksparse_gemm.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/cast/example_group_per_split_token_cast_to_fp8.py b/examples/cast/example_group_per_split_token_cast_to_fp8.py index 102ac2021..db6beab1e 100644 --- a/examples/cast/example_group_per_split_token_cast_to_fp8.py +++ b/examples/cast/example_group_per_split_token_cast_to_fp8.py @@ -5,8 +5,8 @@ from tilelang.utils.tensor import torch_assert_close # support bfloat16, float, float16 -dtype = "bfloat16" -accum_dtype = "float" +dtype = T.bfloat16 +accum_dtype = T.float32 @tilelang.jit(out_idx=[2, 3]) @@ -16,11 +16,13 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): fp8_max = 448.0 @T.prim_func - def group_per_split_token_cast(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor( - (BG,), "int32"), X_fp8: T.Tensor((BG, M_max, N), "float8_e4m3"), X_amax: T.Tensor( - (BG, M_max, T.ceildiv(N, group_size)), accum_dtype)): - with T.Kernel( - T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz): + def group_per_split_token_cast( + X: T.Tensor((M, N), dtype), + batch_sizes: T.Tensor((BG,), T.int32), + X_fp8: T.Tensor((BG, M_max, N), T.float8_e4m3fn), + X_amax: T.Tensor((BG, M_max, T.ceildiv(N, group_size)), accum_dtype), + ): + with T.Kernel(T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz): row = bx row_g_id = by bg = bz @@ -28,39 +30,29 @@ def group_per_split_token_cast(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor y_amax_local = T.alloc_fragment((blk_m,), accum_dtype) y_s_local = T.alloc_fragment((blk_m,), accum_dtype) y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype) - y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3") - row_offset = T.alloc_fragment((1,), "int32") + y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), T.float8_e4m3fn) + row_offset = T.alloc_var(dtype=T.int32) - T.annotate_layout({ - y_local: - T.Fragment( - y_local.shape, - forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32), - }) - - row_offset[0] = 0 + row_offset = 0 for i in T.serial(bg): - row_offset[0] += batch_sizes[i] + row_offset += batch_sizes[i] T.copy( - X[row_offset[0] + row * blk_m:row_offset[0] + (row + 1) * blk_m, - row_g_id * group_size:(row_g_id + 1) * group_size], y_local) + X[row_offset + row * blk_m : row_offset + (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size], + y_local, + ) T.reduce_absmax(y_local, y_amax_local, dim=1) for i in T.Parallel(blk_m): y_amax_local[i] = T.max(y_amax_local[i], 1e-4) - y_s_local[i] = T.if_then_else(row * blk_m + i < batch_sizes[bg], - y_amax_local[i] / fp8_max, 0) + y_s_local[i] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_amax_local[i] / fp8_max, 0) for i, j in T.Parallel(blk_m, group_size): y_q_local[i, j] = T.clamp(y_local[i, j] / y_s_local[i], fp8_min, fp8_max) T.copy(y_q_local, y_q_local_fp8) for i, j in T.Parallel(blk_m, group_size): - y_q_local_fp8[i, j] = T.if_then_else(row * blk_m + i < batch_sizes[bg], - y_q_local[i, j], 0) + y_q_local_fp8[i, j] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_q_local[i, j], 0) for i in T.Parallel(blk_m): X_amax[bg, row * blk_m + i, row_g_id] = y_s_local[i] - T.copy( - y_q_local_fp8, X_fp8[bg, row * blk_m:(row + 1) * blk_m, - row_g_id * group_size:(row_g_id + 1) * group_size]) + T.copy(y_q_local_fp8, X_fp8[bg, row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size]) return group_per_split_token_cast @@ -127,8 +119,7 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: return x.squeeze(0) if remove_dim else x # Normal layout requires transposing - aligned_x = torch.transpose( - torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) + aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) aligned_x[:, :m, :] = x aligned_x = aligned_x[:, :m, :] return aligned_x.squeeze(0) if remove_dim else aligned_x @@ -146,15 +137,17 @@ def ref_per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens x_fp8 = x_fp8.view(m, -1)[:, :n].contiguous() return x_fp8, (x_amax / 448.0).view(m, -1) -def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \ - Tuple[torch.Tensor, torch.Tensor]: + +def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # assert x.shape[0] == batch_sizes.sum() M_max = ceil_div(batch_sizes.max(), 128) * 128 split_x = torch.split(x, batch_sizes.tolist(), dim=0) padded_x = [torch.nn.functional.pad(t, (0, 0, 0, M_max - t.shape[0])) for t in split_x] num_groups, m, n = batch_sizes.shape[0], M_max, x.shape[1] - x_fp8 = (torch.empty((num_groups, m, n), device='cuda', dtype=torch.float8_e4m3fn), - torch.empty((num_groups, m, n // 128), device='cuda', dtype=torch.float)) + x_fp8 = ( + torch.empty((num_groups, m, n), device="cuda", dtype=torch.float8_e4m3fn), + torch.empty((num_groups, m, n // 128), device="cuda", dtype=torch.float), + ) for i in range(num_groups): x_fp8[0][i], x_fp8[1][i] = ref_per_token_cast_to_fp8(padded_x[i]) x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) @@ -164,11 +157,11 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \ def main(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None): if batch_sizes is None: batch_sizes = [2048, 6144] - if dtype == "float": + if dtype == T.float: x = torch.randn(M, N, device="cuda", dtype=torch.float32) - elif dtype == "float16": + elif dtype == T.float16: x = torch.randn(M, N, device="cuda", dtype=torch.float16) - elif dtype == "bfloat16": + elif dtype == T.bfloat16: x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) else: raise ValueError(f"Unsupported dtype: {dtype}") @@ -206,5 +199,35 @@ def run_torch(): print("Torch: {:.2f} ms".format(latency)) +def run_regression_perf(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None): + if batch_sizes is None: + batch_sizes = [2048, 6144] + if dtype == "float": + x = torch.randn(M, N, device="cuda", dtype=torch.float32) + elif dtype == "float16": + x = torch.randn(M, N, device="cuda", dtype=torch.float16) + elif dtype == "bfloat16": + x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) + else: + raise ValueError(f"Unsupported dtype: {dtype}") + batch_sizes = torch.tensor(batch_sizes, device="cuda", dtype=torch.int32) + M_max = int(ceil_div(batch_sizes.max(), 128) * 128) + + kernel = group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m) + + x_fp8, x_amax = kernel(x, batch_sizes) + x_fp8_ref, x_amax_ref = ref_program(x, batch_sizes) + + torch_assert_close(x_fp8.to(torch.float32), x_fp8_ref.to(torch.float32), rtol=0.01, atol=0.01) + torch_assert_close(x_amax, x_amax_ref, rtol=0.01, atol=0.01) + + from tilelang.profiler import do_bench + + def run_tilelang(): + kernel(x, batch_sizes) + + return do_bench(run_tilelang, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/cast/example_per_token_cast_to_fp8.py b/examples/cast/example_per_token_cast_to_fp8.py index 484a092f0..693e90d30 100644 --- a/examples/cast/example_per_token_cast_to_fp8.py +++ b/examples/cast/example_per_token_cast_to_fp8.py @@ -7,14 +7,15 @@ @tilelang.jit(out_idx=[1, 2]) def per_token_cast_to_fp8(M, N, blk_m): - dtype = "float" + dtype = T.float group_size = 128 fp8_min = -448.0 fp8_max = 448.0 @T.prim_func - def per_token_cast(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "float8_e4m3"), - X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype)): + def per_token_cast( + X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), T.float8_e4m3fn), X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype) + ): with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by): row = bx row_g_id = by @@ -22,18 +23,9 @@ def per_token_cast(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "float8_e y_amax_local = T.alloc_fragment((blk_m,), dtype) y_s_local = T.alloc_fragment((blk_m,), dtype) y_q_local = T.alloc_fragment((blk_m, group_size), dtype) - y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3") - - T.annotate_layout({ - y_local: - T.Fragment( - y_local.shape, - forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32), - }) - - T.copy( - X[row * blk_m:(row + 1) * blk_m, row_g_id * group_size:(row_g_id + 1) * group_size], - y_local) + y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), T.float8_e4m3fn) + + T.copy(X[row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size], y_local) T.reduce_absmax(y_local, y_amax_local, dim=1) for i in T.Parallel(blk_m): y_amax_local[i] = T.max(y_amax_local[i], 1e-4) @@ -43,9 +35,7 @@ def per_token_cast(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "float8_e T.copy(y_q_local, y_q_local_fp8) for i in T.Parallel(blk_m): X_amax[row * blk_m + i, row_g_id] = y_s_local[i] - T.copy( - y_q_local_fp8, X_fp8[row * blk_m:(row + 1) * blk_m, - row_g_id * group_size:(row_g_id + 1) * group_size]) + T.copy(y_q_local_fp8, X_fp8[row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size]) return per_token_cast @@ -105,8 +95,7 @@ def main(M=8192, N=8192, blk_m=8): from example_triton_cast_to_fp8 import per_token_group_quant_fp8 def run_triton(): - x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8( - x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False) + x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8(x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False) return x_fp8_triton_, x_amax_triton_ x_fp8_triton, x_amax_triton = run_triton() @@ -114,5 +103,16 @@ def run_triton(): print("Triton: {:.2f} ms".format(latency)) +def run_regression_perf(M=8192, N=8192, blk_m=8): + kernel = per_token_cast_to_fp8(M, N, blk_m) + x = torch.randn(M, N, device="cuda", dtype=torch.float32) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(x) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/cast/example_triton_cast_to_fp8.py b/examples/cast/example_triton_cast_to_fp8.py index cc56defe7..1859433f1 100644 --- a/examples/cast/example_triton_cast_to_fp8.py +++ b/examples/cast/example_triton_cast_to_fp8.py @@ -128,9 +128,7 @@ def per_token_group_quant_fp8( Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. """ - assert (x.shape[-1] % - group_size == 0), (f"the last dimension of `x` {x.shape[-1]} must be divisible " - f"by `group_size` {group_size}") + assert x.shape[-1] % group_size == 0, f"the last dimension of `x` {x.shape[-1]} must be divisible by `group_size` {group_size}" assert x.stride(-1) == 1, "`x` groups must be contiguous" finfo = torch.finfo(dtype) diff --git a/examples/cast/regression_example_cast.py b/examples/cast/regression_example_cast.py new file mode 100644 index 000000000..4bdfb99e7 --- /dev/null +++ b/examples/cast/regression_example_cast.py @@ -0,0 +1,17 @@ +import tilelang.testing +import example_group_per_split_token_cast_to_fp8 +import example_per_token_cast_to_fp8 + + +def regression_example_group_per_split_token_cast_to_fp8(): + tilelang.testing.process_func( + example_group_per_split_token_cast_to_fp8.run_regression_perf, M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896] + ) + + +def regression_example_per_token_cast_to_fp8(): + tilelang.testing.process_func(example_per_token_cast_to_fp8.run_regression_perf, M=2048, N=512, blk_m=8) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/cast/test_example_cast.py b/examples/cast/test_example_cast.py index 1ca000eb2..e8b10a797 100644 --- a/examples/cast/test_example_cast.py +++ b/examples/cast/test_example_cast.py @@ -4,8 +4,7 @@ def test_example_group_per_split_token_cast_to_fp8(): - example_group_per_split_token_cast_to_fp8.main( - M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896]) + example_group_per_split_token_cast_to_fp8.main(M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896]) def test_example_per_token_cast_to_fp8(): diff --git a/examples/compile_flags/usecase.py b/examples/compile_flags/usecase.py deleted file mode 100644 index 8451b04fc..000000000 --- a/examples/compile_flags/usecase.py +++ /dev/null @@ -1,56 +0,0 @@ -import tilelang -import tilelang.language as T - - -# @tilelang.jit(compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"]) -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - - @T.prim_func - def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), - ): - # Initialize Kernel Context - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_K, block_N), dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - T.clear(C_local) - - for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): - T.copy(A[by * block_M, ko * block_K], A_shared) - T.copy(B[ko * block_K, bx * block_N], B_shared) - T.gemm(A_shared, B_shared, C_local) - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - -M = 1024 -N = 1024 -K = 1024 -block_M = 128 -block_N = 128 -block_K = 32 - -func = matmul(M, N, K, block_M, block_N, block_K) - -jit_kernel = tilelang.compile( - func, out_idx=[2], target="cuda", compile_flags="-O3 --use_fast_math --expt-relaxed-constexpr") -# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"]) -# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3 --use_fast_math --expt-relaxed-constexpr"]) - -import torch - -a = torch.randn(M, K, device="cuda", dtype=torch.float16) -b = torch.randn(K, N, device="cuda", dtype=torch.float16) - -c = jit_kernel(a, b) - -print(c) - -ref_c = a @ b - -torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) -print("Kernel output matches PyTorch reference.") diff --git a/examples/conftest.py b/examples/conftest.py index 9f49d40a9..4010e0d83 100644 --- a/examples/conftest.py +++ b/examples/conftest.py @@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): "warnings", "error", } - if (sum( - len(terminalreporter.stats.get(k, [])) - for k in known_types.difference({"skipped", "deselected"})) == 0): + if sum(len(terminalreporter.stats.get(k, [])) for k in known_types.difference({"skipped", "deselected"})) == 0: terminalreporter.write_sep( "!", - (f"Error: No tests were collected. " - f"{dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"), + (f"Error: No tests were collected. {dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"), ) pytest.exit("No tests were collected.", returncode=5) diff --git a/examples/convolution/example_convolution.py b/examples/convolution/example_convolution.py index b2696ba8f..1599d3464 100644 --- a/examples/convolution/example_convolution.py +++ b/examples/convolution/example_convolution.py @@ -14,7 +14,6 @@ def check_hopper(): def ref_program(stride, padding, dilation): - def main(A, B): A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W @@ -26,38 +25,21 @@ def main(A, B): @tilelang.jit(out_idx=[2]) -def convolution(N, - C, - H, - W, - F, - K, - S, - D, - P, - block_M, - block_N, - block_K, - num_stages, - threads, - dtype="float16", - accum_dtype="float"): +def convolution(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 is_hopper = check_hopper() @T.prim_func def main( - data: T.Tensor((N, H, W, C), dtype), - kernel: T.Tensor((KH, KW, C, F), dtype), - out: T.Tensor((N, OH, OW, F), dtype), + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), ): - with T.Kernel( - T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by): data_shared = T.alloc_shared((block_M, block_K), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -66,12 +48,6 @@ def main( kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) - T.annotate_layout({ - out_shared: tilelang.layout.make_swizzled_layout(out_shared), - data_shared: tilelang.layout.make_swizzled_layout(data_shared), - kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), - }) - T.clear(out_local) for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): if is_hopper: @@ -82,10 +58,8 @@ def main( m = by * block_M + i access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P access_w = m % OW * S + k // C % KW * D - P - in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and - (access_w < W)) - data_shared[i, j] = T.if_then_else( - in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W) + data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) T.gemm(data_shared, kernel_shared, out_local) @@ -97,15 +71,15 @@ def main( def main(argv=None): parser = argparse.ArgumentParser() - parser.add_argument('--n', type=int, default=128, help='n') - parser.add_argument('--c', type=int, default=128, help='c') - parser.add_argument('--h', type=int, default=64, help='h') - parser.add_argument('--w', type=int, default=64, help='w') - parser.add_argument('--f', type=int, default=128, help='f') - parser.add_argument('--k', type=int, default=3, help='k') - parser.add_argument('--s', type=int, default=1, help='s') - parser.add_argument('--d', type=int, default=1, help='d') - parser.add_argument('--p', type=int, default=1, help='p') + parser.add_argument("--n", type=int, default=128, help="n") + parser.add_argument("--c", type=int, default=128, help="c") + parser.add_argument("--h", type=int, default=64, help="h") + parser.add_argument("--w", type=int, default=64, help="w") + parser.add_argument("--f", type=int, default=128, help="f") + parser.add_argument("--k", type=int, default=3, help="k") + parser.add_argument("--s", type=int, default=1, help="s") + parser.add_argument("--d", type=int, default=1, help="d") + parser.add_argument("--p", type=int, default=1, help="p") args = parser.parse_args(argv) N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p @@ -125,5 +99,30 @@ def main(argv=None): print("All checks passed.✅") +def run_regression_perf(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument("--n", type=int, default=128, help="n") + parser.add_argument("--c", type=int, default=128, help="c") + parser.add_argument("--h", type=int, default=64, help="h") + parser.add_argument("--w", type=int, default=64, help="w") + parser.add_argument("--f", type=int, default=128, help="f") + parser.add_argument("--k", type=int, default=3, help="k") + parser.add_argument("--s", type=int, default=1, help="s") + parser.add_argument("--d", type=int, default=1, help="d") + parser.add_argument("--p", type=int, default=1, help="p") + + args = parser.parse_args(argv) + N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p + + block_m = 64 + block_n = 128 + block_k = 32 + num_stages = 3 + threads = 256 + kernel = convolution(N, C, H, W, F, K, S, D, P, block_m, block_n, block_k, num_stages, threads) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/convolution/example_convolution_autotune.py b/examples/convolution/example_convolution_autotune.py index 393677489..c0c666402 100644 --- a/examples/convolution/example_convolution_autotune.py +++ b/examples/convolution/example_convolution_autotune.py @@ -14,7 +14,6 @@ def check_hopper(): def ref_program(stride, padding, dilation): - def main(A, B): A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W @@ -40,7 +39,8 @@ def get_configs(): num_stages, thread_num, enable_rasterization, - )) + ) + ) configs = [ { @@ -50,7 +50,8 @@ def get_configs(): "num_stages": c[3], "thread_num": c[4], "enable_rasteration": c[5], # keep param name for backward-compat - } for c in _configs + } + for c in _configs ] return configs @@ -64,69 +65,32 @@ def get_heuristic_config() -> dict: sm_version = sm_major * 10 + sm_minor print(f"CUDA device capability: {sm_version}") if sm_version in {80}: - return { - "block_M": 128, - "block_N": 256, - "block_K": 32, - "num_stages": 2, - "thread_num": 128, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True} elif sm_version in {90}: - return { - "block_M": 128, - "block_N": 256, - "block_K": 64, - "num_stages": 3, - "thread_num": 256, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True} else: - return { - "block_M": 128, - "block_N": 256, - "block_K": 32, - "num_stages": 0, - "thread_num": 128, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True} @tilelang.autotune(configs=get_configs()) @tilelang.jit(out_idx=[2]) -def convolution(N, - C, - H, - W, - F, - K, - S, - D, - P, - block_M, - block_N, - block_K, - num_stages, - thread_num, - enable_rasteration, - dtype="float16", - accum_dtype="float"): +def convolution( + N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32 +): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 is_hopper = check_hopper() @T.prim_func def main( - data: T.Tensor((N, H, W, C), dtype), - kernel: T.Tensor((KH, KW, C, F), dtype), - out: T.Tensor((N, OH, OW, F), dtype), + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), ): - with T.Kernel( - T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), - threads=thread_num) as (bx, by): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=thread_num) as (bx, by): data_shared = T.alloc_shared((block_M, block_K), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -135,11 +99,6 @@ def main( kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) - if is_hopper: - T.annotate_layout({ - out_shared: tilelang.layout.make_swizzled_layout(out_shared), - }) - T.clear(out_local) for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): if is_hopper: @@ -150,10 +109,8 @@ def main( m = by * block_M + i access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P access_w = m % OW * S + k // C % KW * D - P - in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and - (access_w < W)) - data_shared[i, j] = T.if_then_else( - in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W) + data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) T.gemm(data_shared, kernel_shared, out_local) @@ -166,17 +123,19 @@ def main( return main -def main(n: int = 128, - c: int = 128, - h: int = 64, - w: int = 64, - f: int = 128, - k: int = 3, - s: int = 1, - d: int = 1, - p: int = 1, - use_autotune: bool = False, - with_roller: bool = True): +def main( + n: int = 128, + c: int = 128, + h: int = 64, + w: int = 64, + f: int = 128, + k: int = 3, + s: int = 1, + d: int = 1, + p: int = 1, + use_autotune: bool = False, + with_roller: bool = True, +): N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p ref_prog = ref_program(S, P, D) @@ -194,27 +153,38 @@ def main(n: int = 128, print(f"Ref latency: {ref_latency}") +def run_regression_perf( + n: int = 128, + c: int = 128, + h: int = 64, + w: int = 64, + f: int = 128, + k: int = 3, + s: int = 1, + d: int = 1, + p: int = 1, + use_autotune: bool = False, + with_roller: bool = True, +): + N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p + config = get_heuristic_config() + kernel = convolution(N, C, H, W, F, K, S, D, P, **config) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") - parser.add_argument('--n', type=int, default=128, help='n') - parser.add_argument('--c', type=int, default=128, help='c') - parser.add_argument('--h', type=int, default=64, help='h') - parser.add_argument('--w', type=int, default=64, help='w') - parser.add_argument('--f', type=int, default=128, help='f') - parser.add_argument('--k', type=int, default=3, help='k') - parser.add_argument('--s', type=int, default=1, help='s') - parser.add_argument('--d', type=int, default=1, help='d') - parser.add_argument('--p', type=int, default=1, help='p') - parser.add_argument( - "--use_autotune", - action="store_true", - default=False, - help="Whether to use autotune for matmul configs") - parser.add_argument( - "--with_roller", - action="store_true", - default=True, - help="Whether to enable BitBLAS roller for search space") + parser.add_argument("--n", type=int, default=128, help="n") + parser.add_argument("--c", type=int, default=128, help="c") + parser.add_argument("--h", type=int, default=64, help="h") + parser.add_argument("--w", type=int, default=64, help="w") + parser.add_argument("--f", type=int, default=128, help="f") + parser.add_argument("--k", type=int, default=3, help="k") + parser.add_argument("--s", type=int, default=1, help="s") + parser.add_argument("--d", type=int, default=1, help="d") + parser.add_argument("--p", type=int, default=1, help="p") + parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs") + parser.add_argument("--with_roller", action="store_true", default=True, help="Whether to enable BitBLAS roller for search space") args = parser.parse_args() - main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune, - args.with_roller) + main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune, args.with_roller) diff --git a/examples/convolution/regression_example_convolution.py b/examples/convolution/regression_example_convolution.py new file mode 100644 index 000000000..18d4bcb68 --- /dev/null +++ b/examples/convolution/regression_example_convolution.py @@ -0,0 +1,15 @@ +import tilelang.testing +import example_convolution +import example_convolution_autotune + + +def regression_example_convolution(): + tilelang.testing.process_func(example_convolution.run_regression_perf) + + +def regression_example_convolution_autotune(): + tilelang.testing.process_func(example_convolution_autotune.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py b/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py index 715f09a9b..18467a811 100644 --- a/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py +++ b/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py @@ -20,11 +20,11 @@ def tl_gemm( accum_dtype, ): assert in_dtype in [ - "float8_e4m3", + T.float8_e4m3fn, ], "Currently only float8_e4m3 is supported" assert out_dtype in [ - "bfloat16", - "float32", + T.bfloat16, + T.float32, ], "Currently only float16 and float32 are supported" group_size = 128 @@ -41,18 +41,17 @@ def tl_gemm( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - scales_a: T.Tensor(Scales_A_shape, "float32"), - scales_b: T.Tensor(Scales_B_shape, "float32"), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + scales_a: T.Tensor(Scales_A_shape, T.float32), + scales_b: T.Tensor(Scales_B_shape, T.float32), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_shared = T.alloc_shared(C_shared_shape, out_dtype) - Scale_C_shared = T.alloc_shared((block_M), "float32") + Scale_C_shared = T.alloc_shared((block_M), T.float32) C_local = T.alloc_fragment(C_shared_shape, accum_dtype) C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype) @@ -93,21 +92,18 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: m, n = x.shape x_view = x.view(m, -1, 128) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view( - m, n), (x_amax / 448.0).view(m, -1) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape - x_padded = torch.zeros( - ceildiv(m, 128) * 128, ceildiv(n, 128) * 128, dtype=x.dtype, device=x.device) + x_padded = torch.zeros(ceildiv(m, 128) * 128, ceildiv(n, 128) * 128, dtype=x.dtype, device=x.device) x_padded[:m, :n] = x x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view( - x_view.size(0), x_view.size(2)) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype): @@ -127,13 +123,14 @@ def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype): c_acc.zero_() for k in range(ceildiv(K, 128)): c = torch._scaled_mm( - A_fp8[i * 128:(i + 1) * 128, k * 128:(k + 1) * 128], - B_fp8[j * 128:(j + 1) * 128, k * 128:(k + 1) * 128].T, + A_fp8[i * 128 : (i + 1) * 128, k * 128 : (k + 1) * 128], + B_fp8[j * 128 : (j + 1) * 128, k * 128 : (k + 1) * 128].T, scale_a=A_scales[i, k].view(128, 1).contiguous(), scale_b=B_scales[j, k].view(1, 128).contiguous(), - out_dtype=torch.bfloat16) + out_dtype=torch.bfloat16, + ) c_acc += c.to(torch.float32) - C[i * 128:(i + 1) * 128, j * 128:(j + 1) * 128] = c_acc.to(out_dtype) + C[i * 128 : (i + 1) * 128, j * 128 : (j + 1) * 128] = c_acc.to(out_dtype) return C @@ -179,11 +176,11 @@ def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtyp def main(): - assert_tl_gemm_correctness(1024, 1024, 8192, 128, "float8_e4m3", "bfloat16", "float32") + assert_tl_gemm_correctness(1024, 1024, 8192, 128, T.float8_e4m3fn, T.bfloat16, T.float32) if __name__ == "__main__": - for dtype in ["float8_e4m3"]: - for out_dtype in ["bfloat16", "float32"]: + for dtype in [T.float8_e4m3fn]: + for out_dtype in [T.bfloat16, T.float32]: for block_N in [16, 32, 64, 128]: - assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, "float32") + assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, T.float32) diff --git a/examples/deepseek_mhc/example_mhc_post.py b/examples/deepseek_mhc/example_mhc_post.py new file mode 100644 index 000000000..feec31bc9 --- /dev/null +++ b/examples/deepseek_mhc/example_mhc_post.py @@ -0,0 +1,114 @@ +import math + +import torch + +import tilelang +import tilelang.language as T + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, + }, +) +def mhc_post_tilelang(a, b, c, d, x, hc: int, hidden: int, n_thr: int = 128, h_blk: int = 1024) -> tilelang.JITKernel: + # rename for shorter code + n = T.dynamic("num_tokens") + h = hidden + + h_blk = math.gcd(hidden, h_blk) + a: T.Tensor((n, hc, hc), T.float32) + b: T.Tensor((n, hc, h), T.bfloat16) + c: T.Tensor((n, hc), T.float32) + d: T.Tensor((n, h), T.bfloat16) + x: T.Tensor((n, hc, h), T.bfloat16) + with T.Kernel(n, threads=n_thr) as i_n: + x_shared = T.alloc_shared((hc, h_blk), T.bfloat16) + b_shared = T.alloc_shared((hc, h_blk), T.bfloat16) + d_shared = T.alloc_shared(h_blk, T.bfloat16) + + x_local = T.alloc_fragment((hc, h_blk), T.float32) + b_local = T.alloc_fragment((hc, h_blk), T.float32) + d_local = T.alloc_fragment(h_blk, T.float32) + + a_local = T.alloc_fragment((hc, hc), T.float32) + c_local = T.alloc_fragment(hc, T.float32) + T.copy(a[i_n, 0, 0], a_local) + T.copy(c[i_n, 0], c_local) + + for i0_h in T.Pipelined(T.ceildiv(h, h_blk), num_stages=2): + T.copy(b[i_n, 0, i0_h * h_blk], b_shared) + T.copy(d[i_n, i0_h * h_blk], d_shared) + + T.copy(b_shared, b_local) + T.copy(d_shared, d_local) + for i_hco, i1_h in T.Parallel(hc, h_blk): + x_local[i_hco, i1_h] = c_local[i_hco] * d_local[i1_h] + for i_hci in T.serial(hc): + x_local[i_hco, i1_h] += a_local[i_hci, i_hco] * b_local[i_hci, i1_h] + T.copy(x_local, x_shared) + + T.copy(x_shared, x[i_n, 0, i0_h * h_blk]) + + +def mhc_post( + x: torch.Tensor, + residual: torch.Tensor, + post_layer_mix: torch.Tensor, + comb_res_mix: torch.Tensor, +) -> torch.Tensor: + out = torch.empty_like(residual) + mhc_post_tilelang(comb_res_mix, residual, post_layer_mix.squeeze(-1), x, out, residual.shape[-2], residual.shape[-1]) + return out + + +def mhc_post_ref( + x: torch.Tensor, + residual: torch.Tensor, + post_layer_mix: torch.Tensor, + comb_res_mix: torch.Tensor, +) -> torch.Tensor: + term2 = torch.bmm(comb_res_mix.mT, residual.float()) + return (x.float().unsqueeze(-2) * post_layer_mix + term2).bfloat16() + + +def generate_test_data( + n: int, + h: int, + hc_mult: int, + device: str = "cuda", +) -> dict[str, torch.Tensor]: + """Generate test data for post operator.""" + torch.random.manual_seed(42) + + x = torch.randn((n, h), dtype=torch.bfloat16, device=device) + residual = torch.randn((n, hc_mult, h), dtype=torch.bfloat16, device=device) + post_layer_mix = torch.randn((n, hc_mult, 1), dtype=torch.float32, device=device) + comb_res_mix = torch.randn((n, hc_mult, hc_mult), dtype=torch.float32, device=device) + + return { + "x": x, + "residual": residual, + "post_layer_mix": post_layer_mix, + "comb_res_mix": comb_res_mix, + } + + +def test(n: int, h: int) -> None: + print(f"Testing mhc_post with {n=} {h=}") + test_data = generate_test_data(n=n, h=h, hc_mult=4) + out_tl = mhc_post(**test_data) + out_ref = mhc_post_ref(**test_data) + torch.testing.assert_close(out_tl, out_ref) + + +def main(): + for n in [4096]: + for h in [1280, 2560, 7168]: + test(n=n, h=h) + + +if __name__ == "__main__": + main() diff --git a/examples/deepseek_mhc/example_mhc_pre.py b/examples/deepseek_mhc/example_mhc_pre.py new file mode 100644 index 000000000..9dbd66839 --- /dev/null +++ b/examples/deepseek_mhc/example_mhc_pre.py @@ -0,0 +1,419 @@ +import math + +import tilelang +import tilelang.language as T +import torch + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, + }, +) +def mhc_pre_big_fuse_tilelang( + gemm_out_mul, + gemm_out_sqrsum, + hc_scale, + hc_base, + residual, + post_mix, + comb_mix, + layer_input, + hidden_size: int, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + n_splits: int = 16, + hc_mult: int = 4, +): + """Deeply fused kernels, everything other than gemm & sqrsum in mHC pre block.""" + num_tokens = T.dynamic("num_tokens") + hc_mult3 = hc_mult * (2 + hc_mult) + hidden_block = math.gcd(512, hidden_size) + + gemm_out_mul: T.Tensor[[n_splits, num_tokens, hc_mult3], T.float32] + gemm_out_sqrsum: T.Tensor[[n_splits, num_tokens], T.float32] + hc_scale: T.Tensor[[3], T.float32] + hc_base: T.Tensor[[hc_mult3], T.float32] + residual: T.Tensor[[num_tokens, hc_mult, hidden_size], T.bfloat16] + # outputs + post_mix: T.Tensor[[num_tokens, hc_mult], T.float32] + comb_mix: T.Tensor[[num_tokens, hc_mult * hc_mult], T.float32] + layer_input: T.Tensor[[num_tokens, hidden_size], T.bfloat16] + + with T.Kernel(num_tokens, threads=96) as i: + ################################################################## + # _pre_norm_fn_fwd_norm + rms = T.alloc_fragment(1, T.float32) + mixes = T.alloc_fragment(hc_mult3, T.float32) + T.clear(mixes) + rms[0] = 0 + for i_split in T.serial(n_splits): + rms[0] += gemm_out_sqrsum[i_split, i] + rms[0] = T.rsqrt(rms[0] / (hc_mult * hidden_size) + rms_eps) + for j in T.Parallel(hc_mult3): + mixes[j] = 0 + for i_split in T.serial(n_splits): + mixes[j] += gemm_out_mul[i_split, i, j] + mixes[j] *= rms[0] + mixes_shared = T.alloc_shared(hc_mult3, T.float32) + T.copy(mixes, mixes_shared) + + if T.get_thread_binding() < 32: + ################################################################## + # _pre_split_mixes_fwd (post & comb) + cm = T.alloc_fragment((hc_mult, hc_mult), T.float32) + for j in T.Parallel(hc_mult): + post_mix[i, j] = T.sigmoid(mixes_shared[j + hc_mult] * hc_scale[1] + hc_base[j + hc_mult]) * hc_post_mult_value + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = mixes_shared[j * hc_mult + k + hc_mult * 2] * hc_scale[2] + hc_base[j * hc_mult + k + hc_mult * 2] + + ################################################################## + # _sinkhorn_fwd + row_sum = T.alloc_fragment(hc_mult, T.float32) + col_sum = T.alloc_fragment(hc_mult, T.float32) + + # comb = comb.softmax(-1) + eps + row_max = T.alloc_fragment(hc_mult, T.float32) + T.reduce_max(cm, row_max, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = T.exp(cm[j, k] - row_max[j]) + T.reduce_sum(cm, row_sum, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / row_sum[j] + hc_sinkhorn_eps + + # comb = comb / (comb.sum(-2) + eps) + T.reduce_sum(cm, col_sum, dim=0) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps) + + for _ in T.serial(sinkhorn_repeat - 1): + # comb = comb / (comb.sum(-1) + eps) + T.reduce_sum(cm, row_sum, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (row_sum[j] + hc_sinkhorn_eps) + + # comb = comb / (comb.sum(-2) + eps) + T.reduce_sum(cm, col_sum, dim=0) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps) + + # save comb_mix to global memory + for j, k in T.Parallel(hc_mult, hc_mult): + comb_mix[i, j * hc_mult + k] = cm[j, k] + else: + ################################################################## + # _pre_split_mixes_fwd (pre) + pre_mix_shared = T.alloc_shared(hc_mult, T.float32) + for j in T.Parallel(hc_mult): + pre_mix_shared[j] = ( + T.sigmoid( + mixes_shared[j] * hc_scale[0] + hc_base[j], + ) + + hc_pre_eps + ) + ################################################################### + # _pre_apply_mix_fwd + for i0_h in T.Pipelined(hidden_size // hidden_block, num_stages=2): + xs = T.alloc_shared((hc_mult, hidden_block), T.float32) + xl = T.alloc_fragment((hc_mult, hidden_block), T.float32) + T.copy(residual[i, 0, i0_h * hidden_block], xs) + T.copy(xs, xl) + + ol = T.alloc_fragment(hidden_block, T.float32) + T.clear(ol) + + for i_hc in T.serial(hc_mult): + pre = pre_mix_shared[i_hc] + for i1_h in T.Parallel(hidden_block): + ol[i1_h] += pre * xl[i_hc, i1_h] + + T.copy(ol, layer_input[i, i0_h * hidden_block]) + + +@tilelang.jit +def mhc_pre_gemm_sqrsum_tilelang( + x, + fn, + out, + sqrsum, + hc_mult3: int, + hc_hidden_size: int, + token_block: int = 32, + hidden_block: int = 256, +) -> tilelang.JITKernel: + """Not highly optimized TileLang implementation of fused gemm and sqrsum in mHC pre block.""" + assert hc_mult3 <= 32 # should be 24 usually + num_tokens = T.dynamic("num_tokens") + assert hc_hidden_size % hidden_block == 0 + + x: T.Tensor((num_tokens, hc_hidden_size), T.bfloat16) + fn: T.Tensor((hc_mult3, hc_hidden_size), T.float32) + out: T.Tensor((num_tokens, hc_mult3), T.float32) + sqrsum: T.Tensor((num_tokens), T.float32) + + with T.Kernel(T.ceildiv(num_tokens, token_block)) as px: + out_frag = T.alloc_fragment((token_block, 32), T.float32) + sqrsum_part = T.alloc_fragment((token_block, 4), T.float32) + T.clear(out_frag) + T.clear(sqrsum_part) + for pz in T.Pipelined(hc_hidden_size // hidden_block, num_stages=2): + x_smem_16 = T.alloc_shared((token_block, hidden_block), T.bfloat16) + fn_smem = T.alloc_shared((32, hidden_block), T.float32) + + T.annotate_layout({x_smem_16: tilelang.layout.make_swizzled_layout(x_smem_16)}) + + T.copy(x[px * token_block, pz * hidden_block], x_smem_16) + T.copy(fn[0, pz * hidden_block], fn_smem) + + x_frag_16 = T.alloc_fragment((token_block, hidden_block), T.bfloat16) + T.copy(x_smem_16, x_frag_16) + x_frag = T.alloc_fragment((token_block, hidden_block), T.float32) + T.copy(x_frag_16, x_frag) + + for jj in T.serial(hidden_block // 4): + for i, j in T.Parallel(token_block, 4): + sqrsum_part[i, j] += x_frag[i, jj * 4 + j] * x_frag[i, jj * 4 + j] + + # should be TF32 gemm + T.gemm( + x_frag, + fn_smem, + out_frag, + transpose_A=False, + transpose_B=True, + wg_wait=0, + clear_accum=False, + ) + sqrsum_l = T.alloc_fragment(token_block, T.float32) + T.reduce_sum(sqrsum_part, sqrsum_l) + for i in T.Parallel(token_block): + sqrsum[px * token_block + i] = sqrsum_l[i] + for i, j in T.Parallel(token_block, 32): + if j < hc_mult3: + out[px * token_block + i, j] = out_frag[i, j] + + +def mhc_pre( + residual: torch.Tensor, + fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + n_splits: int = 1, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward pass for mHC pre block. + + Args: + residual: shape (..., hc_mult, hidden_size), dtype torch.bfloat16 + fn: shape (hc_mult3, hc_mult * hidden_size), dtype torch.float32 + hc_scale: shape (3,), dtype torch.float32 + hc_base: shape (hc_mult3,), dtype torch.float32 + rms_eps: RMS normalization epsilon + hc_pre_eps: pre-mix epsilon + hc_sinkhorn_eps: sinkhorn epsilon + hc_post_mult_value: post-mix multiplier value + sinkhorn_repeat: number of sinkhorn iterations + n_splits: split-k factor; TileLang version of mhc_pre_gemm_sqrsum doesn't support this + + Returns: + post_mix: shape (..., hc_mult), dtype torch.float32 + comb_mix: shape (..., hc_mult, hc_mult), dtype torch.float32 + layer_input: shape (..., hidden_size), dtype torch.bfloat16 + """ + + # Validate shapes + assert residual.dtype == torch.bfloat16 + assert fn.dtype == torch.float32 + assert hc_scale.dtype == torch.float32 + assert hc_base.dtype == torch.float32 + + hc_mult = residual.shape[-2] + hidden_size = residual.shape[-1] + hc_mult2 = hc_mult * hc_mult + hc_mult3 = hc_mult * 2 + hc_mult2 + + hc_hidden_size = hc_mult * hidden_size + assert fn.shape[0] == hc_mult3 + assert fn.shape[1] == hc_hidden_size + assert hc_scale.shape == (3,) + assert hc_base.shape == (hc_mult3,) + + outer_shape = residual.shape[:-2] + + residual_flat = residual.view(-1, hc_mult, hidden_size) + num_tokens = residual_flat.shape[0] + fn_flat = fn + + post_mix = torch.empty(num_tokens, hc_mult, dtype=torch.float32, device=residual.device) + comb_mix = torch.empty(num_tokens, hc_mult2, dtype=torch.float32, device=residual.device) + layer_input = torch.empty(num_tokens, hidden_size, dtype=torch.bfloat16, device=residual.device) + + gemm_out_mul = torch.empty(n_splits, num_tokens, hc_mult3, dtype=torch.float32, device=residual.device) + gemm_out_sqrsum = torch.empty(n_splits, num_tokens, dtype=torch.float32, device=residual.device) + assert n_splits == 1, "The simple TileLang version gemm_sqrsum doesn't support split-k" + mhc_pre_gemm_sqrsum_tilelang( + residual_flat.view(num_tokens, hc_mult * hidden_size), + fn_flat, + gemm_out_mul.squeeze(0), + gemm_out_sqrsum.squeeze(0), + hc_mult3, + hc_mult * hidden_size, + ) + + mhc_pre_big_fuse_tilelang( + gemm_out_mul, + gemm_out_sqrsum, + hc_scale, + hc_base, + residual_flat, + post_mix, + comb_mix, + layer_input, + hidden_size, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_mult_value, + sinkhorn_repeat, + n_splits, + hc_mult, + ) + + post_mix = post_mix.view(*outer_shape, hc_mult, 1) + comb_mix = comb_mix.view(*outer_shape, hc_mult, hc_mult) + layer_input = layer_input.view(*outer_shape, hidden_size) + + return post_mix, comb_mix, layer_input + + +def sinkhorn_normalize_ref(x: torch.Tensor, repeat: int, eps: float) -> torch.Tensor: + x = x.softmax(-1) + eps + x = x / (x.sum(-2, keepdim=True) + eps) + for _ in range(repeat - 1): + x = x / (x.sum(-1, keepdim=True) + eps) + x = x / (x.sum(-2, keepdim=True) + eps) + return x + + +def mhc_pre_ref( + residual: torch.Tensor, + fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hc_mult = residual.shape[-2] + + residual_flat = residual.flatten(-2, -1).float() + sqrsum = residual_flat.square().sum(-1) + mixes = residual_flat @ fn.T * (sqrsum.unsqueeze(-1) / fn.shape[-1] + rms_eps).rsqrt() + + hc_scale = torch.cat( + [ + hc_scale[0].expand(hc_mult), + hc_scale[1].expand(hc_mult), + hc_scale[2].expand(hc_mult * hc_mult), + ], + ) + mixes = mixes * hc_scale + hc_base + + pre_mix = mixes[:, :hc_mult].sigmoid().unsqueeze(-1) + hc_pre_eps + post_mix = (mixes[:, hc_mult : 2 * hc_mult].sigmoid() * hc_post_mult_value).unsqueeze(-1) + res_mix = mixes[:, 2 * hc_mult :].view(-1, hc_mult, hc_mult) + + res_mix = sinkhorn_normalize_ref(res_mix, repeat=sinkhorn_repeat, eps=hc_sinkhorn_eps) + + layer_input = (residual * pre_mix).sum(-2).bfloat16() + + return post_mix, res_mix, layer_input + + +def generate_test_data( + n: int, + hc_mult: int, + hidden_size: int, + rms_eps: float = 1e-6, + hc_pre_eps: float = 1e-6, + hc_sinkhorn_eps: float = 1e-6, + hc_post_mult_value: float = 1.0, + sinkhorn_repeat: int = 10, +) -> dict[str, torch.Tensor | float]: + """Generate test data for big fuse operator.""" + torch.random.manual_seed(42) + + hc_mult2 = hc_mult * hc_mult + hc_mult3 = hc_mult * 2 + hc_mult2 + device = "cuda" + + residual = ( + torch.randn((n, hc_mult, hidden_size), dtype=torch.float, device=device) + .mul(1 + torch.arange(hc_mult, device=device).mul(0.01).view(1, -1, 1)) + .bfloat16() + ) + + fn = ( + torch.randn((hc_mult3, hc_mult, hidden_size), dtype=torch.float, device=device) + * 1e-4 + * (1 + torch.arange(hc_mult, device=device).mul(0.01).view(1, -1, 1)) + ).flatten(1, 2) + + hc_scale = torch.randn((3,), dtype=torch.float, device=device) * 0.1 + + hc_base = torch.randn((hc_mult3,), dtype=torch.float, device=device) * 0.1 + + return { + "residual": residual, + "fn": fn, + "hc_scale": hc_scale, + "hc_base": hc_base, + "rms_eps": rms_eps, + "hc_pre_eps": hc_pre_eps, + "hc_sinkhorn_eps": hc_sinkhorn_eps, + "hc_post_mult_value": hc_post_mult_value, + "sinkhorn_repeat": sinkhorn_repeat, + } + + +def test(n: int, hidden_size: int, hc_mult: int) -> None: + print(f"Testing mhc_pre with {n=} {hidden_size=} {hc_mult=}") + test_data = generate_test_data( + n=n, + hc_mult=hc_mult, + hidden_size=hidden_size, + ) + + # Forward pass with big fuse + post_mix_fused, comb_mix_fused, layer_input_fused = mhc_pre(**test_data) + + # Forward pass with reference + post_mix_ref, comb_mix_ref, layer_input_ref = mhc_pre_ref(**test_data) + + # Compare outputs + torch.testing.assert_close(post_mix_fused, post_mix_ref) + torch.testing.assert_close(comb_mix_fused, comb_mix_ref) + torch.testing.assert_close(layer_input_fused, layer_input_ref) + + +def main(): + for n1 in [512, 1024, 2048, 8192]: + for hidden_size in [1280, 2560, 4096]: + for hc_mult in [4]: + test(n=n1, hidden_size=hidden_size, hc_mult=hc_mult) + + +if __name__ == "__main__": + main() diff --git a/examples/deepseek_mhc/test_example_mhc.py b/examples/deepseek_mhc/test_example_mhc.py new file mode 100644 index 000000000..3d9ecad4d --- /dev/null +++ b/examples/deepseek_mhc/test_example_mhc.py @@ -0,0 +1,18 @@ +import tilelang.testing + +from example_mhc_post import main as main_post +from example_mhc_pre import main as main_pre + + +@tilelang.testing.requires_cuda +def test_mhc_post(): + main_post() + + +@tilelang.testing.requires_cuda +def test_mhc_pre(): + main_pre() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/deepseek_mla/README.md b/examples/deepseek_mla/README.md index e64b1c37d..bd3539d26 100644 --- a/examples/deepseek_mla/README.md +++ b/examples/deepseek_mla/README.md @@ -24,14 +24,14 @@ We benchmarked the performance of FlashMLA, TileLang, Torch, Triton, and FlashIn
Figure 2:Performance under batch size=128
-As shown in the results, TileLang achieves performance comparable to FlashMLA in most cases, significantly outperforming both FlashInfer and Triton. +As shown in the results, TileLang achieves performance comparable to FlashMLA in most cases, significantly outperforming both FlashInfer and Triton. Notably, **TileLang accomplishes this with just around 80 lines of Python code**, demonstrating its exceptional ease of use and efficiency. Let's dive in and see how TileLang achieves this. ## Implementation First, let's review the core computation logic of traditional FlashAttention: -```python +```python # acc_s: [block_M, block_N] # scores_max: [block_M] # scores_scale: [block_M] @@ -54,7 +54,7 @@ Compared to traditional attention operators like MHA (Multi-Headed Attention) or This raises the question of how to partition the matrix multiplication operation. On the Hopper architecture, most computation kernels use [`wgmma.mma_async`](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) instructions for optimal performance. The `wgmma.mma_async` instruction organizes 4 warps (128 threads) into a warpgroup for collective MMA operations. However, `wgmma.mma_async` instructions require a minimum M dimension of 64. This means each warpgroup's minimum M dimension can only be reduced to 64, but a tile size of 64*512 is too large for a single warpgroup, leading to register spilling. -Therefore, our only option is to partition `acc_o` along the `dim` dimension, with two warpgroups computing the left and right part of `acc_o` respectively. However, this introduces another challenge: both warpgroups require the complete `acc_s` result as input. +Therefore, our only option is to partition `acc_o` along the `dim` dimension, with two warpgroups computing the left and right part of `acc_o` respectively. However, this introduces another challenge: both warpgroups require the complete `acc_s` result as input. Our solution is to have each warpgroup compute half of `acc_s` during `Q @ K` computation, then obtain the other half computed by the other warpgroup through shared memory. @@ -96,7 +96,6 @@ T.use_swizzle(panel_size: int, order: str = "row") Here, `panel_size` specifies the width of the swizzled threadblock group, and `order` determines the swizzling pattern, which can be either "row" or "col". - ### Shared Memory Swizzling In CUDA programming, shared memory is divided into multiple memory banks, with each bank capable of servicing one thread request per clock cycle in parallel. Bank conflicts occur when multiple threads simultaneously access different addresses mapped to the same bank, forcing these accesses to be serialized and degrading performance. @@ -113,17 +112,14 @@ T.annotate_layout({ Here, `T.annotate_layout` allows users to specify any desired layout for a buffer. For convenience, TileLang provides the `make_swizzled_layout` primitive to automatically generate a swizzled layout. - ### Warp-Specialization The Hopper architecture commonly employs warp specialization for performance optimization. A typical approach is to designate one warpgroup as a producer that handles data movement using TMA (Tensor Memory Accelerator), while the remaining warpgroups serve as consumers performing computations. However, this programming pattern is complex, requiring developers to manually manage the execution logic for producers and consumers, including synchronization through the `mbarrier` objects. In TileLang, users are completely shielded from these implementation details. The frontend script is automatically transformed into a warp-specialized form, where TileLang handles all producer-consumer synchronization automatically, enabling efficient computation. - ### Pipeline - Pipeline is a technique used to improve memory access efficiency by overlapping memory access and computation. In TileLang, pipeline can be implemented through the `T.pipelined` annotation: ```python @@ -132,9 +128,8 @@ T.pipelined(range: int, stage: int) Here, `range` specifies the range of the pipeline, and `stage` specifies the stage of the pipeline. Multi-stage pipelining enables overlapping of computation and memory access, which can significantly improve performance for memory-intensive operators. However, setting a higher number of stages consumes more shared memory resources, so the optimal configuration needs to be determined based on specific use cases. - ### Split-KV We have also implemented Split-KV optimization similar to [FlashDecoding](https://pytorch.org/blog/flash-decoding/). Specifically, when the batch size is small, parallel SM resources cannot be fully utilized due to low parallelism. In such cases, we can split the kv_ctx dimension across multiple SMs for parallel computation and then merge the results. -In our implementation, we have developed both split and combine kernels, allowing users to control the split size through a `num_split` parameter. \ No newline at end of file +In our implementation, we have developed both split and combine kernels, allowing users to control the split size through a `num_split` parameter. diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_aiter.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_aiter.py new file mode 100644 index 000000000..9eae48082 --- /dev/null +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_aiter.py @@ -0,0 +1,290 @@ +# This benchmark script is modified based on: https://github.com/deepseek-ai/FlashMLA/blob/main/benchmark/bench_flash_mla.py +# ruff: noqa +import argparse +import math +import random +import torch + +import triton +import triton.language as tl + +import tilelang +from tilelang.profiler import do_bench + +try: + from aiter.mla import mla_decode_fwd +except ImportError: + print("aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device.") + + +def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + +@torch.inference_mode() +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + blocked_v = blocked_k[..., :dv] + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + h_q, + h_kv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + lse[i] = LSE + return out, lse + + out_torch, lse_torch = ref_mla() + t = triton.testing.do_bench(ref_mla) + return out_torch, lse_torch, t + + +@torch.inference_mode() +def run_mla_aiter(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + assert d > dv, "mla with rope dim should be larger than no rope dim" + + qo_indptr = torch.zeros(b + 1, dtype=torch.int) + kv_indptr = torch.zeros(b + 1, dtype=torch.int) + seq_lens_qo = torch.empty(b, dtype=torch.int) + seq_lens_qo.fill_(1) + max_seqlen_qo = seq_lens_qo.max().item() + + kv_indptr[1 : b + 1] = torch.cumsum(cache_seqlens, dim=0) + qo_indptr[1 : b + 1] = torch.cumsum(seq_lens_qo, dim=0) + total_q = qo_indptr[-1].item() + + # set block_size to 1 + page_size = 1 + kv_buffer = blocked_k.view(-1, page_size, h_kv, d) + + flat_indices = [] + for i in range(b): + start = i * max_seqlen_pad + end = start + cache_seqlens[i] + flat_indices.append(torch.arange(start, end, dtype=torch.int)) + + kv_indices = torch.cat(flat_indices) + + kv_last_page_lens = torch.ones(b, dtype=torch.int) + + sm_scale = 1.0 / (d**0.5) + + def mla_aiter(): + out_aiter = torch.empty((total_q, h_q, dv), dtype=dtype).fill_(-1) + attn_logits_aiter, attn_lse_aiter = mla_decode_fwd( + q.view((total_q, h_q, d)), + kv_buffer, + out_aiter, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + max_seqlen_qo, + sm_scale, + ) + return out_aiter.view([b, s_q, h_q, dv]) + + out_aiter = mla_aiter() + t = triton.testing.do_bench(mla_aiter) + return out_aiter, None, t + + +FUNC_TABLE = { + "torch": run_torch_mla, + "mla_aiter": run_mla_aiter, +} + + +def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print( + f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" + ) + device = torch.device("cuda:0") + torch.set_default_dtype(dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert baseline in FUNC_TABLE + assert target in FUNC_TABLE + baseline_func = FUNC_TABLE[baseline] + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_a, lse_a, perf_a = baseline_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + + torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" + if target not in ["mla_aiter"]: + # flash_mla_triton doesn't return lse + torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.3f} TFLOPS, {bytes / 10**6 / perf_a:.3f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.3f} TFLOPS, {bytes / 10**6 / perf_b:.3f} GB/s") + return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b + + +def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") + torch.set_default_dtype(dtype) + device = torch.device("cuda:0") + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert target in FUNC_TABLE, f"target {target} not in {FUNC_TABLE}" + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.3f} TFLOPS, {bytes / 10**6 / perf_b:.3f} GB/s") + return bytes / 10**6 / perf_b + + +available_targets = [ + "torch", + "mla_aiter", +] + +shape_configs = [ + { + "b": batch, + "s_q": 1, + "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), + "h_q": head, + "h_kv": 1, + "d": 512 + 64, + "dv": 512, + "causal": True, + "dtype": torch.bfloat16, + } + for batch in [64, 128] + for seqlen in [1024, 2048, 4096, 8192, 16384] + for head in [128] +] + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--baseline", type=str, default="torch") + parser.add_argument("--target", type=str, default="mla_aiter") + parser.add_argument("--all", action="store_true") + parser.add_argument("--one", action="store_true") + parser.add_argument("--compare", action="store_true") + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = get_args() + benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target + with open(f"{benchmark_type}_perf.csv", "w") as fout: + fout.write("name,batch,seqlen,head,bw\n") + for shape in shape_configs: + if args.all: + for target in available_targets: + perf = compare_a( + target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" + ) + elif args.compare: + perfa, prefb = compare_ab( + args.baseline, + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n" + ) + fout.write( + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n" + ) + elif args.one: + perf = compare_a( + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" + ) diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py index db460437f..399bb8e6e 100644 --- a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py @@ -8,6 +8,7 @@ def get_configs(): import itertools + BLOCK_N = [16, 32, 64, 128] BLOCK_H = [16, 32, 64, 128] num_split = [1, 2, 4, 8, 16, 32] @@ -15,45 +16,44 @@ def get_configs(): _configs = list(itertools.product(BLOCK_N, BLOCK_H, num_split, threads)) - return [{ - "block_N": c[0], - "block_H": c[1], - "num_split": c[2], - "threads": c[3], - } for c in _configs] + return [ + { + "block_N": c[0], + "block_H": c[1], + "num_split": c[2], + "threads": c[3], + } + for c in _configs + ] @tilelang.autotune(configs=get_configs()) @tilelang.jit( - out_idx=[6], pass_configs={ + out_idx=[6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashmla_decode(batch, - heads, - kv_head_num, - seqlen_kv, - dim, - pe_dim, - block_N, - block_H, - num_split, - threads=128): - scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) - dtype = "float16" - accum_dtype = "float" + }, +) +def flashmla_decode(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, threads=128): + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // kv_head_num VALID_BLOCK_H = min(block_H, kv_group_num) assert kv_head_num == 1, "kv_head_num must be 1" - @T.macro - def flash_attn( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + @T.prim_func + def main_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): - with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=threads) as (bx, by): + # flash_attn_split + with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=threads) as (bx, by, bz): Q_local = T.alloc_fragment([block_H, dim], dtype) Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype) KV_shared = T.alloc_shared([block_N, dim], dtype) @@ -69,34 +69,31 @@ def flash_attn( cur_kv_head = by // (kv_group_num // block_H) T.use_swizzle(10) - - T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_local) - T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_local) + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_local) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_local) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = T.ceildiv(seqlen_kv, block_N) + loop_range = T.ceildiv((seqlen_kv // num_split), block_N) for k in T.Pipelined(loop_range, num_stages=0): - T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared) - T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared) + kv_start = (seqlen_kv // num_split) * bz + k * block_N + kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N + T.copy(KV[bx, kv_start:kv_end, cur_kv_head, :], KV_shared) + T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared) T.clear(acc_s) T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.gemm( - Q_pe_local, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_pe_local, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) - # T.copy(acc_s, S_shared) T.copy(acc_s, acc_s_cast) for i in T.Parallel(block_H): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] @@ -105,20 +102,50 @@ def flash_attn( T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :]) + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, glse[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz]) + T.copy(acc_o, Output_partial[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz, :]) - @T.macro - def flash_attn_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + # combine + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim], dtype) + o_accum_local = T.alloc_fragment([dim], accum_dtype) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local = T.max(lse_max_local, glse[bz, by, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split = glse[bz, by, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local + for k in T.serial(num_split): + for i in T.Parallel(dim): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split = glse[bz, by, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) + for i in T.Parallel(dim): + o_accum_local[i] += po_local[i] * scale_local + for i in T.Parallel(dim): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): - with T.Kernel( - batch, heads // min(block_H, kv_group_num), num_split, - threads=threads) as (bx, by, bz): + with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=threads) as (bx, by): Q_local = T.alloc_fragment([block_H, dim], dtype) Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype) KV_shared = T.alloc_shared([block_N, dim], dtype) @@ -134,34 +161,31 @@ def flash_attn_split( cur_kv_head = by // (kv_group_num // block_H) T.use_swizzle(10) - T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_local) - T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_local) + + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_local) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_local) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + loop_range = T.ceildiv(seqlen_kv, block_N) for k in T.Pipelined(loop_range, num_stages=0): - kv_start = (seqlen_kv // num_split) * bz + k * block_N - kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N - T.copy(KV[bx, kv_start:kv_end, cur_kv_head, :], KV_shared) - T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared) + T.copy(KV[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared) T.clear(acc_s) T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.gemm( - Q_pe_local, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_pe_local, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) + # T.copy(acc_s, S_shared) T.copy(acc_s, acc_s_cast) for i in T.Parallel(block_H): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] @@ -170,72 +194,7 @@ def flash_attn_split( T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] - for i in T.Parallel(block_H): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, glse[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz]) - T.copy(acc_o, Output_partial[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz, :]) - - @T.macro - def combine( - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), - ): - with T.Kernel(heads, batch, threads=128) as (by, bz): - po_local = T.alloc_fragment([dim], dtype) - o_accum_local = T.alloc_fragment([dim], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - - T.annotate_layout({ - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) - for k in T.serial(num_split): - lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) - for k in T.Pipelined(num_split, num_stages=1): - lse_local_split[0] = glse[bz, by, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] - for k in T.serial(num_split): - for i in T.Parallel(dim): - po_local[i] = Output_partial[bz, by, k, i] - lse_local_split[0] = glse[bz, by, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) - for i in T.Parallel(dim): - o_accum_local[i] += po_local[i] * scale_local[0] - for i in T.Parallel(dim): - Output[bz, by, i] = o_accum_local[i] - - @T.prim_func - def main_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), - ): - flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial) - combine(glse, Output_partial, Output) - - @T.prim_func - def main_no_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), - ): - flash_attn(Q, Q_pe, KV, K_pe, Output) + T.copy(acc_o, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :]) if num_split > 1: return main_split @@ -258,43 +217,36 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=128, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') - parser.add_argument('--autotune', action='store_true', help='auto tune') + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") + parser.add_argument("--autotune", action="store_true", help="auto tune") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim enable_autotune = args.autotune @@ -307,26 +259,16 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): num_split = 4 threads = 128 + print(f"Using {batch=}, {heads=}, {kv_heads=}, {kv_ctx=}, {dim=}, {pe_dim=}") + if enable_autotune: kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim) else: - kernel = flashmla_decode( - batch, - heads, - kv_heads, - kv_ctx, - dim, - pe_dim, - BLOCK_N, - BLOCK_H, - num_split, - threads=threads) + kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, threads=threads) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) input_tensors = profiler._get_inputs() tilelang_output = kernel(*input_tensors) ref_output = ref_program(*input_tensors) - print(f"Tilelang output: {tilelang_output}") - print(f"Ref output: {ref_output}") torch.testing.assert_close(tilelang_output, ref_output, rtol=0.01, atol=0.01) latency = profiler.do_bench(warmup=500) print(f"Latency: {latency} ms") diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py deleted file mode 100644 index 0006d9468..000000000 --- a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py +++ /dev/null @@ -1,495 +0,0 @@ -# This benchmark script is modified based on: https://github.com/deepseek-ai/FlashMLA/blob/main/benchmark/bench_flash_mla.py -# ruff: noqa -import argparse -import math -import random -import torch -import triton -import triton.language as tl - -import tilelang -from tilelang.profiler import do_bench - - -def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): - query = query.float() - key = key.float() - value = value.float() - key = key.repeat_interleave(h_q // h_kv, dim=0) - value = value.repeat_interleave(h_q // h_kv, dim=0) - attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) - if is_causal: - s_q = query.shape[-2] - s_k = key.shape[-2] - attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) - temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) - attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) - attn_bias.to(query.dtype) - attn_weight += attn_bias - lse = attn_weight.logsumexp(dim=-1) - attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) - return attn_weight @ value, lse - - -@torch.inference_mode() -def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, - h_kv, d, dv, causal, dtype): - blocked_v = blocked_k[..., :dv] - - def ref_mla(): - out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) - lse = torch.empty(b, h_q, s_q, dtype=torch.float32) - for i in range(b): - begin = i * max_seqlen_pad - end = begin + cache_seqlens[i] - O, LSE = scaled_dot_product_attention( - q[i].transpose(0, 1), - blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), - blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), - h_q, - h_kv, - is_causal=causal, - ) - out[i] = O.transpose(0, 1) - lse[i] = LSE - return out, lse - - out_torch, lse_torch = ref_mla() - t = triton.testing.do_bench(ref_mla) - return out_torch, lse_torch, t - - -@triton.jit -def _mla_attn_kernel( - Q_nope, - Q_pe, - Kv_c_cache, - K_pe_cache, - Req_to_tokens, - B_seq_len, - O, - sm_scale, - stride_q_nope_bs, - stride_q_nope_h, - stride_q_pe_bs, - stride_q_pe_h, - stride_kv_c_bs, - stride_k_pe_bs, - stride_req_to_tokens_bs, - stride_o_b, - stride_o_h, - stride_o_s, - BLOCK_H: tl.constexpr, - BLOCK_N: tl.constexpr, - NUM_KV_SPLITS: tl.constexpr, - PAGE_SIZE: tl.constexpr, - HEAD_DIM_CKV: tl.constexpr, - HEAD_DIM_KPE: tl.constexpr, -): - cur_batch = tl.program_id(1) - cur_head_id = tl.program_id(0) - split_kv_id = tl.program_id(2) - - cur_batch_seq_len = tl.load(B_seq_len + cur_batch) - - offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) - cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) - offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[ - None, :] - q_nope = tl.load(Q_nope + offs_q_nope) - - offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) - offs_q_pe = cur_batch * stride_q_pe_bs + cur_head[:, None] * stride_q_pe_h + offs_d_kpe[None, :] - q_pe = tl.load(Q_pe + offs_q_pe) - - e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") - e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) - acc = tl.zeros([BLOCK_H, HEAD_DIM_CKV], dtype=tl.float32) - - kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) - split_kv_start = kv_len_per_split * split_kv_id - split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) - - for start_n in range(split_kv_start, split_kv_end, BLOCK_N): - offs_n = start_n + tl.arange(0, BLOCK_N) - kv_page_number = tl.load( - Req_to_tokens + stride_req_to_tokens_bs * cur_batch + offs_n // PAGE_SIZE, - mask=offs_n < split_kv_end, - other=0, - ) - kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE - offs_k_c = kv_loc[None, :] * stride_kv_c_bs + offs_d_ckv[:, None] - k_c = tl.load(Kv_c_cache + offs_k_c, mask=offs_n[None, :] < split_kv_end, other=0.0) - - qk = tl.dot(q_nope, k_c.to(q_nope.dtype)) - - offs_k_pe = kv_loc[None, :] * stride_k_pe_bs + offs_d_kpe[:, None] - k_pe = tl.load(K_pe_cache + offs_k_pe, mask=offs_n[None, :] < split_kv_end, other=0.0) - - qk += tl.dot(q_pe, k_pe.to(q_pe.dtype)) - qk *= sm_scale - - qk = tl.where(offs_n[None, :] < split_kv_end, qk, float("-inf")) - - v_c = tl.trans(k_c) - - n_e_max = tl.maximum(tl.max(qk, 1), e_max) - re_scale = tl.exp(e_max - n_e_max) - p = tl.exp(qk - n_e_max[:, None]) - acc *= re_scale[:, None] - acc += tl.dot(p.to(v_c.dtype), v_c) - - e_sum = e_sum * re_scale + tl.sum(p, 1) - e_max = n_e_max - offs_o = cur_batch * stride_o_b + cur_head[:, - None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[ - None, :] - tl.store(O + offs_o, acc / e_sum[:, None]) - offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV - tl.store(O + offs_o_1, e_max + tl.log(e_sum)) - - -def _mla_attn( - q_nope, - q_pe, - kv_c_cache, - k_pe_cache, - attn_logits, - req_to_tokens, - b_seq_len, - num_kv_splits, - sm_scale, - page_size, -): - batch_size, head_num = q_nope.shape[0], q_nope.shape[1] - head_dim_ckv = q_nope.shape[-1] - head_dim_kpe = q_pe.shape[-1] - - BLOCK_H = 16 - BLOCK_N = 64 - grid = ( - triton.cdiv(head_num, BLOCK_H), - batch_size, - num_kv_splits, - ) - _mla_attn_kernel[grid]( - q_nope, - q_pe, - kv_c_cache, - k_pe_cache, - req_to_tokens, - b_seq_len, - attn_logits, - sm_scale, - # stride - q_nope.stride(0), - q_nope.stride(1), - q_pe.stride(0), - q_pe.stride(1), - kv_c_cache.stride(-2), - k_pe_cache.stride(-2), - req_to_tokens.stride(0), - attn_logits.stride(0), - attn_logits.stride(1), - attn_logits.stride(2), - BLOCK_H=BLOCK_H, - BLOCK_N=BLOCK_N, - NUM_KV_SPLITS=num_kv_splits, - PAGE_SIZE=page_size, - HEAD_DIM_CKV=head_dim_ckv, - HEAD_DIM_KPE=head_dim_kpe, - num_stages=1, # 2 will oom in amd - ) - - -@triton.jit -def _mla_softmax_reducev_kernel( - Logits, - B_seq_len, - O, - stride_l_b, - stride_l_h, - stride_l_s, - stride_o_b, - stride_o_h, - NUM_KV_SPLITS: tl.constexpr, - HEAD_DIM_CKV: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - cur_batch_seq_len = tl.load(B_seq_len + cur_batch) - - offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) - - e_sum = 0.0 - e_max = -float("inf") - acc = tl.zeros([HEAD_DIM_CKV], dtype=tl.float32) - - offs_l = cur_batch * stride_l_b + cur_head * stride_l_h + offs_d_ckv - offs_l_1 = cur_batch * stride_l_b + cur_head * stride_l_h + HEAD_DIM_CKV - - for split_kv_id in range(0, NUM_KV_SPLITS): - kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) - split_kv_start = kv_len_per_split * split_kv_id - split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) - - if split_kv_end > split_kv_start: - logits = tl.load(Logits + offs_l + split_kv_id * stride_l_s) - logits_1 = tl.load(Logits + offs_l_1 + split_kv_id * stride_l_s) - - n_e_max = tl.maximum(logits_1, e_max) - old_scale = tl.exp(e_max - n_e_max) - acc *= old_scale - exp_logic = tl.exp(logits_1 - n_e_max) - acc += exp_logic * logits - - e_sum = e_sum * old_scale + exp_logic - e_max = n_e_max - - tl.store( - O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv, - acc / e_sum, - ) - - -def _mla_softmax_reducev( - logits, - o, - b_seq_len, - num_kv_splits, -): - batch_size, head_num, head_dim_ckv = o.shape[0], o.shape[1], o.shape[2] - grid = (batch_size, head_num) - _mla_softmax_reducev_kernel[grid]( - logits, - b_seq_len, - o, - logits.stride(0), - logits.stride(1), - logits.stride(2), - o.stride(0), - o.stride(1), - NUM_KV_SPLITS=num_kv_splits, - HEAD_DIM_CKV=head_dim_ckv, - ) - - -def mla_decode_triton( - q_nope, - q_pe, - kv_c_cache, - k_pe_cache, - o, - req_to_tokens, - b_seq_len, - attn_logits, - num_kv_splits, - sm_scale, - page_size, -): - assert num_kv_splits == attn_logits.shape[2] - _mla_attn( - q_nope, - q_pe, - kv_c_cache, - k_pe_cache, - attn_logits, - req_to_tokens, - b_seq_len, - num_kv_splits, - sm_scale, - page_size, - ) - _mla_softmax_reducev( - attn_logits, - o, - b_seq_len, - num_kv_splits, - ) - - -@torch.inference_mode() -def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, - cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - - blocked_v = blocked_k[..., :dv] - - assert d > dv, "mla with rope dim should be larger than no rope dim" - q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() - - def flash_mla_triton(): - num_kv_splits = 32 - o = torch.empty([b * s_q, h_q, dv]) - attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) - mla_decode_triton( - q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv), - blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits, - num_kv_splits, 1 / math.sqrt(d), block_size) - return o.view([b, s_q, h_q, dv]) - - out_flash = flash_mla_triton() - t = triton.testing.do_bench(flash_mla_triton) - return out_flash, None, t - - -FUNC_TABLE = { - "torch": run_torch_mla, - "flash_mla_triton": run_flash_mla_triton, -} - - -def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - print( - f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" - ) - device = torch.device("cuda:0") - torch.set_default_dtype(dtype) - torch.set_default_device(device) - torch.cuda.set_device(device) - torch.manual_seed(0) - random.seed(0) - assert baseline in FUNC_TABLE - assert target in FUNC_TABLE - baseline_func = FUNC_TABLE[baseline] - target_func = FUNC_TABLE[target] - - total_seqlens = cache_seqlens.sum().item() - max_seqlen = cache_seqlens.max().item() - max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 - # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") - - q = torch.randn(b, s_q, h_q, d) - block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) - blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - - out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) - - torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" - if target not in ["flash_mla_triton"]: - # flash_mla_triton doesn't return lse - torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" - - FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s" - ) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) - return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b - - -def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - print( - f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" - ) - torch.set_default_dtype(dtype) - device = torch.device("cuda:0") - torch.set_default_device(device) - torch.cuda.set_device(device) - torch.manual_seed(0) - random.seed(0) - assert target in FUNC_TABLE, f"target {target} not in {FUNC_TABLE}" - target_func = FUNC_TABLE[target] - - total_seqlens = cache_seqlens.sum().item() - max_seqlen = cache_seqlens.max().item() - max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 - # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") - - q = torch.randn(b, s_q, h_q, d) - block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) - blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) - - FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) - return bytes / 10**6 / perf_b - - -available_targets = [ - "torch", - "flash_mla_triton", -] - -shape_configs = [{ - "b": - batch, - "s_q": - 1, - "cache_seqlens": - torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), - "h_q": - head, - "h_kv": - 1, - "d": - 512 + 64, - "dv": - 512, - "causal": - True, - "dtype": - torch.float16 -} for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 16384] for head in [128]] - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--baseline", type=str, default="torch") - parser.add_argument("--target", type=str, default="torch") - parser.add_argument("--all", action="store_true") - parser.add_argument("--one", action="store_true") - parser.add_argument("--compare", action="store_true") - args = parser.parse_args() - return args - - -if __name__ == "__main__": - args = get_args() - benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target - with open(f"{benchmark_type}_perf.csv", "w") as fout: - fout.write("name,batch,seqlen,head,bw\n") - for shape in shape_configs: - if args.all: - for target in available_targets: - perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) - fout.write( - f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' - ) - elif args.compare: - perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], - shape["cache_seqlens"], shape["h_q"], shape["h_kv"], - shape["d"], shape["dv"], shape["causal"], shape["dtype"]) - fout.write( - f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n' - ) - fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n' - ) - elif args.one: - perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) - fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' - ) diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py index 644f97da1..e8c1006a0 100644 --- a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py @@ -29,8 +29,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): @torch.inference_mode() -def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, - h_kv, d, dv, causal, dtype): +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] def ref_mla(): @@ -91,8 +90,7 @@ def _mla_attn_kernel( offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) - offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[ - None, :] + offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] q_nope = tl.load(Q_nope + offs_q_nope) offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) @@ -138,9 +136,7 @@ def _mla_attn_kernel( e_sum = e_sum * re_scale + tl.sum(p, 1) e_max = n_e_max - offs_o = cur_batch * stride_o_b + cur_head[:, - None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[ - None, :] + offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] tl.store(O + offs_o, acc / e_sum[:, None]) offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV tl.store(O + offs_o_1, e_max + tl.log(e_sum)) @@ -306,24 +302,30 @@ def mla_decode_triton( @torch.inference_mode() -def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, - cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - +def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() def flash_mla_triton(): num_kv_splits = 32 o = torch.empty([b * s_q, h_q, dv]) attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) mla_decode_triton( - q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv), - blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits, - num_kv_splits, 1 / math.sqrt(d), block_size) + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, d - dv), + blocked_k_nope.view(-1, dv), + blocked_k_pe.view(-1, d - dv), + o, + block_table, + cache_seqlens, + attn_logits, + num_kv_splits, + 1 / math.sqrt(d), + block_size, + ) return o.view([b, s_q, h_q, dv]) out_flash = flash_mla_triton() @@ -359,14 +361,15 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_a, lse_a, perf_a = baseline_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" if target not in ["flash_mla_triton"]: @@ -374,21 +377,14 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s" - ) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.3f} TFLOPS, {bytes / 10**6 / perf_a:.3f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.3f} TFLOPS, {bytes / 10**6 / perf_b:.3f} GB/s") return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - print( - f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" - ) + print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") torch.set_default_dtype(dtype) device = torch.device("cuda:0") torch.set_default_device(device) @@ -405,19 +401,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.3f} TFLOPS, {bytes / 10**6 / perf_b:.3f} GB/s") return bytes / 10**6 / perf_b @@ -426,26 +419,22 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): "flash_mla_triton", ] -shape_configs = [{ - "b": - batch, - "s_q": - 1, - "cache_seqlens": - torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), - "h_q": - head, - "h_kv": - 1, - "d": - 512 + 64, - "dv": - 512, - "causal": - True, - "dtype": - torch.float16 -} for batch in [64, 128] for seqlen in [1024, 2048, 4096, 8192, 16384] for head in [128]] +shape_configs = [ + { + "b": batch, + "s_q": 1, + "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), + "h_q": head, + "h_kv": 1, + "d": 512 + 64, + "dv": 512, + "causal": True, + "dtype": torch.float16, + } + for batch in [64, 128] + for seqlen in [1024, 2048, 4096, 8192, 16384] + for head in [128] +] def get_args(): @@ -467,26 +456,54 @@ def get_args(): for shape in shape_configs: if args.all: for target in available_targets: - perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) elif args.compare: - perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], - shape["cache_seqlens"], shape["h_q"], shape["h_kv"], - shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + perfa, prefb = compare_ab( + args.baseline, + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n' + f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n" ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n" ) elif args.one: - perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) diff --git a/examples/deepseek_mla/benchmark_mla.py b/examples/deepseek_mla/benchmark_mla.py index a542ff611..544b5e128 100644 --- a/examples/deepseek_mla/benchmark_mla.py +++ b/examples/deepseek_mla/benchmark_mla.py @@ -33,8 +33,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): @torch.inference_mode() -def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, - h_kv, d, dv, causal, dtype): +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] def ref_mla(): @@ -61,8 +60,7 @@ def ref_mla(): @torch.inference_mode() -def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, - h_kv, d, dv, causal, dtype): +def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): from flash_mla import flash_mla_with_kvcache, get_mla_metadata blocked_v = blocked_k[..., :dv] @@ -87,14 +85,13 @@ def flash_mla(): @torch.inference_mode() -def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, - h_q, h_kv, d, dv, causal, dtype): +def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): # pip install flashinfer-python import flashinfer + assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() kv_indptr = [0] kv_indices = [] @@ -111,8 +108,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) - mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( - torch.empty(128 * 1024 * 1024, dtype=torch.int8), backend="fa3") + mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(torch.empty(128 * 1024 * 1024, dtype=torch.int8), backend="fa3") mla_wrapper.plan( q_indptr, kv_indptr, @@ -129,12 +125,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q ) def flashinfer(): - output, lse = mla_wrapper.run( - q_nope.view(-1, h_q, dv), - q_pe.view(-1, h_q, d - dv), - blocked_k_nope, - blocked_k_pe, - return_lse=True) + output, lse = mla_wrapper.run(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope, blocked_k_pe, return_lse=True) return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1) out_flash, lse_flash = flashinfer() @@ -177,8 +168,7 @@ def _mla_attn_kernel( offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) - offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[ - None, :] + offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] q_nope = tl.load(Q_nope + offs_q_nope) offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) @@ -224,9 +214,7 @@ def _mla_attn_kernel( e_sum = e_sum * re_scale + tl.sum(p, 1) e_max = n_e_max - offs_o = cur_batch * stride_o_b + cur_head[:, - None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[ - None, :] + offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] tl.store(O + offs_o, acc / e_sum[:, None]) offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV tl.store(O + offs_o_1, e_max + tl.log(e_sum)) @@ -393,24 +381,30 @@ def mla_decode_triton( @torch.inference_mode() -def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, - cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - +def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() def flash_mla_triton(): num_kv_splits = 32 o = torch.empty([b * s_q, h_q, dv]) attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) mla_decode_triton( - q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv), - blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits, - num_kv_splits, 1 / math.sqrt(d), block_size) + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, d - dv), + blocked_k_nope.view(-1, dv), + blocked_k_pe.view(-1, d - dv), + o, + block_table, + cache_seqlens, + attn_logits, + num_kv_splits, + 1 / math.sqrt(d), + block_size, + ) return o.view([b, s_q, h_q, dv]) out_flash = flash_mla_triton() @@ -419,13 +413,10 @@ def flash_mla_triton(): @torch.inference_mode() -def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, - cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - +def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() dpe = d - dv num_kv_splits = 1 @@ -434,8 +425,7 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) - kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, - num_kv_splits, block_size) + kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size) def flash_mla_tilelang(): out = kernel( @@ -486,38 +476,31 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_a, lse_a, perf_a = baseline_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" - if target not in ["flashinfer", "flash_mla_triton", "tilelang" - ] and baseline not in ["flashinfer", "flash_mla_triton", "tilelang"]: + if target not in ["flashinfer", "flash_mla_triton", "tilelang"] and baseline not in ["flashinfer", "flash_mla_triton", "tilelang"]: # flashinfer has a different lse return value # flash_mla_triton and flash_mla_tilelang doesn't return lse torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s" - ) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - print( - f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" - ) + print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") torch.set_default_dtype(dtype) device = torch.device("cuda:0") torch.set_default_device(device) @@ -534,19 +517,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") return bytes / 10**6 / perf_b @@ -558,26 +538,22 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): "flash_mla_triton", ] -shape_configs = [{ - "b": - batch, - "s_q": - 1, - "cache_seqlens": - torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), - "h_q": - head, - "h_kv": - 1, - "d": - 512 + 64, - "dv": - 512, - "causal": - True, - "dtype": - torch.float16 -} for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 16384, 32768] for head in [128]] +shape_configs = [ + { + "b": batch, + "s_q": 1, + "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), + "h_q": head, + "h_kv": 1, + "d": 512 + 64, + "dv": 512, + "causal": True, + "dtype": torch.float16, + } + for batch in [128] + for seqlen in [1024, 2048, 4096, 8192, 16384, 32768] + for head in [128] +] def get_args(): @@ -599,26 +575,54 @@ def get_args(): for shape in shape_configs: if args.all: for target in available_targets: - perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) elif args.compare: - perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], - shape["cache_seqlens"], shape["h_q"], shape["h_kv"], - shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + perfa, prefb = compare_ab( + args.baseline, + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n' + f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n" ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n" ) elif args.one: - perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) diff --git a/examples/deepseek_mla/example_mla_decode.py b/examples/deepseek_mla/example_mla_decode.py index 417e319fd..d6d76e54e 100644 --- a/examples/deepseek_mla/example_mla_decode.py +++ b/examples/deepseek_mla/example_mla_decode.py @@ -8,27 +8,31 @@ @tilelang.jit( - out_idx=[6], pass_configs={ + out_idx=[6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, - softmax_scale): + }, +) +def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale): scale = float(softmax_scale * 1.44269504) # log2(e) - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // kv_head_num VALID_BLOCK_H = min(block_H, kv_group_num) assert kv_head_num == 1, "kv_head_num must be 1" - @T.macro - def flash_attn( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + @T.prim_func + def main_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): - with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid): + # flash_attn_split + with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=256) as (bid, hid, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) S_shared = T.alloc_shared([block_H, block_N], dtype) Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) @@ -36,6 +40,7 @@ def flash_attn( K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype) O_shared = T.alloc_shared([block_H, dim], dtype) acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) acc_o = T.alloc_fragment([block_H, dim], accum_dtype) scores_max = T.alloc_fragment([block_H], accum_dtype) scores_max_prev = T.alloc_fragment([block_H], accum_dtype) @@ -44,64 +49,87 @@ def flash_attn( logsum = T.alloc_fragment([block_H], accum_dtype) cur_kv_head = hid // (kv_group_num // block_H) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - }) + T.use_swizzle(10) - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = T.ceildiv(seqlen_kv, block_N) + loop_range = T.ceildiv((seqlen_kv // num_split), block_N) for k in T.Pipelined(loop_range, num_stages=2): - T.copy(KV[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared) - T.copy(K_pe[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared) - T.gemm( - Q_shared, - KV_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol, - clear_accum=True) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + kv_start = (seqlen_kv // num_split) * bz + k * block_N + kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N + T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared) + T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared) + T.clear(acc_s) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) T.copy(acc_s, S_shared) + T.copy(S_shared, acc_s_cast) for i in T.Parallel(block_H): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_H, dim): acc_o[i, j] *= scores_scale[i] - T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz]) T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :]) + T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, :]) + + # combine + with T.Kernel(heads, batch, threads=128) as (hid, bz): + po_local = T.alloc_fragment([dim], dtype) + o_accum_local = T.alloc_fragment([dim], accum_dtype) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local = T.max(lse_max_local, glse[bz, hid, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split = glse[bz, hid, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local + for k in T.serial(num_split): + for i in T.Parallel(dim): + po_local[i] = Output_partial[bz, hid, k, i] + lse_local_split = glse[bz, hid, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) + for i in T.Parallel(dim): + o_accum_local[i] += po_local[i] * scale_local + for i in T.Parallel(dim): + Output[bz, hid, i] = o_accum_local[i] - @T.macro - def flash_attn_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): - with T.Kernel( - batch, heads // min(block_H, kv_group_num), num_split, - threads=256) as (bid, hid, bz): + with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid): Q_shared = T.alloc_shared([block_H, dim], dtype) S_shared = T.alloc_shared([block_H, block_N], dtype) Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) @@ -109,7 +137,6 @@ def flash_attn_split( K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype) O_shared = T.alloc_shared([block_H, dim], dtype) acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) acc_o = T.alloc_fragment([block_H, dim], accum_dtype) scores_max = T.alloc_fragment([block_H], accum_dtype) scores_max_prev = T.alloc_fragment([block_H], accum_dtype) @@ -118,118 +145,39 @@ def flash_attn_split( logsum = T.alloc_fragment([block_H], accum_dtype) cur_kv_head = hid // (kv_group_num // block_H) - T.use_swizzle(10) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - S_shared: tilelang.layout.make_swizzled_layout(S_shared), - }) - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + loop_range = T.ceildiv(seqlen_kv, block_N) for k in T.Pipelined(loop_range, num_stages=2): - kv_start = (seqlen_kv // num_split) * bz + k * block_N - kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N - T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared) - T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared) - T.clear(acc_s) - T.gemm( - Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.copy(KV[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) T.copy(acc_s, S_shared) - T.copy(S_shared, acc_s_cast) for i in T.Parallel(block_H): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_H, dim): acc_o[i, j] *= scores_scale[i] - T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] - for i in T.Parallel(block_H): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, bz]) T.copy(acc_o, O_shared) - T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - bz, :]) - - @T.macro - def combine( - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), - ): - with T.Kernel(heads, batch, threads=128) as (hid, bz): - po_local = T.alloc_fragment([dim], dtype) - o_accum_local = T.alloc_fragment([dim], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - - T.annotate_layout({ - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) - for k in T.serial(num_split): - lse_max_local[0] = T.max(lse_max_local[0], glse[bz, hid, k]) - for k in T.Pipelined(num_split, num_stages=1): - lse_local_split[0] = glse[bz, hid, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] - for k in T.serial(num_split): - for i in T.Parallel(dim): - po_local[i] = Output_partial[bz, hid, k, i] - lse_local_split[0] = glse[bz, hid, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) - for i in T.Parallel(dim): - o_accum_local[i] += po_local[i] * scale_local[0] - for i in T.Parallel(dim): - Output[bz, hid, i] = o_accum_local[i] - - @T.prim_func - def main_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), - ): - flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial) - combine(glse, Output_partial, Output) - - @T.prim_func - def main_no_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), - ): - flash_attn(Q, Q_pe, KV, K_pe, Output) + T.copy(O_shared, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :]) if num_split > 1: return main_split @@ -252,31 +200,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out @@ -294,10 +235,9 @@ def main( BLOCK_N = 64 BLOCK_H = min(64, heads // kv_heads) num_split = 1 - softmax_scale = (dim + pe_dim)**-0.5 + softmax_scale = (dim + pe_dim) ** -0.5 - kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, - softmax_scale) + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, softmax_scale) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4) latency = profiler.do_bench(warmup=500) @@ -305,14 +245,33 @@ def main( print(f"TFlops: {total_flops / latency * 1e-9} TFlops") +def run_regression_perf( + batch=1, + heads=128, + kv_heads=1, + kv_ctx=8192, + dim=512, + pe_dim=64, +): + BLOCK_N = 64 + BLOCK_H = min(64, heads // kv_heads) + num_split = 1 + softmax_scale = (dim + pe_dim) ** -0.5 + + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, softmax_scale) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=132, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') + parser.add_argument("--batch", type=int, default=132, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) diff --git a/examples/deepseek_mla/example_mla_decode_paged.py b/examples/deepseek_mla/example_mla_decode_paged.py index fe50d4d4f..2e1911028 100644 --- a/examples/deepseek_mla/example_mla_decode_paged.py +++ b/examples/deepseek_mla/example_mla_decode_paged.py @@ -8,41 +8,36 @@ @tilelang.jit( - out_idx=[8], pass_configs={ + out_idx=[8], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def mla_decode_tilelang(batch, - h_q, - h_kv, - max_seqlen_pad, - dv, - dpe, - block_N, - block_H, - num_split, - block_size, - softmax_scale=None): + }, +) +def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split, block_size, softmax_scale=None): if softmax_scale is None: - softmax_scale = (dv + dpe)**-0.5 + softmax_scale = (dv + dpe) ** -0.5 scale = float(softmax_scale * 1.44269504) # log2(e) - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = h_q // h_kv VALID_BLOCK_H = min(block_H, kv_group_num) assert h_kv == 1, "h_kv must be 1" assert block_size >= block_N and block_size % block_N == 0, "block_size must be larger than block_N and a multiple of block_N" - @T.macro - def flash_mla_kernel( - Q: T.Tensor([batch, h_q, dv], dtype), - Q_pe: T.Tensor([batch, h_q, dpe], dtype), - KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), - K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), - BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), - CACHE_SEQLENS: T.Tensor([batch], "int32"), - Output: T.Tensor([batch, h_q, dv], dtype), + @T.prim_func + def main_split( + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + block_table: T.Tensor([batch, max_seqlen_pad // block_size], T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + Output: T.Tensor([batch, h_q, dv], dtype), ): - with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by): + # split kv + with T.Kernel(batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dv], dtype) S_shared = T.alloc_shared([block_H, block_N], dtype) Q_pe_shared = T.alloc_shared([block_H, dpe], dtype) @@ -50,6 +45,7 @@ def flash_mla_kernel( K_pe_shared = T.alloc_shared([block_N, dpe], dtype) O_shared = T.alloc_shared([block_H, dv], dtype) acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) acc_o = T.alloc_fragment([block_H, dv], accum_dtype) scores_max = T.alloc_fragment([block_H], accum_dtype) scores_max_prev = T.alloc_fragment([block_H], accum_dtype) @@ -59,69 +55,94 @@ def flash_mla_kernel( cur_kv_head = by // (kv_group_num // block_H) T.use_swizzle(10) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - S_shared: tilelang.layout.make_swizzled_layout(S_shared), - }) - T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = T.ceildiv(CACHE_SEQLENS[bx], block_N) - for kr in T.Pipelined(loop_range, num_stages=2): - k = loop_range - 1 - kr - kv_start = BLOCK_TABLE[bx, (k * block_N) // - block_size] * block_size + (k * block_N) % block_size - T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared) - T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared) + total_blocks = T.ceildiv(cache_seqlens[bx], block_N) + blocks_per_split = T.floordiv(total_blocks, num_split) + remaining_blocks = T.floormod(total_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0) + start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N + + for k in T.Pipelined(loop_range, num_stages=2): + kv_start = block_table[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_size + T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[kv_start : kv_start + block_N, cur_kv_head, :], K_pe_shared) T.clear(acc_s) - T.gemm( - Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) - if kr == 0: - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], - -T.infinity(accum_dtype), acc_s[i, j]) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(start + k * block_N + j >= cache_seqlens[bx], -T.infinity(accum_dtype), acc_s[i, j]) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) T.copy(acc_s, S_shared) + T.copy(S_shared, acc_s_cast) for i in T.Parallel(block_H): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_H, dv): acc_o[i, j] *= scores_scale[i] - T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) for i, j in T.Parallel(block_H, dv): acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, glse[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz]) T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :]) - - @T.macro - def flash_mla_split_kv_kernel( - Q: T.Tensor([batch, h_q, dv], dtype), - Q_pe: T.Tensor([batch, h_q, dpe], dtype), - KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), - K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), - BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), - CACHE_SEQLENS: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, h_q, num_split], dtype), - Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz, :]) + + # combine + with T.Kernel(h_q, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dv], dtype) + o_accum_local = T.alloc_fragment([dv], accum_dtype) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local = T.max(lse_max_local, glse[bz, by, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split = glse[bz, by, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local + for k in T.serial(num_split): + for i in T.Parallel(dv): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split = glse[bz, by, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) + for i in T.Parallel(dv): + o_accum_local[i] += po_local[i] * scale_local + for i in T.Parallel(dv): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + block_table: T.Tensor([batch, max_seqlen_pad // block_size], T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + Output: T.Tensor([batch, h_q, dv], dtype), ): - with T.Kernel( - batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz): + with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by): Q_shared = T.alloc_shared([block_H, dv], dtype) S_shared = T.alloc_shared([block_H, block_N], dtype) Q_pe_shared = T.alloc_shared([block_H, dpe], dtype) @@ -129,7 +150,6 @@ def flash_mla_split_kv_kernel( K_pe_shared = T.alloc_shared([block_N, dpe], dtype) O_shared = T.alloc_shared([block_H, dv], dtype) acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) acc_o = T.alloc_fragment([block_H, dv], accum_dtype) scores_max = T.alloc_fragment([block_H], accum_dtype) scores_max_prev = T.alloc_fragment([block_H], accum_dtype) @@ -139,129 +159,45 @@ def flash_mla_split_kv_kernel( cur_kv_head = by // (kv_group_num // block_H) T.use_swizzle(10) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - S_shared: tilelang.layout.make_swizzled_layout(S_shared), - }) - T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - total_blocks = T.ceildiv(CACHE_SEQLENS[bx], block_N) - blocks_per_split = T.floordiv(total_blocks, num_split) - remaining_blocks = T.floormod(total_blocks, num_split) - loop_range = (blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0)) - start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N - - for k in T.Pipelined(loop_range, num_stages=2): - kv_start = BLOCK_TABLE[bx, (start + k * block_N) // - block_size] * block_size + (k * block_N) % block_size - T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared) - T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared) + loop_range = T.ceildiv(cache_seqlens[bx], block_N) + for kr in T.Pipelined(loop_range, num_stages=2): + k = loop_range - 1 - kr + kv_start = block_table[bx, (k * block_N) // block_size] * block_size + (k * block_N) % block_size + T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[kv_start : kv_start + block_N, cur_kv_head, :], K_pe_shared) T.clear(acc_s) - T.gemm( - Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx], - -T.infinity(accum_dtype), acc_s[i, j]) + if kr == 0: + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= cache_seqlens[bx], -T.infinity(accum_dtype), acc_s[i, j]) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) T.copy(acc_s, S_shared) - T.copy(S_shared, acc_s_cast) for i in T.Parallel(block_H): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_H, dv): acc_o[i, j] *= scores_scale[i] - T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) for i, j in T.Parallel(block_H, dv): acc_o[i, j] /= logsum[i] - for i in T.Parallel(block_H): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, glse[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz]) T.copy(acc_o, O_shared) - T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz, :]) - - @T.macro - def combine( - glse: T.Tensor([batch, h_q, num_split], dtype), - Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), - Output: T.Tensor([batch, h_q, dv], dtype), - ): - with T.Kernel(h_q, batch, threads=128) as (by, bz): - po_local = T.alloc_fragment([dv], dtype) - o_accum_local = T.alloc_fragment([dv], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - - T.annotate_layout({ - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) - for k in T.serial(num_split): - lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) - for k in T.Pipelined(num_split, num_stages=1): - lse_local_split[0] = glse[bz, by, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] - for k in T.serial(num_split): - for i in T.Parallel(dv): - po_local[i] = Output_partial[bz, by, k, i] - lse_local_split[0] = glse[bz, by, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) - for i in T.Parallel(dv): - o_accum_local[i] += po_local[i] * scale_local[0] - for i in T.Parallel(dv): - Output[bz, by, i] = o_accum_local[i] - - @T.prim_func - def main_split( - Q: T.Tensor([batch, h_q, dv], dtype), - Q_pe: T.Tensor([batch, h_q, dpe], dtype), - KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), - K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), - block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, h_q, num_split], dtype), - Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), - Output: T.Tensor([batch, h_q, dv], dtype), - ): - flash_mla_split_kv_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, glse, - Output_partial) - combine(glse, Output_partial, Output) - - @T.prim_func - def main_no_split( - Q: T.Tensor([batch, h_q, dv], dtype), - Q_pe: T.Tensor([batch, h_q, dpe], dtype), - KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), - K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), - block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, h_q, num_split], dtype), - Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), - Output: T.Tensor([batch, h_q, dv], dtype), - ): - flash_mla_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, Output) + T.copy(O_shared, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :]) if num_split > 1: return main_split @@ -280,8 +216,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): s_q = query.shape[-2] s_k = key.shape[-2] attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype, device=query.device) - temp_mask = torch.ones( - s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) attn_weight += attn_bias @@ -291,8 +226,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): @torch.inference_mode() -def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, - h_kv, d, dv, causal, dtype): +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): # q: [b, s_q, h_q, d] # block_table: [b, max_seqlen_pad // block_size] # blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d] @@ -321,13 +255,10 @@ def ref_mla(): return out_torch -def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, - h_q, h_kv, d, dv, causal, dtype): - +def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() dpe = d - dv num_kv_splits = 1 @@ -337,8 +268,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) - kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, - num_kv_splits, block_size, softmax_scale) + kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size, softmax_scale) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) def flash_mla_tilelang(): @@ -356,8 +286,7 @@ def flash_mla_tilelang(): out_flash = flash_mla_tilelang() t = do_bench(flash_mla_tilelang) - out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, - cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) torch.testing.assert_close(out_flash, out_ref, rtol=0.01, atol=0.01) print("All close") return out_flash, t @@ -365,12 +294,12 @@ def flash_mla_tilelang(): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=128, help='batch size') - parser.add_argument('--h_q', type=int, default=128, help='q heads number') - parser.add_argument('--h_kv', type=int, default=1, help='kv heads number') - parser.add_argument('--cache_seqlen', type=int, default=8192, help='kv cache context length') - parser.add_argument('--d', type=int, default=576, help='query/key head dim, d = dv + dpe') - parser.add_argument('--dv', type=int, default=512, help='value head dim') + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--h_q", type=int, default=128, help="q heads number") + parser.add_argument("--h_kv", type=int, default=1, help="kv heads number") + parser.add_argument("--cache_seqlen", type=int, default=8192, help="kv cache context length") + parser.add_argument("--d", type=int, default=576, help="query/key head dim, d = dv + dpe") + parser.add_argument("--dv", type=int, default=512, help="value head dim") args = parser.parse_args() b, h_q, h_kv, cache_seqlen, d, dv = args.batch, args.h_q, args.h_kv, args.cache_seqlen, args.d, args.dv @@ -379,9 +308,7 @@ def flash_mla_tilelang(): s_q = 1 # for decode, s_q = 1 block_size = 64 - cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)], - dtype=torch.int32, - device=device) + cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)], dtype=torch.int32, device=device) dpe = d - dv causal = True @@ -393,12 +320,11 @@ def flash_mla_tilelang(): total_flops = s_q * total_seqlens * h_q * d * 2 q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device) - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32, - device=device).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32, device=device).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d, dtype=dtype, device=device) - out_flash, latency = run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_flash, latency = run_tilelang_mla( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) diff --git a/examples/deepseek_mla/example_mla_decode_persistent.py b/examples/deepseek_mla/example_mla_decode_persistent.py index 3f57ea051..74d974fbb 100644 --- a/examples/deepseek_mla/example_mla_decode_persistent.py +++ b/examples/deepseek_mla/example_mla_decode_persistent.py @@ -9,13 +9,15 @@ @tilelang.jit( - out_idx=[6], pass_configs={ + out_idx=[6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): - scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) - dtype = "float16" - accum_dtype = "float" + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // kv_head_num VALID_BLOCK_H = min(block_H, kv_group_num) assert kv_head_num == 1, "kv_head_num must be 1" @@ -23,13 +25,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ @T.prim_func def main_split_persistent( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(sm_num, threads=256) as (block_id): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -48,16 +50,11 @@ def main_split_persistent( logsum = T.alloc_fragment([block_H], accum_dtype) po_local = T.alloc_fragment([dim], dtype) o_accum_local = T.alloc_fragment([dim], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - - T.annotate_layout({ - # O_shared: tilelang.layout.make_swizzled_layout(O_shared), - S_shared: tilelang.layout.make_swizzled_layout(S_shared), - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + T.use_swizzle(10) total_tiles = batch * (heads // min(block_H, kv_group_num)) * num_split @@ -70,8 +67,8 @@ def main_split_persistent( cur_kv_head = hid // (kv_group_num // block_H) if bid < batch and hid * VALID_BLOCK_H < heads and sid < num_split: - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -83,24 +80,15 @@ def main_split_persistent( T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared) T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared) T.clear(acc_s) - T.gemm( - Q_shared, - KV_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - - scores_max[i] * scale) + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) @@ -115,11 +103,9 @@ def main_split_persistent( acc_o[i, j] /= logsum[i] for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, sid]) + T.copy(logsum, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, sid]) # T.copy(acc_o, O_shared) - T.copy( - acc_o, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - sid, :]) + T.copy(acc_o, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, sid, :]) T.sync_grid() waves = T.ceildiv(heads * batch, sm_num) @@ -130,20 +116,20 @@ def main_split_persistent( if bid < batch and hid < heads: T.clear(lse_logsum_local) T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) + lse_max_local = -T.infinity(accum_dtype) for k in T.serial(num_split): - lse_max_local[0] = T.max(lse_max_local[0], glse[bid, hid, k]) + lse_max_local = T.max(lse_max_local, glse[bid, hid, k]) for k in T.Pipelined(num_split, num_stages=1): - lse_local_split[0] = glse[bid, hid, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + lse_local_split = glse[bid, hid, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local for k in T.serial(num_split): for i in T.Parallel(dim): po_local[i] = Output_partial[bid, hid, k, i] - lse_local_split[0] = glse[bid, hid, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + lse_local_split = glse[bid, hid, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) for i in T.Parallel(dim): - o_accum_local[i] += po_local[i] * scale_local[0] + o_accum_local[i] += po_local[i] * scale_local for i in T.Parallel(dim): Output[bid, hid, i] = o_accum_local[i] @@ -165,42 +151,35 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out def main(): parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=128, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) diff --git a/examples/deepseek_mla/example_mla_decode_ws.py b/examples/deepseek_mla/example_mla_decode_ws.py index 6554d57de..32eb0d475 100644 --- a/examples/deepseek_mla/example_mla_decode_ws.py +++ b/examples/deepseek_mla/example_mla_decode_ws.py @@ -13,30 +13,38 @@ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }, compile_flags=[ - "-O3", "-Wno-deprecated-declarations", "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", - "--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG" + "-O3", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--ptxas-options=-v,--register-usage-level=10", + "-DNDEBUG", ], ) -def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, - softmax_scale): +def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale): sm_scale = float(softmax_scale * 1.44269504) # log2(e) - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // kv_head_num VALID_BLOCK_H = min(block_H, kv_group_num) assert kv_head_num == 1, "kv_head_num must be 1" - @T.macro - def flash_attn( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + @T.prim_func + def main_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): - with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=384) as (hid, bid): + # flash_attn_split + with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=384) as (bid, hid, bz): Q_shared_l = T.alloc_shared([block_H, dim // 2], dtype) Q_shared_r = T.alloc_shared([block_H, dim // 2], dtype) Q_tail_shared = T.alloc_shared([block_H, pe_dim], dtype) @@ -75,16 +83,16 @@ def flash_attn( tx = T.get_thread_binding() - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, 0:dim // 2], Q_shared_l) - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, dim // 2:dim], Q_shared_r) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_tail_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2], Q_shared_l) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim], Q_shared_r) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_tail_shared) T.barrier_arrive(bar_q) if tx < 128: T.set_max_nreg(240, 1) T.fill(sumexp, 0) - T.fill(m_i, -2**30) # avoid -inf - inf to cause nan + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan T.fill(acc_o_l, 0) T.barrier_wait(bar_q, 0) @@ -105,6 +113,8 @@ def flash_attn( T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(block_H): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(block_H, block_N): @@ -137,6 +147,8 @@ def flash_attn( T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(block_H): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(block_H, block_N): @@ -162,8 +174,8 @@ def flash_attn( for h_i in T.Parallel(block_H): sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale T.copy(acc_o_l, O_shared_l) - T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - 0:dim // 2]) + T.copy(O_shared_l, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, 0 : dim // 2]) + T.copy(sumexp, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz]) elif tx >= 128 and tx < 256: T.set_max_nreg(168, 1) @@ -193,8 +205,7 @@ def flash_attn( acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] T.copy(acc_o_r, O_shared_r) - T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - dim // 2:dim]) + T.copy(O_shared_r, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, dim // 2 : dim]) elif tx >= 256: # producer @@ -203,59 +214,82 @@ def flash_attn( # Buffer 0 T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - kv_indices = (i_i * 2) * block_N + r * 16 + (tx - 256) // 8 + kv_indices = (seqlen_kv // num_split) * bz + (i_i * 2) * block_N + r * 16 + (tx - 256) // 8 with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_0_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_0_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, dim // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = K_pe[bid, kv_indices, cur_kv_head, - (tx - 256) % 8 * 8 + v] + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_0_ready[0]) # Buffer 1 T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - kv_indices = (i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8 + kv_indices = (seqlen_kv // num_split) * bz + (i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8 with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_1_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_1_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, dim // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = K_pe[bid, kv_indices, cur_kv_head, - (tx - 256) % 8 * 8 + v] + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_1_ready[0]) - @T.macro - def flash_attn_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + # combine + with T.Kernel(heads, batch, threads=128) as (hid, bz): + po_local = T.alloc_fragment([dim], dtype) + o_accum_local = T.alloc_fragment([dim], accum_dtype) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local = T.max(lse_max_local, glse[bz, hid, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split = glse[bz, hid, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local + for k in T.serial(num_split): + for i in T.Parallel(dim): + po_local[i] = Output_partial[bz, hid, k, i] + lse_local_split = glse[bz, hid, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) + for i in T.Parallel(dim): + o_accum_local[i] += po_local[i] * scale_local + for i in T.Parallel(dim): + Output[bz, hid, i] = o_accum_local[i] + + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): - with T.Kernel( - batch, heads // min(block_H, kv_group_num), num_split, - threads=384) as (bid, hid, bz): + with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=384) as (hid, bid): Q_shared_l = T.alloc_shared([block_H, dim // 2], dtype) Q_shared_r = T.alloc_shared([block_H, dim // 2], dtype) Q_tail_shared = T.alloc_shared([block_H, pe_dim], dtype) @@ -294,16 +328,16 @@ def flash_attn_split( tx = T.get_thread_binding() - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, 0:dim // 2], Q_shared_l) - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, dim // 2:dim], Q_shared_r) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_tail_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2], Q_shared_l) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim], Q_shared_r) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_tail_shared) T.barrier_arrive(bar_q) if tx < 128: T.set_max_nreg(240, 1) T.fill(sumexp, 0) - T.fill(m_i, -2**30) # avoid -inf - inf to cause nan + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan T.fill(acc_o_l, 0) T.barrier_wait(bar_q, 0) @@ -323,7 +357,9 @@ def flash_attn_split( T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1) T.copy(m_i, m_i_prev) - T.reduce_max(acc_s, m_i, dim=1, clear=False) + T.reduce_max(acc_s, out=m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(block_H): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(block_H, block_N): @@ -356,6 +392,8 @@ def flash_attn_split( T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(block_H): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(block_H, block_N): @@ -381,10 +419,7 @@ def flash_attn_split( for h_i in T.Parallel(block_H): sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale T.copy(acc_o_l, O_shared_l) - T.copy( - O_shared_l, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - bz, 0:dim // 2]) - T.copy(sumexp, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, bz]) + T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2]) elif tx >= 128 and tx < 256: T.set_max_nreg(168, 1) @@ -414,9 +449,7 @@ def flash_attn_split( acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] T.copy(acc_o_r, O_shared_r) - T.copy( - O_shared_r, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - bz, dim // 2:dim]) + T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim]) elif tx >= 256: # producer @@ -425,111 +458,43 @@ def flash_attn_split( # Buffer 0 T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - kv_indices = (seqlen_kv // num_split) * bz + ( - i_i * 2) * block_N + r * 16 + (tx - 256) // 8 + kv_indices = (i_i * 2) * block_N + r * 16 + (tx - 256) // 8 with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_0_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_0_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, dim // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = K_pe[bid, kv_indices, cur_kv_head, - (tx - 256) % 8 * 8 + v] + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_0_ready[0]) # Buffer 1 T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - kv_indices = (seqlen_kv // num_split) * bz + ( - i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8 + kv_indices = (i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8 with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_1_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_1_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, dim // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = K_pe[bid, kv_indices, cur_kv_head, - (tx - 256) % 8 * 8 + v] + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_1_ready[0]) - @T.macro - def combine( - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), - ): - with T.Kernel(heads, batch, threads=128) as (hid, bz): - po_local = T.alloc_fragment([dim], dtype) - o_accum_local = T.alloc_fragment([dim], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - - T.annotate_layout({ - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) - for k in T.serial(num_split): - lse_max_local[0] = T.max(lse_max_local[0], glse[bz, hid, k]) - for k in T.Pipelined(num_split, num_stages=1): - lse_local_split[0] = glse[bz, hid, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] - for k in T.serial(num_split): - for i in T.Parallel(dim): - po_local[i] = Output_partial[bz, hid, k, i] - lse_local_split[0] = glse[bz, hid, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) - for i in T.Parallel(dim): - o_accum_local[i] += po_local[i] * scale_local[0] - for i in T.Parallel(dim): - Output[bz, hid, i] = o_accum_local[i] - - @T.prim_func - def main_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), - ): - flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial) - combine(glse, Output_partial, Output) - - @T.prim_func - def main_no_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), - ): - flash_attn(Q, Q_pe, KV, K_pe, Output) - if num_split > 1: return main_split else: @@ -551,31 +516,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out @@ -593,10 +551,9 @@ def main( BLOCK_N = 64 BLOCK_H = min(64, heads // kv_heads) num_split = 1 - softmax_scale = (dim + pe_dim)**-0.5 + softmax_scale = (dim + pe_dim) ** -0.5 - kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, - softmax_scale) + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, softmax_scale) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4) latency = profiler.do_bench(warmup=500) @@ -606,12 +563,12 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=132, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') + parser.add_argument("--batch", type=int, default=132, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) diff --git a/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py b/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py index 1b1447e88..e70c35349 100644 --- a/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py +++ b/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py @@ -8,25 +8,27 @@ @tilelang.jit( - out_idx=[-1], pass_configs={ + out_idx=[-1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H): - scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) - dtype = "float16" - q_dtype = "float8_e4m3" - accum_dtype = "float" + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + q_dtype = T.float8_e4m3fn + accum_dtype = T.float32 kv_group_num = heads // kv_head_num VALID_BLOCK_H = min(block_H, kv_group_num) assert kv_head_num == 1, "kv_head_num must be 1" @T.prim_func def main_no_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], q_dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], q_dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=256) as (bx, by): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -46,34 +48,27 @@ def main_no_split( cur_kv_head = by // (kv_group_num // block_H) T.use_swizzle(10) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - }) - T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) T.disable_warp_group_reg_alloc() loop_range = T.ceildiv(seqlen_kv, block_N) for k in T.Pipelined(loop_range, num_stages=2): - T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], qKV_shared) - T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared) + T.copy(KV[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], qKV_shared) + T.copy(K_pe[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared) T.copy(qKV_shared, KV_shared) T.clear(acc_s) - T.gemm( - Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): @@ -88,7 +83,7 @@ def main_no_split( for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :]) + T.copy(O_shared, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :]) return main_no_split @@ -106,42 +101,35 @@ def ref_program(q, q_pe, kv, k_pe): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=128, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) diff --git a/examples/deepseek_mla/regression_example_mla_decode.py b/examples/deepseek_mla/regression_example_mla_decode.py new file mode 100644 index 000000000..64e1c436a --- /dev/null +++ b/examples/deepseek_mla/regression_example_mla_decode.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_mla_decode + + +def regression_example_mla_decode(): + tilelang.testing.process_func(example_mla_decode.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/deepseek_mla/test_example_mla_decode.py b/examples/deepseek_mla/test_example_mla_decode.py index 66a750f7d..a269ea57a 100644 --- a/examples/deepseek_mla/test_example_mla_decode.py +++ b/examples/deepseek_mla/test_example_mla_decode.py @@ -1,5 +1,4 @@ import tilelang.testing - import example_mla_decode diff --git a/examples/deepseek_mla/torch_refs.py b/examples/deepseek_mla/torch_refs.py index 4b4c888cd..aae6c7cd2 100644 --- a/examples/deepseek_mla/torch_refs.py +++ b/examples/deepseek_mla/torch_refs.py @@ -11,7 +11,7 @@ def flash_split_ref(Q, Q_pe, KV, K_pe): block_N = 64 seqlen_kv = KV.size(1) - scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) acc_s = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float) acc_s_cast = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float16) acc_o = torch.empty((batch, nheads, dim), device="cuda", dtype=torch.float) @@ -31,18 +31,20 @@ def flash_split_ref(Q, Q_pe, KV, K_pe): for ks in range(num_split): acc_o.fill_(0) logsum.fill_(0) - scores_max.fill_(float('-inf')) - scores_max_prev.fill_(float('-inf')) + scores_max.fill_(float("-inf")) + scores_max_prev.fill_(float("-inf")) for i in range(int((seqlen_kv // num_split) / block_N)): acc_s.fill_(0) - acc_s = torch.einsum('bhd,bkhd->bhk', Q_, - KV_[:, (seqlen_kv // num_split) * ks + - i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) # [batch, nheads, block_N] + acc_s = torch.einsum( + "bhd,bkhd->bhk", + Q_, + KV_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) # [batch, nheads, block_N] acc_s += torch.einsum( - 'bhd,bkhd->bhk', Q_pe_, - K_pe_[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) + "bhd,bkhd->bhk", + Q_pe_, + K_pe_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) scores_max_prev = scores_max scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads] scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads] @@ -50,9 +52,10 @@ def flash_split_ref(Q, Q_pe, KV, K_pe): acc_s = torch.exp2(acc_s - scores_max[:, :, None]) acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N] acc_o += torch.einsum( - 'bhk,bkhd->bhd', acc_s_cast, - KV_[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) + "bhk,bkhd->bhd", + acc_s_cast, + KV_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) scores_sum = acc_s.sum(dim=-1, keepdim=False) logsum = logsum * scores_scale + scores_sum acc_o /= logsum[:, :, None] diff --git a/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py b/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py index daee39865..ca98d01be 100644 --- a/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py +++ b/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py @@ -14,21 +14,44 @@ from fla.utils import autocast_custom_fwd, contiguous -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) @triton.jit -def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices, - block_counts, offsets, token_indices, T, H: tl.constexpr, - HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, - S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, USE_OFFSETS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -40,20 +63,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc NS = S - p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), - (G, BV), (1, 0)) + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_slc = tl.zeros([G, BV], dtype=tl.float32) - b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_slc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS @@ -66,7 +87,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) # [G, BS] b_s_slc = tl.dot(b_q, b_k_slc) - b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf')) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) # [G] b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc @@ -87,7 +108,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -100,8 +120,7 @@ def forward(ctx, q, k, v, block_indices, block_size, scale, offsets): # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] token_indices = prepare_token_indices(offsets) if offsets is not None else None - o, lse = parallel_nsa_fwd( - q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) + o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) ctx.save_for_backward(q, k, v, o, lse) ctx.block_indices = block_indices ctx.block_size = block_size @@ -172,7 +191,6 @@ def parallel_nsa_fwd( @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -195,7 +213,8 @@ def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, window_size=window_size, scale=scale, offsets=offsets, - token_indices=token_indices) + token_indices=token_indices, + ) ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) ctx.block_indices = block_indices ctx.block_counts = block_counts @@ -207,18 +226,20 @@ def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa -def parallel_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -258,44 +279,44 @@ def parallel_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, - window_size, scale, cu_seqlens) + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) else: o = o_slc * g_slc.unsqueeze(-1) if head_first: - o = rearrange(o, 'b t h d -> b h t d') + o = rearrange(o, "b t h d -> b h t d") return o -def naive_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def naive_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -335,26 +356,24 @@ def naive_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - raise RuntimeError( - "Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") dtype = q.dtype G = q.shape[2] // k.shape[2] BS = block_size S = block_indices.shape[-1] - k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices)) + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) if isinstance(block_counts, torch.Tensor): - block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G) + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) q, k, v = map(lambda x: x.float(), (q, k, v)) @@ -364,14 +383,11 @@ def naive_nsa(q: torch.Tensor, if cu_seqlens is None: varlen = False B, T = q.shape[:2] - cu_seqlens = torch.cat( - [block_indices.new_tensor(range(0, B * T, T)), - block_indices.new_tensor([B * T])]) + cu_seqlens = torch.cat([block_indices.new_tensor(range(0, B * T, T)), block_indices.new_tensor([B * T])]) for i in range(len(cu_seqlens) - 1): if not varlen: - q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[ - i], block_indices[i] + q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[i], block_indices[i] if isinstance(block_counts, torch.Tensor): s_b = block_counts[i] else: @@ -379,10 +395,10 @@ def naive_nsa(q: torch.Tensor, else: T = cu_seqlens[i + 1] - cu_seqlens[i] q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map( - lambda x: x[0][cu_seqlens[i]:cu_seqlens[i + 1]], - (q, k, v, g_slc, g_swa, block_indices)) + lambda x: x[0][cu_seqlens[i] : cu_seqlens[i + 1]], (q, k, v, g_slc, g_swa, block_indices) + ) if isinstance(block_counts, torch.Tensor): - s_b = block_counts[0][cu_seqlens[i]:cu_seqlens[i + 1]] + s_b = block_counts[0][cu_seqlens[i] : cu_seqlens[i + 1]] else: s_b = block_counts @@ -404,71 +420,58 @@ def naive_nsa(q: torch.Tensor, else: s_i = s_b # [S*BS, HQ, -1] - k_i_slc, v_i_slc = map( - lambda x: x.gather( - 0, - i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) + k_i_slc, v_i_slc = map(lambda x: x.gather(0, i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) # [S*BS, HQ] - attn_slc = torch.einsum('h d, n h d -> n h', q_i, k_i_slc).masked_fill( - torch.logical_or(i_i < 0, i_i > i_q) | - (c >= s_i if block_counts is not None else False), float('-inf')).softmax(0) + attn_slc = ( + torch.einsum("h d, n h d -> n h", q_i, k_i_slc) + .masked_fill(torch.logical_or(i_i < 0, i_i > i_q) | (c >= s_i if block_counts is not None else False), float("-inf")) + .softmax(0) + ) if not varlen: - o_slc[i, i_q] = torch.einsum('n h, n h v -> h v', attn_slc, - v_i_slc) * g_slc_i.unsqueeze(-1) + o_slc[i, i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) else: - o_slc[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_slc, - v_i_slc) * g_slc_i.unsqueeze(-1) + o_slc[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) if window_size > 0: - k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1):i_q + 1], - (k_b, v_b)) - attn_swa = torch.einsum('h d, n h d -> n h', q_i, k_i_swa).softmax(0) + k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1) : i_q + 1], (k_b, v_b)) + attn_swa = torch.einsum("h d, n h d -> n h", q_i, k_i_swa).softmax(0) if not varlen: - o_swa[i, i_q] = torch.einsum('n h, n h v -> h v', attn_swa, - v_i_swa) * g_swa_i.unsqueeze(-1) + o_swa[i, i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) else: - o_swa[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_swa, - v_i_swa) * g_swa_i.unsqueeze(-1) + o_swa[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) if head_first: - o_slc = rearrange(o_slc, 'b t h d -> b h t d') - o_swa = rearrange(o_swa, 'b t h d -> b h t d') + o_slc = rearrange(o_slc, "b t h d -> b h t d") + o_swa = rearrange(o_swa, "b t h d -> b h t d") return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype) def get_configs(): import itertools + iter_params = dict( block_T=[128, 256, 512], num_stages=[0, 1, 2, 4, 5], threads=[32, 64, 128, 256, 512], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] -@tilelang.autotune(configs=get_configs(),) +@tilelang.autotune( + configs=get_configs(), +) @tilelang.jit( pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) -def tilelang_sparse_attention(batch, - heads, - seq_len, - dim, - is_causal, - scale=None, - block_size=64, - groups=1, - selected_blocks=16, - block_T=128, - num_stages=2, - threads=32): + } +) +def tilelang_sparse_attention( + batch, heads, seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16, block_T=128, num_stages=2, threads=32 +): if scale is None: - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) else: scale = scale * 1.44269504 # log2(e) @@ -476,9 +479,9 @@ def tilelang_sparse_attention(batch, q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] block_indices_shape = [batch, seq_len, head_kv, selected_blocks] - block_indices_dtype = "int32" - dtype = "float16" - accum_dtype = "float" + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 block_S = block_size block_T = min(block_T, tilelang.math.next_power_of_2(dim)) @@ -493,11 +496,11 @@ def tilelang_sparse_attention(batch, @T.prim_func def tilelang_sparse_attention( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([G, BK], dtype) @@ -514,13 +517,11 @@ def tilelang_sparse_attention( scores_sum = T.alloc_fragment([G], accum_dtype) logsum = T.alloc_fragment([G], accum_dtype) - T.annotate_layout({O_shared: tilelang.layout.make_swizzled_layout(O_shared)}) - i_t, i_v, i_bh = bx, by, bz i_b, i_h = i_bh // head_kv, i_bh % head_kv NS = S - T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared) + T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -530,21 +531,15 @@ def tilelang_sparse_attention( i_s = BlockIndices[i_b, i_t, i_h, i] * BS if i_s <= i_t and i_s >= 0: # [BS, BK] - T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared) + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) if is_causal: for i, j in T.Parallel(G, BS): - acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) # Softmax T.copy(scores_max, scores_max_prev) @@ -564,45 +559,33 @@ def tilelang_sparse_attention( acc_o[i, j] *= scores_scale[i] # V * softmax(Q * K) - T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(G, BV): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV]) + T.copy(O_shared, Output[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) return tilelang_sparse_attention def generate_block_indices(batch, seq_len, heads, selected_blocks, block_size): """Generate random block indices for the benchmark.""" - block_indices = torch.full((batch, seq_len, heads, selected_blocks), - seq_len, - dtype=torch.long, - device='cuda') + block_indices = torch.full((batch, seq_len, heads, selected_blocks), seq_len, dtype=torch.long, device="cuda") for b in range(batch): for t in range(seq_len): for h in range(heads): i_i = torch.randperm(max(1, (t // block_size)))[:selected_blocks] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i return block_indices.sort(-1)[0] -def benchmark_nsa(batch_size, - seq_len, - heads, - head_query, - dim, - selected_blocks, - block_size, - dtype, - scale, - warmup=10, - iterations=100, - validate=False): +def benchmark_nsa( + batch_size, seq_len, heads, head_query, dim, selected_blocks, block_size, dtype, scale, warmup=10, iterations=100, validate=False +): """Benchmark the TileLang Sparse Attention implementation.""" # Set random seed for reproducibility @@ -628,14 +611,13 @@ def benchmark_nsa(batch_size, print(f"Profiler latency: {profiler_latency} ms") # Create input tensors - Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda') - K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda') - V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda') - out = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda') + Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") + K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + out = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") # Generate block indices - block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, - block_size).to(torch.int32) + block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, block_size).to(torch.int32) # Warmup for _ in range(warmup): @@ -666,10 +648,9 @@ def benchmark_nsa(batch_size, # Validate result against reference if requested if validate: - g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda') - g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda') - block_counts = torch.randint( - 1, selected_blocks + 1, (batch_size, seq_len, heads), device='cuda') + g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") + g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") + block_counts = torch.randint(1, selected_blocks + 1, (batch_size, seq_len, heads), device="cuda") ref = naive_nsa( q=Q, @@ -700,22 +681,13 @@ def benchmark_nsa(batch_size, "head_query": head_query, "dim": dim, "selected_blocks": selected_blocks, - "block_size": block_size + "block_size": block_size, } -def benchmark_triton_nsa(batch_size, - seq_len, - heads, - head_query, - dim, - selected_blocks, - block_size, - dtype, - scale, - warmup=10, - iterations=100, - validate=False): +def benchmark_triton_nsa( + batch_size, seq_len, heads, head_query, dim, selected_blocks, block_size, dtype, scale, warmup=10, iterations=100, validate=False +): """Benchmark the Triton-based TileLang Sparse Attention implementation.""" # Set random seed for reproducibility @@ -723,18 +695,17 @@ def benchmark_triton_nsa(batch_size, torch.random.manual_seed(0) # Create input tensors - Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda') - K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda') - V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda') - g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda') - g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda') + Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") + K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") + g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") # Generate block indices block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, block_size) - block_counts = torch.randint( - 1, selected_blocks + 1, (batch_size, seq_len, heads), device='cuda') - o_slc = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda') - lse_slc = torch.empty((batch_size, seq_len, head_query), dtype=torch.float, device='cuda') + block_counts = torch.randint(1, selected_blocks + 1, (batch_size, seq_len, heads), device="cuda") + o_slc = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") + lse_slc = torch.empty((batch_size, seq_len, head_query), dtype=torch.float, device="cuda") # Warmup for _ in range(warmup): @@ -750,7 +721,8 @@ def benchmark_triton_nsa(batch_size, block_counts=block_counts, block_size=block_size, window_size=0, - scale=scale) + scale=scale, + ) # Synchronize before timing torch.cuda.synchronize() @@ -770,7 +742,8 @@ def benchmark_triton_nsa(batch_size, block_counts=block_counts, block_size=block_size, window_size=0, - scale=scale) + scale=scale, + ) torch.cuda.synchronize() end_time = time.time() @@ -815,54 +788,28 @@ def benchmark_triton_nsa(batch_size, "head_query": head_query, "dim": dim, "selected_blocks": selected_blocks, - "block_size": block_size + "block_size": block_size, } -def run_benchmark_suite(impl='all'): +def run_benchmark_suite(impl="all"): """Run a suite of benchmarks with different configurations.""" # Define configurations to benchmark configs = [ # Small model config - Note: head_query must be a multiple of heads*16 for Triton - { - "batch_size": 2, - "seq_len": 1024, - "heads": 8, - "head_query": 8 * 16, - "dim": 64, - "selected_blocks": 8, - "block_size": 32 - }, - + {"batch_size": 2, "seq_len": 1024, "heads": 8, "head_query": 8 * 16, "dim": 64, "selected_blocks": 8, "block_size": 32}, # Medium model config - { - "batch_size": 2, - "seq_len": 2048, - "heads": 16, - "head_query": 16 * 16, - "dim": 64, - "selected_blocks": 16, - "block_size": 64 - }, - + {"batch_size": 2, "seq_len": 2048, "heads": 16, "head_query": 16 * 16, "dim": 64, "selected_blocks": 16, "block_size": 64}, # Large model config - { - "batch_size": 1, - "seq_len": 4096, - "heads": 32, - "head_query": 32 * 16, - "dim": 128, - "selected_blocks": 32, - "block_size": 128 - }, + {"batch_size": 1, "seq_len": 4096, "heads": 32, "head_query": 32 * 16, "dim": 128, "selected_blocks": 32, "block_size": 128}, ] results = [] for config in configs: print(f"Running benchmark with config: {config}") - if impl in ['all', 'tilelang']: + if impl in ["all", "tilelang"]: print("Benchmarking TileLang implementation:") result = benchmark_nsa( batch_size=config["batch_size"], @@ -874,12 +821,13 @@ def run_benchmark_suite(impl='all'): block_size=config["block_size"], dtype=torch.float16, scale=0.1, - validate=False) + validate=False, + ) results.append({"impl": "tilelang", **result}) print(f"Average time: {result['avg_time_ms']:.2f} ms") print(f"Performance: {result['tflops']:.2f} TFLOPs") - if impl in ['all', 'triton']: + if impl in ["all", "triton"]: print("Benchmarking Triton implementation:") result = benchmark_triton_nsa( batch_size=config["batch_size"], @@ -891,19 +839,24 @@ def run_benchmark_suite(impl='all'): block_size=config["block_size"], dtype=torch.float16, scale=0.1, - validate=False) + validate=False, + ) results.append({"impl": "triton", **result}) print(f"Average time: {result['avg_time_ms']:.2f} ms") print(f"Performance: {result['tflops']:.2f} TFLOPs") - if impl in ['all']: + if impl in ["all"]: # Print comparison if both implementations were run tilelang_result = next( - r for r in results if r["impl"] == "tilelang" and - r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"]) + r + for r in results + if r["impl"] == "tilelang" and r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"] + ) triton_result = next( - r for r in results if r["impl"] == "triton" and - r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"]) + r + for r in results + if r["impl"] == "triton" and r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"] + ) speedup = tilelang_result["avg_time_ms"] / triton_result["avg_time_ms"] print(f"Speedup (Triton vs TileLang): {speedup:.2f}x") @@ -921,8 +874,7 @@ def run_benchmark_suite(impl='all'): parser.add_argument("--dim", type=int, default=128, help="Head dimension") parser.add_argument("--selected_blocks", type=int, default=16, help="Number of selected blocks") parser.add_argument("--block_size", type=int, default=32, help="Block size") - parser.add_argument( - "--dtype", type=str, default="float16", help="Data type (float16 or float32)") + parser.add_argument("--dtype", type=str, default=T.float16, help="Data type (float16 or float32)") parser.add_argument("--scale", type=float, default=0.1, help="Attention scale factor") parser.add_argument("--iterations", type=int, default=100, help="Number of iterations") parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations") @@ -933,7 +885,8 @@ def run_benchmark_suite(impl='all'): type=str, default="all", choices=["tilelang", "triton", "all"], - help="Implementation to benchmark (tilelang, triton, or all)") + help="Implementation to benchmark (tilelang, triton, or all)", + ) args = parser.parse_args() @@ -941,13 +894,12 @@ def run_benchmark_suite(impl='all'): if args.impl in ["triton", "all"] and args.head_query % (args.heads * 16) != 0: # Adjust head_query to nearest valid value args.head_query = ((args.head_query // (args.heads * 16)) + 1) * (args.heads * 16) - print( - f"Adjusted head_query to {args.head_query} to be compatible with Triton implementation") + print(f"Adjusted head_query to {args.head_query} to be compatible with Triton implementation") if args.suite: run_benchmark_suite(impl=args.impl) else: - dtype = torch.float16 if args.dtype == "float16" else torch.float32 + dtype = torch.float16 if args.dtype == T.float16 else torch.float32 if args.impl in ["tilelang", "all"]: print("Benchmarking TileLang implementation:") @@ -963,12 +915,14 @@ def run_benchmark_suite(impl='all'): scale=args.scale, warmup=args.warmup, iterations=args.iterations, - validate=args.validate) + validate=args.validate, + ) print("\nBenchmark Results (TileLang):") print( - f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " + - f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " + - f"block_size={args.block_size}") + f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " + + f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " + + f"block_size={args.block_size}" + ) print(f"Average time: {result['avg_time_ms']:.2f} ms") print(f"Performance: {result['tflops']:.2f} TFLOPs") @@ -986,11 +940,13 @@ def run_benchmark_suite(impl='all'): scale=args.scale, warmup=args.warmup, iterations=args.iterations, - validate=args.validate) + validate=args.validate, + ) print("\nBenchmark Results (Triton):") print( - f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " + - f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " + - f"block_size={args.block_size}") + f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " + + f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " + + f"block_size={args.block_size}" + ) print(f"Average time: {result['avg_time_ms']:.2f} ms") print(f"Performance: {result['tflops']:.2f} TFLOPs") diff --git a/examples/deepseek_nsa/example_tilelang_nsa_bwd.py b/examples/deepseek_nsa/example_tilelang_nsa_bwd.py index 8387d2271..3da285a9b 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_bwd.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_bwd.py @@ -7,6 +7,7 @@ import triton import fla + if parse(fla.__version__) < parse("0.2.1"): from fla.ops.common.utils import prepare_token_indices else: @@ -22,7 +23,8 @@ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + } +) def tilelang_kernel_fwd( batch, heads, @@ -34,11 +36,10 @@ def tilelang_kernel_fwd( groups=1, selected_blocks=16, ): - from tilelang import language as T if scale is None: - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) else: scale = scale * 1.44269504 # log2(e) @@ -48,9 +49,9 @@ def tilelang_kernel_fwd( o_slc_shape = [batch, seq_len, heads, dim] lse_slc_shape = [batch, seq_len, heads] block_indices_shape = [batch, seq_len, head_kv, selected_blocks] - block_indices_dtype = "int32" - dtype = "float16" - accum_dtype = "float" + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 block_S = block_size block_T = min(128, tilelang.math.next_power_of_2(dim)) @@ -67,12 +68,12 @@ def tilelang_kernel_fwd( @T.prim_func def native_sparse_attention( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), - O_slc: T.Tensor(o_slc_shape, dtype), - LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + O_slc: T.Tensor(o_slc_shape, dtype), + LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), ): with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([G, BK], dtype) @@ -93,7 +94,7 @@ def native_sparse_attention( i_b, i_h = i_bh // head_kv, i_bh % head_kv NS = S - T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared) + T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -103,12 +104,11 @@ def native_sparse_attention( i_s = BlockIndices[i_b, i_t, i_h, i] * BS if i_s <= i_t and i_s >= 0: # [BS, BK] - T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared) + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) if is_causal: - for i, j in T.Parallel(G, BS): - acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, - -T.infinity(acc_s.dtype)) + for k, j in T.Parallel(G, BS): + acc_s[k, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) @@ -124,21 +124,21 @@ def native_sparse_attention( T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=True) - for i in T.Parallel(G): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(G, BS): - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + for k in T.Parallel(G): + scores_scale[k] = T.exp2(scores_max_prev[k] * scale - scores_max[k] * scale) + for k, j in T.Parallel(G, BS): + acc_s[k, j] = T.exp2(acc_s[k, j] * scale - scores_max[k] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(G): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for k in T.Parallel(G): + logsum[k] = logsum[k] * scores_scale[k] + scores_sum[k] T.copy(acc_s, acc_s_cast) # Rescale - for i, j in T.Parallel(G, BV): - acc_o[i, j] *= scores_scale[i] + for k, j in T.Parallel(G, BV): + acc_o[k, j] *= scores_scale[k] # V * softmax(Q * K) - T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(G, BV): @@ -146,18 +146,20 @@ def native_sparse_attention( T.copy(acc_o, O_shared) T.copy( O_shared, - O_slc[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV], + O_slc[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV], ) for i in T.Parallel(G): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, LSE_slc[i_b, i_t, i_h * G:(i_h + 1) * G]) + T.copy(logsum, LSE_slc[i_b, i_t, i_h * G : (i_h + 1) * G]) return native_sparse_attention -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def tilelang_kernel_bwd_dkv( batch, heads, @@ -168,11 +170,11 @@ def tilelang_kernel_bwd_dkv( block_size=64, groups=1, selected_blocks=16, - dtype="float16", - accum_dtype="float", + dtype=T.float16, + accum_dtype=T.float32, ): if scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 else: sm_scale = scale @@ -207,15 +209,15 @@ def tilelang_kernel_bwd_dkv( @T.prim_func def flash_bwd_dkv( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(k_shape, dtype), - V: T.Tensor(v_shape, dtype), - LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), - Delta_slc: T.Tensor(delta_slc_shape, accum_dtype), - DO_slc: T.Tensor(do_slc_shape, dtype), - DK: T.Tensor(dk_shape, dtype), - DV: T.Tensor(dv_shape, dtype), - BlockMask: T.Tensor(block_mask_shape, "int32"), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(k_shape, dtype), + V: T.Tensor(v_shape, dtype), + LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), + Delta_slc: T.Tensor(delta_slc_shape, accum_dtype), + DO_slc: T.Tensor(do_slc_shape, dtype), + DK: T.Tensor(dk_shape, dtype), + DV: T.Tensor(dv_shape, dtype), + BlockMask: T.Tensor(block_mask_shape, T.int32), ): with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): K_shared = T.alloc_shared([BS, BK], dtype) @@ -238,31 +240,25 @@ def flash_bwd_dkv( i_b, i_h = i_bh // H, i_bh % H - T.copy(K[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK], K_shared) - T.copy(V[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV], V_shared) + T.copy(K[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK], K_shared) + T.copy(V[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV], V_shared) # [BS, BK] T.clear(dk) # [BS, BV] T.clear(dv) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) - loop_st = i_s * BS loop_ed = seq_len for i in T.Pipelined( - start=loop_st, - stop=loop_ed, - num_stages=0, + start=loop_st, + stop=loop_ed, + num_stages=0, ): b_m_slc = BlockMask[i_b, i, i_h, i_s] if b_m_slc != 0: # [G, BK] - T.copy(Q[i_b, i, i_h * G:(i_h + 1) * G, :BK], Q_shared) + T.copy(Q[i_b, i, i_h * G : (i_h + 1) * G, :BK], Q_shared) T.clear(qkT) # [BS, BK] @ [G, BK] -> [BS, G] T.gemm( @@ -273,7 +269,7 @@ def flash_bwd_dkv( policy=T.GemmWarpPolicy.FullRow, ) # [G] - T.copy(LSE_slc[i_b, i, i_h * G:(i_h + 1) * G], lse_shared) + T.copy(LSE_slc[i_b, i, i_h * G : (i_h + 1) * G], lse_shared) for _i, _j in T.Parallel(BS, G): qkT[_i, _j] = T.exp2(qkT[_i, _j] * scale - lse_shared[_j]) @@ -282,7 +278,7 @@ def flash_bwd_dkv( qkT[_i, _j] = T.if_then_else(i >= (i_s * BS + _i), qkT[_i, _j], 0) # [G, BV] - T.copy(DO_slc[i_b, i, i_h * G:(i_h + 1) * G, :BV], do) + T.copy(DO_slc[i_b, i, i_h * G : (i_h + 1) * G, :BV], do) T.clear(dsT) # [BS, BV] @ [G, BV] -> [BS, G] T.gemm( @@ -296,7 +292,7 @@ def flash_bwd_dkv( # [BS, G] @ [G, BV] -> [BS, BV] T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) # [G] - T.copy(Delta_slc[i_b, i, i_h * G:(i_h + 1) * G], delta) + T.copy(Delta_slc[i_b, i, i_h * G : (i_h + 1) * G], delta) for i, j in T.Parallel(BS, G): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -305,8 +301,8 @@ def flash_bwd_dkv( T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, DV[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV]) - T.copy(dk_shared, DK[i_v, i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK]) + T.copy(dv_shared, DV[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV]) + T.copy(dk_shared, DK[i_v, i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK]) return flash_bwd_dkv @@ -321,9 +317,11 @@ def make_dq_layout(dQ): ) -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def tilelang_kernel_bwd_dqkv( batch, heads, @@ -334,11 +332,11 @@ def tilelang_kernel_bwd_dqkv( block_size=64, groups=1, selected_blocks=16, - dtype="float16", - accum_dtype="float", + dtype=T.float16, + accum_dtype=T.float32, ): if scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 else: sm_scale = scale @@ -373,16 +371,16 @@ def tilelang_kernel_bwd_dqkv( @T.prim_func def flash_bwd_dqkv( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(k_shape, dtype), - V: T.Tensor(v_shape, dtype), - LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), - Delta_slc: T.Tensor(delta_slc_shape, accum_dtype), - DO_slc: T.Tensor(do_slc_shape, dtype), - DQ: T.Tensor(dq_shape, dtype), - DK: T.Tensor(dk_shape, dtype), - DV: T.Tensor(dv_shape, dtype), - BlockMask: T.Tensor(block_mask_shape, "int32"), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(k_shape, dtype), + V: T.Tensor(v_shape, dtype), + LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), + Delta_slc: T.Tensor(delta_slc_shape, accum_dtype), + DO_slc: T.Tensor(do_slc_shape, dtype), + DQ: T.Tensor(dq_shape, dtype), + DK: T.Tensor(dk_shape, dtype), + DV: T.Tensor(dv_shape, dtype), + BlockMask: T.Tensor(block_mask_shape, T.int32), ): with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): K_shared = T.alloc_shared([BS, BK], dtype) @@ -406,31 +404,25 @@ def flash_bwd_dqkv( i_b, i_h = i_bh // H, i_bh % H - T.copy(K[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK], K_shared) - T.copy(V[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV], V_shared) + T.copy(K[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK], K_shared) + T.copy(V[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV], V_shared) # [BS, BK] T.clear(dk) # [BS, BV] T.clear(dv) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) - loop_st = i_s * BS loop_ed = seq_len for i in T.Pipelined( - start=loop_st, - stop=loop_ed, - num_stages=0, + start=loop_st, + stop=loop_ed, + num_stages=0, ): b_m_slc = BlockMask[i_b, i, i_h, i_s] if b_m_slc != 0: # [G, BK] - T.copy(Q[i_b, i, i_h * G:(i_h + 1) * G, :BK], Q_shared) + T.copy(Q[i_b, i, i_h * G : (i_h + 1) * G, :BK], Q_shared) T.clear(qkT) # [BS, BK] @ [G, BK] -> [BS, G] T.gemm( @@ -441,7 +433,7 @@ def flash_bwd_dqkv( policy=T.GemmWarpPolicy.FullRow, ) # [G] - T.copy(LSE_slc[i_b, i, i_h * G:(i_h + 1) * G], lse_shared) + T.copy(LSE_slc[i_b, i, i_h * G : (i_h + 1) * G], lse_shared) for _i, _j in T.Parallel(BS, G): qkT[_i, _j] = T.exp2(qkT[_i, _j] * scale - lse_shared[_j]) @@ -450,7 +442,7 @@ def flash_bwd_dqkv( qkT[_i, _j] = T.if_then_else(i >= (i_s * BS + _i), qkT[_i, _j], 0) # [G, BV] - T.copy(DO_slc[i_b, i, i_h * G:(i_h + 1) * G, :BV], do) + T.copy(DO_slc[i_b, i, i_h * G : (i_h + 1) * G, :BV], do) T.clear(dsT) # [BS, BV] @ [G, BV] -> [BS, G] T.gemm( @@ -464,9 +456,9 @@ def flash_bwd_dqkv( # [BS, G] @ [G, BV] -> [BS, BV] T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) # [G] - T.copy(Delta_slc[i_b, i, i_h * G:(i_h + 1) * G], delta) - for i, j in T.Parallel(BS, G): - dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.copy(Delta_slc[i_b, i, i_h * G : (i_h + 1) * G], delta) + for _i, _j in T.Parallel(BS, G): + dsT_cast[_i, _j] = qkT[_i, _j] * (dsT[_i, _j] - delta[_j]) * sm_scale # [BS, G] @ [G, BK] -> [BS, BK] T.gemm(dsT_cast, Q_shared, dk, policy=T.GemmWarpPolicy.FullRow) @@ -480,23 +472,25 @@ def flash_bwd_dqkv( T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, DV[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV]) - T.copy(dk_shared, DK[i_v, i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK]) + T.copy(dv_shared, DV[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV]) + T.copy(dk_shared, DK[i_v, i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK]) return flash_bwd_dqkv @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def tilelang_kernel_preprocess( batch, heads, seq_len, dim, - dtype="float16", - accum_dtype="float", + dtype=T.float16, + accum_dtype=T.float32, blk=32, ): from tilelang import language as T @@ -505,9 +499,9 @@ def tilelang_kernel_preprocess( @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, seq_len, heads], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, seq_len, heads], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -516,27 +510,29 @@ def flash_bwd_prep( delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, by * blk:(by + 1) * blk, bx]) + T.copy(delta, Delta[bz, by * blk : (by + 1) * blk, bx]) return flash_bwd_prep @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def tilelang_kernel_block_mask( batch, heads, seq_len, selected_blocks, block_size, - dtype="int32", + dtype=T.int32, ): from tilelang import language as T @@ -551,9 +547,9 @@ def tilelang_kernel_block_mask( @T.prim_func def flash_bwd_block_mask( - BlockIndices: T.Tensor(block_indices_shape, dtype), # type: ignore - BlockCounts: T.Tensor(block_counts_shape, dtype), # type: ignore - BlockMask: T.Tensor(block_mask_shape, dtype), # type: ignore + BlockIndices: T.Tensor(block_indices_shape, dtype), # type: ignore + BlockCounts: T.Tensor(block_counts_shape, dtype), # type: ignore + BlockMask: T.Tensor(block_mask_shape, dtype), # type: ignore ): with T.Kernel(seq_len, batch, heads * S) as (bx, by, bz): i_t, i_b, i_hs = bx, by, bz @@ -603,9 +599,7 @@ def parallel_nsa_bwd( dk = torch.empty(NV, *k.shape, dtype=k.dtype, device=q.device) dv = torch.empty(v.shape, dtype=v.dtype, device=q.device) - block_mask = tilelang_kernel_block_mask(B, H, T, S, - BS)(block_indices.to(torch.int32), - block_counts.to(torch.int32)).to(torch.bool) + block_mask = tilelang_kernel_block_mask(B, H, T, S, BS)(block_indices.to(torch.int32), block_counts.to(torch.int32)).to(torch.bool) fused_qkv_bwd_kernel = tilelang_kernel_bwd_dqkv( batch=B, @@ -618,8 +612,7 @@ def parallel_nsa_bwd( selected_blocks=S, scale=scale, ) - fused_qkv_bwd_kernel(q, k, v, lse_slc, delta_slc, do_slc, dq, dk, dv, - block_mask.to(torch.int32)) + fused_qkv_bwd_kernel(q, k, v, lse_slc, delta_slc, do_slc, dq, dk, dv, block_mask.to(torch.int32)) dq = dq.sum(0) dk = dk.sum(0) @@ -628,7 +621,6 @@ def parallel_nsa_bwd( @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -773,23 +765,21 @@ def parallel_nsa( Outputs of shape `[B, SEQLEN, HQ, V]` if `head_first=False` else `[B, HQ, SEQLEN, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), - (q, k, v, block_indices)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): block_counts = rearrange(block_counts, "b h t -> b t h") - assert (q.shape[2] % (k.shape[2] * 16) == 0), "Group size must be a multiple of 16 in NSA" + assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, - window_size, scale, cu_seqlens) + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) else: @@ -814,7 +804,7 @@ def parallel_nsa( for t in range(T): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda") diff --git a/examples/deepseek_nsa/example_tilelang_nsa_decode.py b/examples/deepseek_nsa/example_tilelang_nsa_decode.py index 58f435509..381d92493 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_decode.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_decode.py @@ -16,7 +16,8 @@ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def native_sparse_attention( batch, heads, @@ -25,18 +26,18 @@ def native_sparse_attention( scale=None, block_size=64, # Tile size for attention computation groups=1, # Grouped query attention (GQA) groups - selected_blocks=16 # Number of blocks to select per attention head + selected_blocks=16, # Number of blocks to select per attention head ): if scale is None: - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups # Modified shapes for inference (q has seq_len=1)a q_shape = [batch, 1, heads, dim] # Changed seq_len to 1 kv_shape = [batch, seq_len, head_kv, dim] block_indices_shape = [batch, 1, head_kv, selected_blocks] # Changed seq_len to 1 - block_indices_dtype = "int32" - dtype = "float16" - accum_dtype = "float" + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 block_S = block_size block_T = min(128, tilelang.math.next_power_of_2(dim)) @@ -53,12 +54,11 @@ def native_sparse_attention( @T.prim_func def native_sparse_attention( - Q: T.Tensor(q_shape, dtype), # [batch, 1, heads, dim] - K: T.Tensor(kv_shape, dtype), # [batch, seq_len, head_kv, dim] - V: T.Tensor(kv_shape, dtype), # Same shape as K - BlockIndices: T.Tensor(block_indices_shape, - block_indices_dtype), # Selected block indices - Output: T.Tensor(q_shape, dtype), # Output attention tensor + Q: T.Tensor(q_shape, dtype), # [batch, 1, heads, dim] + K: T.Tensor(kv_shape, dtype), # [batch, seq_len, head_kv, dim] + V: T.Tensor(kv_shape, dtype), # Same shape as K + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), # Selected block indices + Output: T.Tensor(q_shape, dtype), # Output attention tensor ): with T.Kernel(1, NV, batch * head_kv, threads=threads) as (bx, by, bz): # Shared memory allocations for tile storage @@ -82,7 +82,7 @@ def native_sparse_attention( NS = S # Copy Q for the single position - T.copy(Q[i_b, 0, i_h * G:(i_h + 1) * G, :], Q_shared) # Changed i_t to 0 + T.copy(Q[i_b, 0, i_h * G : (i_h + 1) * G, :], Q_shared) # Changed i_t to 0 T.fill(acc_o, 0) T.fill(logsum, 0) @@ -93,16 +93,11 @@ def native_sparse_attention( i_s = BlockIndices[i_b, 0, i_h, i] * BS # Get block offset if i_s >= 0: # Skip invalid/padding blocks # Load current key block to shared memory - T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared) + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) # Compute QK^T attention scores T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) # Online softmax with numerical stability # 1. Compute max for scaling @@ -122,15 +117,14 @@ def native_sparse_attention( T.copy(acc_s, acc_s_cast) # Accumulate attention-weighted values - T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) # Final normalization and output for i, j in T.Parallel(G, BV): acc_o[i, j] /= logsum[i] # Normalize by logsum T.copy(acc_o, O_shared) - T.copy(O_shared, Output[i_b, 0, i_h * G:(i_h + 1) * G, - i_v * BV:(i_v + 1) * BV]) # Changed i_t to 0 + T.copy(O_shared, Output[i_b, 0, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) # Changed i_t to 0 return native_sparse_attention @@ -149,21 +143,21 @@ def main(): selected_blocks=S, ) - Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) - K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) - V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) + Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) - mask = torch.randint(0, 2, (B, SEQ_LEN, groups), device='cuda') - DO = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda') + mask = torch.randint(0, 2, (B, SEQ_LEN, groups), device="cuda") + DO = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda") - block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device='cuda') + block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device="cuda") for b in range(B): for t in range(SEQ_LEN_Q): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (B, SEQ_LEN_Q, H), device='cuda') + block_counts = torch.randint(1, S + 1, (B, SEQ_LEN_Q, H), device="cuda") out = kernel(Q, K, V, block_indices.to(torch.int32)) @@ -178,5 +172,38 @@ def main(): torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) +def run_regression_perf(): + B, SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 16, 1, 32, torch.float16 + groups = HQ // H + SEQ_LEN_Q = 1 + kernel = native_sparse_attention( + batch=B, + heads=HQ, + seq_len=SEQ_LEN, + dim=D, + block_size=block_size, + groups=HQ // H, + selected_blocks=S, + ) + + Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device="cuda") + for b in range(B): + for t in range(SEQ_LEN_Q): + for h in range(H): + i_i = torch.randperm(max(1, (t // block_size)))[:S] + block_indices[b, t, h, : len(i_i)] = i_i + block_indices = block_indices.sort(-1)[0] + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, block_indices.to(torch.int32)) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py index f8a7ebfb0..7b36d6e26 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py @@ -14,18 +14,11 @@ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) -def native_sparse_attention(batch, - heads, - seq_len, - dim, - is_causal, - scale=None, - block_size=64, - groups=1, - selected_blocks=16): + }, +) +def native_sparse_attention(batch, heads, seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16): if scale is None: - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) else: scale = scale * 1.44269504 # log2(e) @@ -33,9 +26,9 @@ def native_sparse_attention(batch, q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] block_indices_shape = [batch, seq_len, head_kv, selected_blocks] - block_indices_dtype = "int32" - dtype = "float16" - accum_dtype = "float" + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 block_S = block_size block_T = min(128, tilelang.math.next_power_of_2(dim)) @@ -52,11 +45,11 @@ def native_sparse_attention(batch, @T.prim_func def native_sparse_attention( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([G, BK], dtype) @@ -77,7 +70,7 @@ def native_sparse_attention( i_b, i_h = i_bh // head_kv, i_bh % head_kv NS = S - T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared) + T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -87,21 +80,15 @@ def native_sparse_attention( i_s = BlockIndices[i_b, i_t, i_h, i] * BS if i_s <= i_t and i_s >= 0: # [BS, BK] - T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared) + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) if is_causal: for i, j in T.Parallel(G, BS): - acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) # Softmax T.copy(scores_max, scores_max_prev) @@ -121,13 +108,13 @@ def native_sparse_attention( acc_o[i, j] *= scores_scale[i] # V * softmax(Q * K) - T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(G, BV): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV]) + T.copy(O_shared, Output[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) return native_sparse_attention @@ -148,21 +135,22 @@ def main(): ) print(kernel.get_kernel_source()) torch.random.manual_seed(0) - Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) - K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) - V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) - g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True) - DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda') - - block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device='cuda') + Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda") + block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda") for b in range(B): for t in range(SEQ_LEN): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i + block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item() block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (B, SEQ_LEN, H), device='cuda') out = kernel(Q, K, V, block_indices.to(torch.int32)) @@ -183,5 +171,43 @@ def main(): torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) +def run_regression_perf(): + B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 + kernel = native_sparse_attention( + batch=B, + heads=HQ, + seq_len=SEQ_LEN, + dim=D, + is_causal=True, + block_size=block_size, + groups=HQ // H, + selected_blocks=S, + scale=scale, + ) + torch.random.manual_seed(0) + Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda") + block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda") + block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda") + for b in range(B): + for t in range(SEQ_LEN): + for h in range(H): + i_i = torch.randperm(max(1, (t // block_size)))[:S] + block_indices[b, t, h, : len(i_i)] = i_i + block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item() + block_indices = block_indices.sort(-1)[0] + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, block_indices.to(torch.int32)) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py b/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py index d365e7a5f..b52ebe42e 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py @@ -8,6 +8,7 @@ import tilelang.testing import fla + if parse(fla.__version__) < parse("0.2.1"): from fla.ops.common.utils import prepare_token_indices else: @@ -21,18 +22,11 @@ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) -def native_sparse_attention_varlen(batch, - heads, - c_seq_len, - dim, - is_causal, - scale=None, - block_size=64, - groups=1, - selected_blocks=16): + } +) +def native_sparse_attention_varlen(batch, heads, c_seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16): if scale is None: - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [c_seq_len, heads, dim] kv_shape = [c_seq_len, head_kv, dim] @@ -44,12 +38,12 @@ def native_sparse_attention_varlen(batch, block_counts_shape = [c_seq_len, head_kv] offsets_shape = [batch + 1] token_indices_shape = [c_seq_len, 2] - block_indices_dtype = "int32" - block_counts_dtype = "int32" - offsets_dtype = "int32" - token_indices_dtype = "int32" - dtype = "float16" - accum_dtype = "float" + block_indices_dtype = T.int32 + block_counts_dtype = T.int32 + offsets_dtype = T.int32 + token_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 block_S = block_size block_T = min(128, tilelang.math.next_power_of_2(dim)) @@ -66,14 +60,14 @@ def native_sparse_attention_varlen(batch, @T.prim_func def native_sparse_attention_varlen( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - O_slc: T.Tensor(o_slc_shape, dtype), - BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), - BlockCounts: T.Tensor(block_counts_shape, block_counts_dtype), - Offsets: T.Tensor(offsets_shape, offsets_dtype), - TokenIndices: T.Tensor(token_indices_shape, token_indices_dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + O_slc: T.Tensor(o_slc_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + BlockCounts: T.Tensor(block_counts_shape, block_counts_dtype), + Offsets: T.Tensor(offsets_shape, offsets_dtype), + TokenIndices: T.Tensor(token_indices_shape, token_indices_dtype), ): with T.Kernel(c_seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([G, BK], dtype) @@ -100,7 +94,7 @@ def native_sparse_attention_varlen( current_seq_len = eos - bos NS = BlockCounts[i_t, i_h] - T.copy(Q[bos + i_t, i_h * G:(i_h + 1) * G, :BK], Q_shared) + T.copy(Q[bos + i_t, i_h * G : (i_h + 1) * G, :BK], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -112,21 +106,15 @@ def native_sparse_attention_varlen( # [BS, BK] # Lei: may have some padding issues # we should learn from mha varlen templates to handle this - T.copy(K[bos + i_s:bos + i_s + BS, i_h, :BK], K_shared) + T.copy(K[bos + i_s : bos + i_s + BS, i_h, :BK], K_shared) if is_causal: for i, j in T.Parallel(G, BS): - acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) # Softmax T.copy(scores_max, scores_max_prev) @@ -146,13 +134,13 @@ def native_sparse_attention_varlen( acc_o[i, j] *= scores_scale[i] # V * softmax(Q * K) - T.copy(V[bos + i_s:bos + i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) + T.copy(V[bos + i_s : bos + i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(G, BV): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, O_slc[bos + i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV]) + T.copy(O_shared, O_slc[bos + i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) return native_sparse_attention_varlen @@ -190,17 +178,20 @@ def parallel_nsa_fwd( o_slc = torch.empty(B, C_SEQ_LEN, HQ, V, dtype=v.dtype, device=q.device) kernel( - q.view(C_SEQ_LEN, HQ, D), k.view(C_SEQ_LEN, H, D), v.view(C_SEQ_LEN, H, D), + q.view(C_SEQ_LEN, HQ, D), + k.view(C_SEQ_LEN, H, D), + v.view(C_SEQ_LEN, H, D), o_slc.view(C_SEQ_LEN, HQ, V), block_indices.to(torch.int32).view(C_SEQ_LEN, H, S), - block_counts.to(torch.int32).view(C_SEQ_LEN, H), offsets.to(torch.int32), - token_indices.to(torch.int32)) + block_counts.to(torch.int32).view(C_SEQ_LEN, H), + offsets.to(torch.int32), + token_indices.to(torch.int32), + ) return o_slc @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets): ctx.dtype = q.dtype @@ -221,22 +212,25 @@ def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, window_size=window_size, scale=scale, offsets=offsets, - token_indices=token_indices) + token_indices=token_indices, + ) return o_slc.to(q.dtype) -def parallel_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -276,29 +270,27 @@ def parallel_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, - scale, cu_seqlens) + o_slc = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: assert False, "Window size is not supported yet" else: o = o_slc * g_slc.unsqueeze(-1) if head_first: - o = rearrange(o, 'b t h d -> b h t d') + o = rearrange(o, "b t h d -> b h t d") return o @@ -306,41 +298,57 @@ def parallel_nsa(q: torch.Tensor, N, C_SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 64, 1, 32, torch.float16 torch.manual_seed(42) # randomly split the sequence into N segments - offsets = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, C_SEQ_LEN)[torch.randperm(C_SEQ_LEN - 1)[:N - 1]], - torch.tensor([C_SEQ_LEN], dtype=torch.long) - ], 0).cuda().sort()[0] + offsets = ( + torch.cat( + [ + torch.tensor([0], dtype=torch.long), + torch.arange(16, C_SEQ_LEN)[torch.randperm(C_SEQ_LEN - 1)[: N - 1]], + torch.tensor([C_SEQ_LEN], dtype=torch.long), + ], + 0, + ) + .cuda() + .sort()[0] + ) # seq-first required for inputs with variable lengths - perm_q = torch.randperm(C_SEQ_LEN, device='cuda') - perm_k = torch.randperm(C_SEQ_LEN, device='cuda') - perm_v = torch.randperm(C_SEQ_LEN, device='cuda') - q = torch.linspace( - 0, 1, steps=C_SEQ_LEN, dtype=dtype, - device='cuda')[perm_q].view(1, C_SEQ_LEN, 1, 1).expand(1, C_SEQ_LEN, HQ, - D).clone().requires_grad_(True) - k = torch.linspace( - 0, 1, steps=C_SEQ_LEN, dtype=dtype, - device='cuda')[perm_k].view(1, C_SEQ_LEN, 1, 1).expand(1, C_SEQ_LEN, H, - D).clone().requires_grad_(True) - v = torch.linspace( - 0, 1, steps=C_SEQ_LEN, dtype=dtype, - device='cuda')[perm_v].view(1, C_SEQ_LEN, 1, 1).expand(1, C_SEQ_LEN, H, - D).clone().requires_grad_(True) - g_slc = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((1, C_SEQ_LEN, HQ, D), dtype=dtype, device='cuda') + perm_q = torch.randperm(C_SEQ_LEN, device="cuda") + perm_k = torch.randperm(C_SEQ_LEN, device="cuda") + perm_v = torch.randperm(C_SEQ_LEN, device="cuda") + q = ( + torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_q] + .view(1, C_SEQ_LEN, 1, 1) + .expand(1, C_SEQ_LEN, HQ, D) + .clone() + .requires_grad_(True) + ) + k = ( + torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_k] + .view(1, C_SEQ_LEN, 1, 1) + .expand(1, C_SEQ_LEN, H, D) + .clone() + .requires_grad_(True) + ) + v = ( + torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_v] + .view(1, C_SEQ_LEN, 1, 1) + .expand(1, C_SEQ_LEN, H, D) + .clone() + .requires_grad_(True) + ) + g_slc = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((1, C_SEQ_LEN, HQ, D), dtype=dtype, device="cuda") token_indices = prepare_token_indices(offsets).tolist() - block_indices = torch.full((1, C_SEQ_LEN, H, S), C_SEQ_LEN, dtype=torch.long, device='cuda') + block_indices = torch.full((1, C_SEQ_LEN, H, S), C_SEQ_LEN, dtype=torch.long, device="cuda") for i in range(C_SEQ_LEN): _, t = token_indices[i] for h in range(H): i_i = torch.randperm(max(1, tilelang.cdiv(t, block_size)))[:S] - block_indices[0, i, h, :len(i_i)] = i_i + block_indices[0, i, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (1, C_SEQ_LEN, H), device='cuda') + block_counts = torch.randint(1, S + 1, (1, C_SEQ_LEN, H), device="cuda") ref = naive_nsa( q=q, @@ -351,7 +359,8 @@ def parallel_nsa(q: torch.Tensor, block_indices=block_indices, block_counts=block_counts, block_size=block_size, - cu_seqlens=offsets) + cu_seqlens=offsets, + ) tri = parallel_nsa( q=q, @@ -362,7 +371,8 @@ def parallel_nsa(q: torch.Tensor, block_indices=block_indices, block_counts=block_counts, block_size=block_size, - cu_seqlens=offsets) + cu_seqlens=offsets, + ) print("tri", tri) print("ref", ref) diff --git a/examples/deepseek_nsa/example_triton_nsa_bwd.py b/examples/deepseek_nsa/example_triton_nsa_bwd.py index e912794a4..af05bfa70 100644 --- a/examples/deepseek_nsa/example_triton_nsa_bwd.py +++ b/examples/deepseek_nsa/example_triton_nsa_bwd.py @@ -8,6 +8,7 @@ import triton.language as tl import fla + if parse(fla.__version__) < parse("0.2.1"): from fla.ops.common.utils import prepare_token_indices else: @@ -17,21 +18,44 @@ from einops import rearrange -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) @triton.jit -def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices, - block_counts, offsets, token_indices, T, H: tl.constexpr, - HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, - S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, USE_OFFSETS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -46,20 +70,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc # else: NS = S - p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), - (G, BV), (1, 0)) + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_slc = tl.zeros([G, BV], dtype=tl.float32) - b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_slc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS @@ -72,7 +94,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) # [G, BS] b_s_slc = tl.dot(b_q, b_k_slc) - b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf')) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) # [G] b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc @@ -92,7 +114,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -105,8 +126,7 @@ def forward(ctx, q, k, v, block_indices, block_size, scale, offsets): # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] token_indices = prepare_token_indices(offsets) if offsets is not None else None - o, lse = parallel_nsa_fwd( - q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) + o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) ctx.save_for_backward(q, k, v, o, lse) ctx.block_indices = block_indices ctx.block_size = block_size @@ -134,7 +154,8 @@ def backward(ctx, do_slc, do_swa): window_size=ctx.window_size, scale=ctx.scale, offsets=ctx.offsets, - token_indices=ctx.token_indices) + token_indices=ctx.token_indices, + ) return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None @@ -199,37 +220,56 @@ def parallel_nsa_fwd( return o_slc, lse_slc, o_swa, lse_swa -@triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) +@triton.heuristics({"USE_OFFSETS": lambda args: args["offsets"] is not None}) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) -@triton.jit(do_not_specialize=['T']) -def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa, do_slc, do_swa, dk, - dv, block_mask, offsets, chunk_indices, scale, T, B: tl.constexpr, - H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, - V: tl.constexpr, M: tl.constexpr, BS: tl.constexpr, - WS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr): +@triton.jit(do_not_specialize=["T"]) +def parallel_nsa_bwd_kernel_dkv( + q, + k, + v, + lse_slc, + lse_swa, + delta_slc, + delta_swa, + do_slc, + do_swa, + dk, + dv, + block_mask, + offsets, + chunk_indices, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + M: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, +): i_v, i_s, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: - i_n, i_s = tl.load(chunk_indices + i_s * 2).to(tl.int32), tl.load(chunk_indices + i_s * 2 + - 1).to(tl.int32) + i_n, i_s = tl.load(chunk_indices + i_s * 2).to(tl.int32), tl.load(chunk_indices + i_s * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T - p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK), - (1, 0)) - p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), - (BS, BV), (1, 0)) - p_dk = tl.make_block_ptr(dk + (i_v * B * T * H + bos * H + i_h) * K, (T, K), (H * K, 1), - (i_s * BS, 0), (BS, BK), (1, 0)) - p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), - (BS, BV), (1, 0)) + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_v * B * T * H + bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0)) # [BS, BK] b_k = tl.load(p_k, boundary_check=(0, 1)) @@ -241,14 +281,12 @@ def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa, for i in range(i_s * BS, T): b_m_slc = tl.load(block_mask + (bos + i) * H * M + i_h * M + i_s) if b_m_slc: - p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_do_slc = tl.make_block_ptr(do_slc + (bos + i) * HQ * V, (HQ, V), (V, 1), - (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_do_slc = tl.make_block_ptr(do_slc + (bos + i) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i) * HQ + i_h * G + tl.arange(0, G) p_delta_slc = delta_slc + (bos + i) * HQ + i_h * G + tl.arange(0, G) # [G, BV] @@ -272,14 +310,12 @@ def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa, if WS > 0: o_s = i_s * BS + tl.arange(0, BS) if max(i_s * BS, i - WS + 1) < min((i_s + 1) * BS, i + 1): - p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), - (G, BK), (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_do_swa = tl.make_block_ptr(do_swa + (bos + i) * HQ * V, (HQ, V), (V, 1), - (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_do_swa = tl.make_block_ptr(do_swa + (bos + i) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_swa = lse_swa + (bos + i) * HQ + i_h * G + tl.arange(0, G) p_delta_swa = delta_swa + (bos + i) * HQ + i_h * G + tl.arange(0, G) # [G, BV] @@ -304,12 +340,19 @@ def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa, tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) -@triton.heuristics( - {'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor)}) +@triton.heuristics({"USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor)}) @triton.jit -def parallel_nsa_kernel_mask(block_indices, block_counts, block_mask, T: tl.constexpr, - H: tl.constexpr, S: tl.constexpr, BS: tl.constexpr, NS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_kernel_mask( + block_indices, + block_counts, + block_mask, + T: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + NS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_b, i_hs = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_h, i_s = i_hs // S, i_hs % S @@ -320,31 +363,56 @@ def parallel_nsa_kernel_mask(block_indices, block_counts, block_mask, T: tl.cons b_m = b_i * BS <= i_t if b_i < NS and b_i >= 0: - tl.store(block_mask + i_b * T * H * NS + i_t * H * NS + i_h * NS + b_i, - b_m.to(block_mask.dtype.element_ty)) + tl.store(block_mask + i_b * T * H * NS + i_t * H * NS + i_h * NS + b_i, b_m.to(block_mask.dtype.element_ty)) -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor) -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) -@triton.jit(do_not_specialize=['T']) -def parallel_nsa_bwd_kernel_dq(q, k, v, lse_slc, delta_slc, do_slc, lse_swa, delta_swa, do_swa, dq, - scale, block_indices, block_counts, offsets, token_indices, T, - B: tl.constexpr, H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr, - K: tl.constexpr, V: tl.constexpr, S: tl.constexpr, BS: tl.constexpr, - WS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, USE_BLOCK_COUNTS: tl.constexpr): +@triton.jit(do_not_specialize=["T"]) +def parallel_nsa_bwd_kernel_dq( + q, + k, + v, + lse_slc, + delta_slc, + do_slc, + lse_swa, + delta_swa, + do_swa, + dq, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: - i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + - 1).to(tl.int32) + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: @@ -449,27 +517,49 @@ def parallel_nsa_bwd_kernel_dq(q, k, v, lse_slc, delta_slc, do_slc, lse_swa, del tl.store(p_dq, (b_dq_slc + b_dq_swa).to(p_dq.dtype.element_ty), boundary_check=(0, 1)) -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) @triton.jit -def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices, - block_counts, offsets, token_indices, T, H: tl.constexpr, - HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, - S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, USE_OFFSETS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: - i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + - 1).to(tl.int32) + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: @@ -484,20 +574,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc else: NS = S - p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), - (G, BV), (1, 0)) + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_slc = tl.zeros([G, BV], dtype=tl.float32) - b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_slc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS @@ -510,7 +598,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) # [G, BS] b_s_slc = tl.dot(b_q, b_k_slc) - b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf')) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) # [G] b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc @@ -529,13 +617,12 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty)) if WS > 0: - p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), - (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_swa = lse_swa + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_swa = tl.zeros([G, BV], dtype=tl.float32) - b_m_swa = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_swa = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_swa = tl.zeros([G], dtype=tl.float32) for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS): p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) @@ -546,7 +633,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1)) # [G, BS] b_s_swa = tl.dot(b_q, b_k_swa) - b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float('-inf')) + b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float("-inf")) # [G] b_m_swa, b_mp_swa = tl.maximum(b_m_swa, tl.max(b_s_swa, 1)), b_m_swa @@ -593,14 +680,8 @@ def parallel_nsa_block_mask( block_mask = torch.zeros(B, T, H, NS, dtype=torch.bool, device=block_indices.device) parallel_nsa_kernel_mask[(T, B, H * S)]( - block_indices=block_indices, - block_counts=block_counts, - block_mask=block_mask, - T=T, - H=H, - S=S, - BS=BS, - NS=NS) + block_indices=block_indices, block_counts=block_counts, block_mask=block_mask, T=T, H=H, S=S, BS=BS, NS=NS + ) return block_mask @@ -676,7 +757,8 @@ def parallel_nsa_bwd( BS=BS, WS=WS, BK=BK, - BV=BV) + BV=BV, + ) dq = dq.sum(0) if offsets is not None: @@ -719,14 +801,14 @@ def parallel_nsa_bwd( BS=BS, WS=WS, BK=BK, - BV=BV) + BV=BV, + ) dk = dk.sum(0) return dq, dk, dv @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -749,7 +831,8 @@ def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, window_size=window_size, scale=scale, offsets=offsets, - token_indices=token_indices) + token_indices=token_indices, + ) ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) ctx.block_indices = block_indices ctx.block_counts = block_counts @@ -781,22 +864,25 @@ def backward(ctx, do_slc, do_swa): window_size=ctx.window_size, scale=ctx.scale, offsets=ctx.offsets, - token_indices=ctx.token_indices) + token_indices=ctx.token_indices, + ) return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None -def parallel_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -836,51 +922,49 @@ def parallel_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, - window_size, scale, cu_seqlens) + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) else: o = o_slc * g_slc.unsqueeze(-1) if head_first: - o = rearrange(o, 'b t h d -> b h t d') + o = rearrange(o, "b t h d -> b h t d") return o if __name__ == "__main__": B, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16 torch.random.manual_seed(0) - q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) - k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True) - v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True) - g_slc = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda') - - block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='cuda') + q = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + k = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + v = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device="cuda") for b in range(B): for t in range(T): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (B, T, H), device='cuda') + block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda") ref = naive_nsa( q=q, diff --git a/examples/deepseek_nsa/example_triton_nsa_fwd.py b/examples/deepseek_nsa/example_triton_nsa_fwd.py index 2c740013a..c9ab28daa 100644 --- a/examples/deepseek_nsa/example_triton_nsa_fwd.py +++ b/examples/deepseek_nsa/example_triton_nsa_fwd.py @@ -8,6 +8,7 @@ import triton.language as tl import fla + if parse(fla.__version__) < parse("0.2.1"): from fla.ops.common.utils import prepare_token_indices else: @@ -17,21 +18,44 @@ from einops import rearrange -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) @triton.jit -def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices, - block_counts, offsets, token_indices, T, H: tl.constexpr, - HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, - S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, USE_OFFSETS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -46,20 +70,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc # else: NS = S - p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), - (G, BV), (1, 0)) + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_slc = tl.zeros([G, BV], dtype=tl.float32) - b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_slc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS @@ -72,7 +94,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) # [G, BS] b_s_slc = tl.dot(b_q, b_k_slc) - b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf')) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) # [G] b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc @@ -92,7 +114,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -105,8 +126,7 @@ def forward(ctx, q, k, v, block_indices, block_size, scale, offsets): # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] token_indices = prepare_token_indices(offsets) if offsets is not None else None - o, lse = parallel_nsa_fwd( - q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) + o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) ctx.save_for_backward(q, k, v, o, lse) ctx.block_indices = block_indices ctx.block_size = block_size @@ -177,7 +197,6 @@ def parallel_nsa_fwd( @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -200,7 +219,8 @@ def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, window_size=window_size, scale=scale, offsets=offsets, - token_indices=token_indices) + token_indices=token_indices, + ) ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) ctx.block_indices = block_indices ctx.block_counts = block_counts @@ -212,18 +232,20 @@ def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa -def parallel_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -263,51 +285,49 @@ def parallel_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, - window_size, scale, cu_seqlens) + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) else: o = o_slc * g_slc.unsqueeze(-1) if head_first: - o = rearrange(o, 'b t h d -> b h t d') + o = rearrange(o, "b t h d -> b h t d") return o if __name__ == "__main__": B, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16 torch.random.manual_seed(0) - q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) - k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True) - v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True) - g_slc = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda') - - block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='cuda') + q = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + k = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + v = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device="cuda") for b in range(B): for t in range(T): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (B, T, H), device='cuda') + block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda") ref = naive_nsa( q=q, diff --git a/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py b/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py index 9ccbff6a4..cb4eb6d7b 100644 --- a/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py +++ b/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py @@ -8,6 +8,7 @@ import triton.language as tl import fla + if parse(fla.__version__) < parse("0.2.1"): from fla.ops.common.utils import prepare_token_indices else: @@ -17,27 +18,49 @@ from einops import rearrange -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) @triton.jit -def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices, - block_counts, offsets, token_indices, T, H: tl.constexpr, - HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, - S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, USE_OFFSETS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: - i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + - 1).to(tl.int32) + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: @@ -52,20 +75,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc else: NS = S - p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), - (G, BV), (1, 0)) + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_slc = tl.zeros([G, BV], dtype=tl.float32) - b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_slc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS @@ -78,7 +99,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) # [G, BS] b_s_slc = tl.dot(b_q, b_k_slc) - b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf')) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) # [G] b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc @@ -97,13 +118,12 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty)) if WS > 0: - p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), - (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_swa = lse_swa + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_swa = tl.zeros([G, BV], dtype=tl.float32) - b_m_swa = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_swa = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_swa = tl.zeros([G], dtype=tl.float32) for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS): p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) @@ -114,7 +134,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1)) # [G, BS] b_s_swa = tl.dot(b_q, b_k_swa) - b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float('-inf')) + b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float("-inf")) # [G] b_m_swa, b_mp_swa = tl.maximum(b_m_swa, tl.max(b_s_swa, 1)), b_m_swa @@ -196,7 +216,6 @@ def parallel_nsa_fwd( @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -219,7 +238,8 @@ def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, window_size=window_size, scale=scale, offsets=offsets, - token_indices=token_indices) + token_indices=token_indices, + ) ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) ctx.block_indices = block_indices ctx.block_counts = block_counts @@ -231,18 +251,20 @@ def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa -def parallel_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -282,29 +304,27 @@ def parallel_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, - window_size, scale, cu_seqlens) + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) else: o = o_slc * g_slc.unsqueeze(-1) if head_first: - o = rearrange(o, 'b t h d -> b h t d') + o = rearrange(o, "b t h d -> b h t d") return o @@ -312,38 +332,35 @@ def parallel_nsa(q: torch.Tensor, N, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 64, 1, 32, torch.float16 torch.manual_seed(42) # randomly split the sequence into N segments - offsets = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 1)[:N - 1]], - torch.tensor([T], dtype=torch.long) - ], 0).cuda().sort()[0] + offsets = ( + torch.cat( + [torch.tensor([0], dtype=torch.long), torch.arange(16, T)[torch.randperm(T - 1)[: N - 1]], torch.tensor([T], dtype=torch.long)], + 0, + ) + .cuda() + .sort()[0] + ) # offsets.shape is [N+1] # seq-first required for inputs with variable lengths - perm_q = torch.randperm(T, device='cuda') - perm_k = torch.randperm(T, device='cuda') - perm_v = torch.randperm(T, device='cuda') - q = torch.linspace( - 0, 1, steps=T, dtype=dtype, - device='cuda')[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True) - k = torch.linspace( - 0, 1, steps=T, dtype=dtype, - device='cuda')[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) - v = torch.linspace( - 0, 1, steps=T, dtype=dtype, - device='cuda')[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) - g_slc = torch.rand((1, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.rand((1, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((1, T, HQ, D), dtype=dtype, device='cuda') + perm_q = torch.randperm(T, device="cuda") + perm_k = torch.randperm(T, device="cuda") + perm_v = torch.randperm(T, device="cuda") + q = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True) + k = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) + v = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) + g_slc = torch.rand((1, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.rand((1, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((1, T, HQ, D), dtype=dtype, device="cuda") token_indices = prepare_token_indices(offsets).tolist() - block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device='cuda') + block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device="cuda") for i in range(T): _, t = token_indices[i] for h in range(H): i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S] - block_indices[0, i, h, :len(i_i)] = i_i + block_indices[0, i, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (1, T, H), device='cuda') + block_counts = torch.randint(1, S + 1, (1, T, H), device="cuda") ref = naive_nsa( q=q, @@ -354,7 +371,8 @@ def parallel_nsa(q: torch.Tensor, block_indices=block_indices, block_counts=block_counts, block_size=block_size, - cu_seqlens=offsets) + cu_seqlens=offsets, + ) tri = parallel_nsa( q=q, @@ -365,7 +383,8 @@ def parallel_nsa(q: torch.Tensor, block_indices=block_indices, block_counts=block_counts, block_size=block_size, - cu_seqlens=offsets) + cu_seqlens=offsets, + ) print("tri", tri) print("ref", ref) diff --git a/examples/deepseek_nsa/reference.py b/examples/deepseek_nsa/reference.py index 958d0c19e..58083108e 100644 --- a/examples/deepseek_nsa/reference.py +++ b/examples/deepseek_nsa/reference.py @@ -6,18 +6,20 @@ from einops import rearrange, repeat -def naive_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def naive_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -57,26 +59,24 @@ def naive_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - raise RuntimeError( - "Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") dtype = q.dtype G = q.shape[2] // k.shape[2] BS = block_size S = block_indices.shape[-1] - k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices)) + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) if isinstance(block_counts, torch.Tensor): - block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G) + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) q, k, v = map(lambda x: x.float(), (q, k, v)) @@ -86,14 +86,11 @@ def naive_nsa(q: torch.Tensor, if cu_seqlens is None: varlen = False B, T = q.shape[:2] - cu_seqlens = torch.cat( - [block_indices.new_tensor(range(0, B * T, T)), - block_indices.new_tensor([B * T])]) + cu_seqlens = torch.cat([block_indices.new_tensor(range(0, B * T, T)), block_indices.new_tensor([B * T])]) for i in range(len(cu_seqlens) - 1): if not varlen: - q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[ - i], block_indices[i] + q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[i], block_indices[i] if isinstance(block_counts, torch.Tensor): s_b = block_counts[i] else: @@ -101,10 +98,10 @@ def naive_nsa(q: torch.Tensor, else: T = cu_seqlens[i + 1] - cu_seqlens[i] q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map( - lambda x: x[0][cu_seqlens[i]:cu_seqlens[i + 1]], - (q, k, v, g_slc, g_swa, block_indices)) + lambda x: x[0][cu_seqlens[i] : cu_seqlens[i + 1]], (q, k, v, g_slc, g_swa, block_indices) + ) if isinstance(block_counts, torch.Tensor): - s_b = block_counts[0][cu_seqlens[i]:cu_seqlens[i + 1]] + s_b = block_counts[0][cu_seqlens[i] : cu_seqlens[i + 1]] else: s_b = block_counts @@ -126,34 +123,28 @@ def naive_nsa(q: torch.Tensor, else: s_i = s_b # [S*BS, HQ, -1] - k_i_slc, v_i_slc = map( - lambda x: x.gather( - 0, - i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) + k_i_slc, v_i_slc = map(lambda x: x.gather(0, i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) # [S*BS, HQ] - attn_slc = torch.einsum('h d, n h d -> n h', q_i, k_i_slc).masked_fill( - torch.logical_or(i_i < 0, i_i > i_q) | - (c >= s_i if block_counts is not None else False), float('-inf')).softmax(0) + attn_slc = ( + torch.einsum("h d, n h d -> n h", q_i, k_i_slc) + .masked_fill(torch.logical_or(i_i < 0, i_i > i_q) | (c >= s_i if block_counts is not None else False), float("-inf")) + .softmax(0) + ) if not varlen: - o_slc[i, i_q] = torch.einsum('n h, n h v -> h v', attn_slc, - v_i_slc) * g_slc_i.unsqueeze(-1) + o_slc[i, i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) else: - o_slc[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_slc, - v_i_slc) * g_slc_i.unsqueeze(-1) + o_slc[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) if window_size > 0: - k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1):i_q + 1], - (k_b, v_b)) - attn_swa = torch.einsum('h d, n h d -> n h', q_i, k_i_swa).softmax(0) + k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1) : i_q + 1], (k_b, v_b)) + attn_swa = torch.einsum("h d, n h d -> n h", q_i, k_i_swa).softmax(0) if not varlen: - o_swa[i, i_q] = torch.einsum('n h, n h v -> h v', attn_swa, - v_i_swa) * g_swa_i.unsqueeze(-1) + o_swa[i, i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) else: - o_swa[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_swa, - v_i_swa) * g_swa_i.unsqueeze(-1) + o_swa[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) if head_first: - o_slc = rearrange(o_slc, 'b t h d -> b h t d') - o_swa = rearrange(o_swa, 'b t h d -> b h t d') + o_slc = rearrange(o_slc, "b t h d -> b h t d") + o_swa = rearrange(o_swa, "b t h d -> b h t d") return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype) @@ -187,7 +178,7 @@ def naive_nsa_simple( o (torch.Tensor): Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 dtype = q.dtype HQ = q.shape[2] @@ -197,8 +188,8 @@ def naive_nsa_simple( BS = block_size S = block_indices.shape[-1] SELECTED_BLOCKS_SIZE = S * BS - k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices)) - block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G) + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) q, k, v = map(lambda x: x.float(), (q, k, v)) o = torch.zeros_like(v) @@ -228,10 +219,10 @@ def naive_nsa_simple( v_i[t, h] = v_b[selected_block_index, h, :] # [S*BS, HQ] - attn = torch.einsum('h d, n h d -> n h', q_i, k_i) - attn = attn.masked_fill((i_i > i_q) | (c >= s_i), float('-inf')) + attn = torch.einsum("h d, n h d -> n h", q_i, k_i) + attn = attn.masked_fill((i_i > i_q) | (c >= s_i), float("-inf")) attn = torch.softmax(attn, dim=0) - o[i, i_q] = torch.einsum('n h, n h v -> h v', attn, v_i) + o[i, i_q] = torch.einsum("n h, n h v -> h v", attn, v_i) return o.to(dtype) @@ -265,7 +256,7 @@ def naive_nsa_simple_inference( o (torch.Tensor): Outputs of shape `[B, 1, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 dtype = q.dtype HQ = q.shape[2] @@ -275,8 +266,8 @@ def naive_nsa_simple_inference( BS = block_size S = block_indices.shape[-1] SELECTED_BLOCKS_SIZE = S * BS - k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices)) - block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G) + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) q, k, v = map(lambda x: x.float(), (q, k, v)) o = torch.zeros_like(q) @@ -306,9 +297,9 @@ def naive_nsa_simple_inference( v_i[t, h] = v_b[selected_block_index, h, :] # [S*BS, HQ] - attn = torch.einsum('h d, n h d -> n h', q_i, k_i) - attn = attn.masked_fill((c >= s_i), float('-inf')) + attn = torch.einsum("h d, n h d -> n h", q_i, k_i) + attn = attn.masked_fill((c >= s_i), float("-inf")) attn = torch.softmax(attn, dim=0) - o[i, 0] = torch.einsum('n h, n h v -> h v', attn, v_i) + o[i, 0] = torch.einsum("n h, n h v -> h v", attn, v_i) return o.to(dtype) diff --git a/examples/deepseek_nsa/regression_example_tilelang_nsa.py b/examples/deepseek_nsa/regression_example_tilelang_nsa.py new file mode 100644 index 000000000..1858f045a --- /dev/null +++ b/examples/deepseek_nsa/regression_example_tilelang_nsa.py @@ -0,0 +1,15 @@ +import tilelang.testing +import example_tilelang_nsa_fwd +import example_tilelang_nsa_decode + + +def regression_example_tilelang_nsa_fwd(): + tilelang.testing.process_func(example_tilelang_nsa_fwd.run_regression_perf) + + +def regression_example_tilelang_nsa_fwd_decode(): + tilelang.testing.process_func(example_tilelang_nsa_decode.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/deepseek_nsa/requirements.txt b/examples/deepseek_nsa/requirements.txt index 777c2ad4c..e096dfd7d 100644 --- a/examples/deepseek_nsa/requirements.txt +++ b/examples/deepseek_nsa/requirements.txt @@ -1 +1 @@ -git+https://github.com/fla-org/flash-linear-attention@c3bd56589033610264532b11f0972c69e4645f6e \ No newline at end of file +git+https://github.com/fla-org/flash-linear-attention@c3bd56589033610264532b11f0972c69e4645f6e diff --git a/examples/deepseek_v32/README.md b/examples/deepseek_v32/README.md index 8457745b0..01a14b6b2 100644 --- a/examples/deepseek_v32/README.md +++ b/examples/deepseek_v32/README.md @@ -121,7 +121,7 @@ for i_i in T.Pipelined(NI, num_stages=num_stages): # ... compute attention over selected tokens ``` -This reduces compute from O(seq_len * seq_len_kv) to O(seq_len * topk). The causal mask is enforced by checking whether each index position is valid: +This reduces compute from O(seq_len *seq_len_kv) to O(seq_len* topk). The causal mask is enforced by checking whether each index position is valid: ```python for bi_i in T.Parallel(BI): @@ -193,10 +193,10 @@ for i_i in T.Pipelined(NI, num_stages=num_stages): # Load KV data for selected indices for bi_i, d_i in T.Parallel(BI, D): KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BI + bi_i], bz, d_i] - + # Recompute attention scores for backward T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) - + # Apply softmax gradient: dP = P * (dP_raw - Delta) for h_i, bi_i in T.Parallel(padded_H, BI): acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale @@ -204,7 +204,7 @@ for i_i in T.Pipelined(NI, num_stages=num_stages): The key gradient computations are: - **dQ = dP @ K** (query gradients) -- **dK = dP^T @ Q** (key gradients) +- **dK = dP^T @ Q** (key gradients) - **dV = P^T @ dO** (value gradients) **3. Atomic Sparse Updates**: Uses atomic operations for dKV accumulation: @@ -212,7 +212,7 @@ The key gradient computations are: ```python # Atomically update dKV at selected indices for bi_i, d_i in T.Parallel(BI // split_store, D // 4): - T.atomic_addx4(dKV[by, Indices[by, s_i, bz, i_i * BI + bi_i + s * (BI // split_store)], bz, d_i * 4], + T.atomic_addx4(dKV[by, Indices[by, s_i, bz, i_i * BI + bi_i + s * (BI // split_store)], bz, d_i * 4], acc_dkv_shared[bi_i, d_i * 4]) ``` diff --git a/examples/deepseek_v32/fp8_lighting_indexer.py b/examples/deepseek_v32/fp8_lighting_indexer.py index 21baa8fa8..2f8857597 100644 --- a/examples/deepseek_v32/fp8_lighting_indexer.py +++ b/examples/deepseek_v32/fp8_lighting_indexer.py @@ -28,11 +28,11 @@ def validate_tensor_match(a, b, tolerance=1e-8, tensor_name="tensor", should_rai if should_raise: assert False if not torch.isclose( - a.masked_fill(a_finite, 0), - b.masked_fill(b_finite, 0), - rtol=0, - atol=0, - equal_nan=True, + a.masked_fill(a_finite, 0), + b.masked_fill(b_finite, 0), + rtol=0, + atol=0, + equal_nan=True, ).all(): display_error_message(f"{tensor_name} Error: nonfinite value mismatch") if should_raise: @@ -55,13 +55,10 @@ def get_configs(): threads=[128, 256], block_Q=[1, 2, 4], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] class SupplyProg: - def __init__(self): self.tensors_dict = {} @@ -88,7 +85,8 @@ def supply_prog(self, params): @tilelang.jit( pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - },) + }, +) def mqa_attn_return_logits( heads, index_dim, @@ -99,9 +97,9 @@ def mqa_attn_return_logits( ): if block_Q is None: block_Q = 128 // heads - dtype = "float8_e4m3" - accum_dtype = "float" - index_dtype = "int32" + dtype = T.float8_e4m3fn + accum_dtype = T.float32 + index_dtype = T.int32 seq_len = T.dynamic("seq_len") seq_len_kv = T.dynamic("seq_len_kv") @@ -113,46 +111,42 @@ def mqa_attn_return_logits( @T.prim_func def mqa_attn_return_logits_kernel( - IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore - IndexK: T.Tensor(index_k_shape, dtype), # type: ignore - IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype), # type: ignore - Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore - Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore - CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore - CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore + IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore + IndexK: T.Tensor(index_k_shape, dtype), # type: ignore + IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype), # type: ignore + Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore + Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore + CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore + CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx: - index_q_shared = T.alloc_shared([block_Q * heads, index_dim], dtype) index_k_shared = T.alloc_shared([block_N, index_dim], dtype) index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype) s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype) - s_reshaped = T.alloc_fragment([block_N, block_Q, heads], accum_dtype) + s_reshaped = T.reshape(s, (block_N, block_Q, heads)) logits = T.alloc_fragment([block_N, block_Q], accum_dtype) weights = T.alloc_fragment([block_Q, heads], accum_dtype) seq_len_i = bx * block_Q - cu_k_s_min = T.alloc_local([1], index_dtype) - cu_k_e_max = T.alloc_local([1], index_dtype) + cu_k_s_min = T.alloc_var(index_dtype) + cu_k_e_max = T.alloc_var(index_dtype) - cu_k_s_min[0] = 2147483647 - cu_k_e_max[0] = -2147483648 + cu_k_s_min = 2147483647 + cu_k_e_max = -2147483648 for bq_i in T.serial(block_Q): - cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], - seq_len_kv)) + cu_k_s_min = T.min(cu_k_s_min, T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv)) for bq_i in T.serial(block_Q): - cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], - seq_len_kv)) + cu_k_e_max = T.max(cu_k_e_max, T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv)) T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared) T.copy(Weights[seq_len_i, 0], weights) - for nbn_i in T.Pipelined( - T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages): - T.copy(IndexK[cu_k_s_min[0] + nbn_i * block_N, 0], index_k_shared) - T.copy(IndexKScale[cu_k_s_min[0] + nbn_i * block_N], index_k_scale_fragment) + for nbn_i in T.Pipelined(T.ceildiv(cu_k_e_max - cu_k_s_min, block_N), num_stages=num_stages): + T.copy(IndexK[cu_k_s_min + nbn_i * block_N, 0], index_k_shared) + T.copy(IndexKScale[cu_k_s_min + nbn_i * block_N], index_k_scale_fragment) T.gemm( index_k_shared, @@ -164,15 +158,14 @@ def mqa_attn_return_logits_kernel( ) for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads): - s_reshaped[bn_i, bq_i, - h_i] = (T.max(s[bn_i, bq_i * heads + h_i], 0) * - weights[bq_i, h_i]) * index_k_scale_fragment[bn_i] + s_reshaped[bn_i, bq_i, h_i] = (T.max(s_reshaped[bn_i, bq_i, h_i], 0) * weights[bq_i, h_i]) * index_k_scale_fragment[ + bn_i + ] T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) for bq_i, bn_i in T.Parallel(block_Q, block_N): - Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = ( - logits[bn_i, bq_i]) + Logits[seq_len_i + bq_i, cu_k_s_min + nbn_i * block_N + bn_i] = logits[bn_i, bq_i] return mqa_attn_return_logits_kernel @@ -185,38 +178,30 @@ def clean_logits_( seq_len = T.dynamic("seq_len") seq_len_kv = T.dynamic("seq_len_kv") - dtype = "float" - indices_dtype = "int32" + dtype = T.float + indices_dtype = T.int32 @T.prim_func def clean_logits_kernel( - Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore - CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore - CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore + Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore + CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore + CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore ): with T.Kernel(seq_len, threads=threads) as bx: tx = T.thread_binding(0, threads, thread="threadIdx.x") - cu_k_s = T.alloc_local([1], indices_dtype) - cu_k_e = T.alloc_local([1], indices_dtype) - cu_k_s[0] = CuSeqLenKS[bx] - cu_k_e[0] = CuSeqLenKE[bx] + cu_k_s = CuSeqLenKS[bx] + cu_k_e = CuSeqLenKE[bx] for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)): for k_i in T.serial(block_K // threads): idx = n_i * block_K + k_i * threads + tx - if idx < cu_k_s[0] or idx >= cu_k_e[0]: + if idx < cu_k_s or idx >= cu_k_e: Logits[bx, idx] = -T.infinity(dtype) return clean_logits_kernel -def mqa_attn_return_logits_interface(q, - kv, - kv_scales, - weights, - cu_seqlen_ks, - cu_seqlen_ke, - clean_logits=True): +def mqa_attn_return_logits_interface(q, kv, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=True): seq_len, heads, index_dim = q.shape seq_len_kv = kv.shape[0] @@ -238,57 +223,48 @@ def mqa_attn_return_logits_interface(q, return logits -def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, - cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor): +def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor): k = kv q = q.float() k = k.float() seq_len_kv = kv.shape[0] - mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None] - mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None] + mask_lo = torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] + mask_hi = torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] mask = mask_lo & mask_hi - score = torch.einsum('mhd,nd->hmn', q, k) + score = torch.einsum("mhd,nd->hmn", q, k) logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) - logits = logits.masked_fill(~mask, float('-inf')) + logits = logits.masked_fill(~mask, float("-inf")) cost = mask.sum() return logits, cost def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1): + # initial random seed to make the performance reproducible + torch.manual_seed(0) q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) weights = torch.randn(S, H, device="cuda", dtype=torch.float32) p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1) - ks, ke = generate_random_cu_seqlens( - per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048) + ks, ke = generate_random_cu_seqlens(per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048) - logits_ref, cost_ref = ref_fp8_mqa_logits( - q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + logits_ref, cost_ref = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) q_fp8 = q.to(torch.float8_e4m3fn) kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0,), False) - logits_tl = mqa_attn_return_logits_interface( - q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) - diff = validate_tensor_match( - logits_ref, logits_tl, tolerance=1e-14, tensor_name="logits", should_raise=False) + logits_tl = mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + diff = validate_tensor_match(logits_ref, logits_tl, tolerance=1e-14, tensor_name="logits", should_raise=False) print(f"diff: {diff}") from tilelang.profiler import do_bench def logits_fn(): - return mqa_attn_return_logits_interface( - q=q_fp8, - kv=kv_fp8, - kv_scales=kv_scales, - weights=weights, - cu_seqlen_ks=ks, - cu_seqlen_ke=ke) + return mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: logits_fn() @@ -302,5 +278,23 @@ def logits_fn(): print(f"cost_ref: {cost_ref}") +def run_regression_perf(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1): + torch.manual_seed(0) + q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) + kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) + weights = torch.randn(S, H, device="cuda", dtype=torch.float32) + ks, ke = generate_random_cu_seqlens(per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0,), False) + + from tilelang.profiler import do_bench + + def logits_fn(): + return mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + + return do_bench(logits_fn, backend="cupti") + + if __name__ == "__main__": test_fp8_lighting_indexer() diff --git a/examples/deepseek_v32/inference/README.md b/examples/deepseek_v32/inference/README.md index fe4cc21bb..60afe7ceb 100644 --- a/examples/deepseek_v32/inference/README.md +++ b/examples/deepseek_v32/inference/README.md @@ -11,4 +11,4 @@ Launch the interactive chat interface and start exploring DeepSeek's capabilitie ```bash export CONFIG=config_671B_v3.2.json torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive -``` \ No newline at end of file +``` diff --git a/examples/deepseek_v32/inference/config_671B_v3.2.json b/examples/deepseek_v32/inference/config_671B_v3.2.json index be88f1cca..375aa9aa2 100644 --- a/examples/deepseek_v32/inference/config_671B_v3.2.json +++ b/examples/deepseek_v32/inference/config_671B_v3.2.json @@ -23,4 +23,4 @@ "index_n_heads": 64, "index_head_dim": 128, "index_topk": 2048 -} \ No newline at end of file +} diff --git a/examples/deepseek_v32/inference/convert.py b/examples/deepseek_v32/inference/convert.py index df7943918..090be7145 100644 --- a/examples/deepseek_v32/inference/convert.py +++ b/examples/deepseek_v32/inference/convert.py @@ -42,7 +42,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp): save_path (str): Path to the directory where the converted checkpoint files will be saved. n_experts (int): Total number of experts in the model. mp (int): Model parallelism factor. - + Returns: None """ diff --git a/examples/deepseek_v32/inference/kernel.py b/examples/deepseek_v32/inference/kernel.py index 262343536..4090d4beb 100644 --- a/examples/deepseek_v32/inference/kernel.py +++ b/examples/deepseek_v32/inference/kernel.py @@ -11,21 +11,21 @@ tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True, } -FP8 = "float8_e4m3" -BF16 = "bfloat16" -FP32 = "float32" +FP8 = T.float8_e4m3fn +BF16 = T.bfloat16 +FP32 = T.float32 def fast_log2_ceil(x): - bits_x = T.reinterpret("uint32", x) + bits_x = T.reinterpret(x, T.uint32) exp_x = (bits_x >> 23) & 0xFF man_bits = bits_x & ((1 << 23) - 1) - return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0)) + return T.cast(exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0), T.int32) def fast_pow2(x): bits_x = (x + 127) << 23 - return T.reinterpret("float32", bits_x) + return T.reinterpret(bits_x, T.float32) def fast_round_scale(amax, fp8_max_inv): @@ -107,8 +107,8 @@ def act_quant(x: torch.Tensor, @tilelang.jit(pass_configs=pass_configs) -def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"): - assert out_dtype in [BF16, "float32"] +def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=T.float32): + assert out_dtype in [BF16, T.float32] M = T.dynamic("M") group_size = 128 diff --git a/examples/deepseek_v32/inference/requirements.txt b/examples/deepseek_v32/inference/requirements.txt index 604fed552..8c208a8b1 100644 --- a/examples/deepseek_v32/inference/requirements.txt +++ b/examples/deepseek_v32/inference/requirements.txt @@ -2,4 +2,4 @@ torch transformers safetensors fast_hadamard_transform -tilelang==0.1.6 \ No newline at end of file +tilelang==0.1.6 diff --git a/examples/deepseek_v32/regression_tilelang_example_deepseek_v32.py b/examples/deepseek_v32/regression_tilelang_example_deepseek_v32.py new file mode 100644 index 000000000..0610002a6 --- /dev/null +++ b/examples/deepseek_v32/regression_tilelang_example_deepseek_v32.py @@ -0,0 +1,30 @@ +import tilelang.testing +import fp8_lighting_indexer +import sparse_mla_bwd +import sparse_mla_fwd +import sparse_mla_fwd_pipelined +import topk_selector + + +def regression_topk_selector(): + tilelang.testing.process_func(topk_selector.run_regression_perf) + + +def regression_fp8_lighting_indexer(): + tilelang.testing.process_func(fp8_lighting_indexer.run_regression_perf, S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1) + + +def regression_sparse_mla_fwd(): + tilelang.testing.process_func(sparse_mla_fwd.run_regression_perf, S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256) + + +def regression_sparse_mla_fwd_pipelined(): + tilelang.testing.process_func(sparse_mla_fwd_pipelined.run_regression_perf, S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256) + + +def regression_sparse_mla_bwd(): + tilelang.testing.process_func(sparse_mla_bwd.run_regression_perf, S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/deepseek_v32/sparse_mla_bwd.py b/examples/deepseek_v32/sparse_mla_bwd.py index e7f9c6093..527de22b3 100644 --- a/examples/deepseek_v32/sparse_mla_bwd.py +++ b/examples/deepseek_v32/sparse_mla_bwd.py @@ -13,18 +13,18 @@ def preprocess( D, block_ND=32, num_stages=5, - dtype="bfloat16", - accum_dtype="float", + dtype=T.bfloat16, + accum_dtype=T.float32, ): - assert dtype == "bfloat16" - assert accum_dtype == "float" + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 shape = [B, S, H, D] @T.prim_func def preprocess_kernel( - O: T.Tensor(shape, dtype), - dO: T.Tensor(shape, dtype), - Delta: T.Tensor([B, S, H], accum_dtype), + O: T.Tensor(shape, dtype), + dO: T.Tensor(shape, dtype), + Delta: T.Tensor([B, S, H], accum_dtype), ): with T.Kernel(H, T.ceildiv(S, block_ND), B) as (bx, by, bz): o = T.alloc_fragment([block_ND, block_ND], accum_dtype) @@ -33,16 +33,12 @@ def preprocess_kernel( acc = T.alloc_fragment([block_ND, block_ND], accum_dtype) T.clear(acc) for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages): - T.copy( - O[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], - o) - T.copy( - dO[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], - do) + T.copy(O[bz, by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], o) + T.copy(dO[bz, by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], do) for i, j in T.Parallel(block_ND, block_ND): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, by * block_ND:(by + 1) * block_ND, bx]) + T.copy(delta, Delta[bz, by * block_ND : (by + 1) * block_ND, bx]) return preprocess_kernel @@ -56,22 +52,22 @@ def postprocess( kv_group=1, block_N=64, threads=128, - dtype="bfloat16", - accum_dtype="float", + dtype=T.bfloat16, + accum_dtype=T.float32, ): - assert dtype == "bfloat16" - assert accum_dtype == "float" + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 dkv_shape = [B, S_kv, kv_group, D + D_tail] @T.prim_func def postprocess_kernel( - dKV: T.Tensor(dkv_shape, accum_dtype), - dKV_out: T.Tensor(dkv_shape, dtype), + dKV: T.Tensor(dkv_shape, accum_dtype), + dKV_out: T.Tensor(dkv_shape, dtype), ): with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, B, threads=threads) as (bx, by, bz): T.copy( - dKV[bz, bx * block_N:(bx + 1) * block_N, by, :], - dKV_out[bz, bx * block_N:(bx + 1) * block_N, by, :], + dKV[bz, bx * block_N : (bx + 1) * block_N, by, :], + dKV_out[bz, bx * block_N : (bx + 1) * block_N, by, :], ) return postprocess_kernel @@ -82,7 +78,9 @@ def postprocess_kernel( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True, + }, +) def bwd( B, S, @@ -97,18 +95,18 @@ def bwd( block_size=32, num_stages=0, threads=256, - indices_dtype="int32", - dtype="bfloat16", - accum_dtype="float", + indices_dtype=T.int32, + dtype=T.bfloat16, + accum_dtype=T.float32, ): - assert is_causal == True, 'non-casual is not supported now' - assert topk % block_size == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded' - assert dtype == "bfloat16" - assert accum_dtype == "float" - assert indices_dtype == "int32" + assert is_causal == True, "non-casual is not supported now" + assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + assert indices_dtype == T.int32 if sm_scale is None: - sm_scale = (D + D_tail)**(-0.5) + sm_scale = (D + D_tail) ** (-0.5) sm_scale_mul_reciprocal_log2 = sm_scale * 1.44269504 # log2(e) H_kv = H // kv_group @@ -118,12 +116,15 @@ def bwd( indices_shape = [B, S, kv_group, topk] delta_shape = [B, S, H] lse_shape = [B, S, H] - assert indices_dtype == "int32" - assert dtype == "bfloat16" - assert accum_dtype == "float" + assert indices_dtype == T.int32 + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 H = H_kv padded_H = max(tilelang.math.next_power_of_2(H_kv), 16) + block_H = min(64, padded_H) + assert padded_H % block_H == 0 + NH = padded_H // block_H BS = block_size NS = tilelang.cdiv(topk, block_size) @@ -131,122 +132,85 @@ def bwd( @T.prim_func def sparse_mla_bwd_kernel( - Q: T.Tensor(q_shape, dtype), - KV: T.Tensor(k_shape, dtype), - dO: T.Tensor(o_shape, dtype), - Indices: T.Tensor(indices_shape, indices_dtype), - Lse: T.Tensor(lse_shape, accum_dtype), - Delta: T.Tensor(delta_shape, accum_dtype), - dQ: T.Tensor(q_shape, dtype), - dKV: T.Tensor(k_shape, accum_dtype), + Q: T.Tensor(q_shape, dtype), + KV: T.Tensor(k_shape, dtype), + dO: T.Tensor(o_shape, dtype), + Indices: T.Tensor(indices_shape, indices_dtype), + Lse: T.Tensor(lse_shape, accum_dtype), + Delta: T.Tensor(delta_shape, accum_dtype), + dQ: T.Tensor(q_shape, dtype), + dKV: T.Tensor(k_shape, accum_dtype), ): - with T.Kernel(S, B, kv_group, threads=threads) as (s_i, by, bz): - Q_shared = T.alloc_shared([padded_H, D], dtype) - Q_tail_shared = T.alloc_shared([padded_H, D_tail], dtype) + with T.Kernel(S, B, kv_group * NH, threads=threads) as (s_i, by, bz): + Q_shared = T.alloc_shared([block_H, D], dtype) + Q_tail_shared = T.alloc_shared([block_H, D_tail], dtype) KV_shared = T.alloc_shared([BS, D], dtype) KV_tail_shared = T.alloc_shared([BS, D_tail], dtype) - dO_shared = T.alloc_shared([padded_H, D], dtype) + dO_shared = T.alloc_shared([block_H, D], dtype) mask = T.alloc_fragment([BS], "bool") - P_shared_cast = T.alloc_shared([padded_H, BS], dtype) - dP_shared_cast = T.alloc_shared([padded_H, BS], dtype) - dQ_shared = T.alloc_shared([padded_H, D], dtype) - dQ_tail_shared = T.alloc_shared([padded_H, D_tail], dtype) + P_shared_cast = T.alloc_shared([block_H, BS], dtype) + dP_shared_cast = T.alloc_shared([block_H, BS], dtype) + dQ_shared = T.alloc_shared([block_H, D], dtype) + dQ_tail_shared = T.alloc_shared([block_H, D_tail], dtype) - acc_p = T.alloc_fragment([padded_H, BS], accum_dtype) - acc_dp = T.alloc_fragment([padded_H, BS], accum_dtype) - acc_dq = T.alloc_fragment([padded_H, D], accum_dtype) - acc_dq_tail = T.alloc_fragment([padded_H, D_tail], accum_dtype) + acc_p = T.alloc_fragment([block_H, BS], accum_dtype) + acc_dp = T.alloc_fragment([block_H, BS], accum_dtype) + acc_dq = T.alloc_fragment([block_H, D], accum_dtype) + acc_dq_tail = T.alloc_fragment([block_H, D_tail], accum_dtype) acc_dkv = T.alloc_fragment([BS, D], accum_dtype) acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype) - acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype) - acc_dkv_tail_shared = T.view( - KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype) + acc_dkv_shared = T.alloc_shared([BS // split_store, D], accum_dtype) + acc_dkv_tail_shared = T.alloc_shared([BS // split_store, D_tail], accum_dtype) max_kv_i = s_i - T.copy(Q[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D], Q_shared) - T.copy(Q[by, s_i, bz * padded_H:(bz + 1) * padded_H, D:], Q_tail_shared) - T.copy(dO[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D], dO_shared) + T.copy(Q[by, s_i, bz * block_H : (bz + 1) * block_H, :D], Q_shared) + T.copy(Q[by, s_i, bz * block_H : (bz + 1) * block_H, D:], Q_tail_shared) + T.copy(dO[by, s_i, bz * block_H : (bz + 1) * block_H, :D], dO_shared) T.clear(acc_dq) T.clear(acc_dq_tail) - T.annotate_layout({ - dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared), - dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared), - }) - # Process each block of indices for i_i in T.Pipelined(NS, num_stages=num_stages): # Check which indices are valid for bi_i in T.Parallel(BS): - mask[bi_i] = Indices[by, s_i, bz, i_i * BS + bi_i] <= max_kv_i + mask[bi_i] = Indices[by, s_i, bz // NH, i_i * BS + bi_i] <= max_kv_i # Compute attention scores - for h_i, bi_i in T.Parallel(padded_H, BS): + for h_i, bi_i in T.Parallel(block_H, BS): acc_p[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_p.dtype)) # Load KV, V for this block of indices for bi_i, d_i in T.Parallel(BS, D): - KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, d_i] + KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz // NH, i_i * BS + bi_i], bz // NH, d_i] - T.gemm( - Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) for bi_i, d_i in T.Parallel(BS, D_tail): - KV_tail_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, - D + d_i] - T.gemm( - Q_tail_shared, - KV_tail_shared, - acc_p, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) - - for h_i, bi_i in T.Parallel(padded_H, BS): - acc_p[h_i, bi_i] = T.exp2(acc_p[h_i, bi_i] * sm_scale_mul_reciprocal_log2 - - Lse[by, s_i, bz * padded_H + h_i]) + KV_tail_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz // NH, i_i * BS + bi_i], bz // NH, D + d_i] + T.gemm(Q_tail_shared, KV_tail_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + + for h_i, bi_i in T.Parallel(block_H, BS): + acc_p[h_i, bi_i] = T.exp2(acc_p[h_i, bi_i] * sm_scale_mul_reciprocal_log2 - Lse[by, s_i, bz * block_H + h_i]) T.copy(acc_p, P_shared_cast) - T.gemm( - dO_shared, - KV_shared, - acc_dp, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol, - clear_accum=True) + T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) - for h_i, bi_i in T.Parallel(padded_H, BS): - acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * ( - acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale + for h_i, bi_i in T.Parallel(block_H, BS): + acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[by, s_i, bz * block_H + h_i]) * sm_scale T.copy(acc_dp, dP_shared_cast) T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol) T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - dP_shared_cast, - Q_shared, - acc_dkv, - transpose_A=True, - policy=T.GemmWarpPolicy.FullCol, - clear_accum=True) - T.gemm( - P_shared_cast, - dO_shared, - acc_dkv, - transpose_A=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) T.clear(acc_dkv_tail) - T.gemm( - dP_shared_cast, - Q_tail_shared, - acc_dkv_tail, - transpose_A=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(dP_shared_cast, Q_tail_shared, acc_dkv_tail, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) for s in range(split_store): for bi_i, d_i in T.Parallel(BS, D): @@ -255,41 +219,32 @@ def sparse_mla_bwd_kernel( for bi_i, d_i in T.Parallel(BS, D_tail): if bi_i < BS // split_store: - acc_dkv_tail_shared[bi_i, - d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), - d_i] + acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i] for bi_i, d_i in T.Parallel(BS // split_store, D // 4): T.atomic_addx4( - dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], - bz, d_i * 4], acc_dkv_shared[bi_i, d_i * 4]) + dKV[by, Indices[by, s_i, bz // NH, i_i * BS + bi_i + s * (BS // split_store)], bz // NH, d_i * 4], + acc_dkv_shared[bi_i, d_i * 4], + ) # Atomically update dKV, dKV_tail tensors for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4): T.atomic_addx4( - dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], - bz, D + d_i * 4], acc_dkv_tail_shared[bi_i, d_i * 4]) + dKV[by, Indices[by, s_i, bz // NH, i_i * BS + bi_i + s * (BS // split_store)], bz // NH, D + d_i * 4], + acc_dkv_tail_shared[bi_i, d_i * 4], + ) # Store the accumulated dQ T.copy(acc_dq, dQ_shared) T.copy(acc_dq_tail, dQ_tail_shared) - T.copy(dQ_shared, dQ[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D]) - T.copy(dQ_tail_shared, dQ[by, s_i, bz * padded_H:(bz + 1) * padded_H, D:]) + T.copy(dQ_shared, dQ[by, s_i, bz * block_H : (bz + 1) * block_H, :D]) + T.copy(dQ_tail_shared, dQ[by, s_i, bz * block_H : (bz + 1) * block_H, D:]) return sparse_mla_bwd_kernel -def sparse_mla_bwd(q, - kv, - o, - do, - indices, - lse, - sm_scale=None, - is_casual=True, - return_kernel=False, - delta=None): +def sparse_mla_bwd(q, kv, o, do, indices, lse, sm_scale=None, is_casual=True, return_kernel=False, delta=None): assert q.is_contiguous() assert kv.is_contiguous() assert indices.is_contiguous() @@ -322,6 +277,7 @@ def sparse_mla_bwd(q, def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, sm_scale=None, is_casual=True): from sparse_mla_fwd import ref_sparse_mla_fwd_interface + q = q.detach().clone() kv = kv.detach().clone() q.requires_grad = True @@ -331,30 +287,22 @@ def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, sm_scale=None, is_c return q.grad, kv.grad -def test_sparse_mla_bwd(B=1, - S=4096, - SKV=8192, - H=64, - HKV=1, - DQKV=576, - DV=512, - topk=2048, - dtype=torch.bfloat16, - check_correctness=True): +def test_sparse_mla_bwd(B=1, S=4096, SKV=8192, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16, check_correctness=True): # Prepare data - q = torch.randn((B, S, H, DQKV), dtype=dtype, device='cuda').requires_grad_(True) - kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((B, S, H, DV), dtype=dtype, device='cuda') + q = torch.randn((B, S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, S, H, DV), dtype=dtype, device="cuda") - indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device='cuda') + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") for b in range(B): for t in range(S): for h in range(HKV): i_i = torch.randperm(max(1, t))[:topk] - indices[b, t, h, :len(i_i)] = i_i + indices[b, t, h, : len(i_i)] = i_i # Forward from sparse_mla_fwd import sparse_mla_fwd_interface + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices) tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse) @@ -365,13 +313,15 @@ def test_sparse_mla_bwd(B=1, assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv") print("assert_tensors_similar passed") - per_token_flop = 2 * sum([ - H * DV * topk, - H * DQKV * topk, - H * DQKV * topk, - H * DQKV * topk, - H * DV * topk, - ]) + per_token_flop = 2 * sum( + [ + H * DV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DV * topk, + ] + ) from tilelang.profiler import do_bench def fn(): @@ -379,20 +329,44 @@ def fn(): ms = do_bench(fn, rep=100, warmup=250) print(f"Average time: {ms:.3f} ms") - print(f'bwd io bandwidth = ', - (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12) - print(f'bwd tflops = ', per_token_flop * S / (ms * 1e-3) / 1e12) + print(f"bwd io bandwidth = ", (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12) + print(f"bwd tflops = ", per_token_flop * S / (ms * 1e-3) / 1e12) + + +def run_regression_perf(B=1, S=4096, SKV=8192, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16): + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + q = torch.randn((B, S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, S, H, DV), dtype=dtype, device="cuda") + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") + for b in range(B): + for t in range(S): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[b, t, h, : len(i_i)] = i_i + + from sparse_mla_fwd import sparse_mla_fwd_interface + + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices) + B, S, H, dim_plus_tail_dim = q.shape + _, S_kv, kv_group, _ = kv.shape + D = 512 + D_tail = dim_plus_tail_dim - D + topk = indices.shape[-1] + preprocess_kernel = preprocess(B, S, H, D) + bwd_kernel = bwd(B, S, S_kv, H, D, D_tail, topk, kv_group, None, True) + delta = preprocess_kernel(tl_out, do) + dkv = torch.zeros_like(kv, dtype=torch.float32) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + return bwd_kernel(q, kv, do, indices, tl_lse, delta, dkv) + + return do_bench(run_kernel_only, backend="cupti") if __name__ == "__main__": - test_sparse_mla_bwd( - B=1, - S=4096, - SKV=8192, - H=64, - HKV=1, - DQKV=576, - DV=512, - topk=2048, - dtype=torch.bfloat16, - check_correctness=True) + test_sparse_mla_bwd(B=1, S=4096, SKV=8192, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16, check_correctness=True) diff --git a/examples/deepseek_v32/sparse_mla_fwd.py b/examples/deepseek_v32/sparse_mla_fwd.py index a39c72c40..ddde11f5b 100644 --- a/examples/deepseek_v32/sparse_mla_fwd.py +++ b/examples/deepseek_v32/sparse_mla_fwd.py @@ -25,15 +25,12 @@ def sparse_mla_fwd( num_stages=2, threads=256, ): - assert dim == tilelang.math.next_power_of_2( - dim), f"haven't check padding correctness yet, dim={dim}" - assert tail_dim == tilelang.math.next_power_of_2( - tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" assert is_causal == True, "non-casual is not supported" - assert (topk % - block_I == 0), "otherwise will load some index=0 thus causing wrong kv to be loaded" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" if sm_scale is None: - sm_scale = (1.0 / (dim + tail_dim))**0.5 * 1.44269504 # log2(e) + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e) else: sm_scale = sm_scale * 1.44269504 # log2(e) @@ -47,17 +44,17 @@ def sparse_mla_fwd( o_shape = [batch, seq_len, heads, dim] indices_shape = [batch, seq_len, kv_group, topk] lse_shape = [batch, seq_len, heads] - indices_dtype = "int32" - dtype = "bfloat16" - accum_dtype = "float" + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 G = kv_group H = head_kv padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) if padded_H != H: - assert ( - kv_group == 1 - ), "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) BI = block_I NI = tilelang.cdiv(topk, block_I) D = dim @@ -73,18 +70,17 @@ def sparse_mla_fwd( @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), # type: ignore - KV: T.Tensor(kv_shape, dtype), # type: ignore - Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore - Output: T.Tensor(o_shape, dtype), # type: ignore - Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore ): - with T.Kernel( - seq_len * REPLICATE_H, batch, kv_group, threads=threads) as ( - bx, - by, - bz, - ): + with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as ( + bx, + by, + bz, + ): Q_shared = T.alloc_shared([H_per_block, D], dtype) Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) KV_shared = T.alloc_shared([BI, D], dtype) @@ -118,16 +114,13 @@ def main( T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) for i_i in T.Pipelined(NI, num_stages=num_stages): - for bi_i in T.Parallel(BI): mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i for bi_i, d_i in T.Parallel(BI, D): - KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, - d_i] + KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i] for bi_i, d_i in T.Parallel(BI, D_tail): - K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, - D + d_i] + K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] for h_i, bi_i in T.Parallel(H_per_block, BI): acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) @@ -147,6 +140,8 @@ def main( ) T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(H_per_block): alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(H_per_block, BI): @@ -166,23 +161,13 @@ def main( for h_i in T.Parallel(H_per_block): sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale - T.copy(acc_o, O_shared) T.copy(acc_o, Output[b_i, s_i, H0:H1, :]) - T.copy(sumexp, Lse_shared) T.copy(sumexp, Lse[b_i, s_i, H0:H1]) return main -def sparse_mla_fwd_interface(q, - kv, - indices, - sm_scale=None, - return_p_sum: bool = False, - d_v=512, - block_I=64, - num_stages=2, - threads=256): +def sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512, block_I=64, num_stages=2, threads=256): is_casual = True assert return_p_sum == False, "This kernel file is for fwd only" assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() @@ -199,16 +184,8 @@ def sparse_mla_fwd_interface(q, assert indices.shape == (batch, seq_len, kv_group, topk) kernel = sparse_mla_fwd( - heads, - dim, - tail_dim, - topk, - kv_group, - sm_scale, - is_casual, - block_I=block_I, - num_stages=num_stages, - threads=threads) + heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual, block_I=block_I, num_stages=num_stages, threads=threads + ) out, lse = kernel(q, kv, indices) return out, lse @@ -228,14 +205,14 @@ def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True): b, _, _, dim_v = v.shape g_index = g h_index = h // g - compressed_casual_mask = torch.arange( - 0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( - 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda").view(1, -1) + compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( + 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda" + ).view(1, -1) mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) mask = mask[..., :-1] mask = mask & compressed_casual_mask.view(1, 1, sq, sk) - mask[:, :, :1 - 1, 0] = True + mask[:, :, : 1 - 1, 0] = True mask = mask.view(b, g_index, 1, sq, sk) q = q.view(b, sq, g, -1, dim_q) @@ -250,19 +227,21 @@ def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True): return o.to(torch.bfloat16) -def test_sparse_mla_fwd(B=1, - S=4096, - SKV=8192, - H=128, - HKV=1, - DQK=576, - DV=512, - topk=2048, - dtype=torch.bfloat16, - check_correctness=True, - block_I=64, - num_stages=2, - threads=256): +def test_sparse_mla_fwd( + B=1, + S=4096, + SKV=8192, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + check_correctness=True, + block_I=64, + num_stages=2, + threads=256, +): torch.random.manual_seed(0) q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) @@ -272,10 +251,9 @@ def test_sparse_mla_fwd(B=1, for t in range(S): for h in range(HKV): i_i = torch.randperm(max(1, t))[:topk] - indices[b, t, h, :len(i_i)] = i_i + indices[b, t, h, : len(i_i)] = i_i - tl_out, tl_lse = sparse_mla_fwd_interface( - q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads) + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads) if check_correctness: # otherwise may cause out of memory @@ -284,8 +262,7 @@ def test_sparse_mla_fwd(B=1, print("assert_tensors_similar passed") def fn(): - return sparse_mla_fwd_interface( - q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads) + return sparse_mla_fwd_interface(q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads) from tilelang.profiler import do_bench @@ -299,6 +276,36 @@ def fn(): print("fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) +def run_regression_perf( + B=1, S=4096, SKV=8192, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16, block_I=64, num_stages=2, threads=256 +): + torch.random.manual_seed(0) + q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") + for b in range(B): + for t in range(S): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[b, t, h, : len(i_i)] = i_i + + is_casual = True + _, _, heads, dim_plus_tail_dim = q.shape + _, _, kv_group, _ = kv.shape + dim = 512 + tail_dim = dim_plus_tail_dim - dim + _, _, _, topk = indices.shape + kernel = sparse_mla_fwd(heads, dim, tail_dim, topk, kv_group, None, is_casual, block_I=block_I, num_stages=num_stages, threads=threads) + + def run_kernel_only(): + kernel(q, kv, indices) + + from tilelang.profiler import do_bench + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": test_sparse_mla_fwd( B=1, @@ -313,4 +320,5 @@ def fn(): check_correctness=True, block_I=64, num_stages=2, - threads=256) + threads=256, + ) diff --git a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py index 96dda7df5..7e664d11b 100644 --- a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py +++ b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py @@ -9,10 +9,16 @@ @tilelang.jit( out_idx=[-2, -1], compile_flags=[ - "-O3", "-Wno-deprecated-declarations", "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", - "--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG" + "-O3", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--ptxas-options=-v,--register-usage-level=10", + "-DNDEBUG", ], ) def sparse_mla_fwd( @@ -32,14 +38,12 @@ def sparse_mla_fwd( num_stages=0, threads=384, ): - assert dim == tilelang.math.next_power_of_2( - dim), f"haven't check padding correctness yet, dim={dim}" - assert tail_dim == tilelang.math.next_power_of_2( - tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" - assert is_causal == True, 'non-casual is not supported' - assert topk % block_I == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded' + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert is_causal == True, "non-casual is not supported" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" if sm_scale is None: - sm_scale = (1.0 / (dim + tail_dim))**0.5 * 1.44269504 # log2(e) + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e) else: sm_scale = sm_scale * 1.44269504 # log2(e) @@ -49,23 +53,25 @@ def sparse_mla_fwd( o_shape = [batch, seq_len, heads, dim] indices_shape = [batch, seq_len, kv_group, topk] lse_shape = [batch, seq_len, heads] - indices_dtype = "int32" - dtype = "bfloat16" - accum_dtype = "float" + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 G = kv_group H = head_kv padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) if padded_H != H: - assert kv_group == 1, 'here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)' + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) BI = block_I NI = tilelang.cdiv(topk, block_I) - assert NI % 2 == 0, 'NI should be a multiple of 2' + assert NI % 2 == 0, "NI should be a multiple of 2" D = dim D_tail = tail_dim KV_stride = kv_stride if head_kv > 64: - assert head_kv % 64 == 0, 'head_kv should be a multiple of 64' + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" REPLICATE_H = head_kv // 64 else: REPLICATE_H = 1 @@ -74,18 +80,14 @@ def sparse_mla_fwd( @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), # type: ignore - KV: T.Tensor(kv_shape, dtype), # type: ignore - Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore - q_start_index_s: T.Tensor(1, indices_dtype), - Output: T.Tensor(o_shape, dtype), # type: ignore - Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + q_start_index_s: T.Tensor(1, indices_dtype), + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore ): - with T.Kernel( - (seq_len - kv_stride + 1 if CP0 else seq_len) * REPLICATE_H, - batch, - kv_group, - threads=threads) as (bx, by, bz): + with T.Kernel((seq_len - kv_stride + 1 if CP0 else seq_len) * REPLICATE_H, batch, kv_group, threads=threads) as (bx, by, bz): Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype) Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype) Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) @@ -110,7 +112,7 @@ def main( alpha_local = T.alloc_fragment([H_per_block], accum_dtype) m_i = T.alloc_fragment([H_per_block], accum_dtype) m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) - indices_local = T.alloc_local([1], indices_dtype) + indices_local = T.alloc_var(indices_dtype) # TODO: Multi buffer bar_q = T.alloc_barrier(arrive_count=384) @@ -122,8 +124,7 @@ def main( bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256) b_i, g_i = by, bz - s_i = (bx + (KV_stride - 1 if CP0 else 0)) if REPLICATE_H == 1 else ( - bx // REPLICATE_H + (KV_stride - 1 if CP0 else 0)) + s_i = (bx + (KV_stride - 1 if CP0 else 0)) if REPLICATE_H == 1 else (bx // REPLICATE_H + (KV_stride - 1 if CP0 else 0)) q_i = q_start_index_s[0] + s_i max_kv_i = (q_i + 1 - KV_stride) // KV_stride @@ -132,26 +133,24 @@ def main( tx = T.get_thread_binding() - T.copy(Q[b_i, s_i, H0:H1, 0:D // 2], Q_shared_l) - T.copy(Q[b_i, s_i, H0:H1, D // 2:D], Q_shared_r) + T.copy(Q[b_i, s_i, H0:H1, 0 : D // 2], Q_shared_l) + T.copy(Q[b_i, s_i, H0:H1, D // 2 : D], Q_shared_r) T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) T.barrier_arrive(bar_q) if tx < 128: T.set_max_nreg(240, 1) T.fill(sumexp, 0) - T.fill(m_i, -2**30) # avoid -inf - inf to cause nan + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan T.fill(acc_o_l, 0) T.barrier_wait(bar_q, 0) for i_i in T.serial(T.ceildiv(NI, 2)): - # Buffer 0 T.barrier_wait(bar_k_0_ready[0], (i_i & 1)) for h_i, bi_i in T.Parallel(H_per_block, BI): - acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, - -T.infinity(acc_s.dtype)) + acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1) @@ -164,6 +163,8 @@ def main( T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(H_per_block): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(H_per_block, BI): @@ -185,8 +186,7 @@ def main( T.barrier_wait(bar_k_1_ready[0], (i_i & 1)) for h_i, bi_i in T.Parallel(H_per_block, BI): - acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, - -T.infinity(acc_s.dtype)) + acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1) @@ -198,6 +198,8 @@ def main( T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(H_per_block): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(H_per_block, BI): @@ -223,7 +225,7 @@ def main( for h_i in T.Parallel(H_per_block): sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale T.copy(acc_o_l, O_shared_l) - T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0:D // 2]) + T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0 : D // 2]) elif tx >= 128 and tx < 256: T.set_max_nreg(168, 1) @@ -253,7 +255,7 @@ def main( acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] T.copy(acc_o_r, O_shared_r) - T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2:D]) + T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2 : D]) elif tx >= 256: # producer T.set_max_nreg(80, 0) @@ -261,70 +263,58 @@ def main( # Buffer 0 T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - indices_local[0] = Indices[b_i, s_i, g_i, - (i_i * 2) * BI + r * 16 + (tx - 256) // 8] - is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i + indices_local = Indices[b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8] + is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local <= max_kv_i if is_kv_valid[r * 16 + (tx - 256) // 8]: with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_0_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_0_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, D // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local, g_i, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local, g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, - D + (tx - 256) % 8 * 8 + v] + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local, g_i, D + (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_0_ready[0]) # Buffer 1 T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - indices_local[0] = Indices[b_i, s_i, g_i, - (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8] - is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i + indices_local = Indices[b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8] + is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local <= max_kv_i if is_kv_valid[r * 16 + (tx - 256) // 8]: with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_1_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_1_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, D // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local, g_i, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local, g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, - D + (tx - 256) % 8 * 8 + v] + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local, g_i, D + (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_1_ready[0]) return main -def sparse_mla_fwd_interface(q, - kv, - indices, - q_start_index_s, - kv_stride, - sm_scale=None, - is_casual=True, - return_kernel=False, - print_kernel=False): +def sparse_mla_fwd_interface( + q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_casual=True, return_kernel=False, print_kernel=False +): assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() batch, seq_len, heads, dim_plus_tail_dim = q.shape _, seq_len_kv, kv_group, _ = kv.shape - assert dim_plus_tail_dim == 576, 'you should assign dim otherwise' + assert dim_plus_tail_dim == 576, "you should assign dim otherwise" dim = 512 assert kv.shape[-1] == dim_plus_tail_dim @@ -334,29 +324,23 @@ def sparse_mla_fwd_interface(q, assert indices.shape == (batch, seq_len, kv_group, topk) if q_start_index_s != 0: - assert q_start_index_s > kv_stride, "If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)" + assert q_start_index_s > kv_stride, ( + "If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)" + ) CP0 = q_start_index_s == 0 - kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, - kv_group, sm_scale, is_casual, CP0) + kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_casual, CP0) if print_kernel: print(kernel.get_kernel_source()) - out, lse = kernel(q, kv, indices, - torch.tensor([q_start_index_s], dtype=torch.int32, device="cuda")) + out, lse = kernel(q, kv, indices, torch.tensor([q_start_index_s], dtype=torch.int32, device="cuda")) if return_kernel: return kernel if q_start_index_s == 0 and kv_stride > 1: - out[:, :kv_stride - 1, :, :] = 0 + out[:, : kv_stride - 1, :, :] = 0 return out, lse -def ref_sparse_mla_fwd_interface(q, - kv, - indices, - q_start_index_s, - kv_stride=4, - sm_scale=None, - is_casual=True): +def ref_sparse_mla_fwd_interface(q, kv, indices, q_start_index_s, kv_stride=4, sm_scale=None, is_casual=True): q = q.float() kv = kv.float() indices = indices.transpose(1, 2) @@ -365,7 +349,7 @@ def ref_sparse_mla_fwd_interface(q, if q_start_index_s is None: q_start_index_s = sk * kv_stride - sq - assert kv.shape[-1] == 576, 'you should assign dim otherwise' + assert kv.shape[-1] == 576, "you should assign dim otherwise" dim = 512 k = kv v = kv[..., :dim] @@ -374,15 +358,14 @@ def ref_sparse_mla_fwd_interface(q, num_kv_per_index = 1 g_index = g h_index = h // g - compressed_casual_mask = torch.arange( - q_start_index_s, sq + q_start_index_s, dtype=torch.int32, - device="cuda").view(-1, 1) >= torch.arange( - kv_stride - 1, sk * kv_stride, kv_stride, dtype=torch.int32, device="cuda").view(1, -1) + compressed_casual_mask = torch.arange(q_start_index_s, sq + q_start_index_s, dtype=torch.int32, device="cuda").view( + -1, 1 + ) >= torch.arange(kv_stride - 1, sk * kv_stride, kv_stride, dtype=torch.int32, device="cuda").view(1, -1) mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) mask = mask[..., :-1] mask = mask & compressed_casual_mask.view(1, 1, sq, sk) - mask[:, :, :kv_stride - 1, 0] = True + mask[:, :, : kv_stride - 1, 0] = True mask = mask.view(b, g_index, 1, sq, sk) q = q.view(b, sq, g, -1, dim_q) @@ -397,41 +380,32 @@ def ref_sparse_mla_fwd_interface(q, return o.to(torch.bfloat16) -def test_sparse_mla_fwd_pipelined(B=1, - S=4096, - SKV=8192, - H=128, - HKV=1, - DQK=576, - DV=512, - topk=2048, - dtype=torch.bfloat16, - q_start_s_index=1024, - check_correctness=True): +def test_sparse_mla_fwd_pipelined( + B=1, S=4096, SKV=8192, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16, q_start_s_index=1024, check_correctness=True +): KV_stride = 1 torch.random.manual_seed(0) - q = torch.randn((B, S, H, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10 - kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10 + q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 q_start_s_index_t = torch.tensor([q_start_s_index], dtype=torch.int32, device="cuda") q.clamp_(-10, 10) kv.clamp_(-10, 10) - indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device='cuda') + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") for b in range(B): for t in range(S): for h in range(HKV): i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk] - indices[b, t, h, :len(i_i)] = i_i + indices[b, t, h, : len(i_i)] = i_i - kernel = sparse_mla_fwd_interface( - q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True) + kernel = sparse_mla_fwd_interface(q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True) def fn(): out, lse = kernel(q, kv, indices, q_start_s_index_t) if q_start_s_index == 0 and KV_stride > 1: - out[:, :KV_stride - 1, :, :] = 0 + out[:, : KV_stride - 1, :, :] = 0 return out, lse tl_out, tl_lse = fn() @@ -442,14 +416,46 @@ def fn(): torch.testing.assert_close(tl_out, ref_out, rtol=1e-3, atol=1e-3) from tilelang.profiler import do_bench + ms = do_bench( fn, rep=10, warmup=10, ) print(f"Average time: {ms:.3f} ms") - print(f'fwd io bandwidth = ', (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) - print(f'fwd tflops = ', (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) + print(f"fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) + print(f"fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) + + +def run_regression_perf(B=1, S=4096, SKV=8192, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16, q_start_s_index=1024): + KV_stride = 1 + + torch.random.manual_seed(0) + q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 + q.clamp_(-10, 10) + kv.clamp_(-10, 10) + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") + for b in range(B): + for t in range(S): + for h in range(HKV): + i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk] + indices[b, t, h, : len(i_i)] = i_i + + batch, seq_len, heads, dim_plus_tail_dim = q.shape + _, seq_len_kv, kv_group, _ = kv.shape + dim = 512 + tail_dim = dim_plus_tail_dim - dim + CP0 = q_start_s_index == 0 + kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, KV_stride, kv_group, None, True, CP0) + + def run_kernel_only(): + kernel(q, kv, indices, torch.tensor([q_start_s_index], dtype=torch.int32, device="cuda")) + + from tilelang.profiler import do_bench + + return do_bench(run_kernel_only, backend="cupti") if __name__ == "__main__": @@ -460,5 +466,4 @@ def fn(): B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 else: B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 - test_sparse_mla_fwd_pipelined( - B, S, SKV, H, HKV, DQK, DV, topk, dtype, check_correctness=args.test_correctness) + test_sparse_mla_fwd_pipelined(B, S, SKV, H, HKV, DQK, DV, topk, dtype, check_correctness=args.test_correctness) diff --git a/examples/deepseek_v32/sparse_mla_fwd_seesaw.py b/examples/deepseek_v32/sparse_mla_fwd_seesaw.py new file mode 100644 index 000000000..5d155f851 --- /dev/null +++ b/examples/deepseek_v32/sparse_mla_fwd_seesaw.py @@ -0,0 +1,644 @@ +# ruff: noqa +import torch +import tilelang +from tilelang import language as T +import argparse + + +@tilelang.jit( + out_idx=[-2, -1], + compile_flags=[ + "-O3", + "--ptxas-options=-v,--register-usage-level=10", + "-DNDEBUG", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + ], +) +def sparse_mla_fwd( + batch, + seq_len, + seq_len_kv, + heads, + dim, + tail_dim, + topk, + kv_stride, + kv_group=1, + sm_scale=None, + is_causal=True, + CP0=True, + block_I=64, + num_stages=0, + threads=384, +): + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert is_causal == True, "non-casual is not supported" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e) + else: + sm_scale = sm_scale * 1.44269504 # log2(e) + + head_kv = heads // kv_group + q_shape = [batch, seq_len, heads, dim + tail_dim] + kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim] + o_shape = [batch, seq_len, heads, dim] + indices_shape = [batch, seq_len, kv_group, topk] + lse_shape = [batch, seq_len, heads] + indices_dtype = "int32" + dtype = "bfloat16" + accum_dtype = "float" + + G = kv_group + H = head_kv + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != H: + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you " + "should handle Q copy and Output copy with your mask (when " + "kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would " + "be handled automatically)" + ) + BI = block_I + NI = tilelang.cdiv(topk, block_I) + assert NI % 2 == 0, "NI should be a multiple of 2" + D = dim + D_tail = tail_dim + KV_stride = kv_stride + if head_kv > 64: + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" + REPLICATE_H = head_kv // 64 + else: + REPLICATE_H = 1 + + # Increasing from 32->64 reduces the time spent reading kvcache. If num_query_head = 128 + # and num_kv_head = 1, the same kvcache originally needed to be read 4 times, but now only 2 times + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + q_start_index_s: T.Tensor(1, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + ): + with T.Kernel( + # If CP0 is True (i.e., start of sequence), skip the first (KV_stride - 1) + # queries that cannot see any KV. Also be careful that seq_len < kv_stride could cause negative grid size + (max(0, seq_len - kv_stride + 1) if CP0 else seq_len) * REPLICATE_H, + batch, + kv_group, + threads=threads, + ) as (bx, by, bz): + Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype) + Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + + KV_shared_0_l = T.alloc_shared([BI, D // 2], dtype) + KV_shared_0_r = T.alloc_shared([BI, D // 2], dtype) + KV_shared_1_l = T.alloc_shared([BI, D // 2], dtype) + KV_shared_1_r = T.alloc_shared([BI, D // 2], dtype) + K_tail_shared_0 = T.alloc_shared([BI, D_tail], dtype) + K_tail_shared_1 = T.alloc_shared([BI, D_tail], dtype) + + O_shared_l = Q_shared_l + O_shared_r = Q_shared_r + + # Whether the kv in current BI is visible for this query + # Producer alternates writing to buf0 and buf1 masks. To avoid the situation + # where consumer0 is still reading buf0 mask when producer has already started + # writing buf1 mask, we use two buf masks + is_kv_valid = T.alloc_shared([2, BI], "bool", scope="shared") + + acc_o_l = T.alloc_fragment([H_per_block, D // 2], accum_dtype) + acc_o_r = T.alloc_fragment([H_per_block, D // 2], accum_dtype) + + # WG0 computes S0(BI_2*i), WG1 computes S1(BI_2*i+1), shared via shared memory + + # Reuse K_tail_shared for S_shared to save memory when dimensions match + # Must reuse, otherwise H100 SM's shared mem is insufficient (> 228kb), this is shared mem bound + S_shared_0 = K_tail_shared_0 + S_shared_1 = K_tail_shared_1 + + # WG0 and WG1 exchange local max with each other, compare to compute global max, and rescale their O_L or O_R accordingly + row_max_shared_0 = T.alloc_shared([H_per_block], accum_dtype) + row_max_shared_1 = T.alloc_shared([H_per_block], accum_dtype) + + # Used to store sum of exps for even BI and odd BI respectively, which will be summed up for integration later + row_sum_shared_0 = T.alloc_shared([H_per_block], accum_dtype) + row_sum_shared_1 = T.alloc_shared([H_per_block], accum_dtype) + + # acc_s, sumexp, m_i each need to be allocated separately for consumer0 and consumer1 + acc_s_0 = T.alloc_fragment([H_per_block, BI], accum_dtype) + acc_s_1 = T.alloc_fragment([H_per_block, BI], accum_dtype) + + sumexp_0 = T.alloc_fragment([H_per_block], accum_dtype) + sumexp_i_0 = T.alloc_fragment([H_per_block], accum_dtype) + m_i_0 = T.alloc_fragment([H_per_block], accum_dtype) + m_i_prev_0 = T.alloc_fragment([H_per_block], accum_dtype) + m_i_peer_0 = T.alloc_fragment([H_per_block], accum_dtype) + + sumexp_1 = T.alloc_fragment([H_per_block], accum_dtype) + sumexp_i_1 = T.alloc_fragment([H_per_block], accum_dtype) + m_i_1 = T.alloc_fragment([H_per_block], accum_dtype) + m_i_prev_1 = T.alloc_fragment([H_per_block], accum_dtype) + m_i_peer_1 = T.alloc_fragment([H_per_block], accum_dtype) + + bar_q = T.alloc_barrier(arrive_count=384) + + # Producer -> Consumer Barriers + bar_k_0_ready = T.alloc_barrier(arrive_count=128) # Prod arrives + bar_k_1_ready = T.alloc_barrier(arrive_count=128) # Prod arrives + + # Consumer -> Producer Barriers (Both consumers must arrive) + bar_k_0_free = T.alloc_barrier(arrive_count=256) + bar_k_1_free = T.alloc_barrier(arrive_count=256) + + # Inter-Consumer Barriers (Seesaw Sync) + bar_stats_0_ready = T.alloc_barrier(arrive_count=128) # Cons 0 arrives + bar_stats_1_ready = T.alloc_barrier(arrive_count=128) # Cons 1 arrives + + bar_S_0_ready = T.alloc_barrier(arrive_count=128) # Cons 0 arrives + bar_S_1_ready = T.alloc_barrier(arrive_count=128) # Cons 1 arrives + + b_i, g_i = by, bz + # If it's the first chunk, start computing directly from the (kv_stride - 1)-th token + s_i = (bx + (KV_stride - 1 if CP0 else 0)) if REPLICATE_H == 1 else (bx // REPLICATE_H + (KV_stride - 1 if CP0 else 0)) + q_i = q_start_index_s[0] + s_i + # Sometimes to reduce kvcache size, we may not store KV for every token, but store + # KV every KV_stride tokens (usually the last token in the stride window), + # so the kv range visible to the current query should be [0:max_kv_i] + max_kv_i = (q_i + 1 - KV_stride) // KV_stride + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + tx = T.get_thread_binding() + + T.copy(Q[b_i, s_i, H0:H1, 0 : D // 2], Q_shared_l) + T.copy(Q[b_i, s_i, H0:H1, D // 2 : D], Q_shared_r) + T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) + + # Non-blockingly increment the barrier's internal counter, producer threads can start loading kv ahead of time + T.barrier_arrive(bar_q) + + if tx >= 256: + # producer: prefetch kvcache to shared mem + T.set_max_nreg(72, 0) + + prefetch_indices_0 = T.alloc_fragment([4], indices_dtype) + prefetch_indices_1 = T.alloc_fragment([4], indices_dtype) + + # Prime the Pump! Prefetch indices for iter_0 + for r in T.serial(4): + # This read will cause a long scoreboard stall, but it only happens once before the loop starts + prefetch_indices_0[r] = Indices[b_i, s_i, g_i, r * 16 + (tx - 256) // 8] + prefetch_indices_1[r] = Indices[b_i, s_i, g_i, BI + r * 16 + (tx - 256) // 8] + + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + # Wait for both KV_shared_0_l and KV_shared_0_r to be done being used + + T.barrier_wait(bar_k_0_free[0], (i_i & 1)) + + # Block size `BI` is 64, loading is divided into 4 iterations, each processing 16 indices + # Producer has 128 threads total, 8 consecutive threads collaborate to load kv for one index + for r in T.serial(4): + # mitigate long scoreboard stall here + index = prefetch_indices_0[r] + is_kv_valid[0, r * 16 + (tx - 256) // 8] = index <= max_kv_i + if is_kv_valid[0, r * 16 + (tx - 256) // 8]: + # Here we assume dim = 512, tail_dim = 64 + with T.attr("default", "async_scope", 1): + # 8 threads collaborate to load one row of KV_dim (length 512), divided into 4 iterations + # In each iteration, each thread loads 8 consecutive elements for both KV_shared_0_l + # and KV_shared_0_r, 8 threads load 64 elements total for each + for u in T.serial(4): + for v in T.vectorized(8): + # (tx - 256) // 8 determines which row the thread is responsible for, + # (tx - 256) % 8 determines which part of the row the thread loads + KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, index, g_i, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, index, g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] + with T.attr("default", "async_scope", 1): + # tail_dim (length 64) only needs 8 threads collaborating in one iteration to complete loading + for v in T.vectorized(8): + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[ + b_i, index, g_i, D + (tx - 256) % 8 * 8 + v + ] + T.cp_async_barrier_noinc(bar_k_0_ready[0]) + + if i_i + 1 < T.ceildiv(NI, 2): + # Async prefetch indices needed for the next round of kv data loading, overlaps with current round to hide latency + for r in T.serial(4): + prefetch_indices_0[r] = Indices[b_i, s_i, g_i, ((i_i + 1) * 2) * BI + r * 16 + (tx - 256) // 8] + + # Buffer 1 + T.barrier_wait(bar_k_1_free[0], (i_i & 1)) + + for r in T.serial(4): + index = prefetch_indices_1[r] + is_kv_valid[1, r * 16 + (tx - 256) // 8] = index <= max_kv_i + if is_kv_valid[1, r * 16 + (tx - 256) // 8]: + with T.attr("default", "async_scope", 1): + for u in T.serial(4): + for v in T.vectorized(8): + KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, index, g_i, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, index, g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] + with T.attr("default", "async_scope", 1): + for v in T.vectorized(8): + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[ + b_i, index, g_i, D + (tx - 256) % 8 * 8 + v + ] + T.cp_async_barrier_noinc(bar_k_1_ready[0]) + + if i_i + 1 < T.ceildiv(NI, 2): + for r in T.serial(4): + prefetch_indices_1[r] = Indices[b_i, s_i, g_i, ((i_i + 1) * 2 + 1) * BI + r * 16 + (tx - 256) // 8] + + elif tx < 128: + # Check if 384 threads have already arrived at bar_q (phase0 completed), + # if not continue waiting, otherwise pass through directly + T.barrier_wait(bar_q, 0) + + # pre-arrive free barriers to indicate buffers are initially free + # At the beginning of phase0, tells producer it can load data into both buffers + T.barrier_arrive(bar_k_0_free[0]) + T.barrier_arrive(bar_k_1_free[0]) + + # Consumer 0 (WG0): Responsible for Even Blocks and O_L (Left Half) + T.set_max_nreg(216, 1) + T.fill(sumexp_0, 0) + for h_i in T.Parallel(H_per_block): + m_i_0[h_i] = -5e4 + T.fill(acc_o_l, 0) + + # Each iteration, two consumers cooperate to compute two BIs + for i_i in T.serial(T.ceildiv(NI, 2)): + # --- Step 1: Compute S0 = Q @ K0^T (Even Block) --- + T.barrier_wait(bar_k_0_ready[0], (i_i & 1)) + + T.fill(acc_s_0, 0) + T.gemm(Q_shared_l, KV_shared_0_l, acc_s_0, transpose_B=True, wg_wait=-1) + T.gemm(Q_shared_r, KV_shared_0_r, acc_s_0, transpose_B=True, wg_wait=-1) + T.gemm(Q_tail_shared, K_tail_shared_0, acc_s_0, transpose_B=True, wg_wait=-1) + + T.copy(m_i_0, m_i_prev_0) + T.wait_wgmma(0) + + for h_i, bi_i in T.Parallel(H_per_block, BI): + if not is_kv_valid[0, bi_i]: + acc_s_0[h_i, bi_i] = -5e4 + T.reduce_max(acc_s_0, m_i_0, dim=1, clear=False) + + # --- Step 2: Local Softmax Stats & Exchange --- + T.copy(m_i_0, row_max_shared_0) + T.barrier_arrive(bar_stats_0_ready) + # If consumer0 has received the local max from consumer1 at iter_i, this also means + # consumer1 has finished using S_0 passed by consumer0 at iter_i-1, + # so we can write to it directly without blocking below + T.barrier_wait(bar_stats_1_ready, (i_i & 1)) + T.copy(row_max_shared_1, m_i_peer_0) + + # Update global max and scale O + for h_i in T.Parallel(H_per_block): + m_i_0[h_i] = T.max(m_i_0[h_i], m_i_peer_0[h_i]) + + # Scale O_L + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_l[h_i, d_i] *= T.exp2((m_i_prev_0[h_i] - m_i_0[h_i]) * sm_scale) + + # Scale SumExp + for h_i in T.Parallel(H_per_block): + sumexp_0[h_i] *= T.exp2((m_i_prev_0[h_i] - m_i_0[h_i]) * sm_scale) + + # Compute P0 = exp(S0 - m_new) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s_0[h_i, bi_i] = T.exp2(acc_s_0[h_i, bi_i] * sm_scale - m_i_0[h_i] * sm_scale) + + # Update SumExp with P0 + T.reduce_sum(acc_s_0, sumexp_i_0, dim=1) + for h_i in T.Parallel(H_per_block): + sumexp_0[h_i] += sumexp_i_0[h_i] + + # --- Step 3: O_L += P0 @ V0_L (Self-Attention) --- + # Wait for S0 buffer to be free (consumed by peer in prev iter) + # T.barrier_wait(bar_S_0_free, (i_i & 1)) + T.copy(acc_s_0, S_shared_0) + T.barrier_arrive(bar_S_0_ready) + + T.gemm(S_shared_0, KV_shared_0_l, acc_o_l, transpose_B=False, wg_wait=-1) + + # --- Step 4: O_L += P1 @ V1_L (Cross-Attention) --- + # Wait for P1 (S1) from peer + T.barrier_wait(bar_S_1_ready, (i_i & 1)) + + T.gemm(S_shared_1, KV_shared_1_l, acc_o_l, transpose_B=False, wg_wait=-1) + + # NOTE: However, k_0 and k_1 are used by both consumer0 and consumer1, so this doesn't bring much performance improvement + # Except for the most recent async gemm (i.e., S_shared_1 @ KV_shared_1_k), all others need to wait to finish + T.wait_wgmma(1) + T.barrier_arrive(bar_k_0_free[0]) + # Wait for all async gemms to finish + T.wait_wgmma(0) + T.barrier_arrive(bar_k_1_free[0]) + + T.copy(sumexp_0, row_sum_shared_0) + T.barrier_arrive(bar_stats_0_ready) # Reuse barrier + T.barrier_wait(bar_stats_1_ready, T.ceildiv(NI, 2) & 1) + T.copy(row_sum_shared_1, sumexp_i_0) # Reuse sumexp_i buffer + + for h_i in T.Parallel(H_per_block): + sumexp_0[h_i] += sumexp_i_0[h_i] + + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_l[h_i, d_i] /= sumexp_0[h_i] + + for h_i in T.Parallel(H_per_block): + sumexp_0[h_i] = T.log2(sumexp_0[h_i]) + m_i_0[h_i] * sm_scale + + T.copy(acc_o_l, O_shared_l) + T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0 : D // 2]) + T.copy(sumexp_0, Lse[b_i, s_i, H0:H1]) # Write LSE + + elif tx >= 128 and tx < 256: + T.barrier_wait(bar_q, 0) + + # pre-arrive free barriers to indicate buffers are initially free + # At the beginning of phase0, tells producer it can load data into both buffers + T.barrier_arrive(bar_k_0_free[0]) + T.barrier_arrive(bar_k_1_free[0]) + + # Consumer 1 (WG1): Responsible for Odd Blocks and O_R (Right Half) + # NOTE: 256 * 216 + 128 * 72 = 64,512 < 65536 (H100 SM RegFile Limit), + # setting more registers will cause a hang, all values must be multiples of 8 + T.set_max_nreg(216, 1) + T.fill(sumexp_1, 0) + for h_i in T.Parallel(H_per_block): + m_i_1[h_i] = -5e4 + T.fill(acc_o_r, 0) + + for i_i in T.serial(T.ceildiv(NI, 2)): + # --- Step 1: Compute S1 = Q @ K1^T (Odd Block) --- + T.barrier_wait(bar_k_1_ready[0], (i_i & 1)) + + T.fill(acc_s_1, 0) + T.gemm(Q_shared_l, KV_shared_1_l, acc_s_1, transpose_B=True, wg_wait=-1) + T.gemm(Q_shared_r, KV_shared_1_r, acc_s_1, transpose_B=True, wg_wait=-1) + T.gemm(Q_tail_shared, K_tail_shared_1, acc_s_1, transpose_B=True, wg_wait=-1) + + # --- Step 2: Local Softmax Stats & Exchange --- + T.copy(m_i_1, m_i_prev_1) + T.wait_wgmma(0) + + for h_i, bi_i in T.Parallel(H_per_block, BI): + if not is_kv_valid[1, bi_i]: + acc_s_1[h_i, bi_i] = -5e4 + + T.reduce_max(acc_s_1, m_i_1, dim=1, clear=False) + T.copy(m_i_1, row_max_shared_1) + T.barrier_arrive(bar_stats_1_ready) + T.barrier_wait(bar_stats_0_ready, (i_i & 1)) + T.copy(row_max_shared_0, m_i_peer_1) + + for h_i in T.Parallel(H_per_block): + m_i_1[h_i] = T.max(m_i_1[h_i], m_i_peer_1[h_i]) + + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_r[h_i, d_i] *= T.exp2((m_i_prev_1[h_i] - m_i_1[h_i]) * sm_scale) + + for h_i in T.Parallel(H_per_block): + sumexp_1[h_i] *= T.exp2((m_i_prev_1[h_i] - m_i_1[h_i]) * sm_scale) + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s_1[h_i, bi_i] = T.exp2(acc_s_1[h_i, bi_i] * sm_scale - m_i_1[h_i] * sm_scale) + + T.reduce_sum(acc_s_1, sumexp_i_1, dim=1) + for h_i in T.Parallel(H_per_block): + sumexp_1[h_i] += sumexp_i_1[h_i] + + # --- Step 3: O_R += P1 @ V1_R (Self-Attention) --- + T.copy(acc_s_1, S_shared_1) + + T.barrier_arrive(bar_S_1_ready) + + T.gemm(S_shared_1, KV_shared_1_r, acc_o_r, transpose_B=False, wg_wait=-1) + + # --- Step 4: O_R += P0 @ V0_R (Cross-Attention) --- + T.barrier_wait(bar_S_0_ready, (i_i & 1)) + + T.gemm(S_shared_0, KV_shared_0_r, acc_o_r, transpose_B=False, wg_wait=-1) + + T.wait_wgmma(1) + T.barrier_arrive(bar_k_1_free[0]) + T.wait_wgmma(0) + T.barrier_arrive(bar_k_0_free[0]) + + T.copy(sumexp_1, row_sum_shared_1) + T.barrier_arrive(bar_stats_1_ready) + T.barrier_wait(bar_stats_0_ready, T.ceildiv(NI, 2) & 1) + T.copy(row_sum_shared_0, sumexp_i_1) + + for h_i in T.Parallel(H_per_block): + sumexp_1[h_i] += sumexp_i_1[h_i] + + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_r[h_i, d_i] /= sumexp_1[h_i] + + T.copy(acc_o_r, O_shared_r) + T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2 : D]) + + return main + + +def sparse_mla_fwd_interface( + q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_casual=True, return_kernel=False, print_kernel=False +): + assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() + batch, seq_len, heads, dim_plus_tail_dim = q.shape + _, seq_len_kv, kv_group, _ = kv.shape + + assert dim_plus_tail_dim == 576, "you should assign dim otherwise" + dim = 512 + + assert kv.shape[-1] == dim_plus_tail_dim + tail_dim = dim_plus_tail_dim - dim + assert kv.shape[0] == batch + _, _, _, topk = indices.shape + assert indices.shape == (batch, seq_len, kv_group, topk) + + if q_start_index_s != 0: + assert q_start_index_s > kv_stride, ( + "If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)" + ) + CP0 = q_start_index_s == 0 + + # Compile the kernel + kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_casual, CP0) + + if print_kernel: + print(kernel.get_kernel_source()) + + if return_kernel: + return kernel + + ( + out, + lse, + ) = kernel(q, kv, indices, torch.tensor([q_start_index_s], dtype=torch.int32, device="cuda")) + if q_start_index_s == 0 and kv_stride > 1: + # Set the output of the first (kv_stride - 1) positions to 0, since they cannot see any kv so no computation was performed + out[:, : kv_stride - 1, :, :] = 0 + return out, lse + + +def ref_sparse_mla_fwd_interface(q, kv, indices, q_start_index_s, kv_stride=1, sm_scale=None, is_casual=True): + q = q.float() + kv = kv.float() + indices = indices.transpose(1, 2) + b, sq, h, dim_q = q.shape + b, sk, g, _ = kv.shape + if q_start_index_s is None: + q_start_index_s = sk * kv_stride - sq + + assert kv.shape[-1] == 576, "you should assign dim otherwise" + dim = 512 + k = kv + v = kv[..., :dim] + + b, _, _, dim_v = v.shape + num_kv_per_index = 1 + g_index = g + h_index = h // g + compressed_casual_mask = torch.arange(q_start_index_s, sq + q_start_index_s, dtype=torch.int32, device="cuda").view( + -1, 1 + ) >= torch.arange(kv_stride - 1, sk * kv_stride, kv_stride, dtype=torch.int32, device="cuda").view(1, -1) + + mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) + mask = mask[..., :-1] + mask = mask & compressed_casual_mask.view(1, 1, sq, sk) + mask[:, :, : kv_stride - 1, 0] = True + mask = mask.view(b, g_index, 1, sq, sk) + + q = q.view(b, sq, g, -1, dim_q) + score = torch.einsum("bmghd,bngd->bghmn", q, k) + sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale + score = score.masked_fill(~mask, float("-inf")).mul(sm_scale) + p = score.softmax(dim=-1) + p = p.view(b, g_index, h_index, -1, sq, sk) + p = p.view(b, g, -1, sq, sk) + o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v) + o = o.reshape(b, sq, h, dim_v) + return o.to(torch.bfloat16) + + +def test_sparse_mla_fwd_pipelined( + B=1, + S=4096, + SKV=8192, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + # Offset of query in global sequence position (or relative to kv) + q_start_s_index=2048, + check_correctness=True, + profile=False, +): + KV_stride = 1 + + torch.random.manual_seed(0) + q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 + q_start_s_index_t = torch.tensor([q_start_s_index], dtype=torch.int32, device="cuda") + + q.clamp_(-10, 10) + kv.clamp_(-10, 10) + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") + for b in range(B): + for t in range(S): + for h in range(HKV): + # Add offset q_start_s_index to convert to global sequence position + i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk] + indices[b, t, h, : len(i_i)] = i_i + + print("index generation finished") + + kernel = sparse_mla_fwd_interface(q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True) + + def fn(): + return kernel(q, kv, indices, q_start_s_index_t) + + if check_correctness: + tl_out, tl_lse = fn() + assert KV_stride == 1, "KV_stride > 1 not supported" + # if q_start_s_index == 0 and KV_stride > 1: + # tl_out[:, :KV_stride - 1, :, :] = 0 + ref_out = ref_sparse_mla_fwd_interface(q, kv, indices, q_start_s_index, KV_stride) + print(f"tl_out: {tl_out}") + print(f"ref_out: {ref_out}") + torch.testing.assert_close(tl_out, ref_out, rtol=1e-3, atol=1e-3) + + if profile: + print("Profiling mode: running minimal iterations (1 warmup + 1 run)...") + fn() + torch.cuda.synchronize() + fn() + torch.cuda.synchronize() + return + + from tilelang.profiler import do_bench + + ms = do_bench( + fn, + rep=20, + warmup=10, + ) + print(f"Average time: {ms:.3f} ms") + print(f"fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) + tflops = (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12 + print(f"fwd tflops = {tflops:.2f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--test_correctness", action="store_true") + parser.add_argument("--profile", action="store_true") + args = parser.parse_args() + if args.test_correctness: + B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 + test_sparse_mla_fwd_pipelined(B, S, SKV, H, HKV, DQK, DV, topk, dtype, check_correctness=True, profile=args.profile) + else: + # Prefill Benchmark: long context + print(" --- Prefill Benchmark --- ") + B, S, SKV, H, HKV, DQK, DV, topk, dtype = 2, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 + test_sparse_mla_fwd_pipelined( + B, S, SKV, H, HKV, DQK, DV, topk, dtype, q_start_s_index=4096, check_correctness=False, profile=args.profile + ) + + # Decode Benchmark: large batch size, high throughput generation + print("\n --- Decode Benchmark --- ") + # Increase batch size to saturate h100 for decode + B, S, SKV, H, HKV, DQK, DV, topk, dtype = 128 * 16, 2, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 + test_sparse_mla_fwd_pipelined( + B, S, SKV, H, HKV, DQK, DV, topk, dtype, q_start_s_index=2048 + 4096, check_correctness=False, profile=args.profile + ) diff --git a/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py b/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py index e10141b59..983798f9f 100644 --- a/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py +++ b/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py @@ -1,4 +1,5 @@ # ruff: noqa +import tilelang import tilelang.testing import topk_selector @@ -20,23 +21,23 @@ def test_example_fp8_lighting_indexer(): @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_sparse_mla_fwd(): # small shapes for testing - sparse_mla_fwd.test_sparse_mla_fwd( - S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) + sparse_mla_fwd.test_sparse_mla_fwd(S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_sparse_mla_fwd_pipelined(): # small shapes for testing - sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined( - S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) + sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined(S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_sparse_mla_bwd(): + sparse_mla_bwd.test_sparse_mla_bwd(S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False) sparse_mla_bwd.test_sparse_mla_bwd( - S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False) + S=256, SKV=512, H=128, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False + ) # test for large H if __name__ == "__main__": diff --git a/examples/deepseek_v32/topk_selector.py b/examples/deepseek_v32/topk_selector.py index 4a4b43277..8b29c6fd5 100644 --- a/examples/deepseek_v32/topk_selector.py +++ b/examples/deepseek_v32/topk_selector.py @@ -8,24 +8,24 @@ def convert_to_uint16(x): - hval = T.Cast("float16", x) - bits_uint = T.reinterpret("uint16", hval) + hval = T.cast(x, T.float16) + bits_uint = T.reinterpret(hval, T.uint16) bits_uint = T.if_then_else(x < 0, ~bits_uint & (0xFFFF), bits_uint | (0x8000)) return bits_uint >> 8 def convert_to_uint32(x): - bits_uint = T.reinterpret("uint32", x) + bits_uint = T.reinterpret(x, T.uint32) bits_uint = T.if_then_else( x < 0, - ~bits_uint & T.Cast("uint32", (0xFFFFFFFF)), - bits_uint | T.Cast("uint32", (0x80000000)), + ~bits_uint & T.cast((0xFFFFFFFF), T.uint32), + bits_uint | T.cast((0x80000000), T.uint32), ) return bits_uint @tilelang.jit(pass_configs=pass_configs) -def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"): +def tl_topk_impl(topk, in_dtype=T.float32, out_dtype=T.int32): batch = T.dynamic("batch") seq_len = T.dynamic("seq_len") RADIX = 1 << 8 @@ -42,20 +42,22 @@ def tl_topk_kernel( with T.Kernel(batch, threads=BLOCK_SIZE) as (bx): tx = T.get_thread_binding() - s_threshold_bin_id = T.alloc_shared([1], "int32") - s_histogram = T.alloc_shared([RADIX + 1], "int32") - s_num_input = T.alloc_shared([2], "int32") - s_input_idx = T.alloc_shared([2, SMEM_INPUT_SIZE], "int32") - - l_threshold_bin_id = T.alloc_var("int32") - l_new_topk = T.alloc_var("int32") - l_num_input = T.alloc_var("int32") - l_bin_id32 = T.alloc_var("int32") - l_val = T.alloc_var("int32") - l_start_pos = T.alloc_var("int32") - l_start_idx = T.alloc_var("int32") - l_end_idx = T.alloc_var("int32") - l_out_pos = T.alloc_var("int32") + s_threshold_bin_id = T.alloc_shared([1], T.int32) + s_histogram = T.alloc_shared([RADIX + 1], T.int32) + s_num_input = T.alloc_shared([2], T.int32) + s_input_idx = T.alloc_shared([2, SMEM_INPUT_SIZE], T.int32) + + l_threshold_bin_id = T.alloc_var(T.int32) + l_new_topk = T.alloc_var(T.int32) + l_num_input = T.alloc_var(T.int32) + l_bin_id32 = T.alloc_var(T.int32) + l_val = T.alloc_var(T.int32) + l_start_pos = T.alloc_var(T.int32) + l_start_idx = T.alloc_var(T.int32) + l_end_idx = T.alloc_var(T.int32) + l_out_pos = T.alloc_var(T.int32) + + pos = T.alloc_var(T.int32) l_new_topk = topk l_start_idx = starts[bx] @@ -99,7 +101,7 @@ def tl_topk_kernel( input_idx = s * BLOCK_SIZE + tx if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len: bin_id = convert_to_uint16(input[bx, input_idx]) - l_bin_id32 = T.Cast("int32", bin_id) + l_bin_id32 = T.cast(bin_id, T.int32) if l_bin_id32 > l_threshold_bin_id: # need a pos = T.atomic_add(s_histogram[bin_id32+1], 1) pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) @@ -113,7 +115,7 @@ def tl_topk_kernel( # stage 2: tail pass for round in T.serial(4): if l_new_topk <= 0: - T.loop_break() + break r_idx = round % 2 l_start_pos = topk - l_new_topk @@ -127,9 +129,9 @@ def tl_topk_kernel( l_num_input = s_num_input[r_idx] for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): if s * BLOCK_SIZE + tx < l_num_input: - l_bin_id32 = T.Cast("int32", (( - convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> - (24 - round * 8)) & 0xFF)) + l_bin_id32 = T.cast( + ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF), T.int32 + ) T.atomic_add(s_histogram[l_bin_id32], 1) T.sync_threads() # cumsum @@ -156,23 +158,20 @@ def tl_topk_kernel( for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): T.sync_threads() if s * BLOCK_SIZE + tx < l_num_input: - l_bin_id32 = T.Cast("int32", (( - convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> - (24 - round * 8)) & 0xFF)) + l_bin_id32 = T.cast( + ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF), T.int32 + ) if l_bin_id32 > l_threshold_bin_id: - pos = T.atomic_add( - s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos + pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos index[bx, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: if round == 3: - l_out_pos = T.atomic_add( - s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos + l_out_pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos if l_out_pos < topk: index[bx, l_out_pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] else: pos = T.atomic_add(s_num_input[r_idx ^ 1], 1, return_prev=True) - s_input_idx[r_idx ^ 1, pos] = s_input_idx[r_idx, - s * BLOCK_SIZE + tx] + s_input_idx[r_idx ^ 1, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] return tl_topk_kernel @@ -186,10 +185,6 @@ def tl_topk(input, starts, ends, topk): def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): - - batch = 64 - seq_len = 32 * 1024 - topk = 2048 torch.manual_seed(1) input = torch.randn(batch, seq_len, dtype=torch.float32).cuda() starts = torch.zeros(batch, dtype=torch.int32).cuda() @@ -212,8 +207,7 @@ def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): set_ref = set(ref_np) set_trt = set(trt_np) intersection = set_ref & set_trt - print("selected/all:", len(intersection), "/", len(set_ref), "=", - len(intersection) / len(set_ref)) + print("selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref)) # Performance test with CUDA events @@ -245,5 +239,19 @@ def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): print(f"Average torch.topk time: {elapsed_time_ms / n_iters:.3f} ms") +def run_regression_perf(batch=64, seq_len=32 * 1024, topk=2048): + torch.manual_seed(1) + input = torch.randn(batch, seq_len, dtype=torch.float32).cuda() + starts = torch.zeros(batch, dtype=torch.int32).cuda() + ends = torch.ones(batch, dtype=torch.int32).cuda() * seq_len + + from tilelang.profiler import do_bench + + def run_kernel_only(): + tl_topk(input, starts, ends, topk) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": test_topk_selector() diff --git a/examples/deepseek_v32/utils.py b/examples/deepseek_v32/utils.py index 2ea34b14a..d7252e171 100644 --- a/examples/deepseek_v32/utils.py +++ b/examples/deepseek_v32/utils.py @@ -23,8 +23,7 @@ def _is_equal(a, b): if isinstance(a, torch.Tensor): return a is b # Whitelist of types that are safe to compare by value for caching. - if isinstance(a, (int, float, str, bool, type(None))) and isinstance( - b, (int, float, str, bool, type(None))): + if isinstance(a, (int, float, str, bool, type(None))) and isinstance(b, (int, float, str, bool, type(None))): return a == b # For other types, we cannot guarantee a cheap and safe comparison, so we fail the cache check. return False @@ -58,9 +57,11 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: if len(args) == len(last_args) and len(kwargs) == len(last_kwargs): # For Tensors, check for object identity. For other types, check for equality. # Python caches small integers, so `is` works for them but not for large integers like 4096. - if all(_is_equal(a, b) for a, b in zip(args, last_args)) and \ - set(kwargs.keys()) == set(last_kwargs.keys()) and \ - all(_is_equal(v, last_kwargs[k]) for k, v in kwargs.items()): + if ( + all(_is_equal(a, b) for a, b in zip(args, last_args)) + and set(kwargs.keys()) == set(last_kwargs.keys()) + and all(_is_equal(v, last_kwargs[k]) for k, v in kwargs.items()) + ): return last_result result = fn(*args, **kwargs) @@ -79,73 +80,68 @@ def cal_seq_idx_from_cu_seqlens(cu_seqlens: torch.LongTensor, seq_len: int): @tensor_cache -def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, - seq_len: int) -> torch.IntTensor: - seq_idx_for_q = torch.full((seq_len,), - len(cu_seqlens_qs), - dtype=torch.int32, - device=cu_seqlens_qs.device) +def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, seq_len: int) -> torch.IntTensor: + seq_idx_for_q = torch.full((seq_len,), len(cu_seqlens_qs), dtype=torch.int32, device=cu_seqlens_qs.device) for i in range(len(cu_seqlens_qs)): - seq_idx_for_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = i + seq_idx_for_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = i return seq_idx_for_q @tensor_cache -def cal_cu_seqlen_ks_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, - cu_seqlens_ks: torch.LongTensor, seq_len: int) -> torch.IntTensor: +def cal_cu_seqlen_ks_for_q( + cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, cu_seqlens_ks: torch.LongTensor, seq_len: int +) -> torch.IntTensor: cu_seqlen_ks_for_each_q = torch.gather( - input=torch.cat([ - cu_seqlens_ks, - torch.full((1,), - torch.iinfo(torch.int32).max, - dtype=torch.int32, - device=cu_seqlens_qs.device) - ]), + input=torch.cat([cu_seqlens_ks, torch.full((1,), torch.iinfo(torch.int32).max, dtype=torch.int32, device=cu_seqlens_qs.device)]), dim=0, - index=cal_seq_idx_for_q( - cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long()) + index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(), + ) return cu_seqlen_ks_for_each_q.int() @tensor_cache -def cal_cu_seqlen_ke_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, - cu_seqlens_ks: torch.LongTensor, cu_seqlens_ke: torch.LongTensor, - q_start_idxs: torch.LongTensor, seq_len: int, - kv_stride: int) -> torch.IntTensor: +def cal_cu_seqlen_ke_for_q( + cu_seqlens_qs: torch.LongTensor, + cu_seqlens_qe: torch.LongTensor, + cu_seqlens_ks: torch.LongTensor, + cu_seqlens_ke: torch.LongTensor, + q_start_idxs: torch.LongTensor, + seq_len: int, + kv_stride: int, +) -> torch.IntTensor: cu_seqlen_ke_for_each_q = torch.gather( - input=torch.cat( - [cu_seqlens_ke, - torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]), + input=torch.cat([cu_seqlens_ke, torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]), dim=0, - index=cal_seq_idx_for_q( - cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long()) - casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,), - dtype=torch.int32, - device=cu_seqlens_qs.device) + index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(), + ) + casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,), dtype=torch.int32, device=cu_seqlens_qs.device) for i in range(len(cu_seqlens_qs)): - casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = (torch.arange( - q_start_idxs[i], - q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i], - dtype=torch.int32, - device=cu_seqlens_qs.device) + 1) // kv_stride + cu_seqlens_ks[i] + casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = ( + torch.arange( + q_start_idxs[i], q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i], dtype=torch.int32, device=cu_seqlens_qs.device + ) + + 1 + ) // kv_stride + cu_seqlens_ks[i] cu_seqlen_ke_for_each_q = torch.minimum(casual_cu_seqlen_ke_for_each_q, cu_seqlen_ke_for_each_q) return cu_seqlen_ke_for_each_q.int() @tensor_cache -def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor, - cu_seqlens_k: torch.LongTensor = None, - offs_q: torch.LongTensor = None, - *, - seq_len: int, - kv_stride: int = 1, - cp_rank: int = 0, - cp_size: int = 1, - balanced_cp=False): - ''' +def cal_ks_ke_from_cu_seqlen_qk( + cu_seqlens_q: torch.LongTensor, + cu_seqlens_k: torch.LongTensor = None, + offs_q: torch.LongTensor = None, + *, + seq_len: int, + kv_stride: int = 1, + cp_rank: int = 0, + cp_size: int = 1, + balanced_cp=False, +): + """ seq_len: seq len per cp rank balanced cp slice assignment: 0 1 2 3 3 2 1 0 - ''' + """ n_seq = len(cu_seqlens_q) - 1 assert n_seq > 0 assert cu_seqlens_q.shape == (n_seq + 1,) @@ -170,10 +166,12 @@ def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor, def f(x: torch.Tensor): chunks = x.chunk(cp_size * 2) - return torch.cat([ - chunks[cp_rank], - chunks[cp_size - cp_rank - 1], - ]) + return torch.cat( + [ + chunks[cp_rank], + chunks[cp_size - cp_rank - 1], + ] + ) ks = f(ks) ke = f(ke) @@ -189,8 +187,7 @@ def ceil_to_ue8m0(x: torch.Tensor): return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) -def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], - use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: +def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) sf = x_amax / 448.0 @@ -239,14 +236,18 @@ def generate_random_cu_seqlens(per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1, total_seqlen - (cp_rank + 1) * per_chunk_seqlen, total_seqlen - cp_rank * per_chunk_seqlen, ) - ks = torch.cat([ - cu_seqlens_ks_for_each_q[slice_short], - cu_seqlens_ks_for_each_q[slice_long], - ]) - ke = torch.cat([ - cu_seqlens_ke_for_each_q[slice_short], - cu_seqlens_ke_for_each_q[slice_long], - ]) + ks = torch.cat( + [ + cu_seqlens_ks_for_each_q[slice_short], + cu_seqlens_ks_for_each_q[slice_long], + ] + ) + ke = torch.cat( + [ + cu_seqlens_ke_for_each_q[slice_short], + cu_seqlens_ke_for_each_q[slice_long], + ] + ) assert len(ks) == len(ke) == per_cp_seqlen return ks, ke @@ -302,11 +303,9 @@ def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): raise_assert: Whether to raise assertion error on failure """ sim = calculate_tensor_similarity(x, y, name) - diff = 1. - sim + diff = 1.0 - sim if not (0 <= diff <= eps): - print( - f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m" - ) + print(f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m") if raise_assert: assert False # noqa: B011 @@ -316,11 +315,8 @@ def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): cu_seqlens = torch.randint(128, 4096, (1000,), dtype=torch.int32, device="cuda") last_idx = torch.where(cu_seqlens.cumsum(dim=0) >= seq_len)[0][0] cu_seqlens_cumsum = cu_seqlens[:last_idx].cumsum(dim=0) - cu_seqlens_qs = torch.cat( - [torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum]) - cu_seqlens_qe = torch.cat( - [cu_seqlens_cumsum, - torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len]) + cu_seqlens_qs = torch.cat([torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum]) + cu_seqlens_qe = torch.cat([cu_seqlens_cumsum, torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len]) from tilelang.profiler import do_bench diff --git a/examples/dequantize_gemm/README.md b/examples/dequantize_gemm/README.md index 0c6116775..25ef617a2 100644 --- a/examples/dequantize_gemm/README.md +++ b/examples/dequantize_gemm/README.md @@ -19,7 +19,7 @@ def dequant_matmul( T.clear(Ct_local) for k in T.Pipelined( - T.ceildiv(K, block_K), + T.ceildiv(K, block_K), num_stages=num_stages ): T.copy(A[by * block_M, k * block_K], A_shared) diff --git a/examples/dequantize_gemm/dequantize_utils.py b/examples/dequantize_gemm/dequantize_utils.py index b14c0aee6..90a6265ff 100644 --- a/examples/dequantize_gemm/dequantize_utils.py +++ b/examples/dequantize_gemm/dequantize_utils.py @@ -39,12 +39,10 @@ def torch_convert_bit_twiddling(tensor): res0 = val_concat_expanded & mask res1 = (val_concat_expanded << 3) & mask res2 = (val_concat_expanded << 6) & mask - res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | ( - (val_concat_expanded >> 7) & mask3) + res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | ((val_concat_expanded >> 7) & mask3) # Select the correct result based on position - bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1, - torch.where(pos == 2, res2, res3))) + bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1, torch.where(pos == 2, res2, res3))) # Convert to uint16 for .view(torch.bfloat16) bf16_uint16 = (bf16 & 0xFFFF).to(torch.uint16) @@ -110,7 +108,7 @@ def print_bit(name, val): val (torch.Tensor): A scalar PyTorch tensor (numeric) whose 32-bit binary representation will be shown. """ val_cpu = val.cpu().item() - binary_repr = f'{val_cpu:032b}' + binary_repr = f"{val_cpu:032b}" print(name, binary_repr) @@ -122,7 +120,7 @@ def calc_sim(x, y, name="tensor"): x, y = x.data.double(), y.data.double() denominator = (x * x + y * y).sum() if denominator == 0: - print_red_warning(f'{name} all zero') + print_red_warning(f"{name} all zero") return 1 sim = 2 * (x * y).sum() / denominator return sim @@ -132,21 +130,19 @@ def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True): x_mask = torch.isfinite(x) y_mask = torch.isfinite(y) if not torch.all(x_mask == y_mask): - print_red_warning(f'{name} Error: isfinite mask mismatch') + print_red_warning(f"{name} Error: isfinite mask mismatch") if raise_assert: raise AssertionError - if not torch.isclose( - x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, - equal_nan=True).all(): - print_red_warning(f'{name} Error: nonfinite value mismatch') + if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all(): + print_red_warning(f"{name} Error: nonfinite value mismatch") if raise_assert: raise AssertionError x = x.masked_fill(~x_mask, 0) y = y.masked_fill(~y_mask, 0) sim = calc_sim(x, y, name) - diff = (1. - sim).item() - print(f'{diff=}') + diff = (1.0 - sim).item() + print(f"{diff=}") if not (0 <= diff <= eps): - print_red_warning(f'{name} Error: {diff=}') + print_red_warning(f"{name} Error: {diff=}") if raise_assert: raise AssertionError diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py index e30845b8d..36b32c0a8 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py @@ -24,6 +24,7 @@ def get_configs(): the parameter name to its chosen value. """ import itertools + iter_params = dict( block_M=[64, 128, 256], block_N=[64, 128, 256], @@ -32,65 +33,64 @@ def get_configs(): threads=[128, 256, 512], split=[1, 2], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] -@tilelang.autotune(configs=get_configs(),) +@tilelang.autotune( + configs=get_configs(), +) @tilelang.jit( out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) -def matmul(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - source_format='uint', - num_bits=4, - fast_dequant=True, - block_M=256, - block_N=128, - block_K=128, - num_stages=2, - threads=256, - split=1): +def matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + source_format=T.uint32, + num_bits=4, + fast_dequant=True, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, +): + """ + Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T. + + This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts: + - A: dense input of shape (M, K) with dtype `in_dtype`. + - B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`. + - C: output of shape (M, N) with dtype `out_dtype`. + + The generated kernel supports two dequantization paths: + - fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group. + - simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element. + + Important behavior and requirements: + - num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits. + - QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes. + - Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid. + - When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group. + - The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages. + + Parameters that alter kernel layout/behavior (brief): + - block_M, block_N, block_K: tile sizes for M, N, and K dimensions. + - num_stages: number of software pipeline stages for the K-loop. + - threads: number of threads used per kernel block. + - split: extra K-splitting factor; K must be divisible by block_K * split. + - source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics. + + Returns: + A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel. """ - Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T. - - This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts: - - A: dense input of shape (M, K) with dtype `in_dtype`. - - B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`. - - C: output of shape (M, N) with dtype `out_dtype`. - - The generated kernel supports two dequantization paths: - - fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group. - - simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element. - - Important behavior and requirements: - - num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits. - - QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes. - - Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid. - - When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group. - - The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages. - - Parameters that alter kernel layout/behavior (brief): - - block_M, block_N, block_K: tile sizes for M, N, and K dimensions. - - num_stages: number of software pipeline stages for the K-loop. - - threads: number of threads used per kernel block. - - split: extra K-splitting factor; K must be divisible by block_K * split. - - source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics. - - Returns: - A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel. - """ num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 QK = K // num_elems_per_byte Block_QK = block_K // num_elems_per_byte @@ -121,7 +121,7 @@ def matmul(M, assert func_name is not None, "mxfp_intrin_info is not found" import_source = import_source - def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): + def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16): """ Create a TileLang macro that performs fast, twiddling-based dequantization from packed FP4 to BF16 using an external runtime plugin. @@ -131,13 +131,13 @@ def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): - Writes the dequantized BF16 values back to a shared dequantized buffer for use by the kernel. Notes and preconditions: - - Asserts that `in_dtype == "fp4"` and `out_dtype == "bfloat16"`. + - Asserts that `in_dtype == "fp4"` and `out_dtype == T.bfloat16`. - The generated macro depends on several surrounding-scope symbols (e.g., `import_source`, `func_name`, `block_K`, `Block_QK`, `threads`, `num_elems_per_byte`, `storage_dtype`, and `out_dtype`) and expects them to be defined consistently in the enclosing kernel. - The macro is optimized for block-wise, per-thread transactions sized to the target storage width (uses a MAX_TRANSACTION_SIZE_BITS constant) and uses local/register buffers sized accordingly. - The macro uses `T.import_source` to bring the external plugin into the module and `T.call_extern` to perform the high-throughput dequantization; callers must ensure the external function matches the expected calling convention and memory layout. """ assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] + assert out_dtype in [T.bfloat16] # Some variables for dequantization in each thread MAX_TRANSACTION_SIZE_BITS = 128 @@ -189,12 +189,11 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared): # Finally, store the dequantized data to shared memory. for v in T.vectorized(0, local_size): index = i * threads * local_size + tx * local_size + v - B_dequantize_shared[index // block_K, - index % block_K] = B_dequantize_local_thread[v] + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] return fast_dequant_bf16_fp4_twiddling - def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): + def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16): """ Create a simple TIR dequantization macro that converts packed 4-bit FP (FP4) stored in uint8 into bfloat16. @@ -205,7 +204,7 @@ def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): - Writes the dequantized bfloat16 block into B_dequantize_shared. Constraints: - - Supports only in_dtype="fp4" and out_dtype="bfloat16". + - Supports only in_dtype="fp4" and out_dtype=T.bfloat16. - The helper assumes nbit == 4 and produces bfloat16 values. - The macro uses a fixed test-scale of 0 (no per-element scaling) as written. @@ -213,49 +212,49 @@ def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): A TIR macro function performing the described in-place block dequantization from packed uint8 FP4 to bfloat16. """ assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] + assert out_dtype in [T.bfloat16] - def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, - scale: tir.PrimExpr, dtype: str): + def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str): """ - Convert a 4-bit FP4 value packed in a uint8 byte into a bfloat16 value. - - This helper extracts the 4-bit field located at the bit position `pos` within the - byte `val`, interprets it as an FP4 (sign, exponent, mantissa) value, applies an - exponent `scale` offset to align it with bfloat16 exponent bias, clamps the - resulting exponent to 8 bits, and returns the assembled bfloat16 bit pattern. - - Parameters: - nbit (int): Number of bits in the packed element; must be 4. - val (tir.PrimExpr): A uint8 value containing packed FP4 elements. - pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract. - scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16. - dtype (str): Target dtype string; must be "bfloat16". - - Returns: - tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value. - - Notes: - - The function asserts `nbit == 4`, `dtype == "bfloat16"`, and that `val.dtype` is "uint8". - - The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16 - bit fields and clamps the computed exponent to fit into 8 bits. + Convert a 4-bit FP4 value packed in a uint8 byte into a bfloat16 value. + + This helper extracts the 4-bit field located at the bit position `pos` within the + byte `val`, interprets it as an FP4 (sign, exponent, mantissa) value, applies an + exponent `scale` offset to align it with bfloat16 exponent bias, clamps the + resulting exponent to 8 bits, and returns the assembled bfloat16 bit pattern. + + Parameters: + nbit (int): Number of bits in the packed element; must be 4. + val (tir.PrimExpr): A uint8 value containing packed FP4 elements. + pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract. + scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16. + dtype (str): Target dtype string; must be T.bfloat16. + + Returns: + tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value. + + Notes: + - The function asserts `nbit == 4`, `dtype == T.bfloat16`, and that `val.dtype` is T.uint8. + - The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16 + bit fields and clamps the computed exponent to fit into 8 bits. """ assert nbit == 4 - assert dtype == "bfloat16" - assert val.dtype == "uint8" - mask = tir.const((1 << nbit) - 1, "uint16") - f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask - s = f4 >> tir.const(3, "uint16") - e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16") + assert dtype == T.bfloat16 + assert val.dtype == T.uint8 + mask = tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 - e_bf16 = e_f4 + tir.const(126, "uint16") + e_bf16 = e_f4 + tir.const(126, T.uint16) # Scale is the exponential part, within the representation of uint8 # To handle the overflow, we use the max function to limit the exponential part to 8 bits - e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) - m_f4 = f4 & tir.const(1, "uint16") + e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, T.uint16)) + m_f4 = f4 & tir.const(1, T.uint16) val_bf16 = tir.reinterpret( - "bfloat16", ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) - | (m_f4 << tir.const(6, "uint16"))).astype("uint16")) + T.bfloat16, + ((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16), + ) return val_bf16 @T.macro @@ -292,32 +291,32 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared): @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((M, N), out_dtype), ): """ - Kernel entry for the tiled, pipelined matmul used by the generated prim_func. - - This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it: - - Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile. - - Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout. - - Pipelines over K in chunks of `block_K` for `num_stages` stages: - - Loads A and packed B tiles into shared memory. - - Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine. - - Performs a GEMM accumulating into C_local with B transposed. - - Stores the accumulated block from C_local back to the global output C via C_shared. - - Parameters: - - A: input tile of shape (M, K) with dtype `in_dtype`. - - B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing). - - C: output tensor of shape (M, N) with dtype `out_dtype`. - - Side effects: - - Writes the computed output block into the global tensor `C`. - - Uses and updates shared memory buffers and per-thread accumulators. - - No value is returned. + Kernel entry for the tiled, pipelined matmul used by the generated prim_func. + + This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it: + - Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile. + - Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout. + - Pipelines over K in chunks of `block_K` for `num_stages` stages: + - Loads A and packed B tiles into shared memory. + - Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine. + - Performs a GEMM accumulating into C_local with B transposed. + - Stores the accumulated block from C_local back to the global output C via C_shared. + + Parameters: + - A: input tile of shape (M, K) with dtype `in_dtype`. + - B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing). + - C: output tensor of shape (M, N) with dtype `out_dtype`. + + Side effects: + - Writes the computed output block into the global tensor `C`. + - Uses and updates shared memory buffers and per-thread accumulators. + + No value is returned. """ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -327,10 +326,6 @@ def main( C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype) - T.annotate_layout({ - C_shared: tilelang.layout.make_swizzled_layout(C_shared), - }) - T.clear(C_local) for k in T.Pipelined(K // block_K, num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) @@ -344,7 +339,7 @@ def main( T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N]) + T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) return main @@ -363,7 +358,7 @@ def ref_program_twiddling(A, qB): Returns: torch.Tensor: Result matrix C with shape (M, N) in bfloat16. """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert_bit_twiddling(qB) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) @@ -383,7 +378,7 @@ def ref_program_simple(A, qB): Returns: torch.Tensor: Resulting matrix C in bfloat16 with shape (M, N). """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert(qB) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) @@ -409,16 +404,15 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False): """ total_flops = 2 * m * n * k if tune: - kernel = matmul( - m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, fast_dequant=fast_dequant) + kernel = matmul(m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, fast_dequant=fast_dequant) else: kernel = matmul( m, n, k, - "bfloat16", - "bfloat16", - "float32", + T.bfloat16, + T.bfloat16, + T.float32, num_bits=4, fast_dequant=fast_dequant, block_M=256, @@ -426,7 +420,8 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False): block_K=128, num_stages=2, threads=256, - split=1) + split=1, + ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) if fast_dequant: profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01) @@ -437,6 +432,27 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False): print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(m=4096, n=4096, k=4096, fast_dequant=True): + kernel = matmul( + m, + n, + k, + "bfloat16", + "bfloat16", + "float32", + num_bits=4, + fast_dequant=fast_dequant, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, + ) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main(256, 256, 256, True) main(256, 256, 256, False) diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py index ac1417aeb..cc37c8bc4 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py @@ -7,45 +7,45 @@ from dequantize_utils import torch_convert_bit_twiddling, torch_convert -def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, - dtype: str): +def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str): """ - Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. + Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. - This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its - bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by - `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. + This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its + bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by + `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. - Parameters: - nbit (int): Number of bits in the packed field (must be 4). - val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. - pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). - scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). - dtype (str): Destination dtype string (must be "bfloat16"). + Parameters: + nbit (int): Number of bits in the packed field (must be 4). + val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. + pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). + scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). + dtype (str): Destination dtype string (must be T.bfloat16). - Returns: - tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. + Returns: + tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. - Notes: - - Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8". - - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. - """ + Notes: + - Preconditions are enforced via assertions: nbit == 4, dtype == T.bfloat16, and val.dtype == T.uint8. + - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. + """ assert nbit == 4 - assert dtype == "bfloat16" - assert val.dtype == "uint8" - mask = tir.const((1 << nbit) - 1, "uint16") - f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask - s = f4 >> tir.const(3, "uint16") - e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16") + assert dtype == T.bfloat16 + assert val.dtype == T.uint8 + mask = tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 - e_bf16 = e_f4 + tir.const(126, "uint16") + e_bf16 = e_f4 + tir.const(126, T.uint16) # Scale is the exponential part, within the representation of uint8 # To handle the overflow, we may use the min function to limit the exponential part to 8 bits # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) - m_f4 = f4 & tir.const(1, "uint16") - val_bf16 = tir.reinterpret("bfloat16", - ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) - | (m_f4 << tir.const(6, "uint16"))).astype("uint16")) + m_f4 = f4 & tir.const(1, T.uint16) + val_bf16 = tir.reinterpret( + T.bfloat16, + ((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16), + ) return val_bf16 @@ -65,6 +65,7 @@ def get_configs(): List[dict]: A list of configuration dictionaries covering all combinations. """ import itertools + iter_params = dict( block_M=[64, 128, 256], block_N=[64, 128, 256], @@ -73,70 +74,74 @@ def get_configs(): threads=[128, 256, 512], split=[1, 2], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] - - -@tilelang.autotune(configs=get_configs(),) -@tilelang.jit(out_idx=[-1],) -def matmul(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - source_format='uint', - num_bits=4, - scale_size=32, - fast_dequant=True, - with_bias=False, - block_M=256, - block_N=128, - block_K=128, - num_stages=2, - threads=256, - split=1): + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit( + out_idx=[-1], +) +def matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + source_format=T.uint32, + num_bits=4, + scale_size=32, + fast_dequant=True, + with_bias=False, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, +): """ - Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype. - - The generated kernel accepts: - - A: dense matrix with element type `in_dtype`. - - B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)). - - Scale: per-block scale/exponent information used to dequantize B. - The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: - - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. - - fast_dequant (False): uses a simple elementwise dequantization helper. - - Parameters: - M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). - in_dtype (str): element type of A (e.g., "fp4" in this file). - out_dtype (str): output tensor element type (e.g., "bfloat16"). - accum_dtype (str): accumulation type used for the inner GEMM. - source_format (str, optional): format string passed to intrinsic selector (default "uint"). - num_bits (int, optional): number of bits per quantized element in B (default 4). - scale_size (int, optional): number of elements grouped per scale entry (default 32). - fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). - block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). - num_stages (int, optional): pipelining stages for K loop (default 2). - threads (int, optional): threads per block used by the kernel (default 256). - split (int, optional): split factor along K used by the scheduler (default 1). - with_bias (bool, optional): whether to add Bias to the output (default False). - - Returns: - A T.prim_func implementing the tiled, pipelined GEMM that: - - loads tiled blocks of A and packed B to shared memory, - - dequantizes B via the chosen path into a shared dequantized tile, - - performs a tiled GEMM accumulating into local fragments, - - writes the final MxN block to the global output tensor. + Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype. - Notes: - - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. - - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. - - An assertion enforces that K % (block_K * split) == 0. + The generated kernel accepts: + - A: dense matrix with element type `in_dtype`. + - B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)). + - Scale: per-block scale/exponent information used to dequantize B. + The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: + - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. + - fast_dequant (False): uses a simple elementwise dequantization helper. + + Parameters: + M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). + in_dtype (str): element type of A (e.g., "fp4" in this file). + out_dtype (str): output tensor element type (e.g., T.bfloat16). + accum_dtype (str): accumulation type used for the inner GEMM. + source_format (str, optional): format string passed to intrinsic selector (default "uint"). + num_bits (int, optional): number of bits per quantized element in B (default 4). + scale_size (int, optional): number of elements grouped per scale entry (default 32). + fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). + block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). + num_stages (int, optional): pipelining stages for K loop (default 2). + threads (int, optional): threads per block used by the kernel (default 256). + split (int, optional): split factor along K used by the scheduler (default 1). + with_bias (bool, optional): whether to add Bias to the output (default False). + + Returns: + A T.prim_func implementing the tiled, pipelined GEMM that: + - loads tiled blocks of A and packed B to shared memory, + - dequantizes B via the chosen path into a shared dequantized tile, + - performs a tiled GEMM accumulating into local fragments, + - writes the final MxN block to the global output tensor. + + Notes: + - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. + - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. + - An assertion enforces that K % (block_K * split) == 0. """ num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 QK = K // num_elems_per_byte Block_QK = block_K // num_elems_per_byte A_shape = (M, K) @@ -150,6 +155,7 @@ def matmul(M, assert K % (block_K * split) == 0 from tilelang.quantize import get_mxfp_intrin_group + # fast_dequant_bf16_fp4_twiddling mxfp_intrin_info = get_mxfp_intrin_group( out_dtype=in_dtype, @@ -164,7 +170,7 @@ def matmul(M, assert func_name is not None, "mxfp_intrin_info is not found" import_source = import_source - def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): + def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16): """ Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16. @@ -175,12 +181,12 @@ def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): - Writes the scaled BF16 results into B_dequantize_shared. Notes: - - This factory only supports in_dtype="fp4" and out_dtype="bfloat16". + - This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16. - The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro. - The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime. """ assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] + assert out_dtype in [T.bfloat16] # Some variables for dequantization in each thread MAX_TRANSACTION_SIZE_BITS = 128 @@ -252,24 +258,23 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale, k): for v in T.vectorized(0, local_size): index = i * threads * local_size + tx * local_size + v - B_dequantize_shared[index // block_K, - index % block_K] = B_dequantize_local_thread[v] + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] return fast_dequant_bf16_fp4_twiddling - def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): + def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16): """ Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16. Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared. Notes: - - Only supports in_dtype="fp4" and out_dtype="bfloat16". + - Only supports in_dtype="fp4" and out_dtype=T.bfloat16. - The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion. - Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper. """ assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] + assert out_dtype in [T.bfloat16] @T.macro def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k): @@ -301,33 +306,32 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k): B_local[i, j // num_elems_per_byte], j % num_elems_per_byte, Scale[ - bx * block_N + i, k * block_K // scale_size + j // - scale_size], # Scale is the exponential part, within the representation of uint8 + bx * block_N + i, k * block_K // scale_size + j // scale_size + ], # Scale is the exponential part, within the representation of uint8 dtype=out_dtype, - ) * T.shift_left( - 1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size])) + ) * T.shift_left(1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size])) T.copy(B_dequantize_local, B_dequantize_shared) return simple_dequant_bf16_fp4 @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - Scale: T.Tensor(Scale_shape, storage_dtype), - Bias: T.Tensor(Bias_shape, out_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Scale: T.Tensor(Scale_shape, storage_dtype), + Bias: T.Tensor(Bias_shape, out_dtype), + C: T.Tensor((M, N), out_dtype), ): """ - Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C. + Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C. - This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function. + This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function. - Parameters are self-descriptive in the signature; notable behaviors: - - B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM. - - The selected dequantization path is controlled by the outer-scope flag `fast_dequant`. - - The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization). - - The function writes results in-place into C. + Parameters are self-descriptive in the signature; notable behaviors: + - B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM. + - The selected dequantization path is controlled by the outer-scope flag `fast_dequant`. + - The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization). + - The function writes results in-place into C. """ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -337,23 +341,24 @@ def main( C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype) - T.annotate_layout({ - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - C_shared: tilelang.layout.make_swizzled_layout(C_shared), - }) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + } + ) if with_bias: - T.annotate_layout({ - Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared), - }) + T.annotate_layout( + { + Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared), + } + ) if threads == 512: T.disable_warp_group_reg_alloc() if with_bias: - T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N], - Bias_shared) + T.copy(Bias[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N], Bias_shared) T.copy(Bias_shared, C_local) else: T.clear(C_local) @@ -368,7 +373,7 @@ def main( T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N]) + T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) return main @@ -387,9 +392,9 @@ def ref_program_twiddling(A, qB, Scale, Bias=None): Returns: torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert_bit_twiddling(qB) - B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) return C @@ -410,9 +415,9 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias): Returns: torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert_bit_twiddling(qB) - B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = C.to(torch.__getattribute__(dtypeC)) return C @@ -434,9 +439,9 @@ def ref_program_simple(A, qB, Scale, Bias=None): No in-place modification is performed on inputs (a local floating copy of B is scaled). """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert(qB) - B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) return C @@ -462,9 +467,9 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias): No in-place modification is performed on inputs (a local floating copy of B is scaled). """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert(qB) - B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = C.to(torch.__getattribute__(dtypeC)) return C @@ -491,24 +496,16 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, if tune: kernel = matmul( - m, - n, - k, - "bfloat16", - "bfloat16", - "float32", - num_bits=4, - scale_size=scale_size, - fast_dequant=fast_dequant, - with_bias=with_bias) + m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias + ) else: kernel = matmul( m, n, k, - "bfloat16", - "bfloat16", - "float32", + T.bfloat16, + T.bfloat16, + T.float32, num_bits=4, scale_size=scale_size, block_M=256, @@ -518,7 +515,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, threads=256, split=1, fast_dequant=fast_dequant, - with_bias=with_bias) + with_bias=with_bias, + ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) @@ -538,6 +536,29 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(m=4096, n=4096, k=4096, scale_size=32, fast_dequant=True, with_bias=False): + kernel = matmul( + m, + n, + k, + "bfloat16", + "bfloat16", + "float32", + num_bits=4, + scale_size=scale_size, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, + fast_dequant=fast_dequant, + with_bias=with_bias, + ) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": M, N, K = 256, 256, 256 scale_size = 32 diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py deleted file mode 100644 index 7dad79597..000000000 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py +++ /dev/null @@ -1,563 +0,0 @@ -import tilelang -import tilelang.language as T -from tilelang import tvm as tvm -from tvm import DataType -from tvm import tir -import torch -from dequantize_utils import torch_convert_bit_twiddling, torch_convert - - -def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, - dtype: str): - """ - Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. - - This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its - bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by - `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. - - Parameters: - nbit (int): Number of bits in the packed field (must be 4). - val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. - pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). - scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). - dtype (str): Destination dtype string (must be "bfloat16"). - - Returns: - tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. - - Notes: - - Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8". - - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. - """ - assert nbit == 4 - assert dtype == "bfloat16" - assert val.dtype == "uint8" - mask = tir.const((1 << nbit) - 1, "uint16") - f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask - s = f4 >> tir.const(3, "uint16") - e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16") - # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 - e_bf16 = e_f4 + tir.const(126, "uint16") - # Scale is the exponential part, within the representation of uint8 - # To handle the overflow, we may use the min function to limit the exponential part to 8 bits - # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) - m_f4 = f4 & tir.const(1, "uint16") - val_bf16 = tir.reinterpret("bfloat16", - ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) - | (m_f4 << tir.const(6, "uint16"))).astype("uint16")) - return val_bf16 - - -def get_configs(): - """ - Generate a list of hyperparameter configuration dictionaries for tuning. - - Each configuration is a dict with keys: 'block_M', 'block_N', 'block_K', - 'num_stages', 'threads', and 'split'. The function returns the Cartesian - product of the parameter value lists: - - block_M, block_N, block_K: tiling sizes (64, 128, 256) - - num_stages: pipeline stages (0, 2) - - threads: thread counts (128, 256, 512) - - split: K-splitting factor (1, 2) - - Returns: - List[dict]: A list of configuration dictionaries covering all combinations. - """ - import itertools - iter_params = dict( - block_M=[64, 128, 256], - block_N=[64, 128, 256], - block_K=[64, 128, 256], - num_stages=[0, 1, 2], - threads=[128, 256, 512], - split=[1, 2], - ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] - - -@tilelang.autotune(configs=get_configs(),) -@tilelang.jit(out_idx=[-1],) -def matmul(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - source_format='uint', - num_bits=4, - scale_size=32, - fast_dequant=True, - with_bias=False, - block_M=256, - block_N=128, - block_K=128, - num_stages=2, - threads=256, - split=1): - """ - Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype. - - The generated kernel accepts: - - A: dense matrix with element type `in_dtype`. - - B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)). - - Scale: per-block scale/exponent information used to dequantize B. - The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: - - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. - - fast_dequant (False): uses a simple elementwise dequantization helper. - - Parameters: - M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). - in_dtype (str): element type of A (e.g., "fp4" in this file). - out_dtype (str): output tensor element type (e.g., "bfloat16"). - accum_dtype (str): accumulation type used for the inner GEMM. - source_format (str, optional): format string passed to intrinsic selector (default "uint"). - num_bits (int, optional): number of bits per quantized element in B (default 4). - scale_size (int, optional): number of elements grouped per scale entry (default 32). - fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). - block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). - num_stages (int, optional): pipelining stages for K loop (default 2). - threads (int, optional): threads per block used by the kernel (default 256). - split (int, optional): split factor along K used by the scheduler (default 1). - with_bias (bool, optional): whether to add Bias to the output (default False). - - Returns: - A T.prim_func implementing the tiled, pipelined GEMM that: - - loads tiled blocks of A and packed B to shared memory, - - dequantizes B via the chosen path into a shared dequantized tile, - - performs a tiled GEMM accumulating into local fragments, - - writes the final MxN block to the global output tensor. - - Notes: - - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. - - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. - - An assertion enforces that K % (block_K * split) == 0. - """ - num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" - QK = K // num_elems_per_byte - Block_QK = block_K // num_elems_per_byte - A_shape = (M, K) - B_shape = (N, QK) - Bias_shape = (M, N) - Scale_shape = (N, K // scale_size) - A_shared_shape = (block_M, block_K) - B_shared_shape = (block_N, Block_QK) - Bias_shared_shape = (block_M, block_N) - B_dequantize_shared_shape = (block_N, block_K) - assert K % (block_K * split) == 0 - - from tilelang.quantize import get_mxfp_intrin_group - # fast_dequant_bf16_fp4_twiddling - mxfp_intrin_info = get_mxfp_intrin_group( - out_dtype=in_dtype, - source_format=source_format, - source_bit=num_bits, - storage_dtype=storage_dtype, - use_twiddling=True, - ) - import_source = mxfp_intrin_info["c_source"] - func_name = mxfp_intrin_info["func_name"] - assert import_source is not None, "mxfp_intrin_info is not found" - assert func_name is not None, "mxfp_intrin_info is not found" - import_source = import_source - - def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): - """ - Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16. - - The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and: - - Loads packed FP4 elements from B_shared into per-thread local registers. - - Calls an external fast dequantization intrinsic (provided via `import_source` / `func_name` in the outer scope) to expand packed FP4 -> BF16 values. - - Applies a per-block scale factor derived from the Scale tensor (using exponentiation by powers of two). - - Writes the scaled BF16 results into B_dequantize_shared. - - Notes: - - This factory only supports in_dtype="fp4" and out_dtype="bfloat16". - - The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro. - - The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime. - """ - assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] - - # Some variables for dequantization in each thread - MAX_TRANSACTION_SIZE_BITS = 128 - local_size = MAX_TRANSACTION_SIZE_BITS // DataType(out_dtype).bits - local_compress_size = local_size // num_elems_per_byte - - @T.macro - def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale_shared, k): - # import fast_dequantize plugin - """ - Fast dequantization kernel: convert packed 4-bit quantized values in B_shared to bfloat16 - in B_dequantize_shared using an external intrinsic optimized for twiddled (bit-packed) FP4, - applying per-block scale factors from Scale. - - This routine is a tiled, thread-parallel helper that: - - Imports and calls an external dequantization function (via `import_source`/`func_name`) - to expand compressed uint8-packed FP4 values into BF16 fragments in-thread. - - Loads the corresponding per-block scale entry, interprets it as an exponent bias - (applies 2^(Scale - 127)), and multiplies the dequantized BF16 fragment by that factor. - - Writes the scaled BF16 results back into the shared B_dequantize_shared buffer in-place. - - Parameters: - - B_shared: read-only shared buffer containing compressed FP4 data (packed uint8 layout). - - B_dequantize_shared: shared output buffer that is overwritten with BF16 dequantized values. - - Scale: per-block scale tensor; entries are interpreted such that the multiplicative scale - = 2^(Scale - 127). - - k: block index along the K dimension used to select the appropriate Scale entries. - - Side effects: - - Mutates B_dequantize_shared in shared memory. - - Calls an external intrinsic function (must be provided by the environment via `import_source` - and `func_name`) to perform the low-level unpacking/dequantization. - """ - T.import_source(import_source) - - tx = T.get_thread_binding() - bx = T.get_block_binding(0) # noqa: F841 - - B_local_thread = T.alloc_local((local_compress_size,), storage_dtype) - B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype) - Scale_local_thread = T.alloc_local((1,), storage_dtype) - Scale_local_thread_exponent = T.alloc_local((1,), out_dtype) - - for i in T.serial(0, block_N * block_K // threads // local_size): - # First, load data from share memory to register. - # Prepare for dequant. - index_base = i * threads * local_compress_size + tx * local_compress_size - for v in T.vectorized(0, local_compress_size): - index = index_base + v - B_local_thread[v] = B_shared[index // Block_QK, index % Block_QK] - index_scale = index_base // (scale_size // num_elems_per_byte) - si = index_scale // (block_K // scale_size) - sj = index_scale % (block_K // scale_size) - Scale_local_thread[0] = Scale_shared[si, k * block_K // scale_size + sj] - Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0])) - - # Then, dequant. - T.call_extern( - func_name, - T.address_of(B_local_thread[0]), - T.address_of(B_dequantize_local_thread[0]), - 1, - dtype=out_dtype, - ) - - # Finally, store the dequantized data to shared memory. - for v in T.Parallel(local_size): - B_dequantize_local_thread[v] *= Scale_local_thread_exponent[0] - - for v in T.vectorized(0, local_size): - index = i * threads * local_size + tx * local_size + v - B_dequantize_shared[index // block_K, - index % block_K] = B_dequantize_local_thread[v] - - return fast_dequant_bf16_fp4_twiddling - - def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): - """ - Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16. - - Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared. - - Notes: - - Only supports in_dtype="fp4" and out_dtype="bfloat16". - - The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion. - - Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper. - """ - assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] - - @T.macro - def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): - """ - Dequantizes a packed 4-bit (FP4) block from B_shared into BF16 values in B_dequantize_shared using per-element scale exponents. - - Per-element behavior: - - Reads packed 4-bit entries from B_shared (uint8 storage, multiple nibbles per byte). - - Uses Scale to obtain an exponent term (stored as uint8) and reconstructs BF16 values via _tir_u8_to_f4_to_bf16. - - Writes the dequantized BF16 block into B_dequantize_shared. - - Parameters: - - B_shared: shared-memory buffer holding packed 4-bit values (uint8-packed layout). - - B_dequantize_shared: shared-memory buffer to receive dequantized BF16 results. - - Scale: per-element exponent buffer; used to compute the scale factor for each dequantized element. - - k: current block index along the K dimension (used to select the appropriate slice of Scale). - - Side effects: - - Mutates B_dequantize_shared by storing the dequantized BF16 fragment. - """ - B_local = T.alloc_fragment(B_shared_shape, storage_dtype) - B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype) - - bx = T.get_block_binding(0) # noqa: F841 - T.copy(B_shared, B_local) - for i, j in T.Parallel(block_N, block_K): - B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16( - num_bits, - B_local[i, j // num_elems_per_byte], - j % num_elems_per_byte, - Scale_shared[ - i, k * block_K // scale_size + j // - scale_size], # Scale is the exponential part, within the representation of uint8 - dtype=out_dtype, - ) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size])) - T.copy(B_dequantize_local, B_dequantize_shared) - - return simple_dequant_bf16_fp4 - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - Scale: T.Tensor(Scale_shape, storage_dtype), - Bias: T.Tensor(Bias_shape, out_dtype), - C: T.Tensor((M, N), out_dtype), - ): - """ - Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C. - - This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function. - - Parameters are self-descriptive in the signature; notable behaviors: - - B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM. - - The selected dequantization path is controlled by the outer-scope flag `fast_dequant`. - - The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization). - - The function writes results in-place into C. - """ - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, storage_dtype) - B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) - Bias_shared = T.alloc_shared(Bias_shared_shape, out_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - C_shared = T.alloc_shared((block_M, block_N), out_dtype) - # To use 1D TMA, the last dim of Scale_shared must have stride=1 - # May use much more shared memory than necessary - Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype) - - T.annotate_layout({ - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - C_shared: tilelang.layout.make_swizzled_layout(C_shared), - }) - - if with_bias: - T.annotate_layout({ - Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared), - }) - - if threads == 512: - T.disable_warp_group_reg_alloc() - - if with_bias: - # T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N], - # Bias_shared) - # T.copy(Bias_shared, C_local) - T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N], - C_local) - else: - T.clear(C_local) - - # Use 1D TMA to load Scale - T.copy(Scale[bx * block_N:(bx + 1) * block_N, :], Scale_shared) - - for k in T.Pipelined(K // block_K, num_stages=num_stages): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) - if fast_dequant: - get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, - k) - else: - get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k) - T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) - - T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N]) - - return main - - -def ref_program_twiddling(A, qB, Scale, Bias=None): - """ - Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results. - - Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16. - - Parameters: - A (torch.Tensor): Left operand with shape (M, K), used in floating precision. - qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling. - Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B. - - Returns: - torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. - """ - dtypeC = "bfloat16" - B = torch_convert_bit_twiddling(qB) - for i in range(B.shape[0]): - for j in range(B.shape[1]): - B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) - C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) - C = C.to(torch.__getattribute__(dtypeC)) - return C - - -def ref_program_twiddling_with_bias(A, qB, Scale, Bias): - """ - Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results. - - Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16. - - Parameters: - A (torch.Tensor): Left operand with shape (M, K), used in floating precision. - qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling. - Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B. - Bias (torch.Tensor): Bias tensor with shape (M, N). - - Returns: - torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. - """ - dtypeC = "bfloat16" - B = torch_convert_bit_twiddling(qB) - for i in range(B.shape[0]): - for j in range(B.shape[1]): - B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) - C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias - C = C.to(torch.__getattribute__(dtypeC)) - return C - - -def ref_program_simple(A, qB, Scale, Bias=None): - """ - Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization. - - Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16. - - Parameters: - - A: 2D tensor representing the left operand (will be cast to float32 for the matmul). - - qB: Quantized representation of B accepted by `torch_convert`. - - Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32. - - Returns: - - 2D bfloat16 tensor C containing the matrix product A · B^T. - - No in-place modification is performed on inputs (a local floating copy of B is scaled). - """ - dtypeC = "bfloat16" - B = torch_convert(qB) - for i in range(B.shape[0]): - for j in range(B.shape[1]): - B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) - C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) - C = C.to(torch.__getattribute__(dtypeC)) - return C - - -def ref_program_simple_with_bias(A, qB, Scale, Bias): - """ - Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization. - - Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16. - - Parameters: - - Returns: - - A: 2D tensor representing the left operand (will be cast to float32 for the matmul). - - qB: Quantized representation of B accepted by `torch_convert`. - - Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32. - - Bias: 2D tensor representing the Bias (will be cast to float32 for the matmul). - - - Returns: - - 2D bfloat16 tensor C containing the matrix product A · B^T. - - No in-place modification is performed on inputs (a local floating copy of B is scaled). - """ - dtypeC = "bfloat16" - B = torch_convert(qB) - for i in range(B.shape[0]): - for j in range(B.shape[1]): - B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) - C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias - C = C.to(torch.__getattribute__(dtypeC)) - return C - - -def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, tune=False): - """ - Run and validate the tiled quantized matmul kernel, then benchmark its latency and report TFLOPS. - - Builds a matmul kernel for the given matrix sizes and quantization scale size. If `tune` is True the kernel is obtained via the autotuning path; otherwise a fixed-parameter kernel is used. Validates numerical correctness against the appropriate reference implementation (bit-twiddling reference when `fast_dequant` is True, plain reference otherwise) with rtol/atol=0.01, prints a confirmation, then runs a benchmark (500 warmup iterations) and prints the measured latency (ms) and achieved TFLOPS. - - Parameters: - m (int): Number of rows of A / output rows. Default 256. - n (int): Number of columns of B / output columns. Default 256. - k (int): Reduction dimension. Default 256. - scale_size (int): Size of the per-block scale vector used for dequantization. Default 32. - fast_dequant (bool): If True validate against the twiddling (fast dequant) reference and exercise the fast dequant path; otherwise use the simple dequant reference. Default True. - tune (bool): If True obtain a tuned/autotuned kernel; otherwise use a fixed-parameter kernel. Default False. - - Returns: - None - """ - total_flops = 2 * m * n * k - - if tune: - kernel = matmul( - m, - n, - k, - "bfloat16", - "bfloat16", - "float32", - num_bits=4, - scale_size=scale_size, - fast_dequant=fast_dequant, - with_bias=with_bias) - else: - kernel = matmul( - m, - n, - k, - "bfloat16", - "bfloat16", - "float32", - num_bits=4, - scale_size=scale_size, - block_M=256, - block_N=128, - block_K=128, - num_stages=2, - threads=256, - split=1, - fast_dequant=fast_dequant, - with_bias=with_bias) - - profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) - - if fast_dequant: - if with_bias: - profiler.assert_allclose(ref_program_twiddling_with_bias, rtol=0.01, atol=0.01) - else: - profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01) - else: - if with_bias: - profiler.assert_allclose(ref_program_simple_with_bias, rtol=0.01, atol=0.01) - else: - profiler.assert_allclose(ref_program_simple, rtol=0.01, atol=0.01) - print("All checks pass.") - latency = profiler.do_bench(warmup=500) - print("Tile-lang: {:.2f} ms".format(latency)) - print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) - - -if __name__ == "__main__": - M, N, K = 256, 256, 256 - scale_size = 32 - main(M, N, K, scale_size, fast_dequant=True, with_bias=True) - main(M, N, K, scale_size, fast_dequant=False, with_bias=True) - main(M, N, K, scale_size, fast_dequant=True, with_bias=False) - main(M, N, K, scale_size, fast_dequant=False, with_bias=False) diff --git a/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py index 727d6d3b6..37826874b 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py @@ -24,8 +24,9 @@ def matmul( num_bits=4, ): from tilelang.quantize import _tir_packed_to_unsigned_convert + num_elems_per_byte = 8 // num_bits - storage_dtype = "int8" + storage_dtype = T.int8 storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) A_shape = (M, K) @@ -39,9 +40,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -58,21 +59,19 @@ def main( T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) - for i in T.serial(block_N * block_K // num_elems_per_byte // - (threads * local_size_compressed)): + for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)): for v in T.vectorized(0, local_size_compressed): index = i * threads * local_size_compressed + tx * local_size_compressed + v vi = index // (block_K // num_elems_per_byte) vj = index % (block_K // num_elems_per_byte) B_local[v] = B_shared[vi, vj] for v in T.serial(0, local_size): - B_dequantize_local[v] = _tir_packed_to_unsigned_convert( - storage_type, storage_nbit)( - num_bits, - B_local[v // num_elems_per_byte], - v % num_elems_per_byte, - dtype=in_dtype, - ) + B_dequantize_local[v] = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + num_bits, + B_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) for v in T.vectorized(0, local_size): index = i * threads * local_size + tx * local_size + v vi = index // block_K @@ -121,9 +120,7 @@ def run_gemm( def ref_program(A, qB): import torch - B = ( - torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, - dtype=torch.half).to(torch.half).to(A.device)) + B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device) for i in range(B.shape[0]): for j in range(B.shape[1]): B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) @@ -146,25 +143,27 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ): from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout as make_swizzle_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitterWithLadderTransform,) + TensorCoreIntrinEmitterWithLadderTransform, + ) from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 + assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" num_bits = 4 num_elems_per_byte = 8 // num_bits - storage_dtype = "int8" + storage_dtype = T.int8 micro_size_x = micro_size_y = micro_size_k = 16 - if out_dtype == "int32": + if out_dtype == T.int32: micro_size_k = 32 # This is a debug config @@ -183,7 +182,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( block_M = block_row_warps * warp_row_tiles block_N = block_col_warps * warp_col_tiles - block_K = 32 if in_dtype == "float16" else 64 + block_K = 32 if in_dtype == T.float16 else 64 chunk = block_K // reduce_k is_smooth_a = False @@ -192,8 +191,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( pad_factor = 8 A_shape = (M, K) - B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, - micro_size_k // num_elems_per_byte) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k // num_elems_per_byte) A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) B_shared_shape = ( block_N // micro_size_y, @@ -228,7 +226,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( chunk=chunk, reduce_k=reduce_k, transform_kind_b=transform_b, - num_elems_per_byte=num_elems_per_byte) + num_elems_per_byte=num_elems_per_byte, + ) vec_load_qb = 16 if block_N * (block_K // reduce_k) // num_elems_per_byte // threads < vec_load_qb: @@ -236,14 +235,11 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((M, N), out_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, - prelude=decode_i4_to_f16) as (bx, by): - + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, prelude=decode_i4_to_f16) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -255,40 +251,36 @@ def main( thread_binding = T.get_thread_binding(0) rk = T.get_thread_binding(1) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + } + ) T.use_swizzle(panel_size=10) T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, (block_K // reduce_k)): vk = rk * (block_K // reduce_k) + k A_shared[i, vk] = A[by * block_M + i, ko * block_K + vk] # TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load - for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte // - (threads * vec_load_qb)): + for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte // (threads * vec_load_qb)): for v in T.vectorized(0, vec_load_qb): t = thread_binding idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v vkk = idx % (micro_size_k // num_elems_per_byte) vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y - vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % ( - block_K // micro_size_k) - vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y // - (block_K // micro_size_k)) % ( - block_N // micro_size_y) - B_shared[vj, vk, vjj, - vkk] = B[bx * (block_N // micro_size_y) + vj, - ko * (block_K // micro_size_k) + vk, vjj, vkk] + vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % (block_K // micro_size_k) + vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y // (block_K // micro_size_k)) % ( + block_N // micro_size_y + ) + B_shared[vj, vk, vjj, vkk] = B[bx * (block_N // micro_size_y) + vj, ko * (block_K // micro_size_k) + vk, vjj, vkk] for ki in T.serial(0, (block_K // (micro_size_k * reduce_k))): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, @@ -307,9 +299,13 @@ def main( for j in T.serial(warp_cols): local_size_b = mma_emitter.local_size_b - T.call_extern('handle', 'decode_i4u_to_f16', - T.address_of(B_local[j * local_size_b // num_elems_per_byte]), - T.address_of(B_dequantize_local[j * local_size_b]), 8) + T.call_extern( + "handle", + "decode_i4u_to_f16", + T.address_of(B_local[j * local_size_b // num_elems_per_byte]), + T.address_of(B_dequantize_local[j * local_size_b]), + 8, + ) mma_emitter.mma(A_local, B_dequantize_local, C_local) @@ -328,7 +324,8 @@ def main( reduced_accum_res[0], rk, dtype="handle", - )) + ) + ) if rk == 0: C_local[n] = reduced_accum_res[0] @@ -340,9 +337,9 @@ def main( for i, j in T.Parallel(block_M, (block_N // reduce_k)): vj = rk * (block_N // reduce_k) + j - C[by * block_M + i, - bx * block_N + vj] = C_shared[i // micro_size_x, vj // micro_size_y, - i % micro_size_x, vj % micro_size_y] + C[by * block_M + i, bx * block_N + vj] = C_shared[ + i // micro_size_x, vj // micro_size_y, i % micro_size_x, vj % micro_size_y + ] return main @@ -357,8 +354,8 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct transform_b, ): import bitblas - matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( - M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) + + matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) kernel = tilelang.compile(matmul, out_idx=[2]) src_code = kernel.get_kernel_source() @@ -368,11 +365,10 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct assert src_code is not None num_bits = 4 num_elems_per_byte = 8 // num_bits - storage_dtype = "int8" + storage_dtype = T.int8 A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - qB = torch.randint( - 0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype)) + qB = torch.randint(0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) ladder_permutate_config = bitblas.ops.LadderPermutateConfig( @@ -407,9 +403,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct # Ensure that the latency is not None assert latency is not None - B = ( - torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, - dtype=torch.half).to(torch.half).to(A.device)) + B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device) for i in range(B.shape[0]): for j in range(B.shape[1]): B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) @@ -423,14 +417,13 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct @tilelang.testing.requires_package("bitblas") def test_run_dequantize_gemm(): - run_gemm(256, 256, 256, "float16", "float16", "float16", 128, 128, 32, num_threads=128) - run_gemm(256, 256, 256, "int8", "int32", "int32", 128, 128, 32, num_threads=128) + run_gemm(256, 256, 256, T.float16, T.float16, T.float16, 128, 128, 32, num_threads=128) + run_gemm(256, 256, 256, T.int8, T.int32, T.int32, 128, 128, 32, num_threads=128) @tilelang.testing.requires_package("bitblas") def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(): - assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness( - 256, 1024, 512, "float16", "float16", "float16", 3) + assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(256, 1024, 512, T.float16, T.float16, T.float16, 3) def main(): diff --git a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py index c5588d516..2bdcbb068 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py @@ -9,30 +9,29 @@ def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): assert nbit == 4 - assert dtype == "float16" - assert val.dtype == "uint8" + assert dtype == T.float16 + assert val.dtype == T.uint8 # e_f4 == 0 -> e_f16 = 0 # e_f4 != 0 -> e_f16 = e_f4 + ExponentialBias(f16, f4) = e_f4 + (2^4 - 2^1) = e_f4 + 14 # s1e2m1 - mask = tir.const((1 << nbit) - 1, "uint16") - f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask - s = f4 >> tir.const(3, "uint16") - e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16") - e_f16 = e_f4 + tir.const(14, "uint16") - m_f4 = f4 & tir.const(1, "uint16") + mask = tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) + e_f16 = e_f4 + tir.const(14, T.uint16) + m_f4 = f4 & tir.const(1, T.uint16) m_f16 = m_f4 - val_f16 = tir.reinterpret("float16", - ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16") - | m_f16 << tir.const(9, "uint16")).astype("uint16")) - # return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16) + val_f16 = tir.reinterpret( + T.float16, ((e_f16 | (s << tir.const(5, T.uint16))) << tir.const(10, T.uint16) | m_f16 << tir.const(9, T.uint16)).astype(T.uint16) + ) + # return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, T.float16), val_f16) return val_f16 def torch_convert(tensor): - def print_bit(name, val): val_cpu = val.cpu().item() - binary_repr = f'{val_cpu:032b}' + binary_repr = f"{val_cpu:032b}" print(name, binary_repr) def _convert(val, pos): @@ -61,15 +60,15 @@ def _convert(val, pos): @tilelang.jit(out_idx=[1]) def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 B_shape = (N, K // num_elems_per_byte) B_shared_shape = (block_N, block_K // num_elems_per_byte) B_dequantize_shared_shape = (block_N, block_K) @T.prim_func def main( - B: T.Tensor(B_shape, storage_dtype), - C: T.Tensor((N, K), in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((N, K), in_dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): B_shared = T.alloc_shared(B_shared_shape, storage_dtype) @@ -99,7 +98,7 @@ def test_fp4_fp16_convert_close(): K, block_N, block_K, - "float16", + T.float16, ) B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8) @@ -118,23 +117,15 @@ def get_configs(): splits = [1] _configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits)) - configs = [{ - 'block_M': c[0], - 'block_N': c[1], - 'block_K': c[2], - 'num_stages': c[3], - 'threads': c[4], - 'split': c[5] - } for c in _configs] + configs = [{"block_M": c[0], "block_N": c[1], "block_K": c[2], "num_stages": c[3], "threads": c[4], "split": c[5]} for c in _configs] return configs def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): - @tilelang.jit(out_idx=[2]) def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1): num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 A_shape = (M, K) B_shape = (N, K // num_elems_per_byte) A_shared_shape = (block_M, block_K) @@ -145,29 +136,24 @@ def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1): @T.prim_func def main_split( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - Ct: T.Tensor((N, M), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), ): - SplitC = T.alloc_buffer([ - split, (N + block_N - 1) // block_N * block_N, - (M + block_M - 1) // block_M * block_M - ], out_dtype) - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), split, - threads=threads) as (bx, by, bz): + SplitC = T.alloc_buffer([split, (N + block_N - 1) // block_N * block_N, (M + block_M - 1) // block_M * block_M], out_dtype) + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split, threads=threads) as (bx, by, bz): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype) B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) - Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) - T.annotate_layout({ - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), - }) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + } + ) T.clear(Ct_local) for k in T.Pipelined(K // (block_K * split), num_stages=num_stages): @@ -183,8 +169,7 @@ def main_split( ) T.copy(B_dequantize_local, B_dequantize_prev_local) T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) - T.copy(Ct_local, SplitC[bz, bx * block_N:(bx + 1) * block_N, - by * block_M:(by + 1) * block_M]) + T.copy(Ct_local, SplitC[bz, bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M]) with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by): acc = T.alloc_fragment((block_N, block_M), out_dtype) T.clear(acc) @@ -195,12 +180,11 @@ def main_split( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - Ct: T.Tensor((N, M), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype) @@ -209,10 +193,11 @@ def main( Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) - T.annotate_layout({ - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), - }) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + } + ) T.clear(Ct_local) for k in T.Pipelined(K // block_K, num_stages=num_stages): @@ -229,8 +214,7 @@ def main( T.copy(B_dequantize_local, B_dequantize_prev_local) T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) T.copy(Ct_local, Ct_shared) - T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N, - by * block_M:(by + 1) * block_M]) + T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M]) if split == 1: return main @@ -241,12 +225,7 @@ def main( @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[2]) - def kernel(block_M=None, - block_N=None, - block_K=None, - num_stages=None, - threads=None, - split=None): + def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None, split=None): return kernel_func(block_M, block_N, block_K, num_stages, threads, split).prim_func return kernel() @@ -259,7 +238,7 @@ def kernel(block_M, block_N, block_K, num_stages, threads, split=1): def ref_program(A, qB): - dtypeC = "float16" + dtypeC = T.float16 B = torch_convert(qB) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) @@ -269,10 +248,10 @@ def ref_program(A, qB): def main(m=256, n=256, k=256, tune=False): total_flops = 2 * m * n * k - if (not tune): - kernel = matmul( - m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)( - block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1) + if not tune: + kernel = matmul(m, n, k, T.float16, T.float16, T.float32, num_bits=4, tune=tune)( + block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1 + ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("All checks pass.") @@ -283,7 +262,7 @@ def main(m=256, n=256, k=256, tune=False): print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) else: - best_result = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune) + best_result = matmul(m, n, k, T.float16, T.float16, T.float32, num_bits=4, tune=tune) best_latency = best_result.latency best_config = best_result.config print(f"Best latency: {best_latency}") @@ -291,12 +270,20 @@ def main(m=256, n=256, k=256, tune=False): print(f"Best config: {best_config}") +def run_regression_perf(m=4096, n=4096, k=4096): + kernel = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=False)( + block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1 + ) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--m', type=int, default=256, help='M') - parser.add_argument('--n', type=int, default=256, help='N') - parser.add_argument('--k', type=int, default=256, help='K') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--m", type=int, default=256, help="M") + parser.add_argument("--n", type=int, default=256, help="N") + parser.add_argument("--k", type=int, default=256, help="K") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() M, N, K = args.m, args.n, args.k main(M, N, K, args.tune) diff --git a/examples/dequantize_gemm/example_dequant_gemm_w4a8.py b/examples/dequantize_gemm/example_dequant_gemm_w4a8.py index 52ee8216f..b1f8b1132 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_w4a8.py +++ b/examples/dequantize_gemm/example_dequant_gemm_w4a8.py @@ -9,15 +9,15 @@ def _tir_u8_to_i4_to_i8(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): assert nbit == 4 - assert dtype == "int8" - assert val.dtype == "uint8" + assert dtype == T.int8 + assert val.dtype == T.uint8 - mask = tir.const((1 << nbit) - 1, "uint8") + mask = tir.const((1 << nbit) - 1, T.uint8) - i4 = (val >> (pos.astype("uint8") * tir.const(nbit, "uint8"))) & mask + i4 = (val >> (pos.astype(T.uint8) * tir.const(nbit, T.uint8))) & mask - i8_shifted = tir.reinterpret("int8", i4 << tir.const(4, "uint8")) - i8 = i8_shifted >> tir.const(4, "int8") + i8_shifted = tir.reinterpret(T.int8, i4 << tir.const(4, T.uint8)) + i8 = i8_shifted >> tir.const(4, T.int8) return i8 @@ -35,15 +35,15 @@ def get_configs(): @tilelang.jit(out_idx=[1]) def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 B_shape = (N, K // num_elems_per_byte) B_shared_shape = (block_N, block_K // num_elems_per_byte) B_dequantize_shared_shape = (block_N, block_K) @T.prim_func def main( - B: T.Tensor(B_shape, storage_dtype), - C: T.Tensor((N, K), in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((N, K), in_dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): B_shared = T.alloc_shared(B_shared_shape, storage_dtype) @@ -66,13 +66,12 @@ def main( def torch_convert(tensor): - def _convert(val, pos): assert val.dtype == torch.uint8 val = val.view(torch.int8) mask = (1 << 4) - 1 - i4_shifted = ((val >> (pos * 4)) & mask) - i4 = ((i4_shifted << 4) >> 4) + i4_shifted = (val >> (pos * 4)) & mask + i4 = (i4_shifted << 4) >> 4 return i4.view(torch.int8) @@ -86,7 +85,7 @@ def _convert(val, pos): def ref_program(A, qB): - dtypeC = "int32" + dtypeC = T.int32 B = torch_convert(qB) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) @@ -94,11 +93,10 @@ def ref_program(A, qB): def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): - @tilelang.jit(out_idx=[2]) def kernel_func(block_M, block_N, block_K, num_stages, threads): num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 A_shape = (M, K) B_shape = (N, K // num_elems_per_byte) A_shared_shape = (block_M, block_K) @@ -109,12 +107,11 @@ def kernel_func(block_M, block_N, block_K, num_stages, threads): @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - Ct: T.Tensor((N, M), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype) @@ -123,10 +120,11 @@ def main( Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) - T.annotate_layout({ - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), - }) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + } + ) T.clear(Ct_local) for k in T.Pipelined(K // block_K, num_stages=num_stages): @@ -143,8 +141,7 @@ def main( T.copy(B_dequantize_local, B_dequantize_prev_local) T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) T.copy(Ct_local, Ct_shared) - T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N, - by * block_M:(by + 1) * block_M]) + T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M]) return main @@ -167,10 +164,10 @@ def kernel(block_M, block_N, block_K, num_stages, threads): def main(m=128, n=256, k=256, tune=False): total_flops = 2 * m * n * k - if (not tune): - kernel = matmul_int8xint4( - m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)( - block_M=32, block_N=32, block_K=128, num_stages=1, threads=128) + if not tune: + kernel = matmul_int8xint4(m, n, k, T.int8, T.int32, T.int32, num_bits=4, tune=tune)( + block_M=32, block_N=32, block_K=128, num_stages=1, threads=128 + ) profiler = kernel.get_profiler() profiler.assert_allclose(ref_program, rtol=1e-2, atol=1e-2) print("All checks pass.") @@ -179,7 +176,7 @@ def main(m=128, n=256, k=256, tune=False): print(f"Tilelang: {latency} ms") else: - best_result = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune) + best_result = matmul_int8xint4(m, n, k, T.int8, T.int32, T.int32, num_bits=4, tune=tune) best_latency = best_result.latency best_config = best_result.config print(f"Bset latency: {best_latency}") @@ -187,6 +184,14 @@ def main(m=128, n=256, k=256, tune=False): print(f"Best tflops: {total_flops / best_latency * 1e-9}") +def run_regression_perf(m=4096, n=4096, k=4096): + kernel = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=False)( + block_M=32, block_N=32, block_K=128, num_stages=1, threads=128 + ) + profiler = kernel.get_profiler() + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--m", type=int, default=512, help="Matrix dimension M") diff --git a/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py index d3e90ec93..652ce3479 100644 --- a/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py +++ b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py @@ -4,7 +4,8 @@ import torch from tilelang import DataType from tilelang.quantize import ( - _tir_packed_int_to_int_convert,) + _tir_packed_int_to_int_convert, +) @tilelang.jit @@ -16,7 +17,7 @@ def dequantize_gemv( out_dtype: str, accum_dtype: str, num_bits: int = 4, - storage_dtype: str = "int8", + storage_dtype: T.dtype = T.int8, source_format: str = "uint", n_partition: int = 4, reduce_thread: int = 32, @@ -26,11 +27,10 @@ def dequantize_gemv( group_size: int = -1, with_scaling: bool = False, ) -> Callable[..., Any]: - assert n_partition is not None, "n_partition must be provided" assert reduce_thread is not None, ( - "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV" - "sch_outer_reduction_with_config is not implemented") + "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMVsch_outer_reduction_with_config is not implemented" + ) assert trans_A is False, "Dequantize only implement for trans_A=False currently" assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" @@ -51,7 +51,7 @@ def dequantize_gemv( C_shape = (M, N) dp4a_size = 4 - use_dp4a = in_dtype == "int8" and accum_dtype == "int32" + use_dp4a = in_dtype == T.int8 and accum_dtype == T.int32 import_source: Optional[str] = None func_name: str = "" @@ -81,12 +81,12 @@ def main( C: T.Tensor[C_shape, out_dtype], ): with T.Kernel( - T.ceildiv(N, n_partition), - M, - threads=(reduce_thread, n_partition), + T.ceildiv(N, n_partition), + M, + threads=(reduce_thread, n_partition), ) as ( - bx, - by, + bx, + by, ): A_local = T.alloc_local((micro_size_k,), in_dtype) B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype) @@ -107,8 +107,7 @@ def main( for v in T.vectorized(micro_size_k_compressed): B_quant_local[v] = B[ bx * n_partition + ni, - ko * (reduce_thread * micro_size_k_compressed) + - kr * micro_size_k_compressed + v, + ko * (reduce_thread * micro_size_k_compressed) + kr * micro_size_k_compressed + v, ] if fast_decoding: @@ -120,10 +119,9 @@ def main( ) else: for ki in T.serial(micro_size_k): - B_dequantize_local[ki] = _tir_packed_int_to_int_convert( - storage_type, - storage_nbit)(num_bits, B_quant_local[ki // num_elems_per_byte], - ki % num_elems_per_byte, in_dtype) + B_dequantize_local[ki] = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + num_bits, B_quant_local[ki // num_elems_per_byte], ki % num_elems_per_byte, in_dtype + ) if use_dp4a: for ki in T.serial(micro_size_k // dp4a_size): @@ -137,9 +135,9 @@ def main( accum_res[0] += A_local[ki] * B_dequantize_local[ki] with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), - "reduce_scope", - T.reinterpret(T.uint64(0), dtype="handle"), + T.comm_reducer(lambda x, y: x + y, [T.cast(0, accum_dtype)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), ): T.evaluate( T.tvm_thread_allreduce( @@ -149,7 +147,8 @@ def main( reduced_accum_res[0], kr, dtype="handle", - )) + ) + ) if kr == 0: C[by, bx * n_partition + ni] = reduced_accum_res[0] @@ -160,11 +159,11 @@ def main() -> None: M = 1 N = 1024 K = 1024 - in_dtype = "float16" - out_dtype = "float16" - accum_dtype = "float16" + in_dtype = T.float16 + out_dtype = T.float16 + accum_dtype = T.float16 num_bits = 4 - storage_dtype = "int8" + storage_dtype = T.int8 source_format = "uint" n_partition = 4 reduce_thread = 32 @@ -174,26 +173,39 @@ def main() -> None: group_size = -1 with_scaling = False - kernel = dequantize_gemv(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits, storage_dtype, - source_format, n_partition, reduce_thread, fast_decoding, trans_A, - trans_B, group_size, with_scaling) + kernel = dequantize_gemv( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + num_bits, + storage_dtype, + source_format, + n_partition, + reduce_thread, + fast_decoding, + trans_A, + trans_B, + group_size, + with_scaling, + ) storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) num_elems_per_byte = storage_nbit // num_bits A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda() - qB = torch.randint( - 0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda() + qB = torch.randint(0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda() C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda() if fast_decoding: from tilelang.quantize.utils import interleave_weight + qB = interleave_weight(qB, num_bits, in_dtype) kernel(A, qB, C) # int4 reference - B = ( - torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, - dtype=torch.half).to(torch.half).to(A.device)) + B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device) for j in range(B.shape[1]): B[:, j] = ((qB[:, j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) @@ -205,5 +217,62 @@ def main() -> None: torch.testing.assert_close(C, ref_c, atol=1e3, rtol=1e-1) +def run_regression_perf(): + M = 1 + N = 8192 + K = 8192 + in_dtype = "float16" + out_dtype = "float16" + accum_dtype = "float16" + num_bits = 4 + storage_dtype = "int8" + source_format = "uint" + n_partition = 4 + reduce_thread = 32 + fast_decoding = True + trans_A = False + trans_B = True + group_size = -1 + with_scaling = False + + kernel = dequantize_gemv( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + num_bits, + storage_dtype, + source_format, + n_partition, + reduce_thread, + fast_decoding, + trans_A, + trans_B, + group_size, + with_scaling, + ) + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + num_elems_per_byte = storage_nbit // num_bits + A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda() + qB = torch.randint(0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda() + C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda() + + if fast_decoding: + from tilelang.quantize.utils import interleave_weight + + qB = interleave_weight(qB, num_bits, in_dtype) + kernel(A, qB, C) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(A, qB, C) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py index c4cf5fb50..6aad32bdb 100644 --- a/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py @@ -25,6 +25,7 @@ def get_configs(): List[dict]: A list of configuration dictionaries covering all combinations. """ import itertools + iter_params = dict( block_M=[128], block_N=[64, 128, 256], @@ -33,33 +34,33 @@ def get_configs(): threads=[128, 256, 512], split=[1], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] @tilelang.autotune(configs=get_configs()) @tilelang.jit(out_idx=[-1]) -def matmul(M, - N, - K, - topk, - E, - padding_M, - in_dtype, - out_dtype, - accum_dtype, - source_format='uint', - num_bits=4, - scale_size=32, - fast_dequant=True, - with_bias=False, - block_M=128, - block_N=256, - block_K=128, - num_stages=2, - threads=256, - split=1): +def matmul( + M, + N, + K, + topk, + E, + padding_M, + in_dtype, + out_dtype, + accum_dtype, + source_format=T.uint32, + num_bits=4, + scale_size=32, + fast_dequant=True, + with_bias=False, + block_M=128, + block_N=256, + block_K=128, + num_stages=2, + threads=256, + split=1, +): """ Construct and return a grouped (Mixture-of-Experts) matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized, expert-grouped B (shape ExNxQK) and writes an output of shape (M, topk, N) in out_dtype. @@ -82,8 +83,8 @@ def matmul(M, topk (int): number of experts selected per token. E (int): number of experts. padding_M (int): padded number of tokens after grouping and block alignment. - in_dtype (str): element type of A (e.g., "bfloat16"). - out_dtype (str): output tensor element type (e.g., "bfloat16"). + in_dtype (str): element type of A (e.g., T.bfloat16). + out_dtype (str): output tensor element type (e.g., T.bfloat16). accum_dtype (str): accumulation type used for the inner GEMM. source_format (str, optional): format string passed to intrinsic selector (default "uint"). num_bits (int, optional): number of bits per quantized element in B (default 4). @@ -110,16 +111,17 @@ def matmul(M, """ num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 QK = K // num_elems_per_byte Block_QK = block_K // num_elems_per_byte A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, Block_QK) - Bias_shared_shape = (block_N) + Bias_shared_shape = block_N B_dequantize_shared_shape = (block_N, block_K) assert K % (block_K * split) == 0 from tilelang.quantize import get_mxfp_intrin_group + # fast_dequant_bf16_fp4_twiddling mxfp_intrin_info = get_mxfp_intrin_group( out_dtype=in_dtype, @@ -135,7 +137,7 @@ def matmul(M, import_source = import_source # the dequant part is the same as in dequant_gemm - def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): + def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16): """ Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16. The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and: @@ -145,12 +147,12 @@ def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): - Writes the scaled BF16 results into B_dequantize_shared. Notes: - - This factory only supports in_dtype="fp4" and out_dtype="bfloat16". + - This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16. - The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro. - The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime. """ assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] + assert out_dtype in [T.bfloat16] # Some variables for dequantization in each thread MAX_TRANSACTION_SIZE_BITS = 128 @@ -221,19 +223,16 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale_shared, for v in T.vectorized(0, local_size): index = i * threads * local_size + tx * local_size + v - B_dequantize_shared[index // block_K, - index % block_K] = B_dequantize_local_thread[v] + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] return fast_dequant_bf16_fp4_twiddling - def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): - + def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16): assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] + assert out_dtype in [T.bfloat16] @T.macro def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): - B_local = T.alloc_fragment(B_shared_shape, storage_dtype) B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype) @@ -244,8 +243,8 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): B_local[i, j // num_elems_per_byte], j % num_elems_per_byte, Scale_shared[ - i, k * block_K // scale_size + j // - scale_size], # Scale is the exponential part, within the representation of uint8 + i, k * block_K // scale_size + j // scale_size + ], # Scale is the exponential part, within the representation of uint8 dtype=out_dtype, ) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size])) T.copy(B_dequantize_local, B_dequantize_shared) @@ -254,19 +253,17 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): @T.prim_func def main( - A: T.Tensor((M, K), in_dtype), - B: T.Tensor((E, N, QK), storage_dtype), - Scale: T.Tensor((E, N, K // scale_size), storage_dtype), - Bias: T.Tensor((E, N), out_dtype), - # Add fusedmoe tensors - topk_weights: T.Tensor((M * topk), out_dtype), - sorted_token_ids: T.Tensor((padding_M), "int32"), - expert_ids: T.Tensor((padding_M // block_M), "int32"), - C: T.Tensor((M, topk, N), out_dtype), + A: T.Tensor((M, K), in_dtype), + B: T.Tensor((E, N, QK), storage_dtype), + Scale: T.Tensor((E, N, K // scale_size), storage_dtype), + Bias: T.Tensor((E, N), out_dtype), + # Add fusedmoe tensors + topk_weights: T.Tensor((M * topk), out_dtype), + sorted_token_ids: T.Tensor((padding_M), T.int32), + expert_ids: T.Tensor((padding_M // block_M), T.int32), + C: T.Tensor((M, topk, N), out_dtype), ): - - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) @@ -274,23 +271,23 @@ def main( C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype) topk_weights_shared = T.alloc_shared((block_M), out_dtype) - sorted_token_ids_shared = T.alloc_shared((block_M), "int32") - expert_id = T.alloc_local((1), "int32") # the expert id for the current block + sorted_token_ids_shared = T.alloc_shared((block_M), T.int32) + expert_id = T.alloc_local((1), T.int32) # the expert id for the current block # To use 1D TMA, the last dim of Scale_shared must have stride=1 # May use much more shared memory than necessary Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype) - T.annotate_layout({ - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - C_shared: tilelang.layout.make_swizzled_layout(C_shared), - }) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + } + ) T.use_swizzle(10) if threads == 512: T.disable_warp_group_reg_alloc() - T.copy(sorted_token_ids[by * block_M:(by + 1) * block_M], sorted_token_ids_shared) + T.copy(sorted_token_ids[by * block_M : (by + 1) * block_M], sorted_token_ids_shared) expert_id[0] = expert_ids[by] # Get the topk weights of each token in the current block @@ -300,11 +297,11 @@ def main( # Get bias and scale based on the expert id if with_bias: - T.copy(Bias[expert_id[0], bx * block_N:(bx + 1) * block_N], Bias_shared) + T.copy(Bias[expert_id[0], bx * block_N : (bx + 1) * block_N], Bias_shared) else: T.clear(Bias_shared) - T.copy(Scale[expert_id[0], bx * block_N:(bx + 1) * block_N, :], Scale_shared) + T.copy(Scale[expert_id[0], bx * block_N : (bx + 1) * block_N, :], Scale_shared) for i, j in T.Parallel(block_M, block_N): C_local[i, j] = Bias_shared[j] @@ -317,14 +314,13 @@ def main( base = copy_i * threads * 16 + tx * 16 if sorted_token_ids_shared[base // block_K] != -1: for copy_j in T.vectorized(16): - A_shared[base // block_K, base % block_K + - copy_j] = A[sorted_token_ids_shared[base // block_K] // topk, - k * block_K + base % block_K + copy_j] + A_shared[base // block_K, base % block_K + copy_j] = A[ + sorted_token_ids_shared[base // block_K] // topk, k * block_K + base % block_K + copy_j + ] T.copy(B[expert_id[0], bx * block_N, k * block_K // num_elems_per_byte], B_shared) if fast_dequant: - get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, - k) + get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, k) else: get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k) @@ -338,16 +334,17 @@ def main( base = copy_i * threads * 16 + tx * 16 if sorted_token_ids_shared[base // block_N] != -1: for copy_j in T.vectorized(16): - C[sorted_token_ids_shared[base // block_N] // topk, - sorted_token_ids_shared[base // block_N] % topk, bx * block_N + - base % block_N + copy_j] = C_shared[base // block_N, - base % block_N + copy_j] + C[ + sorted_token_ids_shared[base // block_N] // topk, + sorted_token_ids_shared[base // block_N] % topk, + bx * block_N + base % block_N + copy_j, + ] = C_shared[base // block_N, base % block_N + copy_j] return main def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=256): - dtypeC = "bfloat16" + dtypeC = T.bfloat16 M, K = A.shape E, N, QK = qB.shape topk = topk_weights.shape[0] // M @@ -355,7 +352,7 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc assert scale_size == 32 # MXFP4 # Initialize output tensor - C = torch.ones((M, topk, N), dtype=getattr(torch, dtypeC), device='cuda') + C = torch.ones((M, topk, N), dtype=getattr(torch, dtypeC), device="cuda") # Iterate over sorted_token_ids for idx in range(len(sorted_token_ids)): # padding_M @@ -370,14 +367,11 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc # Dequantize the expert weights B = torch_convert_bit_twiddling(qB[expert_id]) # shape: (N, K) - B *= 2**( - Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to( - torch.bfloat16)) + B *= 2 ** (Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to(torch.bfloat16)) # Compute the output for this token-expert pair # token_embedding @ B.T + bias - output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to( - torch.bfloat16)) + Bias[expert_id] + output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to(torch.bfloat16)) + Bias[expert_id] output = output.to(torch.__getattribute__(dtypeC)) # Apply the topk weight @@ -391,14 +385,12 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc def get_data(m, n, k, qk, scale_size, topk, E, block_M): - A = torch.empty(m, k, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1) - qB = torch.randint( - 0, 256, (E, n, qk), dtype=torch.uint8, - device='cuda') # Quantized weight tensor for E experts. - Scale = torch.randint(0, 8, (E, n, k // scale_size), dtype=torch.uint8, device='cuda') - Bias = torch.empty(E, n, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1) - - weights = torch.empty(m, E, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1) + A = torch.empty(m, k, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1) + qB = torch.randint(0, 256, (E, n, qk), dtype=torch.uint8, device="cuda") # Quantized weight tensor for E experts. + Scale = torch.randint(0, 8, (E, n, k // scale_size), dtype=torch.uint8, device="cuda") + Bias = torch.empty(E, n, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1) + + weights = torch.empty(m, E, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1) # topk_weights: Router weights for the top-k experts for each token. # Shape: (m, topk) # tokens_experts: A flattened tensor of expert assignments for each token. @@ -420,10 +412,7 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M): pad_len = ((cnt + block_M - 1) // block_M) * block_M - cnt if pad_len > 0: # -1 for padding (`M` instead in vLLM moe_align_block_size()) - group_token_ids = torch.cat([ - group_token_ids, - torch.full((pad_len,), -1, dtype=group_token_ids.dtype, device='cuda') - ]) + group_token_ids = torch.cat([group_token_ids, torch.full((pad_len,), -1, dtype=group_token_ids.dtype, device="cuda")]) padded_token_ids.append(group_token_ids) expert_ids.extend([eid] * ((cnt + block_M - 1) // block_M)) start = end @@ -431,21 +420,13 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M): # sorted_token_ids: The final flattened and padded tensor of token indices. sorted_token_ids = torch.cat(padded_token_ids, dim=0).to(torch.int32) # (padding_M,) # expert_ids: The final tensor of expert IDs corresponding to `sorted_token_ids`. - expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda') # (padding_M,) + expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda") # (padding_M,) padding_M = sorted_token_ids.shape[0] # padding_M: token number after padding return A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M -def main(m=256, - n=256, - k=256, - scale_size=32, - topk=4, - E=32, - fast_dequant=True, - with_bias=False, - tune=False): +def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, with_bias=False, tune=False): # Tunable parameters block_M, block_N, block_K = 128, 256, 128 # noqa: F841 num_stages = 1 # noqa: F841 @@ -456,8 +437,7 @@ def main(m=256, num_bits = 4 num_elems_per_byte = 8 // num_bits qk = k // num_elems_per_byte - A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data( - m, n, k, qk, scale_size, topk, E, block_M) + A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data(m, n, k, qk, scale_size, topk, E, block_M) if tune: with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]): @@ -469,9 +449,9 @@ def main(m=256, topk, E, padding_M, - "bfloat16", - "bfloat16", - "float32", + T.bfloat16, + T.bfloat16, + T.float32, num_bits=num_bits, scale_size=scale_size, fast_dequant=fast_dequant, @@ -485,9 +465,9 @@ def main(m=256, topk, E, padding_M, - "bfloat16", - "bfloat16", - "float32", + T.bfloat16, + T.bfloat16, + T.float32, num_bits=num_bits, scale_size=scale_size, fast_dequant=fast_dequant, @@ -509,15 +489,11 @@ def main(m=256, sorted_token_ids, expert_ids, ) + print("Tilelang kernel run finished.") - print('Tilelang kernel run finished.') - - ref_output = ref_moe( - A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, - block_M=block_M) # Maybe a little bit slow... + ref_output = ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=block_M) # Maybe a little bit slow... - latency = tilelang.profiler.do_bench( - lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100) + latency = tilelang.profiler.do_bench(lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100) print("Tilelang: {:.2f} ms".format(latency)) print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) @@ -525,32 +501,72 @@ def main(m=256, max_val = diff.max() max_idx = diff.argmax() print(f"max abs diff: {max_val} at index: {max_idx}") - assert_similar( - output, ref_output, name="output", - eps=2e-5) # We care about the similarity rather than abs. difference + assert_similar(output, ref_output, name="output", eps=2e-5) # We care about the similarity rather than abs. difference print("All checks pass. ✅") +def run_regression_perf(m=4096, n=4096, k=4096, scale_size=32, topk=4, E=32, fast_dequant=True, with_bias=False, tune=False): + block_M, block_N, block_K = 128, 256, 128 + num_stages = 1 + threads = 512 + split = 1 + num_bits = 4 + num_elems_per_byte = 8 // num_bits + qk = k // num_elems_per_byte + A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data(m, n, k, qk, scale_size, topk, E, block_M) + + if tune: + with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]): + kernel = matmul( + m, + n, + k, + topk, + E, + padding_M, + "bfloat16", + "bfloat16", + "float32", + num_bits=num_bits, + scale_size=scale_size, + fast_dequant=fast_dequant, + with_bias=with_bias, + ) + else: + kernel = matmul( + m, + n, + k, + topk, + E, + padding_M, + "bfloat16", + "bfloat16", + "float32", + num_bits=num_bits, + scale_size=scale_size, + fast_dequant=fast_dequant, + with_bias=with_bias, + block_M=block_M, + block_N=block_N, + block_K=block_K, + num_stages=num_stages, + threads=threads, + split=split, + ) + + return tilelang.profiler.do_bench(lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - "--M", type=int, default=16384, help="M") # From gpt-oss-20b MoE's first gemm - parser.add_argument("--N", type=int, default=5760, help="N") - parser.add_argument("--K", type=int, default=2944, help="K") + parser.add_argument("--M", type=int, default=256, help="M") # From gpt-oss-20b MoE's first gemm + parser.add_argument("--N", type=int, default=256, help="N") + parser.add_argument("--K", type=int, default=256, help="K") parser.add_argument("--scale_size", type=int, default=32, help="scale size") - parser.add_argument( - "--topk", type=int, default=4, help="topk") # experts activated for each token + parser.add_argument("--topk", type=int, default=4, help="topk") # experts activated for each token parser.add_argument("--E", type=int, default=32, help="E") # number of experts parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - main( - args.M, - args.N, - args.K, - args.scale_size, - topk=args.topk, - E=args.E, - fast_dequant=True, - with_bias=True, - tune=args.tune) + main(args.M, args.N, args.K, args.scale_size, topk=args.topk, E=args.E, fast_dequant=True, with_bias=True, tune=args.tune) diff --git a/examples/dequantize_gemm/regression_example_dequantize_gemm.py b/examples/dequantize_gemm/regression_example_dequantize_gemm.py new file mode 100644 index 000000000..4ab03784f --- /dev/null +++ b/examples/dequantize_gemm/regression_example_dequantize_gemm.py @@ -0,0 +1,35 @@ +import tilelang.testing +import example_dequant_gemm_bf16_fp4_hopper +import example_dequant_gemm_bf16_mxfp4_hopper +import example_dequant_gemm_fp4_hopper +import example_dequant_gemm_w4a8 +import example_dequant_gemv_fp16xint4 +import example_dequant_groupedgemm_bf16_mxfp4_hopper + + +def regression_example_dequant_gemv_fp16xint4(): + tilelang.testing.process_func(example_dequant_gemv_fp16xint4.run_regression_perf) + + +def regression_example_dequant_gemm_fp4_hopper(): + tilelang.testing.process_func(example_dequant_gemm_fp4_hopper.run_regression_perf) + + +def regression_example_dequant_gemm_bf16_fp4_hopper(): + tilelang.testing.process_func(example_dequant_gemm_bf16_fp4_hopper.run_regression_perf) + + +def regression_example_dequant_gemm_bf16_mxfp4_hopper(): + tilelang.testing.process_func(example_dequant_gemm_bf16_mxfp4_hopper.run_regression_perf) + + +def regression_example_dequant_groupedgemm_bf16_mxfp4_hopper(): + tilelang.testing.process_func(example_dequant_groupedgemm_bf16_mxfp4_hopper.run_regression_perf) + + +def regression_example_dequant_gemm_w4a8(): + tilelang.testing.process_func(example_dequant_gemm_w4a8.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/dequantize_gemm/test_example_dequantize_gemm.py b/examples/dequantize_gemm/test_example_dequantize_gemm.py index 01bc40e6c..a2f777222 100644 --- a/examples/dequantize_gemm/test_example_dequantize_gemm.py +++ b/examples/dequantize_gemm/test_example_dequantize_gemm.py @@ -3,7 +3,6 @@ import example_dequant_gemv_fp16xint4 import example_dequant_gemm_fp4_hopper import example_dequant_gemm_bf16_mxfp4_hopper -import example_dequant_gemm_bf16_mxfp4_hopper_tma import example_dequant_groupedgemm_bf16_mxfp4_hopper import example_dequant_gemm_w4a8 @@ -25,12 +24,6 @@ def test_example_dequant_gemm_bf16_mxfp4_hopper(): example_dequant_gemm_bf16_mxfp4_hopper.main() -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) -def test_example_dequant_gemm_bf16_mxfp4_hopper_tma(): - example_dequant_gemm_bf16_mxfp4_hopper_tma.main() - - @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_dequant_groupedgemm_bf16_mxfp4_hopper(): diff --git a/examples/dsa_sparse_finetune/dsa.py b/examples/dsa_sparse_finetune/dsa.py new file mode 100644 index 000000000..9fae8e5e3 --- /dev/null +++ b/examples/dsa_sparse_finetune/dsa.py @@ -0,0 +1,223 @@ +from typing import Optional +import torch +import torch.nn.functional as F +from indexer_topk_reducesum import indexer_topk_reducesum_interface +from indexer_bwd import indexer_bwd_interface +from sparse_mla_fwd import sparse_mla_fwd_interface +from sparse_mla_bwd import sparse_mla_bwd +from sparse_mla_topk_reducesum import sparse_mla_topk_reducesum_interface +from einops import einsum, repeat +from utils import get_abs_err, get_err_ratio + + +class RegsiterLossFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, loss): + ctx.save_for_backward(loss) + return x + + @staticmethod + def backward(ctx, grad): + loss = ctx.saved_tensors + return grad, torch.ones(1, dtype=loss[0].dtype, device=loss[0].device) + + +register_loss = RegsiterLossFunction.apply + + +def ref_deepseek_sparse_attention_innner( + q: torch.Tensor, + kv: torch.Tensor, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + topk: int, + dim_v: int, + sm_scale: Optional[float] = None, + index_sm_scale: Optional[float] = None, +): + dtype = q.dtype + q, kv, index_q, index_k, weights = map(lambda x: x.to(torch.float32), (q, kv, index_q, index_k, weights)) + + index_sm_scale = index_q.shape[-1] ** -0.5 + b, s = index_q.shape[:2] + + # tl_topk_indices = tl_topk_indices.to(torch.int64) + # tl_topk_indices[tl_topk_indices == -1] = s + + casual_mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) + index_logits = einsum(index_q, index_k, "b s1 h k, b s2 k -> b s1 h s2") + index_logits = F.relu(index_logits) + index_logits = (index_logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * index_sm_scale + index_logits = torch.where(casual_mask, index_logits, float("-inf")) + topk_indices = torch.topk(index_logits, k=topk, dim=-1).indices + topk_logits = torch.gather(F.pad(index_logits, (0, 1), value=float("-inf")), dim=-1, index=topk_indices) + topk_score = F.log_softmax(topk_logits, dim=-1, dtype=torch.float32) + index_topk_score = topk_score + + if sm_scale is None: + sm_scale = kv.shape[-1] ** -0.5 + + h = q.shape[-2] + index_mask = torch.zeros((b, s, s + 1), dtype=torch.bool, device="cuda").scatter_( + dim=-1, index=topk_indices, src=torch.ones_like(topk_indices, dtype=torch.bool) + )[:, :, :-1] + mask = repeat(casual_mask & index_mask, "b s1 s2 -> b s1 h s2", h=h) + k, v = kv, kv[..., :dim_v] + logits = einsum(q, k, "b s1 h d, b s2 d -> b s1 h s2") * sm_scale + logits = torch.where(mask, logits, float("-inf")) + attn_score = F.softmax(logits, dim=-1, dtype=torch.float32) + o = einsum(attn_score, v, "b s1 h s2, b s2 d -> b s1 h d") + + attn_score = attn_score.sum(dim=-2) # [b, s1, s2] + attn_topk_score = torch.gather(F.pad(attn_score, (0, 1)), dim=-1, index=topk_indices) + attn_topk_score = attn_topk_score / attn_topk_score.sum(dim=-1, keepdim=True) + + loss = F.kl_div(index_topk_score.clip(-100, 0), attn_topk_score.detach().log().clip(-100, 0), log_target=True, reduction="sum") + o = register_loss(o, loss) + + return o.to(dtype), topk_indices + + +def ref_deepseek_sparse_attention( + q: torch.Tensor, + kv: torch.Tensor, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + offsets: torch.Tensor, + topk: int, + dim_v: int, + sm_scale: Optional[float] = None, + index_sm_scale: Optional[float] = None, +): + all_o, all_topk_indices = [], [] + for i in range(offsets.shape[0] - 1): + o, topk_indices = ref_deepseek_sparse_attention_innner( + q[None, offsets[i] : offsets[i + 1]], + kv[None, offsets[i] : offsets[i + 1]], + index_q[None, offsets[i] : offsets[i + 1]], + index_k[None, offsets[i] : offsets[i + 1]], + weights[None, offsets[i] : offsets[i + 1]], + topk, + dim_v, + sm_scale, + index_sm_scale, + ) + all_o.append(o.squeeze(0)) + all_topk_indices.append(topk_indices.squeeze(0)) + o = torch.cat(all_o, dim=0) + topk_indices = torch.cat(all_topk_indices, dim=0) + return o, topk_indices + + +class DSAFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + kv: torch.Tensor, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + offsets: torch.Tensor, + topk: int, + dim_v: int, + sm_scale: Optional[float] = None, + ): + # topk_indices, index_score = ref_index_score(index_q, weights, index_k, topk) + topk_indices, index_score = indexer_topk_reducesum_interface(index_q, weights, index_k, topk, offsets) + o, lse = sparse_mla_fwd_interface(q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), offsets, sm_scale=sm_scale, d_v=dim_v) + ctx.save_for_backward(q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets) + ctx.topk = topk + ctx.dim_v = dim_v + ctx.sm_scale = sm_scale + return o, topk_indices + + @staticmethod + def backward( + ctx, + do: torch.Tensor, + _1: torch.Tensor, + ): + q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets = ctx.saved_tensors + attn_score = sparse_mla_topk_reducesum_interface( + q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), lse, offsets, dim_v=ctx.dim_v + ).squeeze(-2) + dq, dkv = sparse_mla_bwd(q, kv.unsqueeze(-2), o, do, topk_indices.unsqueeze(-2), lse, offsets, sm_scale=ctx.sm_scale) + dindex_q, dweights, dindex_k = indexer_bwd_interface(index_q, weights, index_k, attn_score, index_score, topk_indices, offsets) + return dq, dkv.squeeze(-2), dindex_q, dindex_k, dweights, None, None, None, None + + +def deepseek_sparse_attention( + q: torch.Tensor, + kv: torch.Tensor, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + offsets: torch.Tensor, + topk: int, + dim_v: int, + sm_scale: Optional[float] = None, +): + return DSAFunction.apply(q, kv, index_q, index_k, weights, offsets, topk, dim_v, sm_scale) + + +def test_kernel( + B=1, + S=2048, + H=16, + D=512, + tail_D=64, + index_D=128, + topk=64, +): + torch.manual_seed(42) + q = torch.randn((S, H, D + tail_D)).cuda().bfloat16().requires_grad_() + kv = torch.randn((S, D + tail_D)).cuda().bfloat16().requires_grad_() + index_q = torch.randn((S, H, index_D)).cuda().bfloat16().requires_grad_() + weights = torch.randn((S, H)).cuda().bfloat16().requires_grad_() + index_k = torch.randn((S, index_D)).cuda().bfloat16().requires_grad_() + do = torch.randn((S, H, D)).cuda().bfloat16().requires_grad_() + offsets = torch.tensor([0, S // 2, S], dtype=torch.int32).cuda() + + o, topk_indices = deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D) + o.backward(do) + q_grad, q.grad = q.grad, None + kv_grad, kv.grad = kv.grad, None + index_q_grad, index_q.grad = index_q.grad, None + index_k_grad, index_k.grad = index_k.grad, None + weights_grad, weights.grad = weights.grad, None + + ref_o, ref_topk_indices = ref_deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D) + ref_o.backward(do) + ref_q_grad, q.grad = q.grad, None + ref_kv_grad, kv.grad = kv.grad, None + ref_index_q_grad, index_q.grad = index_q.grad, None + ref_index_k_grad, index_k.grad = index_k.grad, None + ref_weights_grad, weights.grad = weights.grad, None + + print(f"o err: {get_abs_err(o, ref_o):.6f} ratio: {get_err_ratio(o, ref_o):.6f}") + print(f"q.grad err: {get_abs_err(q_grad, ref_q_grad):.6f} ratio: {get_err_ratio(q_grad, ref_q_grad):.6f}") + print(f"kv.grad err: {get_abs_err(kv_grad, ref_kv_grad):.6f} ratio: {get_err_ratio(kv_grad, ref_kv_grad):.6f}") + print( + f"index_q.grad err: {get_abs_err(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f} ratio: {get_err_ratio(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f}" + ) + print(f"index_k.grad err: {get_abs_err(index_k_grad, ref_index_k_grad):.6f} ratio: {get_err_ratio(index_k_grad, ref_index_k_grad):.6f}") + print(f"weights.grad err: {get_abs_err(weights_grad, ref_weights_grad):.6f} ratio: {get_err_ratio(weights_grad, ref_weights_grad):.6f}") + + intersections = [] + for j in range(S): + ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy() + trt_np = topk_indices[j].cpu().to(torch.int32).numpy() + + mask = trt_np != -1 + + set_ref = set(ref_np[mask]) + set_trt = set(trt_np[mask]) + intersection = set_ref & set_trt + intersections.append(len(intersection) / len(set_ref)) + print("average intersections: {:.4f}".format(sum(intersections) / len(intersections))) + + +test_kernel() diff --git a/examples/dsa_sparse_finetune/index.py b/examples/dsa_sparse_finetune/index.py new file mode 100644 index 000000000..5e4800411 --- /dev/null +++ b/examples/dsa_sparse_finetune/index.py @@ -0,0 +1,82 @@ +# Modified from: https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/index.py +import torch +import torch.nn.functional as F +import functools +from typing import Callable, Any + + +def tensor_cache( + fn: Callable[..., torch.Tensor], +) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent result of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + If the function is called again with the same input tensors, it will return the cached result. + + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + last_args: tuple | None = None + last_kwargs: dict | None = None + last_result: Any = None + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal last_args, last_kwargs, last_result + + if ( + (last_args is not None and last_kwargs is not None) + and (len(args) == len(last_args) and len(kwargs) == len(last_kwargs)) + and all(a is b for a, b in zip(args, last_args, strict=False)) + and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()) + ): + return last_result + + result = fn(*args, **kwargs) + last_args, last_kwargs, last_result = args, kwargs, result + return result + + return wrapper + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return torch.diff(cu_seqlens) + + +@tensor_cache +def prepare_cu_seqlens_from_lens( + lens: torch.LongTensor, + dtype: torch.dtype | None = torch.int32, +) -> torch.LongTensor: + return F.pad(lens.cumsum(dim=0, dtype=dtype), (1, 0)) + + +@tensor_cache +def prepare_lens_from_cu_seqlens( + cu_seqlens: torch.LongTensor, +) -> torch.LongTensor: + return torch.diff(cu_seqlens) + + +@tensor_cache +def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return torch.cat([torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) for n in prepare_lens(cu_seqlens).unbind()]) + + +@tensor_cache +def prepare_sequence_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return prepare_position_ids(cu_seqlens).eq(0).cumsum(0) - 1 + + +@tensor_cache +def prepare_token_indices(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + position_ids = prepare_position_ids(cu_seqlens) + return torch.stack([prepare_sequence_ids(cu_seqlens), position_ids], 1).to(cu_seqlens) diff --git a/examples/dsa_sparse_finetune/indexer_bwd.py b/examples/dsa_sparse_finetune/indexer_bwd.py new file mode 100644 index 000000000..68508ad4e --- /dev/null +++ b/examples/dsa_sparse_finetune/indexer_bwd.py @@ -0,0 +1,254 @@ +import torch +import torch.nn.functional as F +from einops import einsum, repeat + +import tilelang as tl +import tilelang.language as T +from typing import Optional +from index import prepare_token_indices + +from utils import get_abs_err, get_err_ratio + +BF16 = T.bfloat16 +FP32 = T.float32 +INT32 = T.int32 + +pass_configs = { + tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, +} + + +@tl.jit(pass_configs=pass_configs) +def tl_indexer_bwd_impl( + heads: int, + dim: int, + topk: int, + sm_scale: Optional[float] = None, + block_I: int = 32, + num_stages: int = 0, + num_threads: int = 128, +): + assert num_stages == 0 + assert topk == tl.math.next_power_of_2(topk) + assert topk % block_I == 0 + assert heads <= 64 and heads % 8 == 0 + batch_plus_one = T.symbolic("batch_plus_one") + seq_len = T.symbolic("seq_len") + dtype: str = BF16 + accum_dtype: str = FP32 + index_q_shape = [seq_len, heads, dim] + weights_shape = [seq_len, heads] + index_k_shape = [seq_len, dim] + shape_p = [seq_len, topk] + topk_indices_shape = [seq_len, topk] + offsets_shape = [batch_plus_one] + token_indices_shape = [seq_len, 2] + if sm_scale is None: + sm_scale = dim**-0.5 + + @T.prim_func + def tl_indexer_bwd_kernel( + IndexQ: T.Tensor(index_q_shape, dtype), + Weights: T.Tensor(weights_shape, dtype), + IndexK: T.Tensor(index_k_shape, dtype), + dIndexQ: T.Tensor(index_q_shape, dtype), + dWeights: T.Tensor(weights_shape, dtype), + dIndexK: T.Tensor(index_k_shape, dtype), + AttnScore: T.Tensor(shape_p, FP32), + IndexScore: T.Tensor(shape_p, FP32), + TopkIndices: T.Tensor(topk_indices_shape, INT32), + Offsets: T.Tensor(offsets_shape, INT32), + TokenIndices: T.Tensor(token_indices_shape, INT32), + ): + with T.Kernel(seq_len, threads=num_threads) as (bx): + i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1] + bos = Offsets[i_b] + num_blocks = T.ceildiv(topk, block_I) + + index_q_shared = T.alloc_shared([heads, dim], dtype=dtype) + weights_shared = T.alloc_shared([heads], dtype=dtype) + + d_index_q_frag = T.alloc_fragment([heads, dim], dtype=accum_dtype) + d_weights_frag = T.alloc_fragment([heads], dtype=accum_dtype) + + T.copy(IndexQ[bos + i_t, :, :], index_q_shared) + T.copy(Weights[bos + i_t, :], weights_shared) + T.fill(d_index_q_frag, 0) + T.fill(d_weights_frag, 0) + + for i, j in T.Parallel(heads, dim): + index_q_shared[i, j] = index_q_shared[i, j] * sm_scale + + for bi_i in T.Pipelined(num_blocks, num_stages=num_stages): + i_st = bi_i * block_I + i_ed = (bi_i + 1) * block_I + + indices_shared = T.alloc_shared([block_I], dtype=INT32) + T.copy(TopkIndices[bos + i_t, i_st:i_ed], indices_shared) + + index_k_shared = T.alloc_shared([block_I, dim], dtype=dtype) + for i, j in T.Parallel(block_I, dim): + pos = indices_shared[i] + index_k_shared[i, j] = T.if_then_else((pos > -1) & (pos <= i_t), IndexK[bos + pos, j], 0) + + attn_score_shared = T.alloc_shared([block_I], dtype=accum_dtype) + index_score_shared = T.alloc_shared([block_I], dtype=accum_dtype) + for i in T.Parallel(block_I): + attn_score_shared[i] = AttnScore[bos + i_t, i_st + i] + index_score_shared[i] = IndexScore[bos + i_t, i_st + i] + + logits = T.alloc_fragment((block_I, heads), accum_dtype) + T.gemm( + index_k_shared, + index_q_shared, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=True, + ) + for i, j in T.Parallel(block_I, heads): + logits[i, j] = T.max(logits[i, j], 0) + + # dw + d_weights_i = T.alloc_fragment((block_I, heads), accum_dtype) + for i, j in T.Parallel(block_I, heads): + d_weights_i[i, j] = (index_score_shared[i] - attn_score_shared[i]) * logits[i, j] + T.reduce_sum(d_weights_i, d_weights_frag, dim=0, clear=False) + + d_logits_qk = T.alloc_shared((block_I, heads), accum_dtype) + d_logits_qk_cast1 = T.alloc_fragment((block_I, heads), dtype) + d_logits_qk_cast2 = T.alloc_fragment((block_I, heads), dtype) + + for i, j in T.Parallel(block_I, heads): + d_relu = T.alloc_var(accum_dtype) + if logits[i, j] > 0: + d_relu = 1.0 + else: + d_relu = 0.0 + d_logits_qk[i, j] = (index_score_shared[i] - attn_score_shared[i]) * d_relu * weights_shared[j] + + # dq + T.copy(d_logits_qk, d_logits_qk_cast1) + T.gemm( + d_logits_qk_cast1, # [BS, HQ] + index_k_shared, # [BS, K] + d_index_q_frag, # [HQ, K] + transpose_A=True, + transpose_B=False, + clear_accum=False, + ) + + # dk + T.copy(d_logits_qk, d_logits_qk_cast2) + d_index_k_frag = T.alloc_fragment([block_I, dim], dtype=accum_dtype) + T.gemm( + d_logits_qk_cast2, # [BS, HQ] + index_q_shared, # [HQ, K] + d_index_k_frag, # [BS, K] + transpose_A=False, + transpose_B=False, + clear_accum=True, + ) + + for i, j in T.Parallel(block_I, dim): + pos = indices_shared[i] + if (pos > -1) & (pos <= i_t): + T.atomic_add(dIndexK[bos + pos, j], d_index_k_frag[i, j]) + + for i, j in T.Parallel(heads, dim): + d_index_q_frag[i, j] = d_index_q_frag[i, j] * sm_scale + + T.copy(d_index_q_frag, dIndexQ[bos + i_t, :, :]) + T.copy(d_weights_frag, dWeights[bos + i_t, :]) + + return tl_indexer_bwd_kernel + + +def indexer_bwd_interface( + q: torch.Tensor, + weights: torch.Tensor, + k: torch.Tensor, + attn_score: torch.Tensor, + index_score: torch.Tensor, + topk_indices: torch.Tensor, + offsets: torch.Tensor, +): + _, heads, dim, topk = *q.shape, topk_indices.shape[-1] + token_indices = prepare_token_indices(offsets) + dq = torch.zeros_like(q) + dweights = torch.zeros_like(weights) + dk = torch.zeros_like(k) + kernel = tl_indexer_bwd_impl(heads, dim, topk) + kernel(q, weights, k, dq, dweights, dk, attn_score, index_score, topk_indices, offsets, token_indices) + return dq, dweights, dk + + +def ref_indexer_bwd( + Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, AttnScore: torch.Tensor, offsets: torch.Tensor +) -> torch.Tensor: + Q.requires_grad_(True) + Weights.requires_grad_(True) + K.requires_grad_(True) + softmax_scale = Q.shape[-1] ** -0.5 + all_loss = [] + all_log_topk_prob = [] + for i in range(offsets.shape[0] - 1): + assert (offsets[i + 1] - offsets[i]).item() >= TopkIndices.shape[-1] + q = Q[offsets[i] : offsets[i + 1]] + weights = Weights[offsets[i] : offsets[i + 1]] + k = K[offsets[i] : offsets[i + 1]] + topk_indices = TopkIndices[offsets[i] : offsets[i + 1]] + attn_score = AttnScore[offsets[i] : offsets[i + 1]] + s = q.shape[0] + mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) + logits = einsum(q, k, "s1 h k, s2 k -> s1 h s2") * softmax_scale + logits = F.relu(logits) + score = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) + score = torch.where(mask, score, float("-inf")) + topk_value = torch.gather(score, dim=-1, index=topk_indices.to(torch.int64)) + log_topk_prob = F.log_softmax(topk_value, dim=-1, dtype=torch.float32) + loss = F.kl_div(log_topk_prob.clip(-100, 0), attn_score.log().clip(-100, 0), log_target=True, reduction="sum") + all_loss.append(loss) + all_log_topk_prob.append(log_topk_prob) + loss = torch.stack(all_loss).sum() + loss.backward() + log_topk_prob = torch.cat(all_log_topk_prob, dim=0) + return log_topk_prob.exp(), Q.grad, Weights.grad, K.grad + + +def test_kernel( + B=1, + S=2048, + H=16, + D=128, + topk=64, +): + torch.manual_seed(42) + q = torch.randn((S, H, D)).cuda().bfloat16() + w = torch.randn((S, H)).cuda().bfloat16() + k = torch.randn((S, D)).cuda().bfloat16() + offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda() + + all_attn_score = [] + for i in range(offsets.shape[0] - 1): + seq_len = (offsets[i + 1] - offsets[i]).item() + mask = (torch.arange(seq_len)[:, None] >= torch.arange(topk)[None, :]).to(q.device) + logits = torch.ones(seq_len, topk).cuda() + logits = torch.where(mask, logits, float("-inf")) + attn_score = F.softmax(logits, dim=-1, dtype=torch.float32) + all_attn_score.append(attn_score) + attn_score = torch.cat(all_attn_score, dim=0) + + topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous() + index_score, ref_dq, ref_dw, ref_dk = ref_indexer_bwd(q, w, k, topk_indices, attn_score, offsets) + + dq, dw, dk = indexer_bwd_interface(q, w, k, attn_score, index_score, topk_indices, offsets) + + print(f"dq err: {get_abs_err(dq, ref_dq):.6f} ratio: {get_err_ratio(dq, ref_dq):.6f}") + print(f"dq err: {get_abs_err(dw, ref_dw):.6f} ratio: {get_err_ratio(dw, ref_dw):.6f}") + print(f"dq err: {get_abs_err(dk, ref_dk):.6f} ratio: {get_err_ratio(dk, ref_dk):.6f}") + + +if __name__ == "__main__": + test_kernel() diff --git a/examples/dsa_sparse_finetune/indexer_topk_reducesum.py b/examples/dsa_sparse_finetune/indexer_topk_reducesum.py new file mode 100644 index 000000000..d76eb0272 --- /dev/null +++ b/examples/dsa_sparse_finetune/indexer_topk_reducesum.py @@ -0,0 +1,273 @@ +import math +import torch +import torch.nn.functional as F +from einops import einsum + +import tilelang as tl +import tilelang.language as T +from typing import Optional +from index import prepare_token_indices + +from utils import get_abs_err, get_err_ratio + +BF16 = T.bfloat16 +FP32 = T.float32 +INT32 = T.int32 + +pass_configs = { + tl.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, + tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, +} + + +@tl.jit(pass_configs=pass_configs) +def tl_indexer_topk_reducesum_impl( + heads: int, + dim: int, + topk: int, + sm_scale: Optional[float] = None, + block_K: int = 32, + dtype: str = FP32, + num_stages: int = 0, + num_threads: int = 128, +): + assert topk == tl.math.next_power_of_2(topk) + assert topk % block_K == 0 + assert heads <= 64 and heads % 8 == 0 + assert num_stages == 0 + batch_plus_one = T.symbolic("batch_plus_one") + seq_len = T.symbolic("seq_len") + + index_q_shape = [seq_len, heads, dim] + weights_shape = [seq_len, heads] + index_k_shape = [seq_len, dim] + topk_indices_shape = [seq_len, topk] + offsets_shape = [batch_plus_one] + token_indices_shape = [seq_len, 2] + + N = 2 * topk + num_iters = int(round(math.log2(N))) + if sm_scale is None: + sm_scale = dim**-0.5 + + @T.macro + def bitonic_sort( + topk_index_shared: T.SharedBuffer([N], dtype=INT32), + topk_value_shared: T.SharedBuffer([N], dtype=FP32), + ): + T.sync_threads() + for i1 in T.serial(num_iters): + for i2 in T.serial(i1 + 1): + for i in T.Parallel(N): + ascending = (i & (1 << (i1 + 1))) != 0 + j = i ^ (1 << (i1 - i2)) + if i < j and ( + (ascending and topk_value_shared[i] > topk_value_shared[j]) + or (not ascending and topk_value_shared[i] < topk_value_shared[j]) + ): + val = topk_value_shared[i] + topk_value_shared[i] = topk_value_shared[j] + topk_value_shared[j] = val + idx = topk_index_shared[i] + topk_index_shared[i] = topk_index_shared[j] + topk_index_shared[j] = idx + T.sync_threads() + + @T.prim_func + def tl_indexer_topk_reducesum_kernel( + IndexQ: T.Tensor(index_q_shape, dtype), + Weights: T.Tensor(weights_shape, dtype), + IndexK: T.Tensor(index_k_shape, dtype), + TopkIndices: T.Tensor(topk_indices_shape, INT32), + ReduceSum: T.Tensor(topk_indices_shape, FP32), + Offsets: T.Tensor(offsets_shape, INT32), + TokenIndices: T.Tensor(token_indices_shape, INT32), + ): + with T.Kernel(seq_len, threads=num_threads) as (bx): + i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1] + bos, eos = Offsets[i_b], Offsets[i_b + 1] + num_blocks = T.ceildiv(i_t + 1, block_K) + + topk_index_shared = T.alloc_shared([N], dtype=INT32) + topk_value_shared = T.alloc_shared([N], dtype=FP32) + + T.fill(topk_index_shared, -1) + T.fill(topk_value_shared, float("-inf")) + T.sync_threads() + + index_q_shared = T.alloc_shared([heads, dim], dtype=dtype) + T.copy(IndexQ[bos + i_t, :, :], index_q_shared) + T.sync_threads() + + weights_frag = T.alloc_shared([heads], dtype=dtype) + T.copy(Weights[bos + i_t, :], weights_frag) + T.sync_threads() + + for i, j in T.Parallel(heads, dim): + index_q_shared[i, j] = index_q_shared[i, j] * sm_scale + T.sync_threads() + + for bk_i in T.Pipelined(num_blocks, num_stages=num_stages): + k_st = bk_i * block_K + k_ed = T.min((bk_i + 1) * block_K, eos - bos) + + index_k_shared = T.alloc_shared([block_K, dim], dtype=dtype) + for i, j in T.Parallel(block_K, dim): + index_k_shared[i, j] = T.if_then_else(k_st + i < k_ed, IndexK[bos + k_st + i, j], 0) + T.sync_threads() + + logits = T.alloc_fragment((block_K, heads), FP32) + T.gemm( + index_k_shared, + index_q_shared, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=True, + ) + T.sync_threads() + + for i, j in T.Parallel(block_K, heads): + logits[i, j] = T.max(logits[i, j], 0) * weights_frag[j] + T.sync_threads() + + logits_sum = T.alloc_fragment(block_K, FP32) + T.reduce_sum(logits, logits_sum, dim=1) + T.sync_threads() + + offset = T.alloc_var(INT32) + if k_st >= topk: + offset = topk + (k_st % topk) + else: + offset = k_st + T.sync_threads() + for i in T.Parallel(block_K): + if k_st + i > i_t: + logits_sum[i] = float("-inf") + j = offset + i + topk_index_shared[j] = k_st + i + topk_value_shared[j] = logits_sum[i] + T.sync_threads() + + if k_ed > topk and k_ed % topk == 0: + bitonic_sort(topk_index_shared, topk_value_shared) + + bitonic_sort(topk_index_shared, topk_value_shared) + + logits_max_frag = T.alloc_fragment([1], dtype=FP32) + logits_frag = T.alloc_fragment([topk], dtype=FP32) + reducesum_shared = T.alloc_shared([topk], dtype=FP32) + + T.copy(topk_value_shared[:topk], logits_frag) + T.sync_threads() + + T.reduce_max(logits_frag, logits_max_frag, dim=-1) + T.sync_threads() + + for i in T.Parallel(topk): + logits_frag[i] = T.exp(logits_frag[i] - logits_max_frag[0]) + T.sync_threads() + + lse_frag = T.alloc_fragment([1], dtype=FP32) + T.reduce_sum(logits_frag, lse_frag) + T.sync_threads() + + for i in T.Parallel(topk): + reducesum_shared[i] = logits_frag[i] / lse_frag[0] + T.sync_threads() + + # for i in T.Parallel(topk): + # reducesum_shared[i] = logits_frag[i] + # T.sync_threads() + + for i in T.Parallel(topk): + if topk_index_shared[i] > i_t: + topk_index_shared[i] = -1 + T.sync_threads() + + T.copy(topk_index_shared[:topk], TopkIndices[bos + i_t, :]) + T.copy(reducesum_shared[:topk], ReduceSum[bos + i_t, :]) + + return tl_indexer_topk_reducesum_kernel + + +def indexer_topk_reducesum_interface( + q: torch.Tensor, + weights: torch.Tensor, + k: torch.Tensor, + topk: int, + offsets: torch.Tensor, + dtype: str = BF16, +): + seq_len, heads, dim = q.shape + kernel = tl_indexer_topk_reducesum_impl(heads=heads, dim=dim, topk=topk, dtype=dtype) + token_indices = prepare_token_indices(offsets) + topk_indices = torch.zeros((seq_len, topk), device=q.device, dtype=torch.int32) + topk_score = torch.zeros((seq_len, topk), device=q.device, dtype=torch.float32) + kernel(q, weights, k, topk_indices, topk_score, offsets, token_indices) + return topk_indices, topk_score + + +def ref_index_score(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, topk: int, offsets: torch.Tensor) -> torch.Tensor: + all_topk_indices = [] + all_topk_score = [] + for i in range(offsets.shape[0] - 1): + assert (offsets[i + 1] - offsets[i]).item() >= topk + q = Q[offsets[i] : offsets[i + 1]] + weights = Weights[offsets[i] : offsets[i + 1]] + k = K[offsets[i] : offsets[i + 1]] + softmax_scale = q.shape[-1] ** -0.5 + s = q.shape[0] + mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) + logits = einsum(q, k, "s1 h k, s2 k -> s1 h s2") + logits = F.relu(logits) + logits = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * softmax_scale + logits = torch.where(mask, logits, float("-inf")) + topk_logits, topk_indices = torch.topk(logits, k=topk, dim=-1) + topk_score = F.softmax(topk_logits, dim=-1, dtype=torch.float32) + all_topk_indices.append(topk_indices) + all_topk_score.append(topk_score) + topk_indices = torch.cat(all_topk_indices, dim=0) + topk_score = torch.cat(all_topk_score, dim=0) + return topk_indices, topk_score + + +def test_kernel( + B=1, + S=2048, + H=64, + D=128, + topk=64, +): + torch.manual_seed(42) + + q = torch.randn((S, H, D)).cuda().bfloat16() + weights = torch.randn((S, H)).cuda().bfloat16() + k = torch.randn((S, D)).cuda().bfloat16() + offsets = torch.tensor([0, S], dtype=torch.int32).cuda() + + ref_topk_indices, ref_topk_score = ref_index_score(q, weights, k, topk, offsets) + + topk_indices, topk_score = indexer_topk_reducesum_interface(q, weights, k, topk, offsets) + + for j in range(S): + ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy() + trt_np = topk_indices[j].cpu().to(torch.int32).numpy() + + ref_np_val = ref_topk_score[j] + trt_np_val = topk_score[j] + + mask = (ref_np_val > 0).cpu().numpy() + + set_ref = set(ref_np[mask]) + set_trt = set(trt_np[mask]) + intersection = set_ref & set_trt + + print("idx:", j, "selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref)) + + print(f"err: {get_abs_err(ref_np_val, trt_np_val):.6f} ratio: {get_err_ratio(ref_np_val, trt_np_val):.6f}") + + +if __name__ == "__main__": + test_kernel() diff --git a/examples/dsa_sparse_finetune/sparse_mla_bwd.py b/examples/dsa_sparse_finetune/sparse_mla_bwd.py new file mode 100644 index 000000000..06eaa8eb3 --- /dev/null +++ b/examples/dsa_sparse_finetune/sparse_mla_bwd.py @@ -0,0 +1,347 @@ +# ruff: noqa +import tilelang +from tilelang import language as T +import torch +from index import prepare_token_indices + +from utils import assert_tensors_similar + + +@tilelang.jit(out_idx=[-1]) +def preprocess( + H, + D, + block_ND=32, + num_stages=5, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + + S = T.symbolic("S") + + shape = [S, H, D] + + @T.prim_func + def preprocess_kernel( + O: T.Tensor(shape, dtype), + dO: T.Tensor(shape, dtype), + Delta: T.Tensor([S, H], accum_dtype), + ): + with T.Kernel(H, T.ceildiv(S, block_ND)) as (bx, by): + o = T.alloc_fragment([block_ND, block_ND], accum_dtype) + do = T.alloc_fragment([block_ND, block_ND], accum_dtype) + delta = T.alloc_fragment([block_ND], accum_dtype) + acc = T.alloc_fragment([block_ND, block_ND], accum_dtype) + T.clear(acc) + for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages): + T.copy(O[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], o) + T.copy(dO[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], do) + for i, j in T.Parallel(block_ND, block_ND): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[by * block_ND : (by + 1) * block_ND, bx]) + + return preprocess_kernel + + +@tilelang.jit(out_idx=[-1]) +def postprocess( + D, + D_tail, + kv_group=1, + block_N=64, + threads=128, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + S_kv = T.symbolic("S_kv") + + dkv_shape = [S_kv, kv_group, D + D_tail] + + @T.prim_func + def postprocess_kernel( + dKV: T.Tensor(dkv_shape, accum_dtype), + dKV_out: T.Tensor(dkv_shape, dtype), + ): + with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, threads=threads) as (bx, by): + T.copy( + dKV[bx * block_N : (bx + 1) * block_N, by, :], + dKV_out[bx * block_N : (bx + 1) * block_N, by, :], + ) + + return postprocess_kernel + + +@tilelang.jit( + out_idx=[-2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def bwd( + H, + D, + D_tail, + topk, + kv_group=1, + sm_scale=None, + is_causal=True, + block_size=32, + num_stages=0, + threads=128, + indices_dtype=T.int32, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + assert is_causal == True, "non-casual is not supported now" + assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + assert indices_dtype == T.int32 + + if sm_scale is None: + sm_scale = (D + D_tail) ** (-0.5) + + B_plus_one = T.symbolic("B_plus_one") + S = T.symbolic("S") + + H_kv = H // kv_group + q_shape = [S, H, D + D_tail] + k_shape = [S, kv_group, D + D_tail] + o_shape = [S, H, D] + indices_shape = [S, kv_group, topk] + delta_shape = [S, H] + lse_shape = [S, H] + offsets_shape = [B_plus_one] + token_indices_shape = [S, 2] + assert indices_dtype == T.int32 + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + + H = H_kv + padded_H = max(tilelang.math.next_power_of_2(H_kv), 16) + BS = block_size + NS = tilelang.cdiv(topk, block_size) + + split_store = 2 + + @T.prim_func + def sparse_mla_bwd_kernel( + Q: T.Tensor(q_shape, dtype), + KV: T.Tensor(k_shape, dtype), + dO: T.Tensor(o_shape, dtype), + Indices: T.Tensor(indices_shape, indices_dtype), + Lse: T.Tensor(lse_shape, accum_dtype), + Delta: T.Tensor(delta_shape, accum_dtype), + Offsets: T.Tensor(offsets_shape, indices_dtype), + TokenIndices: T.Tensor(token_indices_shape, indices_dtype), + dQ: T.Tensor(q_shape, dtype), + dKV: T.Tensor(k_shape, accum_dtype), + ): + with T.Kernel(S, kv_group, threads=threads) as (b_s_i, bz): + Q_shared = T.alloc_shared([padded_H, D], dtype) + Q_tail_shared = T.alloc_shared([padded_H, D_tail], dtype) + KV_shared = T.alloc_shared([BS, D], dtype) + KV_tail_shared = T.alloc_shared([BS, D_tail], dtype) + dO_shared = T.alloc_shared([padded_H, D], dtype) + mask = T.alloc_fragment([BS], "bool") + + P_shared_cast = T.alloc_shared([padded_H, BS], dtype) + dP_shared_cast = T.alloc_shared([padded_H, BS], dtype) + dQ_shared = T.alloc_shared([padded_H, D], dtype) + dQ_tail_shared = T.alloc_shared([padded_H, D_tail], dtype) + + acc_p = T.alloc_fragment([padded_H, BS], accum_dtype) + acc_dp = T.alloc_fragment([padded_H, BS], accum_dtype) + acc_dq = T.alloc_fragment([padded_H, D], accum_dtype) + acc_dq_tail = T.alloc_fragment([padded_H, D_tail], accum_dtype) + acc_dkv = T.alloc_fragment([BS, D], accum_dtype) + acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype) + acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype) + acc_dkv_tail_shared = T.view(KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype) + + b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1] + bos, eos = Offsets[b_i], Offsets[b_i + 1] + + max_kv_i = s_i + + T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], Q_shared) + T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:], Q_tail_shared) + T.copy(dO[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], dO_shared) + + T.clear(acc_dq) + T.clear(acc_dq_tail) + + # Process each block of indices + for i_i in T.Pipelined(NS, num_stages=num_stages): + # Check which indices are valid + for bi_i in T.Parallel(BS): + mask[bi_i] = (Indices[bos + s_i, bz, i_i * BS + bi_i] <= max_kv_i) & (Indices[bos + s_i, bz, i_i * BS + bi_i] != -1) + + # Compute attention scores + for h_i, bi_i in T.Parallel(padded_H, BS): + acc_p[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_p.dtype)) + + # Load KV, V for this block of indices + for bi_i, d_i in T.Parallel(BS, D): + KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, d_i] + + T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + + for bi_i, d_i in T.Parallel(BS, D_tail): + KV_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, D + d_i] + T.gemm(Q_tail_shared, KV_tail_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + + for h_i, bi_i in T.Parallel(padded_H, BS): + acc_p[h_i, bi_i] = T.exp(acc_p[h_i, bi_i] * sm_scale - Lse[bos + s_i, bz * padded_H + h_i]) + + T.copy(acc_p, P_shared_cast) + + T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + + for h_i, bi_i in T.Parallel(padded_H, BS): + acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[bos + s_i, bz * padded_H + h_i]) * sm_scale + + T.copy(acc_dp, dP_shared_cast) + T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol) + T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol) + + T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) + + T.clear(acc_dkv_tail) + T.gemm(dP_shared_cast, Q_tail_shared, acc_dkv_tail, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) + + for s in range(split_store): + for bi_i, d_i in T.Parallel(BS, D): + if bi_i < BS // split_store: + acc_dkv_shared[bi_i, d_i] = acc_dkv[bi_i + s * (BS // split_store), d_i] + + for bi_i, d_i in T.Parallel(BS, D_tail): + if bi_i < BS // split_store: + acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i] + + for bi_i, d_i in T.Parallel(BS // split_store, D): + T.atomic_add( + dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, d_i], + acc_dkv_shared[bi_i, d_i], + ) + + # Atomically update dKV, dKV_tail tensors + for bi_i, d_i in T.Parallel(BS // split_store, D_tail): + T.atomic_add( + dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, D + d_i], + acc_dkv_tail_shared[bi_i, d_i], + ) + + # Store the accumulated dQ + T.copy(acc_dq, dQ_shared) + T.copy(acc_dq_tail, dQ_tail_shared) + + T.copy(dQ_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D]) + T.copy(dQ_tail_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:]) + + return sparse_mla_bwd_kernel + + +def sparse_mla_bwd(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True, return_kernel=False, delta=None): + assert q.is_contiguous() + assert kv.is_contiguous() + assert indices.is_contiguous() + assert lse.is_contiguous() + S, H, dim_plus_tail_dim = q.shape + S_kv, kv_group, _ = kv.shape + assert kv.shape[-1] == dim_plus_tail_dim + assert S == S_kv + # dim should be assigned + D = 512 + + D_tail = dim_plus_tail_dim - D + topk = indices.shape[-1] + assert indices.shape == (S, kv_group, topk) + assert lse.shape == (S, H) + + token_indices = prepare_token_indices(offsets) + + # Get kernels + preprocess_kernel = preprocess(H, D) + bwd_kernel = bwd(H, D, D_tail, topk, kv_group, sm_scale, is_casual) + postprocess_kernel = postprocess(D, D_tail, kv_group) + + if delta is None: + delta = preprocess_kernel(o, do) + dkv = torch.zeros_like(kv, dtype=torch.float32) + dq = bwd_kernel(q, kv, do, indices, lse, delta, offsets, token_indices, dkv) + dkv = postprocess_kernel(dkv) + + return dq, dkv + + +def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True): + from sparse_mla_fwd import ref_sparse_mla_fwd_interface + + q = q.detach().clone() + kv = kv.detach().clone() + q.requires_grad = True + kv.requires_grad = True + o = ref_sparse_mla_fwd_interface(q, kv, indices, offsets, sm_scale, is_casual) + o.backward(do) + return q.grad, kv.grad + + +def test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True): + # Prepare data + q = torch.randn((S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((S, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((S, H, DV), dtype=dtype, device="cuda") + offsets = torch.tensor([0, S], dtype=torch.int32, device="cuda") + + indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device="cuda") + for i in range(offsets.shape[0] - 1): + seq_len = (offsets[i + 1] - offsets[i]).item() + assert seq_len >= topk + for t in range(seq_len): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[offsets[i] + t, h, : len(i_i)] = i_i + + # Forward + from sparse_mla_fwd import sparse_mla_fwd_interface + + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets) + + tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets) + ref_dq, ref_dkv = ref_sparse_mla_bwd_interface(q, kv, None, do, indices, None, offsets) + + if check_correctness: + assert_tensors_similar(tl_dq, ref_dq, eps=1e-4, name="dq") + assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv") + print("assert_tensors_similar passed") + + per_token_flop = 2 * sum( + [ + H * DV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DV * topk, + ] + ) + from tilelang.profiler import do_bench + + def fn(): + return sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets) + + ms = do_bench(fn, rep=100, warmup=250) + print(f"Average time: {ms:.3f} ms") + print(f"bwd io bandwidth = ", (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12) + print(f"bwd tflops = ", per_token_flop * S / (ms * 1e-3) / 1e12) + + +if __name__ == "__main__": + test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True) diff --git a/examples/dsa_sparse_finetune/sparse_mla_fwd.py b/examples/dsa_sparse_finetune/sparse_mla_fwd.py new file mode 100644 index 000000000..d87523695 --- /dev/null +++ b/examples/dsa_sparse_finetune/sparse_mla_fwd.py @@ -0,0 +1,310 @@ +# ruff: noqa +import torch +import tilelang +from tilelang import language as T +from index import prepare_token_indices + +from utils import assert_tensors_similar + + +@tilelang.jit( + out_idx=[-2, -1], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def sparse_mla_fwd( + heads, + dim, + tail_dim, + topk, + kv_group=1, + sm_scale=None, + is_causal=True, + CP0=True, + block_I=32, + num_stages=2, + threads=128, +): + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert is_causal == True, "non-casual is not supported" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 + else: + sm_scale = sm_scale + + batch_plus_one = T.symbolic("batch_plus_one") + seq_len = T.symbolic("seq_len") + + head_kv = heads // kv_group + q_shape = [seq_len, heads, dim + tail_dim] + kv_shape = [seq_len, kv_group, dim + tail_dim] + o_shape = [seq_len, heads, dim] + indices_shape = [seq_len, kv_group, topk] + lse_shape = [seq_len, heads] + offsets_shape = [batch_plus_one] + token_indices_shape = [seq_len, 2] + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 + + G = kv_group + H = head_kv + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != H: + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) + BI = block_I + NI = tilelang.cdiv(topk, block_I) + D = dim + D_tail = tail_dim + + if head_kv > 64: + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" + REPLICATE_H = head_kv // 64 + else: + REPLICATE_H = 1 + + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore + TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + ): + with T.Kernel(seq_len * REPLICATE_H, kv_group, threads=threads) as ( + bx, + by, + ): + Q_shared = T.alloc_shared([H_per_block, D], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + KV_shared = T.alloc_shared([BI, D], dtype) + K_tail_shared = T.alloc_shared([BI, D_tail], dtype) + mask = T.alloc_fragment([BI], "bool") + + acc_o = T.alloc_fragment([H_per_block, D], accum_dtype) + acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) + S_shared = T.alloc_shared([H_per_block, BI], dtype) + sumexp = T.alloc_fragment([H_per_block], accum_dtype) + sumexp_i = T.alloc_fragment([H_per_block], accum_dtype) + alpha = T.alloc_fragment([H_per_block], accum_dtype) + m_i = T.alloc_fragment([H_per_block], accum_dtype) + m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) + + T.fill(acc_o, 0) + T.fill(sumexp, 0) + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan + + b_s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H) + b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1] + bos, eos = Offsets[b_i], Offsets[b_i + 1] + g_i = by + q_i = s_i + max_kv_i = q_i + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + T.copy(Q[bos + s_i, H0:H1, :D], Q_shared) + T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared) + + for i_i in T.Pipelined(NI, num_stages=num_stages): + for bi_i in T.Parallel(BI): + mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1) + + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i] + for bi_i, d_i in T.Parallel(BI, D_tail): + K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) + T.gemm( + Q_shared, + KV_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.gemm( + Q_tail_shared, + K_tail_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + alpha[h_i] = T.exp((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i] + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + # Rescale + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] /= sumexp[h_i] + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = T.log(sumexp[h_i]) + m_i[h_i] * sm_scale + + T.copy(acc_o, Output[bos + s_i, H0:H1, :]) + T.copy(sumexp, Lse[bos + s_i, H0:H1]) + + return main + + +def sparse_mla_fwd_interface( + q, kv, indices, offsets, sm_scale=None, return_p_sum: bool = False, d_v=512, block_I=32, num_stages=2, threads=128 +): + is_casual = True + assert return_p_sum == False, "This kernel file is for fwd only" + assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() + seq_len, heads, dim_plus_tail_dim = q.shape + seq_len_kv, kv_group, _ = kv.shape + assert seq_len == seq_len_kv + + assert dim_plus_tail_dim == 576, "you should assign dim otherwise" + dim = d_v + + assert kv.shape[-1] == dim_plus_tail_dim + tail_dim = dim_plus_tail_dim - dim + _, _, topk = indices.shape + assert indices.shape == (seq_len, kv_group, topk) + + token_indices = prepare_token_indices(offsets) + + kernel = sparse_mla_fwd( + heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual, block_I=block_I, num_stages=num_stages, threads=threads + ) + out, lse = kernel(q, kv, indices, offsets, token_indices) + return out, lse + + +def ref_sparse_mla_fwd_interface(Q, KV, Indices, offsets, sm_scale=None, is_casual=True): + Q = Q.float() + KV = KV.float() + all_o = [] + for i in range(offsets.shape[0] - 1): + q = Q[None, offsets[i] : offsets[i + 1]] + kv = KV[None, offsets[i] : offsets[i + 1]] + indices = Indices[None, offsets[i] : offsets[i + 1]].clone() + + indices = indices.transpose(1, 2) + b, sq, h, dim_q = q.shape + b, sk, g, _ = kv.shape + + assert kv.shape[-1] == 576, "you should assign dim otherwise" + dim = 512 + k = kv + v = kv[..., :dim] + + b, _, _, dim_v = v.shape + g_index = g + h_index = h // g + compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( + 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda" + ).view(1, -1) + + indices[indices > sk] = sk + mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) + mask = mask[..., :-1] + mask = mask & compressed_casual_mask.view(1, 1, sq, sk) + mask[:, :, : 1 - 1, 0] = True + mask = mask.view(b, g_index, 1, sq, sk) + + q = q.view(b, sq, g, -1, dim_q) + score = torch.einsum("bmghd,bngd->bghmn", q, k) + sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale + score = score.masked_fill(~mask, float("-inf")).mul(sm_scale) + p = score.softmax(dim=-1) + p = p.view(b, g_index, h_index, -1, sq, sk) + p = p.view(b, g, -1, sq, sk) + o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v) + o = o.reshape(b, sq, h, dim_v) + all_o.append(o.squeeze(0)) + o = torch.cat(all_o, dim=0) + return o.to(torch.bfloat16) + + +def test_sparse_mla_fwd( + B=1, + S=4096, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + check_correctness=True, + block_I=64, + num_stages=2, + threads=256, +): + torch.random.manual_seed(0) + q = torch.randn((S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((S, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) + offsets = torch.tensor([0, S // 2 - 1, S], dtype=torch.int32, device="cuda") + + indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device="cuda") + for i in range(offsets.shape[0] - 1): + seq_len = (offsets[i + 1] - offsets[i]).item() + assert seq_len >= topk + for t in range(seq_len): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[offsets[i] + t, h, : len(i_i)] = i_i + + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads) + + if check_correctness: + # otherwise may cause out of memory + ref_out = ref_sparse_mla_fwd_interface(q, kv, indices, offsets) + assert_tensors_similar(tl_out, ref_out, eps=1e-2, name="out") + print("assert_tensors_similar passed") + + def fn(): + return sparse_mla_fwd_interface(q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads) + + from tilelang.profiler import do_bench + + ms = do_bench( + fn, + rep=100, + warmup=250, + ) + print(f"Average time: {ms:.3f} ms") + print("fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) + print("fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) + + +if __name__ == "__main__": + test_sparse_mla_fwd( + B=1, + S=4096, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=1024, + dtype=torch.bfloat16, + check_correctness=True, + block_I=64, + num_stages=2, + threads=256, + ) diff --git a/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py b/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py new file mode 100644 index 000000000..a03bc74f5 --- /dev/null +++ b/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py @@ -0,0 +1,226 @@ +# ruff: noqa +import torch +import torch.nn as nn +import torch.nn.functional as F +import tilelang +from tilelang import language as T +from einops import repeat, rearrange, einsum +from index import prepare_token_indices +from utils import get_abs_err, get_err_ratio + +BF16 = T.bfloat16 +FP32 = T.float32 +INT32 = T.int32 + +pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, +} + + +@tilelang.jit(pass_configs=pass_configs) +def tl_sparse_mla_topk_reducesum_impl( + heads, + dim, + tail_dim, + topk, + kv_group=1, + sm_scale=None, + block_I=32, + num_stages=2, + threads=128, +): + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 + + batch_plus_one = T.symbolic("batch_plus_one") + seq_len = T.symbolic("seq_len") + seq_len_kv = T.symbolic("seq_len_kv") + + head_kv = heads // kv_group + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 + + G = kv_group + H = head_kv + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != H: + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) + BI = block_I + NI = tilelang.cdiv(topk, block_I) + D = dim + D_tail = tail_dim + + if head_kv > 64: + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" + REPLICATE_H = head_kv // 64 + else: + REPLICATE_H = 1 + + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + q_shape = [seq_len, heads, dim + tail_dim] + kv_shape = [seq_len_kv, kv_group, dim + tail_dim] + indices_shape = [seq_len, kv_group, topk] + lse_shape = [seq_len, heads] + reducesum_shape = [seq_len, kv_group, REPLICATE_H, topk] + offsets_shape = [batch_plus_one] + token_indices_shape = [seq_len, 2] + + @T.prim_func + def tl_sparse_mla_topk_reducesum_kernel( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore + TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore + ReduceSum: T.Tensor(reducesum_shape, accum_dtype), # type: ignore + ): + with T.Kernel(seq_len * REPLICATE_H, kv_group, threads=threads) as ( + bx, + by, + ): + Q_shared = T.alloc_shared([H_per_block, D], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + KV_shared = T.alloc_shared([BI, D], dtype) + K_tail_shared = T.alloc_shared([BI, D_tail], dtype) + mask = T.alloc_fragment([BI], "bool") + + acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) + reducesum = T.alloc_fragment([BI], accum_dtype) + lse = T.alloc_fragment([H_per_block], accum_dtype) + + T.fill(lse, 0) + + b_s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H) + b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1] + bos, eos = Offsets[b_i], Offsets[b_i + 1] + r_i = bx % REPLICATE_H + g_i = by + q_i = s_i + max_kv_i = q_i + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + T.copy(Q[bos + s_i, H0:H1, :D], Q_shared) + T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared) + T.copy(Lse[bos + s_i, H0:H1], lse) + + for i_i in T.Pipelined(NI, num_stages=num_stages): + for bi_i in T.Parallel(BI): + mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1) + + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i] + for bi_i, d_i in T.Parallel(BI, D_tail): + K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) + T.gemm( + Q_shared, + KV_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.gemm( + Q_tail_shared, + K_tail_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - lse[h_i]) + T.reduce_sum(acc_s, reducesum, dim=0) + T.copy(reducesum, ReduceSum[bos + s_i, g_i, r_i, i_i * BI : i_i * BI + BI]) + + return tl_sparse_mla_topk_reducesum_kernel + + +def sparse_mla_topk_reducesum_interface( + q: torch.Tensor, + kv: torch.Tensor, + topk_indices: torch.Tensor, + lse: torch.Tensor, + offsets: torch.Tensor, + dim_v: int, +): + assert kv.shape[-2] == 1 + seq_len, heads, dim_plus_tail_dim, topk = *q.shape, topk_indices.shape[-1] + REPLICATE_H = max(heads // 64, 1) + tail_dim = dim_plus_tail_dim - dim_v + token_indices = prepare_token_indices(offsets) + + reducesum = torch.zeros([seq_len, 1, REPLICATE_H, topk], dtype=torch.float32, device=q.device) + kernel = tl_sparse_mla_topk_reducesum_impl(heads=heads, dim=dim_v, tail_dim=tail_dim, topk=topk) + kernel(q, kv, topk_indices, lse, offsets, token_indices, reducesum) + reducesum = reducesum.sum(dim=-2) # [batch, seq_len, 1, RH, topk] -> [batch, seq_len, 1, topk] + attn_score = reducesum / reducesum.sum(dim=-1, keepdim=True) + + return attn_score + + +def ref_mla_topk_softmax(Q: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, offsets: torch.Tensor): + # q: [batch, seq_len, heads, dim] + # k: [batch, seq_len, dim] + sm_scale = Q.shape[-1] ** -0.5 + all_lse = [] + all_topk_score = [] + for i in range(offsets.shape[0] - 1): + q = Q[offsets[i] : offsets[i + 1]] + k = K[offsets[i] : offsets[i + 1]] + topk_indices = TopkIndices[offsets[i] : offsets[i + 1]] + seq_len = q.shape[0] + mask = (torch.arange(seq_len)[:, None] >= torch.arange(seq_len)[None, :]).unsqueeze(-2).cuda() + logits = einsum(q, k, "s1 h d, s2 d -> s1 h s2") * sm_scale + logits = torch.where(mask, logits, float("-inf")) + score = F.softmax(logits, dim=-1, dtype=torch.float32) + score_sum = score.sum(dim=-2) + topk_score = torch.gather(score_sum, dim=-1, index=topk_indices.to(torch.int64)) + topk_score = topk_score / topk_score.sum(dim=-1, keepdim=True) + max_logits = logits.amax(dim=-1).to(torch.float32) + lse = torch.log((logits - max_logits.unsqueeze(-1).to(torch.float32)).exp().sum(dim=-1)) + max_logits + all_lse.append(lse) + all_topk_score.append(topk_score) + lse = torch.cat(all_lse, dim=0) + topk_score = torch.cat(all_topk_score, dim=0) + return lse, topk_score + + +def test_kernel( + B=1, + S=2048, + H=16, + D=512, + tail_D=64, + topk=128, +): + torch.manual_seed(42) + + q = torch.randn((S, H, D + tail_D)).cuda().bfloat16() + kv = torch.randn((S, D + tail_D)).cuda().bfloat16() + offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda() + + topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous() + + lse, ref_attn_score = ref_mla_topk_softmax(q, kv, topk_indices, offsets) + + kv = kv.unsqueeze(-2) + topk_indices = topk_indices.unsqueeze(-2) + + attn_score = sparse_mla_topk_reducesum_interface(q, kv, topk_indices, lse, offsets, dim_v=D).squeeze(-2) + print(f"attn_score err: {get_abs_err(attn_score, ref_attn_score):.6f} ratio: {get_err_ratio(attn_score, ref_attn_score):.6f}") + + +if __name__ == "__main__": + test_kernel() diff --git a/examples/dsa_sparse_finetune/utils.py b/examples/dsa_sparse_finetune/utils.py new file mode 100644 index 000000000..96afd064d --- /dev/null +++ b/examples/dsa_sparse_finetune/utils.py @@ -0,0 +1,73 @@ +import torch + + +def get_abs_err(y, x): + x = x.to(torch.float32) + y = y.to(torch.float32) + return (x - y).flatten().abs().max().item() + + +def get_err_ratio(y, x): + x = x.to(torch.float32) + y = y.to(torch.float32) + err = (x - y).flatten().square().mean().sqrt().item() + base = (x).flatten().square().mean().sqrt().item() + return err / base + + +def calculate_tensor_similarity(x, y, name="tensor"): + """ + Calculate similarity between two tensors using a normalized dot product metric. + + Unlike torch.testing.assert_close which uses absolute/relative tolerance based on + element-wise differences, this function computes a global similarity score: + sim = 2 * / (||x||^2 + ||y||^2) + + This metric is scale-invariant and measures the cosine-like similarity normalized + by the magnitude of both tensors. It returns 1 for identical tensors and values + closer to 0 for dissimilar ones. This is particularly useful for comparing tensors + with varying magnitudes where relative errors matter more than absolute differences. + + Args: + x: First tensor to compare + y: Second tensor to compare + name: Name of the tensor for logging purposes + + Returns: + Similarity score in range [0, 1] where 1 means identical + """ + x, y = x.data.double(), y.data.double() + denominator = (x * x + y * y).sum() + if denominator == 0: + print(f"\033[33mWARNING: {name} all zero\033[0m") + return 1 + sim = 2 * (x * y).sum() / denominator + return sim + + +def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): + """ + Assert that two tensors are similar using a global similarity metric. + + Key differences from torch.testing.assert_close: + - torch.testing.assert_close: Uses element-wise comparison with rtol/atol, checking + that |x - y| <= atol + rtol * |y| for each element. It's sensitive to outliers + and requires all elements to satisfy the tolerance. + - assert_tensors_similar: Uses a single global similarity score (1 - sim) where sim is the + normalized dot product. It's more robust to outliers and focuses on overall + tensor similarity rather than element-wise precision. This is better suited for + comparing large tensors where a few outlier elements shouldn't fail the test. + + Args: + x: First tensor to compare + y: Second tensor to compare + eps: Maximum allowed difference (1 - similarity), default 1e-8 + name: Name of the tensor for error messages + raise_assert: Whether to raise assertion error on failure + """ + sim = calculate_tensor_similarity(x, y, name) + diff = 1.0 - sim + if not (0 <= diff <= eps): + print(f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m") + if raise_assert: + assert False # noqa: B011 diff --git a/examples/dynamic_shape/example_dynamic.py b/examples/dynamic_shape/example_dynamic.py index be018c8b7..e338d76ca 100644 --- a/examples/dynamic_shape/example_dynamic.py +++ b/examples/dynamic_shape/example_dynamic.py @@ -1,10 +1,9 @@ import tilelang import tilelang.language as T import tilelang.testing -from tilelang import tvm as tvm -@tilelang.jit(pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8}) +@tilelang.jit def matmul_dynamic_mnk( block_M, block_N, @@ -17,9 +16,9 @@ def matmul_dynamic_mnk( num_stages, threads, ): - M = tvm.te.var("m") - N = tvm.te.var("n") - K = tvm.te.var("k") + M = T.dynamic("m") + N = T.dynamic("n") + K = T.dynamic("k") A_shape = (K, M) if trans_A else (M, K) B_shape = (N, K) if trans_B else (K, N) @@ -29,9 +28,9 @@ def matmul_dynamic_mnk( @T.prim_func def dynamic_matmul( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -53,15 +52,14 @@ def dynamic_matmul( return dynamic_matmul -def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, - accum_dtype, num_stages, threads): +def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads): print( f"M: {M}, N: {N}, K: {K}, block_M: {block_M}, block_N: {block_N}, block_K: {block_K}, trans_A: {trans_A}, trans_B: {trans_B}, in_dtype: {in_dtype}, out_dtype: {out_dtype}, accum_dtype: {accum_dtype}, num_stages: {num_stages}, threads: {threads}" ) - kernel = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, - accum_dtype, num_stages, threads) + kernel = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads) import torch + if trans_A: A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype)) else: @@ -103,8 +101,30 @@ def main(M=16384, N=16384, K=16384): accum_dtype = "float32" num_stages = 3 threads = 128 - matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, - accum_dtype, num_stages, threads) + matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads) + + +def run_regression_perf(M=4096, N=4096, K=4096): + block_M, block_N, block_K = 128, 128, 32 + trans_A, trans_B = False, False + in_dtype, out_dtype = "float16", "float16" + accum_dtype = "float32" + num_stages = 3 + threads = 128 + kernel = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads) + import torch + + if trans_A: + A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype)) + else: + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + if trans_B: + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + else: + B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(input_tensors=[A, B, C], backend="cupti") if __name__ == "__main__": diff --git a/examples/dynamic_shape/regression_example_dynamic.py b/examples/dynamic_shape/regression_example_dynamic.py new file mode 100644 index 000000000..958695990 --- /dev/null +++ b/examples/dynamic_shape/regression_example_dynamic.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_dynamic + + +def regression_example_dynamic(): + tilelang.testing.process_func(example_dynamic.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/dynamic_shape/test_example_dynamic.py b/examples/dynamic_shape/test_example_dynamic.py deleted file mode 100644 index 36a3743f1..000000000 --- a/examples/dynamic_shape/test_example_dynamic.py +++ /dev/null @@ -1,10 +0,0 @@ -import tilelang.testing -import example_dynamic - - -def test_example_dynamic(): - example_dynamic.main(M=1024, N=1024, K=1024) - - -if __name__ == "__main__": - tilelang.testing.main() diff --git a/examples/eager_jit/eagerjit.en.ipynb b/examples/eager_jit/eagerjit.en.ipynb new file mode 100644 index 000000000..6a2bf8453 --- /dev/null +++ b/examples/eager_jit/eagerjit.en.ipynb @@ -0,0 +1,977 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5e0deecc", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n", + "import tilelang\n", + "import torch\n", + "import tilelang.language as T" + ] + }, + { + "cell_type": "markdown", + "id": "1ca2c56d", + "metadata": {}, + "source": [ + "# Tilelang Eager JIT" + ] + }, + { + "cell_type": "markdown", + "id": "156e7370", + "metadata": {}, + "source": [ + "## Tensor Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "b070c109", + "metadata": {}, + "source": [ + "Tilelang Eager JIT merges JIT kernel generation and invocation into a single workflow.\n", + "\n", + "The function signature looks similar to Triton, but we add many enhancements; the most important one is allowing rich Tensor annotations:\n", + "\n", + "* If a Tensor has complex shape constraints, we can move its annotation into the function body.\n", + "* Use `T.const` or `T.dynamic` to create shape variables, then annotate complex Tensors with `T.Tensor`.\n", + "* Use `T.empty` to declare return tensors." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "60bf8954", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def gemm(\n", + " A,\n", + " B,\n", + " out_dtype: T.dtype = T.float32,\n", + " block_M: int = 128,\n", + " block_N: int = 128,\n", + " block_K: int = 32,\n", + "):\n", + " M, N, K = T.const(\"M, N, K\")\n", + "\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + "\n", + " C = T.empty((M, N), out_dtype)\n", + "\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", + " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), out_dtype)\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", + " T.copy(A[bx * block_M, k * block_K], A_shared)\n", + " T.copy(B[k * block_K, by * block_N], B_shared)\n", + " T.gemm(A_shared, B_shared, C_local)\n", + " T.copy(C_local, C[bx * block_M, by * block_N])\n", + " return C" + ] + }, + { + "cell_type": "markdown", + "id": "28f868fe", + "metadata": {}, + "source": [ + "Calling the function with Tensors directly triggers the full JIT compile-and-run pipeline:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ee13394a", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm(A, B)\n", + "\n", + "# check output is correct\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "c6705091", + "metadata": {}, + "source": [ + "Changing the call arguments may trigger a recompilation when compilation parameters change:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d8aab5b7", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 1024, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm(A, B, block_M=64, block_N=64)" + ] + }, + { + "cell_type": "markdown", + "id": "ce6b7391", + "metadata": {}, + "source": [ + "You can also explicitly call the `compile` method to build the kernel.\n", + "\n", + "1. `ker.compile` compiles the kernel\n", + "2. `ker.get_tir` retrieves the TIR\n", + "3. `ker.par_compile` compiles in parallel" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f3cf3a2d", + "metadata": {}, + "outputs": [], + "source": [ + "kernel = gemm.compile(A, B, block_M=64, block_N=64)\n", + "C = kernel(A, B)" + ] + }, + { + "cell_type": "markdown", + "id": "921761b5", + "metadata": {}, + "source": [ + "## More Tensor Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "4539e54e", + "metadata": {}, + "source": [ + "### Use macros to separate implementation" + ] + }, + { + "cell_type": "markdown", + "id": "ad96ba65", + "metadata": {}, + "source": [ + "Next, we implement a simple GEMM in several different ways. For convenience, we first write a macro that contains the core GEMM logic:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "171d4fe6", + "metadata": {}, + "outputs": [], + "source": [ + "@T.macro\n", + "def gemm_impl(A, B, C, M, N, K, block_M, block_N, block_K):\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", + " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), C.dtype)\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", + " T.copy(A[bx * block_M, k * block_K], A_shared)\n", + " T.copy(B[k * block_K, by * block_N], B_shared)\n", + " T.gemm(A_shared, B_shared, C_local)\n", + " T.copy(C_local, C[bx * block_M, by * block_N])" + ] + }, + { + "cell_type": "markdown", + "id": "446a1acd", + "metadata": {}, + "source": [ + "### Use `T.dynamic` to mark dynamic shapes\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6a38aa95", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def gemm_dyn_K(A, B):\n", + " M, N, K = T.dynamic(\"M, N, K\")\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, 128, 128, 32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "fe6cfdc8", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_dyn_K(A, B)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "2ee97bf7", + "metadata": {}, + "source": [ + "### Use `T.StridedTensor` to annotate tensors with strides\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "9dde1dae", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def as_contingious(A):\n", + " M, N, dM, dN = T.dynamic(\"M, N, dM, dN\")\n", + " A: T.StridedTensor[[M, N], [dM, dN], T.float32]\n", + " B = T.empty((M, N), A.dtype)\n", + " block_M = 128\n", + " block_N = 128\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " T.copy(\n", + " A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " )\n", + " return B" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "dec2c0a7", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 1024, device=\"cuda\")\n", + "B = as_contingious(A.T)\n", + "B_ref = A.T.contiguous()\n", + "torch.testing.assert_close(B, B_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "f5fb20d6", + "metadata": {}, + "source": [ + "## More Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "890df0a2", + "metadata": {}, + "source": [ + "### Use parameters directly as annotations" + ] + }, + { + "cell_type": "markdown", + "id": "e9a47d42", + "metadata": {}, + "source": [ + "You can directly use function parameters in the annotations." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "0fc17af6", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def gemm_ptr(\n", + " A,\n", + " B,\n", + " M,\n", + " N,\n", + " K,\n", + "):\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "8e52a554", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_ptr(A, B, 1024, 256, 512)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "6b19ef90", + "metadata": {}, + "source": [ + "### Annotations for runtime variables" + ] + }, + { + "cell_type": "markdown", + "id": "bba5f27f", + "metadata": {}, + "source": [ + "Runtime variables work the same; if the function annotation becomes too long, you can move it into the function body." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "c1e7598a", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def gemm_ptr_dyn(A, B, M, N, K):\n", + " M: T.int32\n", + " N: T.int32\n", + " K: T.int32\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "9e9a4c88", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "81427765", + "metadata": {}, + "source": [ + "### Constraints for constants" + ] + }, + { + "cell_type": "markdown", + "id": "4d6b084b", + "metadata": {}, + "source": [ + "A constant annotation created by `T.const` must be used directly at least once, otherwise an error is raised." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "c90dd24f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Constexpr variable `M` is not used in any buffer shape or stride.\n", + "At least one **DIRECT** usage is required. Please check:\n", + "(1) the variable is not used\n", + "(2) all uses are indirect, e.g. M * 2, M * 3. (you can replace them with separate constexpr variables)\n", + "Buffer shapes: {A: [M * 2, M * 3]}\n", + "Buffer strides: {A: [M * 3, 1]}\n" + ] + } + ], + "source": [ + "@tilelang.jit\n", + "def example_wrong_kernel(A):\n", + " M = T.const(\"M\")\n", + " A: T.Tensor[[M * 2, M * 3], T.float32]\n", + " with T.Kernel(1) as _:\n", + " A[0, 0]\n", + "\n", + "\n", + "try:\n", + " A = torch.randn(64, 96, dtype=torch.float32, device=\"cuda\")\n", + " example_wrong_kernel(A)\n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "id": "e07e762b", + "metadata": {}, + "source": [ + "### Dynamic dimensions" + ] + }, + { + "cell_type": "markdown", + "id": "f48e5d7a", + "metadata": {}, + "source": [ + "If you want certain parameters in a Tensor annotation to change, it is recommended to switch to the `T.ptr` + `T.match_buffer` style." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "1d050321", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@tilelang.jit\n", + "def dyn_annot(\n", + " A: T.ptr, # 1. T.ptr type annotation\n", + " is_2d=False,\n", + "):\n", + " if is_2d:\n", + " M, N = T.const(\"M, N\")\n", + " # 2. dynamic shape annotation inside function body\n", + " A = T.match_buffer(A, [M, N], T.float32)\n", + " with T.Kernel(1) as _:\n", + " A[0, 0]\n", + " else:\n", + " L = T.const(\"L\")\n", + " A = T.match_buffer(A, [L], T.float32)\n", + " with T.Kernel(1) as _:\n", + " A[0]\n", + "\n", + "\n", + "A = torch.randn(64, 96, dtype=torch.float32, device=\"cuda\")\n", + "dyn_annot(A, is_2d=True)" + ] + }, + { + "cell_type": "markdown", + "id": "2e9f1bb3", + "metadata": {}, + "source": [ + "### Default arguments" + ] + }, + { + "cell_type": "markdown", + "id": "f7fc9917", + "metadata": {}, + "source": [ + "Scalar annotations like `T.float32` can carry default values." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "42ec86a1", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def add_one(X, data: T.float32 = 1):\n", + " M, N = T.const(\"M, N\")\n", + " X: T.Tensor[[M, N], T.float32]\n", + " Y = T.empty((M, N), T.float32)\n", + " with T.Kernel(T.ceildiv(M, 128), threads=128) as bx:\n", + " for i, j in T.Parallel(128, N):\n", + " Y[bx * 128 + i, j] = X[bx * 128 + i, j] + data\n", + " return Y" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "d49e1120", + "metadata": {}, + "outputs": [], + "source": [ + "X = torch.randn(1024, 1024, dtype=torch.float32, device=\"cuda\")\n", + "Y = add_one(X)\n", + "torch.testing.assert_close(Y, X + 1)" + ] + }, + { + "cell_type": "markdown", + "id": "a02baedc", + "metadata": {}, + "source": [ + "## Overhead of argument matching" + ] + }, + { + "cell_type": "markdown", + "id": "860a2972", + "metadata": {}, + "source": [ + "EagerJIT has very small overhead; each additional constant annotation costs about 200 ns.\n", + "* 200 ns is roughly the cost of an FFI call that reads parameters from a `torch.Tensor`'s shape/stride." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc676e33", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Kernel call : 7.68 us\n", + "Parse cache key: 0.41 us\n" + ] + } + ], + "source": [ + "import time\n", + "\n", + "A = torch.randn(128, 128, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(128, 128, dtype=torch.float16, device=\"cuda\")\n", + "\n", + "\n", + "@tilelang.jit\n", + "def dummy_kernel(A, B):\n", + " M, N = T.const(\"M, N\")\n", + " A: T.Tensor[[M, N], T.float16]\n", + " B: T.Tensor[[M, N], T.float16]\n", + " with T.Kernel(1) as _:\n", + " pass\n", + "\n", + "\n", + "# compile it first\n", + "dummy_kernel(A, B)\n", + "\n", + "\n", + "def eval_overhead(f):\n", + " start = time.perf_counter_ns()\n", + " for _ in range(10000):\n", + " f()\n", + " stop = time.perf_counter_ns()\n", + " return (stop - start) / 10000 / 1000\n", + "\n", + "\n", + "kernel_call_overhead = eval_overhead(lambda: dummy_kernel(A, B))\n", + "parse_cache_key_overhead = eval_overhead(lambda: dummy_kernel.parse_cache_key(A, B))\n", + "\n", + "print(f\"Kernel call : {kernel_call_overhead:.2f} us\")\n", + "print(f\"Parse cache key: {parse_cache_key_overhead:.2f} us\")" + ] + }, + { + "cell_type": "markdown", + "id": "39166cb4", + "metadata": {}, + "source": [ + "## Compilation and parallel compilation" + ] + }, + { + "cell_type": "markdown", + "id": "8c6fbe08", + "metadata": {}, + "source": [ + "Both EagerJIT and the original `jit` (i.e. LazyJIT) support parallel compilation.\n", + "\n", + "To avoid wasting memory on temporary `torch.Tensor` objects, you can use `T.Tensor` to create placeholders." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "7222e57b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8a4e4eb3cd4445bda6e8693da31ef3b8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Elaborating: 0%| | 0/8 [00:00,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from itertools import product\n", + "\n", + "\n", + "def get_configs():\n", + " return [\n", + " {\n", + " \"A\": T.Tensor((1024, 1024), T.float32),\n", + " \"B\": T.Tensor((1024, 1024), T.float32),\n", + " \"block_M\": block_M,\n", + " \"block_N\": block_N,\n", + " \"block_K\": block_K,\n", + " }\n", + " for block_M, block_N, block_K in product([32, 64], repeat=3)\n", + " ]\n", + "\n", + "\n", + "gemm.par_compile(get_configs())" + ] + }, + { + "cell_type": "markdown", + "id": "5160d2cc", + "metadata": {}, + "source": [ + "## More convenient macros" + ] + }, + { + "cell_type": "markdown", + "id": "be44afc4", + "metadata": {}, + "source": [ + "tilelang's macros have been improved:\n", + "\n", + "1. Allow using `T.Ref` as an annotation, similar to C++ references.\n", + "2. Allow returning multiple values.\n", + "3. Allow nesting and recursion." + ] + }, + { + "cell_type": "markdown", + "id": "79575972", + "metadata": {}, + "source": [ + "### Passing references with `T.Ref`\n", + "\n", + "A `T.Ref` reference can point to a scalar variable or to an element of a buffer." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "90eaa6e5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# from tvm.script import tir as T\n", + "\n", + "@T.prim_func\n", + "def foo(x_handle: T.handle):\n", + " x = T.match_buffer(x_handle, (2,), strides=(1,))\n", + " # with T.block(\"root\"):\n", + " bx = T.launch_thread(\"blockIdx.x\", 1)\n", + " tx = T.launch_thread(\"threadIdx.x\", 128)\n", + " ty = T.launch_thread(\"threadIdx.y\", 1)\n", + " tz = T.launch_thread(\"threadIdx.z\", 1)\n", + " with T.block(\"tilelang_root\"):\n", + " T.reads()\n", + " idx = T.Buffer((1,), \"int32\", scope=\"local.var\")\n", + " T.writes(x[T.min(1, idx[0]):T.min(1, idx[0]) + (T.max(1, idx[0]) + 1 - T.min(1, idx[0]))])\n", + " T.block_attr({\"tl.local_var_init\": {idx.data: 0}})\n", + " idx = T.alloc_buffer((1,), \"int32\", data=idx.data, scope=\"local.var\")\n", + " x[1] = T.float32(1.0)\n", + " _tmp: T.int32 = idx[0]\n", + " x[_tmp] = T.float32(1.0)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@T.macro\n", + "def macro_with_ref(x: T.Ref):\n", + " x = 1 # noqa: F841\n", + "\n", + "\n", + "@T.prim_func\n", + "def foo(x: T.Tensor((2,))):\n", + " with T.Kernel(1) as _:\n", + " # Supports constant indices\n", + " macro_with_ref(x[1])\n", + "\n", + " # Also supports variable indices\n", + " idx = T.alloc_var(T.int32, 0)\n", + " macro_with_ref(x[idx])\n", + "\n", + "\n", + "foo" + ] + }, + { + "cell_type": "markdown", + "id": "7bb447a2", + "metadata": {}, + "source": [ + "### Pass macros as arguments\n", + "\n", + "You can pass a macro as a function argument." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "dc7bb779", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def element_wise(A, fn):\n", + " N = T.dynamic(\"N\")\n", + " A: T.Tensor[[N], T.float32]\n", + " B = T.empty((N,), dtype=A.dtype)\n", + " block_N = 128\n", + " with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n", + " for i in T.Parallel(block_N):\n", + " idx = bx * block_N + i\n", + " B[idx] = fn(A[idx])\n", + " return B\n", + "\n", + "\n", + "@T.macro\n", + "def add_one(x):\n", + " return x + 1" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "a89fdb44", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, device=\"cuda\")\n", + "B = element_wise(A, add_one)\n", + "B_ref = A + 1\n", + "torch.testing.assert_close(B, B_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "ef6e403a", + "metadata": {}, + "source": [ + "### Recursive macros\n", + "\n", + "You may not need this often, but macros can be recursive as long as the termination condition is known at compile time." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "7703cab5", + "metadata": {}, + "outputs": [], + "source": [ + "@T.macro\n", + "def n31(x, var: T.Ref):\n", + " if x == 1:\n", + " pass\n", + " elif x % 2 == 0:\n", + " var = var // 2\n", + " n31(x // 2, var)\n", + " else:\n", + " var = var * 3 + 1\n", + " n31(x * 3 + 1, var)\n", + "\n", + "\n", + "@tilelang.jit\n", + "def foo(A: T.Tensor[[1], T.int32], n: int):\n", + " with T.Kernel(1) as _:\n", + " n31(n, A[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "542ddd4e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([18], device='cuda:0', dtype=torch.int32)" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "A = torch.tensor([100], dtype=torch.int32, device=\"cuda\")\n", + "foo(A, 5)\n", + "A" + ] + }, + { + "cell_type": "markdown", + "id": "dc30c2d2", + "metadata": {}, + "source": [ + "### Macros returning multiple values" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "d5a2388f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# from tvm.script import tir as T\n", + "\n", + "@T.prim_func\n", + "def foo():\n", + " # with T.block(\"root\"):\n", + " x = T.launch_thread(\"blockIdx.x\", 32)\n", + " tx = T.launch_thread(\"threadIdx.x\", 128)\n", + " ty = T.launch_thread(\"threadIdx.y\", 1)\n", + " tz = T.launch_thread(\"threadIdx.z\", 1)\n", + " with T.block(\"tilelang_root\"):\n", + " T.reads()\n", + " T.writes()\n", + " s: T.int32 = T.sin(x)\n", + " c: T.int32 = T.cos(x)\n", + " a: T.int32 = s + c\n", + " b: T.int32 = s - c\n", + " T.evaluate(0)" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@T.macro\n", + "def sincos(x):\n", + " return T.sin(x), T.cos(x)\n", + "\n", + "\n", + "@T.prim_func\n", + "def foo():\n", + " with T.Kernel(32) as x:\n", + " s, c = sincos(x)\n", + " a = s + c # noqa: F841\n", + " b = s - c # noqa: F841\n", + "\n", + "\n", + "foo" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd83fea7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tilelang-dev_0", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/eager_jit/eagerjit.zh.ipynb b/examples/eager_jit/eagerjit.zh.ipynb new file mode 100644 index 000000000..0f7c9be99 --- /dev/null +++ b/examples/eager_jit/eagerjit.zh.ipynb @@ -0,0 +1,977 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5e0deecc", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n", + "import tilelang\n", + "import torch\n", + "import tilelang.language as T" + ] + }, + { + "cell_type": "markdown", + "id": "1ca2c56d", + "metadata": {}, + "source": [ + "# Tilelang Lazy JIT" + ] + }, + { + "cell_type": "markdown", + "id": "156e7370", + "metadata": {}, + "source": [ + "## Tensor Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "b070c109", + "metadata": {}, + "source": [ + "Tilelang Lazy JIT 将 jit 生成和调用的逻辑合并到一起\n", + "\n", + "函数签名的写法与 triton 相似,但做了大量增强,最主要的增强是允许对 Tensor 的标注:\n", + "\n", + "* 如果一个 Tensor 有复杂的 shape 约束,我们可以把它的标注移动到函数内部\n", + "* 通过 `T.const` 或 `T.dynamic` 来建立一些 shape 变量,然后用 `T.Tensor` 标注复杂的 Tensor\n", + "* 用 `T.empty` 来声明返回值" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "60bf8954", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def gemm(\n", + " A,\n", + " B,\n", + " out_dtype: T.dtype = T.float32,\n", + " block_M: int = 128,\n", + " block_N: int = 128,\n", + " block_K: int = 32,\n", + "):\n", + " M, N, K = T.const(\"M, N, K\")\n", + "\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + "\n", + " C = T.empty((M, N), out_dtype)\n", + "\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", + " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), out_dtype)\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", + " T.copy(A[bx * block_M, k * block_K], A_shared)\n", + " T.copy(B[k * block_K, by * block_N], B_shared)\n", + " T.gemm(A_shared, B_shared, C_local)\n", + " T.copy(C_local, C[bx * block_M, by * block_N])\n", + " return C" + ] + }, + { + "cell_type": "markdown", + "id": "28f868fe", + "metadata": {}, + "source": [ + "直接将 Tensor 作为参数调用,即可触发完整的 jit 编译运行流程:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ee13394a", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm(A, B)\n", + "\n", + "# check output is correct\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "c6705091", + "metadata": {}, + "source": [ + "更改调用的参数,如果编译器参数发生了变化,会触发重新编译:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d8aab5b7", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 1024, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm(A, B, block_M=64, block_N=64)" + ] + }, + { + "cell_type": "markdown", + "id": "ce6b7391", + "metadata": {}, + "source": [ + "你也可以手动调用 compile 函数编译 kernel\n", + "\n", + "1. `ker.compile` 编译 kernel\n", + "2. `ker.get_tir` 获取 tir\n", + "3. `ker.par_compile` 并行编译" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f3cf3a2d", + "metadata": {}, + "outputs": [], + "source": [ + "kernel = gemm.compile(A, B, block_M=64, block_N=64)\n", + "C = kernel(A, B)" + ] + }, + { + "cell_type": "markdown", + "id": "921761b5", + "metadata": {}, + "source": [ + "## More Tensor Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "4539e54e", + "metadata": {}, + "source": [ + "### 用 macro 来分离实现" + ] + }, + { + "cell_type": "markdown", + "id": "ad96ba65", + "metadata": {}, + "source": [ + "接下来,我们会用各种方式来实现一个简单的 gemm,为了方便,我们先写一个 macro 把 gemm 的主要逻辑写出来:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "171d4fe6", + "metadata": {}, + "outputs": [], + "source": [ + "@T.macro\n", + "def gemm_impl(A, B, C, M, N, K, block_M, block_N, block_K):\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", + " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), C.dtype)\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", + " T.copy(A[bx * block_M, k * block_K], A_shared)\n", + " T.copy(B[k * block_K, by * block_N], B_shared)\n", + " T.gemm(A_shared, B_shared, C_local)\n", + " T.copy(C_local, C[bx * block_M, by * block_N])" + ] + }, + { + "cell_type": "markdown", + "id": "446a1acd", + "metadata": {}, + "source": [ + "### 用 T.dynamic 标记动态 Shape\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6a38aa95", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def gemm_dyn_K(A, B):\n", + " M, N, K = T.dynamic(\"M, N, K\")\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, 128, 128, 32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "fe6cfdc8", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_dyn_K(A, B)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "2ee97bf7", + "metadata": {}, + "source": [ + "### 用 T.StridedTensor 标记带 stride 的 Tensor\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "9dde1dae", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def as_contingious(A):\n", + " M, N, dM, dN = T.dynamic(\"M, N, dM, dN\")\n", + " A: T.StridedTensor[[M, N], [dM, dN], T.float32]\n", + " B = T.empty((M, N), A.dtype)\n", + " block_M = 128\n", + " block_N = 128\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " T.copy(\n", + " A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " )\n", + " return B" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "dec2c0a7", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 1024, device=\"cuda\")\n", + "B = as_contingious(A.T)\n", + "B_ref = A.T.contiguous()\n", + "torch.testing.assert_close(B, B_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "f5fb20d6", + "metadata": {}, + "source": [ + "## More Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "890df0a2", + "metadata": {}, + "source": [ + "### 直接用参数当 annotation" + ] + }, + { + "cell_type": "markdown", + "id": "e9a47d42", + "metadata": {}, + "source": [ + "可以直接把函数参数写到 annotation 里面" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "0fc17af6", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def gemm_ptr(\n", + " A,\n", + " B,\n", + " M,\n", + " N,\n", + " K,\n", + "):\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "8e52a554", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_ptr(A, B, 1024, 256, 512)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "6b19ef90", + "metadata": {}, + "source": [ + "### 对运行时变量的 annotation" + ] + }, + { + "cell_type": "markdown", + "id": "bba5f27f", + "metadata": {}, + "source": [ + "运行时变量也是一样,如果嫌函数 annotation 太长,可以放到函数体里面" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "c1e7598a", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def gemm_ptr_dyn(A, B, M, N, K):\n", + " M: T.int32\n", + " N: T.int32\n", + " K: T.int32\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "9e9a4c88", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "81427765", + "metadata": {}, + "source": [ + "### 常量的约束" + ] + }, + { + "cell_type": "markdown", + "id": "4d6b084b", + "metadata": {}, + "source": [ + "`T.const` 创建的常量 annotation 只要要被直接使用一次,否则会报错" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "c90dd24f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Constexpr variable `M` is not used in any buffer shape or stride.\n", + "At least one **DIRECT** usage is required. Please check:\n", + "(1) the variable is not used\n", + "(2) all uses are indirect, e.g. M * 2, M * 3. (you can replace them with separate constexpr variables)\n", + "Buffer shapes: {A: [M * 2, M * 3]}\n", + "Buffer strides: {A: [M * 3, 1]}\n" + ] + } + ], + "source": [ + "@tilelang.jit\n", + "def example_wrong_kernel(A):\n", + " M = T.const(\"M\")\n", + " A: T.Tensor[[M * 2, M * 3], T.float32]\n", + " with T.Kernel(1) as _:\n", + " A[0, 0]\n", + "\n", + "\n", + "try:\n", + " A = torch.randn(64, 96, dtype=torch.float32, device=\"cuda\")\n", + " example_wrong_kernel(A)\n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "id": "e07e762b", + "metadata": {}, + "source": [ + "### 动态维度的" + ] + }, + { + "cell_type": "markdown", + "id": "f48e5d7a", + "metadata": {}, + "source": [ + "如果想要 Tensor 的 annotation 类型某个参数变化,建议改成 T.ptr + T.match_buffer 格式。" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "1d050321", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@tilelang.jit\n", + "def dyn_annot(\n", + " A: T.ptr, # 1. T.ptr type annotation\n", + " is_2d=False,\n", + "):\n", + " if is_2d:\n", + " M, N = T.const(\"M, N\")\n", + " # 2. dynamic shape annotation inside function body\n", + " A = T.match_buffer(A, [M, N], T.float32)\n", + " with T.Kernel(1) as _:\n", + " A[0, 0]\n", + " else:\n", + " L = T.const(\"L\")\n", + " A = T.match_buffer(A, [L], T.float32)\n", + " with T.Kernel(1) as _:\n", + " A[0]\n", + "\n", + "\n", + "A = torch.randn(64, 96, dtype=torch.float32, device=\"cuda\")\n", + "dyn_annot(A, is_2d=True)" + ] + }, + { + "cell_type": "markdown", + "id": "2e9f1bb3", + "metadata": {}, + "source": [ + "### 带默认参数的" + ] + }, + { + "cell_type": "markdown", + "id": "f7fc9917", + "metadata": {}, + "source": [ + "类似 `T.float32` 标注的标量可以带默认参数" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "42ec86a1", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def add_one(X, data: T.float32 = 1):\n", + " M, N = T.const(\"M, N\")\n", + " X: T.Tensor[[M, N], T.float32]\n", + " Y = T.empty((M, N), T.float32)\n", + " with T.Kernel(T.ceildiv(M, 128), threads=128) as bx:\n", + " for i, j in T.Parallel(128, N):\n", + " Y[bx * 128 + i, j] = X[bx * 128 + i, j] + data\n", + " return Y" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "d49e1120", + "metadata": {}, + "outputs": [], + "source": [ + "X = torch.randn(1024, 1024, dtype=torch.float32, device=\"cuda\")\n", + "Y = add_one(X)\n", + "torch.testing.assert_close(Y, X + 1)" + ] + }, + { + "cell_type": "markdown", + "id": "a02baedc", + "metadata": {}, + "source": [ + "## 参数匹配的 Overhead" + ] + }, + { + "cell_type": "markdown", + "id": "860a2972", + "metadata": {}, + "source": [ + "EagerJIT overhead 很小,每个 constant 添加约 200ns 的 overhead\n", + "* 200ns 大约是从 torch.Tensor 的 shape/stride 中拿参数的 ffi call 的代价" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc676e33", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Kernel call : 7.68 us\n", + "Parse cache key: 0.41 us\n" + ] + } + ], + "source": [ + "import time\n", + "\n", + "A = torch.randn(128, 128, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(128, 128, dtype=torch.float16, device=\"cuda\")\n", + "\n", + "\n", + "@tilelang.jit\n", + "def dummy_kernel(A, B):\n", + " M, N = T.const(\"M, N\")\n", + " A: T.Tensor[[M, N], T.float16]\n", + " B: T.Tensor[[M, N], T.float16]\n", + " with T.Kernel(1) as _:\n", + " pass\n", + "\n", + "\n", + "# compile it first\n", + "dummy_kernel(A, B)\n", + "\n", + "\n", + "def eval_overhead(f):\n", + " start = time.perf_counter_ns()\n", + " for _ in range(10000):\n", + " f()\n", + " stop = time.perf_counter_ns()\n", + " return (stop - start) / 10000 / 1000\n", + "\n", + "\n", + "kernel_call_overhead = eval_overhead(lambda: dummy_kernel(A, B))\n", + "parse_cache_key_overhead = eval_overhead(lambda: dummy_kernel.parse_cache_key(A, B))\n", + "\n", + "print(f\"Kernel call : {kernel_call_overhead:.2f} us\")\n", + "print(f\"Parse cache key: {parse_cache_key_overhead:.2f} us\")" + ] + }, + { + "cell_type": "markdown", + "id": "39166cb4", + "metadata": {}, + "source": [ + "## 编译与并行编译" + ] + }, + { + "cell_type": "markdown", + "id": "8c6fbe08", + "metadata": {}, + "source": [ + "Eager JIT 和原来的 jit(即 LazyJIT) 都支持并行编译\n", + "\n", + "为了防止 torch.tensor 白白浪费内存,可以使用 T.Tensor 来创建 placeholder" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "7222e57b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8a4e4eb3cd4445bda6e8693da31ef3b8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Elaborating: 0%| | 0/8 [00:00,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from itertools import product\n", + "\n", + "\n", + "def get_configs():\n", + " return [\n", + " {\n", + " \"A\": T.Tensor((1024, 1024), T.float32),\n", + " \"B\": T.Tensor((1024, 1024), T.float32),\n", + " \"block_M\": block_M,\n", + " \"block_N\": block_N,\n", + " \"block_K\": block_K,\n", + " }\n", + " for block_M, block_N, block_K in product([32, 64], repeat=3)\n", + " ]\n", + "\n", + "\n", + "gemm.par_compile(get_configs())" + ] + }, + { + "cell_type": "markdown", + "id": "5160d2cc", + "metadata": {}, + "source": [ + "## 更便利的 Macro" + ] + }, + { + "cell_type": "markdown", + "id": "be44afc4", + "metadata": {}, + "source": [ + "tilelang 的 macro 现在已经升级:\n", + "\n", + "1. 允许用 `T.Ref` 作为 annotation,这类似与 C++ 的引用传递\n", + "2. 允许返回多个值\n", + "3. 允许嵌套,递归" + ] + }, + { + "cell_type": "markdown", + "id": "79575972", + "metadata": {}, + "source": [ + "### T.Ref 传递引用\n", + "\n", + "T.Ref 传递的引用可以 var 也可以是 Buffer 的索引" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "90eaa6e5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# from tvm.script import tir as T\n", + "\n", + "@T.prim_func\n", + "def foo(x_handle: T.handle):\n", + " x = T.match_buffer(x_handle, (2,), strides=(1,))\n", + " # with T.block(\"root\"):\n", + " bx = T.launch_thread(\"blockIdx.x\", 1)\n", + " tx = T.launch_thread(\"threadIdx.x\", 128)\n", + " ty = T.launch_thread(\"threadIdx.y\", 1)\n", + " tz = T.launch_thread(\"threadIdx.z\", 1)\n", + " with T.block(\"tilelang_root\"):\n", + " T.reads()\n", + " idx = T.Buffer((1,), \"int32\", scope=\"local.var\")\n", + " T.writes(x[T.min(1, idx[0]):T.min(1, idx[0]) + (T.max(1, idx[0]) + 1 - T.min(1, idx[0]))])\n", + " T.block_attr({\"tl.local_var_init\": {idx.data: 0}})\n", + " idx = T.alloc_buffer((1,), \"int32\", data=idx.data, scope=\"local.var\")\n", + " x[1] = T.float32(1.0)\n", + " _tmp: T.int32 = idx[0]\n", + " x[_tmp] = T.float32(1.0)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@T.macro\n", + "def macro_with_ref(x: T.Ref):\n", + " x = 1 # noqa: F841\n", + "\n", + "\n", + "@T.prim_func\n", + "def foo(x: T.Tensor((2,))):\n", + " with T.Kernel(1) as _:\n", + " # 支持常量 index\n", + " macro_with_ref(x[1])\n", + "\n", + " # 也支持变量 index\n", + " idx = T.alloc_var(T.int32, 0)\n", + " macro_with_ref(x[idx])\n", + "\n", + "\n", + "foo" + ] + }, + { + "cell_type": "markdown", + "id": "7bb447a2", + "metadata": {}, + "source": [ + "### 当作参数传递\n", + "\n", + "你可以把 macro 当做参数传递" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "dc7bb779", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def element_wise(A, fn):\n", + " N = T.dynamic(\"N\")\n", + " A: T.Tensor[[N], T.float32]\n", + " B = T.empty((N,), dtype=A.dtype)\n", + " block_N = 128\n", + " with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n", + " for i in T.Parallel(block_N):\n", + " idx = bx * block_N + i\n", + " B[idx] = fn(A[idx])\n", + " return B\n", + "\n", + "\n", + "@T.macro\n", + "def add_one(x):\n", + " return x + 1" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "a89fdb44", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, device=\"cuda\")\n", + "B = element_wise(A, add_one)\n", + "B_ref = A + 1\n", + "torch.testing.assert_close(B, B_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "ef6e403a", + "metadata": {}, + "source": [ + "### Macro 递归\n", + "\n", + "虽然不知道有没有这种需求,但 macro 是可以递归的,终止条件要求编译期间确定" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "7703cab5", + "metadata": {}, + "outputs": [], + "source": [ + "@T.macro\n", + "def n31(x, var: T.Ref):\n", + " if x == 1:\n", + " pass\n", + " elif x % 2 == 0:\n", + " var = var // 2\n", + " n31(x // 2, var)\n", + " else:\n", + " var = var * 3 + 1\n", + " n31(x * 3 + 1, var)\n", + "\n", + "\n", + "@tilelang.jit\n", + "def foo(A: T.Tensor[[1], T.int32], n: int):\n", + " with T.Kernel(1) as _:\n", + " n31(n, A[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "542ddd4e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([18], device='cuda:0', dtype=torch.int32)" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "A = torch.tensor([100], dtype=torch.int32, device=\"cuda\")\n", + "foo(A, 5)\n", + "A" + ] + }, + { + "cell_type": "markdown", + "id": "dc30c2d2", + "metadata": {}, + "source": [ + "### Macro 返回多个值" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "d5a2388f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# from tvm.script import tir as T\n", + "\n", + "@T.prim_func\n", + "def foo():\n", + " # with T.block(\"root\"):\n", + " x = T.launch_thread(\"blockIdx.x\", 32)\n", + " tx = T.launch_thread(\"threadIdx.x\", 128)\n", + " ty = T.launch_thread(\"threadIdx.y\", 1)\n", + " tz = T.launch_thread(\"threadIdx.z\", 1)\n", + " with T.block(\"tilelang_root\"):\n", + " T.reads()\n", + " T.writes()\n", + " s: T.int32 = T.sin(x)\n", + " c: T.int32 = T.cos(x)\n", + " a: T.int32 = s + c\n", + " b: T.int32 = s - c\n", + " T.evaluate(0)" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@T.macro\n", + "def sincos(x):\n", + " return T.sin(x), T.cos(x)\n", + "\n", + "\n", + "@T.prim_func\n", + "def foo():\n", + " with T.Kernel(32) as x:\n", + " s, c = sincos(x)\n", + " a = s + c # noqa: F841\n", + " b = s - c # noqa: F841\n", + "\n", + "\n", + "foo" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd83fea7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tilelang-dev_0", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/elementwise/example_elementwise_add.py b/examples/elementwise/example_elementwise_add.py index bc9bb4df5..3d142ed54 100644 --- a/examples/elementwise/example_elementwise_add.py +++ b/examples/elementwise/example_elementwise_add.py @@ -1,9 +1,7 @@ import argparse -import itertools import torch import tilelang import tilelang.language as T -from tilelang.autotuner import AutoTuner def ref_program(x, y): @@ -12,10 +10,8 @@ def ref_program(x, y): @tilelang.jit(out_idx=[-1]) def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): - @T.prim_func - def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor( - (M, N), out_dtype)): + def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared((block_M, block_N), in_dtype) B_shared = T.alloc_shared((block_M, block_N), in_dtype) @@ -24,7 +20,7 @@ def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T. T.copy(A[by * block_M, bx * block_N], A_shared) T.copy(B[by * block_M, bx * block_N], B_shared) - for (local_y, local_x) in T.Parallel(block_M, block_N): + for local_y, local_x in T.Parallel(block_M, block_N): C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x] T.copy(C_local, C_shared) T.copy(C_shared, C[by * block_M, bx * block_N]) @@ -32,53 +28,34 @@ def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T. return elem_add -def get_configs(M, N): - block_M = [64, 128, 256] - block_N = [64, 128, 256] - threads = [64, 128, 256] - configs = list(itertools.product(block_M, block_N, threads)) - return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs] - - -def get_best_config(M, N): +def main(M=1024, N=1024, use_autotune=False): + a = torch.randn(M, N, dtype=torch.float32, device="cuda") + b = torch.randn(M, N, dtype=torch.float32, device="cuda") - def kernel(block_M=None, block_N=None, threads=None): - return elementwise_add(M, N, block_M, block_N, "float32", "float32", threads) + kernel = elementwise_add(M, N, block_M=32, block_N=32, threads=128, in_dtype=T.float32, out_dtype=T.float32) - autotuner = AutoTuner.from_kernel( - kernel=kernel, configs=get_configs(M, N)).set_compile_args( - out_idx=[-1], - target="cuda", - ).set_profile_args( - supply_type=tilelang.TensorSupplyType.Auto, - ref_prog=ref_program, - skip_check=False, - ) - return autotuner.run(warmup=3, rep=20) + out = kernel(a, b) + torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) -def main(): +def run_regression_perf(): parser = argparse.ArgumentParser() - parser.add_argument("--m", type=int, default=1024) - parser.add_argument("--n", type=int, default=1024) - parser.add_argument("--use_autotune", action="store_true", default=False) + parser.add_argument("--m", type=int, default=4096) + parser.add_argument("--n", type=int, default=4096) args, _ = parser.parse_known_args() M, N = args.m, args.n - a = torch.randn(M, N, dtype=torch.float32, device="cuda") b = torch.randn(M, N, dtype=torch.float32, device="cuda") + config = {"block_M": 32, "block_N": 32, "threads": 128} + kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") + from tilelang.profiler import do_bench - if args.use_autotune: - result = get_best_config(M, N) - kernel = result.kernel - else: - # Default config - config = {"block_M": 32, "block_N": 32, "threads": 128} - kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") - - out = kernel(a, b) - torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) + return do_bench(lambda: kernel(a, b), backend="cupti") if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + parser.add_argument("--m", type=int, default=1024) + parser.add_argument("--n", type=int, default=1024) + args, _ = parser.parse_known_args() + main(args.m, args.n) diff --git a/examples/elementwise/regression_example_elementwise.py b/examples/elementwise/regression_example_elementwise.py new file mode 100644 index 000000000..261202a56 --- /dev/null +++ b/examples/elementwise/regression_example_elementwise.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_elementwise_add + + +def regression_example_elementwise_add(): + tilelang.testing.process_func(example_elementwise_add.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/flash_attention/README.md b/examples/flash_attention/README.md index be11a8dc6..355ed7325 100644 --- a/examples/flash_attention/README.md +++ b/examples/flash_attention/README.md @@ -34,8 +34,6 @@ def flash_attention( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - # Annotate layout for Q_shared, e.g., use a swizzled layout to optimize memory access - T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) # Copy a block of Q from global memory to Q_shared T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) @@ -77,6 +75,8 @@ def flash_attention( # Compute the maximum value per row on dimension 1 (block_N) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # Compute the factor by which we need to rescale previous partial sums for i in T.Parallel(block_M): @@ -106,4 +106,4 @@ def flash_attention( # Write back the final output block from acc_o to the Output buffer T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) -``` \ No newline at end of file +``` diff --git a/examples/flash_attention/bert_padding.py b/examples/flash_attention/bert_padding.py index 7058fd773..15c4097ce 100644 --- a/examples/flash_attention/bert_padding.py +++ b/examples/flash_attention/bert_padding.py @@ -6,7 +6,6 @@ class IndexFirstAxis(torch.autograd.Function): - @staticmethod def forward(ctx, input, indices): ctx.save_for_backward(indices) @@ -15,9 +14,7 @@ def forward(ctx, input, indices): second_dim = other_shape.numel() # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. # return input[indices] - return torch.gather( - rearrange(input, "b ... -> b (...)"), 0, - repeat(indices, "z -> z d", d=second_dim)).reshape(-1, *other_shape) + return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)).reshape(-1, *other_shape) @staticmethod def backward(ctx, grad_output): @@ -40,14 +37,12 @@ def backward(ctx, grad_output): class IndexPutFirstAxis(torch.autograd.Function): - @staticmethod def forward(ctx, values, indices, first_axis_dim): ctx.save_for_backward(indices) assert indices.ndim == 1 assert values.ndim >= 2 - output = torch.zeros( - first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) + output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. output[indices] = values # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) @@ -66,7 +61,6 @@ def backward(ctx, grad_output): class IndexFirstAxisResidual(torch.autograd.Function): - @staticmethod def forward(ctx, input, indices): ctx.save_for_backward(indices) @@ -128,7 +122,7 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng """ Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model). The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286). - + For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: ``` [ @@ -177,9 +171,7 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng """ length = attention_mask_in_length.sum(dim=-1) seqlen = attention_mask_in_length.size(-1) - attention_mask_2d = torch.arange( - seqlen, device=length.device, dtype=length.dtype).expand(len(length), - seqlen) < length.unsqueeze(1) + attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1) real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten() seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() diff --git a/examples/flash_attention/example_gqa_bwd.py b/examples/flash_attention/example_gqa_bwd.py index 907a121d2..801927faf 100644 --- a/examples/flash_attention/example_gqa_bwd.py +++ b/examples/flash_attention/example_gqa_bwd.py @@ -6,25 +6,27 @@ @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_fwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -39,26 +41,25 @@ def flash_fwd( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim_v): @@ -72,29 +73,31 @@ def flash_fwd( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim_v): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim_v] blk = 32 @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -103,81 +106,74 @@ def flash_bwd_prep( delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim_v, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment - return T.Layout(dQ.shape, - lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + return T.Layout(dQ.shape, lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) @tilelang.jit( - out_idx=[1], pass_configs={ + out_idx=[1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim_qk] blk = 64 @T.prim_func def flash_bwd_post( - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(shape, dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) T.copy( - dQ[bz, bx * blk:(bx + 1) * blk, by, :], - dQ_out[bz, bx * blk:(bx + 1) * blk, by, :], + dQ[bz, bx * blk : (bx + 1) * blk, by, :], + dQ_out[bz, bx * blk : (bx + 1) * blk, by, :], ) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_atomic_add(batch, - heads, - seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -197,35 +193,35 @@ def flash_bwd( dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype) dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - }) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -237,49 +233,41 @@ def flash_bwd( for i, j in T.Parallel(block_N, dim_qk): T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) - T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared) + T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared) T.copy(dk, dk_shared) - T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared) + T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared) return flash_bwd -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_split(batch, - heads, - seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_split(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(dk_shape, dtype), # type: ignore - dV: T.Tensor(dv_shape, dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(dk_shape, dtype), # type: ignore + dV: T.Tensor(dv_shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -299,37 +287,35 @@ def flash_bwd( dv_shared = T.alloc_shared([block_M, dim_v], dtype) dk_shared = T.alloc_shared([block_M, dim_qk], dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -342,16 +328,15 @@ def flash_bwd( T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) - T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) + T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) T.copy(dk, dk_shared) - T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) + T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) return flash_bwd @torch.compile class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): BATCH, N_CTX, H, D_HEAD_QK = q.shape @@ -369,7 +354,10 @@ def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): def backward(ctx, do): q, k, v, o, lse = ctx.saved_tensors BATCH, N_CTX, H, D_HEAD_QK = q.shape - HEAD_KV, D_HEAD_V, = v.shape[-2], v.shape[-1] + ( + HEAD_KV, + D_HEAD_V, + ) = v.shape[-2], v.shape[-1] groups = H // HEAD_KV def maybe_contiguous(x): @@ -386,17 +374,8 @@ def maybe_contiguous(x): if ctx.use_atomic: kernel = flashattn_bwd_atomic_add( - BATCH, - H, - N_CTX, - D_HEAD_QK, - D_HEAD_V, - ctx.causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=groups) + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] @@ -409,17 +388,8 @@ def maybe_contiguous(x): dv = dv.to(torch.float16) else: kernel = flashattn_bwd_split( - BATCH, - H, - N_CTX, - D_HEAD_QK, - D_HEAD_V, - ctx.causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=groups) + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel @@ -441,53 +411,45 @@ def ref_program(Q, K, V, is_causal, groups=1): # K: [B, T, HK, D_QK] # V: [B, T, HV, D_V] # HQ = HKV * groups - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim_qk = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output -def main(BATCH: int = 1, - H: int = 32, - N_CTX: int = 256, - D_HEAD_QK: int = 192, - D_HEAD_V: int = 128, - groups: int = 16, - causal: bool = False, - use_atomic: bool = True): +def main( + BATCH: int = 1, + H: int = 32, + N_CTX: int = 256, + D_HEAD_QK: int = 192, + D_HEAD_V: int = 128, + groups: int = 16, + causal: bool = False, + use_atomic: bool = True, +): flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V total_flops = 3 * flops_per_qk + 2 * flops_per_v if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() head_kv = H // groups - K = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - V = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - dO = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() O = attention(Q, K, V, causal, groups, use_atomic) O.backward(dO, retain_graph=True) dQ, Q.grad = Q.grad.clone(), None @@ -504,7 +466,7 @@ def main(BATCH: int = 1, torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) - print('All checks passed.✅') + print("All checks passed.✅") def run(): O_ref.backward(dO, retain_graph=True) @@ -522,19 +484,61 @@ def run1(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(): + BATCH = 1 + H = 32 + N_CTX = 256 + D_HEAD_QK = 192 + D_HEAD_V = 128 + groups = 16 + causal = False + device = "cuda" + torch.manual_seed(42) + head_kv = H // groups + Q = torch.randn(BATCH, N_CTX, H, D_HEAD_QK, device=device, dtype=torch.half) + K = torch.randn(BATCH, N_CTX, head_kv, D_HEAD_QK, device=device, dtype=torch.half) + V = torch.randn(BATCH, N_CTX, head_kv, D_HEAD_V, device=device, dtype=torch.half) + O = torch.randn(BATCH, N_CTX, H, D_HEAD_V, device=device, dtype=torch.half) + dO = torch.randn(BATCH, N_CTX, H, D_HEAD_V, device=device, dtype=torch.half) + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) + kernel = flashattn_bwd_split( + BATCH, + H, + N_CTX, + D_HEAD_QK, + D_HEAD_V, + causal, + block_M=128, + block_N=32, + threads=256, + num_stages=2, + groups=groups, + ) + dQ = torch.zeros_like(Q, dtype=torch.float32) + dK = torch.zeros(groups, BATCH, N_CTX, head_kv, D_HEAD_QK, device=device, dtype=torch.float16) + dV = torch.zeros(groups, BATCH, N_CTX, head_kv, D_HEAD_V, device=device, dtype=torch.float16) + Delta = mod_prep(O, dO) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, dQ, dK, dV) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') - parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') - parser.add_argument('--causal', action='store_true', help='Causal flag') - parser.add_argument('--groups', type=int, default=16, help='groups') - parser.add_argument( - '--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') - parser.add_argument( - '--use_split', action='store_true', default=False, help='Use split for dK/dV') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV") + parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV") args = parser.parse_args() # Handle backward compatibility and logic @@ -546,5 +550,4 @@ def run1(): # Default: use atomic use_atomic = True - main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, - use_atomic) + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce.py b/examples/flash_attention/example_gqa_bwd_tma_reduce.py index 615c2e191..4920d8cf0 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce.py @@ -5,29 +5,29 @@ from tilelang.contrib import nvcc import argparse -tilelang.disable_cache() - @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_fwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -42,28 +42,27 @@ def flash_fwd( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) # Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops # We should set it to negative large number instead - T.fill(scores_max, T.Cast(accum_dtype, -1e30)) - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + T.fill(scores_max, T.cast(-1e30, accum_dtype)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - T.Cast(accum_dtype, -1e30)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, T.cast(-1e30, accum_dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim_v): @@ -77,29 +76,31 @@ def flash_fwd( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim_v): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim_v] blk = 32 @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -108,12 +109,12 @@ def flash_bwd_prep( delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim_v, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep @@ -124,12 +125,14 @@ def make_dq_layout(dQ): @tilelang.jit( - out_idx=[3, 4, 5], pass_configs={ + out_idx=[3, 4, 5], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] @@ -137,64 +140,55 @@ def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v): @T.prim_func def flash_bwd_post( - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(q_shape, dtype), # type: ignore - dK_out: T.Tensor(k_shape, dtype), # type: ignore - dV_out: T.Tensor(v_shape, dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(q_shape, dtype), # type: ignore + dK_out: T.Tensor(k_shape, dtype), # type: ignore + dV_out: T.Tensor(v_shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) - T.copy(dQ[bz, bx * blk:(bx + 1) * blk, by, :], dQ_out[bz, bx * blk:(bx + 1) * blk, - by, :]) + T.copy(dQ[bz, bx * blk : (bx + 1) * blk, by, :], dQ_out[bz, bx * blk : (bx + 1) * blk, by, :]) with T.Kernel(T.ceildiv(seq_len, blk), head_kv, batch, threads=128) as (bx, by, bz): - T.annotate_layout({ - dK: make_dq_layout(dK), - dV: make_dq_layout(dV), - }) - T.copy(dK[bz, bx * blk:(bx + 1) * blk, by, :], dK_out[bz, bx * blk:(bx + 1) * blk, - by, :]) - T.copy(dV[bz, bx * blk:(bx + 1) * blk, by, :], dV_out[bz, bx * blk:(bx + 1) * blk, - by, :]) + T.annotate_layout( + { + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + } + ) + T.copy(dK[bz, bx * blk : (bx + 1) * blk, by, :], dK_out[bz, bx * blk : (bx + 1) * blk, by, :]) + T.copy(dV[bz, bx * blk : (bx + 1) * blk, by, :], dV_out[bz, bx * blk : (bx + 1) * blk, by, :]) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_atomic_add(batch, - heads, - seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -215,37 +209,29 @@ def flash_bwd( dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - dK: make_dq_layout(dK), - dV: make_dq_layout(dV), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - }) - - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -255,53 +241,43 @@ def flash_bwd( T.clear(dq) T.gemm(dsT_shared, K_shared, dq, transpose_A=True) T.copy(dq, dq_shared) - T.atomic_add(dQ[bz, k * block_N:(k + 1) * block_N, bx, :], dq_shared, use_tma=True) + T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared, use_tma=True) T.copy(dv, dv_shared) - T.atomic_add( - dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared, use_tma=True) + T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared, use_tma=True) T.copy(dk, dk_shared) - T.atomic_add( - dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared, use_tma=True) + T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared, use_tma=True) return flash_bwd -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_split_novarlen(batch, - heads, - seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_split_novarlen(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(dk_shape, dtype), # type: ignore - dV: T.Tensor(dv_shape, dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(dk_shape, dtype), # type: ignore + dV: T.Tensor(dv_shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -321,37 +297,35 @@ def flash_bwd( dv_shared = T.alloc_shared([block_M, dim_v], dtype) dk_shared = T.alloc_shared([block_M, dim_qk], dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -364,16 +338,15 @@ def flash_bwd( T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) - T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) + T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) T.copy(dk, dk_shared) - T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) + T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) return flash_bwd @torch.compile class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): BATCH, N_CTX, H, D_HEAD_QK = q.shape @@ -391,7 +364,10 @@ def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): def backward(ctx, do): q, k, v, o, lse = ctx.saved_tensors BATCH, N_CTX, H, D_HEAD_QK = q.shape - HEAD_KV, D_HEAD_V, = v.shape[-2], v.shape[-1] + ( + HEAD_KV, + D_HEAD_V, + ) = v.shape[-2], v.shape[-1] groups = H // HEAD_KV def maybe_contiguous(x): @@ -403,22 +379,12 @@ def maybe_contiguous(x): block_M = 128 block_N = 32 mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) - mod_post = flashattn_bwd_postprocess(BATCH, H, HEAD_KV, N_CTX, D_HEAD_QK, D_HEAD_V) delta = mod_prep(o, do) if ctx.use_atomic: kernel = flashattn_bwd_atomic_add( - BATCH, - H, - N_CTX, - D_HEAD_QK, - D_HEAD_V, - ctx.causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=groups) + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] @@ -426,20 +392,11 @@ def maybe_contiguous(x): dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device) dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device) kernel(q, k, v, do, lse, delta, dq, dk, dv) - dq, dk, dv = mod_post(dq, dk, dv) else: kernel = flashattn_bwd_split_novarlen( - BATCH, - H, - N_CTX, - D_HEAD_QK, - D_HEAD_V, - ctx.causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=groups) + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) + mod_post = flashattn_bwd_postprocess(BATCH, H, HEAD_KV, N_CTX, D_HEAD_QK, D_HEAD_V) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel @@ -447,8 +404,7 @@ def maybe_contiguous(x): dk = torch.empty(shape_k, dtype=torch.float16, device=q.device) dv = torch.empty(shape_v, dtype=torch.float16, device=q.device) kernel(q, k, v, do, lse, delta, dq, dk, dv) - dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), - torch.zeros_like(v, dtype=torch.float32)) + dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32)) dk, dv = dk.sum(0), dv.sum(0) return dq, dk, dv, None, None, None @@ -462,53 +418,45 @@ def ref_program(Q, K, V, is_causal, groups=1): # K: [B, T, HK, D_QK] # V: [B, T, HV, D_V] # HQ = HKV * groups - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim_qk = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output -def main(BATCH: int = 1, - H: int = 32, - N_CTX: int = 256, - D_HEAD_QK: int = 192, - D_HEAD_V: int = 128, - groups: int = 16, - causal: bool = False, - use_atomic: bool = True): +def main( + BATCH: int = 1, + H: int = 32, + N_CTX: int = 256, + D_HEAD_QK: int = 192, + D_HEAD_V: int = 128, + groups: int = 16, + causal: bool = False, + use_atomic: bool = True, +): flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V total_flops = 3 * flops_per_qk + 2 * flops_per_v if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() head_kv = H // groups - K = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - V = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - dO = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() O = attention(Q, K, V, causal, groups, use_atomic) O.backward(dO, retain_graph=True) dQ, Q.grad = Q.grad.clone(), None @@ -525,7 +473,7 @@ def main(BATCH: int = 1, torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) - print('All checks passed.✅') + print("All checks passed.✅") def run(): O_ref.backward(dO, retain_graph=True) @@ -548,17 +496,15 @@ def run1(): print(f"Detected GPU compute capability: {arch}") assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0" parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') - parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') - parser.add_argument('--causal', action='store_true', help='Causal flag') - parser.add_argument('--groups', type=int, default=16, help='groups') - parser.add_argument( - '--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') - parser.add_argument( - '--use_split', action='store_true', default=False, help='Use split for dK/dV') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV") + parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV") args = parser.parse_args() # Handle backward compatibility and logic @@ -570,5 +516,4 @@ def run1(): # Default: use atomic use_atomic = True - main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, - use_atomic) + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py index 88f2d81e1..b09eec00c 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py @@ -7,57 +7,44 @@ from einops import rearrange, repeat from bert_padding import pad_input, unpad_input -# tilelang.disable_cache() - def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): assert mode in ["full", "random", "third"] if mode == "full": lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) elif mode == "random": - lengths = torch.randint( - max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) + lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) elif mode == "third": lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) - padding_mask = ( - repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths) + padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths return padding_mask @tilelang.jit( - out_idx=[5, 6], pass_configs={ + out_idx=[5, 6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn_fwd(batch, - total_q, - total_kv, - N_CTX, - heads, - max_seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - groups=1): - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn_fwd(batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [total_q, heads, dim_qk] k_shape = [total_kv, head_kv, dim_qk] v_shape = [total_kv, head_kv, dim_v] o_shape = [total_q, heads, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_fwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore - cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore - Output: T.Tensor(o_shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore + cu_seqlens_k: T.Tensor([batch + 1], T.int32), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(max_seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -79,8 +66,6 @@ def flash_fwd( q_current_seqlen = q_end_idx - q_start_idx k_current_seqlen = k_end_idx - k_start_idx - T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - for i, d in T.Parallel(block_M, dim_qk): if bx * block_M + i < q_current_seqlen: Q_shared[i, d] = Q[q_start_idx + bx * block_M + i, by, d] @@ -91,7 +76,7 @@ def flash_fwd( T.fill(logsum, 0.0) # Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops # We should set it to negative large number instead - T.fill(scores_max, T.Cast(accum_dtype, -1e30)) + T.fill(scores_max, T.cast(-1e30, accum_dtype)) loop_range = T.ceildiv(k_current_seqlen, block_N) for k in T.Pipelined(loop_range, num_stages=1): for i, d in T.Parallel(block_N, dim_qk): @@ -102,15 +87,17 @@ def flash_fwd( if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and - (bx * block_M + i < q_current_seqlen and - k * block_N + j < k_current_seqlen), 0, - T.Cast(accum_dtype, -1e30)) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= k * block_N + j) + and (bx * block_M + i < q_current_seqlen and k * block_N + j < k_current_seqlen), + 0, + T.cast(-1e30, accum_dtype), + ) else: for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.if_then_else( - bx * block_M + i < q_current_seqlen and - k * block_N + j < k_current_seqlen, 0, T.Cast(accum_dtype, -1e30)) + bx * block_M + i < q_current_seqlen and k * block_N + j < k_current_seqlen, 0, T.cast(-1e30, accum_dtype) + ) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, d in T.Parallel(block_N, dim_v): if k * block_N + i < k_current_seqlen: @@ -119,6 +106,8 @@ def flash_fwd( V_shared[i, d] = 0.0 T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim_v): @@ -146,21 +135,23 @@ def flash_fwd( @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, total_q, N_CTX, max_seq_len, dim_v): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [total_q, heads, dim_v] blk = 32 @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore - Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore + Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(max_seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -199,12 +190,14 @@ def make_dq_layout(dQ): @tilelang.jit( - out_idx=[3, 4, 5], pass_configs={ + out_idx=[3, 4, 5], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 q_shape = [total_q, heads, dim_qk] k_shape = [total_kv, head_kv, dim_qk] v_shape = [total_kv, head_kv, dim_v] @@ -212,70 +205,62 @@ def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v): @T.prim_func def flash_bwd_post( - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(q_shape, dtype), # type: ignore - dK_out: T.Tensor(k_shape, dtype), # type: ignore - dV_out: T.Tensor(v_shape, dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(q_shape, dtype), # type: ignore + dK_out: T.Tensor(k_shape, dtype), # type: ignore + dV_out: T.Tensor(v_shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(total_q, blk), heads, threads=128) as (bx, by): T.annotate_layout({dQ: make_dq_layout(dQ)}) - T.copy(dQ[bx * blk:(bx + 1) * blk, by, :], dQ_out[bx * blk:(bx + 1) * blk, by, :]) + T.copy(dQ[bx * blk : (bx + 1) * blk, by, :], dQ_out[bx * blk : (bx + 1) * blk, by, :]) with T.Kernel(T.ceildiv(total_kv, blk), head_kv, threads=128) as (bx, by): - T.annotate_layout({ - dK: make_dq_layout(dK), - dV: make_dq_layout(dV), - }) - T.copy(dK[bx * blk:(bx + 1) * blk, by, :], dK_out[bx * blk:(bx + 1) * blk, by, :]) - T.copy(dV[bx * blk:(bx + 1) * blk, by, :], dV_out[bx * blk:(bx + 1) * blk, by, :]) + T.annotate_layout( + { + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + } + ) + T.copy(dK[bx * blk : (bx + 1) * blk, by, :], dK_out[bx * blk : (bx + 1) * blk, by, :]) + T.copy(dV[bx * blk : (bx + 1) * blk, by, :], dV_out[bx * blk : (bx + 1) * blk, by, :]) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_atomic_add(batch, - total_q, - total_kv, - N_CTX, - heads, - max_seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_atomic_add( + batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1 +): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [total_q, heads, dim_qk] k_shape = [total_kv, head_kv, dim_qk] v_shape = [total_kv, head_kv, dim_v] do_shape = [total_q, heads, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor(do_shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore - cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore - cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor(do_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore + cu_seqlens_k: T.Tensor([batch + 1], T.int32), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore ): - with T.Kernel( - heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz): + with T.Kernel(heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype) q = T.alloc_shared([block_N, dim_qk], dtype) @@ -301,58 +286,45 @@ def flash_bwd( q_current_seqlen = q_end_idx - q_start_idx k_current_seqlen = k_end_idx - k_start_idx - T.annotate_layout({ - dQ: make_dq_layout(dQ), - dK: make_dq_layout(dK), - dV: make_dq_layout(dV), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - }) - - T.copy(K[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :], - K_shared) - T.copy(V[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :], - V_shared) + T.copy(K[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) - loop_st = T.min( - T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, - block_N)) if is_causal else 0 + loop_st = T.min(T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, block_N)) if is_causal else 0 loop_ed = T.ceildiv(q_current_seqlen, block_N) for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy( - Q[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :], - q) + T.copy(Q[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k_base * block_N:(k_base + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k_base * block_N : (k_base + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else((by * block_M + i <= k_base * block_N + j) and - (by * block_M + i < k_current_seqlen and - k_base * block_N + j < q_current_seqlen), - qkT[i, j], 0) + qkT[i, j] = T.if_then_else( + (by * block_M + i <= k_base * block_N + j) + and (by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen), + qkT[i, j], + 0, + ) else: for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.if_then_else( - by * block_M + i < k_current_seqlen and - k_base * block_N + j < q_current_seqlen, qkT[i, j], 0) + by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen, qkT[i, j], 0 + ) - T.copy( - dO[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :], - do) + T.copy(dO[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], do) T.clear(dsT) # dsT: (block_kv, block_q) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k_base * block_N:(k_base + 1) * block_N], delta) + T.copy(Delta[bz, bx, k_base * block_N : (k_base + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) @@ -362,49 +334,40 @@ def flash_bwd( T.gemm(dsT_shared, K_shared, dq, transpose_A=True) T.copy(dq, dq_shared) T.atomic_add( - dQ[q_start_idx + k_base * block_N:q_start_idx + k_base * block_N + block_N, - bx, :], + dQ[q_start_idx + k_base * block_N : q_start_idx + k_base * block_N + block_N, bx, :], dq_shared, memory_order="relaxed", - use_tma=True) + use_tma=True, + ) T.copy(dv, dv_shared) T.atomic_add( - dV[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M, - bx // groups, :], + dV[k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :], dv_shared, memory_order="relaxed", - use_tma=True) + use_tma=True, + ) T.copy(dk, dk_shared) T.atomic_add( - dK[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M, - bx // groups, :], + dK[k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :], dk_shared, memory_order="relaxed", - use_tma=True) + use_tma=True, + ) return flash_bwd -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_split(batch, - total_q, - total_kv, - N_CTX, - heads, - max_seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_split( + batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1 +): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [total_q, heads, dim_qk] k_shape = [total_kv, head_kv, dim_qk] @@ -412,25 +375,24 @@ def flashattn_bwd_split(batch, do_shape = [total_q, heads, dim_v] dk_shape = [groups, total_kv, head_kv, dim_qk] # sum after kernel dv_shape = [groups, total_kv, head_kv, dim_v] # sum after kernel - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor(do_shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore - cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore - cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(dk_shape, dtype), # type: ignore - dV: T.Tensor(dv_shape, dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor(do_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore + cu_seqlens_k: T.Tensor([batch + 1], T.int32), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(dk_shape, dtype), # type: ignore + dV: T.Tensor(dv_shape, dtype), # type: ignore ): - with T.Kernel( - heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz): + with T.Kernel(heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype) q = T.alloc_shared([block_N, dim_qk], dtype) @@ -455,59 +417,52 @@ def flash_bwd( q_current_seqlen = q_end_idx - q_start_idx k_current_seqlen = k_end_idx - k_start_idx - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) - T.copy(K[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :], - K_shared) - T.copy(V[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :], - V_shared) + T.copy(K[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) - loop_st = T.min( - T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, - block_N)) if is_causal else 0 + loop_st = T.min(T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, block_N)) if is_causal else 0 loop_ed = T.ceildiv(q_current_seqlen, block_N) for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): # Note: The padding zero of varlen should be considered in T.copy - T.copy( - Q[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :], - q) + T.copy(Q[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy( - dO[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :], - do) + T.copy(dO[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k_base * block_N:(k_base + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k_base * block_N : (k_base + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else((by * block_M + i <= k_base * block_N + j) and - (by * block_M + i < k_current_seqlen and - k_base * block_N + j < q_current_seqlen), - qkT[i, j], 0) + qkT[i, j] = T.if_then_else( + (by * block_M + i <= k_base * block_N + j) + and (by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen), + qkT[i, j], + 0, + ) else: for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.if_then_else( - by * block_M + i < k_current_seqlen and - k_base * block_N + j < q_current_seqlen, qkT[i, j], 0) + by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen, qkT[i, j], 0 + ) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k_base * block_N:(k_base + 1) * block_N], delta) + T.copy(Delta[bz, bx, k_base * block_N : (k_base + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -518,62 +473,37 @@ def flash_bwd( T.gemm(dsT_shared, K_shared, dq, transpose_A=True) for i, j in T.Parallel(block_N, dim_qk): if k_base * block_N + i < q_current_seqlen: - T.atomic_add( - dQ[q_start_idx + k_base * block_N + i, bx, j], - dq[i, j], - memory_order="relaxed") + T.atomic_add(dQ[q_start_idx + k_base * block_N + i, bx, j], dq[i, j], memory_order="relaxed") T.copy(dv, dv_shared) - T.copy( - dv_shared, - dV[bx % groups, k_start_idx + by * block_M:k_start_idx + by * block_M + block_M, - bx // groups, :]) + T.copy(dv_shared, dV[bx % groups, k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :]) T.copy(dk, dk_shared) - T.copy( - dk_shared, - dK[bx % groups, k_start_idx + by * block_M:k_start_idx + by * block_M + block_M, - bx // groups, :]) + T.copy(dk_shared, dK[bx % groups, k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :]) return flash_bwd @torch.compile class _attention(torch.autograd.Function): - @staticmethod - def forward(ctx, - q, - k, - v, - seqlens_q, - seqlens_k, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - causal, - groups=1, - use_atomic=True): + def forward( + ctx, q, k, v, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, groups=1, use_atomic=True + ): BATCH, N_CTX, H, D_HEAD_QK = q.shape D_HEAD_V = v.shape[-1] block_M = 128 block_N = 64 - q_unpad, indices_q, _, _ = unpad_input( - q, (torch.arange(N_CTX, device=q.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) - k_unpad, indices_k, _, _ = unpad_input( - k, (torch.arange(N_CTX, device=k.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) - v_unpad, _, _, _ = unpad_input( - v, (torch.arange(N_CTX, device=v.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) + q_unpad, indices_q, _, _ = unpad_input(q, (torch.arange(N_CTX, device=q.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) + k_unpad, indices_k, _, _ = unpad_input(k, (torch.arange(N_CTX, device=k.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) + v_unpad, _, _, _ = unpad_input(v, (torch.arange(N_CTX, device=v.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) total_q = q_unpad.shape[0] total_kv = k_unpad.shape[0] - mod = flashattn_fwd(BATCH, total_q, total_kv, N_CTX, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, - causal, block_M, block_N, groups) - o_unpad, lse = mod(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k) + kernel = flashattn_fwd(BATCH, total_q, total_kv, N_CTX, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups) + o_unpad, lse = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k) o = pad_input(o_unpad, indices_q, BATCH, N_CTX) - ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k, - cu_seqlens_q, cu_seqlens_k) + ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k) ctx.batch = BATCH ctx.causal = causal ctx.use_atomic = use_atomic @@ -588,8 +518,7 @@ def backward(ctx, do): N_CTX = do.shape[1] q, k, v, o, lse_clone, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors # lse_clone = lse.clone() - do_unpad, _, _, _ = unpad_input( - do, (torch.arange(N_CTX, device=do.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) + do_unpad, _, _, _ = unpad_input(do, (torch.arange(N_CTX, device=do.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) total_q, H, D_HEAD_QK = q.shape total_kv, HEAD_KV, D_HEAD_V = v.shape groups = H // HEAD_KV @@ -604,7 +533,6 @@ def maybe_contiguous(x): block_M = 128 block_N = 32 mod_prep = flashattn_bwd_preprocess(BATCH, H, total_q, N_CTX, ctx.max_seqlen_q, D_HEAD_V) - mod_post = flashattn_bwd_postprocess(total_q, total_kv, H, HEAD_KV, D_HEAD_QK, D_HEAD_V) delta = mod_prep(o, do, cu_seqlens_q) if ctx.use_atomic: @@ -622,12 +550,12 @@ def maybe_contiguous(x): block_N, threads=256, num_stages=2, - groups=groups) + groups=groups, + ) dq = torch.zeros_like(q, dtype=torch.float32) dk = torch.zeros_like(k, dtype=torch.float32) dv = torch.zeros_like(v, dtype=torch.float32) kernel(q, k, v, do, lse_clone, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv) - dq, dk, dv = mod_post(dq, dk, dv) else: kernel = flashattn_bwd_split( BATCH, @@ -643,13 +571,14 @@ def maybe_contiguous(x): block_N, threads=256, num_stages=2, - groups=groups) + groups=groups, + ) + mod_post = flashattn_bwd_postprocess(total_q, total_kv, H, HEAD_KV, D_HEAD_QK, D_HEAD_V) dq = torch.zeros_like(q, dtype=torch.float32) dk = torch.empty(groups, *k.shape, dtype=torch.float16, device=q.device) dv = torch.empty(groups, *v.shape, dtype=torch.float16, device=q.device) kernel(q, k, v, do, lse_clone, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv) - dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), - torch.zeros_like(v, dtype=torch.float32)) + dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32)) dk, dv = dk.sum(0), dv.sum(0) dq = pad_input(dq, ctx.indices_q, BATCH, N_CTX) @@ -668,15 +597,13 @@ def ref_program(Q, K, V, padding_mask, is_causal, groups=1): # HQ = HKV * groups # To handle precision issue Q, K, V = Q.float(), K.float(), V.float() - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim_qk = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) if padding_mask is not None: scores.masked_fill_(rearrange(~padding_mask, "b s -> b 1 1 s"), float("-inf")) @@ -684,41 +611,35 @@ def ref_program(Q, K, V, padding_mask, is_causal, groups=1): seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) if padding_mask is not None: output.masked_fill_(rearrange(~padding_mask, "b s -> b s 1 1"), 0.0) return output -def main(BATCH: int = 1, - H: int = 32, - N_CTX: int = 256, - D_HEAD_QK: int = 192, - D_HEAD_V: int = 128, - groups: int = 16, - causal: bool = False, - use_atomic: bool = True): +def main( + BATCH: int = 1, + H: int = 32, + N_CTX: int = 256, + D_HEAD_QK: int = 192, + D_HEAD_V: int = 128, + groups: int = 16, + causal: bool = False, + use_atomic: bool = True, +): flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V total_flops = 3 * flops_per_qk + 2 * flops_per_v if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() head_kv = H // groups - K = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - V = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - dO = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() padding_mask = generate_random_padding_mask(N_CTX, BATCH, "cuda", mode="random") seqlens_q = padding_mask.sum(dim=-1, dtype=torch.int32) cu_seqlens_q = F.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0)) @@ -727,8 +648,7 @@ def main(BATCH: int = 1, # In training backward pass, seqlens_k should be the same as seqlens_q seqlens_k, cu_seqlens_k, max_seqlen_k = seqlens_q, cu_seqlens_q, max_seqlen_q - O = attention(Q, K, V, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, - max_seqlen_k, causal, groups, use_atomic) + O = attention(Q, K, V, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, groups, use_atomic) O.backward(dO, retain_graph=True) dQ, Q.grad = Q.grad.clone(), None dK, K.grad = K.grad.clone(), None @@ -765,22 +685,72 @@ def run1(): ) +def run_regression_perf(): + BATCH = 1 + H = 32 + N_CTX = 256 + D_HEAD_QK = 192 + D_HEAD_V = 128 + groups = 16 + causal = False + device = "cuda" + torch.manual_seed(42) + total_q = BATCH * N_CTX + total_kv = BATCH * N_CTX + head_kv = H // groups + Q = torch.randn(total_q, H, D_HEAD_QK, device=device, dtype=torch.half) + K = torch.randn(total_kv, head_kv, D_HEAD_QK, device=device, dtype=torch.half) + V = torch.randn(total_kv, head_kv, D_HEAD_V, device=device, dtype=torch.half) + O = torch.randn(total_q, H, D_HEAD_V, device=device, dtype=torch.half) + dO = torch.randn(total_q, H, D_HEAD_V, device=device, dtype=torch.half) + cu_seqlens_q = torch.arange(0, (BATCH + 1) * N_CTX, N_CTX, device=device, dtype=torch.int32) + cu_seqlens_k = cu_seqlens_q + max_seqlen_q = N_CTX + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, total_q, N_CTX, max_seqlen_q, D_HEAD_V) + kernel = flashattn_bwd_split( + BATCH, + total_q, + total_kv, + N_CTX, + H, + max_seqlen_q, + D_HEAD_QK, + D_HEAD_V, + causal, + block_M=128, + block_N=32, + threads=256, + num_stages=2, + groups=groups, + ) + dQ = torch.zeros_like(Q, dtype=torch.float32) + dK = torch.zeros(groups, total_kv, head_kv, D_HEAD_QK, device=device, dtype=torch.float16) + dV = torch.zeros(groups, total_kv, head_kv, D_HEAD_V, device=device, dtype=torch.float16) + Delta = mod_prep(O, dO, cu_seqlens_q) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, cu_seqlens_q, cu_seqlens_k, dQ, dK, dV) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": arch = nvcc.get_target_compute_version() print(f"Detected GPU compute capability: {arch}") assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0" parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') - parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') - parser.add_argument('--causal', action='store_true', help='Causal flag') - parser.add_argument('--groups', type=int, default=16, help='groups') - parser.add_argument( - '--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') - parser.add_argument( - '--use_split', action='store_true', default=False, help='Use split for dK/dV') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV") + parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV") args = parser.parse_args() # Can be set to True/False for testing args.causal = True @@ -794,5 +764,4 @@ def run1(): # Default: use atomic use_atomic = True - main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, - use_atomic) + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) diff --git a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py index ed07e7d9d..2da64472c 100644 --- a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py @@ -6,25 +6,27 @@ @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_fwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -39,26 +41,25 @@ def flash_fwd( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim_v): @@ -72,29 +73,31 @@ def flash_fwd( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim_v): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim_v] blk = 32 @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -103,50 +106,42 @@ def flash_bwd_prep( delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim_v, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd(batch, - heads, - seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -167,45 +162,30 @@ def flash_bwd( dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - }) - - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) - T.gemm( - K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) - T.gemm( - V_shared, - do, - dsT, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow, - wg_wait=-1) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) T.wait_wgmma(1) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -217,18 +197,17 @@ def flash_bwd( T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) T.wait_wgmma(0) T.copy(dq, dq_shared) - T.atomic_add(dQ[bz, k * block_N:(k + 1) * block_N, bx, :], dq_shared) + T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared) T.copy(dv, dv_shared) - T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared) + T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared) T.copy(dk, dk_shared) - T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared) + T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared) return flash_bwd @torch.compile class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): BATCH, N_CTX, H, D_HEAD_QK = q.shape @@ -246,7 +225,10 @@ def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): def backward(ctx, do): q, k, v, o, lse = ctx.saved_tensors BATCH, N_CTX, H, D_HEAD_QK = q.shape - HEAD_KV, D_HEAD_V, = v.shape[-2], v.shape[-1] + ( + HEAD_KV, + D_HEAD_V, + ) = v.shape[-2], v.shape[-1] groups = H // HEAD_KV def maybe_contiguous(x): @@ -260,18 +242,7 @@ def maybe_contiguous(x): mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) delta = mod_prep(o, do) - kernel = flashattn_bwd( - BATCH, - H, - N_CTX, - D_HEAD_QK, - D_HEAD_V, - ctx.causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=groups) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] @@ -294,52 +265,36 @@ def ref_program(Q, K, V, is_causal, groups=1): # K: [B, T, HK, D_QK] # V: [B, T, HV, D_V] # HQ = HKV * groups - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim_qk = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output -def main(BATCH: int = 1, - H: int = 32, - N_CTX: int = 256, - D_HEAD_QK: int = 192, - D_HEAD_V: int = 128, - groups: int = 16, - causal: bool = False): +def main(BATCH: int = 1, H: int = 32, N_CTX: int = 256, D_HEAD_QK: int = 192, D_HEAD_V: int = 128, groups: int = 16, causal: bool = False): flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V total_flops = 3 * flops_per_qk + 2 * flops_per_v if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() head_kv = H // groups - K = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - V = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - dO = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() O = attention(Q, K, V, causal, groups) O.backward(dO, retain_graph=True) dQ, Q.grad = Q.grad.clone(), None @@ -356,7 +311,7 @@ def main(BATCH: int = 1, torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) - print('All checks passed.✅') + print("All checks passed.✅") def run(): O_ref.backward(dO, retain_graph=True) @@ -374,15 +329,34 @@ def run1(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf( + BATCH: int = 1, H: int = 32, N_CTX: int = 256, D_HEAD_QK: int = 192, D_HEAD_V: int = 128, groups: int = 16, causal: bool = False +): + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + + head_kv = H // groups + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + O = attention(Q, K, V, causal, groups) + + def run1(): + O.backward(dO, retain_graph=True) + + from tilelang.profiler import do_bench + + return do_bench(run1, warmup=500, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') - parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') - parser.add_argument('--causal', action='store_true', help='Causal flag') - parser.add_argument('--groups', type=int, default=16, help='groups') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal) diff --git a/examples/flash_attention/example_gqa_fwd_bshd.py b/examples/flash_attention/example_gqa_fwd_bshd.py index 4d9d06a4f..e884a8158 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd.py +++ b/examples/flash_attention/example_gqa_fwd_bshd.py @@ -9,7 +9,6 @@ class FlashAttentionTuneSpace: - def __init__( self, block_sizes=(64, 128, 256), @@ -40,7 +39,7 @@ def get_configs(user_config=None): warp_M = block_M // warp_count warp_N = block_N // warp_count - if (warp_M % config.warp_alignment != 0 or warp_N % config.warp_alignment != 0): + if warp_M % config.warp_alignment != 0 or warp_N % config.warp_alignment != 0: continue shared_mem = 2 * config.dtype_bytes * config.dim * (block_M + block_N) @@ -48,114 +47,38 @@ def get_configs(user_config=None): continue for num_stages in config.num_stages_range: - valid_configs.append({ - "block_M": block_M, - "block_N": block_N, - "num_stages": num_stages, - "threads": threads, - }) + valid_configs.append( + { + "block_M": block_M, + "block_N": block_N, + "num_stages": num_stages, + "threads": threads, + } + ) return valid_configs @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_len, - dim, - is_causal, - groups=1, - block_M=64, - block_N=64, - num_stages=0, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_len, dim, is_causal, groups=1, block_M=64, block_N=64, num_stages=0, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] - dtype = "float16" - accum_dtype = "float" - - @T.macro - def MMA0( - K: T.Tensor(kv_shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(kv_shape, dtype), - V_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -171,25 +94,49 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) return main @@ -199,50 +146,34 @@ def ref_program(Q, K, V, is_causal, groups=1): # K: [B, T, HK, D] # V: [B, T, HV, D] # HQ = HKV * groups - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output -def main(batch: int = 1, - heads: int = 64, - seq_len: int = 4096, - dim: int = 128, - is_causal: bool = False, - groups: int = 16, - tune: bool = False): +def main( + batch: int = 1, heads: int = 64, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 16, tune: bool = False +): flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim total_flops = 2 * flops_per_matmul if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_len, - dim, - is_causal, - groups=groups, - block_M=64, - block_N=64, - num_stages=2, - threads=128) + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=64, block_N=64, num_stages=2, threads=128) ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) @@ -264,14 +195,22 @@ def main(batch: int = 1, print(f"Ref latency: {ref_latency}") +def run_regression_perf( + batch: int = 1, heads: int = 64, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 16, tune: bool = False +): + kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=64, block_N=64, num_stages=2, threads=128) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='heads') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') - parser.add_argument('--groups', type=int, default=16, help='groups') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") + parser.add_argument("--groups", type=int, default=16, help="groups") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups, args.tune) diff --git a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py index 1c1fc12d2..73a725d9f 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py @@ -24,9 +24,11 @@ def get_configs(): rep=10, ) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn( batch, heads, @@ -39,90 +41,19 @@ def flashattn( num_stages=0, threads=128, ): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] - dtype = "float16" - accum_dtype = "float" - - @T.macro - def MMA0( - K: T.Tensor(kv_shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(kv_shape, dtype), - V_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -138,30 +69,55 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined( - loop_range, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) return main @@ -171,23 +127,21 @@ def ref_program(Q, K, V, is_causal, groups=1): # K: [B, T, HK, D] # V: [B, T, HV, D] # HQ = HKV * groups - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -205,18 +159,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_len, - dim, - is_causal, - groups=groups, - block_M=128, - block_N=128, - num_stages=2, - threads=256) + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=128, block_N=128, num_stages=2, threads=256) ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) @@ -238,14 +182,28 @@ def main( print(f"Ref latency: {ref_latency}") +def run_regression_perf( + batch: int = 1, + heads: int = 64, + seq_len: int = 4096, + dim: int = 128, + is_causal: bool = False, + groups: int = 16, +): + kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=128, block_N=128, num_stages=2, threads=256) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='heads') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') - parser.add_argument('--groups', type=int, default=16, help='groups') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") + parser.add_argument("--groups", type=int, default=16, help="groups") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups, args.tune) diff --git a/examples/flash_attention/example_gqa_fwd_varlen.py b/examples/flash_attention/example_gqa_fwd_varlen.py index db16e1586..0e8e21c43 100644 --- a/examples/flash_attention/example_gqa_fwd_varlen.py +++ b/examples/flash_attention/example_gqa_fwd_varlen.py @@ -4,91 +4,36 @@ import tilelang import tilelang.language as T import tilelang.testing -from einops import rearrange, repeat from tilelang.profiler import do_bench from varlen_utils import generate_random_padding_mask, generate_qkv -def attention_ref( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - causal=False, - window_size=(-1, -1), - upcast=True, -): - if causal: - window_size = (window_size[0], 0) - dtype_og = q.dtype - if upcast: - q, k, v = q.float(), k.float(), v.float() - b, T, Hq, D = q.shape - S = k.shape[1] - scale = (1.0 / D)**0.5 - k = repeat(k, "b s h d -> b s (h g) d", g=Hq // k.shape[2]) - v = repeat(v, "b s h d -> b s (h g) d", g=Hq // v.shape[2]) - scores = torch.einsum("bthd,bshd->bhts", q, k) - left, right = window_size - left = S if left is None or left < 0 else int(left) - right = S if right is None or right < 0 else int(right) - t_idx = torch.arange(T, device=scores.device)[:, None] - s_idx = torch.arange(S, device=scores.device)[None, :] - visible_ts = (s_idx >= (t_idx - left)) & (s_idx <= (t_idx + right)) - visible_mask = visible_ts.unsqueeze(0).unsqueeze(0) - if key_padding_mask is not None: - k_keep = rearrange(key_padding_mask, "b s -> b 1 1 s") - visible_mask = visible_mask & k_keep - neg_inf = torch.finfo(scores.dtype).min - scores = scores * scale - scores = scores.masked_fill(~visible_mask, neg_inf) - attention = torch.softmax(scores, dim=-1).to(v.dtype) - if query_padding_mask is not None: - q_keep = rearrange(query_padding_mask, "b t -> b 1 t 1") - attention = attention.masked_fill(~q_keep, 0.0) - output = torch.einsum("bhts,bshd->bthd", attention, v) - if query_padding_mask is not None: - output = output.masked_fill(rearrange(~query_padding_mask, "b t -> b t 1 1"), 0.0) - return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) - - @tilelang.jit( - out_idx=[6], pass_configs={ + out_idx=[6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch_size, - groups, - UQ, - UKV, - heads, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=1, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch_size, groups, UQ, UKV, heads, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [UQ, heads, dim] kv_shape = [UKV, head_kv, dim] o_shape = [UQ, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def main( - Q_unpad: T.Tensor(q_shape, dtype), - K_unpad: T.Tensor(kv_shape, dtype), - V_unpad: T.Tensor(kv_shape, dtype), - cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), - cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), - max_seqlen_q: T.int32, - Output_unpad: T.Tensor(o_shape, dtype), + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(kv_shape, dtype), + V_unpad: T.Tensor(kv_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), + cu_seqlens_k: T.Tensor([batch_size + 1], T.int32), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), ): - with T.Kernel( - T.ceildiv(max_seqlen_q, block_M), heads, batch_size, - threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -102,11 +47,6 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - }) - batch_idx = bz head_idx = by kv_head_idx = head_idx // groups @@ -119,43 +59,42 @@ def main( q_current_seqlen = q_end_idx - q_start_idx kv_current_seqlen = k_end_idx - kv_start_idx - T.copy( - Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, head_idx, :], - Q_shared) + T.copy(Q_unpad[q_start_idx + bx * block_M : q_start_idx + (bx + 1) * block_M, head_idx, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) + offset = kv_current_seqlen - q_current_seqlen # always align on the right + max_visible_k_idx = offset + (bx + 1) * block_M loop_range = ( - T.min( - T.ceildiv(q_current_seqlen + - (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N)) - if is_causal else T.ceildiv(kv_current_seqlen, block_N)) + T.min(T.ceildiv(max_visible_k_idx, block_N), T.ceildiv(kv_current_seqlen, block_N)) + if is_causal + else T.ceildiv(kv_current_seqlen, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): - T.copy( - K_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N, - kv_head_idx, :], K_shared) + T.copy(K_unpad[kv_start_idx + k * block_N : kv_start_idx + (k + 1) * block_N, kv_head_idx, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, - j] = T.if_then_else((bx * block_M + i < k * block_N + j) or - (bx * block_M + i >= q_current_seqlen or - k * block_N + j >= kv_current_seqlen), -1e9, 0) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i + offset < k * block_N + j) + or (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), + -1e9, + 0, + ) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or - k * block_N + j >= kv_current_seqlen), -1e9, - 0) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), -1e9, 0 + ) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_M): scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) @@ -171,16 +110,15 @@ def main( for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] - T.copy( - V_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N, - kv_head_idx, :], V_shared) + T.copy(V_unpad[kv_start_idx + k * block_N : kv_start_idx + (k + 1) * block_N, kv_head_idx, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_M, dim): - acc_o[i, j] /= logsum[i] - T.copy(acc_o, O_shared) + # When sq > skv, some tokens can see nothing + acc_o[i, j] = 0 if is_causal and bx * block_M + i + offset < 0 else acc_o[i, j] / logsum[i] + T.copy(acc_o, O_shared) for i, d in T.Parallel(block_M, dim): if bx * block_M + i < q_current_seqlen: Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d] @@ -188,13 +126,9 @@ def main( return main -def main(batch: int = 1, - heads: int = 64, - q_seqlen: int = 2048, - k_seqlen: int = 2048, - dim: int = 128, - groups: int = 16, - is_causal: bool = False): +def main( + batch: int = 1, heads: int = 64, q_seqlen: int = 2048, k_seqlen: int = 2048, dim: int = 128, groups: int = 16, is_causal: bool = False +): assert heads % groups == 0, "heads must be divisible by groups" flops_per_matmul = 2.0 * batch * heads * q_seqlen * k_seqlen * dim @@ -232,55 +166,46 @@ def main(batch: int = 1, output_pad_fn, _, _, - ) = generate_qkv( - q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) UQ = q_unpad.shape[0] UKV = k_unpad.shape[0] - kernel = flashattn( - batch, - groups, - UQ, - UKV, - heads, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=2, - threads=256) + kernel = flashattn(batch, groups, UQ, UKV, heads, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) out = output_pad_fn(out_unpad) - out_ref, _ = attention_ref( - q, - k, - v, - query_padding_mask=query_padding_mask, - key_padding_mask=key_padding_mask, + import flash_attn + + fa_out_unpad = flash_attn.flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + 0.0, causal=is_causal, ) - torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) + fa_out = output_pad_fn(fa_out_unpad) + torch.testing.assert_close(out, fa_out, rtol=1e-2, atol=1e-2) + print("All checks passed.✅") - latency = do_bench( - lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q), - _n_warmup=5, - _n_repeat=5) + latency = do_bench(lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q), _n_warmup=5, _n_repeat=5) print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='query heads') - parser.add_argument('--groups', type=int, default=16, help='groups') - parser.add_argument('--q_seqlen', type=int, default=2048, help='query sequence length') - parser.add_argument('--k_seqlen', type=int, default=2048, help='key/value sequence length') - parser.add_argument('--dim', type=int, default=128, help='head dim') - parser.add_argument('--is_causal', action='store_true', help='causal attention') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="query heads") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--q_seqlen", type=int, default=2048, help="query sequence length") + parser.add_argument("--k_seqlen", type=int, default=2048, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=128, help="head dim") + parser.add_argument("--is_causal", action="store_true", help="causal attention") args = parser.parse_args() - main(args.batch, args.heads, args.q_seqlen, args.k_seqlen, args.dim, args.groups, - args.is_causal) + main(args.batch, args.heads, args.q_seqlen, args.k_seqlen, args.dim, args.groups, args.is_causal) diff --git a/examples/flash_attention/example_mha_bwd_bhsd.py b/examples/flash_attention/example_mha_bwd_bhsd.py index 1595ae764..34e8fefc5 100644 --- a/examples/flash_attention/example_mha_bwd_bhsd.py +++ b/examples/flash_attention/example_mha_bwd_bhsd.py @@ -7,22 +7,24 @@ @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_fwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - Output: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -38,29 +40,28 @@ def flash_fwd( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) # T.copy(Q_shared, Q_local) # for i, j in T.Parallel(block_M, dim): # Q_local[i, j] *= scale - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): @@ -74,29 +75,31 @@ def flash_fwd( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, heads, seq_len, dim] blk = 32 @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -105,68 +108,71 @@ def flash_bwd_prep( delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], do) + T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment - return T.Layout(dQ.shape, - lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) @tilelang.jit( - out_idx=[1], pass_configs={ + out_idx=[1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(batch, heads, seq_len, dim): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, heads, seq_len, dim] blk = 64 @T.prim_func def flash_bwd_post( - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(shape, dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) T.copy( - dQ[bz, by, bx * blk:(bx + 1) * blk, :], - dQ_out[bz, by, bx * blk:(bx + 1) * blk, :], + dQ[bz, by, bx * blk : (bx + 1) * blk, :], + dQ_out[bz, by, bx * blk : (bx + 1) * blk, :], ) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - sm_scale = (1.0 / dim)**0.5 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + sm_scale = (1.0 / dim) ** 0.5 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dK: T.Tensor(shape, dtype), # type: ignore - dV: T.Tensor(shape, dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) @@ -190,36 +196,36 @@ def flash_bwd( dv_shared = T.alloc_shared([block_M, dim], dtype) dk_shared = T.alloc_shared([block_M, dim], dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) - T.copy(K[bz, bx, by * block_M:(by + 1) * block_M, :], K_shared) - T.copy(V[bz, bx, by * block_M:(by + 1) * block_M, :], V_shared) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) + T.copy(K[bz, bx, by * block_M : (by + 1) * block_M, :], K_shared) + T.copy(V[bz, bx, by * block_M : (by + 1) * block_M, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=2): - T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q) + T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + # We don't need to handle OOB positions for non-causal cases, + # since OOB values won't affect other positions here. + T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -232,14 +238,13 @@ def flash_bwd( T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j]) T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, dV[bz, bx, by * block_M:(by + 1) * block_M, :]) - T.copy(dk_shared, dK[bz, bx, by * block_M:(by + 1) * block_M, :]) + T.copy(dv_shared, dV[bz, bx, by * block_M : (by + 1) * block_M, :]) + T.copy(dk_shared, dK[bz, bx, by * block_M : (by + 1) * block_M, :]) return flash_bwd class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal): BATCH, H, N_CTX, D_HEAD = q.shape @@ -281,15 +286,15 @@ def maybe_contiguous(x): def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(2) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) return output @@ -304,9 +309,7 @@ def main( total_flops = 5 * flops_per_matmul if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, H, N_CTX, D_HEAD, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, H, N_CTX, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_() K = torch.empty_like(Q).normal_().requires_grad_() V = torch.empty_like(Q).normal_().requires_grad_() dO = torch.randn_like(Q) @@ -345,12 +348,43 @@ def run1(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(): + BATCH = 1 + H = 16 + N_CTX = 512 + D_HEAD = 64 + causal = False + device = "cuda" + torch.manual_seed(0) + block_M = 64 + block_N = 64 if D_HEAD <= 64 else 32 + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, device=device, dtype=torch.half) + K = torch.randn_like(Q) + V = torch.randn_like(Q) + O = torch.randn_like(Q) + dO = torch.randn_like(Q) + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N) + dQ = torch.zeros(BATCH, H, N_CTX, D_HEAD, device=device, dtype=torch.float32) + dK = torch.zeros(BATCH, H, N_CTX, D_HEAD, device=device, dtype=torch.float16) + dV = torch.zeros(BATCH, H, N_CTX, D_HEAD, device=device, dtype=torch.float16) + Delta = mod_prep(O, dO) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, dQ, dK, dV) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head', type=int, default=64, help='Head dimension') - parser.add_argument('--causal', type=bool, default=False, help='Causal flag') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head", type=int, default=64, help="Head dimension") + parser.add_argument("--causal", type=bool, default=False, help="Causal flag") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) diff --git a/examples/flash_attention/example_mha_bwd.py b/examples/flash_attention/example_mha_bwd_bshd.py similarity index 65% rename from examples/flash_attention/example_mha_bwd.py rename to examples/flash_attention/example_mha_bwd_bshd.py index 543c2c0e7..fc8328fa4 100644 --- a/examples/flash_attention/example_mha_bwd.py +++ b/examples/flash_attention/example_mha_bwd_bshd.py @@ -7,22 +7,24 @@ @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_fwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - Output: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -38,25 +40,25 @@ def flash_fwd( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): @@ -70,29 +72,31 @@ def flash_fwd( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim] blk = 32 @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -101,68 +105,71 @@ def flash_bwd_prep( delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment - return T.Layout(dQ.shape, - lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + return T.Layout(dQ.shape, lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) @tilelang.jit( - out_idx=[1], pass_configs={ + out_idx=[1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(batch, heads, seq_len, dim): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim] blk = 64 @T.prim_func def flash_bwd_post( - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(shape, dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) T.copy( - dQ[bz, bx * blk:(bx + 1) * blk, by, :], - dQ_out[bz, bx * blk:(bx + 1) * blk, by, :], + dQ[bz, bx * blk : (bx + 1) * blk, by, :], + dQ_out[bz, bx * blk : (bx + 1) * blk, by, :], ) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - sm_scale = (1.0 / dim)**0.5 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + sm_scale = (1.0 / dim) ** 0.5 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dK: T.Tensor(shape, dtype), # type: ignore - dV: T.Tensor(shape, dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) @@ -186,33 +193,36 @@ def flash_bwd( dv_shared = T.alloc_shared([block_M, dim], dtype) dk_shared = T.alloc_shared([block_M, dim], dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - }) - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=2): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + # We don't need to handle OOB positions for non-causal cases, + # since OOB values won't affect other positions here. + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -225,14 +235,13 @@ def flash_bwd( T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :]) - T.copy(dk_shared, dK[bz, by * block_M:(by + 1) * block_M, bx, :]) + T.copy(dv_shared, dV[bz, by * block_M : (by + 1) * block_M, bx, :]) + T.copy(dk_shared, dK[bz, by * block_M : (by + 1) * block_M, bx, :]) return flash_bwd class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal): BATCH, N_CTX, H, D_HEAD = q.shape @@ -274,15 +283,15 @@ def maybe_contiguous(x): def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -297,9 +306,7 @@ def main( total_flops = 5 * flops_per_matmul if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_() K = torch.empty_like(Q).normal_().requires_grad_() V = torch.empty_like(Q).normal_().requires_grad_() dO = torch.randn_like(Q) @@ -336,12 +343,43 @@ def run1(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(): + BATCH = 1 + H = 16 + N_CTX = 512 + D_HEAD = 64 + causal = False + device = "cuda" + torch.manual_seed(42) + block_M = 64 + block_N = 64 if D_HEAD <= 64 else 32 + Q = torch.randn(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.half) + K = torch.randn_like(Q) + V = torch.randn_like(Q) + O = torch.randn_like(Q) + dO = torch.randn_like(Q) + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N) + dQ = torch.zeros(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.float32) + dK = torch.zeros(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.float16) + dV = torch.zeros(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.float16) + Delta = mod_prep(O, dO) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, dQ, dK, dV) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head', type=int, default=64, help='Head dimension') - parser.add_argument('--causal', type=bool, default=False, help='Causal flag') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head", type=int, default=64, help="Head dimension") + parser.add_argument("--causal", type=bool, default=False, help="Causal flag") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) diff --git a/examples/flash_attention/example_mha_bwd_wgmma_pipelined.py b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py similarity index 64% rename from examples/flash_attention/example_mha_bwd_wgmma_pipelined.py rename to examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py index 7ad417ef5..c0fe4e33d 100644 --- a/examples/flash_attention/example_mha_bwd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py @@ -7,22 +7,24 @@ @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_fwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - Output: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -37,27 +39,26 @@ def flash_fwd( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): @@ -71,29 +72,31 @@ def flash_fwd( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim] blk = 32 @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -102,37 +105,39 @@ def flash_bwd_prep( delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - sm_scale = (1.0 / dim)**0.5 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + sm_scale = (1.0 / dim) ** 0.5 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dK: T.Tensor(shape, dtype), # type: ignore - dV: T.Tensor(shape, dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) @@ -157,47 +162,34 @@ def flash_bwd( dk_shared = T.alloc_shared([block_M, dim], dtype) dq_shared = T.alloc_shared([block_N, dim], accum_dtype) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), - }) - - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=2): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) - T.gemm( - K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) - T.gemm( - V_shared, - do, - dsT, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow, - wg_wait=-1) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) T.wait_wgmma(1) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + # We don't need to handle OOB positions for non-causal cases, + # since OOB values won't affect other positions here. T.wait_wgmma(0) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -208,17 +200,16 @@ def flash_bwd( T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) T.wait_wgmma(0) T.copy(dq, dq_shared) - T.atomic_add(dQ[bz, k * block_N:(k + 1) * block_N, bx, :], dq_shared) + T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared) T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :]) - T.copy(dk_shared, dK[bz, by * block_M:(by + 1) * block_M, bx, :]) + T.copy(dv_shared, dV[bz, by * block_M : (by + 1) * block_M, bx, :]) + T.copy(dk_shared, dK[bz, by * block_M : (by + 1) * block_M, bx, :]) return flash_bwd class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal): BATCH, N_CTX, H, D_HEAD = q.shape @@ -260,15 +251,15 @@ def maybe_contiguous(x): def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -283,9 +274,7 @@ def main( total_flops = 5 * flops_per_matmul if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_() K = torch.empty_like(Q).normal_().requires_grad_() V = torch.empty_like(Q).normal_().requires_grad_() dO = torch.randn_like(Q) @@ -305,7 +294,7 @@ def main( assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) - print('All checks passed.✅') + print("All checks passed.✅") def run(): O_ref.backward(dO, retain_graph=True) @@ -321,12 +310,44 @@ def run1(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(): + BATCH = 1 + H = 32 + N_CTX = 256 + D_HEAD = 64 + causal = False + device = "cuda" + torch.manual_seed(0) + block_M = 128 + block_N = 128 if D_HEAD <= 64 else 32 + Q = torch.randn(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.half) + K = torch.randn_like(Q) + V = torch.randn_like(Q) + O = torch.randn_like(Q) + dO = torch.randn_like(Q) + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N) + dQ = torch.zeros(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.float32) + dK = torch.zeros_like(Q, dtype=torch.float16) + dV = torch.zeros_like(Q, dtype=torch.float16) + Delta = mod_prep(O, dO) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, dQ, dK, dV) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head', type=int, default=64, help='Head dimension') - parser.add_argument('--causal', type=bool, default=False, help='Causal flag') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head", type=int, default=64, help="Head dimension") + parser.add_argument("--causal", type=bool, default=False, help="Causal flag") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) diff --git a/examples/flash_attention/example_mha_fwd_bhsd.py b/examples/flash_attention/example_mha_fwd_bhsd.py index f07f7a618..400736541 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd.py +++ b/examples/flash_attention/example_mha_fwd_bhsd.py @@ -15,107 +15,27 @@ def get_configs(): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=1, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 past_len = seq_kv - seq_q assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" - @T.macro - def MMA0( - K: T.Tensor(kv_shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - q_idx = bx * block_M + i + past_len - k_idx = k * block_N + j - acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(kv_shape, dtype), - V_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -131,43 +51,69 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv( - (bx + 1) * block_M + - past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) + T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + if is_causal + else T.ceildiv(seq_kv, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_q = Q.size(2) seq_kv = K.size(2) mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) return output @@ -185,18 +131,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=1, - threads=128) + if not tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128) ref_program_processed = partial(ref_program, is_causal=is_causal) profiler = kernel.get_profiler() @@ -219,14 +155,28 @@ def main( print(f"Ref latency: {ref_latency}") +def run_regression_perf( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 64, + is_causal: bool = False, + tune: bool = False, +): + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) + profiler = kernel.get_profiler() + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=1, help='heads') - parser.add_argument('--seq_q', type=int, default=256, help='query sequence length') - parser.add_argument('--seq_kv', type=int, default=256, help='key/value sequence length') - parser.add_argument('--dim', type=int, default=64, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=1, help="heads") + parser.add_argument("--seq_q", type=int, default=256, help="query sequence length") + parser.add_argument("--seq_kv", type=int, default=256, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal", default=False) + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py index 26167b34b..90514f762 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py @@ -15,107 +15,27 @@ def get_configs(): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=2, - threads=256): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 past_len = seq_kv - seq_q assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" - @T.macro - def MMA0( - K: T.Tensor(kv_shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - q_idx = bx * block_M + i + past_len - k_idx = k * block_N + j - acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(kv_shape, dtype), - V_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -131,48 +51,75 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv( - (bx + 1) * block_M + - past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) + T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + if is_causal + else T.ceildiv(seq_kv, block_N) + ) for k in T.Pipelined( - loop_range, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_q = Q.size(2) seq_kv = K.size(2) mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) return output @@ -190,18 +137,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=2, - threads=256) + if not tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) ref_program_processed = partial(ref_program, is_causal=is_causal) profiler = kernel.get_profiler() @@ -224,14 +161,28 @@ def main( print(f"Ref latency: {ref_latency}") +def run_regression_perf( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + is_causal: bool = False, + tune: bool = False, +): + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) + profiler = kernel.get_profiler() + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_q', type=int, default=4096, help='query sequence length') - parser.add_argument('--seq_kv', type=int, default=4096, help='key/value sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="query sequence length") + parser.add_argument("--seq_kv", type=int, default=4096, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_bshd.py b/examples/flash_attention/example_mha_fwd_bshd.py index 6a1f707e5..e584971c0 100644 --- a/examples/flash_attention/example_mha_fwd_bshd.py +++ b/examples/flash_attention/example_mha_fwd_bshd.py @@ -15,100 +15,23 @@ def get_configs(): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_len, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=1, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_len, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] - dtype = "float16" - accum_dtype = "float" - - @T.macro - def MMA0( - K: T.Tensor(shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(shape, dtype), - V_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def main( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - Output: T.Tensor(shape, dtype), + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -124,40 +47,64 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) return main def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -174,17 +121,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_len, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=1, - threads=128) + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=1, threads=128) ref_program_processed = partial(ref_program, is_causal=is_causal) profiler = kernel.get_profiler() profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) @@ -206,13 +144,19 @@ def main( print(f"Ref latency: {ref_latency}") +def run_regression_perf(batch: int = 8, heads: int = 32, seq_len: int = 4096, dim: int = 128, is_causal: bool = False): + kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=1, threads=128) + profiler = kernel.get_profiler() + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py index 3928db4c3..d6e1490c9 100644 --- a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py @@ -15,100 +15,23 @@ def get_configs(): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_len, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=2, - threads=256): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] - dtype = "float16" - accum_dtype = "float" - - @T.macro - def MMA0( - K: T.Tensor(shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(shape, dtype), - V_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def main( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - Output: T.Tensor(shape, dtype), + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -124,45 +47,70 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined( - loop_range, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) return main def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -179,17 +127,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_len, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=2, - threads=256) + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) ref_program_processed = partial(ref_program, is_causal=is_causal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) @@ -211,13 +150,19 @@ def main( print(f"Ref latency: {ref_latency}") +def run_regression_perf(batch: int = 8, heads: int = 32, seq_len: int = 4096, dim: int = 128, is_causal: bool = False): + kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_varlen.py b/examples/flash_attention/example_mha_fwd_varlen.py index f381e900a..0f3610b11 100644 --- a/examples/flash_attention/example_mha_fwd_varlen.py +++ b/examples/flash_attention/example_mha_fwd_varlen.py @@ -4,109 +4,51 @@ import tilelang.language as T import tilelang.testing import argparse +from tilelang.profiler import do_bench +from tilelang.autotuner import set_autotune_inputs, autotune import torch -from einops import rearrange, repeat from varlen_utils import generate_random_padding_mask, generate_qkv +import itertools -def attention_ref( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - causal=False, - window_size=(-1, -1), # -1 means infinite window size - upcast=True, -): - """ - Arguments: - q: (batch_size, seqlen_q, nheads, head_dim) - k: (batch_size, seqlen_k, nheads_k, head_dim) - v: (batch_size, seqlen_k, nheads_k, head_dim) - query_padding_mask: (batch_size, seqlen_q) - key_padding_mask: (batch_size, seqlen_k) - attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) - dropout_p: float - dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) - causal: whether to apply causal masking - window_size: (int, int), left and right window size - upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast - output back to fp16/bf16. - reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) - without changing the math. This is to estimate the numerical error from operation - reordering. - Output: - output: (batch_size, seqlen_q, nheads, head_dim) - attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout - """ - if causal: - window_size = (window_size[0], 0) - dtype_og = q.dtype - if upcast: - q, k, v = q.float(), k.float(), v.float() - dim = q.shape[-1] - scale = (1.0 / dim)**0.5 # log2(e) - k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) - v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) - scores = torch.einsum("bthd,bshd->bhts", q, k) - if key_padding_mask is not None: - scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) - # scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0) - scores = scores * scale - attention = torch.softmax(scores, dim=-1).to(v.dtype) - - # We want to mask here so that the attention matrix doesn't have any NaNs - # Otherwise we'll get NaN in dV - if query_padding_mask is not None: - attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) - output = torch.einsum("bhts,bshd->bthd", attention, v) - if query_padding_mask is not None: - output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) - return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) +def get_configs(): + iter_params = dict(block_M=[64, 128], block_N=[64, 128], num_stages=[0, 1, 2, 3], threads=[128, 256]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] +@autotune(configs=get_configs()) @tilelang.jit( - out_idx=[6], pass_configs={ + out_idx=[6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch_size, - UQ, - UKV, - heads, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=0, - threads=32): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch_size, UQ, UKV, heads, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) q_shape = [UQ, heads, dim] k_shape = [UKV, heads, dim] v_shape = [UKV, heads, dim] o_shape = [UQ, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def main( - Q_unpad: T.Tensor(q_shape, dtype), - K_unpad: T.Tensor(k_shape, dtype), - V_unpad: T.Tensor(v_shape, dtype), - cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), - cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), - max_seqlen_q: T.int32, - Output_unpad: T.Tensor(o_shape, dtype), + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(k_shape, dtype), + V_unpad: T.Tensor(v_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), + cu_seqlens_k: T.Tensor([batch_size + 1], T.int32), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), ): - with T.Kernel( - T.ceildiv(max_seqlen_q, block_M), heads, batch_size, - threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_M, dim], dtype, "shared") - K_shared = T.alloc_shared([block_N, dim], dtype, "shared") - V_shared = T.alloc_shared([block_N, dim], dtype, "shared") - O_shared = T.alloc_shared([block_M, dim], dtype, "shared") + with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) acc_o = T.alloc_fragment([block_M, dim], accum_dtype) @@ -120,46 +62,46 @@ def main( head_idx = by q_start_idx = cu_seqlens_q[batch_idx] - k_start_idx = cu_seqlens_k[batch_idx] - v_start_idx = cu_seqlens_k[batch_idx] + kv_start_idx = cu_seqlens_k[batch_idx] q_end_idx = cu_seqlens_q[batch_idx + 1] - k_end_idx = cu_seqlens_k[batch_idx + 1] - v_end_idx = cu_seqlens_k[batch_idx + 1] + kv_end_idx = cu_seqlens_k[batch_idx + 1] q_current_seqlen = q_end_idx - q_start_idx - k_current_seqlen = k_end_idx - k_start_idx - v_current_seqlen = v_end_idx - v_start_idx + kv_current_seqlen = kv_end_idx - kv_start_idx - for i, d in T.Parallel(block_M, dim): - if bx * block_M + i < q_current_seqlen: - Q_shared[i, d] = Q_unpad[q_start_idx + bx * block_M + i, head_idx, d] - else: - Q_shared[i, d] = 0 + T.copy( + Q_unpad[q_start_idx + bx * block_M : q_start_idx + bx * block_M + block_M, head_idx, :], Q_shared + ) # OOB positions will be handled below T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = T.ceildiv(k_current_seqlen, block_N) + offset = kv_current_seqlen - q_current_seqlen # always align on the right + loop_range = ( + T.min(T.ceildiv(offset + (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N)) + if is_causal + else T.ceildiv(kv_current_seqlen, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): # Q * K - for i, d in T.Parallel(block_N, dim): - if k * block_N + i < k_current_seqlen: - K_shared[i, d] = K_unpad[k_start_idx + k * block_N + i, head_idx, d] - else: - K_shared[i, d] = 0 + T.copy( + K_unpad[kv_start_idx + k * block_N : kv_start_idx + k * block_N + block_N, head_idx, :], K_shared + ) # OOB positions will be handled below if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and - (bx * block_M + i >= q_current_seqlen or - k * block_N + j >= k_current_seqlen), - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i + offset < k * block_N + j) + or (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), + -1e9, + 0, + ) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or - k * block_N + j >= k_current_seqlen), - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), -1e9, 0 + ) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -167,6 +109,8 @@ def main( T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # in the first ceil_div(kBlockM, kBlockN) steps. @@ -189,18 +133,17 @@ def main( acc_o[i, j] *= scores_scale[i] # V * softmax(Q * K) - for i, d in T.grid(block_N, dim): - if k * block_N + i < v_current_seqlen: - V_shared[i, d] = V_unpad[v_start_idx + k * block_N + i, head_idx, d] - else: - V_shared[i, d] = 0 + T.copy( + V_unpad[kv_start_idx + k * block_N : kv_start_idx + k * block_N + block_N, head_idx, :], V_shared + ) # OOB positions' weights are 0 T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_M, dim): - acc_o[i, j] /= logsum[i] - T.copy(acc_o, O_shared) + # When sq > skv, some tokens can see nothing + acc_o[i, j] = 0 if is_causal and bx * block_M + i + offset < 0 else acc_o[i, j] / logsum[i] + T.copy(acc_o, O_shared) for i, d in T.Parallel(block_M, dim): if bx * block_M + i < q_current_seqlen: Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d] @@ -208,19 +151,17 @@ def main( return main -def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128): +def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128, causal: bool = False, tune: bool = False): flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim total_flops = 2 * flops_per_matmul tilelang.testing.set_random_seed(0) - causal = False if causal: total_flops *= 0.5 dtype = torch.float16 device = torch.device("cuda") - window_size = (-1, -1) q = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) k = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) @@ -240,30 +181,23 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128): k, v, output_pad_fn, - dq_pad_fn, - dk_pad_fn, - ) = generate_qkv( - q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + _, + _, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) UQ = q_unpad.shape[0] # unpadded query length - UK = k_unpad.shape[0] # unpadded key length UKV = k_unpad.shape[0] # unpadded query key length - kernel = flashattn(batch, UQ, UKV, heads, dim, causal) + if tune: + with set_autotune_inputs(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q): + kernel = flashattn(batch, UQ, UKV, heads, dim, causal) + else: + kernel = flashattn(batch, UQ, UKV, heads, dim, causal, block_M=64, block_N=64, num_stages=1, threads=128) + # NOTE: (128, 128, 2or3, 256) is recommended for Hopper out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) out = output_pad_fn(out_unpad) - out_ref, _ = attention_ref( - q, - k, - v, - query_padding_mask, - key_padding_mask, - causal=causal, - ) - torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) - import flash_attn fla_out_unpad = flash_attn.flash_attn_varlen_func( @@ -282,13 +216,67 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128): print("All checks passed.✅") + # benchmark + t = do_bench(lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)) + print(f"Tilelang time: {t} ms") + print(f"Tilelang: {total_flops / t * 1e-9} TFlops") + t = do_bench( + lambda: flash_attn.flash_attn_varlen_func( + q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, 0.0, causal=causal + ) + ) + print(f"FA2 time: {t} ms") + print(f"FA2: {total_flops / t * 1e-9} TFlops") + + +def run_regression_perf(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128, causal: bool = False): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + tilelang.testing.set_random_seed(0) + if causal: + total_flops *= 0.5 + dtype = torch.float16 + device = torch.device("cuda") + q = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) + k = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) + v = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) + query_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random") + key_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random") + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + UQ = q_unpad.shape[0] + UKV = k_unpad.shape[0] + kernel = flashattn(batch, UQ, UKV, heads, dim, causal, block_M=128, block_N=128, num_stages=2, threads=256) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) + + return do_bench(run_kernel_only, backend="cupti") + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='heads') - parser.add_argument('--seq_len', type=int, default=2048, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_len", type=int, default=2048, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", default=False, help="causal attention") + parser.add_argument("--tune", action="store_true", default=False, help="tune the kernel") args = parser.parse_args() - main(args.batch, args.heads, args.seq_len, args.dim) + main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/regression_example_flash_attention.py b/examples/flash_attention/regression_example_flash_attention.py new file mode 100644 index 000000000..8710bbb6e --- /dev/null +++ b/examples/flash_attention/regression_example_flash_attention.py @@ -0,0 +1,74 @@ +import tilelang.testing +import example_gqa_fwd_bshd +import example_gqa_fwd_bshd_wgmma_pipelined +import example_mha_fwd_bhsd +import example_mha_fwd_bhsd_wgmma_pipelined +import example_mha_fwd_bshd +import example_mha_fwd_bshd_wgmma_pipelined +import example_mha_fwd_varlen +import example_gqa_bwd_tma_reduce_varlen +import example_gqa_bwd +import example_gqa_bwd_wgmma_pipelined +import example_mha_bwd_bshd +import example_mha_bwd_bhsd +import example_mha_bwd_bshd_wgmma_pipelined + + +def regression_example_gqa_bwd_tma_reduce_varlen(): + tilelang.testing.process_func(example_gqa_bwd_tma_reduce_varlen.run_regression_perf) + + +def regression_example_gqa_bwd(): + tilelang.testing.process_func(example_gqa_bwd.run_regression_perf) + + +def regression_example_gqa_bwd_wgmma_pipelined(): + tilelang.testing.process_func(example_gqa_bwd_wgmma_pipelined.run_regression_perf) + + +def regression_example_mha_bwd_bshd(): + tilelang.testing.process_func(example_mha_bwd_bshd.run_regression_perf) + + +def regression_example_mha_bwd_bhsd(): + tilelang.testing.process_func(example_mha_bwd_bhsd.run_regression_perf) + + +def regression_example_mha_bwd_bshd_wgmma_pipelined(): + tilelang.testing.process_func(example_mha_bwd_bshd_wgmma_pipelined.run_regression_perf) + + +def regression_example_gqa_fwd_bshd_wgmma_pipelined(): + tilelang.testing.process_func( + example_gqa_fwd_bshd_wgmma_pipelined.run_regression_perf, batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16 + ) + + +def regression_example_gqa_fwd_bshd(): + tilelang.testing.process_func( + example_gqa_fwd_bshd.run_regression_perf, batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16 + ) + + +def regression_example_mha_fwd_bhsd_wgmma_pipelined(): + tilelang.testing.process_func(example_mha_fwd_bhsd_wgmma_pipelined.run_regression_perf) + + +def regression_example_mha_fwd_bhsd(): + tilelang.testing.process_func(example_mha_fwd_bhsd.run_regression_perf) + + +def regression_example_mha_fwd_bshd_wgmma_pipelined(): + tilelang.testing.process_func(example_mha_fwd_bshd_wgmma_pipelined.run_regression_perf, batch=1, heads=32, seq_len=256) + + +def regression_example_mha_fwd_bshd(): + tilelang.testing.process_func(example_mha_fwd_bshd.run_regression_perf, batch=1, seq_len=256) + + +def regression_example_mha_fwd_varlen(): + tilelang.testing.process_func(example_mha_fwd_varlen.run_regression_perf, batch=4, heads=16, seq_len=512, dim=64) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/flash_attention/test_example_flash_attention.py b/examples/flash_attention/test_example_flash_attention.py index f4932aee9..a74bf071b 100644 --- a/examples/flash_attention/test_example_flash_attention.py +++ b/examples/flash_attention/test_example_flash_attention.py @@ -2,7 +2,7 @@ import example_gqa_bwd import example_gqa_bwd_wgmma_pipelined -import example_mha_bwd +import example_mha_bwd_bshd import example_mha_bwd_bhsd import example_mha_fwd_bhsd_wgmma_pipelined import example_gqa_fwd_bshd @@ -10,9 +10,10 @@ import example_gqa_fwd_bshd_wgmma_pipelined import example_mha_fwd_bshd_wgmma_pipelined import example_mha_fwd_varlen -import example_mha_bwd_wgmma_pipelined +import example_mha_bwd_bshd_wgmma_pipelined import example_mha_fwd_bhsd import example_gqa_bwd_tma_reduce_varlen +import example_gqa_fwd_varlen @tilelang.testing.requires_cuda @@ -33,7 +34,7 @@ def test_example_gqa_bwd_wgmma_pipelined(): @tilelang.testing.requires_cuda def test_example_mha_bwd(): - example_mha_bwd.main( + example_mha_bwd_bshd.main( BATCH=1, H=16, N_CTX=512, @@ -56,20 +57,18 @@ def test_example_mha_bwd_bhsd(): @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_mha_bwd_wgmma_pipelined(): - example_mha_bwd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False) + example_mha_bwd_bshd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False) @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_gqa_fwd_bshd_wgmma_pipelined(): - example_gqa_fwd_bshd_wgmma_pipelined.main( - batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) + example_gqa_fwd_bshd_wgmma_pipelined.main(batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) @tilelang.testing.requires_cuda def test_example_gqa_fwd_bshd(): - example_gqa_fwd_bshd.main( - batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) + example_gqa_fwd_bshd.main(batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) @tilelang.testing.requires_cuda @@ -96,7 +95,14 @@ def test_example_mha_fwd_bshd(): @tilelang.testing.requires_cuda def test_example_mha_fwd_varlen(): - example_mha_fwd_varlen.main(batch=4, heads=16, seq_len=512, dim=64) + example_mha_fwd_varlen.main(batch=4, heads=16, seq_len=512, dim=64, causal=False) + example_mha_fwd_varlen.main(batch=4, heads=16, seq_len=512, dim=64, causal=True) + + +@tilelang.testing.requires_cuda +def test_example_gqa_fwd_varlen(): + example_gqa_fwd_varlen.main(batch=4, heads=16, q_seqlen=512, k_seqlen=512, dim=64, is_causal=False) + example_gqa_fwd_varlen.main(batch=4, heads=16, q_seqlen=512, k_seqlen=512, dim=64, is_causal=True) if __name__ == "__main__": diff --git a/examples/flash_attention/varlen_utils.py b/examples/flash_attention/varlen_utils.py index 4301215d5..43e21cc3b 100644 --- a/examples/flash_attention/varlen_utils.py +++ b/examples/flash_attention/varlen_utils.py @@ -9,22 +9,14 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): if mode == "full": lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) elif mode == "random": - lengths = torch.randint( - max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) + lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) elif mode == "third": lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) - padding_mask = ( - repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths) + padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths return padding_mask -def generate_qkv(q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - kvpacked=False, - qkvpacked=False): +def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False): """ Arguments: q: (batch_size, seqlen_q, nheads, d) @@ -39,15 +31,12 @@ def generate_qkv(q, if query_padding_mask is not None: q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) - output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q - ) + output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q) else: q_unpad = rearrange(q, "b s h d -> (b s) h d") - cu_seqlens_q = torch.arange( - 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device) + cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device) max_seqlen_q = seqlen_q - output_pad_fn = lambda output_unpad: rearrange( - output_unpad, "(b s) h d -> b s h d", b=batch_size) + output_pad_fn = lambda output_unpad: rearrange(output_unpad, "(b s) h d -> b s h d", b=batch_size) if key_padding_mask is not None: k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) @@ -55,8 +44,7 @@ def generate_qkv(q, else: k_unpad = rearrange(k, "b s h d -> (b s) h d") v_unpad = rearrange(v, "b s h d -> (b s) h d") - cu_seqlens_k = torch.arange( - 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device) + cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device) max_seqlen_k = seqlen_k if qkvpacked: @@ -67,8 +55,7 @@ def generate_qkv(q, if query_padding_mask is not None: dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) else: - dqkv_pad_fn = lambda dqkv_unpad: rearrange( - dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) + dqkv_pad_fn = lambda dqkv_unpad: rearrange(dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) return ( qkv_unpad.detach().requires_grad_(), cu_seqlens_q, @@ -84,8 +71,7 @@ def generate_qkv(q, if key_padding_mask is not None: dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) else: - dkv_pad_fn = lambda dkv_unpad: rearrange( - dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) + dkv_pad_fn = lambda dkv_unpad: rearrange(dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) return ( q_unpad.detach().requires_grad_(), kv_unpad.detach().requires_grad_(), diff --git a/examples/flash_decoding/example_gqa_decode.py b/examples/flash_decoding/example_gqa_decode.py index 9ec3a0265..9e6f36017 100644 --- a/examples/flash_decoding/example_gqa_decode.py +++ b/examples/flash_decoding/example_gqa_decode.py @@ -15,18 +15,12 @@ def get_configs(): block_N = [64, 128] block_H = [64] - num_split = [2, 4, 8] + num_split = [1, 2, 4, 8] num_stages = [1, 2, 3] threads = [128] _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) - configs = [{ - 'block_N': c[0], - 'block_H': c[1], - 'num_split': c[2], - 'num_stages': c[3], - 'threads': c[4] - } for c in _configs] + configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs] return configs @@ -42,43 +36,42 @@ def get_heuristic_config() -> Tuple[Dict, int]: if sm_version == 89: cfg = dict(block_N=128, block_H=64, num_split=1, num_stages=0, threads=128) else: - cfg = dict(block_N=128, block_H=64, num_split=1, num_stages=2, threads=128) + cfg = dict(block_N=128, block_H=64, num_split=8, num_stages=2, threads=128) return cfg, sm_version # TODO(lei): fix warp specialized and tma lower pass def get_pass_configs(): - return { - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - } + return {tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[6], pass_configs=get_pass_configs()) -def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages, - threads): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) +def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages, threads): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape_q = [batch, heads, dim] shape_k = [batch, seqlen_kv, groups, dim] shape_v = [batch, seqlen_kv, groups, dim] shape_o = [batch, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // groups part_shape = [batch, heads, num_split, dim] valid_block_H = min(block_H, kv_group_num) valid_block_N = min(block_N, seqlen_kv // num_split) - @T.macro - def flash_attn( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - Output: T.Tensor([batch, heads, dim], dtype), + @T.prim_func + def flashattn_gqa_decode_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_o, dtype), ): + # split with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) @@ -96,25 +89,43 @@ def flash_attn( bid = bx hid = by + sid = bz cur_kv_head = hid // (kv_group_num // valid_block_H) - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + for k in T.Pipelined(loop_range, num_stages=num_stages): - T.copy(K[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_shared) - T.copy(mask[bid, k * block_N:(k + 1) * block_N, cur_kv_head], mask_local) + T.copy( + K[ + bid, + (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, + cur_kv_head, + :, + ], + K_shared, + ) + T.copy( + mask[ + bid, + (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, + cur_kv_head, + ], + mask_local, + ) T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j], - -T.infinity(accum_dtype)) + acc_s[i, j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split), acc_s[i, j], -T.infinity(accum_dtype)) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): @@ -125,23 +136,66 @@ def flash_attn( T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(block_H, dim): acc_o[i, j] *= scores_scale[i] - T.copy(V[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], V_shared) + T.copy( + V[ + bid, + (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, + cur_kv_head, + :, + ], + V_shared, + ) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] T.copy(acc_o[:valid_block_H, :], O_shared) - T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :]) - - @T.macro - def flash_attn_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), + T.copy(O_shared, Output_partial[bid, hid * valid_block_H : (hid + 1) * valid_block_H, sid, :]) + + # combine + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim], dtype) + o_accum_local = T.alloc_fragment([dim], accum_dtype) + lse_local = T.alloc_fragment([num_split, 128], dtype) + lse_logsum_local = T.alloc_fragment([128], accum_dtype) + lse_max_local = T.alloc_fragment([128], accum_dtype) + scale_local = T.alloc_fragment([128], accum_dtype) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + for k, j in T.Parallel(num_split, 128): + lse_local[k, j] = glse[bz, by, k] + T.reduce_max(lse_local, lse_max_local, dim=0, clear=True) + for k in T.serial(num_split): + for j in T.Parallel(128): + lse_logsum_local[j] += T.exp2(lse_local[k, j] - lse_max_local[j]) + for j in T.Parallel(128): + lse_logsum_local[j] = T.log2(lse_logsum_local[j]) + lse_max_local[j] + for k in T.serial(num_split): + for i in T.Parallel(dim): + po_local[i] = Output_partial[bz, by, k, i] + for j in T.Parallel(128): + scale_local[j] = T.exp2(lse_local[k, j] - lse_logsum_local[j]) + # Note: Pay attention to dim and the number of threads in Parallel + for i in T.Parallel(dim): + o_accum_local[i] += po_local[i] * scale_local[i] + for i in T.Parallel(dim): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def flashattn_gqa_decode_no_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_o, dtype), ): with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -160,34 +214,26 @@ def flash_attn_split( bid = bx hid = by - sid = bz cur_kv_head = hid // (kv_group_num // valid_block_H) - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = T.ceildiv((seqlen_kv // num_split), block_N) - for k in T.Pipelined(loop_range, num_stages=num_stages): - T.copy( - K[bid, (seqlen_kv // num_split) * sid + - k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, - cur_kv_head, :], K_shared) - T.copy( - mask[bid, (seqlen_kv // num_split) * sid + - k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, - cur_kv_head], mask_local) + T.copy(K[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_shared) + T.copy(mask[bid, k * block_N : (k + 1) * block_N, cur_kv_head], mask_local) T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, block_N): - acc_s[i, - j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split), - acc_s[i, j], -T.infinity(accum_dtype)) + acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j], -T.infinity(accum_dtype)) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): @@ -198,88 +244,14 @@ def flash_attn_split( T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(block_H, dim): acc_o[i, j] *= scores_scale[i] - T.copy( - V[bid, (seqlen_kv // num_split) * sid + - k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, - cur_kv_head, :], V_shared) + T.copy(V[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - - for i in T.Parallel(block_H): - if i < valid_block_H: - glse[bid, hid * valid_block_H + i, sid] = logsum[i] T.copy(acc_o[:valid_block_H, :], O_shared) - T.copy(O_shared, Output_partial[bid, hid * valid_block_H:(hid + 1) * valid_block_H, - sid, :]) - - @T.macro - def combine( - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), - Output: T.Tensor(shape_o, dtype), - ): - with T.Kernel(heads, batch, threads=128) as (by, bz): - po_local = T.alloc_fragment([dim], dtype) - o_accum_local = T.alloc_fragment([dim], accum_dtype) - lse_local = T.alloc_fragment([num_split, 128], dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_fragment([128], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - - T.annotate_layout({ - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i), - # lse_local: (local_id, thread_id) - lse_local: T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)), - }) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - for k, j in T.Parallel(num_split, 128): - lse_local[k, j] = glse[bz, by, k] - T.reduce_max(lse_local, lse_max_local, dim=0, clear=True) - for k in T.Pipelined(num_split, num_stages=1): - lse_local_split[0] = glse[bz, by, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] - for k in T.serial(num_split): - for i in T.Parallel(dim): - po_local[i] = Output_partial[bz, by, k, i] - lse_local_split[0] = glse[bz, by, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) - for i in T.Parallel(dim): - o_accum_local[i] += po_local[i] * scale_local[0] - for i in T.Parallel(dim): - Output[bz, by, i] = o_accum_local[i] - - @T.prim_func - def flashattn_gqa_decode_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), - Output: T.Tensor(shape_o, dtype), - ): - flash_attn_split(Q, K, V, mask, glse, Output_partial) - combine(glse, Output_partial, Output) - - @T.prim_func - def flashattn_gqa_decode_no_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), - Output: T.Tensor(shape_o, dtype), - ): - flash_attn(Q, K, V, mask, Output) + T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) if num_split > 1: return flashattn_gqa_decode_split @@ -300,27 +272,21 @@ def ref_program(query, key, value, mask, glse, Output_partial): dim = query.shape[-1] num_head_groups = query.shape[1] // key.shape[2] scale = dim**0.5 - key = rearrange(key, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] - value = rearrange(value, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + key = rearrange(key, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] if mask is not None: - mask = rearrange(mask, 'b s h -> b h s') + mask = rearrange(mask, "b s h -> b h s") mask = mask.unsqueeze(1) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, value, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out @@ -334,16 +300,12 @@ def flash_split_ref(Q, K, V, mask): seqlen_kv = K.size(1) num_head_groups = nheads // groups - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) acc_s = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float) - acc_s_cast = torch.empty((batch, num_head_groups, groups, block_N), - device="cuda", - dtype=torch.float16) + acc_s_cast = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float16) acc_o = torch.empty((batch, num_head_groups, groups, dim), device="cuda", dtype=torch.float) scores_max = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) - scores_max_prev = torch.empty((batch, num_head_groups, groups), - device="cuda", - dtype=torch.float) + scores_max_prev = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) scores_scale = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) scores_sum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) logsum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) @@ -351,25 +313,25 @@ def flash_split_ref(Q, K, V, mask): glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float) Q_ = Q * scale - Q_ = rearrange(Q_, 'b (h g) d -> b g h d', g=num_head_groups) + Q_ = rearrange(Q_, "b (h g) d -> b g h d", g=num_head_groups) for ks in range(num_split): acc_o.fill_(0) logsum.fill_(0) - scores_max.fill_(float('-inf')) - scores_max_prev.fill_(float('-inf')) + scores_max.fill_(float("-inf")) + scores_max_prev.fill_(float("-inf")) for i in range(int((seqlen_kv // num_split) / block_N)): acc_s.fill_(0) - acc_s = torch.einsum('bghd,bkhd->bghk', Q_, - K[:, (seqlen_kv // num_split) * ks + - i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) # [batch, nheads, block_N] + acc_s = torch.einsum( + "bghd,bkhd->bghk", + Q_, + K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) # [batch, nheads, block_N] if mask is not None: - mask_local = mask[:, (seqlen_kv // num_split) * ks + - i * block_N:(seqlen_kv // num_split) * ks + (i + 1) * block_N, :] - mask_local = rearrange(mask_local, 'b s h -> b h s') + mask_local = mask[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :] + mask_local = rearrange(mask_local, "b s h -> b h s") mask_local = mask_local.unsqueeze(1) - acc_s = acc_s.masked_fill(mask_local == 0, float('-inf')) + acc_s = acc_s.masked_fill(mask_local == 0, float("-inf")) scores_max_prev = scores_max scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads] scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads] @@ -377,15 +339,16 @@ def flash_split_ref(Q, K, V, mask): acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N] acc_o += torch.einsum( - 'bghk,bkhd->bghd', acc_s_cast, - V[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) + "bghk,bkhd->bghd", + acc_s_cast, + V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) scores_sum = acc_s.sum(dim=-1, keepdim=False) logsum = logsum * scores_scale + scores_sum - acc_o_out = rearrange(acc_o, 'b g h d->b (h g) d') - logsum_out = rearrange(logsum, 'b g h->b (h g)') + acc_o_out = rearrange(acc_o, "b g h d->b (h g) d") + logsum_out = rearrange(logsum, "b g h->b (h g)") acc_o_out /= logsum_out[:, :, None] - logsum_out = torch.log2(logsum_out) + rearrange(scores_max, 'b g h->b (h g)') + logsum_out = torch.log2(logsum_out) + rearrange(scores_max, "b g h->b (h g)") gacc_o[ks, :, :, :] = acc_o_out glogsum[ks, :, :] = logsum_out @@ -421,7 +384,7 @@ def calc_sim(x, y, name="tensor"): x, y = x.data.double(), y.data.double() denominator = (x * x + y * y).sum() if denominator == 0: - print_red_warning(f'{name} all zero') + print_red_warning(f"{name} all zero") return 1 sim = 2 * (x * y).sum() / denominator return sim @@ -429,28 +392,23 @@ def calc_sim(x, y, name="tensor"): def assert_similar(x, y, eps=1e-2, name="tensor", assert_=False, print_=True): sim = calc_sim(x, y, name) - diff = 1. - sim + diff = 1.0 - sim if not (0 <= diff <= eps): - print_red_warning(f'{name} Error: {diff}') + print_red_warning(f"{name} Error: {diff}") if assert_: - raise AssertionError(f'{name} Error: {diff}') + raise AssertionError(f"{name} Error: {diff}") else: if print_: - print(f'passed: {name} diff={diff}') + print(f"passed: {name} diff={diff}") -def main(batch: int = 1, - heads: int = 32, - groups: int = 8, - kv_seqlen: int = 8192, - dim: int = 128, - tune: bool = False): +def main(batch: int = 1, heads: int = 32, groups: int = 8, kv_seqlen: int = 8192, dim: int = 128, tune: bool = False): batch, heads, groups, kv_seqlen, dim = batch, heads, groups, kv_seqlen, dim qk_flops = 2 * batch * heads * kv_seqlen * dim pv_flops = 2 * batch * heads * kv_seqlen * dim total_flops = qk_flops + pv_flops - if (not tune): + if not tune: config, sm_version = get_heuristic_config() kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) @@ -470,7 +428,7 @@ def main(batch: int = 1, print(o_ref) assert_similar(o, o_ref, name="o_ref") - assert_similar(o_ref_split, o_ref, name="o_ref_split") + assert_similar(o, o_ref_split, name="o_ref_split") print("All checks pass.") latency = profiler.do_bench(ref_program, warmup=500) @@ -490,13 +448,21 @@ def main(batch: int = 1, print(f"Ref latency: {ref_latency}") +def run_regression_perf(batch: int = 1, heads: int = 32, groups: int = 8, kv_seqlen: int = 8192, dim: int = 128): + batch, heads, groups, kv_seqlen, dim = batch, heads, groups, kv_seqlen, dim + config, _ = get_heuristic_config() + kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--groups', type=int, default=8, help='groups') - parser.add_argument('--kv_seqlen', type=int, default=8192, help='kv sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--groups", type=int, default=8, help="groups") + parser.add_argument("--kv_seqlen", type=int, default=8192, help="kv sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() main(args.batch, args.heads, args.groups, args.kv_seqlen, args.dim, args.tune) diff --git a/examples/flash_decoding/example_gqa_decode_varlen_logits.py b/examples/flash_decoding/example_gqa_decode_varlen_logits.py index 16924ebe8..864ff3e54 100644 --- a/examples/flash_decoding/example_gqa_decode_varlen_logits.py +++ b/examples/flash_decoding/example_gqa_decode_varlen_logits.py @@ -1,14 +1,11 @@ import torch -import triton -import triton.language as tl import math import argparse import tilelang import tilelang.language as T -from tilelang.autotuner import autotune +from tilelang.profiler import do_bench torch.manual_seed(0) -tilelang.disable_cache() def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -19,184 +16,13 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, - head_dim) + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -@triton.jit -def _fwd_inner( - q, - k_ptrs, - v_ptrs, - s_ptrs, - m_i, - l_i, - acc, - offs_h, - mask_h, - offs_n, - seqlen, - softmax_scale, - lo, - hi, - stride_kt, - stride_vt, - stride_sh, - stride_sn, - BLOCK_N: tl.constexpr, -): - """Inner loop computation for attention""" - - for blk_idx in tl.range(lo, hi): - start_n = blk_idx * BLOCK_N - k = tl.load(k_ptrs + start_n * stride_kt, mask=offs_n[None, :] + start_n < seqlen) - v = tl.load(v_ptrs + start_n * stride_vt, mask=offs_n[:, None] + start_n < seqlen) - - qk = tl.dot(q, k) - qk *= softmax_scale - qk += tl.where(offs_n[None, :] + start_n < seqlen, 0, -1.0e9) - - row_max = tl.max(qk, 1) - tl.store(s_ptrs + offs_h * stride_sh + blk_idx * stride_sn, row_max, mask=mask_h) - - m_ij = tl.maximum(m_i, row_max) - qk -= m_ij[:, None] - p = tl.math.exp(qk) - l_ij = tl.sum(p, 1) - alpha = tl.math.exp(m_i - m_ij) - l_i = l_i * alpha + l_ij - m_i = m_ij - acc *= alpha[:, None] - p = p.to(v.type.element_ty) - acc += tl.dot(p, v) - - return m_i, l_i, acc - - - -@triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [4, 8]\ - for num_stages in [2, 4]\ - ], - key=['gqa_group_size', 'BLOCK_N', 'BLOCK_D', 'BLOCK_H'], -) -@triton.jit -def _fwd_kernel_varlen( - Q, # [token_q = b, h_q, dim] - K, # [token_k, h_kv, dim] - V, - O, - S, - s_aux, - softmax_scale, - cu_seqlens_k, - stride_qt, - stride_qh, - stride_qd, - stride_kt, - stride_kh, - stride_kd, - stride_vt, - stride_vh, - stride_vd, - stride_ot, - stride_oh, - stride_od, - stride_sb, - stride_sh, - stride_sn, #bmask shape [b, q_h, seq/BLOCK_N] - gqa_group_size: tl.constexpr, - BLOCK_H: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_D: tl.constexpr, -): - - off_z = tl.program_id(0) - off_h_for_kv = tl.program_id(1) - off_h_q = off_h_for_kv * gqa_group_size - - cu_k_start = tl.load(cu_seqlens_k + off_z) - cu_k_end = tl.load(cu_seqlens_k + off_z + 1) - - seqlen_k = cu_k_end - cu_k_start - - offs_h = tl.arange(0, BLOCK_H) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_D) - - Q_ptrs = Q + off_z * stride_qt + off_h_q * stride_qh - K_ptrs = K + (cu_k_start) * stride_kt + off_h_for_kv * stride_kh - V_ptrs = V + (cu_k_start) * stride_vt + off_h_for_kv * stride_vh - O_ptrs = O + off_z * stride_ot + off_h_q * stride_oh - S_ptrs = S + off_z * stride_sb + off_h_q * stride_sh - - mask_h = offs_h < gqa_group_size - q = tl.load( - Q_ptrs + offs_d[None, :] * stride_qd + offs_h[:, None] * stride_qh, mask=mask_h[:, None]) - - if s_aux is not None: - sink = tl.load(s_aux + off_h_q + offs_h, mask=mask_h).to(tl.float32) - l_i = tl.zeros([BLOCK_H], dtype=tl.float32) - m_i = tl.zeros([BLOCK_H], dtype=tl.float32) + sink - else: - l_i = tl.full([BLOCK_H], 1.0, dtype=tl.float32) - m_i = tl.full([BLOCK_H], float("-inf"), dtype=tl.float32) - - acc = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32) - - k_ptrs = K_ptrs + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd - v_ptrs = V_ptrs + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd - - lo, hi = 0, tl.cdiv(seqlen_k, BLOCK_N) - m_i, l_i, acc = _fwd_inner( - q, - k_ptrs, - v_ptrs, - S_ptrs, - m_i, - l_i, - acc, - offs_h, - mask_h, - offs_n, - seqlen_k, - softmax_scale, - lo, - hi, - stride_kt, - stride_vt, - stride_sh, - stride_sn, - BLOCK_N, - ) - - if s_aux is not None: - sink = tl.math.exp(sink - m_i) - l_i = l_i + sink - acc = acc / l_i[:, None] - - else: - l_recip = 1 / l_i[:, None] - acc = acc * l_recip - - for blk_idx in tl.range(lo, hi): - s = tl.load(S_ptrs + offs_h * stride_sh + blk_idx * stride_sn, mask=mask_h) - s = tl.exp(s - m_i) / l_i - tl.store(S_ptrs + offs_h * stride_sh + blk_idx * stride_sn, s, mask=mask_h) - - acc = acc.to(O.dtype.element_ty) - - tl.store( - O_ptrs + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od, - acc, - mask=mask_h[:, None]) - - def get_configs(): import itertools + block_N = [64, 128] block_H = [64] num_split = [1] @@ -204,54 +30,37 @@ def get_configs(): threads = [128] _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) - configs = [{ - 'block_N': c[0], - 'block_H': c[1], - 'num_split': c[2], - 'num_stages': c[3], - 'threads': c[4] - } for c in _configs] + configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs] return configs -@autotune(configs=get_configs(), warmup=10, rep=10) -@tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding") -def flashattn(batch, - heads, - k_heads, - max_seqlen_kv, - total_seqlen_k, - dim, - has_sink, - block_N=128, - block_H=64, - num_split=1, - num_stages=1, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) +@tilelang.jit(out_idx=[-2, -1]) +def flashattn( + batch, heads, k_heads, max_seqlen_kv, total_seqlen_k, dim, has_sink, block_N=128, block_H=64, num_split=1, num_stages=1, threads=128 +): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape_q = [batch, heads, dim] shape_k = [total_seqlen_k, k_heads, dim] shape_v = [total_seqlen_k, k_heads, dim] shape_o = [batch, heads, dim] shape_s = [batch, heads, math.ceil(max_seqlen_kv / block_N)] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // k_heads valid_block_H = min(block_H, kv_group_num) - # TODO: check if max_seqlen_kv is correct for varlen case - - @T.macro - def flash_attn( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - cu_seqlens_k: T.Tensor([batch + 1], "int32"), - s_aux: T.Tensor([heads], "float32"), - Output: T.Tensor([batch, heads, dim], dtype), - S: T.Tensor(shape_s, dtype), + + @T.prim_func + def flashattn_gqa_decode_no_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + cu_seqlens_k: T.Tensor([batch + 1], T.int32), + s_aux: T.Tensor([heads], T.float32), + Output: T.Tensor(shape_o, dtype), + S: T.Tensor(shape_s, dtype), ): - with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bid, hid, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -264,578 +73,148 @@ def flash_attn( scores_scale = T.alloc_fragment([block_H], accum_dtype) scores_sum = T.alloc_fragment([block_H], accum_dtype) logsum = T.alloc_fragment([block_H], accum_dtype) - S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype) - # S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype) - s_aux_shared = T.alloc_shared([block_H], "float32") - - T.annotate_layout({ - # Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - # K_shared: tilelang.layout.make_swizzled_layout(K_shared), - # V_shared: tilelang.layout.make_swizzled_layout(V_shared), - # O_shared: tilelang.layout.make_swizzled_layout(O_shared), - # S_shared: tilelang.layout.make_swizzled_layout(S_shared), - }) - - bid = bx - hid = by + S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype) + S_shared_cast = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype) + s_aux_shared = T.alloc_shared([block_H], T.float32) + cur_kv_head = hid // (kv_group_num // valid_block_H) cur_start_k = cu_seqlens_k[bid] cur_end_k = cu_seqlens_k[bid + 1] cur_seqlen_k = cur_end_k - cur_start_k - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - # loop_range = T.ceildiv((seqlen_kv // num_split), block_N) loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N) for k in T.Pipelined(loop_range, num_stages=num_stages): - T.copy(K[cur_start_k + k * block_N:cur_start_k + (k + 1) * block_N, cur_kv_head, :], - K_shared) + T.copy(K[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], K_shared) T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, block_N): - # acc_s[i, j] = T.if_then_else(mask_local[j] != 0 and k * block_N + j < cur_seqlen_k, acc_s[i, j], - # -T.infinity(accum_dtype)) - acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], - -T.infinity(accum_dtype)) + acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype)) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # scores_max_prev is m_i - # scores_max is row_max->m_ij in triton T.copy(scores_max, S_shared[:, k]) - # scores_scale is alpha in triton for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) - # scores_sum is l_ij in triton - # logsum is l_i in triton for i in T.Parallel(block_H): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(block_H, dim): acc_o[i, j] *= scores_scale[i] - T.copy(V[cur_start_k + k * block_N:cur_start_k + (k + 1) * block_N, cur_kv_head, :], - V_shared) + T.copy(V[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) if has_sink: - T.copy(s_aux[hid * valid_block_H:hid * valid_block_H + block_H], s_aux_shared) + T.copy(s_aux[hid * valid_block_H : hid * valid_block_H + block_H], s_aux_shared) for i in T.Parallel(block_H): logsum[i] += s_aux_shared[i] for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)): S_shared[h, k] = T.exp2((S_shared[h, k] - scores_max[h]) * scale) / logsum[h] - # T.copy(S_shared, S_fragment) - # for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)): - # S_fragment[h, k] = T.exp2((S_fragment[h, k] - scores_max[h]) * scale) / logsum[h] for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale T.copy(acc_o[:valid_block_H, :], O_shared) - T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :]) - # T.copy(S_fragment, S_shared) - T.copy(S_shared[:valid_block_H, :], S[bid, - hid * valid_block_H:(hid + 1) * valid_block_H, :]) - - @T.prim_func - def flashattn_gqa_decode_no_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - cu_seqlens_k: T.Tensor([batch + 1], "int32"), - s_aux: T.Tensor([heads], "float32"), - Output: T.Tensor(shape_o, dtype), - S: T.Tensor(shape_s, dtype), - ): - flash_attn(Q, K, V, cu_seqlens_k, s_aux, Output, S) + T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) + T.copy(S_shared, S_shared_cast) + T.copy(S_shared_cast[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) - # TODO: split version return flashattn_gqa_decode_no_split -def flash_attn_with_attn_pool_decode_tilelang( - Q: torch.Tensor, ## [tq = b, q_h, q_dim] - K: torch.Tensor, ## [tk, k_h, k_dim] - V: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_k: int, - real_max_k_seqlen: int, - num_split: int, - softmax_scale: float, - s_aux: torch.Tensor = None, - block_size: int = 64, - use_per_kv_head_sparse_index: bool = False, - tl_kernel=None, -): - num_tokens, q_h, head_size = Q.shape - batch = cu_seqlens_k.size(0) - 1 - k_h = K.size(1) - - assert Q.dim() == K.dim() == 3 - assert Q.size(2) == K.size(2) - assert cu_seqlens_k.dim() == 1 - assert head_size in {64, 128, 256} - assert Q.is_contiguous() - assert K.is_contiguous() - assert V.is_contiguous() - - gqa_group_size = q_h // k_h - - O_tl = torch.zeros_like(Q) - S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), - dtype=Q.dtype, - device=Q.device) - O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux) - - if use_per_kv_head_sparse_index: - S_tl = torch.max_pool2d(S_tl, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1)) - else: - S_tl = torch.max_pool2d(S_tl, kernel_size=(q_h, 1), stride=(q_h, 1)) - - return O_tl, S_tl - - -def flash_attn_with_attn_pool_decode( - Q: torch.Tensor, ## [tq = b, q_h, q_dim] - K: torch.Tensor, ## [tk, k_h, k_dim] - V: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_k: int, - real_max_k_seqlen: int, - num_split: int, - softmax_scale: float, - s_aux: torch.Tensor = None, - block_size: int = 64, - use_per_kv_head_sparse_index: bool = False, -): - num_tokens, q_h, head_size = Q.shape - batch = cu_seqlens_k.size(0) - 1 - k_h = K.size(1) - - assert Q.dim() == K.dim() == 3 - assert Q.size(2) == K.size(2) - assert cu_seqlens_k.dim() == 1 - assert head_size in {64, 128, 256} - assert Q.is_contiguous() - assert K.is_contiguous() - assert V.is_contiguous() - - gqa_group_size = q_h // k_h - - BLOCK_D = head_size - BLOCK_N = block_size - BLOCK_H = 64 - - O = torch.zeros_like(Q) - S = torch.zeros((batch, q_h, math.ceil(max_seqlen_k / block_size)), - dtype=Q.dtype, - device=Q.device) - - def grid(META): - return (batch, k_h) - - with torch.cuda.device(Q.device.index): - _fwd_kernel_varlen[grid]( - Q, - K, - V, - O, - S, - s_aux, - softmax_scale, - cu_seqlens_k, - *Q.stride(), - *K.stride(), - *V.stride(), - *O.stride(), - *S.stride(), - gqa_group_size, - BLOCK_H=BLOCK_H, - BLOCK_N=BLOCK_N, - BLOCK_D=BLOCK_D, - ) - - if use_per_kv_head_sparse_index: - S = torch.max_pool2d(S, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1)) - else: - S = torch.max_pool2d(S, kernel_size=(q_h, 1), stride=(q_h, 1)) - - return O, S - - -def test_equal_seqlen_decode_main(args): - """Test decode kernel with equal sequence lengths""" - print("Testing decode kernel with equal sequence lengths") - - batch_size = args.batch_size - q_heads = args.q_heads - kv_heads = args.kv_heads - k_seqlen = args.k_seqlen - real_max_k_seqlen = args.k_seqlen - head_size = args.head_size - block_size = args.block_size - dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 - - # For decode, query is just 1 token per batch - q = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) - k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype) - v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype) +def ref_attention(q, k, v, k_seqlens, q_heads, sink=None): + """ + Compute reference attention output and weights. + Args: + q: [b, q_heads, head_size] + k, v: [b, kv_heads, max_seqlen, head_size] + k_seqlens: [b] actual sequence lengths + sink: [q_heads] optional sink values + Returns: output [b, q_heads, head_size], attn_weights [b, q_heads, max_seqlen] + """ + batch_size, kv_heads, max_seqlen, head_size = k.shape softmax_scale = 1.0 / math.sqrt(head_size) - # Generate sink values if needed - sink = None - if args.test_sink: - sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values - print(f"Using sink attention with sink values: {sink}") - - # Convert to varlen format for K, V - k_varlen = k.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size) - v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size) - - # Generate cumulative sequence lengths - cu_seqlens_k = torch.arange( - 0, (batch_size + 1) * k_seqlen, k_seqlen, device='cuda', dtype=torch.int32) - max_seqlen_k = k_seqlen - - print(f"q shape: {q.shape}") - print(f"k_varlen shape: {k_varlen.shape}") - print(f"v_varlen shape: {v_varlen.shape}") + # Expand KV heads and compute attention scores + k = repeat_kv(k, q_heads // kv_heads) + v = repeat_kv(v, q_heads // kv_heads) + logits = torch.matmul(q.unsqueeze(2), k.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] - num_tokens, q_h, head_size = q.shape - batch = cu_seqlens_k.size(0) - 1 - k_h = k_varlen.size(1) - tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, - args.test_sink) - - # Test our decode kernel - O_triton, S_triton = flash_attn_with_attn_pool_decode( - q, - k_varlen, - v_varlen, - cu_seqlens_k, - max_seqlen_k, - real_max_k_seqlen, - args.num_split, - softmax_scale, - s_aux=sink, - block_size=block_size) - O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( - q, - k_varlen, - v_varlen, - cu_seqlens_k, - max_seqlen_k, - real_max_k_seqlen, - args.num_split, - softmax_scale, - s_aux=sink, - block_size=block_size, - tl_kernel=tl_kernel, - ) - for i in range(batch_size): - S_tilelang[i, :, - math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / - block_size):] = 0 - - # Compute torch reference - q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size] - k_repeat = repeat_kv(k, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size] - v_repeat = repeat_kv(v, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size] + # Mask invalid positions + mask = torch.arange(max_seqlen, device=q.device).expand(batch_size, -1) >= k_seqlens.unsqueeze(1) + logits.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float("-inf")) if sink is None: - # Standard scaled dot-product attention - logits = torch.matmul(q_expanded, k_repeat.transpose( - -2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] - attn_weights = torch.softmax(logits, dim=-1) - O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size] + attn_weights = logits.softmax(dim=-1) else: - # s_aux attention - logits = torch.matmul(q_expanded, k_repeat.transpose( - -2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] - - sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] - logits_max = torch.max(logits, dim=-1, keepdim=True).values - logits_or_sinks_max = torch.maximum(logits_max, sink_expanded) - sinks = torch.exp(sink_expanded - logits_or_sinks_max) - unnormalized_scores = torch.exp(logits - logits_or_sinks_max) - normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks - attn_weights = unnormalized_scores / normalizer - O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), - v_repeat).squeeze(2) # [batch, q_heads, head_size] - - # Compute attention score pooling - attn_score_pooled = torch.max_pool2d( - attn_weights.squeeze(2), # [b, q_heads, k_seqlen] - kernel_size=(q_heads, block_size), - stride=(q_heads, block_size), - ceil_mode=True).to(torch.float16) - - print("S_tilelang", S_tilelang) - print("attn_score_pooled", attn_score_pooled) - - max_diff_o = torch.max(torch.abs(O_triton - O_torch)) - max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) - max_diff_o_tilelang = torch.max(torch.abs(O_tilelang - O_torch)) - max_diff_s_tilelang = torch.max(torch.abs(S_tilelang - attn_score_pooled)) - - print(f"Max difference in O: {max_diff_o.item()}") - print(f"Max difference in S: {max_diff_s.item()}") - print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}") - print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}") - assert torch.allclose( - O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" - assert torch.allclose( - S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" - assert torch.allclose( - O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}" - assert torch.allclose( - S_tilelang, attn_score_pooled, atol=1e-2, - rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}" - print("✅ All tests passed!") + # Sink attention: softmax with additional sink term + sink_expanded = sink.view(1, q_heads, 1, 1) + logits_max = torch.maximum(logits.max(dim=-1, keepdim=True).values, sink_expanded) + exp_logits = torch.exp(logits - logits_max) + attn_weights = exp_logits / (exp_logits.sum(dim=-1, keepdim=True) + torch.exp(sink_expanded - logits_max)) + attn_weights.masked_fill_(mask.unsqueeze(1).unsqueeze(2), 0.0) + output = torch.matmul(attn_weights.to(v.dtype), v).squeeze(2) + return output, attn_weights.squeeze(2) -def test_varlen_decode_main(args): - """Test decode kernel with variable sequence lengths""" - batch_size = args.batch_size - q_heads = args.q_heads - kv_heads = args.kv_heads - max_k_seqlen = args.k_seqlen # Use as max sequence length - real_max_k_seqlen = args.k_seqlen - head_size = args.head_size - block_size = args.block_size - dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 - - print(f"Testing decode kernel with variable sequence lengths (max_k_seqlen={max_k_seqlen})") - # Generate sink values if needed - sink = None - if args.test_sink: - sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values - print(f"Using sink attention with sink values: {sink}") +def test_varlen_decode_main(args): + """Test decode kernel with variable sequence lengths.""" + batch_size, q_heads, kv_heads = args.batch_size, args.q_heads, args.kv_heads + max_k_seqlen, head_size, block_size = args.k_seqlen, args.head_size, args.block_size + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 - # Generate variable length k sequences + # Generate variable length sequences and cumulative lengths k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,)) - print(f"k_seqlens: {k_seqlens}") - - # Generate cumulative sequence lengths for k - cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32) - total_k_tokens = 0 - for i in range(batch_size): - cu_seqlens_k[i] = total_k_tokens - total_k_tokens += k_seqlens[i] - cu_seqlens_k[batch_size] = total_k_tokens - - print(f"cu_seqlens_k: {cu_seqlens_k}") - - # Generate tensors - Q is [batch_size, q_heads, head_size] for decode - q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) - k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) - v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) - - softmax_scale = 1.0 / math.sqrt(head_size) - max_seqlen_k = int(k_seqlens.max()) - - print(f"Actual max_seqlen_k: {max_seqlen_k}") - print(f"q_decode shape: {q_decode.shape}") - print(f"k_varlen shape: {k_varlen.shape}") - print(f"v_varlen shape: {v_varlen.shape}") - - num_tokens, q_h, head_size = q_decode.shape - batch = cu_seqlens_k.size(0) - 1 - k_h = k_varlen.size(1) - tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, - args.test_sink) - - # Test our decode kernel - O_triton, S_triton = flash_attn_with_attn_pool_decode( - q_decode, - k_varlen, - v_varlen, - cu_seqlens_k, - max_seqlen_k, - real_max_k_seqlen, - args.num_split, - softmax_scale, - s_aux=sink, - block_size=block_size) - O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( - q_decode, - k_varlen, - v_varlen, - cu_seqlens_k, - max_seqlen_k, - real_max_k_seqlen, - args.num_split, - softmax_scale, - s_aux=sink, - block_size=block_size, - tl_kernel=tl_kernel, - ) + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) + cu_seqlens_k[1:] = torch.cumsum(k_seqlens, dim=0).to(torch.int32).cuda() + total_k_tokens = cu_seqlens_k[-1].item() + + # Generate input tensors + q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 if args.test_sink else None + + # Run tilelang kernel + tilelang.disable_cache() + tl_kernel = flashattn(batch_size, q_heads, kv_heads, max_k_seqlen, total_k_tokens, head_size, args.test_sink) + O_tl, S_tl = tl_kernel(q, k_varlen, v_varlen, cu_seqlens_k, sink) + S_tl = torch.max_pool2d(S_tl, kernel_size=(q_heads, 1), stride=(q_heads, 1)) + + # Mask out invalid S positions for i in range(batch_size): - S_tilelang[i, :, - math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / - block_size):] = 0 - - # Create torch reference - pad tensors for comparison - k_padded_list = [] - v_padded_list = [] + valid_blocks = math.ceil(k_seqlens[i].item() / block_size) + S_tl[i, :, valid_blocks:] = 0 + # Prepare padded tensors for reference + actual_max = int(k_seqlens.max()) + k_padded = torch.zeros(batch_size, kv_heads, actual_max, head_size, device="cuda", dtype=dtype) + v_padded = torch.zeros(batch_size, kv_heads, actual_max, head_size, device="cuda", dtype=dtype) for i in range(batch_size): - actual_k_len = k_seqlens[i] - - # Extract and pad k, v for this batch - k_start = cu_seqlens_k[i] - k_end = cu_seqlens_k[i + 1] - - # Pad to max_seqlen_k - k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype) - v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype) - - k_padded[:actual_k_len] = k_varlen[k_start:k_end] - v_padded[:actual_k_len] = v_varlen[k_start:k_end] - - k_padded_list.append(k_padded) - v_padded_list.append(v_padded) + seq_len = k_seqlens[i].item() + k_padded[i, :, :seq_len] = k_varlen[cu_seqlens_k[i] : cu_seqlens_k[i + 1]].transpose(0, 1) + v_padded[i, :, :seq_len] = v_varlen[cu_seqlens_k[i] : cu_seqlens_k[i + 1]].transpose(0, 1) - # Stack to create batched tensors [b, max_seqlen, kv_heads, head_size] - k_padded_batched = torch.stack( - k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] - v_padded_batched = torch.stack( - v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] - - # Expand q to match kv heads: [b, q_heads, 1, head_size] - q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size] - - print(f"q_expanded shape: {q_expanded.shape}") - print(f"k_padded_batched shape: {k_padded_batched.shape}") - print(f"v_padded_batched shape: {v_padded_batched.shape}") - - # Compute torch reference - k_repeat = repeat_kv(k_padded_batched, - q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] - v_repeat = repeat_kv(v_padded_batched, - q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] - - if sink is None: - # Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen] - attn_score = torch.matmul(q_expanded, k_repeat.transpose( - -2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] - - # Apply sequence length masking - for i in range(batch_size): - actual_k_len = k_seqlens[i] - attn_score[i, :, :, actual_k_len:] = float('-inf') - - attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen] - - # Mask out invalid positions - for i in range(batch_size): - actual_k_len = k_seqlens[i] - attn_weights[i, :, :, actual_k_len:] = 0.0 - - # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] - O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size] - else: - # s_aux attention - logits = torch.matmul(q_expanded, k_repeat.transpose( - -2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] - - # Apply sequence length masking - for i in range(batch_size): - actual_k_len = k_seqlens[i] - logits[i, :, :, actual_k_len:] = float('-inf') - - sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] - logits_max = torch.max(logits, dim=-1, keepdim=True).values - logits_or_sinks_max = torch.maximum(logits_max, sink_expanded) - sinks = torch.exp(sink_expanded - logits_or_sinks_max) - unnormalized_scores = torch.exp(logits - logits_or_sinks_max) - normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks - attn_weights = unnormalized_scores / normalizer - - # Mask out invalid positions - for i in range(batch_size): - actual_k_len = k_seqlens[i] - attn_weights[i, :, :, actual_k_len:] = 0.0 - - # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] - O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), - v_repeat) # [b, q_heads, 1, head_size] - - O_torch = O_torch.squeeze(2) # [b, q_heads, head_size] - - # Compute attention score pooling for S - attn_score_pooled = torch.max_pool2d( - attn_weights.squeeze(2), # [b, q_heads, max_seqlen] - kernel_size=(q_heads, block_size), - stride=(q_heads, block_size), - ceil_mode=True).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] - - print(f"O_triton shape: {O_triton.shape}") - print(f"O_tilelang shape: {O_tilelang.shape}") - print(f"O_torch shape: {O_torch.shape}") - print(f"S_triton shape: {S_triton.shape}") - print(f"S_tilelang shape: {S_tilelang.shape}") - print(f"attn_score_pooled shape: {attn_score_pooled.shape}") + # Compute reference + O_ref, attn_weights = ref_attention(q, k_padded, v_padded, k_seqlens.cuda(), q_heads, sink) + S_ref = torch.max_pool2d(attn_weights, kernel_size=(q_heads, block_size), stride=(q_heads, block_size), ceil_mode=True).to(dtype) # Compare results - max_diff_o = torch.max(torch.abs(O_triton - O_torch)) - max_diff_o_tl = torch.max(torch.abs(O_tilelang - O_torch)) - print(f"Max difference in O: {max_diff_o.item()}") - print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}") - - max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) - max_diff_s_tl = torch.max( - torch.abs(S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)] - attn_score_pooled)) - print(f"Max difference in S: {max_diff_s.item()}") - print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}") - - assert torch.allclose( - O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" - assert torch.allclose( - S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" - assert torch.allclose( - O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" - assert torch.allclose( - S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)], - attn_score_pooled, - atol=1e-2, - rtol=1e-2), f"Score mismatch: {max_diff_s_tl.item()}" - + num_blocks = math.ceil(actual_max / block_size) + assert torch.allclose(O_tl, O_ref, atol=1e-2, rtol=1e-2), f"Output mismatch: {(O_tl - O_ref).abs().max()}" + assert torch.allclose(S_tl[:, :, :num_blocks], S_ref[:, :, :num_blocks], atol=1e-2, rtol=1e-2), "Score mismatch" print("✅ All tests passed!") -def do_bench(fn, *args, warmup=10, rep=10, **kwargs): - """ - Do benchmark for a function. - """ - start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] - end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] - for _ in range(warmup): - fn(*args, **kwargs) - - torch.cuda.synchronize() - for i in range(rep): - start_event[i].record() - fn(*args, **kwargs) - end_event[i].record() - torch.cuda.synchronize() - - # Record clocks - times = torch.tensor( - [s.elapsed_time(e) for s, e in zip(start_event, end_event)], - dtype=torch.float, - ) - - return times.mean().item() - - def speed_benchmark_decode_comparison(args): """Speed benchmark for decode kernel""" batch_size = args.batch_size @@ -844,7 +223,7 @@ def speed_benchmark_decode_comparison(args): max_k_seqlen = args.k_seqlen head_size = args.head_size block_size = args.block_size - dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 print("\n=== Decode Speed Benchmark Comparison ===") print("Configuration:") @@ -865,7 +244,7 @@ def speed_benchmark_decode_comparison(args): k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int) # Generate cumulative sequence lengths for k - cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32) + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) total_k_tokens = 0 for i in range(batch_size): cu_seqlens_k[i] = total_k_tokens @@ -873,88 +252,68 @@ def speed_benchmark_decode_comparison(args): cu_seqlens_k[batch_size] = total_k_tokens # Generate tensors - q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) - k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) - v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) - - softmax_scale = 1.0 / math.sqrt(head_size) - max_seqlen_k = int(k_seqlens.max()) - - # Generate sink values if needed - sink = None - if args.test_sink: - sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values - print(" Using sink attention with sink values") - - print("Setup complete:") - print(f" Total K tokens: {total_k_tokens}") - print(f" Actual max K seq len: {max_seqlen_k}") + q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 if args.test_sink else None if args.test_varlen: print(f" K sequence lengths: {k_seqlens.tolist()}") - # Warmup - num_tokens, q_h, head_size = q_decode.shape + _, q_h, head_size = q_decode.shape batch = cu_seqlens_k.size(0) - 1 k_h = k_varlen.size(1) - tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, - args.test_sink) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink) + + def run_once(): + tl_kernel(q_decode, k_varlen, v_varlen, cu_seqlens_k, sink) # Benchmark print("⚡ Benchmarking Tilelang kernel (100 iterations)...") tilelang_time = do_bench( - flash_attn_with_attn_pool_decode_tilelang, - q_decode, - k_varlen, - v_varlen, - cu_seqlens_k, - max_seqlen_k, - args.k_seqlen, - 1, - softmax_scale, - sink, - block_size, - False, - tl_kernel, + run_once, ) print(f"Average decode kernel time Tilelang: {tilelang_time:.3f} ms") - # Benchmark - print("⚡ Benchmarking Triton kernel (100 iterations)...") - triton_time = do_bench(flash_attn_with_attn_pool_decode, q_decode, k_varlen, v_varlen, - cu_seqlens_k, max_seqlen_k, args.k_seqlen, 1, softmax_scale, sink, - block_size) - print(f"Average decode kernel time Triton: {triton_time:.3f} ms") - print(f"Speedup: {(triton_time / tilelang_time):.3f}") +def main(): + args = argparse.Namespace( + batch_size=1, + q_heads=32, + kv_heads=8, + k_seqlen=8192, + head_size=128, + block_size=128, + dtype=T.float16, + ) + args.test_sink = True + args.test_varlen = True + args.dtype = T.float16 + args.num_split = 1 + test_varlen_decode_main(args) if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Flash Attention Decode with Attention Pooling') - parser.add_argument('--batch_size', type=int, default=1, help='Batch size') - parser.add_argument('--q_heads', type=int, default=32, help='Number of query heads') - parser.add_argument('--kv_heads', type=int, default=8, help='Number of key-value heads') - parser.add_argument('--k_seqlen', type=int, default=8192, help='Key sequence length') - parser.add_argument( - '--head_size', type=int, default=128, choices=[64, 128, 256], help='Head dimension') - parser.add_argument('--block_size', type=int, default=64, help='Block size for computation') - parser.add_argument( - '--dtype', type=str, default='bfloat16', choices=['float16', 'bfloat16'], help='Data type') - parser.add_argument( - '--test_varlen', action='store_true', help='Test with truly variable sequence lengths') - parser.add_argument( - '--test_sink', action='store_true', help='Test with sink attention mechanism') - parser.add_argument('--benchmark', action='store_true', help='Run speed benchmark') - parser.add_argument( - '--num_split', type=int, default=1, choices=[1, 16], help='Number of splits') + parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--q_heads", type=int, default=32, help="Number of query heads") + parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads") + parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length") + parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension") + parser.add_argument("--block_size", type=int, default=128, help="Block size for computation") + parser.add_argument("--dtype", type=str, default=T.bfloat16, choices=[T.float16, T.bfloat16], help="Data type") + parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths") + parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism") + parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark") + parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits") args = parser.parse_args() args.test_sink = True - args.test_varlen = False - args.dtype = 'float16' + args.test_varlen = True + args.dtype = T.float16 args.num_split = 1 - if args.benchmark: - speed_benchmark_decode_comparison(args) - elif args.test_varlen: - test_varlen_decode_main(args) - else: - test_equal_seqlen_decode_main(args) + # if args.benchmark: + # speed_benchmark_decode_comparison(args) + # else: + # test_varlen_decode_main(args) + + speed_benchmark_decode_comparison(args) diff --git a/examples/flash_decoding/example_mha_inference.py b/examples/flash_decoding/example_mha_inference.py index 3eabc9a76..24a90c57b 100644 --- a/examples/flash_decoding/example_mha_inference.py +++ b/examples/flash_decoding/example_mha_inference.py @@ -10,102 +10,24 @@ @tilelang.jit(out_idx=[5]) def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape_q = [batch, seqlen_q, heads, dim] shape_kv = [batch, seqlen_kv, heads, dim] part_shape = [batch, seqlen_q, heads, num_split, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 - @T.macro - def MMA0( + @T.prim_func + def flashattn_mha_inference( + Q: T.Tensor(shape_q, dtype), K: T.Tensor(shape_kv, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - mid: T.int32, - hid: T.int32, - bid: T.int32, - sid: T.int32, - ): - T.copy( - K[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid + - (k + 1) * block_N, hid, :], K_shared) - # TODO: Handle causal split case - if is_causal: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( V: T.Tensor(shape_kv, dtype), - V_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - hid: T.int32, - bid: T.int32, - sid: T.int32, - ): - T.copy( - V[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid + - (k + 1) * block_N, hid, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), + Output_partial: T.Tensor(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim] + Output: T.Tensor(shape_q, dtype), ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - - @T.macro - def flash_attn_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_kv, dtype), - V: T.Tensor(shape_kv, dtype), - glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), - Output_partial: T.Tensor(part_shape, dtype), - ): - with T.Kernel( - T.ceildiv(seqlen_q, block_M), heads * batch, num_split, - threads=128) as (bx, by, bz): + # split + with T.Kernel(T.ceildiv(seqlen_q, block_M), heads * batch, num_split, threads=128) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -126,43 +48,73 @@ def flash_attn_split( # NOTE(wt): tma barrier has some problems with padded dimensions (seq_q here) currently # disable relevant tma copy and use SIMT as fallback for now - T.copy(Q[bid, mid * block_M:(mid + 1) * block_M, hid, :], Q_shared, disable_tma=True) + T.copy(Q[bid, mid * block_M : (mid + 1) * block_M, hid, :], Q_shared, disable_tma=True) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) # TODO: Handle causal split case loop_range = ( - T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv( - (mid + 1) * block_M, block_N)) if is_causal else T.ceildiv( - (seqlen_kv // num_split), block_N)) + T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv((mid + 1) * block_M, block_N)) + if is_causal + else T.ceildiv((seqlen_kv // num_split), block_N) + ) for k in T.Pipelined(loop_range, num_stages=2): - MMA0(K, Q_shared, K_shared, acc_s, k, mid, hid, bid, sid) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, hid, bid, sid) + T.copy( + K[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], + K_shared, + ) + # TODO: Handle causal split case + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy( + V[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], + V_shared, + ) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, glse[bid, hid, sid, mid * block_M:(mid + 1) * block_M]) + T.copy(logsum, glse[bid, hid, sid, mid * block_M : (mid + 1) * block_M]) T.copy(acc_o, O_shared) - T.copy( - O_shared, - Output_partial[bid, mid * block_M:(mid + 1) * block_M, hid, sid, :], - disable_tma=True) - - @T.macro - def combine( - glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), - Output_partial: T.Tensor(part_shape, dtype), - Output: T.Tensor(shape_q, dtype), - ): + T.copy(O_shared, Output_partial[bid, mid * block_M : (mid + 1) * block_M, hid, sid, :], disable_tma=True) + + # combine with T.Kernel(T.ceildiv(seqlen_q, block_M), heads, batch, threads=128) as (bx, by, bz): po_local = T.alloc_fragment([block_M, dim], dtype) - po_shared = T.alloc_shared([block_M, dim], dtype) o_accum_local = T.alloc_fragment([block_M, dim], accum_dtype) o_shared = T.alloc_shared([block_M, dim], dtype) lse_local = T.alloc_fragment([num_split, block_M], dtype) @@ -171,20 +123,17 @@ def combine( lse_max_local = T.alloc_fragment([block_M], accum_dtype) scale_local = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({ - o_accum_local: T.Fragment(o_accum_local.shape, forward_thread_fn=lambda i, j: i), - o_shared: tilelang.layout.make_swizzled_layout(o_shared), - po_shared: tilelang.layout.make_swizzled_layout(po_shared), - }) - T.clear(lse_logsum_local) T.clear(o_accum_local) - T.copy(glse[ - bz, - by, - :, - bx * block_M:(bx + 1) * block_M, - ], lse_local) + T.copy( + glse[ + bz, + by, + :, + bx * block_M : (bx + 1) * block_M, + ], + lse_local, + ) T.reduce_max(lse_local, lse_max_local, dim=0, clear=False) for k in T.Pipelined(num_split): T.copy(lse_local[k, :], lse_local_split) @@ -193,11 +142,7 @@ def combine( for i in T.Parallel(block_M): lse_logsum_local[i] = T.log2(lse_logsum_local[i]) + lse_max_local[i] for k in T.Pipelined(num_split, num_stages=2): - T.copy( - Output_partial[bz, bx * block_M:(bx + 1) * block_M, by, k, :], - po_shared, - disable_tma=True) - T.copy(po_shared, po_local) + T.copy(Output_partial[bz, bx * block_M : (bx + 1) * block_M, by, k, :], po_local) for i in T.Parallel(block_M): lse_local_split[i] = lse_local[k, i] for i in T.Parallel(block_M): @@ -205,19 +150,7 @@ def combine( for i, j in T.Parallel(block_M, dim): o_accum_local[i, j] += po_local[i, j] * scale_local[i] T.copy(o_accum_local, o_shared) - T.copy(o_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :], disable_tma=True) - - @T.prim_func - def flashattn_mha_inference( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_kv, dtype), - V: T.Tensor(shape_kv, dtype), - glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), - Output_partial: T.Tensor(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim] - Output: T.Tensor(shape_q, dtype), - ): - flash_attn_split(Q, K, V, glse, Output_partial) - combine(glse, Output_partial, Output) + T.copy(o_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :], disable_tma=True) return flashattn_mha_inference @@ -225,10 +158,10 @@ def flashattn_mha_inference( def ref_program(Q, K, V, glse, Output_partial, causal): assert causal is False dim = Q.size(-1) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -256,7 +189,7 @@ def flash_split_ref(Q, K, V, causal): block_N = 128 seqlen_kv = K.size(1) - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) acc_s = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float) acc_s_cast = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float16) acc_o = torch.empty((batch, block_M, nheads, dim), device="cuda", dtype=torch.float) @@ -273,14 +206,15 @@ def flash_split_ref(Q, K, V, causal): for ks in range(num_split): acc_o.fill_(0) logsum.fill_(0) - scores_max.fill_(float('-inf')) - scores_max_prev.fill_(float('-inf')) + scores_max.fill_(float("-inf")) + scores_max_prev.fill_(float("-inf")) for i in range(int((seqlen_kv // num_split) / block_N)): acc_s.fill_(0) - acc_s = torch.einsum('bqhd,bkhd->bhqk', Q_, - K[:, (seqlen_kv // num_split) * ks + - i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) # [batch, seqlen, nheads, block_N] + acc_s = torch.einsum( + "bqhd,bkhd->bhqk", + Q_, + K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) # [batch, seqlen, nheads, block_N] scores_max_prev = scores_max scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM] scores_scale = torch.exp2(scores_max_prev - scores_max) @@ -288,9 +222,10 @@ def flash_split_ref(Q, K, V, causal): acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) acc_s_cast = acc_s.to(torch.float16) acc_o += torch.einsum( - 'bhqk,bkhd->bqhd', acc_s_cast, - V[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) + "bhqk,bkhd->bqhd", + acc_s_cast, + V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) scores_sum = acc_s.sum(dim=-1, keepdim=False) logsum = logsum * scores_scale + scores_sum acc_o /= logsum[:, :, :, None].transpose(1, 2) @@ -298,8 +233,7 @@ def flash_split_ref(Q, K, V, causal): gacc_o[ks, :, :, :, :] = acc_o glogsum[ks, :, :, :] = logsum - return glogsum.to(torch.float16).permute(1, 2, 0, - 3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4) + return glogsum.to(torch.float16).permute(1, 2, 0, 3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4) def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False): @@ -323,5 +257,13 @@ def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False): print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False): + BLOCK_M = 128 + BLOCK_N = 64 + kernel = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/flash_decoding/regression_example_flash_decoding.py b/examples/flash_decoding/regression_example_flash_decoding.py new file mode 100644 index 000000000..476bceb34 --- /dev/null +++ b/examples/flash_decoding/regression_example_flash_decoding.py @@ -0,0 +1,17 @@ +import tilelang.testing +import example_gqa_decode +import example_mha_inference + + +def regression_example_gqa_decode(): + tilelang.testing.process_func(example_gqa_decode.run_regression_perf) + + +def regression_example_mha_inference(): + tilelang.testing.process_func( + example_mha_inference.run_regression_perf, BATCH=1, H=32, Q_CTX=128, KV_CTX=2048, D_HEAD=128, causal=False + ) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/flash_decoding/test_example_flash_decoding.py b/examples/flash_decoding/test_example_flash_decoding.py index c728dfe0e..2cbcd8404 100644 --- a/examples/flash_decoding/test_example_flash_decoding.py +++ b/examples/flash_decoding/test_example_flash_decoding.py @@ -2,9 +2,9 @@ import example_gqa_decode import example_mha_inference +import example_gqa_decode_varlen_logits -# TODO(lei): fix the correctness of gqa decode on sm90 @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_le(8, 9) def test_example_example_gqa_decode(): @@ -15,5 +15,9 @@ def test_example_example_mha_inference(): example_mha_inference.main(BATCH=1, H=32, Q_CTX=128, KV_CTX=2048, D_HEAD=128, causal=False) +def test_example_example_gqa_decode_varlen_logits(): + example_gqa_decode_varlen_logits.main() + + if __name__ == "__main__": tilelang.testing.main() diff --git a/examples/fusedmoe/example_fusedmoe_tilelang.py b/examples/fusedmoe/example_fusedmoe_tilelang.py index a8d684965..4b843cdfe 100644 --- a/examples/fusedmoe/example_fusedmoe_tilelang.py +++ b/examples/fusedmoe/example_fusedmoe_tilelang.py @@ -9,17 +9,18 @@ @tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) -def moe_forward_tilelang_shared(d_hidden, - d_expert, - n_shared_experts, - dtype, - num_tokens, - block_token=128, - block_dhidden=128, - block_dexpert=128, - threads=256, - num_stages=1): - +def moe_forward_tilelang_shared( + d_hidden, + d_expert, + n_shared_experts, + dtype, + num_tokens, + block_token=128, + block_dhidden=128, + block_dexpert=128, + threads=256, + num_stages=1, +): scale = 1.44269504 # log2(e) # Parameters @@ -32,21 +33,19 @@ def moe_forward_tilelang_shared(d_hidden, shared_W_up_shape = (dexpert, dhidden) shared_W_down_shape = (dhidden, dexpert) - accum_type = "float32" + accum_type = T.float32 @T.prim_func def kernel_shared( - input: T.Tensor(input_shape, dtype), # type: ignore - shared_W_gate: T.Tensor(shared_W_gate_shape, dtype), # type: ignore - shared_W_up: T.Tensor(shared_W_up_shape, dtype), # type: ignore - shared_W_down: T.Tensor(shared_W_down_shape, dtype), # type: ignore - up_logits: T.Tensor((num_tokens, dexpert), dtype), # type: ignore - output: T.Tensor(input_shape, dtype), # type: ignore + input: T.Tensor(input_shape, dtype), # type: ignore + shared_W_gate: T.Tensor(shared_W_gate_shape, dtype), # type: ignore + shared_W_up: T.Tensor(shared_W_up_shape, dtype), # type: ignore + shared_W_down: T.Tensor(shared_W_down_shape, dtype), # type: ignore + up_logits: T.Tensor((num_tokens, dexpert), dtype), # type: ignore + output: T.Tensor(input_shape, dtype), # type: ignore ): # Step 1: Compute gate and up logits - with T.Kernel( - T.ceildiv(num_tokens, block_token), T.ceildiv(dexpert, block_dexpert), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by): # Split the block to shared experts and routed experts input_shared = T.alloc_fragment((block_token, block_dhidden), dtype=dtype) W_gate_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype) @@ -70,16 +69,13 @@ def kernel_shared( # Fuse with SiLU and element-wise product for i, j in T.Parallel(block_token, block_dexpert): - gate_logits_local[i, j] = gate_logits_local[i, j] * ( - 1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale))) + gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale))) up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j] T.copy(up_logits_local, up_logits[bx * block_token, by * block_dexpert]) # Step 2: Compute down logits - with T.Kernel( - T.ceildiv(num_tokens, block_token), T.ceildiv(dhidden, block_dhidden), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dhidden, block_dhidden), threads=threads) as (bx, by): up_logits_shared = T.alloc_fragment((block_token, block_dexpert), dtype=dtype) W_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype) output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_type) @@ -97,21 +93,25 @@ def kernel_shared( return kernel_shared -@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) -def moe_forward_tilelang_routed(d_hidden, - d_expert, - n_routed_experts, - dtype, - group_sum, - group_count, - block_token=128, - block_dhidden=128, - block_dexpert=128, - threads=256, - num_stages=1, - k_pack=1, - coalesced_width=None): - +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + } +) +def moe_forward_tilelang_routed( + d_hidden, + d_expert, + n_routed_experts, + dtype, + group_sum, + group_count, + block_token=128, + block_dhidden=128, + block_dexpert=128, + threads=256, + num_stages=1, +): scale = 1.44269504 # log2(e) # Parameters @@ -124,7 +124,7 @@ def moe_forward_tilelang_routed(d_hidden, # group_count = len(group_sizes_list) # M = sum([(group_size + block_token - 1) // block_token for group_size in group_sizes_list]) M = math.ceil(group_sum / block_token) + group_count - accum_dtype = "float32" + accum_dtype = T.float32 # Tensors: Note that input shape is reshape to (bs * seq_len * n_experts_per_token, dhidden) for grouped gemm input_shape = (group_sum, dhidden) @@ -132,22 +132,22 @@ def moe_forward_tilelang_routed(d_hidden, routed_expert_gate_shape = (n_routed_experts, dexpert, dhidden) routed_expert_up_shape = (n_routed_experts, dexpert, dhidden) routed_expert_down_shape = (n_routed_experts, dhidden, dexpert) - routed_expert_weights_shape = (group_sum) - group_sizes_shape = (n_routed_experts) + routed_expert_weights_shape = group_sum + group_sizes_shape = n_routed_experts @T.prim_func def kernel( - input: T.Tensor(input_shape, dtype), # type: ignore - routed_expert_gate: T.Tensor(routed_expert_gate_shape, dtype), # type: ignore - routed_expert_up: T.Tensor(routed_expert_up_shape, dtype), # type: ignore - routed_expert_down: T.Tensor(routed_expert_down_shape, dtype), # type: ignore - routed_expert_weights: T.Tensor(routed_expert_weights_shape, dtype), # type: ignore - group_sizes: T.Tensor(group_sizes_shape, "int32"), # type: ignore - group_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore - group_padded_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore - group_idx_for_bx: T.Tensor((M,), "int32"), # type: ignore - up_logits: T.Tensor(intermediate_shape, dtype), # type: ignore - output: T.Tensor(input_shape, dtype), # type: ignore + input: T.Tensor(input_shape, dtype), # type: ignore + routed_expert_gate: T.Tensor(routed_expert_gate_shape, dtype), # type: ignore + routed_expert_up: T.Tensor(routed_expert_up_shape, dtype), # type: ignore + routed_expert_down: T.Tensor(routed_expert_down_shape, dtype), # type: ignore + routed_expert_weights: T.Tensor(routed_expert_weights_shape, dtype), # type: ignore + group_sizes: T.Tensor(group_sizes_shape, T.int32), # type: ignore + group_offsets: T.Tensor(group_sizes_shape, T.int32), # type: ignore + group_padded_offsets: T.Tensor(group_sizes_shape, T.int32), # type: ignore + group_idx_for_bx: T.Tensor((M,), T.int32), # type: ignore + up_logits: T.Tensor(intermediate_shape, dtype), # type: ignore + output: T.Tensor(input_shape, dtype), # type: ignore ): # Step 1: Compute gate and up logits with T.Kernel(M, T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by): @@ -158,58 +158,41 @@ def kernel( gate_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype) up_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype) - cur_group_idx = T.alloc_local([1], "int32") - cur_group_size = T.alloc_local([1], "int32") - - T.use_swizzle(10, enable=True) + T.use_swizzle(10) m_start_padded = bx * block_token - cur_group_idx[0] = group_idx_for_bx[bx] + cur_group_idx = group_idx_for_bx[bx] - cur_group_size[0] = group_sizes[cur_group_idx[0]] - m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[ - cur_group_idx[0]] - actual_rows = T.max( - 0, - T.min(block_token, cur_group_size[0] - - (m_start_padded - group_padded_offsets[cur_group_idx[0]]))) + cur_group_size = group_sizes[cur_group_idx] + m_start = m_start_padded - group_padded_offsets[cur_group_idx] + group_offsets[cur_group_idx] + actual_rows = T.max(0, T.min(block_token, cur_group_size - (m_start_padded - group_padded_offsets[cur_group_idx]))) T.clear(gate_logits_local) T.clear(up_logits_local) for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages): T.copy( - input[m_start:m_start + block_token, k * block_dhidden:(k + 1) * block_dhidden], + input[m_start : m_start + block_token, k * block_dhidden : (k + 1) * block_dhidden], input_shared, - coalesced_width=coalesced_width) + ) T.copy( - routed_expert_gate[cur_group_idx[0], - by * block_dexpert:(by + 1) * block_dexpert, - k * block_dhidden:(k + 1) * block_dhidden], - routed_expert_gate_shared, - coalesced_width=coalesced_width) - T.gemm( - input_shared, + routed_expert_gate[ + cur_group_idx, by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden + ], routed_expert_gate_shared, - gate_logits_local, - k_pack=k_pack, - transpose_B=True) + ) + T.gemm(input_shared, routed_expert_gate_shared, gate_logits_local, transpose_B=True) T.copy( - routed_expert_up[cur_group_idx[0], by * block_dexpert:(by + 1) * block_dexpert, - k * block_dhidden:(k + 1) * block_dhidden], + routed_expert_up[ + cur_group_idx, by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden + ], routed_expert_up_shared, - coalesced_width=coalesced_width) - T.gemm( - input_shared, - routed_expert_up_shared, - up_logits_local, - k_pack=k_pack, - transpose_B=True) + ) + T.gemm(input_shared, routed_expert_up_shared, up_logits_local, transpose_B=True) for i, j in T.Parallel(block_token, block_dexpert): - gate_logits_local[i, j] = gate_logits_local[i, j] * ( - 1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale))) + gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale))) up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j] for i, j in T.Parallel(block_token, block_dexpert): @@ -222,60 +205,40 @@ def kernel( routed_expert_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype) output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_dtype) - cur_group_idx = T.alloc_local([1], "int32") - cur_group_size = T.alloc_local([1], "int32") - - T.use_swizzle(10, enable=True) + T.use_swizzle(10) m_start_padded = bx * block_token - cur_group_idx[0] = group_idx_for_bx[bx] + cur_group_idx = group_idx_for_bx[bx] - cur_group_size[0] = group_sizes[cur_group_idx[0]] - m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[ - cur_group_idx[0]] - actual_rows = T.max( - 0, - T.min(block_token, cur_group_size[0] - - (m_start_padded - group_padded_offsets[cur_group_idx[0]]))) + cur_group_size = group_sizes[cur_group_idx] + m_start = m_start_padded - group_padded_offsets[cur_group_idx] + group_offsets[cur_group_idx] + actual_rows = T.max(0, T.min(block_token, cur_group_size - (m_start_padded - group_padded_offsets[cur_group_idx]))) T.clear(output_local) for k in T.Pipelined(T.ceildiv(dexpert, block_dexpert), num_stages=num_stages): T.copy( - up_logits[m_start:m_start + block_token, - k * block_dexpert:(k + 1) * block_dexpert], + up_logits[m_start : m_start + block_token, k * block_dexpert : (k + 1) * block_dexpert], up_logits_shared, - coalesced_width=coalesced_width) + ) T.copy( - routed_expert_down[cur_group_idx[0], - by * block_dhidden:(by + 1) * block_dhidden, - k * block_dexpert:(k + 1) * block_dexpert], - routed_expert_down_shared, - coalesced_width=coalesced_width) - T.gemm( - up_logits_shared, + routed_expert_down[ + cur_group_idx, by * block_dhidden : (by + 1) * block_dhidden, k * block_dexpert : (k + 1) * block_dexpert + ], routed_expert_down_shared, - output_local, - k_pack=k_pack, - transpose_B=True) + ) + T.gemm(up_logits_shared, routed_expert_down_shared, output_local, transpose_B=True) for i, j in T.Parallel(block_token, block_dhidden): if i < actual_rows: - output[m_start + i, by * block_dhidden + - j] = output_local[i, j] * routed_expert_weights[m_start + i] + output[m_start + i, by * block_dhidden + j] = output_local[i, j] * routed_expert_weights[m_start + i] return kernel class Expert(nn.Module): - - def __init__(self, - config: Dict, - gate: torch.Tensor, - up: torch.Tensor, - down: torch.Tensor, - d_expert: Optional[int] = None): + def __init__(self, config: Dict, gate: torch.Tensor, up: torch.Tensor, down: torch.Tensor, d_expert: Optional[int] = None): super().__init__() self.config = config self.act_fn = nn.SiLU() @@ -294,14 +257,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MoEGate(nn.Module): - def __init__(self, config: Dict, weights: Dict): super().__init__() self.top_k: int = config["n_experts_per_token"] self.num_experts: int = config["n_routed_experts"] self.d_hidden: int = config["d_hidden"] - self.W_g_weight = weights['router.weight'].t() + self.W_g_weight = weights["router.weight"].t() def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: logits = x @ self.W_g_weight @@ -312,76 +274,69 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: class MoE(nn.Module): - - def __init__(self, - config: Dict, - shared_kernel: tilelang.JITKernel, - routed_kernel: tilelang.JITKernel, - weights: Dict, - padding_M: int = 128): + def __init__( + self, config: Dict, shared_kernel: tilelang.JITKernel, routed_kernel: tilelang.JITKernel, weights: Dict, padding_M: int = 128 + ): super().__init__() self.config = config self.shared_kernel = shared_kernel self.routed_kernel = routed_kernel self.padding_M = padding_M - self.experts = nn.ModuleList([ - Expert( - config, - gate=weights[f'experts.{i}.0.weight'], - up=weights[f'experts.{i}.1.weight'], - down=weights[f'experts.{i}.2.weight']) for i in range(config["n_routed_experts"]) - ]) + self.experts = nn.ModuleList( + [ + Expert( + config, + gate=weights[f"experts.{i}.0.weight"], + up=weights[f"experts.{i}.1.weight"], + down=weights[f"experts.{i}.2.weight"], + ) + for i in range(config["n_routed_experts"]) + ] + ) self.device = torch.device("cuda") self.gating_network = MoEGate(config, weights).to(self.device) shared_expert_dim = config["d_expert"] * config["n_shared_experts"] self.shared_expert = Expert( config=config, - gate=weights['shared_experts.0.weight'], - up=weights['shared_experts.1.weight'], - down=weights['shared_experts.2.weight'], - d_expert=shared_expert_dim).to(self.device) + gate=weights["shared_experts.0.weight"], + up=weights["shared_experts.1.weight"], + down=weights["shared_experts.2.weight"], + d_expert=shared_expert_dim, + ).to(self.device) self.expert_cache = torch.zeros( - (config["batch_size"] * config["seq_len"], config["d_hidden"]), - dtype=torch.float16, - device=self.device) - self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts], - dim=0) - self.stacked_expert_w_up = torch.stack([expert.W_up_weight for expert in self.experts], - dim=0) - self.stacked_expert_w_down = torch.stack([expert.W_down_weight for expert in self.experts], - dim=0) + (config["batch_size"] * config["seq_len"], config["d_hidden"]), dtype=torch.float16, device=self.device + ) + self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts], dim=0) + self.stacked_expert_w_up = torch.stack([expert.W_up_weight for expert in self.experts], dim=0) + self.stacked_expert_w_down = torch.stack([expert.W_down_weight for expert in self.experts], dim=0) self.stacked_expert_tokens = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], - self.config["d_hidden"]), + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]), dtype=torch.float16, - device=self.device) + device=self.device, + ) self.stacked_expert_weights = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), - dtype=torch.float16, - device=self.device) + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.float16, device=self.device + ) self.stacked_expert_tokens_idxs = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), - dtype=torch.int64, - device=self.device) + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.int64, device=self.device + ) self.up_logits_shared = torch.empty( - (config["batch_size"] * config["seq_len"], self.config["d_expert"]), - dtype=torch.float16, - device=self.device) + (config["batch_size"] * config["seq_len"], self.config["d_expert"]), dtype=torch.float16, device=self.device + ) self.expert_output_shared = torch.empty( - (config["batch_size"] * config["seq_len"], self.config["d_hidden"]), - dtype=torch.float16, - device=self.device) + (config["batch_size"] * config["seq_len"], self.config["d_hidden"]), dtype=torch.float16, device=self.device + ) self.up_logits_routed = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], - self.config["d_expert"]), + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_expert"]), dtype=torch.float16, - device=self.device) + device=self.device, + ) self.expert_output_routed = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], - self.config["d_hidden"]), + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]), dtype=torch.float16, - device=self.device) + device=self.device, + ) @torch.no_grad() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -413,22 +368,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens self.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs - self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[ - idxs[start_idx:end_idx]] + self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[idxs[start_idx:end_idx]] group_sizes = torch.tensor(counts, dtype=torch.int32, device=self.device) - group_offset = torch.tensor( - tokens_per_expert - counts, dtype=torch.int32, device=self.device) + group_offset = torch.tensor(tokens_per_expert - counts, dtype=torch.int32, device=self.device) group_padded_offsets = [0 for _ in range(len(group_sizes))] for i in range(1, len(group_sizes)): - group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil( - (counts[i - 1] + 1) / self.padding_M) * self.padding_M + group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil((counts[i - 1] + 1) / self.padding_M) * self.padding_M block_token = 128 - M = math.ceil( - self.config["batch_size"] * self.config["seq_len"] * - self.config["n_experts_per_token"] / block_token) + self.config["n_routed_experts"] + M = ( + math.ceil(self.config["batch_size"] * self.config["seq_len"] * self.config["n_experts_per_token"] / block_token) + + self.config["n_routed_experts"] + ) group_idx_for_bx = [0 for _ in range(M)] for bx in range(M): @@ -437,8 +390,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if m_start_padded >= group_padded_offsets[i]: group_idx_for_bx[bx] = i - group_padded_offsets = torch.tensor( - group_padded_offsets, dtype=torch.int32, device=self.device) + group_padded_offsets = torch.tensor(group_padded_offsets, dtype=torch.int32, device=self.device) group_idx_for_bx = torch.tensor(group_idx_for_bx, dtype=torch.int32, device=self.device) # Multi-stream execution @@ -448,11 +400,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: with torch.cuda.stream(routed_stream): # Tilelang version: Grouped GEMM - self.routed_kernel(self.stacked_expert_tokens, self.stacked_expert_w_gate, - self.stacked_expert_w_up, self.stacked_expert_w_down, - self.stacked_expert_weights, group_sizes, group_offset, - group_padded_offsets, group_idx_for_bx, self.up_logits_routed, - self.expert_output_routed) + self.routed_kernel( + self.stacked_expert_tokens, + self.stacked_expert_w_gate, + self.stacked_expert_w_up, + self.stacked_expert_w_down, + self.stacked_expert_weights, + group_sizes, + group_offset, + group_padded_offsets, + group_idx_for_bx, + self.up_logits_routed, + self.expert_output_routed, + ) # Scatter reduce self.expert_cache = torch.scatter_reduce( @@ -460,14 +420,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: 0, self.stacked_expert_tokens_idxs.view(-1, 1).repeat(1, x_flat.shape[-1]), self.expert_output_routed, - reduce='sum') + reduce="sum", + ) routed_output = self.expert_cache.view(*orig_shape) with torch.cuda.stream(shared_stream): - - self.shared_kernel(x_flat, self.shared_expert.W_gate_weight, - self.shared_expert.W_up_weight, self.shared_expert.W_down_weight, - self.up_logits_shared, self.expert_output_shared) + self.shared_kernel( + x_flat, + self.shared_expert.W_gate_weight, + self.shared_expert.W_up_weight, + self.shared_expert.W_down_weight, + self.up_logits_shared, + self.expert_output_shared, + ) shared_output = self.expert_output_shared.view(*orig_shape) torch.cuda.synchronize() @@ -491,14 +456,15 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: """ input_tensor, weights, config = data - dtype_str = "float16" + dtype_str = T.float16 shared_kernel = moe_forward_tilelang_shared( config["d_hidden"], config["d_expert"], config["n_shared_experts"], dtype=dtype_str, - num_tokens=config["batch_size"] * config["seq_len"]) + num_tokens=config["batch_size"] * config["seq_len"], + ) routed_kernel = moe_forward_tilelang_routed( config["d_hidden"], config["d_expert"], @@ -511,8 +477,7 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: block_dexpert=128, threads=256, num_stages=1, - k_pack=1, - coalesced_width=2) + ) moe = MoE(config, shared_kernel, routed_kernel, weights, padding_M=128) @@ -521,13 +486,7 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: return output -def main(d_hidden=7168, - d_expert=2048, - n_routed_experts=8, - n_shared_experts=1, - n_experts_per_token=4, - batch_size=1, - seq_len=8192): +def main(d_hidden=7168, d_expert=2048, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=8192): config = { "dhidden": d_hidden, "dexpert": d_expert, @@ -536,20 +495,131 @@ def main(d_hidden=7168, "nexpertspertoken": n_experts_per_token, "bs": batch_size, "seqlen": seq_len, - "seed": 81394 + "seed": 81394, } data = generate_input(**config) - - torch.cuda.synchronize() ref_output = ref_kernel(clone_data(data)).to(torch.float32) - torch.cuda.synchronize() tilelang_output = custom_kernel(clone_data(data)).to(torch.float32) - torch.cuda.synchronize() - torch.testing.assert_close(ref_output, tilelang_output, atol=1e-2, rtol=1e-2) print("✅ Tilelang and Torch match") +def run_regression_perf( + d_hidden=7168, d_expert=2048, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=8192 +): + config = { + "dhidden": d_hidden, + "dexpert": d_expert, + "nroutedexperts": n_routed_experts, + "nsharedexperts": n_shared_experts, + "nexpertspertoken": n_experts_per_token, + "bs": batch_size, + "seqlen": seq_len, + "seed": 81394, + } + from tilelang.profiler import do_bench + + data = generate_input(**config) + + x, weights, config = data + + dtype_str = "float16" + + shared_kernel = moe_forward_tilelang_shared( + config["d_hidden"], + config["d_expert"], + config["n_shared_experts"], + dtype=dtype_str, + num_tokens=config["batch_size"] * config["seq_len"], + ) + routed_kernel = moe_forward_tilelang_routed( + config["d_hidden"], + config["d_expert"], + config["n_routed_experts"], + dtype=dtype_str, + group_sum=config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], + group_count=config["n_routed_experts"], + block_token=128, + block_dhidden=128, + block_dexpert=128, + threads=256, + num_stages=1, + coalesced_width=2, + ) + + moe = MoE(config, shared_kernel, routed_kernel, weights, padding_M=128) + batch_size, seq_len, hidden_dim = x.shape + expert_indices, expert_scores = moe.gating_network(x) + flat_expert_indices = expert_indices.view(-1) + flat_expert_weights = expert_scores.view(-1) + x_flat = x.view(-1, hidden_dim) + idxs = flat_expert_indices.argsort() + counts = flat_expert_indices.bincount().cpu().numpy() + tokens_per_expert = counts.cumsum() + num_per_tok = moe.config["n_experts_per_token"] + token_idxs = idxs // num_per_tok + for expert_id, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1] + if start_idx == end_idx: + continue + exp_token_idxs = token_idxs[start_idx:end_idx] + expert_tokens = x_flat[exp_token_idxs] + moe.stacked_expert_tokens[start_idx:end_idx] = expert_tokens + moe.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs + moe.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[idxs[start_idx:end_idx]] + group_sizes = torch.tensor(counts, dtype=torch.int32, device=moe.device) + group_offset = torch.tensor(tokens_per_expert - counts, dtype=torch.int32, device=moe.device) + group_padded_offsets = [0 for _ in range(len(group_sizes))] + for i in range(1, len(group_sizes)): + group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil((counts[i - 1] + 1) / moe.padding_M) * moe.padding_M + block_token = 128 + M = ( + math.ceil(moe.config["batch_size"] * moe.config["seq_len"] * moe.config["n_experts_per_token"] / block_token) + + moe.config["n_routed_experts"] + ) + group_idx_for_bx = [0 for _ in range(M)] + for bx in range(M): + m_start_padded = bx * block_token + for i in range(moe.config["n_routed_experts"]): + if m_start_padded >= group_padded_offsets[i]: + group_idx_for_bx[bx] = i + group_padded_offsets = torch.tensor(group_padded_offsets, dtype=torch.int32, device=moe.device) + group_idx_for_bx = torch.tensor(group_idx_for_bx, dtype=torch.int32, device=moe.device) + + def run_shared_kernel_only(): + moe.routed_kernel( + moe.stacked_expert_tokens, + moe.stacked_expert_w_gate, + moe.stacked_expert_w_up, + moe.stacked_expert_w_down, + moe.stacked_expert_weights, + group_sizes, + group_offset, + group_padded_offsets, + group_idx_for_bx, + moe.up_logits_routed, + moe.expert_output_routed, + ) + + def run_routed_kernel_only(): + moe.routed_kernel( + moe.stacked_expert_tokens, + moe.stacked_expert_w_gate, + moe.stacked_expert_w_up, + moe.stacked_expert_w_down, + moe.stacked_expert_weights, + group_sizes, + group_offset, + group_padded_offsets, + group_idx_for_bx, + moe.up_logits_routed, + moe.expert_output_routed, + ) + + return do_bench(run_routed_kernel_only, backend="cupti") + + if __name__ == "__main__": + tilelang.disable_cache() main() diff --git a/examples/fusedmoe/example_fusedmoe_torch.py b/examples/fusedmoe/example_fusedmoe_torch.py index 00219c6e9..6b6322aff 100644 --- a/examples/fusedmoe/example_fusedmoe_torch.py +++ b/examples/fusedmoe/example_fusedmoe_torch.py @@ -6,7 +6,6 @@ # Reference code in PyTorch class ExpertTorch(nn.Module): - def __init__(self, config: Dict, d_expert: Optional[int] = None): super().__init__() self.config = config @@ -25,7 +24,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MoEGateTorch(nn.Module): - def __init__(self, config: Dict): super().__init__() self.top_k: int = config["n_experts_per_token"] @@ -43,12 +41,10 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: class MoETorch(nn.Module): - def __init__(self, config: Dict): super().__init__() self.config = config - self.experts = nn.ModuleList( - [ExpertTorch(config) for _ in range(config["n_routed_experts"])]) + self.experts = nn.ModuleList([ExpertTorch(config) for _ in range(config["n_routed_experts"])]) self.gating_network = MoEGateTorch(config) shared_expert_dim = config["d_expert"] * config["n_shared_experts"] self.shared_expert = ExpertTorch(config=config, d_expert=shared_expert_dim) @@ -67,8 +63,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return routed_output + shared_output @torch.no_grad() - def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor, - flat_expert_weights: torch.Tensor) -> torch.Tensor: + def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor, flat_expert_weights: torch.Tensor) -> torch.Tensor: expert_cache = torch.zeros_like(x) # test_expert_cache = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"])) # test_expert_tokens = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"])) @@ -91,8 +86,7 @@ def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor, expert_out = expert(expert_tokens) expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) - expert_cache.scatter_reduce_( - 0, exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum') + expert_cache.scatter_reduce_(0, exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce="sum") return expert_cache @@ -116,21 +110,21 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: moe = MoETorch(config) # Fill in the given weights of the model - moe.gating_network.W_g.weight = nn.Parameter(weights['router.weight']) + moe.gating_network.W_g.weight = nn.Parameter(weights["router.weight"]) for i in range(num_experts): - gate_proj_weight = weights[f'experts.{i}.0.weight'] - up_proj_weight = weights[f'experts.{i}.1.weight'] - down_proj_weight = weights[f'experts.{i}.2.weight'] + gate_proj_weight = weights[f"experts.{i}.0.weight"] + up_proj_weight = weights[f"experts.{i}.1.weight"] + down_proj_weight = weights[f"experts.{i}.2.weight"] # Transpose weights to match expected shape for nn.Linear moe.experts[i].W_gate.weight = nn.Parameter(gate_proj_weight.t()) moe.experts[i].W_up.weight = nn.Parameter(up_proj_weight.t()) moe.experts[i].W_down.weight = nn.Parameter(down_proj_weight.t()) - moe.shared_expert.W_gate.weight = nn.Parameter(weights['shared_experts.0.weight'].t()) - moe.shared_expert.W_up.weight = nn.Parameter(weights['shared_experts.1.weight'].t()) - moe.shared_expert.W_down.weight = nn.Parameter(weights['shared_experts.2.weight'].t()) + moe.shared_expert.W_gate.weight = nn.Parameter(weights["shared_experts.0.weight"].t()) + moe.shared_expert.W_up.weight = nn.Parameter(weights["shared_experts.1.weight"].t()) + moe.shared_expert.W_down.weight = nn.Parameter(weights["shared_experts.2.weight"].t()) output = moe(input_tensor) @@ -140,10 +134,9 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: # Input generation for the reference code -def generate_input(dhidden: int, dexpert: int, nroutedexperts: int, nsharedexperts: int, - nexpertspertoken: int, bs: int, seqlen: int, - seed: int) -> Tuple[torch.Tensor, Dict, Dict]: - +def generate_input( + dhidden: int, dexpert: int, nroutedexperts: int, nsharedexperts: int, nexpertspertoken: int, bs: int, seqlen: int, seed: int +) -> Tuple[torch.Tensor, Dict, Dict]: # Really dumb but for now _ isn't parsing correctly. d_hidden = dhidden d_expert = dexpert @@ -163,50 +156,40 @@ def generate_input(dhidden: int, dexpert: int, nroutedexperts: int, nsharedexper "seq_len": seq_len, } - gen = torch.Generator(device='cuda') + gen = torch.Generator(device="cuda") gen.manual_seed(seed) num_experts = n_routed_experts expert_dim = d_expert weights = {} - input_tensor = torch.randn((batch_size, seq_len, d_hidden), - device='cuda', - dtype=torch.float16, - generator=gen).contiguous() + input_tensor = torch.randn((batch_size, seq_len, d_hidden), device="cuda", dtype=torch.float16, generator=gen).contiguous() # Initialize router weights - weights['router.weight'] = torch.randn( - (num_experts, d_hidden), device="cuda", dtype=torch.float16, - generator=gen) / math.sqrt(d_hidden) + weights["router.weight"] = torch.randn((num_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen) / math.sqrt(d_hidden) for i in range(num_experts): - weights[f'experts.{i}.0.weight'] = torch.randn( - (d_hidden, expert_dim), device='cuda', dtype=torch.float16, - generator=gen) / math.sqrt(expert_dim) - - weights[f'experts.{i}.1.weight'] = torch.randn( - (d_hidden, expert_dim), device='cuda', dtype=torch.float16, - generator=gen) / math.sqrt(expert_dim) - - weights[f'experts.{i}.2.weight'] = torch.randn( - (expert_dim, d_hidden), device='cuda', dtype=torch.float16, - generator=gen) / math.sqrt(d_hidden) - - weights['shared_experts.0.weight'] = torch.randn( - (d_hidden, expert_dim * n_shared_experts), - device='cuda', - dtype=torch.float16, - generator=gen) / math.sqrt(expert_dim * n_shared_experts) - weights['shared_experts.1.weight'] = torch.randn( - (d_hidden, expert_dim * n_shared_experts), - device='cuda', - dtype=torch.float16, - generator=gen) / math.sqrt(expert_dim * n_shared_experts) - weights['shared_experts.2.weight'] = torch.randn((expert_dim * n_shared_experts, d_hidden), - device='cuda', - dtype=torch.float16, - generator=gen) / math.sqrt(d_hidden) + weights[f"experts.{i}.0.weight"] = torch.randn( + (d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim) + + weights[f"experts.{i}.1.weight"] = torch.randn( + (d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim) + + weights[f"experts.{i}.2.weight"] = torch.randn( + (expert_dim, d_hidden), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(d_hidden) + + weights["shared_experts.0.weight"] = torch.randn( + (d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim * n_shared_experts) + weights["shared_experts.1.weight"] = torch.randn( + (d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim * n_shared_experts) + weights["shared_experts.2.weight"] = torch.randn( + (expert_dim * n_shared_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(d_hidden) return (input_tensor, weights, config) diff --git a/examples/fusedmoe/regression_example_fusedmoe.py b/examples/fusedmoe/regression_example_fusedmoe.py new file mode 100644 index 000000000..ac0f18aae --- /dev/null +++ b/examples/fusedmoe/regression_example_fusedmoe.py @@ -0,0 +1,19 @@ +import tilelang.testing +import example_fusedmoe_tilelang + + +def regression_example_fusedmoe_tilelang(): + tilelang.testing.process_func( + example_fusedmoe_tilelang.run_regression_perf, + d_hidden=1024, + d_expert=256, + n_routed_experts=8, + n_shared_experts=1, + n_experts_per_token=4, + batch_size=1, + seq_len=1024, + ) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/fusedmoe/test_example_fusedmoe.py b/examples/fusedmoe/test_example_fusedmoe.py index 806aff49e..ba8415895 100644 --- a/examples/fusedmoe/test_example_fusedmoe.py +++ b/examples/fusedmoe/test_example_fusedmoe.py @@ -4,13 +4,8 @@ def test_example_fusedmoe_tilelang(): example_fusedmoe_tilelang.main( - d_hidden=1024, - d_expert=256, - n_routed_experts=8, - n_shared_experts=1, - n_experts_per_token=4, - batch_size=1, - seq_len=1024) + d_hidden=1024, d_expert=256, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=1024 + ) if __name__ == "__main__": diff --git a/examples/gdn/example_chunk_delta_bwd.py b/examples/gdn/example_chunk_delta_bwd.py index 518b0ee21..466c47182 100644 --- a/examples/gdn/example_chunk_delta_bwd.py +++ b/examples/gdn/example_chunk_delta_bwd.py @@ -4,6 +4,7 @@ import tilelang import tilelang.language as T +from tilelang.profiler import do_bench print(tilelang.__file__, flush=True) @@ -12,6 +13,7 @@ # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__, flush=True) from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu except ImportError: @@ -24,7 +26,7 @@ torch.random.manual_seed(0) # torch.set_printoptions(profile="full") -from utils import * +from test_utils import assert_similar def prepare_input( @@ -49,6 +51,7 @@ def prepare_input( G = F.logsigmoid(G) try: from fla.ops.utils.cumsum import chunk_local_cumsum + G = chunk_local_cumsum(G, chunk_size) except ImportError: print("fla not found, skip cumsum") @@ -125,8 +128,11 @@ def torch_chunk_gated_delta_rule_bwd_dhu( DV = dv.shape[-1] block_S = 64 BS = S // block_S - dh, dh0, dv2 = torch.empty((B, BS, H, DK, DV), dtype=output_dtype), torch.empty( - (B, H, DK, DV), dtype=state_dtype), torch.empty((B, S, H, DV), dtype=output_dtype) + dh, dh0, dv2 = ( + torch.empty((B, BS, H, DK, DV), dtype=output_dtype), + torch.empty((B, H, DK, DV), dtype=state_dtype), + torch.empty((B, S, H, DV), dtype=output_dtype), + ) dh_tmp = torch.empty((B, H, DK, DV), dtype=accum_dtype) dv_tmp = torch.empty((B, S, H, DV), dtype=accum_dtype) Q_tmp = torch.empty((B, S, H, DK), dtype=accum_dtype) @@ -138,34 +144,30 @@ def torch_chunk_gated_delta_rule_bwd_dhu( for i_s in range(BS - 1, -1, -1): dh[:, i_s, :, :, :] = dh_tmp - dv_tmp = torch.matmul(K[:, i_s * block_S:(i_s + 1) * block_S, :, :].permute(0, 2, 1, 3), - dh_tmp.to(K.dtype)).permute(0, 2, 1, 3) + dv_tmp = torch.matmul(K[:, i_s * block_S : (i_s + 1) * block_S, :, :].permute(0, 2, 1, 3), dh_tmp.to(K.dtype)).permute(0, 2, 1, 3) if use_g: for i_bh in range(B * H): i_b, i_h = i_bh // H, i_bh % H for i_s2 in range(block_S): - if G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, - i_h] <= 0: - dv_tmp[i_b, i_s2, - i_h, :] *= torch.exp(G[i_b, i_s * block_S + block_S - 1, i_h] - - G[i_b, i_s * block_S + i_s2, i_h]) + if G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h] <= 0: + dv_tmp[i_b, i_s2, i_h, :] *= torch.exp(G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h]) else: dv_tmp[i_b, i_s2, i_h, :] = 0 - dv_tmp += dv[:, i_s * block_S:(i_s + 1) * block_S, :, :] - dv2[:, i_s * block_S:(i_s + 1) * block_S, :, :] = dv_tmp + dv_tmp += dv[:, i_s * block_S : (i_s + 1) * block_S, :, :] + dv2[:, i_s * block_S : (i_s + 1) * block_S, :, :] = dv_tmp if use_g: G_last = G[:, i_s * block_S + block_S - 1, :] for i_bh in range(B * H): i_b, i_h = i_bh // H, i_bh % H dh_tmp[i_b, i_h, :, :] *= torch.exp(G_last[i_b, i_h]) - Q_tmp = Q[:, i_s * block_S:(i_s + 1) * block_S, :, :] + Q_tmp = Q[:, i_s * block_S : (i_s + 1) * block_S, :, :] for i_s2 in range(block_S): for i_k in range(DK): Q_tmp[:, i_s2, :, i_k] *= torch.exp(G[:, i_s * block_S + i_s2, :]) Q_tmp *= scale - W_tmp = W[:, i_s * block_S:(i_s + 1) * block_S, :, :] - dO_tmp = dO[:, i_s * block_S:(i_s + 1) * block_S, :, :] + W_tmp = W[:, i_s * block_S : (i_s + 1) * block_S, :, :] + dO_tmp = dO[:, i_s * block_S : (i_s + 1) * block_S, :, :] torch.backends.cuda.matmul.allow_tf32 = True dh_tmp += torch.matmul(Q_tmp.permute(0, 2, 3, 1), dO_tmp.permute(0, 2, 1, 3)) @@ -223,25 +225,24 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( @T.prim_func def kernel( - # Input - Q: T.Tensor(Q_shape, dtype=input_dtype), - K: T.Tensor(K_shape, dtype=input_dtype), - W: T.Tensor(W_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - h0: T.Tensor(h0_shape, dtype=input_dtype), - dht: T.Tensor(dht_shape, dtype=input_dtype), - dO: T.Tensor(dO_shape, dtype=input_dtype), - dv: T.Tensor(dv_shape, dtype=input_dtype), - # Output - dh: T.Tensor(dh_shape, dtype=output_dtype), - dh0: T.Tensor(dh0_shape, dtype=state_dtype), - dv2: T.Tensor(dv2_shape, dtype=output_dtype), + # Input + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + h0: T.Tensor(h0_shape, dtype=input_dtype), + dht: T.Tensor(dht_shape, dtype=input_dtype), + dO: T.Tensor(dO_shape, dtype=input_dtype), + dv: T.Tensor(dv_shape, dtype=input_dtype), + # Output + dh: T.Tensor(dh_shape, dtype=output_dtype), + dh0: T.Tensor(dh0_shape, dtype=state_dtype), + dv2: T.Tensor(dv2_shape, dtype=output_dtype), ): with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): bb, bh = bbh // H, bbh % H b_dh_shared = T.alloc_shared((DK, block_DV), dtype=output_dtype) - b_dh_shared_fp32 = T.alloc_shared((DK, block_DV), dtype=state_dtype) b_dh_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) b_dh_fragment_1 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) b_dh_fragment_2 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) @@ -249,17 +250,14 @@ def kernel( dv_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) dv_fragment_2 = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) dO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) - dO_shared_t = T.alloc_shared((block_DV, block_S), dtype="float32") - dO_fragment = T.alloc_fragment((block_S, block_DV), dtype="float32") - dO_fragment_t = T.alloc_fragment((block_DV, block_S), dtype="float32") + dO_shared_t = T.alloc_shared((block_DV, block_S), dtype=T.float32) + dO_fragment = T.alloc_fragment((block_S, block_DV), dtype=T.float32) + dO_fragment_t = T.alloc_fragment((block_DV, block_S), dtype=T.float32) K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) Q_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) - Q_shared_fp32 = T.alloc_shared((block_S, DK), dtype="float32") W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) - G_last_local = T.alloc_local((1), dtype=gate_dtype) - G_last_local_exp = T.alloc_local((1), dtype=gate_dtype) G_shared = T.alloc_shared((block_S), dtype=gate_dtype, scope="shared") G_fragment = T.alloc_fragment((block_S), dtype=gate_dtype) G_fragment_post = T.alloc_fragment((block_S), dtype=gate_dtype) @@ -269,20 +267,15 @@ def kernel( T.use_swizzle(10) - T.annotate_layout({ - b_dh_shared: tilelang.layout.make_swizzled_layout(b_dh_shared), - b_dh_shared_fp32: tilelang.layout.make_swizzled_layout(b_dh_shared_fp32), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), - dO_shared_t: tilelang.layout.make_swizzled_layout(dO_shared_t), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - Q_shared_fp32: tilelang.layout.make_swizzled_layout(Q_shared_fp32), - W_shared: tilelang.layout.make_swizzled_layout(W_shared), - }) + T.annotate_layout( + { + dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + } + ) if use_final_state_gradient: - T.copy(dht[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV], b_dh_shared) + T.copy(dht[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_dh_shared) T.copy(b_dh_shared, b_dh_fragment) else: T.clear(b_dh_fragment) @@ -293,57 +286,45 @@ def kernel( # Store the updated dh T.copy(b_dh_fragment, b_dh_shared) - T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) + T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) # Update dv - T.copy(K[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], K_shared) + T.copy(K[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], K_shared) T.gemm(K_shared, b_dh_shared, dv_fragment, clear_accum=True) if use_g: - T.copy( - G[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh], - G_shared, - disable_tma=True) + T.copy(G[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh], G_shared, disable_tma=True) T.copy(G_shared, G_fragment) - G_last_local[0] = G_shared[block_S - 1] - G_last_local_exp[0] = T.exp(G_last_local[0]) + G_last_local = G_shared[block_S - 1] + G_last_local_exp = T.exp(G_last_local) for i_s2 in T.Parallel(block_S): - G_fragment_post[i_s2] = T.exp(G_last_local[0] - G_fragment[i_s2]) + G_fragment_post[i_s2] = T.exp(G_last_local - G_fragment[i_s2]) for i_s2, i_v in T.Parallel(block_S, block_DV): - # with T.If(G_last_local[0] - G_shared[i_s2] <= 0): - with T.If(G_last_local[0] - G_fragment[i_s2] <= 0): - with T.Then(): - dv_fragment[i_s2, - i_v] = dv_fragment[i_s2, i_v] * G_fragment_post[i_s2] - with T.Else(): - dv_fragment[i_s2, i_v] = 0 - - T.copy( - dv[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, - bv * block_DV:(bv + 1) * block_DV], dv_shared) + dv_fragment[i_s2, i_v] = ( + dv_fragment[i_s2, i_v] * G_fragment_post[i_s2] if G_last_local - G_fragment[i_s2] <= 0 else 0 + ) + + T.copy(dv[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dv_shared) T.copy(dv_shared, dv_fragment_2) for i_s2, i_v in T.Parallel(block_S, block_DV): dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] + dv_fragment_2[i_s2, i_v] # Store the updated dv T.copy(dv_fragment, dv_shared) - T.copy( - dv_shared, dv2[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, - bv * block_DV:(bv + 1) * block_DV]) + T.copy(dv_shared, dv2[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) # Update dh - T.copy(Q[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], Q_shared) - T.copy(W[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], W_shared) + T.copy(Q[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], Q_shared) + T.copy(W[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], W_shared) T.clear(Q_fragment) if use_g: for i_k, i_v in T.Parallel(DK, block_DV): - b_dh_fragment[i_k, i_v] *= G_last_local_exp[0] + b_dh_fragment[i_k, i_v] *= G_last_local_exp T.copy(Q_shared, Q_fragment) for i_s2 in T.Parallel(block_S): G_fragment_exp[i_s2] = T.exp(G_shared[i_s2]) for i_s2, i_k in T.Parallel(block_S, DK): - # Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * T.exp(G_shared[i_s2]) * scale Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * G_fragment_exp[i_s2] * scale else: T.copy(Q_shared, Q_fragment) @@ -353,9 +334,7 @@ def kernel( for i_s2, i_k in T.Parallel(block_S, DK): Q_fragment_t[i_k, i_s2] = Q_fragment[i_s2, i_k] - T.copy( - dO[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, - bv * block_DV:(bv + 1) * block_DV], dO_shared) + T.copy(dO[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dO_shared) T.copy(dO_shared, dO_fragment) for i_s2, i_v in T.Parallel(block_S, block_DV): dO_fragment_t[i_v, i_s2] = dO_fragment[i_s2, i_v] @@ -369,7 +348,7 @@ def kernel( b_dh_fragment[i_k, i_v] += b_dh_fragment_1[i_k, i_v] - b_dh_fragment_2[i_k, i_v] if use_initial_state: - T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) + T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) return kernel @@ -444,44 +423,61 @@ def run_test( num_stages=0, use_torch=False, ): - Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - dh_ref, dh0_ref, dv2_ref = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) + Q, K, W, G, h0, dht, dO, dv = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dh_ref, dh0_ref, dv2_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) + dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) # fla ref print("fla running...", flush=True) if use_g: - dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, - scale) + dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale) else: G = G.fill_(0) - dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, - scale) + dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale) # tilelang print("tilelang running...", flush=True) - kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, - chunk_size, scale, use_g, use_initial_state, - use_final_state_gradient, block_DV, threads, - num_stages) + kernel = tilelang_chunk_gated_delta_rule_bwd_dhu( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_g, + use_initial_state, + use_final_state_gradient, + block_DV, + threads, + num_stages, + ) # kernel = tilelang.compile(program) print(kernel.get_kernel_source()) dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) - fla_time = do_bench( - chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size) + fla_time = do_bench(chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size) tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv) print(f"fla time: {fla_time} ms") @@ -496,19 +492,47 @@ def run_test( print("torch running...", flush=True) if use_g: dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu( - Q, K, W, G, h0, dht, dO, dv, scale, use_g, use_initial_state, - use_final_state_gradient, getattr(torch, input_dtype), getattr(torch, output_dtype), - getattr(torch, accum_dtype), getattr(torch, - gate_dtype), getattr(torch, state_dtype)) + Q, + K, + W, + G, + h0, + dht, + dO, + dv, + scale, + use_g, + use_initial_state, + use_final_state_gradient, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) dh_ref_torch = dh_ref_torch.cuda() dh0_ref_torch = dh0_ref_torch.cuda() dv2_ref_torch = dv2_ref_torch.cuda() else: dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu( - Q, K, W, None, h0, dht, dO, dv, scale, use_g, use_initial_state, - use_final_state_gradient, getattr(torch, input_dtype), getattr(torch, output_dtype), - getattr(torch, accum_dtype), getattr(torch, - gate_dtype), getattr(torch, state_dtype)) + Q, + K, + W, + None, + h0, + dht, + dO, + dv, + scale, + use_g, + use_initial_state, + use_final_state_gradient, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) dh_ref_torch = dh_ref_torch.cuda() dh0_ref_torch = dh0_ref_torch.cuda() dv2_ref_torch = dv2_ref_torch.cuda() @@ -521,31 +545,6 @@ def run_test( assert_similar(dv2_ref_torch, dv2_tilelang, 1e-5, "torch-tilelang", data="dv2") -def do_bench(fn, *args, warmup=10, rep=10, **kwargs): - """ - Do benchmark for a function. - """ - start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] - end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] - for _ in range(warmup): - fn(*args, **kwargs) - - torch.cuda.synchronize() - for i in range(rep): - start_event[i].record() - fn(*args, **kwargs) - end_event[i].record() - torch.cuda.synchronize() - - # Record clocks - times = torch.tensor( - [s.elapsed_time(e) for s, e in zip(start_event, end_event)], - dtype=torch.float, - ) - - return times.mean().item() - - def main(): DK = 128 run_test( @@ -554,11 +553,11 @@ def main(): H=8, DK=DK, DV=128, - input_dtype="bfloat16", - output_dtype="bfloat16", - accum_dtype="float32", - gate_dtype="float32", - state_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, chunk_size=64, scale=DK**-0.5, use_g=True, diff --git a/examples/gdn/example_chunk_delta_h.py b/examples/gdn/example_chunk_delta_h.py index 4d6b657ff..c34d9b530 100644 --- a/examples/gdn/example_chunk_delta_h.py +++ b/examples/gdn/example_chunk_delta_h.py @@ -3,12 +3,15 @@ import sys # noqa: F401 import tilelang import tilelang.language as T +from tilelang.autotuner import autotune +from tilelang.profiler import do_bench # Add your fla repository path to sys.path # Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h except ImportError: @@ -19,7 +22,7 @@ import torch.nn.functional as F from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401 -from utils import * +from test_utils import assert_similar # (zhengju) We can slightly modify the generated cuda code from tilelang lowering # in the debug folder to make the performance better. To enable this callback, @@ -55,6 +58,7 @@ def prepare_input( G = F.logsigmoid(G) try: from fla.ops.utils.cumsum import chunk_local_cumsum + G = chunk_local_cumsum(G, chunk_size) except ImportError: print("fla not found, skip cumsum") @@ -80,7 +84,21 @@ def prepare_output( return h, final_state, V_new -@tilelang.jit(out_idx=[-3, -2, -1]) +def get_configs(): + import itertools + + block_DK = [32, 64, 128] + block_DV = [32, 64, 128] + threads = [128, 256] + num_stages = [1, 2, 3] + _configs = list(itertools.product(block_DK, block_DV, threads, num_stages)) + + configs = [{"block_DK": c[0], "block_DV": c[1], "threads": c[2], "num_stages": c[3]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=3, rep=5) +@tilelang.jit(out_idx=[-3, -2, -1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) def tilelang_chunk_gated_delta_rule_fwd_h( # task config B, @@ -94,15 +112,15 @@ def tilelang_chunk_gated_delta_rule_fwd_h( gate_dtype, state_dtype, chunk_size, - use_g=True, - use_initial_state=True, - store_final_state=True, - save_new_value=True, + use_g, + use_initial_state, + store_final_state, + save_new_value, # kernel config block_DK=64, - block_DV=64, - threads=256, - num_stages=0, + block_DV=32, + threads=128, + num_stages=1, ): block_S = chunk_size BS = S // block_S @@ -118,14 +136,14 @@ def tilelang_chunk_gated_delta_rule_fwd_h( @T.prim_func def kernel( - K: T.Tensor(K_shape, dtype=input_dtype), - W: T.Tensor(W_shape, dtype=input_dtype), - U: T.Tensor(U_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - initial_state: T.Tensor(initial_state_shape, dtype=input_dtype), - h: T.Tensor(h_shape, dtype=output_dtype), - final_state: T.Tensor(final_state_shape, dtype=state_dtype), - V_new: T.Tensor(V_shape, dtype=output_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + U: T.Tensor(U_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + initial_state: T.Tensor(initial_state_shape, dtype=input_dtype), + h: T.Tensor(h_shape, dtype=output_dtype), + final_state: T.Tensor(final_state_shape, dtype=state_dtype), + V_new: T.Tensor(V_shape, dtype=output_dtype), ): with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): bb, bh = bbh // H, bbh % H @@ -139,39 +157,35 @@ def kernel( V_new_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) V_new_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype) K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) - G_last_local = T.alloc_local((1), dtype=gate_dtype) + G_last_local = T.alloc_var(T.float32) G_shared = T.alloc_shared((block_S, block_DV), dtype=gate_dtype) G_fragment = T.alloc_fragment((block_S, block_DV), dtype=gate_dtype) - T.annotate_layout({ - b_h_shared: tilelang.layout.make_swizzled_layout(b_h_shared), - U_shared: tilelang.layout.make_swizzled_layout(U_shared), - W_shared: tilelang.layout.make_swizzled_layout(W_shared), - V_new_shared: tilelang.layout.make_swizzled_layout(V_new_shared), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - G_shared: tilelang.layout.make_swizzled_layout(G_shared), - }) + T.annotate_layout( + { + U_shared: tilelang.layout.make_swizzled_layout(U_shared), + G_shared: tilelang.layout.make_swizzled_layout(G_shared), + } + ) T.use_swizzle(10) if use_initial_state: - T.copy(initial_state[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV], b_h_shared) + T.copy(initial_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_h_shared) T.copy(b_h_shared, b_h_fragment) else: T.clear(b_h_fragment) for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages): # Store previous result to the hidden tensor, like the epilogue - T.copy(b_h_shared, h[bb, i_s, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) + T.copy(b_h_shared, h[bb, i_s, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) # Recurrence - T.copy(W[bb, i_s * block_S:(i_s + 1) * block_S, bh, 0:DK], W_shared) + T.copy(W[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], W_shared) T.gemm(W_shared, b_h_shared, V_new_fragment, clear_accum=True) # U - W * S - T.copy( - U[bb, i_s * block_S:(i_s + 1) * block_S, bh, bv * block_DV:(bv + 1) * block_DV], - U_shared) + T.copy(U[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], U_shared) T.copy(U_shared, U_fragment) for i_s2, i_v in T.Parallel(block_S, block_DV): V_new_fragment[i_s2, i_v] = -V_new_fragment[i_s2, i_v] + U_fragment[i_s2, i_v] @@ -179,27 +193,24 @@ def kernel( # Save V_new if save_new_value: T.copy(V_new_fragment, dst=V_new_shared) - T.copy( - V_new_shared, V_new[bb, i_s * block_S:(i_s + 1) * block_S, bh, - bv * block_DV:(bv + 1) * block_DV]) + T.copy(V_new_shared, V_new[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) - T.copy(K[bb, i_s * block_S:(i_s + 1) * block_S, bh, 0:DK], K_shared) + T.copy(K[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], K_shared) # use_g if use_g: - G_last_local[0] = G[bb, (i_s + 1) * block_S - 1, bh] + G_last_local = G[bb, (i_s + 1) * block_S - 1, bh] for i_s2, i_v in T.Parallel(block_S, block_DV): G_shared[i_s2, i_v] = G[bb, i_s * block_S + i_s2, bh] T.copy(G_shared, G_fragment) for i_s2, i_v in T.Parallel(block_S, block_DV): - with T.If(G_last_local[0] - G_fragment[i_s2, i_v] <= 0): - with T.Then(): - V_new_fragment[i_s2, i_v] = V_new_fragment[i_s2, i_v] * T.exp( - G_last_local[0] - G_fragment[i_s2, i_v]) - with T.Else(): - V_new_fragment[i_s2, i_v] = 0 - G_last_local[0] = T.exp(G_last_local[0]) + V_new_fragment[i_s2, i_v] = ( + V_new_fragment[i_s2, i_v] * T.exp2((G_last_local - G_fragment[i_s2, i_v]) * 1.442695) + if G_last_local - G_fragment[i_s2, i_v] <= 0 + else 0 + ) + G_last_local = T.exp2(G_last_local * 1.442695) for i_k, i_v in T.Parallel(DK, block_DV): - b_h_fragment[i_k, i_v] *= G_last_local[0] + b_h_fragment[i_k, i_v] *= G_last_local # Update intermediate results T.copy(V_new_fragment, V_new_shared) @@ -209,36 +220,11 @@ def kernel( # Save final state if store_final_state: - T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) + T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) return kernel -def do_bench(fn, *args, warmup=10, rep=10, **kwargs): - """ - Do benchmark for a function. - """ - start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] - end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] - for _ in range(warmup): - fn(*args, **kwargs) - - torch.cuda.synchronize() - for i in range(rep): - start_event[i].record() - fn(*args, **kwargs) - end_event[i].record() - torch.cuda.synchronize() - - # Record clocks - times = torch.tensor( - [s.elapsed_time(e) for s, e in zip(start_event, end_event)], - dtype=torch.float, - ) - - return times.mean().item() - - def run_test( B, S, @@ -260,47 +246,77 @@ def run_test( threads=128, num_stages=0, ): - K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype)) - h_ref, final_state_ref, V_new_ref = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, state_dtype)) - h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, state_dtype)) + K, W, U, G, initial_state = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + ) + h_ref, final_state_ref, V_new_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype) + ) + h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype) + ) # fla ref - h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h(K, W, U, G, initial_state, - store_final_state, chunk_size, - save_new_value) + h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h( + k=K, + w=W, + u=U, + g=G, + initial_state=initial_state, + output_final_state=store_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value, + ) # tilelang - kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, chunk_size, - use_g, use_initial_state, store_final_state, - save_new_value, block_DK, block_DV, threads, - num_stages) + kernel = tilelang_chunk_gated_delta_rule_fwd_h( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_g, + use_initial_state, + store_final_state, + save_new_value, + ) h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) # (zhengju) If you want to print the generated cuda code, you can uncomment the following line # print("CUDA Code:\n", kernel.get_kernel_source()) - fla_time = do_bench(chunk_gated_delta_rule_fwd_h, K, W, U, G, initial_state, store_final_state, - chunk_size, save_new_value) + fla_time = do_bench( + chunk_gated_delta_rule_fwd_h, + k=K, + w=W, + u=U, + g=G, + initial_state=initial_state, + output_final_state=store_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value, + ) tilelang_time = do_bench(kernel, K, W, U, G, initial_state) # check correctness try: h_ref_fp32 = h_ref.to(torch.float32) h_tilelang_fp32 = h_tilelang.to(torch.float32) - assert_similar( - h_ref_fp32, - h_tilelang_fp32, - eps=1e-5, - name="tilelang chunk gated delta rule fwd h", - raise_assert=False) + assert_similar(h_ref_fp32, h_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd h", raise_assert=False) print("tilelang chunk gated delta rule fwd h passed √") except Exception as e: print("tilelang chunk gated delta rule fwd h failed ✗") @@ -314,7 +330,8 @@ def run_test( final_state_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd final_state", - raise_assert=False) + raise_assert=False, + ) print("tilelang chunk gated delta rule fwd final_state passed √") except Exception as e: print("tilelang chunk gated delta rule fwd final_state failed ✗") @@ -323,12 +340,7 @@ def run_test( try: V_new_ref_fp32 = V_new_ref.to(torch.float32) V_new_tilelang_fp32 = V_new_tilelang.to(torch.float32) - assert_similar( - V_new_ref_fp32, - V_new_tilelang_fp32, - eps=1e-5, - name="tilelang chunk gated delta rule fwd V_new", - raise_assert=False) + assert_similar(V_new_ref_fp32, V_new_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd V_new", raise_assert=False) print("tilelang chunk gated delta rule fwd V_new passed √") except Exception as e: print("tilelang chunk gated delta rule fwd V_new failed ✗") @@ -345,20 +357,20 @@ def main(): H=32, DK=128, DV=128, - input_dtype="bfloat16", - output_dtype="bfloat16", - accum_dtype="float32", - gate_dtype="float32", - state_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, chunk_size=64, use_g=True, - use_initial_state=True, + use_initial_state=False, store_final_state=True, save_new_value=True, - block_DK=64, + block_DK=32, block_DV=32, threads=128, - num_stages=1, + num_stages=2, ) diff --git a/examples/gdn/example_chunk_o.py b/examples/gdn/example_chunk_o.py index 1c084be70..bb95f555f 100644 --- a/examples/gdn/example_chunk_o.py +++ b/examples/gdn/example_chunk_o.py @@ -9,6 +9,7 @@ # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.common.chunk_o import chunk_fwd_o except ImportError: @@ -87,16 +88,14 @@ def tilelang_chunk_fwd_o( @T.prim_func def kernel( - Q: T.Tensor(Q_shape, dtype=input_dtype), - K: T.Tensor(K_shape, dtype=input_dtype), - V: T.Tensor(V_shape, dtype=input_dtype), - HIDDEN: T.Tensor(H_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - O: T.Tensor(O_shape, dtype=output_dtype), + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + HIDDEN: T.Tensor(H_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + O: T.Tensor(O_shape, dtype=output_dtype), ): - with T.Kernel( - T.ceildiv(DV, block_DV), T.ceildiv(S, block_S), B * H, - threads=threads) as (bv, bs, bbh): + with T.Kernel(T.ceildiv(DV, block_DV), T.ceildiv(S, block_S), B * H, threads=threads) as (bv, bs, bbh): bb, bh = bbh // H, bbh % H Q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) @@ -109,28 +108,13 @@ def kernel( G_shared = T.alloc_shared((block_S,), dtype=gate_dtype, scope="shared") G_diff_local = T.alloc_fragment((block_S, block_S), dtype=gate_dtype) - T.annotate_layout({ - Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - V_shared: tilelang.layout.make_swizzled_layout(V_shared), - H_shared: tilelang.layout.make_swizzled_layout(H_shared), - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - }) - T.clear(A_fragment) T.clear(O_fragment) T.disable_warp_group_reg_alloc() for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): - T.copy( - Q[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - Q_shared) - T.copy( - K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - K_shared) - T.copy( - HIDDEN[bb, bs, bh, i_k * block_DK:(i_k + 1) * block_DK, - bv * block_DV:(bv + 1) * block_DV], H_shared) + T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], Q_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) + T.copy(HIDDEN[bb, bs, bh, i_k * block_DK : (i_k + 1) * block_DK, bv * block_DV : (bv + 1) * block_DV], H_shared) T.gemm(Q_shared, H_shared, O_fragment) T.gemm(Q_shared, K_shared, A_fragment, transpose_B=True) @@ -143,20 +127,17 @@ def kernel( for i_s1, i_s2 in T.Parallel(block_S, block_S): G_diff_local[i_s1, i_s2] = G_shared[i_s1] - G_shared[i_s2] for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(G_diff_local[i_s1, i_s2] <= 0): - with T.Then(): - A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp( - G_diff_local[i_s1, i_s2]) - with T.Else(): - A_fragment[i_s1, i_s2] = 0 + A_fragment[i_s1, i_s2] = T.if_then_else( + G_diff_local[i_s1, i_s2] <= 0, + A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]), + 0, + ) for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(i_s1 < i_s2): # noqa: SIM117 - with T.Then(): - A_fragment[i_s1, i_s2] = 0 + if i_s1 < i_s2: + A_fragment[i_s1, i_s2] = 0 - T.copy(V[bb, bs * block_S:(bs + 1) * block_S, bh, bv * block_DV:(bv + 1) * block_DV], - V_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], V_shared) T.copy(A_fragment, A_shared) T.gemm(A_shared, V_shared, O_fragment) @@ -164,8 +145,7 @@ def kernel( O_fragment[i_s, i_v] = O_fragment[i_s, i_v] * scale T.copy(O_fragment, O_shared) - T.copy(O_shared, O[bb, bs * block_S:(bs + 1) * block_S, bh, - bv * block_DV:(bv + 1) * block_DV]) + T.copy(O_shared, O[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) return kernel @@ -191,8 +171,9 @@ def run_test( output_dtype_torch = getattr(torch, output_dtype) accum_dtype_torch = getattr(torch, accum_dtype) gate_dtype_torch = getattr(torch, gate_dtype) - Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, input_dtype_torch, - output_dtype_torch, accum_dtype_torch, gate_dtype_torch) + Q, K, V, HIDDEN, G = prepare_input( + B, S, H, DK, DV, chunk_size, input_dtype_torch, output_dtype_torch, accum_dtype_torch, gate_dtype_torch + ) scale = 1.0 / DK**0.5 O_ref = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch) @@ -200,9 +181,25 @@ def run_test( block_S = chunk_size O_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch) - kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV, - threads, num_stages) + kernel = tilelang_chunk_fwd_o( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + chunk_size, + scale, + use_g, + block_S, + block_DK, + block_DV, + threads, + num_stages, + ) O_tilelang = kernel(Q, K, V, HIDDEN, G) try: @@ -221,10 +218,10 @@ def main(): DK=128, DV=128, chunk_size=64, - input_dtype="bfloat16", - output_dtype="bfloat16", - accum_dtype="float32", - gate_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, use_g=True, block_DK=128, block_DV=128, diff --git a/examples/gdn/example_chunk_o_bwd.py b/examples/gdn/example_chunk_o_bwd.py index 7e87a2c4f..19233de62 100644 --- a/examples/gdn/example_chunk_o_bwd.py +++ b/examples/gdn/example_chunk_o_bwd.py @@ -12,6 +12,7 @@ # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.common.chunk_o import chunk_bwd_dqkwg except ImportError: @@ -19,7 +20,7 @@ fla = None import torch -from utils import * +from test_utils import assert_similar torch.random.manual_seed(0) # torch.set_printoptions(profile="full") @@ -108,10 +109,8 @@ def prepare_output( @tilelang.jit( out_idx=[-4, -3, -2, -1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - }) + pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, +) def tilelang_chunk_o_bwd_dqkwg( # task config B, @@ -155,25 +154,23 @@ def tilelang_chunk_o_bwd_dqkwg( @T.prim_func def kernel( - # input - Q: T.Tensor(Q_shape, dtype=input_dtype), - K: T.Tensor(K_shape, dtype=input_dtype), - V: T.Tensor(V_shape, dtype=input_dtype), - h: T.Tensor(h_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - dO: T.Tensor(dO_shape, dtype=input_dtype), - dh: T.Tensor(dh_shape, dtype=input_dtype), - dv: T.Tensor(dv_shape, dtype=input_dtype), - W: T.Tensor(W_shape, dtype=input_dtype), - # output - dq: T.Tensor(dq_shape, dtype=output_dtype), - dk: T.Tensor(dk_shape, dtype=output_dtype), - dw: T.Tensor(dw_shape, dtype=output_dtype), - dg: T.Tensor(dg_shape, dtype=gate_dtype), + # input + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + h: T.Tensor(h_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + dO: T.Tensor(dO_shape, dtype=input_dtype), + dh: T.Tensor(dh_shape, dtype=input_dtype), + dv: T.Tensor(dv_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + # output + dq: T.Tensor(dq_shape, dtype=output_dtype), + dk: T.Tensor(dk_shape, dtype=output_dtype), + dw: T.Tensor(dw_shape, dtype=output_dtype), + dg: T.Tensor(dg_shape, dtype=gate_dtype), ): - with T.Kernel( - T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H, - threads=threads) as (bk, bs, bbh): + with T.Kernel(T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H, threads=threads) as (bk, bs, bbh): bb, bh = bbh // H, bbh % H V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) @@ -202,27 +199,27 @@ def kernel( dg_fragment = T.alloc_fragment((block_S,), dtype=gate_dtype) dg_fragment_2 = T.alloc_fragment((block_S,), dtype=gate_dtype) dg_fragment_final = T.alloc_fragment((block_S,), dtype=gate_dtype) - dg_last_local = T.alloc_local((2,), dtype=gate_dtype) + dg_last_local_0 = T.alloc_var(dtype=gate_dtype) + dg_last_local_1 = T.alloc_var(dtype=gate_dtype) + G_last_local = T.alloc_var(dtype=gate_dtype) + dg_last_fragment = T.alloc_fragment((block_DV * block_DK), dtype=gate_dtype) dg_last_fragment_scalar = T.alloc_fragment((1,), dtype=gate_dtype) dg_last_fragment_2 = T.alloc_fragment((block_S * block_DK), dtype=gate_dtype) dg_last_fragment_scalar_2 = T.alloc_fragment((1,), dtype=gate_dtype) - G_shared = T.alloc_shared((block_S, block_DK), dtype=gate_dtype, scope="shared") - G_last_local = T.alloc_local((1,), dtype=gate_dtype) + G_shared = T.alloc_shared((block_S, block_DK), dtype=gate_dtype) T.use_swizzle(10) - T.annotate_layout({ - V_shared: tilelang.layout.make_swizzled_layout(V_shared), - dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), - h_shared: tilelang.layout.make_swizzled_layout(h_shared), - dh_shared: tilelang.layout.make_swizzled_layout(dh_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - q_shared: tilelang.layout.make_swizzled_layout(q_shared), - k_shared: tilelang.layout.make_swizzled_layout(k_shared), - }) - - T.clear(dg_last_local) + T.annotate_layout( + { + q_shared: tilelang.layout.make_swizzled_layout(q_shared), + k_shared: tilelang.layout.make_swizzled_layout(k_shared), + } + ) + + T.clear(dg_last_local_0) + T.clear(dg_last_local_1) T.clear(G_last_local) T.clear(G_shared) T.clear(q_fragment) @@ -235,18 +232,10 @@ def kernel( T.clear(dw_fragment) for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): - T.copy( - V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV], - V_shared) - T.copy( - dO[bb, bs * block_S:(bs + 1) * block_S, bh, - i_v * block_DV:(i_v + 1) * block_DV], dO_shared) - T.copy( - h[bb, bs, bh, bk * block_DK:(bk + 1) * block_DK, - i_v * block_DV:(i_v + 1) * block_DV], h_shared) - T.copy( - dh[bb, bs, bh, bk * block_DK:(bk + 1) * block_DK, - i_v * block_DV:(i_v + 1) * block_DV], dh_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) + T.copy(dO[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], dO_shared) + T.copy(h[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], h_shared) + T.copy(dh[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], dh_shared) if use_g: T.clear(dg_last_fragment_scalar) @@ -254,33 +243,25 @@ def kernel( # for i_kv in T.Parallel(block_DK * block_DV): # dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] for i_kv in T.Parallel(block_DK * block_DV): - dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % - block_DV] * dh_shared[i_kv // block_DV, - i_kv % block_DV] + dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False) - dg_last_local[0] += dg_last_fragment_scalar[0] + dg_last_local_0 = dg_last_local_0 + dg_last_fragment_scalar[0] T.gemm(dO_shared, V_shared, ds_fragment, transpose_B=True) T.gemm(dO_shared, h_shared, dq_fragment, transpose_B=True) T.gemm(V_shared, dh_shared, dk_fragment, transpose_B=True) if use_dw: - T.copy( - dv[bb, bs * block_S:(bs + 1) * block_S, bh, - i_v * block_DV:(i_v + 1) * block_DV], dv_shared) + T.copy(dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], dv_shared) T.gemm(dv_shared, h_shared, dw_fragment, transpose_B=True) if use_dw: for i_s, i_k in T.Parallel(block_S, block_DK): dw_fragment[i_s, i_k] = -dw_fragment[i_s, i_k] - T.copy( - dw_fragment, dw[bb, bs * block_S:(bs + 1) * block_S, bh, - bk * block_DK:(bk + 1) * block_DK]) - - T.copy(Q[bb, bs * block_S:(bs + 1) * block_S, bh, bk * block_DK:(bk + 1) * block_DK], - q_shared) - T.copy(K[bb, bs * block_S:(bs + 1) * block_S, bh, bk * block_DK:(bk + 1) * block_DK], - k_shared) + T.copy(dw_fragment, dw[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + + T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], q_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], k_shared) T.copy(q_shared, q_fragment) T.copy(k_shared, k_fragment) @@ -289,13 +270,12 @@ def kernel( T.clear(dg_fragment_2) for i_s, i_k in T.Parallel(block_S, block_DK): G_shared[i_s, i_k] = G[bb, bs * block_S + i_s, bh] - G_last_local[0] = G[bb, bs * block_S + block_S - 1, bh] + dg_last_local_0 = G[bb, bs * block_S + block_S - 1, bh] # Use gmem directly instead of local register - dg_last_local[0] = dg_last_local[0] * T.exp(G[bb, bs * block_S + block_S - 1, bh]) + dg_last_local_0 = dg_last_local_0 * T.exp(G[bb, bs * block_S + block_S - 1, bh]) for i_s, i_k in T.Parallel(block_S, block_DK): - dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * T.exp(G[bb, bs * block_S + i_s, - bh]) * scale + dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * T.exp(G[bb, bs * block_S + i_s, bh]) * scale T.clear(dg_fragment_reduce_tmp) for i_s, i_k in T.Parallel(block_S, block_DK): dg_fragment_reduce_tmp[i_s, i_k] = dq_fragment[i_s, i_k] * q_shared[i_s, i_k] @@ -303,12 +283,11 @@ def kernel( T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=-1, clear=False) for i_s, i_k in T.Parallel(block_S, block_DK): - with T.If(G_last_local[0] - G[bb, bs * block_S + i_s, bh] <= 0): - with T.Then(): - dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] * T.exp( - G_last_local[0] - G[bb, bs * block_S + i_s, bh]) - with T.Else(): - dk_fragment[i_s, i_k] = 0 + dk_fragment[i_s, i_k] = ( + dk_fragment[i_s, i_k] * T.exp(G_last_local - G[bb, bs * block_S + i_s, bh]) + if G_last_local - G[bb, bs * block_S + i_s, bh] <= 0 + else 0 + ) T.clear(dg_fragment_reduce_tmp) for i_s, i_k in T.Parallel(block_S, block_DK): dg_fragment_reduce_tmp[i_s, i_k] = dk_fragment[i_s, i_k] * (-k_shared[i_s, i_k]) @@ -322,24 +301,20 @@ def kernel( i_s, i_k = i_sk // block_DK, i_sk % block_DK dg_last_fragment_2[i_sk] = dk_shared[i_s, i_k] * k_shared[i_s, i_k] T.reduce_sum(dg_last_fragment_2, dg_last_fragment_scalar_2, dim=-1, clear=False) - dg_last_local[1] = dg_last_fragment_scalar_2[0] + dg_last_local_1 = dg_last_fragment_scalar_2[0] for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(i_s1 >= i_s2 and - G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0): - with T.Then(): - ds_fragment[i_s1, i_s2] = ds_fragment[ - i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - - G[bb, bs * block_S + i_s2, bh]) * scale - with T.Else(): - ds_fragment[i_s1, i_s2] = 0 + ds_fragment[i_s1, i_s2] = ( + (ds_fragment[i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]) * scale) + if G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0 + else 0 + ) T.clear(ds_fragment_positive) T.clear(ds_fragment_positive_transpose) T.gemm(q_shared, k_shared, ds_fragment_positive, transpose_B=True) for i_s1, i_s2 in T.Parallel(block_S, block_S): - ds_fragment_positive[ - i_s1, i_s2] = ds_fragment[i_s1, i_s2] * ds_fragment_positive[i_s1, i_s2] + ds_fragment_positive[i_s1, i_s2] = ds_fragment[i_s1, i_s2] * ds_fragment_positive[i_s1, i_s2] # FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass T.reduce_sum(ds_fragment_positive, dg_fragment, dim=1, clear=False) @@ -361,25 +336,16 @@ def kernel( T.gemm(ds_shared, q_shared, dk_fragment, transpose_A=True) for i_s in T.Parallel(block_S): - with T.If(i_s >= block_S - 1): # noqa: SIM117 - with T.Then(): - dg_fragment_final[ - i_s] = dg_fragment_final[i_s] + dg_last_local[0] + dg_last_local[1] - - T.copy( - dq_fragment, dq[bb, bs * block_S:(bs + 1) * block_S, bh, - bk * block_DK:(bk + 1) * block_DK]) - T.copy( - dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh, - bk * block_DK:(bk + 1) * block_DK]) + dg_fragment_final[i_s] = dg_fragment_final[i_s] + dg_last_local_0 + dg_last_local_1 + + T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) for i_s in T.Parallel(block_S): dg[bk, bb, bs * block_S + i_s, bh] = dg_fragment_final[i_s] else: for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(i_s1 < i_s2): # noqa: SIM117 - with T.Then(): - ds_fragment[i_s1, i_s2] = 0 + ds_fragment[i_s1, i_s2] = 0 if i_s1 < i_s2 else ds_fragment[i_s1, i_s2] T.clear(dk_fragment_2) T.copy(ds_fragment, ds_shared) T.gemm(ds_shared, k_shared, dq_fragment) @@ -387,41 +353,12 @@ def kernel( for i_s, i_k in T.Parallel(block_S, block_DK): dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * scale dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] + dk_fragment_2[i_s, i_k] * scale - T.copy( - dq_fragment, dq[bb, bs * block_S:(bs + 1) * block_S, bh, - bk * block_DK:(bk + 1) * block_DK]) - T.copy( - dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh, - bk * block_DK:(bk + 1) * block_DK]) + T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) return kernel -def do_bench(fn, *args, warmup=10, rep=10, **kwargs): - """ - Do benchmark for a function. - """ - start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] - end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] - for _ in range(warmup): - fn(*args, **kwargs) - - torch.cuda.synchronize() - for i in range(rep): - start_event[i].record() - fn(*args, **kwargs) - end_event[i].record() - torch.cuda.synchronize() - - # Record clocks - times = torch.tensor( - [s.elapsed_time(e) for s, e in zip(start_event, end_event)], - dtype=torch.float, - ) - - return times.mean().item() - - def run_test( B, S, @@ -442,33 +379,53 @@ def run_test( threads=256, num_stages=0, ): - Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - dq_ref, dk_ref, dw_ref, dg_ref = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype), block_DK) + Q, K, V, h, G, dO, dh, dv, W = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dq_ref, dk_ref, dw_ref, dg_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype), block_DK + ) dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = prepare_output( - B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), - getattr(torch, state_dtype), block_DK) + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype), block_DK + ) # ref if use_g: - dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg( - Q, K, V, G, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale) + dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(Q, K, V, G, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale) else: - dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg( - Q, K, V, None, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale) + dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(Q, K, V, None, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale) # tilelang - kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, state_dtype, chunk_size, scale, use_g, use_dw, - block_DK, block_DV, threads, num_stages) - print(kernel.get_kernel_source()) + kernel = tilelang_chunk_o_bwd_dqkwg( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_g, + use_dw, + block_DK, + block_DV, + threads, + num_stages, + ) dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W) if use_g: @@ -515,11 +472,11 @@ def main(): H=8, DK=DK, DV=DV, - input_dtype="bfloat16", - output_dtype="bfloat16", - accum_dtype="float32", - gate_dtype="float32", - state_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, chunk_size=64, scale=DK**-0.5, # scale=1, diff --git a/examples/gdn/example_chunk_scaled_dot_kkt.py b/examples/gdn/example_chunk_scaled_dot_kkt.py index d07a4776a..c16374fe8 100644 --- a/examples/gdn/example_chunk_scaled_dot_kkt.py +++ b/examples/gdn/example_chunk_scaled_dot_kkt.py @@ -9,6 +9,7 @@ # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd except ImportError: @@ -56,9 +57,9 @@ def tilelang_chunk_scaled_dot_kkt_fwd( H, DK, chunk_size=64, - input_dtype="bfloat16", - output_dtype="bfloat16", - accum_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, use_g=True, # kernel config block_S=64, @@ -75,10 +76,10 @@ def tilelang_chunk_scaled_dot_kkt_fwd( @T.prim_func def kernel( - K: T.Tensor(K_shape, dtype=input_dtype), - Beta: T.Tensor(Beta_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=accum_dtype), - A: T.Tensor(output_shape, dtype=output_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=accum_dtype), + A: T.Tensor(output_shape, dtype=output_dtype), ): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): bb, bh = bbh // H, bbh % H @@ -93,20 +94,13 @@ def kernel( G_shared = T.alloc_shared((block_S,), dtype=accum_dtype, scope="shared") G_diff_local = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - }) - T.fill(A_fragment, 0) T.disable_warp_group_reg_alloc() for i_s in T.Parallel(block_S): Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): - T.copy( - K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - K_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) for i_s, i_k2 in T.Parallel(block_S, block_DK): Beta_K_fragment[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] T.gemm(Beta_K_fragment, K_shared, A_fragment, transpose_B=True) @@ -117,20 +111,18 @@ def kernel( for i_s1, i_s2 in T.Parallel(block_S, block_S): G_diff_local[i_s1, i_s2] = G_shared[i_s1] - G_shared[i_s2] for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2): - with T.Then(): - A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp( - G_diff_local[i_s1, i_s2]) - with T.Else(): - A_fragment[i_s1, i_s2] = 0 + A_fragment[i_s1, i_s2] = T.if_then_else( + G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2, + A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]), + 0, + ) else: for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(i_s1 <= i_s2): # noqa: SIM117 - with T.Then(): - A_fragment[i_s1, i_s2] = 0 + if i_s1 <= i_s2: + A_fragment[i_s1, i_s2] = 0 T.copy(A_fragment, A_shared) - T.copy(A_shared, A[bb, bs * block_S:(bs + 1) * block_S, bh, :]) + T.copy(A_shared, A[bb, bs * block_S : (bs + 1) * block_S, bh, :]) return kernel @@ -149,24 +141,21 @@ def run_test( threads, num_stages, ): - K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), - getattr(torch, output_dtype), getattr(torch, accum_dtype)) + K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype)) A_ref = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype)) A_tilelang = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype)) # reference if use_g: - A_ref = chunk_scaled_dot_kkt_fwd( - K, Beta, G, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype)) + A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, G, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype)) else: - A_ref = chunk_scaled_dot_kkt_fwd( - K, Beta, None, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype)) + A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, None, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype)) # tilelang block_S = chunk_size - kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype, - accum_dtype, use_g, block_S, block_DK, threads, - num_stages) + kernel = tilelang_chunk_scaled_dot_kkt_fwd( + B, S, H, DK, chunk_size, input_dtype, output_dtype, accum_dtype, use_g, block_S, block_DK, threads, num_stages + ) A_tilelang = kernel(K, Beta, G) try: @@ -186,13 +175,14 @@ def main(): H=32, DK=128, chunk_size=64, - input_dtype="bfloat16", - output_dtype="bfloat16", - accum_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, use_g=True, block_DK=64, threads=128, - num_stages=2) + num_stages=2, + ) if __name__ == "__main__": diff --git a/examples/gdn/example_cumsum.py b/examples/gdn/example_cumsum.py index 9896c7ecf..0760b4964 100644 --- a/examples/gdn/example_cumsum.py +++ b/examples/gdn/example_cumsum.py @@ -10,6 +10,7 @@ # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.utils.cumsum import chunk_local_cumsum_scalar except ImportError: @@ -20,11 +21,8 @@ @tilelang.jit( - out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - }) + out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} +) def tilelang_chunk_local_cumsum_scalar( # task config B, @@ -34,43 +32,43 @@ def tilelang_chunk_local_cumsum_scalar( is_varlen=False, head_first=False, reverse=False, - input_dtype="float16", - output_dtype="float32", + input_dtype=T.float16, + output_dtype=T.float32, # kernel config block_S=64, threads=256, use_fragment=False, ): G_shape = (B, H, S) if head_first else (B, S, H) - assert chunk_size == 2**(chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" assert chunk_size == block_S, "chunk_size must be equal to block_S" @T.prim_func def kernel( - G: T.Tensor(G_shape, dtype=input_dtype), - G_new: T.Tensor(G_shape, dtype=output_dtype), + G: T.Tensor(G_shape, dtype=input_dtype), + G_new: T.Tensor(G_shape, dtype=output_dtype), ): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): bb, bh = bbh // H, bbh % H G_shared = T.alloc_shared((1, block_S), dtype=output_dtype, scope="shared") if head_first: - T.copy(G[bb, bh, bs * block_S:(bs + 1) * block_S], G_shared) + T.copy(G[bb, bh, bs * block_S : (bs + 1) * block_S], G_shared) else: - T.copy(G[bb, bs * block_S:(bs + 1) * block_S, bh], G_shared) + T.copy(G[bb, bs * block_S : (bs + 1) * block_S, bh], G_shared) if use_fragment: G_fragment = T.alloc_fragment((1, block_S), dtype=output_dtype, scope="shared") T.copy(G_shared, G_fragment) T.cumsum(G_fragment, dim=1, reverse=reverse) if head_first: - T.copy(G_fragment, G_new[bb, bh, bs * block_S:(bs + 1) * block_S]) + T.copy(G_fragment, G_new[bb, bh, bs * block_S : (bs + 1) * block_S]) else: - T.copy(G_fragment, G_new[bb, bs * block_S:(bs + 1) * block_S, bh]) + T.copy(G_fragment, G_new[bb, bs * block_S : (bs + 1) * block_S, bh]) else: T.cumsum(G_shared, dim=1, reverse=reverse) if head_first: - T.copy(G_shared, G_new[bb, bh, bs * block_S:(bs + 1) * block_S]) + T.copy(G_shared, G_new[bb, bh, bs * block_S : (bs + 1) * block_S]) else: - T.copy(G_shared, G_new[bb, bs * block_S:(bs + 1) * block_S, bh]) + T.copy(G_shared, G_new[bb, bs * block_S : (bs + 1) * block_S, bh]) return kernel @@ -113,11 +111,8 @@ def run_test( # reference cumsum G_new_ref = chunk_local_cumsum_scalar( - g=G, - chunk_size=chunk_size, - reverse=reverse, - head_first=head_first, - output_dtype=getattr(torch, output_dtype)) + g=G, chunk_size=chunk_size, reverse=reverse, head_first=head_first, output_dtype=getattr(torch, output_dtype) + ) # tilelang cumsum block_S = chunk_size @@ -159,10 +154,11 @@ def main(): chunk_size=64, reverse=True, head_first=False, - input_dtype="float32", - output_dtype="float32", + input_dtype=T.float32, + output_dtype=T.float32, threads=256, - use_fragment=False) + use_fragment=False, + ) if __name__ == "__main__": diff --git a/examples/gdn/example_wy_fast.py b/examples/gdn/example_wy_fast.py index 0a0983a82..d36dcf9b7 100644 --- a/examples/gdn/example_wy_fast.py +++ b/examples/gdn/example_wy_fast.py @@ -9,6 +9,7 @@ # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd except ImportError: @@ -73,13 +74,13 @@ def tilelang_recompute_w_u_fwd( @T.prim_func def kernel( - K: T.Tensor(K_shape, dtype=input_dtype), - V: T.Tensor(V_shape, dtype=input_dtype), - Beta: T.Tensor(Beta_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - A: T.Tensor(A_shape, dtype=output_dtype), - W: T.Tensor(K_shape, dtype=output_dtype), - U: T.Tensor(V_shape, dtype=output_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + A: T.Tensor(A_shape, dtype=output_dtype), + W: T.Tensor(K_shape, dtype=output_dtype), + U: T.Tensor(V_shape, dtype=output_dtype), ): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): bb, bh = bbh // H, bbh % H @@ -95,49 +96,37 @@ def kernel( W_Beta_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) U_Beta_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - V_shared: tilelang.layout.make_swizzled_layout(V_shared), - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - W_shared: tilelang.layout.make_swizzled_layout(W_shared), - U_shared: tilelang.layout.make_swizzled_layout(U_shared), - W_Beta_shared: tilelang.layout.make_swizzled_layout(W_Beta_shared), - U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared), - }) + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + V_shared: tilelang.layout.make_swizzled_layout(V_shared), + } + ) T.disable_warp_group_reg_alloc() for i_s in T.Parallel(block_S): Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] G_shared[i_s] = T.exp(G[bb, bs * block_S + i_s, bh]) - T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared) + T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared) for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): - T.copy( - V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV], - V_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) for i_s, i_v2 in T.Parallel(block_S, block_DV): U_Beta_shared[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s] T.gemm(A_shared, U_Beta_shared, U_fragment, clear_accum=True) # First copy to smem, then copy to gmem to reduce U2RU instructions T.copy(U_fragment, U_shared) - T.copy( - U_shared, U[bb, bs * block_S:(bs + 1) * block_S, bh, - i_v * block_DV:(i_v + 1) * block_DV]) + T.copy(U_shared, U[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV]) for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): - T.copy( - K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - K_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) for i_s, i_k2 in T.Parallel(block_S, block_DK): - W_Beta_shared[i_s, - i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared[i_s] + W_Beta_shared[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared[i_s] T.gemm(A_shared, W_Beta_shared, W_fragment, clear_accum=True) # First copy to smem, then copy to gmem to reduce U2RU instructions T.copy(W_fragment, W_shared) - T.copy( - W_shared, W[bb, bs * block_S:(bs + 1) * block_S, bh, - i_k * block_DK:(i_k + 1) * block_DK]) + T.copy(W_shared, W[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK]) return kernel @@ -159,15 +148,8 @@ def run_test( num_stages, ): K, V, Beta, G, A = prepare_input( - B, - S, - H, - DK, - DV, - chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - gate_dtype=getattr(torch, gate_dtype)) + B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, output_dtype), gate_dtype=getattr(torch, gate_dtype) + ) W_ref, U_ref = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype)) W_tilelang, U_tilelang = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype)) @@ -191,7 +173,8 @@ def run_test( block_DK=block_DK, block_DV=block_DV, threads=threads, - num_stages=num_stages) + num_stages=num_stages, + ) print(kernel.get_kernel_source()) W_tilelang, U_tilelang = kernel(K, V, Beta, G, A) @@ -217,14 +200,15 @@ def main(): DK=128, DV=128, chunk_size=64, - input_dtype="bfloat16", - output_dtype="bfloat16", - gate_dtype="float32", - accum_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + gate_dtype=T.float32, + accum_dtype=T.float32, block_DK=64, block_DV=32, threads=128, - num_stages=3) + num_stages=3, + ) if __name__ == "__main__": diff --git a/examples/gdn/example_wy_fast_bwd_split.py b/examples/gdn/example_wy_fast_bwd_split.py index 618a82b4c..822f745f2 100644 --- a/examples/gdn/example_wy_fast_bwd_split.py +++ b/examples/gdn/example_wy_fast_bwd_split.py @@ -10,6 +10,7 @@ # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr except ImportError: @@ -93,10 +94,8 @@ def prepare_output( @tilelang.jit( out_idx=[-5, -4, -3, -2, -1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - }) + pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, +) def tilelang_wy_fast_bwd( # task config B, @@ -135,20 +134,20 @@ def tilelang_wy_fast_bwd( @T.prim_func def kernel( - # input - K: T.Tensor(K_shape, dtype=input_dtype), - V: T.Tensor(V_shape, dtype=input_dtype), - Beta: T.Tensor(Beta_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - A: T.Tensor(A_shape, dtype=input_dtype), - dw: T.Tensor(dw_shape, dtype=input_dtype), - du: T.Tensor(du_shape, dtype=input_dtype), - # output - dA: T.Tensor(dA_shape, dtype=input_dtype), - dk: T.Tensor(dk_shape, dtype=output_dtype), - dv: T.Tensor(dv_shape, dtype=output_dtype), - dbeta: T.Tensor(dbeta_shape, dtype=output_dtype), - dg: T.Tensor(dg_shape, dtype=gate_dtype), + # input + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + A: T.Tensor(A_shape, dtype=input_dtype), + dw: T.Tensor(dw_shape, dtype=input_dtype), + du: T.Tensor(du_shape, dtype=input_dtype), + # output + dA: T.Tensor(dA_shape, dtype=input_dtype), + dk: T.Tensor(dk_shape, dtype=output_dtype), + dv: T.Tensor(dv_shape, dtype=output_dtype), + dbeta: T.Tensor(dbeta_shape, dtype=output_dtype), + dg: T.Tensor(dg_shape, dtype=gate_dtype), ): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): bb, bh = bbh // H, bbh % H @@ -187,7 +186,7 @@ def kernel( T.clear(dbeta_fragment_v) T.clear(dg_fragment) - T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared) + T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared) for i_s in T.Parallel(block_S): Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] G_shared[i_s] = G[bb, bs * block_S + i_s, bh] @@ -195,51 +194,37 @@ def kernel( # Update dk for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): - T.copy( - K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - K_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) for i_s, i_k2 in T.Parallel(block_S, block_DK): - K_shared_beta_g[i_s, - i_k2] = K_shared[i_s, - i_k2] * Beta_shared[i_s] * G_shared_exp[i_s] - T.copy( - dw[bb, bs * block_S:(bs + 1) * block_S, bh, - i_k * block_DK:(i_k + 1) * block_DK], dw_shared) + K_shared_beta_g[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s] + T.copy(dw[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], dw_shared) T.gemm(dw_shared, K_shared_beta_g, dA_fragment, transpose_B=True) T.gemm(A_shared, dw_shared, dk_fragment_beta_g, clear_accum=True, transpose_A=True) for i_s, i_k2 in T.Parallel(block_S, block_DK): - dk_fragment[ - i_s, - i_k2] = dk_fragment_beta_g[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s] + dk_fragment[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s] # for i_s, i_k2 in T.Parallel(block_S, block_DK): # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] for i_s, i_k2 in T.Parallel(block_S, block_DK): - dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta_g[ - i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] + dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False) # for i_s, i_k2 in T.Parallel(block_S, block_DK): # dg_fragment[i_s] = dg_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s] for i_s, i_k2 in T.Parallel(block_S, block_DK): - dg_fragment_reduce_tmp[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * K_shared[ - i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s] + dg_fragment_reduce_tmp[i_s, i_k2] = ( + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s] + ) T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=1, clear=False) # correct dk - T.copy( - dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh, - i_k * block_DK:(i_k + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK]) # Update dv for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): - T.copy( - V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV], - V_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) for i_s, i_v2 in T.Parallel(block_S, block_DV): V_shared_beta[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s] - T.copy( - du[bb, bs * block_S:(bs + 1) * block_S, bh, - i_v * block_DV:(i_v + 1) * block_DV], du_shared) + T.copy(du[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], du_shared) T.gemm(du_shared, V_shared_beta, dA_fragment, transpose_B=True) T.gemm(A_shared, du_shared, dv_fragment_beta, clear_accum=True, transpose_A=True) for i_s, i_v2 in T.Parallel(block_S, block_DV): @@ -247,30 +232,22 @@ def kernel( # for i_s, i_v2 in T.Parallel(block_S, block_DV): # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2] for i_s, i_v2 in T.Parallel(block_S, block_DV): - dbeta_fragment_reduce_tmpv[i_s, - i_v2] = dv_fragment_beta[i_s, i_v2] * V_shared[i_s, - i_v2] + dbeta_fragment_reduce_tmpv[i_s, i_v2] = dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2] T.reduce_sum(dbeta_fragment_reduce_tmpv, dbeta_fragment_v, dim=1, clear=False) - T.copy( - dv_fragment, dv[bb, bs * block_S:(bs + 1) * block_S, bh, - i_v * block_DV:(i_v + 1) * block_DV]) + T.copy(dv_fragment, dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV]) # Temporary store dbeta, dg and dA for i_s in T.Parallel(block_S): dbeta[bb, bs * block_S + i_s, bh] = dbeta_fragment_k[i_s] + dbeta_fragment_v[i_s] dg[bb, bs * block_S + i_s, bh] = dg_fragment[i_s] # correct dA - T.copy(dA_fragment, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :]) + T.copy(dA_fragment, dA[bb, bs * block_S : (bs + 1) * block_S, bh, :]) return kernel -@tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - }) +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}) def tilelang_wy_fast_bwd_split( # task config B, @@ -308,20 +285,20 @@ def tilelang_wy_fast_bwd_split( @T.prim_func def kernel( - # input - K: T.Tensor(K_shape, dtype=input_dtype), - V: T.Tensor(V_shape, dtype=input_dtype), - Beta: T.Tensor(Beta_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - A: T.Tensor(A_shape, dtype=input_dtype), - dw: T.Tensor(dw_shape, dtype=input_dtype), - du: T.Tensor(du_shape, dtype=input_dtype), - dA: T.Tensor(dA_shape, dtype=input_dtype), - dk: T.Tensor(dk_shape, dtype=output_dtype), - dv: T.Tensor(dv_shape, dtype=output_dtype), - dbeta_k: T.Tensor(dbeta_shape, dtype=output_dtype), - dg_A_positive: T.Tensor(dA_shape, dtype=gate_dtype), - dg_A_negative: T.Tensor(dA_shape, dtype=gate_dtype), + # input + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + A: T.Tensor(A_shape, dtype=input_dtype), + dw: T.Tensor(dw_shape, dtype=input_dtype), + du: T.Tensor(du_shape, dtype=input_dtype), + dA: T.Tensor(dA_shape, dtype=input_dtype), + dk: T.Tensor(dk_shape, dtype=output_dtype), + dv: T.Tensor(dv_shape, dtype=output_dtype), + dbeta_k: T.Tensor(dbeta_shape, dtype=output_dtype), + dg_A_positive: T.Tensor(dA_shape, dtype=gate_dtype), + dg_A_negative: T.Tensor(dA_shape, dtype=gate_dtype), ): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): bb, bh = bbh // H, bbh % H @@ -350,7 +327,7 @@ def kernel( T.clear(dA_A_fragment_1) T.clear(dA_A_fragment_2) - T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared) + T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared) for i_s in T.Parallel(block_S): Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] G_shared[i_s] = G[bb, bs * block_S + i_s, bh] @@ -361,34 +338,32 @@ def kernel( # for i_s in T.Parallel(block_S): # dbeta_fragment[i_s] = dbeta[bb, bs * block_S + i_s, bh] # dg_fragment[i_s] = dg[bb, bs * block_S + i_s, bh] - T.copy(dA[bb, bs * block_S:(bs + 1) * block_S, bh, :], dA_shared) + T.copy(dA[bb, bs * block_S : (bs + 1) * block_S, bh, :], dA_shared) # T.copy(dA_shared, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :]) # Update dA T.copy(dA_shared, dA_fragment) for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(i_s1 <= i_s2): # noqa: SIM117 - with T.Then(): - dA_fragment[i_s1, i_s2] = 0 + if i_s1 <= i_s2: + dA_fragment[i_s1, i_s2] = 0 T.copy(dA_fragment, dA_shared) T.gemm(dA_shared, A_shared, dA_fragment, clear_accum=True, transpose_B=True) T.copy(dA_fragment, dA_shared) T.gemm(A_shared, dA_shared, dA_fragment, clear_accum=True, transpose_A=True) for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(i_s1 <= i_s2): - with T.Then(): - dA_fragment[i_s1, i_s2] = 0 - with T.Else(): - dA_fragment[i_s1, i_s2] = -dA_fragment[i_s1, i_s2] + dA_fragment[i_s1, i_s2] = T.if_then_else( + i_s1 <= i_s2, + 0, + -dA_fragment[i_s1, i_s2], + ) for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0): - with T.Then(): - dA_fragment[i_s1, i_s2] *= T.exp(G[bb, bs * block_S + i_s1, bh] - - G[bb, bs * block_S + i_s2, bh]) - with T.Else(): - dA_fragment[i_s1, i_s2] = 0 + dA_fragment[i_s1, i_s2] = T.if_then_else( + G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0, + dA_fragment[i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]), + 0, + ) T.copy(dA_fragment, dA_shared) # acceptable dA diff @@ -397,12 +372,8 @@ def kernel( # Update dk using previous dk T.clear(A_fragment) for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): - T.copy( - K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - K_shared) - T.copy( - dk[bb, bs * block_S:(bs + 1) * block_S, bh, - i_k * block_DK:(i_k + 1) * block_DK], dk_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) + T.copy(dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], dk_shared) T.copy(dk_shared, dk_fragment) for i_s, i_k2 in T.Parallel(block_S, block_DK): K_shared_beta[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] @@ -411,18 +382,14 @@ def kernel( # for i_s, i_k2 in T.Parallel(block_S, block_DK): # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2] for i_s, i_k2 in T.Parallel(block_S, block_DK): - dbeta_fragment_reduce_tmpk[i_s, - i_k2] = dk_fragment_beta[i_s, i_k2] * K_shared[i_s, - i_k2] + dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2] T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False) T.gemm(dA_shared, K_shared_beta, dk_fragment, transpose_A=True) for i_s, i_k2 in T.Parallel(block_S, block_DK): dk_shared_beta[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * Beta_shared[i_s] for i_s, i_k2 in T.Parallel(block_S, block_DK): dk_fragment[i_s, i_k2] = dk_fragment[i_s, i_k2] + dk_shared_beta[i_s, i_k2] - T.copy( - dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh, - i_k * block_DK:(i_k + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK]) # Update dg and dbeta T.copy(A_fragment, A_shared) @@ -460,19 +427,25 @@ def run_test( threads=128, num_stages=0, ): - K, V, Beta, G, A, dw, du = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, - accum_dtype), getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - dk_ref, dv_ref, dbeta_ref, dg_ref = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) + K, V, Beta, G, A, dw, du = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dk_ref, dv_ref, dbeta_ref, dg_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output( - B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), - getattr(torch, state_dtype)) + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) BS = chunk_size dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda() dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda() @@ -480,28 +453,55 @@ def run_test( dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda() # ref - dk_ref, dv_ref, dbeta_ref, dg_ref = bwd_prepare_wy_repr( - K, V, G, Beta, A, dw, du, cu_seqlens=None) + dk_ref, dv_ref, dbeta_ref, dg_ref = bwd_prepare_wy_repr(K, V, G, Beta, A, dw, du, cu_seqlens=None) # tilelang - kernel = tilelang_wy_fast_bwd(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, state_dtype, chunk_size, block_DK, block_DV, threads, - num_stages) - dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel( - K, V, Beta, G, A, dw, du) + kernel = tilelang_wy_fast_bwd( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(K, V, Beta, G, A, dw, du) torch.cuda.synchronize() - kernel_split = tilelang_wy_fast_bwd_split(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, chunk_size, - block_DK, block_DV, threads, num_stages) - kernel_split(K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, - dg_tilelang_A_positive, dg_tilelang_A_negative) + kernel_split = tilelang_wy_fast_bwd_split( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + kernel_split( + K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, dg_tilelang_A_positive, dg_tilelang_A_negative + ) torch.cuda.synchronize() dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang - dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum( - dim=-1) + dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(dim=-1) + + from test_utils import assert_similar - from utils import assert_similar assert_similar(dk_ref, dk_tilelang, eps=1e-5, name="dk", raise_assert=False) assert_similar(dv_ref, dv_tilelang, eps=1e-5, name="dv", raise_assert=False) assert_similar(dbeta_ref, dbeta_tilelang, eps=1e-5, name="dbeta", raise_assert=False) @@ -517,11 +517,11 @@ def main(): H=8, DK=DK, DV=DV, - input_dtype="bfloat16", - output_dtype="bfloat16", - accum_dtype="float32", - gate_dtype="float32", - state_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, chunk_size=64, block_DK=32, block_DV=32, diff --git a/examples/gdn/test_example_gdn_compilation.py b/examples/gdn/test_example_gdn_compilation.py index e184dbcac..6f9fa5d2f 100644 --- a/examples/gdn/test_example_gdn_compilation.py +++ b/examples/gdn/test_example_gdn_compilation.py @@ -1,16 +1,16 @@ -import tilelang.testing import torch +from tilelang import language as T B = 1 S = 1024 # small but for test only. H = 32 DK = 128 DV = 128 -input_dtype = "bfloat16" -output_dtype = "bfloat16" -accum_dtype = "float32" -gate_dtype = "float32" -state_dtype = "float32" +input_dtype = T.bfloat16 +output_dtype = T.bfloat16 +accum_dtype = T.float32 +gate_dtype = T.float32 +state_dtype = T.float32 chunk_size = 64 use_g = True use_initial_state = True @@ -20,21 +20,15 @@ block_DK = 64 block_DV = 32 threads = 128 -num_stages = 1 +num_stages = 0 def test_example_wy_fast_compilation(): from example_wy_fast import tilelang_recompute_w_u_fwd, prepare_input + K, V, Beta, G, A = prepare_input( - B, - S, - H, - DK, - DV, - chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - gate_dtype=getattr(torch, gate_dtype)) + B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, output_dtype), gate_dtype=getattr(torch, gate_dtype) + ) # tilelang block_S = chunk_size kernel = tilelang_recompute_w_u_fwd( @@ -52,22 +46,31 @@ def test_example_wy_fast_compilation(): block_DK=block_DK, block_DV=block_DV, threads=threads, - num_stages=num_stages) + num_stages=num_stages, + ) print(kernel.get_kernel_source()) W_tilelang, U_tilelang = kernel(K, V, Beta, G, A) def test_example_wy_fast_bwd_split_compilation(): from example_wy_fast_bwd_split import tilelang_wy_fast_bwd, tilelang_wy_fast_bwd_split, prepare_input, prepare_output - K, V, Beta, G, A, dw, du = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, - accum_dtype), getattr(torch, gate_dtype), - getattr(torch, state_dtype)) + + K, V, Beta, G, A, dw, du = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output( - B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), - getattr(torch, state_dtype)) + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) BS = chunk_size dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda() dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda() @@ -75,67 +78,146 @@ def test_example_wy_fast_bwd_split_compilation(): dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda() # tilelang - kernel = tilelang_wy_fast_bwd(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, state_dtype, chunk_size, block_DK, block_DV, threads, - num_stages) - dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel( - K, V, Beta, G, A, dw, du) + kernel = tilelang_wy_fast_bwd( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(K, V, Beta, G, A, dw, du) torch.cuda.synchronize() - kernel_split = tilelang_wy_fast_bwd_split(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, chunk_size, - block_DK, block_DV, threads, num_stages) - kernel_split(K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, - dg_tilelang_A_positive, dg_tilelang_A_negative) + kernel_split = tilelang_wy_fast_bwd_split( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + kernel_split( + K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, dg_tilelang_A_positive, dg_tilelang_A_negative + ) torch.cuda.synchronize() dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang - dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum( - dim=-1) + dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(dim=-1) def test_example_chunk_o_compilation(): from example_chunk_o import tilelang_chunk_fwd_o, prepare_input - Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), - getattr(torch, output_dtype), getattr(torch, accum_dtype), - getattr(torch, gate_dtype)) + + Q, K, V, HIDDEN, G = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + ) scale = 1.0 / DK**0.5 block_S = chunk_size - kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV, - threads, num_stages) + kernel = tilelang_chunk_fwd_o( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + chunk_size, + scale, + use_g, + block_S, + block_DK, + block_DV, + threads, + num_stages, + ) O_tilelang = kernel(Q, K, V, HIDDEN, G) # noqa: F841 def test_example_chunk_o_bwd_compilation(): from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg, prepare_input - Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, state_dtype, chunk_size, 1.0, use_g, True, - block_DK, block_DV, threads, num_stages) - dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, - W) # noqa: F841 + + Q, K, V, h, G, dO, dh, dv, W = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + kernel = tilelang_chunk_o_bwd_dqkwg( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + 1.0, + use_g, + True, + block_DK, + block_DV, + threads, + num_stages, + ) + + dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W) # noqa: F841 if use_g: dg_tilelang = dg_tilelang.sum(dim=0) def test_example_chunk_scaled_dot_kkt_compilation(): from example_chunk_scaled_dot_kkt import tilelang_chunk_scaled_dot_kkt_fwd, prepare_input - K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), - getattr(torch, output_dtype), getattr(torch, accum_dtype)) + + K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype)) block_S = chunk_size - kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype, - accum_dtype, use_g, block_S, block_DK, threads, - num_stages) + kernel = tilelang_chunk_scaled_dot_kkt_fwd( + B, S, H, DK, chunk_size, input_dtype, output_dtype, accum_dtype, use_g, block_S, block_DK, threads, num_stages + ) A_tilelang = kernel(K, Beta, G) # noqa: F841 def test_example_cumsum_compilation(): from example_cumsum import tilelang_chunk_local_cumsum_scalar, prepare_cumsum_input, prepare_cumsum_output + G = prepare_cumsum_input(B, S, H, getattr(torch, gate_dtype)) G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, gate_dtype)) block_S = chunk_size @@ -157,35 +239,82 @@ def test_example_cumsum_compilation(): def test_example_chunk_delta_h_compilation(): from example_chunk_delta_h import tilelang_chunk_gated_delta_rule_fwd_h, prepare_input - K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype)) - kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, chunk_size, - use_g, use_initial_state, store_final_state, - save_new_value, block_DK, block_DV, threads, - num_stages) - h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, - initial_state) # noqa: F841 + + K, W, U, G, initial_state = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + ) + kernel = tilelang_chunk_gated_delta_rule_fwd_h( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_g, + use_initial_state, + store_final_state, + save_new_value, + block_DK, + block_DV, + threads, + num_stages, + ) + h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) # noqa: F841 def test_example_chunk_delta_bwd_compilation(): from example_chunk_delta_bwd import tilelang_chunk_gated_delta_rule_bwd_dhu, prepare_input - Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, - chunk_size, 1.0, use_g, use_initial_state, - use_final_state_gradient, block_DV, threads, - num_stages) + + Q, K, W, G, h0, dht, dO, dv = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + kernel = tilelang_chunk_gated_delta_rule_bwd_dhu( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + 1.0, + use_g, + use_initial_state, + use_final_state_gradient, + block_DV, + threads, + num_stages, + ) dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) # noqa: F841 if __name__ == "__main__": - tilelang.testing.main() + # tilelang.testing.main() + test_example_chunk_delta_bwd_compilation() diff --git a/examples/gdn/utils.py b/examples/gdn/test_utils.py similarity index 68% rename from examples/gdn/utils.py rename to examples/gdn/test_utils.py index 37f8d8e69..3588551ce 100644 --- a/examples/gdn/utils.py +++ b/examples/gdn/test_utils.py @@ -9,7 +9,7 @@ def calc_sim(x, y, name="tensor"): x, y = x.data.double(), y.data.double() denominator = (x * x + y * y).sum() if denominator == 0: - print_red_warning(f'{name} all zero') + print_red_warning(f"{name} all zero") return 1 sim = 2 * (x * y).sum() / denominator return sim @@ -19,21 +19,19 @@ def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True): x_mask = torch.isfinite(x) y_mask = torch.isfinite(y) if not torch.all(x_mask == y_mask): - print_red_warning(f'{name} Error: isfinite mask mismatch') + print_red_warning(f"{name} Error: isfinite mask mismatch") if raise_assert: raise AssertionError - if not torch.isclose( - x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, - equal_nan=True).all(): - print_red_warning(f'{name} Error: nonfinite value mismatch') + if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all(): + print_red_warning(f"{name} Error: nonfinite value mismatch") if raise_assert: raise AssertionError x = x.masked_fill(~x_mask, 0) y = y.masked_fill(~y_mask, 0) sim = calc_sim(x, y, name) - diff = 1. - sim + diff = 1.0 - sim if not (0 <= diff <= eps): - print_red_warning(f'{name} Error: {diff}') + print_red_warning(f"{name} Error: {diff}") if raise_assert: raise AssertionError else: diff --git a/examples/gemm/README.md b/examples/gemm/README.md index d7833c97d..9ab7fb661 100644 --- a/examples/gemm/README.md +++ b/examples/gemm/README.md @@ -53,7 +53,7 @@ import tilelang from tilelang import Profiler import tilelang.language as T -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float): @T.prim_func def main( A: T.Tensor((M, K), dtype), @@ -176,7 +176,7 @@ import tilelang.language as T # that helps align data for MMA (Matrix Multiply-Accumulate) operations. from tilelang.intrinsics import make_mma_swizzle_layout as make_swizzle_layout -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float): @T.prim_func def main( A: T.Tensor((M, K), dtype), @@ -265,18 +265,18 @@ def tl_matmul( accum_dtype, ): assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 - if out_dtype == "int32": + if out_dtype == T.int32: micro_size_k = 32 # This is a debug config diff --git a/examples/gemm/example_gemm.py b/examples/gemm/example_gemm.py index f18cd388a..dfa431121 100644 --- a/examples/gemm/example_gemm.py +++ b/examples/gemm/example_gemm.py @@ -3,13 +3,12 @@ @tilelang.jit(out_idx=[-1]) -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def gemm( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -58,5 +57,11 @@ def main(): print(f"tilelang Latency: {latency}ms") +def run_regression_perf(): + kernel = matmul(1024, 1024, 1024, 128, 128, 32) + profiler = kernel.get_profiler() + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/gemm/example_gemm_autotune.py b/examples/gemm/example_gemm_autotune.py index 661ef1276..052bd64c6 100644 --- a/examples/gemm/example_gemm_autotune.py +++ b/examples/gemm/example_gemm_autotune.py @@ -51,9 +51,9 @@ def get_configs(M, N, K, with_roller=False, topk=20): M=M, N=N, K=K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float", + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float32, ).with_arch(arch) func = carve_template.equivalent_function() @@ -90,7 +90,8 @@ def get_configs(M, N, K, with_roller=False, topk=20): num_stages, thread_num, enable_rasterization, - )) + ) + ) configs = [ { @@ -100,13 +101,19 @@ def get_configs(M, N, K, with_roller=False, topk=20): "num_stages": c[3], "thread_num": c[4], "enable_rasteration": c[5], # keep param name for backward-compat - } for c in _configs + } + for c in _configs ] return configs -def get_best_config(M, N, K, with_roller=False): - +def get_best_config( + M, + N, + K, + with_roller: bool = False, + profile_backend: str = "event", +): def kernel( block_M=None, block_N=None, @@ -115,17 +122,16 @@ def kernel( thread_num=None, enable_rasteration=None, ): - dtype = "bfloat16" - accum_dtype = "float" + dtype = T.bfloat16 + accum_dtype = T.float32 @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -146,15 +152,19 @@ def main( return main - autotuner = AutoTuner.from_kernel( - kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args( + autotuner = ( + AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N, K, with_roller)) + .set_compile_args( out_idx=[-1], target="auto", - ).set_profile_args( + ) + .set_profile_args( supply_type=tl.TensorSupplyType.Integer, ref_prog=ref_program, skip_check=False, + backend=profile_backend, ) + ) return autotuner.run(warmup=3, rep=20) @@ -167,52 +177,20 @@ def get_heuristic_config() -> dict: sm_version = sm_major * 10 + sm_minor print(f"CUDA device capability: {sm_version}") if sm_version in {80}: - return { - "block_M": 128, - "block_N": 256, - "block_K": 32, - "num_stages": 2, - "thread_num": 128, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True} elif sm_version in {90}: - return { - "block_M": 128, - "block_N": 256, - "block_K": 64, - "num_stages": 3, - "thread_num": 256, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True} else: - return { - "block_M": 128, - "block_N": 256, - "block_K": 32, - "num_stages": 0, - "thread_num": 128, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True} @tl.jit(out_idx=[-1]) -def matmul(M, - N, - K, - block_M, - block_N, - block_K, - num_stages, - thread_num, - enable_rasteration, - dtype="float16", - accum_dtype="float"): - +def matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def gemm_autotune( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -236,14 +214,22 @@ def gemm_autotune( return gemm_autotune -def main(M: int = 4096, - N: int = 4096, - K: int = 4096, - use_autotune: bool = False, - with_roller: bool = False): - use_autotune = True +def main( + M: int = 4096, + N: int = 4096, + K: int = 4096, + use_autotune: bool = False, + with_roller: bool = False, + profile_backend: str = "event", +): if use_autotune: - result = get_best_config(M, N, K, with_roller) + result = get_best_config( + M, + N, + K, + with_roller=with_roller, + profile_backend=profile_backend, + ) print(result.config) kernel = result.kernel else: @@ -252,8 +238,13 @@ def main(M: int = 4096, # benchmark profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) - tilelang_latency = profiler.do_bench() - ref_latency = profiler.do_bench(ref_program) + tilelang_latency = profiler.do_bench( + backend=profile_backend, + ) + ref_latency = profiler.do_bench( + ref_program, + backend=profile_backend, + ) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) print(f"TileLang latency: {tilelang_latency}") print(f"Ref latency: {ref_latency}") @@ -261,20 +252,27 @@ def main(M: int = 4096, print(f"Ref TFlops: {2 * M * N * K / ref_latency * 1e-9}") +def run_regression_perf(M: int = 4096, N: int = 4096, K: int = 4096): + config = get_heuristic_config() + kernel = matmul(M, N, K, **config) + profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") parser.add_argument("--m", type=int, default=4096, help="Matrix dimension M") parser.add_argument("--n", type=int, default=4096, help="Matrix dimension N") parser.add_argument("--k", type=int, default=4096, help="Matrix dimension K") - parser.add_argument( - "--use_autotune", - action="store_true", - default=False, - help="Whether to use autotune for matmul configs") - parser.add_argument( - "--with_roller", - action="store_true", - default=False, - help="Whether to enable BitBLAS roller for search space") + parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs") + parser.add_argument("--with_roller", action="store_true", default=False, help="Whether to enable BitBLAS roller for search space") + parser.add_argument("--profile_backend", type=str, default="event", help="Profiler backend") args = parser.parse_args() - main(args.m, args.n, args.k, args.use_autotune, args.with_roller) + main( + args.m, + args.n, + args.k, + args.use_autotune, + args.with_roller, + args.profile_backend, + ) diff --git a/examples/gemm/example_gemm_intrinsics.py b/examples/gemm/example_gemm_intrinsics.py index 5c014ce3a..15e552587 100644 --- a/examples/gemm/example_gemm_intrinsics.py +++ b/examples/gemm/example_gemm_intrinsics.py @@ -4,8 +4,8 @@ import tilelang.language as T from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) -from tilelang.transform import simplify_prim_func + TensorCoreIntrinEmitter, +) def make_swizzle_layout(shared_buf): @@ -24,7 +24,6 @@ def transform_func(i, j): @tilelang.jit(out_idx=[2]) -@simplify_prim_func def tl_matmul( M, N, @@ -34,18 +33,18 @@ def tl_matmul( accum_dtype, ): assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 - if out_dtype == "int32": + if out_dtype == T.int32: micro_size_k = 32 # This is a debug config @@ -53,7 +52,7 @@ def tl_matmul( block_col_warps = 2 warp_row_tiles = 64 warp_col_tiles = 64 - # chunk = 32 if in_dtype == "float16" else 64 + # chunk = 32 if in_dtype == T.float16 else 64 chunk = 32 shared_scope = "shared.dyn" @@ -99,12 +98,11 @@ def tl_matmul( @T.prim_func def gemm_intrinsics( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -112,10 +110,12 @@ def gemm_intrinsics( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -123,7 +123,6 @@ def gemm_intrinsics( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -133,7 +132,6 @@ def gemm_intrinsics( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a(A_local, A_shared, ki) @@ -163,7 +161,7 @@ def ref_program(A, B): def main(M=4096, N=4096, K=4096): - in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32" + in_dtype, out_dtype, accum_dtype = T.float16, T.float16, T.float32 kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) src_code = kernel.get_kernel_source() # src_code is the generated cuda source @@ -181,5 +179,12 @@ def main(M=4096, N=4096, K=4096): profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) +def run_regression_perf(M=4096, N=4096, K=4096): + in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32" + kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + profiler = kernel.get_profiler() + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main(M=4096, N=4096, K=4096) diff --git a/examples/gemm/example_gemm_persistent.py b/examples/gemm/example_gemm_persistent.py index a2a7122d3..ad3d556ed 100644 --- a/examples/gemm/example_gemm_persistent.py +++ b/examples/gemm/example_gemm_persistent.py @@ -5,22 +5,12 @@ @tilelang.jit(out_idx=[-1]) -def matmul_non_persistent(M, - N, - K, - block_M, - block_N, - block_K, - threads, - num_stages, - dtype="float16", - accum_dtype="float"): - +def matmul_non_persistent(M, N, K, block_M, block_N, block_K, threads, num_stages, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -43,18 +33,9 @@ def main( @tilelang.jit(out_idx=[-1]) -def matmul_persistent(M, - N, - K, - block_M, - block_N, - block_K, - threads, - num_stages, - dtype="float16", - accum_dtype="float", - use_persistent_primitive=True): - +def matmul_persistent( + M, N, K, block_M, block_N, block_K, threads, num_stages, dtype=T.float16, accum_dtype=T.float32, use_persistent_primitive=True +): sm_num = driver.get_num_sms() m_blocks = T.ceildiv(M, block_M) n_blocks = T.ceildiv(N, block_N) @@ -63,9 +44,9 @@ def matmul_persistent(M, @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(sm_num, threads=threads) as (block_id): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -90,9 +71,9 @@ def main( @T.prim_func def main_persistent_primitive( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(sm_num, threads=threads) as (block_id): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -100,8 +81,7 @@ def main_persistent_primitive( C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), dtype) - for bx, by in T.Persistent( - [T.ceildiv(M, block_M), T.ceildiv(N, block_N)], sm_num, block_id): + for bx, by in T.Persistent([T.ceildiv(M, block_M), T.ceildiv(N, block_N)], sm_num, block_id): T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[bx * block_M, k * block_K], A_shared) @@ -128,18 +108,15 @@ def main(M=4096, N=4096, K=4096): num_stages = 3 persistent_kernel = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages) - persistent_profiler = persistent_kernel.get_profiler( - tensor_supply_type=tilelang.TensorSupplyType.Randn) + persistent_profiler = persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("Persistent GEMM: All check passed.") persistent_latency = persistent_profiler.do_bench(warmup=500) print(f"Persistent GEMM Latency: {persistent_latency} ms") print(f"Persistent GEMM TFlops: {total_flops / persistent_latency * 1e-9} TFlops") - non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, - num_stages) - non_persistent_profiler = non_persistent_kernel.get_profiler( - tensor_supply_type=tilelang.TensorSupplyType.Randn) + non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages) + non_persistent_profiler = non_persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) non_persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("Non-Persistent GEMM: All check passed.") non_persistent_latency = non_persistent_profiler.do_bench(warmup=500) @@ -149,11 +126,22 @@ def main(M=4096, N=4096, K=4096): print(f"Persistent GEMM Speedup: {non_persistent_latency / persistent_latency}") +def run_regression_perf(M=4096, N=4096, K=4096): + BLOCK_M = 128 + BLOCK_N = 256 + BLOCK_K = 64 + threads = 256 + num_stages = 3 + persistent_kernel = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages) + persistent_profiler = persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + return persistent_profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--M', type=int, default=8192, help='M dimension') - parser.add_argument('--N', type=int, default=8192, help='N dimension') - parser.add_argument('--K', type=int, default=8192, help='K dimension') + parser.add_argument("--M", type=int, default=8192, help="M dimension") + parser.add_argument("--N", type=int, default=8192, help="N dimension") + parser.add_argument("--K", type=int, default=8192, help="K dimension") args = parser.parse_args() M, N, K = args.M, args.N, args.K main(M, N, K) diff --git a/examples/gemm/example_gemm_schedule.py b/examples/gemm/example_gemm_schedule.py index f4727412b..17dbcc568 100644 --- a/examples/gemm/example_gemm_schedule.py +++ b/examples/gemm/example_gemm_schedule.py @@ -3,13 +3,12 @@ @tilelang.jit(out_idx=[-1]) -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def gemm_schedule( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -65,5 +64,19 @@ def main(): print(kernel.get_kernel_source()) +def run_regression_perf(): + kernel = matmul(1024, 1024, 1024, 128, 128, 32) + import torch + + a = torch.randn(1024, 1024).cuda().half() + b = torch.randn(1024, 1024).cuda().half() + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(a, b) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/gemm/regression_example_gemm.py b/examples/gemm/regression_example_gemm.py new file mode 100644 index 000000000..3583cf16a --- /dev/null +++ b/examples/gemm/regression_example_gemm.py @@ -0,0 +1,25 @@ +import tilelang.testing +import example_gemm +import example_gemm_autotune +import example_gemm_intrinsics +import example_gemm_schedule + + +def regression_example_gemm_autotune(): + tilelang.testing.process_func(example_gemm_autotune.run_regression_perf, M=1024, N=1024, K=1024) + + +def regression_example_gemm_intrinsics(): + tilelang.testing.process_func(example_gemm_intrinsics.run_regression_perf, M=1024, N=1024, K=1024) + + +def regression_example_gemm_schedule(): + tilelang.testing.process_func(example_gemm_schedule.run_regression_perf) + + +def regression_example_gemm(): + tilelang.testing.process_func(example_gemm.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/gemm_fp8/README.md b/examples/gemm_fp8/README.md index 9d7011a06..2b3dc9560 100644 --- a/examples/gemm_fp8/README.md +++ b/examples/gemm_fp8/README.md @@ -1 +1 @@ -**Notes**: Now we only support fp8 with mma instructions instead of `T.gemm`, because the cutlass version of tilelang is too old, we should update the cutlass version in future. \ No newline at end of file +**Notes**: Now we only support fp8 with mma instructions instead of `T.gemm`, because the cutlass version of tilelang is too old, we should update the cutlass version in future. diff --git a/examples/gemm_fp8/example_tilelang_gemm_amd.py b/examples/gemm_fp8/example_tilelang_gemm_amd.py index 0e6ace757..16a9d5f32 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_amd.py +++ b/examples/gemm_fp8/example_tilelang_gemm_amd.py @@ -2,6 +2,7 @@ import tilelang import tilelang.language as T from tilelang.utils.tensor import torch_assert_close +from tilelang.utils import determine_fp8_type, determine_torch_fp8_type import itertools @@ -17,10 +18,9 @@ def supply_prog(args): a_param, b_param = args M, K = a_param.shape N, _ = b_param.shape - a = (torch.randn(M, K, dtype=torch.float16, device='cuda') * - 0.01).to(dtype=torch.float8_e4m3fnuz) - b = (torch.randn(N, K, dtype=torch.float16, device='cuda') * - 0.01).to(dtype=torch.float8_e4m3fnuz) + fp8_dtype = determine_torch_fp8_type() + a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype) + b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype) return [a, b] @@ -35,40 +35,36 @@ def get_configs(): valid_configs = [] - for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks, - num_stages, num_threads, k_packs, - gemm_types): - valid_configs.append({ - "block_M": m, - "block_N": n, - "block_K": k, - "num_stages": stages, - "num_threads": t, - "k_pack": kp, - "gemm_type": gemm_type, - }) + for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks, num_stages, num_threads, k_packs, gemm_types): + valid_configs.append( + { + "block_M": m, + "block_N": n, + "block_K": k, + "num_stages": stages, + "num_threads": t, + "k_pack": kp, + "gemm_type": gemm_type, + } + ) return valid_configs @tilelang.autotune( - configs=get_configs(), - cache_input_tensors=True, - ref_prog=ref_program, - manual_check_prog=manual_check_prog, - supply_prog=supply_prog) + configs=get_configs(), cache_input_tensors=True, ref_prog=ref_program, manual_check_prog=manual_check_prog, supply_prog=supply_prog +) @tilelang.jit(out_idx=[-1]) def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type): - dtype = "float8_e4m3fnuz" - accum_dtype = "float" + dtype = determine_fp8_type() + accum_dtype = T.float32 @T.prim_func def gemm_fp8_rs( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), accum_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), accum_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): A_local = T.alloc_fragment((block_M, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -77,24 +73,17 @@ def gemm_fp8_rs( for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_local) T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm( - A_local, - B_shared, - C_local, - transpose_B=True, - k_pack=k_pack, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(A_local, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow) T.copy(C_local, C[by * block_M, bx * block_N]) @T.prim_func def gemm_fp8_ss( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), accum_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), accum_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -103,13 +92,7 @@ def gemm_fp8_ss( for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm( - A_shared, - B_shared, - C_local, - transpose_B=True, - k_pack=k_pack, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(A_shared, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow) T.copy(C_local, C[by * block_M, bx * block_N]) @@ -123,10 +106,9 @@ def gemm_fp8_ss( def test_gemm_fp8(M, N, K): kernel = fp8_matmul(M, N, K) - a = (torch.randn(M, K, dtype=torch.float16, device='cuda') * - 0.01).to(dtype=torch.float8_e4m3fnuz) - b = (torch.randn(N, K, dtype=torch.float16, device='cuda') * - 0.01).to(dtype=torch.float8_e4m3fnuz) + fp8_dtype = determine_torch_fp8_type() + a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype) + b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype) c = kernel(a, b) ref_c = ref_program(a, b) torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2) diff --git a/examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py b/examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py new file mode 100644 index 000000000..fc7fb4400 --- /dev/null +++ b/examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py @@ -0,0 +1,225 @@ +import torch +import itertools +import tilelang +import tilelang.testing +from tilelang import tvm as tvm +import tilelang.language as T +from tilelang.tileop.base import GemmWarpPolicy +from tilelang.layout import make_swizzled_layout +from tilelang.intrinsics.mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter +from tilelang.utils import determine_fp8_type + +tilelang.testing.set_random_seed(0) + + +def get_configs(): + block_Ms = [32, 64, 128] + block_Ns = [32, 64, 128] + block_Ks = [64, 128] + num_stages = [0, 1, 2] + + valid_configs = [] + + for m, n, k, stages in itertools.product(block_Ms, block_Ns, block_Ks, num_stages): + valid_configs.append( + { + "block_M": m, + "block_N": n, + "block_K": k, + "num_stages": stages, + } + ) + return valid_configs + + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit(out_idx=[-1]) +def tl_matmul( + M, + N, + K, + block_M, + block_N, + block_K, + num_stages, + k_pack=2, + num_threads=256, + in_dtype=None, + out_dtype=T.float32, + accum_dtype=T.float32, + a_transposed=False, + b_transposed=True, +): + if in_dtype is None: + in_dtype = determine_fp8_type() + b_preshuffle = True + warp_size = 64 + num_warps = num_threads // warp_size + + policy = GemmWarpPolicy.Square + m_warp, n_warp = policy.compute_warp_partition(block_M, block_N, num_warps) + + shared_scope = "shared" + warp_row_tiles = block_M // m_warp + warp_col_tiles = block_N // n_warp + + # MMA Wrapper to Auto Generate Code for MMA + mfma_emitter = MatrixCorePreshuffleIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=a_transposed, + b_transposed=b_transposed, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=block_K, + k_pack=k_pack, + b_preshuffle=b_preshuffle, + ) + local_size_a = mfma_emitter.local_size_a + local_size_b = mfma_emitter.local_size_b + + warp_rows = mfma_emitter.warp_rows + warp_cols = mfma_emitter.warp_cols + + micro_size_y = mfma_emitter.micro_size_y + micro_size_k = mfma_emitter.micro_size_k + pack_size_k = micro_size_k * k_pack + + A_shape = (K, M) if a_transposed else (M, K) + A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K) + + B_shape = ( + (N // micro_size_y, K // pack_size_k, micro_size_y, pack_size_k) + if b_transposed + else (K // pack_size_k, N // micro_size_y, pack_size_k, micro_size_y) + ) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a * k_pack), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b * k_pack), in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.annotate_layout( + { + A_shared: make_swizzled_layout(A_shared), + C_local: mfma_emitter.make_mfma_store_layout(C_local), + } + ) + + num_ko = K // block_K + num_ki = block_K // (k_pack * micro_size_k) + + # Improve L2 Cache + # T.use_swizzle(panel_size=10) + T.clear(C_local) + for ko in T.Pipelined(num_ko, num_stages=num_stages): + # Load A into shared memory + if a_transposed: + T.copy(A[ko * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, ko * block_K], A_shared) + + for ki in T.serial(0, num_ki): + mfma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + mfma_emitter.ldmatrix_b(B_local, B, ki + ko * num_ki, pid_m=by, pid_n=bx) + + # Perform Matrix Multiplication + mfma_emitter.mfma(A_local, B_local, C_local, ki) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def shuffle_weight( + x: torch.Tensor, + layout=(16, 32), + k_pack=1, + is_transpose=False, +) -> torch.Tensor: + IN, IK = layout + BK = IK * k_pack + BN = IN + + N, K = (x.shape[-2], x.shape[-1]) if is_transpose else (x.shape[-1], x.shape[-2]) + assert N % BN == 0 + assert K % BK == 0 + + x = x.view(N // BN, BN, K // BK, BK) if is_transpose else x.view(K // BK, BK, N // BN, BN) + x = x.permute(0, 2, 1, 3) + return x.contiguous() + + +def assert_tl_matmul_correctness(M, N, K, k_pack=1, a_transposed=False, b_transposed=True): + in_dtype = determine_fp8_type() + out_dtype = T.float32 + accum_dtype = T.float32 + kernel = tl_matmul( + M, + N, + K, + k_pack=k_pack, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + a_transposed=a_transposed, + b_transposed=b_transposed, + ) + + src_code = kernel.get_kernel_source() + # src_code is the generated cuda source + assert src_code is not None + A_shape = (K, M) if a_transposed else (M, K) + B_shape = (N, K) if b_transposed else (K, N) + + A = (torch.rand(A_shape, device="cuda", dtype=torch.float16) / 10).to(getattr(torch, in_dtype)) + B = (torch.rand(B_shape, device="cuda", dtype=torch.float16) / 10).to(getattr(torch, in_dtype)) + + B_preshuffle = shuffle_weight(B, k_pack=k_pack, is_transpose=b_transposed) + C = kernel(A, B_preshuffle) + + profiler = kernel.get_profiler() + latency = profiler.do_bench() + + # Ensure that the latency is not None + assert latency is not None + print("time: ", latency) + + if a_transposed and b_transposed: + # Get Reference Result + ref_c = torch.matmul(A.T.half(), B.T.half()).to(getattr(torch, out_dtype)) + elif a_transposed and not b_transposed: + # Get Reference Result + ref_c = torch.matmul(A.T.half(), B.half()).to(getattr(torch, out_dtype)) + elif not a_transposed and b_transposed: + # Get Reference Result + ref_c = torch.matmul(A.half(), B.T.half()).to(getattr(torch, out_dtype)) + else: + # Get Reference Result + ref_c = torch.matmul(A.half(), B.half()).to(getattr(torch, out_dtype)) + + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(512, 512, 512, k_pack=2) + + +if __name__ == "__main__": + test_assert_tl_matmul() diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8.py b/examples/gemm_fp8/example_tilelang_gemm_fp8.py index a403ed068..3b575c78e 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8.py @@ -1,7 +1,7 @@ import torch import tilelang import tilelang.language as T -from tilelang.utils.tensor import map_torch_type +from tilelang.utils import determine_fp8_type def calc_diff(x, y): @@ -12,13 +12,12 @@ def calc_diff(x, y): @tilelang.jit(out_idx=[-1]) -def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): - +def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype=T.float32): @T.prim_func def gemm_fp8( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -37,12 +36,12 @@ def gemm_fp8( def test_gemm_fp8(M, N, K, dtype): - torch_dtype = map_torch_type(dtype) + torch_dtype = T.dtype(dtype).as_torch() kernel = matmul(M, N, K, 128, 128, 64, dtype) - a = torch.randn(M, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype) - b = torch.randn(N, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype) + a = torch.randn(M, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype) + b = torch.randn(N, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype) c = kernel(a, b) @@ -57,8 +56,24 @@ def test_gemm_fp8(M, N, K, dtype): def main(): - test_gemm_fp8(1024, 1024, 1024, 'float8_e4m3') - test_gemm_fp8(1024, 1024, 1024, 'float8_e5m2') + test_gemm_fp8(1024, 1024, 1024, determine_fp8_type()) + test_gemm_fp8(1024, 1024, 1024, determine_fp8_type("e5m2")) + + +def run_regression_perf(): + M, N, K = 4096, 4096, 4096 + dtype = determine_fp8_type() + kernel_e4m3 = matmul(M, N, K, 128, 128, 64, dtype) + profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer) + if torch.version.hip is None: + latency_e4m3 = profiler_e4m3.do_bench(backend="cupti") + dtype = determine_fp8_type("e5m2") + kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype) + profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e5m2 = profiler_e5m2.do_bench(backend="cupti") + return (latency_e4m3 + latency_e5m2) / 2 + latency_e4m3 = profiler_e4m3.do_bench() + return latency_e4m3 if __name__ == "__main__": diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py index 1d9207aff..39c6fc333 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py @@ -1,11 +1,11 @@ import torch import tilelang import tilelang.language as T -from tilelang.utils.tensor import map_torch_type +from tilelang.utils import determine_fp8_type @tilelang.jit(out_idx=[-1]) -def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype=T.float32): # for fp8 gemm, do one promote after 4 wgmma inst, i.e. block_K = 128. # if block_K < 128, promote after 128/block_K iters. # if block_K > 128, promote after every iter. @@ -13,9 +13,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): @T.prim_func def gemm_fp8_2xAcc( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), accum_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), accum_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -55,18 +55,18 @@ def calc_diff(x, y): def test_gemm_fp8(M, N, K, dtype): - torch_dtype = map_torch_type(dtype) + torch_dtype = T.dtype(dtype).as_torch() kernel = matmul(M, N, K, 128, 128, 64, dtype) - a = torch.rand(M, K, dtype=torch.float16, device='cuda') + a = torch.rand(M, K, dtype=torch.float16, device="cuda") a = (100 * (2 * a - 1)).to(dtype=torch_dtype) - b = torch.rand(N, K, dtype=torch.float16, device='cuda') + b = torch.rand(N, K, dtype=torch.float16, device="cuda") b = (100 * (2 * b - 1)).to(dtype=torch_dtype) c = kernel(a, b) - ref_c = (a.float() @ b.float().T) + ref_c = a.float() @ b.float().T diff = calc_diff(c, ref_c) print(f"diff: {diff}") @@ -74,8 +74,26 @@ def test_gemm_fp8(M, N, K, dtype): def main(): - test_gemm_fp8(1024, 1024, 8192, 'float8_e4m3') - test_gemm_fp8(1024, 1024, 8192, 'float8_e5m2') + test_gemm_fp8(1024, 1024, 8192, determine_fp8_type()) + test_gemm_fp8(1024, 1024, 8192, determine_fp8_type("e5m2")) + + +def run_regression_perf(): + M, N, K = 1024, 1024, 8192 + dtype = determine_fp8_type() + kernel_e4m3 = matmul(M, N, K, 128, 128, 64, dtype) + profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer) + if torch.version.hip is None: + latency_e4m3 = profiler_e4m3.do_bench(backend="cupti") + else: + latency_e4m3 = profiler_e4m3.do_bench() + if torch.version.hip is None: + dtype = determine_fp8_type("e5m2") + kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype) + profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e5m2 = profiler_e5m2.do_bench(backend="cupti") + return (latency_e4m3 + latency_e5m2) / 2 + return latency_e4m3 if __name__ == "__main__": diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py index ed44aab69..1015a7463 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py @@ -4,10 +4,10 @@ from tvm import DataType import tilelang.language as T from tilelang.intrinsics import get_swizzle_layout -from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) -from tilelang.transform import simplify_prim_func +from tilelang.intrinsics.mma_macro_generator import TensorCoreIntrinEmitter +from tilelang.intrinsics.mfma_macro_generator import MatrixCoreIntrinEmitter from tilelang.utils.tensor import map_torch_type +from tilelang.utils import determine_fp8_type tilelang.testing.set_random_seed(0) @@ -28,7 +28,6 @@ def transform_func(i, j): @tilelang.jit(out_idx=[2]) -@simplify_prim_func def tl_matmul( M, N, @@ -38,29 +37,25 @@ def tl_matmul( accum_dtype, ): assert in_dtype in [ - "float16", - "float8_e4m3", - "float8_e5m2", - "int8", - ], "Currently only float16 and int8 are supported" + T.float16, + T.float8_e4m3fn, + T.float8_e4m3fnuz, + T.float8_e5m2, + T.float8_e5m2fnuz, + T.int8, + ], "Currently only float16, float8, and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" - micro_size_x = micro_size_y = micro_size_k = 16 - - is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"] - if out_dtype == "int32" or is_float8: - micro_size_k = 32 - # This is a debug config block_row_warps = 2 block_col_warps = 2 warp_row_tiles = 32 warp_col_tiles = 32 - chunk = 32 if in_dtype == "float16" else 64 + chunk = 32 if in_dtype == T.float16 else 64 shared_scope = "shared.dyn" # Pipeline Stage @@ -74,6 +69,38 @@ def tl_matmul( B_shape = (N, K) A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, block_K) + is_hip = torch.version.hip is not None + # MMA Wrapper to Auto Generate Code for MMA/MFMA + if is_hip: + mma_emitter = MatrixCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + else: + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + micro_size_x = mma_emitter.M_DIM + micro_size_y = getattr(mma_emitter, "n_dim", getattr(mma_emitter, "N_DIM", micro_size_x)) + micro_size_k = mma_emitter.k_dim C_shared_shape = ( block_M // micro_size_x, block_N // micro_size_y, @@ -81,36 +108,20 @@ def tl_matmul( micro_size_y, ) - warp_size = 32 - threads = warp_size * (block_row_warps * block_col_warps) - local_size_a = (micro_size_x * micro_size_k) // warp_size - local_size_b = (micro_size_y * micro_size_k) // warp_size - local_size_c = (micro_size_x * micro_size_y) // warp_size - warp_rows = warp_row_tiles // micro_size_x - warp_cols = warp_col_tiles // micro_size_y - - # MMA Wrapper to Auto Generate Code for MMA - mma_emitter = TensorCoreIntrinEmitter( - a_dtype=in_dtype, - b_dtype=in_dtype, - accum_dtype=accum_dtype, - a_transposed=False, - b_transposed=True, - block_row_warps=block_row_warps, - block_col_warps=block_col_warps, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=chunk, - ) + threads = mma_emitter.threads + local_size_a = mma_emitter.local_size_a + local_size_b = mma_emitter.local_size_b + local_size_c = mma_emitter.local_size_out + warp_rows = mma_emitter.warp_rows + warp_cols = mma_emitter.warp_cols @T.prim_func def gemm_fp8_intrinsic( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -118,10 +129,12 @@ def gemm_fp8_intrinsic( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -129,7 +142,6 @@ def gemm_fp8_intrinsic( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -139,7 +151,6 @@ def gemm_fp8_intrinsic( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, @@ -155,7 +166,10 @@ def gemm_fp8_intrinsic( ) # Perform Matrix Multiplication - mma_emitter.mma(A_local, B_local, C_local) + if is_hip: + mma_emitter.mfma(A_local, B_local, C_local, ki) + else: + mma_emitter.mma(A_local, B_local, C_local) # Perform STMatrix mma_emitter.stmatrix( @@ -189,7 +203,12 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): if in_dtype in {torch.int8, torch.int32}: A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda() - elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: + elif in_dtype in { + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + }: A = torch.randn(M, K).to(in_dtype).cuda() B = torch.randn(N, K).to(in_dtype).cuda() else: @@ -215,8 +234,24 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): def main(): - assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3", "float32", "float32") - assert_tl_matmul_correctness(128, 128, 128, "float8_e5m2", "float32", "float32") + e4m3_dtype = determine_fp8_type() + assert_tl_matmul_correctness(128, 128, 128, e4m3_dtype, T.float32, T.float32) + e5m2_dtype = determine_fp8_type("e5m2") + assert_tl_matmul_correctness(128, 128, 128, e5m2_dtype, T.float32, T.float32) + + +def run_regression_perf(): + M, N, K = 4096, 4096, 4096 + out_dtype, accum_dtype = "float32", "float32" + in_dtype = determine_fp8_type() + kernel_e4m3 = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + print(kernel_e4m3.get_kernel_source()) + profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer) + if torch.version.hip is None: + latency_e4m3 = profiler_e4m3.do_bench(backend="cupti") + else: + latency_e4m3 = profiler_e4m3.do_bench() + return latency_e4m3 if __name__ == "__main__": diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py new file mode 100644 index 000000000..aa7e8b360 --- /dev/null +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py @@ -0,0 +1,124 @@ +import torch +import tilelang +import tilelang.language as T +from tilelang.utils.tensor import map_torch_type + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + mbar = T.alloc_barrier(1) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm_v2( + A_shared, + B_shared, + C_tmem, + trans_A, + trans_B, + mbar=mbar, + wg_wait=-1, + clear_accum=(k == 0), + ) + T.mbarrier_wait_parity(mbar, k % 2) + + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +M, N, K = 4096, 4096, 8192 +block_M, block_N, block_K = 64, 256, 32 +trans_A, trans_B = False, True +num_stages = 2 +threads = 256 +for tvm_fp8_dtype in [T.float8_e4m3fn, T.float8_e5m2]: + for tvm_acc_dtype in [T.float16, T.float32]: # , torch.float16]: + torch_fp8_dtype = map_torch_type(tvm_fp8_dtype) + torch_acc_dtype = map_torch_type(tvm_acc_dtype) + print(f"running {tvm_fp8_dtype} -> {tvm_acc_dtype}") + in_dtype, out_dtype, accum_dtype = tvm_fp8_dtype, tvm_acc_dtype, tvm_acc_dtype + + func = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, + ) + jit_kernel = tilelang.compile( + func, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT: True, + }, + ) + # jit_kernel.export_ptx("./dump.ptx") + # jit_kernel.export_sources("./dump.cu") + + a = torch.randn(M, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype) + b = torch.randn(N, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype) + + c = jit_kernel(a, b) + ref_c = (a.to(torch.half) @ b.T.to(torch.half)).float() + c = c.float() + diff = calc_diff(c, ref_c) + # assert diff < 1e-3, f"{diff}" + print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] diff = {diff}") + + profiler = jit_kernel.get_profiler() + latency = profiler.do_bench() + print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Latency: {latency} ms") + print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS") diff --git a/examples/gemm_fp8/regression_example_gemm_fp8.py b/examples/gemm_fp8/regression_example_gemm_fp8.py new file mode 100644 index 000000000..3ba2f4f27 --- /dev/null +++ b/examples/gemm_fp8/regression_example_gemm_fp8.py @@ -0,0 +1,20 @@ +import tilelang.testing +import example_tilelang_gemm_fp8 +import example_tilelang_gemm_fp8_2xAcc +import example_tilelang_gemm_fp8_intrinsic + + +def regression_example_tilelang_gemm_fp8_2xAcc(): + tilelang.testing.process_func(example_tilelang_gemm_fp8_2xAcc.run_regression_perf) + + +def regression_example_tilelang_gemm_fp8_intrinsic(): + tilelang.testing.process_func(example_tilelang_gemm_fp8_intrinsic.run_regression_perf) + + +def regression_example_tilelang_gemm_fp8(): + tilelang.testing.process_func(example_tilelang_gemm_fp8.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/gemm_sm100/README.md b/examples/gemm_sm100/README.md index 73dd76c30..d630d2d0d 100644 --- a/examples/gemm_sm100/README.md +++ b/examples/gemm_sm100/README.md @@ -40,19 +40,19 @@ import tilelang.language as T @T.prim_func def main( - A: T.Tensor((M, K), "bfloat16"), - B: T.Tensor((N, K), "bfloat16"), - C: T.Tensor((M, N), "bfloat16"), + A: T.Tensor((M, K), T.bfloat16), + B: T.Tensor((N, K), T.bfloat16), + C: T.Tensor((M, N), T.bfloat16), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): # 1. Allocate memory buffers - A_shared = T.alloc_shared((block_M, block_K), "bfloat16") # A matrix shared memory - B_shared = T.alloc_shared((block_N, block_K), "bfloat16") # B matrix shared memory - C_tmem = T.alloc_tmem([block_M, block_N], "float") # TCGEN5MMA output to Tensor Memory + A_shared = T.alloc_shared((block_M, block_K), T.bfloat16) # A matrix shared memory + B_shared = T.alloc_shared((block_N, block_K), T.bfloat16) # B matrix shared memory + C_tmem = T.alloc_tmem([block_M, block_N], T.float) # TCGEN5MMA output to Tensor Memory mbar = T.alloc_barrier(1) # mbarrier synchronization primitive - C_local = T.alloc_fragment((block_M, block_N), "float") # Register storage - C_shared = T.alloc_shared((block_M, block_N), "bfloat16") # Output shared memory + C_local = T.alloc_fragment((block_M, block_N), T.float) # Register storage + C_shared = T.alloc_shared((block_M, block_N), T.bfloat16) # Output shared memory # 2. Main computation loop for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): @@ -103,4 +103,3 @@ latency = profiler.do_bench() print(f"Latency: {latency} ms") print(f"Performance: {2 * M * N * K / (latency/1e3) / 1e12:.2f} TFLOPS") ``` - diff --git a/examples/gemm_sm100/gemm_mma.py b/examples/gemm_sm100/gemm_mma.py index a58e5a7c0..226e33c01 100644 --- a/examples/gemm_sm100/gemm_mma.py +++ b/examples/gemm_sm100/gemm_mma.py @@ -4,13 +4,12 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): @@ -62,7 +61,8 @@ def main( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, +) print(jit_kernel.get_kernel_source()) # 3. Test the kernel in Python with PyTorch data import torch diff --git a/examples/gemm_sm100/gemm_tcgen5mma.py b/examples/gemm_sm100/gemm_tcgen5mma.py index 9008c7ef5..d3f384e98 100644 --- a/examples/gemm_sm100/gemm_tcgen5mma.py +++ b/examples/gemm_sm100/gemm_tcgen5mma.py @@ -25,9 +25,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -38,17 +38,9 @@ def main( C_shared = T.alloc_shared((block_M, block_N), out_dtype) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm( - A_shared, - B_shared, - C_tmem, - trans_A, - trans_B, - mbar=mbar, - wg_wait=-1, - clear_accum=k == 0) + T.copy(A[by * block_M, k * block_K], A_shared) # not trans_A + T.copy(B[bx * block_N, k * block_K], B_shared) # trans_B + T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k == 0) T.mbarrier_wait_parity(mbar, k % 2) T.copy(C_tmem, C_local) @@ -60,14 +52,13 @@ def main( M, N, K = 4096, 4096, 8192 -block_M, block_N, block_K = 128, 256, 128 +block_M, block_N, block_K = 128, 128, 128 trans_A, trans_B = False, True -in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float" -num_stages = 2 +in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float +num_stages = 0 if block_N >= 256 or block_M >= 256 or block_K >= 256 else 2 threads = 256 -func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, - accum_dtype, num_stages, threads) +func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads) jit_kernel = tilelang.compile( func, out_idx=[2], @@ -75,7 +66,8 @@ def main( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, +) print(jit_kernel.get_kernel_source()) @@ -88,4 +80,4 @@ def main( profiler = jit_kernel.get_profiler() latency = profiler.do_bench() print(f"Latency: {latency} ms") -print(f"Flops: {2 * M * N * K / (latency/1e3) / 1e12} TFLOPS") +print(f"Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS") diff --git a/examples/gemm_sp/example_custom_compress.py b/examples/gemm_sp/example_custom_compress.py new file mode 100644 index 000000000..4b03ae83d --- /dev/null +++ b/examples/gemm_sp/example_custom_compress.py @@ -0,0 +1,342 @@ +import argparse + +import tilelang +import tilelang.language as T + +from tilelang.layout import make_cutlass_metadata_layout +from tilelang.utils.sparse import randn_semi_sparse +from tilelang.utils.tensor import torch_assert_close + +from tilelang.profiler import do_bench + +import torch + +torch.manual_seed(42) + +DEFAULT_CONFIG = { # take best config from autotune script + "4090": { + T.float: { + "block_M": 128, + "block_N": 64, + "block_K": 64, + "num_stages": 1, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + T.float16: { + "block_M": 256, + "block_N": 128, + "block_K": 64, + "num_stages": 2, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + }, + "h20": { + T.float: { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + T.float16: { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + }, +} + +ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} + + +@tilelang.jit(out_idx=[-1]) +def matmul_sp_fp16_custom_compress( + M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization, use_cutlass_layout +): + e_factor, e_dtype = (16, T.int16) + + @T.prim_func + def gemm_sp_fp16_custom_compress( + A_sparse: T.Tensor((M, K // 2), T.float16), + E: T.Tensor((M, K // e_factor), e_dtype), + B: T.Tensor((K, N), T.float16), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K // 2), T.float16) + E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) + B_shared = T.alloc_shared((block_K, block_N), T.float16) + C_shared = T.alloc_shared((block_M, block_N), accum_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + if use_cutlass_layout: + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="8.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="8.0", block_k=block_K), + } + ) + T.clear(C_local) + T.disable_warp_group_reg_alloc() + T.use_swizzle(panel_size=10, enable=enable_rasterization) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + T.copy(E[by * block_M, k * block_K // e_factor], E_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp_v2(A_shared, E_shared, B_shared, C_local, False, False, policy=policy) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return gemm_sp_fp16_custom_compress + + +def torch_compress(dense): + """ + A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout. + """ + if dense.dim() != 2: + raise RuntimeError(f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor") + + m, k = dense.shape + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError("Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 16") + else: + if m % 32 != 0: + raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 32") + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError(f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}") + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, _m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, _m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = meta_n[:, :, 0] | (meta_n[:, :, 1] << 4) | (meta_n[:, :, 2] << 8) | (meta_n[:, :, 3] << 12) + elif quadbits_per_meta_elem == 8: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) + + return (sparse, meta) + + +def decode_metadata(meta: torch.Tensor) -> torch.Tensor: + assert meta.dtype is torch.int16 + groups_per_meta = 16 // 4 # 4 groups per uint16 + out = [] + for g in range(groups_per_meta): + group_bits = (meta >> (g * 4)) & 0xF + idx0 = group_bits & 0x3 + idx1 = (group_bits >> 2) & 0x3 + out.append(torch.stack([idx0, idx1], dim=-1)) + return torch.concat(out, dim=-1).view(meta.shape[0], -1) + + +@tilelang.jit( + out_idx=[1, 2], + pass_configs={ + tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True, + }, +) +def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): + e_factor, e_dtype = ARCH_INFO["8.0"] + e_K = K // e_factor + elem, group = 2, 4 + + assert M % block_M == 0, "M must be divisible by block_M" + assert K % block_K == 0, "K must be divisible by block_K" + assert K % e_factor == 0, "K must be divisible by e_factor" + assert block_K % e_factor == 0, "block_K must be divisible by e_factor" + + @T.prim_func + def kernel( + A: T.Tensor((M, K), dtype), + A_sp: T.Tensor((M, K // 2), dtype), + E: T.Tensor((M, e_K), e_dtype), + ): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype) + E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) + if use_cutlass_layout: + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="8.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="8.0", block_k=block_K), + } + ) + T.clear(A_sp_shared) + T.clear(E_shared) + # TODO: alloc_var seems buggy here + non_zero_cnt = T.alloc_local((1,), dtype=T.uint8) + non_zero_elt_log_idx = T.alloc_local((elem,), dtype=T.uint8) + T.copy(A[bx * block_M, by * block_K], A_shared) + for tm in T.Parallel(block_M): + for g_i in range(0, block_K // group): + a_k = g_i * group + non_zero_cnt[0] = 0 + for i in range(elem): + non_zero_elt_log_idx[i] = 0 + for i in range(group): + val = A_shared[tm, a_k + i] + if val != 0.0: + non_zero_elt_log_idx[non_zero_cnt[0]] = i + A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val + non_zero_cnt[0] += 1 + # TODO: use T.device_assert(non_zero_cnt <= 2) after rebasing main + if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3: + non_zero_elt_log_idx[0] = 0 + non_zero_elt_log_idx[1] = 3 + A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2] + A_sp_shared[tm, a_k // 2] = 0.0 + elif non_zero_cnt[0] == 1: + A_sp_shared[tm, a_k // 2 + 1] = 0 + non_zero_elt_log_idx[1] = 3 + for i in T.serial(elem): + val = non_zero_elt_log_idx[i] + E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i) + T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2]) + T.copy(E_shared, E[bx * block_M, by * block_K // e_factor]) + + return kernel + + +def main(M=1024, N=1024, K=1024, use_cutlass_layout=False, use_torch_compressor=False, accum_dtype=T.float, cfg="4090"): + kernel = matmul_sp_fp16_custom_compress(M, N, K, accum_dtype, **DEFAULT_CONFIG[cfg][accum_dtype], use_cutlass_layout=use_cutlass_layout) + + a = randn_semi_sparse(M, K, device="cuda", dtype=torch.half) + b = torch.randn(K, N, device="cuda", dtype=torch.half) + + if use_torch_compressor: + assert not use_cutlass_layout, "torch sparse must be used with naive layout" + a_sparse, e = torch_compress(a) + else: + a_sparse, e = compress_kernel(M, K, 32, 32, T.float16, use_cutlass_layout=use_cutlass_layout)(a) + + c = kernel(a_sparse, e, b) + + ref_c = a @ b + + assert not c.isnan().any(), "Reference result contains NaNs, please report an issue" + torch_assert_close(c, ref_c.to(c.dtype), rtol=1e-3, atol=1e-3) + print(f"Precision check passed. Max diff: {(c - ref_c).abs().max()}, Mean diff: {(c - ref_c).abs().mean()}") + + latency = do_bench(lambda: kernel(a_sparse, e, b)) + ref_latency = do_bench(lambda: a @ b) + + total_flops = 2 * M * N * K + tflops = total_flops / latency / 1e9 + ref_tflops = total_flops / ref_latency / 1e9 + print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s") + print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") + parser.add_argument("--use_cutlass_layout", action="store_true", help="Use cutlass layout for E tensor") + parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference") + parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype") + parser.add_argument("--cfg", type=str, choices=["4090"], default="4090") + args = parser.parse_args() + main( + M=args.m, + N=args.n, + K=args.k, + use_cutlass_layout=args.use_cutlass_layout, + use_torch_compressor=args.use_torch_compressor, + accum_dtype=args.accum_dtype, + cfg=args.cfg, + ) diff --git a/examples/gemm_sp/example_gemm_sp.py b/examples/gemm_sp/example_gemm_sp.py index 505f2b883..769ea6736 100644 --- a/examples/gemm_sp/example_gemm_sp.py +++ b/examples/gemm_sp/example_gemm_sp.py @@ -1,99 +1,90 @@ -# Copyright (c) Tile-AI Corporation. -# Licensed under the MIT License. import argparse import tilelang import tilelang.language as T -from tilelang.layout import make_metadata_layout +from tilelang.layout import make_cutlass_metadata_layout from tilelang.utils.sparse import compress, randn_semi_sparse from tilelang.contrib import nvcc -from triton.testing import do_bench +from tilelang.profiler import do_bench import torch arch = nvcc.get_target_compute_version() -ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} - -default_config = { # take best config from autotune script +DEFAULT_CONFIG = { # take best config from autotune script "4090": { - 'float': { - 'block_M': 128, - 'block_N': 64, - 'block_K': 64, - 'num_stages': 1, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True + T.float: { + "block_M": 128, + "block_N": 64, + "block_K": 64, + "num_stages": 1, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + T.float16: { + "block_M": 256, + "block_N": 128, + "block_K": 64, + "num_stages": 2, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, }, - 'float16': { - 'block_M': 256, - 'block_N': 128, - 'block_K': 64, - 'num_stages': 2, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True - } }, "h20": { - 'float': { - 'block_M': 128, - 'block_N': 64, - 'block_K': 128, - 'num_stages': 3, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True + T.float: { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, }, - 'float16': { - 'block_M': 128, - 'block_N': 64, - 'block_K': 128, - 'num_stages': 3, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True - } - } + T.float16: { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + }, } +ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} + @tilelang.jit(out_idx=[-1]) -def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, - enable_rasterization): +def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization): e_factor, e_dtype = ARCH_INFO[arch] @T.prim_func def gemm_sp_fp16( - A_sparse: T.Tensor((M, K // 2), 'float16'), - E: T.Tensor((M, K // e_factor), e_dtype), - B: T.Tensor((K, N), 'float16'), - C: T.Tensor((M, N), accum_dtype), + A_sparse: T.Tensor((M, K // 2), T.float16), + E: T.Tensor((M, K // e_factor), e_dtype), + B: T.Tensor((K, N), T.float16), + C: T.Tensor((M, N), accum_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K // 2), 'float16') + A_shared = T.alloc_shared((block_M, block_K // 2), T.float16) E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) - B_shared = T.alloc_shared((block_K, block_N), 'float16') + B_shared = T.alloc_shared((block_K, block_N), T.float16) C_shared = T.alloc_shared((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) T.disable_warp_group_reg_alloc() T.use_swizzle(panel_size=10, enable=enable_rasterization) - T.annotate_layout({ - E: - make_metadata_layout( - E, mma_dtype="float16", backend="cutlass", block_k=block_K, arch=arch), - E_shared: - make_metadata_layout( - E_shared, - mma_dtype="float16", - backend="cutlass", - block_k=block_K, - arch=arch), - }) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, block_k=block_K, arch=arch), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, block_k=block_K, arch=arch), + } + ) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) T.copy(E[by * block_M, k * block_K // e_factor], E_shared) @@ -106,30 +97,13 @@ def gemm_sp_fp16( return gemm_sp_fp16 -def main(): - parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") - parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") - parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") - parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") - parser.add_argument( - "--accum_dtype", - type=str, - default="float", - choices=["float", "float16"], - help="Accumulation datatype") - parser.add_argument("--cfg", type=str, choices=["4090", "h20"], required=True) - args = parser.parse_args() - kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, - **default_config[args.cfg][args.accum_dtype]) +def main(M=1024, N=1024, K=1024, accum_dtype=T.float, cfg="h20"): + kernel = matmul_sp_fp16(M, N, K, accum_dtype, **DEFAULT_CONFIG[cfg][accum_dtype]) - a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half) - b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half) + a = randn_semi_sparse(M, K, device="cuda", dtype=torch.half) + b = torch.randn(K, N, device="cuda", dtype=torch.half) - a_sparse, e = compress( - a, - transposed=False, - block_k=default_config[args.cfg][args.accum_dtype]['block_K'], - arch=arch) + a_sparse, e = compress(a, transposed=False, block_k=DEFAULT_CONFIG[cfg][accum_dtype]["block_K"], arch=arch) c = kernel(a_sparse, e, b) ref_c = a @ b @@ -141,12 +115,19 @@ def main(): latency = do_bench(lambda: kernel(a_sparse, e, b)) ref_latency = do_bench(lambda: a @ b) - total_flops = 2 * args.m * args.n * args.k + total_flops = 2 * M * N * K tflops = total_flops / latency / 1e9 ref_tflops = total_flops / ref_latency / 1e9 - print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s") - print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s") + print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s") + print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s") if __name__ == "__main__": - main() + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") + parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype") + parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090") + args = parser.parse_args() + main(M=args.m, N=args.n, K=args.k, accum_dtype=args.accum_dtype, cfg=args.cfg) diff --git a/examples/gemm_sp/test_example_gemm_sp.py b/examples/gemm_sp/test_example_gemm_sp.py new file mode 100644 index 000000000..fe26df144 --- /dev/null +++ b/examples/gemm_sp/test_example_gemm_sp.py @@ -0,0 +1,16 @@ +import tilelang.testing + +import example_custom_compress +import example_gemm_sp + + +def test_example_custom_compress(): + example_custom_compress.main() + + +def test_example_gemm_sp(): + example_gemm_sp.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/gemm_splitk/example_tilelang_gemm_splitk.py b/examples/gemm_splitk/example_tilelang_gemm_splitk.py index c96669711..64ffade8e 100644 --- a/examples/gemm_splitk/example_tilelang_gemm_splitk.py +++ b/examples/gemm_splitk/example_tilelang_gemm_splitk.py @@ -3,27 +3,16 @@ @tilelang.jit -def matmul(M, - N, - K, - block_M, - block_N, - block_K, - split_k, - dtype="float16", - accum_dtype="float", - out_dtype="float32"): - +def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype=T.float16, accum_dtype=T.float32, out_dtype=T.float32): splitK = K // split_k @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), out_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype) @@ -67,5 +56,28 @@ def main(): torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2) +def run_regression_perf(): + M = 4096 + N = 4096 + K = 4096 + block_M = 128 + block_N = 128 + block_K = 32 + split_k = 4 + kernel = matmul(M, N, K, block_M, block_N, block_K, split_k) + import torch + + torch.random.manual_seed(42) + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + c = torch.zeros(M, N).cuda().float() + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(a, b, c) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py b/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py index 145d622ed..3d33478cf 100644 --- a/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py +++ b/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py @@ -3,27 +3,16 @@ @tilelang.jit -def matmul(M, - N, - K, - block_M, - block_N, - block_K, - split_k, - dtype="float16", - accum_dtype="float", - out_dtype="float32"): - +def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype=T.float16, accum_dtype=T.float32, out_dtype=T.float32): splitK = K // split_k @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), out_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype) @@ -66,5 +55,29 @@ def main(): torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2) +def run_regression_perf(): + M = 4096 + N = 4096 + K = 4096 + block_M = 128 + block_N = 128 + block_K = 32 + split_k = 4 + + kernel = matmul(M, N, K, block_M, block_N, block_K, split_k) + import torch + + torch.random.manual_seed(42) + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + c = torch.zeros(M, N).cuda().float() + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(a, b, c) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/gemm_splitk/regression_example_gemm_splitk.py b/examples/gemm_splitk/regression_example_gemm_splitk.py new file mode 100644 index 000000000..c76b7e55c --- /dev/null +++ b/examples/gemm_splitk/regression_example_gemm_splitk.py @@ -0,0 +1,15 @@ +import tilelang.testing +import example_tilelang_gemm_splitk +import example_tilelang_gemm_splitk_vectorize_atomicadd + + +def regression_example_tilelang_gemm_splitk(): + tilelang.testing.process_func(example_tilelang_gemm_splitk.run_regression_perf) + + +def regression_example_tilelang_gemm_splitk_vectorize_atomicadd(): + tilelang.testing.process_func(example_tilelang_gemm_splitk_vectorize_atomicadd.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/gemm_streamk/example_tilelang_gemm_streamk.py b/examples/gemm_streamk/example_tilelang_gemm_streamk.py index 31cf40647..b2e8e9369 100644 --- a/examples/gemm_streamk/example_tilelang_gemm_streamk.py +++ b/examples/gemm_streamk/example_tilelang_gemm_streamk.py @@ -39,7 +39,7 @@ def cdiv(a, b): # Two-tile SK + DP streamk_tiles = total_tiles % streamk_programs -if (total_tiles - streamk_tiles > streamk_programs): # (total_tiles // total_programs > 1) +if total_tiles - streamk_tiles > streamk_programs: # (total_tiles // total_programs > 1) streamk_tiles += streamk_programs blocking_tiles = total_tiles - streamk_tiles @@ -77,95 +77,71 @@ def tl_matmul_streamk( A_shared_shape = (block_M, block_K) if not trans_A else (block_K, block_M) B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K) - @T.macro - def compute_first_wave( - pid: T.int32, - A_buf: T.Tensor, - A_buf_shared: T.SharedBuffer, - B_buf: T.Tensor, - B_buf_shared: T.SharedBuffer, - C: T.Tensor, - C_local: T.LocalBuffer, - ): - start_iter = T.alloc_fragment((1,), "int32", "local") - end_iter = T.alloc_fragment((1,), "int32", "local") - - start_iter[0] = pid * streamk_full_tiles + T.min(pid, streamk_partial_tiles) - last_iter = (pid + 1) * streamk_full_tiles + T.min(pid + 1, streamk_partial_tiles) - - while start_iter[0] < last_iter: - end_iter[0] = T.min( - start_iter[0] + (iters_per_tile - (start_iter[0] % iters_per_tile)), - last_iter, - ) - - tile_id = start_iter[0] // iters_per_tile - remain_iters = start_iter[0] % iters_per_tile - pid_m = tile_id // T.ceildiv(N, block_N) - pid_n = tile_id % T.ceildiv(N, block_N) - - T.clear(C_local) - for k in T.Pipelined(end_iter[0] - start_iter[0], num_stages=num_stages): - T.copy( - A_buf[pid_m * block_M, (k + (start_iter[0] % iters_per_tile)) * block_K], - A_buf_shared, - ) - T.copy( - B_buf[pid_n * block_N, (k + (start_iter[0] % iters_per_tile)) * block_K], - B_buf_shared, - ) - T.gemm(A_buf_shared, B_buf_shared, C_local, transpose_B=trans_B) - - # last iteration of the tile always happens before its start on another SM - if remain_iters == 0 and (end_iter[0] % iters_per_tile == 0): - T.copy(C_local, C[pid_m * block_M, pid_n * block_N]) - else: - for i, j in T.Parallel(block_M, block_N): - T.atomic_add(C[pid_m * block_M + i, pid_n * block_N + j], C_local[i, j]) - - start_iter[0] = end_iter[0] - - @T.macro - def compute_full_tiles( - pid: T.int32, - A_buf: T.Tensor, - A_shared: T.SharedBuffer, - B_buf: T.Tensor, - B_shared: T.SharedBuffer, - C: T.Tensor, - C_local: T.LocalBuffer, - ): - - for p in T.serial(sm_patition_factor): - tile_id = pid + streamk_tiles + p * total_sm - pid_m = tile_id // T.ceildiv(N, block_N) - pid_n = tile_id % T.ceildiv(N, block_N) - T.clear(C_local) - - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): - T.copy(A_buf[pid_m * block_M, k * block_K], A_shared) - T.copy(B_buf[pid_n * block_N, k * block_K], B_shared) - T.gemm(A_shared, B_shared, C_local, transpose_B=trans_B) - T.copy(C_local, C[pid_m * block_M, pid_n * block_N]) - @T.prim_func def main( - A: T.Tensor(A_shape, dtypeAB), - B: T.Tensor(B_shape, dtypeAB), - C: T.Tensor((M, N), dtypeC), + A: T.Tensor(A_shape, dtypeAB), + B: T.Tensor(B_shape, dtypeAB), + C: T.Tensor((M, N), dtypeC), ): with T.Kernel(streamk_programs, threads=threads) as pid: - A_shared = T.alloc_shared(A_shared_shape, dtypeAB) B_shared = T.alloc_shared(B_shared_shape, dtypeAB) A_shared_full_tiles = T.alloc_shared(A_shared_shape, dtypeAB) B_shared_full_tiles = T.alloc_shared(B_shared_shape, dtypeAB) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - compute_first_wave(pid, A, A_shared, B, B_shared, C, C_local) + # compute first wave + start_iter = T.alloc_fragment((1,), T.int32, "local") + end_iter = T.alloc_fragment((1,), T.int32, "local") + + start_iter[0] = pid * streamk_full_tiles + T.min(pid, streamk_partial_tiles) + last_iter = (pid + 1) * streamk_full_tiles + T.min(pid + 1, streamk_partial_tiles) + while start_iter[0] < last_iter: + end_iter[0] = T.min( + start_iter[0] + (iters_per_tile - (start_iter[0] % iters_per_tile)), + last_iter, + ) + + tile_id = start_iter[0] // iters_per_tile + remain_iters = start_iter[0] % iters_per_tile + pid_m = tile_id // T.ceildiv(N, block_N) + pid_n = tile_id % T.ceildiv(N, block_N) + + T.clear(C_local) + for k in T.Pipelined(end_iter[0] - start_iter[0], num_stages=num_stages): + T.copy( + A[pid_m * block_M, (k + (start_iter[0] % iters_per_tile)) * block_K], + A_shared, + ) + T.copy( + B[pid_n * block_N, (k + (start_iter[0] % iters_per_tile)) * block_K], + B_shared, + ) + T.gemm(A_shared, B_shared, C_local, transpose_B=trans_B) + + # last iteration of the tile always happens before its start on another SM + if remain_iters == 0 and (end_iter[0] % iters_per_tile == 0): + T.copy(C_local, C[pid_m * block_M, pid_n * block_N]) + else: + for i, j in T.Parallel(block_M, block_N): + T.atomic_add(C[pid_m * block_M + i, pid_n * block_N + j], C_local[i, j]) + + start_iter[0] = end_iter[0] + + # compute full tiles if sm_patition_factor > 0: - compute_full_tiles(pid, A, A_shared_full_tiles, B, B_shared_full_tiles, C, C_local) + for p in T.serial(sm_patition_factor): + tile_id = pid + streamk_tiles + p * total_sm + pid_m = tile_id // T.ceildiv(N, block_N) + pid_n = tile_id % T.ceildiv(N, block_N) + T.clear(C_local) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): + T.copy(A[pid_m * block_M, k * block_K], A_shared_full_tiles) + T.copy(B[pid_n * block_N, k * block_K], B_shared_full_tiles) + T.gemm(A_shared_full_tiles, B_shared_full_tiles, C_local, transpose_B=trans_B) + T.copy(C_local, C[pid_m * block_M, pid_n * block_N]) return main @@ -181,9 +157,9 @@ def main(): BLOCK_SIZE_K, False, True, - "float16", - "float16", - "float32", + T.float16, + T.float16, + T.float32, 2, 64, ) @@ -201,5 +177,30 @@ def main(): torch.testing.assert_close(C, b_c, rtol=1e-2, atol=1e-2) +def run_regression_perf(): + kernel = tl_matmul_streamk( + m, + n, + k, + streamk_tiles, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + False, + True, + "float16", + "float16", + "float32", + 2, + 64, + ) + b_c = torch.zeros((m, n), device="cuda", dtype=torch.float16) + torch.cuda.synchronize() + + from tilelang.profiler import do_bench + + return do_bench(lambda: kernel(A, B, b_c), backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/gemm_streamk/test_example_tilelang_gemm_splitk.py b/examples/gemm_streamk/test_example_tilelang_gemm_streamk.py similarity index 100% rename from examples/gemm_streamk/test_example_tilelang_gemm_splitk.py rename to examples/gemm_streamk/test_example_tilelang_gemm_streamk.py diff --git a/examples/gemv/example_gemv.py b/examples/gemv/example_gemv.py index 4e43dcd9a..a5ecffbd0 100644 --- a/examples/gemv/example_gemv.py +++ b/examples/gemv/example_gemv.py @@ -17,15 +17,14 @@ def naive_gemv( K: int, BLOCK_N: int, BLOCK_K: int, - dtype: str = "float16", - accum_dtype: str = "float", + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, ): - @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N)) as bn: tn = T.get_thread_binding(0) # tn = threadIdx.x @@ -38,8 +37,7 @@ def main( A_shared[tk] = A[bk * BLOCK_K + tk] B_shared[tn, tk] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk] for tk in T.serial(BLOCK_K): - C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn, - tk].astype(accum_dtype) + C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn, tk].astype(accum_dtype) C[bn * BLOCK_N + tn] = C_reg[0] return main @@ -51,15 +49,14 @@ def naive_splitk_gemv( K: int, BLOCK_N: int, BLOCK_K: int, - dtype: str = "float16", - accum_dtype: str = "float", + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, ): - @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, BLOCK_K)) as bn: tn = T.get_thread_binding(0) @@ -88,16 +85,16 @@ def splitk_gemv( BLOCK_N: int, BLOCK_K: int, reduce_threads: int, - dtype: str = "float16", - accum_dtype: str = "float", + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, ): TILE_K = T.ceildiv(BLOCK_K, reduce_threads) @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: tn = T.get_thread_binding(0) @@ -127,8 +124,8 @@ def splitk_gemv_vectorized( K: int, BLOCK_N: int, reduce_threads: int, - dtype: str = "float16", - accum_dtype: str = "float", + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, ): MAX_TRANSACTION_SIZE_IN_BITS = 128 TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits @@ -136,9 +133,9 @@ def splitk_gemv_vectorized( @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: tn = T.get_thread_binding(0) @@ -168,8 +165,8 @@ def splitk_gemv_vectorized_tvm( K: int, BLOCK_N: int, reduce_threads: int, - dtype: str = "float16", - accum_dtype: str = "float", + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, ): MAX_TRANSACTION_SIZE_IN_BITS = 128 TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits @@ -177,9 +174,9 @@ def splitk_gemv_vectorized_tvm( @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: tn = T.get_thread_binding(0) @@ -197,9 +194,9 @@ def main( C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) C_reduced = T.alloc_local((1,), accum_dtype) with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), - "reduce_scope", - T.reinterpret(T.uint64(0), dtype="handle"), + T.comm_reducer(lambda x, y: x + y, [T.cast(0, accum_dtype)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), ): T.evaluate( T.tvm_thread_allreduce( @@ -209,7 +206,8 @@ def main( C_reduced[0], tk, dtype="handle", - )) + ) + ) C[bn * BLOCK_N + tn] = C_reduced[0] @@ -218,10 +216,8 @@ def main( def get_block_template_configs(): iter_params = dict( - block_M=[2, 4, 8, 32, 64, 128], - block_N=[2, 4, 8, 32, 64, 128], - num_stages=[0, 1, 2, 3, 4], - threads=[32, 64, 128, 256]) + block_M=[2, 4, 8, 32, 64, 128], block_N=[2, 4, 8, 32, 64, 128], num_stages=[0, 1, 2, 3, 4], threads=[32, 64, 128, 256] + ) return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] @@ -237,18 +233,11 @@ def get_block_template_configs(): }, out_idx=[2], ) -def gemv_alloc_reducer(M, - N, - block_M=128, - block_N=128, - num_stages=2, - threads=256, - dtype: str = "float16", - accum_dtype: str = "float"): - +def gemv_alloc_reducer( + M, N, block_M=128, block_N=128, num_stages=2, threads=256, dtype: T.dtype = T.float16, accum_dtype: T.dtype = T.float +): @T.prim_func - def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M, - dtype)): # type: ignore + def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M, dtype)): # type: ignore with T.Kernel(T.ceildiv(M, block_M), threads=threads) as i0_m: o_reducer = T.alloc_reducer(block_M, accum_dtype, replication="all") T.clear(o_reducer) @@ -287,17 +276,17 @@ def get_autotuned_kernel( BLOCK_N=None, reduce_threads=None, ): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 MAX_TRANSACTION_SIZE_IN_BITS = 128 TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits BLOCK_K = reduce_threads * TILE_K @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: tn = T.get_thread_binding(0) @@ -315,9 +304,9 @@ def main( C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) C_reduced = T.alloc_local((1,), accum_dtype) with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), - "reduce_scope", - T.reinterpret(T.uint64(0), dtype="handle"), + T.comm_reducer(lambda x, y: x + y, [T.cast(0, accum_dtype)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), ): T.evaluate( T.tvm_thread_allreduce( @@ -327,21 +316,22 @@ def main( C_reduced[0], tk, dtype="handle", - )) + ) + ) C[bn * BLOCK_N + tn] = C_reduced[0] return main -def check_correctness_and_bench(kernel, N, K, bench_ref=True): +def check_correctness_and_bench(kernel, N, K, do_bench=True): profiler = kernel.get_profiler() profiler.assert_allclose(lambda x, y: x @ y.T, atol=1e-2, rtol=1e-2) - if bench_ref: + if do_bench: latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=50) print(f"Torch Latency: {latency} ms") - latency = profiler.do_bench(kernel, warmup=50) - print(f"TileLang Latency: {latency} ms\n") + latency = profiler.do_bench(kernel, warmup=50) + print(f"TileLang Latency: {latency} ms\n") def main(do_bench: bool = True): @@ -350,16 +340,16 @@ def main(do_bench: bool = True): parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K") args, _ = parser.parse_known_args() N, K = args.n, args.k - check_correctness_and_bench(naive_gemv(N, K, 128, 128), N, K) - check_correctness_and_bench(naive_splitk_gemv(N, K, 32, 32), N, K) - check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K) - check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K) - check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K) - check_correctness_and_bench(gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K) + check_correctness_and_bench(naive_gemv(N, K, 128, 128), N, K, do_bench=do_bench) + check_correctness_and_bench(naive_splitk_gemv(N, K, 32, 32), N, K, do_bench=do_bench) + check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K, do_bench=do_bench) + check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K, do_bench=do_bench) + check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K, do_bench=do_bench) + check_correctness_and_bench(gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K, do_bench=do_bench) print("Test passed!") - if not do_bench: + if do_bench: best_result = get_autotuned_kernel(N, K) best_config = best_result.config kernel = splitk_gemv_vectorized_tvm(N, K, **best_config) @@ -374,5 +364,23 @@ def main(do_bench: bool = True): print(f"TileLang BlockReduce Latency: {tilelang_tile_latency} ms\n") +def run_regression_perf(): + N, K = 4096, 4096 + latency = 0.0 + kernel_list = [ + naive_gemv(N, K, 128, 128), + naive_splitk_gemv(N, K, 32, 32), + splitk_gemv(N, K, 32, 32, 32), + splitk_gemv_vectorized(N, K, 2, 32), + splitk_gemv_vectorized_tvm(N, K, 2, 32), + gemv_alloc_reducer(N, K, block_M=128, block_N=128), + ] + for kernel in kernel_list: + profiler = kernel.get_profiler() + # Benchmark the TileLang kernel itself, not the PyTorch reference. + latency += profiler.do_bench(backend="cupti") + return latency / len(kernel_list) + + if __name__ == "__main__": main() diff --git a/examples/gemv/regression_example_gemv.py b/examples/gemv/regression_example_gemv.py new file mode 100644 index 000000000..dd6f1d39f --- /dev/null +++ b/examples/gemv/regression_example_gemv.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_gemv + + +def regression_example_gemv(): + tilelang.testing.process_func(example_gemv.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/gemv/test_example_gemv.py b/examples/gemv/test_example_gemv.py index 3881ca769..323337a7a 100644 --- a/examples/gemv/test_example_gemv.py +++ b/examples/gemv/test_example_gemv.py @@ -1,5 +1,3 @@ -import tilelang.testing - import example_gemv @@ -8,4 +6,4 @@ def test_example_gemv(): if __name__ == "__main__": - tilelang.testing.main() + test_example_gemv() diff --git a/examples/grouped_gemm/example_grouped_gemm_bwd.py b/examples/grouped_gemm/example_grouped_gemm_bwd.py index ac8da7e2c..49cce0d1d 100644 --- a/examples/grouped_gemm/example_grouped_gemm_bwd.py +++ b/examples/grouped_gemm/example_grouped_gemm_bwd.py @@ -5,78 +5,55 @@ import tilelang.language as T -@tilelang.jit( - out_idx=[2], pass_configs={ - "tl.disable_tma_lower": True, - "tl.disable_warp_specialized": True - }) -def grouped_gemm_fwd(batch_sum, - batch_count, - K, - N, - block_M, - block_N, - block_K, - num_stages=2, - threads=128, - dtype="float16"): +@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +def grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16): """ args: a (torch.Tensor): Input tensor of shape (M, K). b (torch.Tensor): Input tensor of shape (G, K, N). """ - accum_dtype = "float32" + accum_dtype = T.float32 @T.prim_func def kernel( - A: T.Tensor([batch_sum, K], dtype), # type: ignore - B: T.Tensor([batch_count, K, N], dtype), # type: ignore - C: T.Tensor([batch_sum, N], dtype), # type: ignore - batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore - batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore - batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore + A: T.Tensor([batch_sum, K], dtype), # type: ignore + B: T.Tensor([batch_count, K, N], dtype), # type: ignore + C: T.Tensor([batch_sum, N], dtype), # type: ignore + batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore + batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore + batch_padded_offsets: T.Tensor([batch_count], T.int32), # type: ignore ): - - with T.Kernel( - T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N), threads=threads) as (bx, by): A_shared = T.alloc_shared([block_M, block_K], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype) C_local = T.alloc_fragment([block_M, block_N], accum_dtype) - cur_batch_idx = T.alloc_local([1], "int32") - cur_batch_size = T.alloc_local([1], "int32") + cur_batch_idx = T.alloc_var(dtype=T.int32) + cur_batch_size = T.alloc_var(dtype=T.int32) m_start_padded = bx * block_M for i in range(batch_count): - in_cur_batch_idx = (m_start_padded >= batch_padded_offsets[i]) - cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0]) + in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i] + cur_batch_idx = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx) - cur_batch_size[0] = batch_sizes[cur_batch_idx[0]] - m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[ - cur_batch_idx[0]] - actual_rows = T.max( - 0, - T.min(block_M, - cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded)) + cur_batch_size = batch_sizes[cur_batch_idx] + m_start = m_start_padded - batch_padded_offsets[cur_batch_idx] + batch_offsets[cur_batch_idx] + actual_rows = T.max(0, T.min(block_M, cur_batch_size + batch_padded_offsets[cur_batch_idx] - m_start_padded)) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[m_start:m_start + block_M, k * block_K:(k + 1) * block_K], A_shared) - T.copy( - B[cur_batch_idx[0], k * block_K:(k + 1) * block_K, - by * block_N:(by + 1) * block_N], B_shared) + T.copy(A[m_start : m_start + block_M, k * block_K : (k + 1) * block_K], A_shared) + T.copy(B[cur_batch_idx, k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared) T.gemm(A_shared, B_shared, C_local) for i, j in T.Parallel(block_M, block_N): - with T.If(i < actual_rows), T.Then(): + if i < actual_rows: C[m_start + i, by * block_N + j] = C_local[i, j] return kernel class _GroupedGEMM(torch.autograd.Function): - @staticmethod def forward(ctx, a, b, batch_sizes): block_M = 64 @@ -99,15 +76,11 @@ def forward(ctx, a, b, batch_sizes): for i in range(batch_count - 1): batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes[i]) for i in range(batch_count - 1): - batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + - math.ceil((batch_sizes[i] + 1) / padding_M) * - padding_M) + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes[i] + 1) / padding_M) * padding_M) batch_offsets = torch.tensor(batch_offsets_list, device=a.device, dtype=torch.int32) - batch_padded_offsets = torch.tensor( - batch_padded_offsets_list, device=a.device, dtype=torch.int32) + batch_padded_offsets = torch.tensor(batch_padded_offsets_list, device=a.device, dtype=torch.int32) - kernel = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, - num_stages, threads) + kernel = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages, threads) o = kernel(a, b, batch_sizes, batch_offsets, batch_padded_offsets) ctx.save_for_backward(a, b, batch_sizes, batch_offsets) @@ -135,8 +108,7 @@ def maybe_contiguous(x): return x A, B, batch_sizes = [maybe_contiguous(x) for x in (A, B, batch_sizes)] - kernel = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K, - num_stages, threads) + kernel = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K, num_stages, threads) dB = kernel(A, grad_output, batch_sizes, batch_offsets) return None, dB, None @@ -172,9 +144,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): for i in range(batch_count - 1): batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i]) for i in range(batch_count - 1): - batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + - math.ceil((batch_sizes_list[i] + 1) / padding_M) * - padding_M) + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes_list[i] + 1) / padding_M) * padding_M) A = torch.randn(batch_sum, K, device=device, dtype=dtype) B = torch.randn(batch_count, K, M, device=device, dtype=dtype) C = torch.empty(batch_sum, M, device=device, dtype=dtype) @@ -187,40 +157,24 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets -@tilelang.jit( - out_idx=[2], pass_configs={ - "tl.disable_tma_lower": True, - "tl.disable_warp_specialized": True - }) -def grouped_gemm_bwd(batch_sum, - batch_count, - M, - N, - block_M, - block_N, - block_K, - num_stages=2, - threads=128, - dtype="float16"): +@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +def grouped_gemm_bwd(batch_sum, batch_count, M, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16): """ args: a (torch.Tensor): Input tensor of shape (M, K). b (torch.Tensor): Input tensor of shape (G, K, N). """ - accum_dtype = "float32" + accum_dtype = T.float32 @T.prim_func def kernel( - A: T.Tensor([batch_sum, M], dtype), # type: ignore - B: T.Tensor([batch_sum, N], dtype), # type: ignore - C: T.Tensor([batch_count, M, N], dtype), # type: ignore - batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore - batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore + A: T.Tensor([batch_sum, M], dtype), # type: ignore + B: T.Tensor([batch_sum, N], dtype), # type: ignore + C: T.Tensor([batch_count, M, N], dtype), # type: ignore + batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore + batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore ): - - with T.Kernel( - T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count, - threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count, threads=threads) as (bx, by, bz): A_shared = T.alloc_shared([block_K, block_M], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype) C_local = T.alloc_fragment([block_M, block_N], accum_dtype) @@ -228,13 +182,9 @@ def kernel( T.clear(C_local) for k in T.Pipelined(T.ceildiv(batch_sizes[bz], block_K), num_stages=num_stages): for i, j in T.Parallel(block_K, block_M): - A_shared[i, j] = T.if_then_else( - i < batch_sizes[bz], A[batch_offsets[bz] + k * block_K + i, - bx * block_M + j], 0) + A_shared[i, j] = T.if_then_else(i < batch_sizes[bz], A[batch_offsets[bz] + k * block_K + i, bx * block_M + j], 0) for i, j in T.Parallel(block_K, block_N): - B_shared[i, j] = T.if_then_else( - i < batch_sizes[bz], B[batch_offsets[bz] + k * block_K + i, - by * block_N + j], 0) + B_shared[i, j] = T.if_then_else(i < batch_sizes[bz], B[batch_offsets[bz] + k * block_K + i, by * block_N + j], 0) T.gemm(A_shared, B_shared, C_local, transpose_A=True) T.copy(C_local, C[bz, bx * block_M, by * block_N]) @@ -242,23 +192,12 @@ def kernel( return kernel -def run_tilelang_grouped_gemm(batch_sizes_list, - K, - M, - block_M, - block_N, - block_K, - trans_b, - num_stages=2, - threads=128, - profile=False): - +def run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages=2, threads=128, profile=False): padding_M = block_M device = torch.device("cuda") dtype = torch.float16 - A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs( - batch_sizes_list, K, M, False, padding_M, device, dtype) + A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(batch_sizes_list, K, M, False, padding_M, device, dtype) A.requires_grad_(False) B.requires_grad_(True) @@ -273,10 +212,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list, O.backward(dO, retain_graph=True) dB, B.grad = B.grad.clone(), None - if ( - torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) and \ - torch.allclose(dB, dB_ref, rtol=1e-2, atol=1e-2) - ): + if torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) and torch.allclose(dB, dB_ref, rtol=1e-2, atol=1e-2): print("✅ Tilelang and Torch match") else: print("❌ Tilelang and Torch mismatch") @@ -284,12 +220,11 @@ def run_tilelang_grouped_gemm(batch_sizes_list, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - '--batch_sizes', type=str, default="64, 128", help='comma-separated batch sizes') - parser.add_argument('--K', type=int, default=8192, help='reduce dim') - parser.add_argument('--M', type=int, default=8192, help='output dim') - parser.add_argument('--trans_b', action="store_true", help="transpose B") - parser.add_argument('--profile', action="store_true", help="profile") + parser.add_argument("--batch_sizes", type=str, default="64, 128", help="comma-separated batch sizes") + parser.add_argument("--K", type=int, default=8192, help="reduce dim") + parser.add_argument("--M", type=int, default=8192, help="output dim") + parser.add_argument("--trans_b", action="store_true", help="transpose B") + parser.add_argument("--profile", action="store_true", help="profile") args = parser.parse_args() batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")] @@ -301,14 +236,4 @@ def run_tilelang_grouped_gemm(batch_sizes_list, num_stages = 2 threads = 256 - run_tilelang_grouped_gemm( - batch_sizes_list, - K, - M, - block_M, - block_N, - block_K, - trans_b, - num_stages, - threads, - profile=args.profile) + run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages, threads, profile=args.profile) diff --git a/examples/grouped_gemm/example_grouped_gemm_fwd.py b/examples/grouped_gemm/example_grouped_gemm_fwd.py index 9b58e3a21..b71472741 100644 --- a/examples/grouped_gemm/example_grouped_gemm_fwd.py +++ b/examples/grouped_gemm/example_grouped_gemm_fwd.py @@ -18,8 +18,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False): torch.Tensor: Resulting tensor after grouped matrix multiplication. """ assert a.shape[0] == sum(batch_sizes), "Sum of batch_sizes must equal the first dimension of a" - assert b.shape[0] == len( - batch_sizes), "The first dimension of b must match the length of batch_sizes" + assert b.shape[0] == len(batch_sizes), "The first dimension of b must match the length of batch_sizes" # Initialize output tensor output = torch.empty((sum(batch_sizes), b.shape[2]), device=a.device, dtype=a.dtype) @@ -38,15 +37,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False): @tilelang.jit(out_idx=[2]) -def grouped_gemm(batch_sizes_list, - K, - N, - block_M, - block_N, - block_K, - num_stages=2, - threads=128, - dtype="float16"): +def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16): """ args: a (torch.Tensor): Input tensor of shape (M, K). @@ -54,50 +45,43 @@ def grouped_gemm(batch_sizes_list, """ batch_sum = sum(batch_sizes_list) batch_count = len(batch_sizes_list) - accum_dtype = "float32" + accum_dtype = T.float32 total_m_blocks = sum((size + block_M - 1) // block_M for size in batch_sizes_list) @T.prim_func def kernel( - A: T.Tensor([batch_sum, K], dtype), # type: ignore - B: T.Tensor([batch_count, K, N], dtype), # type: ignore - C: T.Tensor([batch_sum, N], dtype), # type: ignore - batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore - batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore - batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore + A: T.Tensor([batch_sum, K], dtype), # type: ignore + B: T.Tensor([batch_count, K, N], dtype), # type: ignore + C: T.Tensor([batch_sum, N], dtype), # type: ignore + batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore + batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore + batch_padded_offsets: T.Tensor([batch_count], T.int32), # type: ignore ): - with T.Kernel(total_m_blocks, T.ceildiv(N, block_N), threads=threads) as (bx, by): A_shared = T.alloc_shared([block_M, block_K], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype) C_local = T.alloc_fragment([block_M, block_N], accum_dtype) - cur_batch_idx = T.alloc_local([1], "int32") - cur_batch_size = T.alloc_local([1], "int32") + cur_batch_idx = T.alloc_var(dtype=T.int32) + cur_batch_size = T.alloc_var(dtype=T.int32) m_start_padded = bx * block_M for i in range(batch_count): - in_cur_batch_idx = (m_start_padded >= batch_padded_offsets[i]) - cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0]) + in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i] + cur_batch_idx = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx) - cur_batch_size[0] = batch_sizes[cur_batch_idx[0]] - m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[ - cur_batch_idx[0]] - actual_rows = T.max( - 0, - T.min(block_M, - cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded)) + cur_batch_size = batch_sizes[cur_batch_idx] + m_start = m_start_padded - batch_padded_offsets[cur_batch_idx] + batch_offsets[cur_batch_idx] + actual_rows = T.max(0, T.min(block_M, cur_batch_size + batch_padded_offsets[cur_batch_idx] - m_start_padded)) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[m_start:m_start + block_M, k * block_K:(k + 1) * block_K], A_shared) - T.copy( - B[cur_batch_idx[0], k * block_K:(k + 1) * block_K, - by * block_N:(by + 1) * block_N], B_shared) + T.copy(A[m_start : m_start + block_M, k * block_K : (k + 1) * block_K], A_shared) + T.copy(B[cur_batch_idx, k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared) T.gemm(A_shared, B_shared, C_local) for i, j in T.Parallel(block_M, block_N): - with T.If(i < actual_rows), T.Then(): + if i < actual_rows: C[m_start + i, by * block_N + j] = C_local[i, j] return kernel @@ -111,8 +95,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): for i in range(batch_count - 1): batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i]) for i in range(batch_count - 1): - batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + - math.ceil((batch_sizes_list[i]) / padding_M) * padding_M) + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes_list[i]) / padding_M) * padding_M) A = torch.randn(batch_sum, K, device=device, dtype=dtype) B = torch.randn(batch_count, K, M, device=device, dtype=dtype) C = torch.empty(batch_sum, M, device=device, dtype=dtype) @@ -125,27 +108,16 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets -def run_tilelang_grouped_gemm(batch_sizes_list, - K, - M, - block_M, - block_N, - block_K, - trans_b, - num_stages=2, - threads=128, - profile=False): +def run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages=2, threads=128, profile=False): padding_M = block_M batch_sum = sum(batch_sizes_list) - kernel = grouped_gemm( - tuple(batch_sizes_list), K, M, block_M, block_N, block_K, num_stages, threads) + kernel = grouped_gemm(tuple(batch_sizes_list), K, M, block_M, block_N, block_K, num_stages, threads) # print(kernel.get_kernel_source()) device = torch.device("cuda") dtype = torch.float16 - A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs( - batch_sizes_list, K, M, trans_b, padding_M, device, dtype) + A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype) out = kernel(A, B, batch_sizes, batch_offsets, batch_padded_offsets) ref_output = torch_gmm(A, B, batch_sizes, batch_offsets, trans_b) # print(out) @@ -157,8 +129,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list, if profile: profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) - latency = profiler.do_bench( - warmup=500, input_tensors=[A, B, batch_sizes, batch_offsets, batch_padded_offsets]) + latency = profiler.do_bench(warmup=500, input_tensors=[A, B, batch_sizes, batch_offsets, batch_padded_offsets]) print(f"Latency: {latency} ms") print(f"TFlops: {batch_sum * K * M * 2 / latency * 1e-9} TFlops") @@ -173,12 +144,11 @@ def test_grouped_gemm(): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - '--batch_sizes', type=str, default="64, 128", help='comma-separated batch sizes') - parser.add_argument('--K', type=int, default=8192, help='reduce dim') - parser.add_argument('--M', type=int, default=8192, help='output dim') - parser.add_argument('--trans_b', action="store_true", help="transpose B") - parser.add_argument('--profile', action="store_true", help="profile") + parser.add_argument("--batch_sizes", type=str, default="64, 128", help="comma-separated batch sizes") + parser.add_argument("--K", type=int, default=8192, help="reduce dim") + parser.add_argument("--M", type=int, default=8192, help="output dim") + parser.add_argument("--trans_b", action="store_true", help="transpose B") + parser.add_argument("--profile", action="store_true", help="profile") args = parser.parse_args() batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")] @@ -190,14 +160,4 @@ def test_grouped_gemm(): num_stages = 2 threads = 256 - run_tilelang_grouped_gemm( - batch_sizes_list, - K, - M, - block_M, - block_N, - block_K, - trans_b, - num_stages, - threads, - profile=args.profile) + run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages, threads, profile=args.profile) diff --git a/examples/hadamard_transform/example_hadamard.py b/examples/hadamard_transform/example_hadamard.py index 531d46891..65f463b71 100644 --- a/examples/hadamard_transform/example_hadamard.py +++ b/examples/hadamard_transform/example_hadamard.py @@ -17,7 +17,7 @@ def is_pow_of_2(n): def hadamard(b, n, dtype): assert is_pow_of_2(n), "n must be a power of 2" assert 2 <= n <= 32768, "n must be in [2, 32768]" - elem_size = {'float32': 4, 'float16': 2, 'bfloat16': 2}[dtype] + elem_size = {T.float32: 4, T.float16: 2, T.bfloat16: 2}[dtype] logN = int(math.log2(n)) threads = [0, 1, 1, 1, 2, 4, 8, 16, 32, 32, 128, 256, 256, 256, 256, 256][logN] @@ -40,23 +40,21 @@ def hadamard(b, n, dtype): # print(f'{exchange_round=}') @T.macro - def warp_shfl(local: T.Tensor((thread_elem,), dtype), buf: T.Tensor((thread_elem,), dtype), - round: int): + def warp_shfl(local: T.Tensor((thread_elem,), dtype), buf: T.Tensor((thread_elem,), dtype), round: int): tx = T.get_thread_binding(0) for i in T.serial(round): tx_stride = 1 << i another_tx = tx ^ tx_stride - sign = ( - tx >> i - ) & 1 # get i-th lowest bit of tx, which determines the operation type for shared[tx, :] + sign = (tx >> i) & 1 # get i-th lowest bit of tx, which determines the operation type for shared[tx, :] for j in T.Pipelined(thread_elem, num_stages=1): buf[j] = T.tvm_warp_shuffle( - 0xffffffff, # mask of all threads + 0xFFFFFFFF, # mask of all threads local[j], another_tx % warp_size, warp_size, - warp_size) + warp_size, + ) local[j] = T.if_then_else(sign == 0, local[j] + buf[j], buf[j] - local[j]) @T.prim_func @@ -78,10 +76,8 @@ def main(A: T.Tensor((b, n), dtype), B: T.Tensor((b, n), dtype)): for j in T.serial(chunknum): chunkbase = j * chunksize for k in T.serial(chunksize // 2): - local[chunkbase + - k] = local[chunkbase + k] + local[chunkbase + k + chunksize // 2] - local[chunkbase + k + chunksize // - 2] = local[chunkbase + k] - 2 * local[chunkbase + k + chunksize // 2] + local[chunkbase + k] = local[chunkbase + k] + local[chunkbase + k + chunksize // 2] + local[chunkbase + k + chunksize // 2] = local[chunkbase + k] - 2 * local[chunkbase + k + chunksize // 2] # 3. Hadamard inside warp, n<=512 # In warp level, we rely on warp shuffle to exchange data inside each warp, without using shared memory @@ -131,28 +127,27 @@ def ref_program(x: torch.Tensor): assert x.ndim == 2 dim = x.shape[-1] assert is_pow_of_2(dim) - return F.linear( - x, torch.tensor(scipy.linalg.hadamard(dim, dtype=float), dtype=x.dtype, device=x.device)) + return F.linear(x, torch.tensor(scipy.linalg.hadamard(dim, dtype=float), dtype=x.dtype, device=x.device)) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=64, help='Batch size') - parser.add_argument('--dim', type=int, default=32768, help='Dimension') + parser.add_argument("--batch", type=int, default=64, help="Batch size") + parser.add_argument("--dim", type=int, default=32768, help="Dimension") args = parser.parse_args() B, D = args.batch, args.dim - x = torch.randn((B, D), device='cuda') - kernel = hadamard(B, D, 'float32') + x = torch.randn((B, D), device="cuda") + kernel = hadamard(B, D, T.float32) y = kernel(x) y_ref = ref_program(x) torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2) - print('All tests passed.') + print("All tests passed.") profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) latency = profiler.do_bench(warmup=100) print("Tile-lang: {:.2f} ms".format(latency)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/kda/FLA_KDA/cumsum.py b/examples/kda/FLA_KDA/cumsum.py new file mode 100644 index 000000000..0fb3368f6 --- /dev/null +++ b/examples/kda/FLA_KDA/cumsum.py @@ -0,0 +1,469 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + + +import torch +import triton +import triton.language as tl + +from .fla_utils import prepare_chunk_indices, autotune_cache_kwargs, input_guard + +BS_LIST = [32, 64] + + +@triton.heuristics( + { + "HAS_SCALE": lambda args: args["scale"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], + key=["B", "H", "BT", "IS_VARLEN", "REVERSE"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_scalar_kernel( + s, + o, + scale, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + # [BT] + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + if REVERSE: + b_z = tl.sum(b_s, axis=0) + b_o = -b_o + b_z[None] + b_s + if HAS_SCALE: + b_o *= scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics( + { + "HAS_SCALE": lambda args: args["scale"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({"BS": BS}, num_warps=num_warps) for BS in BS_LIST for num_warps in [2, 4, 8]], + key=["B", "H", "S", "BT", "IS_VARLEN", "REVERSE"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_vector_kernel( + s, + o, + scale, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + else: + p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + if REVERSE: + b_o = tl.cumsum(b_s, axis=0, reverse=True) + else: + b_o = tl.cumsum(b_s, axis=0) + if HAS_SCALE: + b_o *= scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "HAS_SCALE": lambda args: args["scale"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BT": BT}, num_warps=num_warps, num_stages=num_stages) + for BT in [32, 64, 128, 256] + for num_warps in [2, 4, 8] + for num_stages in [1, 2, 3, 4] + ], + key=["B", "H", "IS_VARLEN", "REVERSE"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_global_cumsum_scalar_kernel( + s, + o, + scale, + cu_seqlens, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_nh = tl.program_id(0) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + T = eos - bos + + b_z = tl.zeros([], dtype=tl.float32) + NT = tl.cdiv(T, BT) + for i_c in range(NT): + i_t = NT - 1 - i_c if REVERSE else i_c + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + b_ss = tl.sum(b_s, 0) + if REVERSE: + b_o = -b_o + b_ss + b_s + b_o += b_z + if i_c >= 0: + b_z += b_ss + if HAS_SCALE: + b_o *= scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics( + { + "HAS_SCALE": lambda args: args["scale"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BT": BT}, num_warps=num_warps, num_stages=num_stages) + for BT in [16, 32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [1, 2, 3, 4] + ], + key=["B", "H", "S", "IS_VARLEN", "REVERSE"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_global_cumsum_vector_kernel( + s, + o, + scale, + cu_seqlens, + T, + B: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_s, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + T = eos - bos + + b_z = tl.zeros([BS], dtype=tl.float32) + NT = tl.cdiv(T, BT) + for i_c in range(NT): + i_t = NT - 1 - i_c if REVERSE else i_c + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + else: + p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + if REVERSE: + b_c = b_z[None, :] + tl.cumsum(b_s, axis=0, reverse=True) + else: + b_c = b_z[None, :] + tl.cumsum(b_s, axis=0) + if HAS_SCALE: + b_c *= scale + tl.store(p_o, b_c.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + b_z += tl.sum(b_s, 0) + + +def chunk_local_cumsum_scalar( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: torch.Tensor = None, + head_first: bool = False, + output_dtype: torch.dtype = torch.float, + chunk_indices: torch.LongTensor = None, +) -> torch.Tensor: + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + grid = (NT, B * H) + chunk_local_cumsum_scalar_kernel[grid]( + s=g_org, + o=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + B=B, + H=H, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + ) + return g + + +def chunk_local_cumsum_vector( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: torch.Tensor = None, + head_first: bool = False, + output_dtype: torch.dtype = torch.float, + chunk_indices: torch.LongTensor = None, +) -> torch.Tensor: + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + + def grid(meta): + return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H) + + # keep cumulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_local_cumsum_vector_kernel[grid]( + s=g_org, + o=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + B=B, + H=H, + S=S, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + ) + return g + + +@input_guard +def chunk_global_cumsum_scalar( + s: torch.Tensor, + reverse: bool = False, + cu_seqlens: torch.Tensor = None, + scale: float = None, + head_first: bool = False, + output_dtype: torch.dtype = torch.float, +) -> torch.Tensor: + if head_first: + B, H, T = s.shape + else: + B, T, H = s.shape + N = len(cu_seqlens) - 1 if cu_seqlens is not None else B + + z = torch.empty_like(s, dtype=output_dtype or s.dtype) + grid = (N * H,) + chunk_global_cumsum_scalar_kernel[grid]( + s=s, + o=z, + scale=scale, + cu_seqlens=cu_seqlens, + T=T, + B=B, + H=H, + HEAD_FIRST=head_first, + REVERSE=reverse, + ) + return z + + +@input_guard +def chunk_global_cumsum_vector( + s: torch.Tensor, + reverse: bool = False, + cu_seqlens: torch.Tensor = None, + scale: float = None, + head_first: bool = False, + output_dtype: torch.dtype = torch.float, +) -> torch.Tensor: + if head_first: + B, H, T, S = s.shape + else: + B, T, H, S = s.shape + N = len(cu_seqlens) - 1 if cu_seqlens is not None else B + BS = min(32, triton.next_power_of_2(S)) + + z = torch.empty_like(s, dtype=output_dtype or s.dtype) + grid = (triton.cdiv(S, BS), N * H) + chunk_global_cumsum_vector_kernel[grid]( + s=s, + o=z, + scale=scale, + cu_seqlens=cu_seqlens, + T=T, + B=B, + H=H, + S=S, + BS=BS, + HEAD_FIRST=head_first, + REVERSE=reverse, + ) + return z + + +@input_guard +def chunk_global_cumsum( + s: torch.Tensor, + reverse: bool = False, + cu_seqlens: torch.Tensor = None, + scale: float = None, + head_first: bool = False, + output_dtype: torch.dtype = torch.float, +) -> torch.Tensor: + if cu_seqlens is not None: + assert s.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + if len(s.shape) == 3: + return chunk_global_cumsum_scalar( + s=s, + reverse=reverse, + cu_seqlens=cu_seqlens, + scale=scale, + head_first=head_first, + output_dtype=output_dtype, + ) + elif len(s.shape) == 4: + return chunk_global_cumsum_vector( + s=s, + reverse=reverse, + cu_seqlens=cu_seqlens, + scale=scale, + head_first=head_first, + output_dtype=output_dtype, + ) + else: + raise ValueError( + f"Unsupported input shape {s.shape}, " + f"which should be [B, T, H]/[B, T, H, D] if `head_first=False` " + f"or [B, H, T]/[B, H, T, D] otherwise", + ) + + +@input_guard +def chunk_local_cumsum( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: torch.Tensor = None, + head_first: bool = False, + output_dtype: torch.dtype = torch.float, + chunk_indices: torch.LongTensor = None, + **kwargs, +) -> torch.Tensor: + if cu_seqlens is not None: + assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + if len(g.shape) == 3: + return chunk_local_cumsum_scalar( + g=g, + chunk_size=chunk_size, + reverse=reverse, + scale=scale, + cu_seqlens=cu_seqlens, + head_first=head_first, + output_dtype=output_dtype, + chunk_indices=chunk_indices, + ) + elif len(g.shape) == 4: + return chunk_local_cumsum_vector( + g=g, + chunk_size=chunk_size, + reverse=reverse, + scale=scale, + cu_seqlens=cu_seqlens, + head_first=head_first, + output_dtype=output_dtype, + chunk_indices=chunk_indices, + ) + else: + raise ValueError( + f"Unsupported input shape {g.shape}, which should be (B, T, H, D) if `head_first=False` or (B, H, T, D) otherwise", + ) diff --git a/examples/kda/FLA_KDA/fla_chunk_delta.py b/examples/kda/FLA_KDA/fla_chunk_delta.py new file mode 100644 index 000000000..3b0fc908d --- /dev/null +++ b/examples/kda/FLA_KDA/fla_chunk_delta.py @@ -0,0 +1,579 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import triton +import triton.language as tl +from .fla_utils import prepare_chunk_indices, exp, exp2, USE_CUDA_GRAPH, autotune_cache_kwargs + +NUM_WARPS = [2, 4] + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_GK": lambda args: args["gk"] is not None, + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + for BV in [32, 64] + ], + key=["H", "K", "V", "BT", "USE_EXP2"], + use_cuda_graph=USE_CUDA_GRAPH, + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( + k, + v, + w, + v_new, + g, + gk, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_h2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_h3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_h4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + h += ((boh * H + i_h) * K * V).to(tl.int64) + v += ((bos * H + i_h) * V).to(tl.int64) + k += ((bos * H + i_h) * K).to(tl.int64) + w += ((bos * H + i_h) * K).to(tl.int64) + if SAVE_NEW_VALUE: + v_new += ((bos * H + i_h) * V).to(tl.int64) + stride_v = H * V + stride_h = H * K * V + stride_k = H * K + if USE_INITIAL_STATE: + h0 = h0 + i_nh * K * V + if STORE_FINAL_STATE: + ht = ht + i_nh * K * V + + # load initial state + if USE_INITIAL_STATE: + p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if K > 64: + p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if K > 128: + p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if K > 192: + p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + + # main recurrence + for i_t in range(NT): + p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) + + p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v = tl.dot(b_w, b_h1.to(b_w.dtype)) + if K > 64: + p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h2.to(b_w.dtype)) + if K > 128: + p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h3.to(b_w.dtype)) + if K > 192: + p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h4.to(b_w.dtype)) + p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v + + if SAVE_NEW_VALUE: + p_v = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t + 1) * BT, T) - 1 + if USE_G: + m_t = (i_t * BT + tl.arange(0, BT)) < T + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + if USE_EXP2: + b_v = b_v * tl.where(m_t, exp2(b_g_last - b_g), 0)[:, None] + b_g_last = exp2(b_g_last) + else: + b_v = b_v * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] + b_g_last = exp(b_g_last) + b_h1 *= b_g_last + if K > 64: + b_h2 *= b_g_last + if K > 128: + b_h3 *= b_g_last + if K > 192: + b_h4 *= b_g_last + + if USE_GK: + o_k1 = tl.arange(0, 64) + b_gk_last1 = tl.load(gk + (bos + last_idx) * H * K + i_h * K + o_k1, mask=(o_k1 < K), other=0.0) + if USE_EXP2: + b_h1 *= exp2(b_gk_last1)[:, None] + else: + b_h1 *= exp(b_gk_last1)[:, None] + if K > 64: + o_k2 = 64 + o_k1 + b_gk_last2 = tl.load(gk + (bos + last_idx) * H * K + i_h * K + o_k2, mask=(o_k2 < K), other=0.0) + if USE_EXP2: + b_h2 *= exp2(b_gk_last2)[:, None] + else: + b_h2 *= exp(b_gk_last2)[:, None] + if K > 128: + o_k3 = 128 + o_k1 + b_gk_last3 = tl.load(gk + (bos + last_idx) * H * K + i_h * K + o_k3, mask=(o_k3 < K), other=0.0) + if USE_EXP2: + b_h3 *= exp2(b_gk_last3)[:, None] + else: + b_h3 *= exp(b_gk_last3)[:, None] + if K > 192: + o_k4 = 192 + o_k1 + b_gk_last4 = tl.load(gk + (bos + last_idx) * H * K + i_h * K + o_k4, mask=(o_k4 < K), other=0.0) + if USE_EXP2: + b_h4 *= exp2(b_gk_last4)[:, None] + else: + b_h4 *= exp(b_gk_last4)[:, None] + b_v = b_v.to(k.dtype.element_ty) + + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h1 += tl.dot(b_k, b_v) + if K > 64: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h2 += tl.dot(b_k, b_v) + if K > 128: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h3 += tl.dot(b_k, b_v) + if K > 192: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h4 += tl.dot(b_k, b_v) + # epilogue + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_GK": lambda args: args["gk"] is not None, + "USE_INITIAL_STATE": lambda args: args["dh0"] is not None, + "USE_FINAL_STATE_GRADIENT": lambda args: args["dht"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in ([4, 3, 2]) + for BV in [64, 32] + ], + key=["H", "K", "V", "BT", "BV", "USE_G", "USE_EXP2"], + use_cuda_graph=USE_CUDA_GRAPH, + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64( + q, + k, + w, + g, + gk, + dht, + dh0, + do, + dh, + dv, + dv2, + cu_seqlens, + chunk_offsets, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_dh1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_dh2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_dh3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_dh4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + q += ((bos * H + i_h) * K).to(tl.int64) + k += ((bos * H + i_h) * K).to(tl.int64) + w += ((bos * H + i_h) * K).to(tl.int64) + do += ((bos * H + i_h) * V).to(tl.int64) + dv += ((bos * H + i_h) * V).to(tl.int64) + dv2 += ((bos * H + i_h) * V).to(tl.int64) + dh += ((boh * H + i_h) * K * V).to(tl.int64) + if USE_GK: + gk += ((bos * H + i_h) * K).to(tl.int64) + + stride_v = H * V + stride_h = H * K * V + stride_k = H * K + if USE_INITIAL_STATE: + dh0 += i_nh * K * V + if USE_FINAL_STATE_GRADIENT: + dht += i_nh * K * V + + if USE_FINAL_STATE_GRADIENT: + p_dht1 = tl.make_block_ptr(dht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + b_dh1 += tl.load(p_dht1, boundary_check=(0, 1)) + if K > 64: + p_dht2 = tl.make_block_ptr(dht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + b_dh2 += tl.load(p_dht2, boundary_check=(0, 1)) + if K > 128: + p_dht3 = tl.make_block_ptr(dht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + b_dh3 += tl.load(p_dht3, boundary_check=(0, 1)) + if K > 192: + p_dht4 = tl.make_block_ptr(dht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + b_dh4 += tl.load(p_dht4, boundary_check=(0, 1)) + + for i_t in range(NT - 1, -1, -1): + p_dh1 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh1, b_dh1.to(p_dh1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_dh2 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh2, b_dh2.to(p_dh2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_dh3 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh3, b_dh3.to(p_dh3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_dh4 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh4, b_dh4.to(p_dh4.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t + 1) * BT, T) - 1 + if USE_G: + bg_last = tl.load(g + (bos + last_idx) * H + i_h) + p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + if USE_EXP2: + bg_last_exp = exp2(bg_last) + b_g_exp = exp2(b_g) + else: + bg_last_exp = exp(bg_last) + b_g_exp = exp(b_g) + + p_dv = tl.make_block_ptr(dv, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv2 = tl.make_block_ptr(dv2, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # Update dv + p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + o_k1 = tl.arange(0, 64) + b_gk_last1 = tl.load(gk + last_idx * H * K + o_k1, mask=(o_k1 < K), other=0.0) + b_dv = tl.dot(b_k, b_dh1.to(b_k.dtype)) + + if K > 64: + p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + o_k2 = 64 + o_k1 + b_gk_last2 = tl.load(gk + last_idx * H * K + o_k2, mask=(o_k2 < K), other=0.0) + b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype)) + + if K > 128: + p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + o_k3 = 128 + o_k1 + b_gk_last3 = tl.load(gk + last_idx * H * K + o_k3, mask=(o_k3 < K), other=0.0) + b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype)) + + if K > 192: + p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + o_k4 = 192 + o_k1 + b_gk_last4 = tl.load(gk + last_idx * H * K + o_k4, mask=(o_k4 < K), other=0.0) + b_dv += tl.dot(b_k, b_dh4.to(b_k.dtype)) + + if USE_G: + m_t = (i_t * BT + tl.arange(0, BT)) < T + if USE_EXP2: + b_dv *= tl.where(m_t, exp2(bg_last - b_g), 0)[:, None] + else: + b_dv *= tl.where(m_t, exp(bg_last - b_g), 0)[:, None] + b_dv += tl.load(p_dv, boundary_check=(0, 1)) + + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # Update dh + p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + if USE_G: + b_dh1 *= bg_last_exp + b_q = b_q * b_g_exp[None, :] + if USE_GK: + if USE_EXP2: + b_dh1 *= exp2(b_gk_last1[:, None]) + else: + b_dh1 *= exp(b_gk_last1[:, None]) + b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) + if K > 64: + p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + if USE_G: + b_dh2 *= bg_last_exp + b_q = b_q * b_g_exp[None, :] + if USE_GK: + if USE_EXP2: + b_dh2 *= exp2(b_gk_last2[:, None]) + else: + b_dh2 *= exp(b_gk_last2[:, None]) + b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) + if K > 128: + p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + if USE_G: + b_dh3 *= bg_last_exp + b_q = b_q * b_g_exp[None, :] + if USE_GK: + if USE_EXP2: + b_dh3 *= exp2(b_gk_last3[:, None]) + else: + b_dh3 *= exp(b_gk_last3[:, None]) + b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) + if K > 192: + p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + if USE_G: + b_dh4 *= bg_last_exp + b_q = b_q * b_g_exp[None, :] + if USE_GK: + if USE_EXP2: + b_dh4 *= exp2(b_gk_last4[:, None]) + else: + b_dh4 *= exp(b_gk_last4[:, None]) + b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) + + if USE_INITIAL_STATE: + p_dh0 = tl.make_block_ptr(dh0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh0, b_dh1.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_dh1 = tl.make_block_ptr(dh0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh1, b_dh2.to(p_dh1.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_dh2 = tl.make_block_ptr(dh0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh2, b_dh3.to(p_dh2.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_dh3 = tl.make_block_ptr(dh0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh3, b_dh4.to(p_dh3.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gated_delta_rule_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: torch.Tensor = None, + gk: torch.Tensor = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + save_new_value: bool = True, + cu_seqlens: torch.LongTensor = None, + chunk_indices: torch.LongTensor = None, + use_exp2: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, u.shape[-1] + BT = chunk_size + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + assert K <= 256, "current kernel does not support head dimension larger than 256." + + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + + v_new = torch.empty_like(u) if save_new_value else None + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), N * H) + + chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( + k=k, + v=u, + w=w, + v_new=v_new, + g=g, + gk=gk, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + USE_EXP2=use_exp2, + ) + return h, v_new, final_state + + +def chunk_gated_delta_rule_bwd_dhu( + q: torch.Tensor, + k: torch.Tensor, + w: torch.Tensor, + do: torch.Tensor, + dv: torch.Tensor, + g: torch.Tensor = None, + gk: torch.Tensor = None, + h0: torch.Tensor = None, + dht: torch.Tensor = None, + scale: float = None, + cu_seqlens: torch.LongTensor = None, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + chunk_indices: torch.LongTensor = None, + use_exp2: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K, V = *q.shape, do.shape[-1] + # N: the actual number of sequences in the batch with either equal or variable lengths + BT = 64 + assert K <= 256, "current kernel does not support head dimension being larger than 256." + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + + dh = q.new_empty(B, NT, H, K, V) + dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None + dv2 = torch.empty_like(dv) + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), N * H) + + chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64[grid]( + q=q, + k=k, + w=w, + g=g, + gk=gk, + dht=dht, + dh0=dh0, + do=do, + dh=dh, + dv=dv, + dv2=dv2, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + USE_EXP2=use_exp2, + ) + return dh, dh0, dv2 diff --git a/examples/kda/FLA_KDA/fla_chunk_inter.py b/examples/kda/FLA_KDA/fla_chunk_inter.py new file mode 100644 index 000000000..e6de9bb28 --- /dev/null +++ b/examples/kda/FLA_KDA/fla_chunk_inter.py @@ -0,0 +1,193 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + + +import torch +import triton +import triton.language as tl + +from .fla_utils import prepare_chunk_indices, exp2, autotune_cache_kwargs, check_shared_mem + +BK_LIST = [32, 64] if check_shared_mem() else [16, 32] +BV_LIST = [64, 128] if check_shared_mem("ampere") else [16, 32] + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BK_LIST + for BV in BV_LIST + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["BT"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_kda_bwd_kernel_inter( + q, + k, + v, + g, + h, + do, + dh, + dq, + dk, + dv, + dw, + dg, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + o_k = i_k * BK + tl.arange(0, BK) + o_t = i_t * BT + tl.arange(0, BT) + m_k = o_k < K + m_t = o_t < T + m_last = o_t == min(T, i_t * BT + BT) - 1 + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + g += (bos * H + i_h) * K + h += (i_tg * H + i_h) * K * V + do += (bos * H + i_h) * V + dh += (i_tg * H + i_h) * K * V + dq += (bos * H + i_h) * K + dk += (bos * H + i_h) * K + dw += (bos * H + i_h) * K + dv += (bos * H + i_h) * V + dg += (bos * H + i_h) * K + + p_g = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + p_gn = g + (min(T, i_t * BT + BT) - 1) * H * K + o_k + b_gn = tl.load(p_gn, mask=m_k, other=0) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT, BK], dtype=tl.float32) + b_dgk = tl.zeros([BK], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BK] + b_dgk += tl.sum(b_h * b_dh, axis=0) + # [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) + + p_dv = tl.make_block_ptr(dv, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype)) + + p_dw = tl.make_block_ptr(dw, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + + b_dgk *= exp2(b_gn) + b_dq *= scale + b_dq = b_dq * exp2(b_g) + b_dk = b_dk * tl.where(m_t[:, None], exp2(b_gn[None, :] - b_g), 0) + + p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dgk += tl.sum(b_dk * b_k, axis=0) + b_dg = b_q * b_dq - b_k * b_dk + m_last[:, None] * b_dgk + + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_kda_bwd_dqkwg( + q: torch.Tensor, + k: torch.Tensor, + w: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + dv: torch.Tensor, + scale: float = None, + cu_seqlens: torch.LongTensor = None, + chunk_size: int = 64, + chunk_indices: torch.LongTensor = None, +): + B, T, H, K, V = *k.shape, v.shape[-1] + BT = chunk_size + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + dq = torch.empty_like(q, dtype=torch.float) + dk = torch.empty_like(k, dtype=torch.float) + dw = torch.empty_like(w) + dg = torch.empty_like(g) + + def grid(meta): + return (triton.cdiv(K, meta["BK"]), NT, B * H) + + chunk_kda_bwd_kernel_inter[grid]( + q=q, + k=k, + v=v, + g=g, + h=h, + do=do, + dh=dh, + dq=dq, + dk=dk, + dv=dv, + dw=dw, + dg=dg, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + ) + return dq, dk, dw, dg diff --git a/examples/kda/FLA_KDA/fla_chunk_intra.py b/examples/kda/FLA_KDA/fla_chunk_intra.py new file mode 100644 index 000000000..244f05f1c --- /dev/null +++ b/examples/kda/FLA_KDA/fla_chunk_intra.py @@ -0,0 +1,650 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import triton +import triton.language as tl + +from .fla_utils import autotune_cache_kwargs, exp2, prepare_chunk_indices +from .cumsum import chunk_local_cumsum + +IS_TF32_SUPPORTED = False +if IS_TF32_SUPPORTED: + SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32x3") +else: + SOLVE_TRIL_DOT_PRECISION = tl.constexpr("ieee") +SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32") +# ============================================================================ +# Fused inter + solve_tril kernel: compute off-diagonal Akk and solve in one pass +# ============================================================================ + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({"BK": BK}, num_warps=num_warps) for BK in [32, 64] for num_warps in [1, 2, 4]], + key=["H", "K", "BC"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_kda_fwd_kernel_inter_solve_fused( + q, + k, + g, + beta, + Aqk, + Akk_diag, + Akk, + scale, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + """ + Fused kernel: compute inter-subchunk Akk + solve_tril in one pass. + Prerequisite: token_parallel has already computed diagonal Akk blocks in Akk_diag. + + This kernel: + 1. Computes off-diagonal Aqk blocks -> writes to global + 2. Computes off-diagonal Akk blocks -> keeps in registers + 3. Loads diagonal Akk blocks from Akk_diag (fp32) + 4. Does forward substitution on diagonals + 5. Computes merged Akk_inv + 6. Writes Akk_inv to Akk + """ + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT >= T: + return + + i_tc0 = i_t * BT + i_tc1 = i_t * BT + BC + i_tc2 = i_t * BT + 2 * BC + i_tc3 = i_t * BT + 3 * BC + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + g += (bos * H + i_h) * K + Aqk += (bos * H + i_h) * BT + Akk += (bos * H + i_h) * BT + Akk_diag += (bos * H + i_h) * BC + + m_tc1 = (i_tc1 + tl.arange(0, BC)) < T + m_tc2 = (i_tc2 + tl.arange(0, BC)) < T + m_tc3 = (i_tc3 + tl.arange(0, BC)) < T + + b_Aqk10 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk10 = tl.zeros([BC, BC], dtype=tl.float32) + + b_Aqk20 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk20 = tl.zeros([BC, BC], dtype=tl.float32) + b_Aqk21 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk21 = tl.zeros([BC, BC], dtype=tl.float32) + + b_Aqk30 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk30 = tl.zeros([BC, BC], dtype=tl.float32) + b_Aqk31 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk31 = tl.zeros([BC, BC], dtype=tl.float32) + b_Aqk32 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk32 = tl.zeros([BC, BC], dtype=tl.float32) + + ################################################################################ + # 1. off-diagonal blocks + ################################################################################ + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + p_k0 = tl.make_block_ptr(k, (K, T), (1, H * K), (i_k * BK, i_tc0), (BK, BC), (0, 1)) + p_g0 = tl.make_block_ptr(g, (K, T), (1, H * K), (i_k * BK, i_tc0), (BK, BC), (0, 1)) + b_kt0 = tl.load(p_k0, boundary_check=(0, 1)).to(tl.float32) + b_gt0 = tl.load(p_g0, boundary_check=(0, 1)).to(tl.float32) + + b_kt1, b_gt1 = b_kt0, b_gt0 + b_kt2, b_gt2 = b_kt0, b_gt0 + if i_tc1 < T: + p_q1 = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0)) + p_k1 = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0)) + p_g1 = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0)) + + b_q1 = tl.load(p_q1, boundary_check=(0, 1)).to(tl.float32) + b_k1 = tl.load(p_k1, boundary_check=(0, 1)).to(tl.float32) + b_g1 = tl.load(p_g1, boundary_check=(0, 1)).to(tl.float32) + b_kt1 = tl.trans(b_k1) + b_gt1 = tl.trans(b_g1) + + b_gn1 = tl.load(g + i_tc1 * H * K + o_k, mask=m_k, other=0).to(tl.float32) + b_gqn1 = tl.where(m_tc1[:, None], exp2(b_g1 - b_gn1[None, :]), 0) + b_qg1 = b_q1 * b_gqn1 + b_kg1 = b_k1 * b_gqn1 + b_kgt = b_kt0 * exp2(b_gn1[:, None] - b_gt0) + b_Aqk10 += tl.dot(b_qg1, b_kgt) + b_Akk10 += tl.dot(b_kg1, b_kgt) + + if i_tc2 < T: + p_q2 = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0)) + p_k2 = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0)) + p_g2 = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0)) + + b_q2 = tl.load(p_q2, boundary_check=(0, 1)).to(tl.float32) + b_k2 = tl.load(p_k2, boundary_check=(0, 1)).to(tl.float32) + b_g2 = tl.load(p_g2, boundary_check=(0, 1)).to(tl.float32) + b_kt2 = tl.trans(b_k2) + b_gt2 = tl.trans(b_g2) + + b_gn2 = tl.load(g + i_tc2 * H * K + o_k, mask=m_k, other=0).to(tl.float32) + b_gqn2 = tl.where(m_tc2[:, None], exp2(b_g2 - b_gn2[None, :]), 0) + b_qg2 = b_q2 * b_gqn2 + b_kg2 = b_k2 * b_gqn2 + b_kgt = b_kt0 * exp2(b_gn2[:, None] - b_gt0) + b_Aqk20 += tl.dot(b_qg2, b_kgt) + b_Akk20 += tl.dot(b_kg2, b_kgt) + + b_kgt = b_kt1 * exp2(b_gn2[:, None] - b_gt1) + b_Aqk21 += tl.dot(b_qg2, b_kgt) + b_Akk21 += tl.dot(b_kg2, b_kgt) + + if i_tc3 < T: + p_q3 = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0)) + p_k3 = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0)) + p_g3 = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0)) + b_q3 = tl.load(p_q3, boundary_check=(0, 1)).to(tl.float32) + b_k3 = tl.load(p_k3, boundary_check=(0, 1)).to(tl.float32) + b_g3 = tl.load(p_g3, boundary_check=(0, 1)).to(tl.float32) + + b_gn3 = tl.load(g + i_tc3 * H * K + o_k, mask=m_k, other=0).to(tl.float32) + b_gqn3 = tl.where(m_tc3[:, None], exp2(b_g3 - b_gn3[None, :]), 0) + b_qg3 = b_q3 * b_gqn3 + b_kg3 = b_k3 * b_gqn3 + b_kgt = b_kt0 * exp2(b_gn3[:, None] - b_gt0) + b_Aqk30 += tl.dot(b_qg3, b_kgt) + b_Akk30 += tl.dot(b_kg3, b_kgt) + + b_kgt = b_kt1 * exp2(b_gn3[:, None] - b_gt1) + b_Aqk31 += tl.dot(b_qg3, b_kgt) + b_Akk31 += tl.dot(b_kg3, b_kgt) + + b_kgt = b_kt2 * exp2(b_gn3[:, None] - b_gt2) + b_Aqk32 += tl.dot(b_qg3, b_kgt) + b_Akk32 += tl.dot(b_kg3, b_kgt) + + ################################################################################ + # 2. save off-diagonal Aqk blocks and prepare Akk + ################################################################################ + if i_tc1 < T: + p_Aqk10 = tl.make_block_ptr(Aqk, (T, BT), (H * BT, 1), (i_tc1, 0), (BC, BC), (1, 0)) + tl.store(p_Aqk10, (b_Aqk10 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + + p_b1 = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_tc1,), (BC,), (0,)) + b_b1 = tl.load(p_b1, boundary_check=(0,)).to(tl.float32) + b_Akk10 = b_Akk10 * b_b1[:, None] + if i_tc2 < T: + p_Aqk20 = tl.make_block_ptr(Aqk, (T, BT), (H * BT, 1), (i_tc2, 0), (BC, BC), (1, 0)) + p_Aqk21 = tl.make_block_ptr(Aqk, (T, BT), (H * BT, 1), (i_tc2, BC), (BC, BC), (1, 0)) + tl.store(p_Aqk20, (b_Aqk20 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Aqk21, (b_Aqk21 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + + p_b2 = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_tc2,), (BC,), (0,)) + b_b2 = tl.load(p_b2, boundary_check=(0,)).to(tl.float32) + b_Akk20 = b_Akk20 * b_b2[:, None] + b_Akk21 = b_Akk21 * b_b2[:, None] + if i_tc3 < T: + p_Aqk30 = tl.make_block_ptr(Aqk, (T, BT), (H * BT, 1), (i_tc3, 0), (BC, BC), (1, 0)) + p_Aqk31 = tl.make_block_ptr(Aqk, (T, BT), (H * BT, 1), (i_tc3, BC), (BC, BC), (1, 0)) + p_Aqk32 = tl.make_block_ptr(Aqk, (T, BT), (H * BT, 1), (i_tc3, 2 * BC), (BC, BC), (1, 0)) + tl.store(p_Aqk30, (b_Aqk30 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Aqk31, (b_Aqk31 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Aqk32, (b_Aqk32 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + + p_b3 = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_tc3,), (BC,), (0,)) + b_b3 = tl.load(p_b3, boundary_check=(0,)).to(tl.float32) + b_Akk30 = b_Akk30 * b_b3[:, None] + b_Akk31 = b_Akk31 * b_b3[:, None] + b_Akk32 = b_Akk32 * b_b3[:, None] + + ################################################################################ + # 3. load diagonal Akk blocks + ################################################################################ + p_Akk00 = tl.make_block_ptr(Akk_diag, (T, BC), (H * BC, 1), (i_tc0, 0), (BC, BC), (1, 0)) + p_Akk11 = tl.make_block_ptr(Akk_diag, (T, BC), (H * BC, 1), (i_tc1, 0), (BC, BC), (1, 0)) + p_Akk22 = tl.make_block_ptr(Akk_diag, (T, BC), (H * BC, 1), (i_tc2, 0), (BC, BC), (1, 0)) + p_Akk33 = tl.make_block_ptr(Akk_diag, (T, BC), (H * BC, 1), (i_tc3, 0), (BC, BC), (1, 0)) + # each diagonal block is stored contiguously: row i of block s is at Akk_diag[t=i_t*BT+s*BC+i, :BC] + b_Ai00 = tl.load(p_Akk00, boundary_check=(0, 1)).to(tl.float32) + b_Ai11 = tl.load(p_Akk11, boundary_check=(0, 1)).to(tl.float32) + b_Ai22 = tl.load(p_Akk22, boundary_check=(0, 1)).to(tl.float32) + b_Ai33 = tl.load(p_Akk33, boundary_check=(0, 1)).to(tl.float32) + + ################################################################################ + # 4. forward substitution on diagonals + ################################################################################ + o_i = tl.arange(0, BC) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + + b_Ai00 = -tl.where(m_A, b_Ai00, 0) + b_Ai11 = -tl.where(m_A, b_Ai11, 0) + b_Ai22 = -tl.where(m_A, b_Ai22, 0) + b_Ai33 = -tl.where(m_A, b_Ai33, 0) + + # Forward substitution: load from Akk_diag (stride H*BC, columns 0:BC) + for i in range(2, min(BC, T - i_tc0)): + b_a00 = -tl.load(Akk_diag + (i_tc0 + i) * H * BC + o_i) + b_a00 = tl.where(o_i < i, b_a00, 0.0) + b_a00 += tl.sum(b_a00[:, None] * b_Ai00, 0) + b_Ai00 = tl.where((o_i == i)[:, None], b_a00, b_Ai00) + for i in range(BC + 2, min(2 * BC, T - i_tc0)): + b_a11 = -tl.load(Akk_diag + (i_tc0 + i) * H * BC + o_i) + b_a11 = tl.where(o_i < i - BC, b_a11, 0.0) + b_a11 += tl.sum(b_a11[:, None] * b_Ai11, 0) + b_Ai11 = tl.where((o_i == i - BC)[:, None], b_a11, b_Ai11) + for i in range(2 * BC + 2, min(3 * BC, T - i_tc0)): + b_a22 = -tl.load(Akk_diag + (i_tc0 + i) * H * BC + o_i) + b_a22 = tl.where(o_i < i - 2 * BC, b_a22, 0.0) + b_a22 += tl.sum(b_a22[:, None] * b_Ai22, 0) + b_Ai22 = tl.where((o_i == i - 2 * BC)[:, None], b_a22, b_Ai22) + for i in range(3 * BC + 2, min(4 * BC, T - i_tc0)): + b_a33 = -tl.load(Akk_diag + (i_tc0 + i) * H * BC + o_i) + b_a33 = tl.where(o_i < i - 3 * BC, b_a33, 0.0) + b_a33 += tl.sum(b_a33[:, None] * b_Ai33, 0) + b_Ai33 = tl.where((o_i == i - 3 * BC)[:, None], b_a33, b_Ai33) + + b_Ai00 += m_I + b_Ai11 += m_I + b_Ai22 += m_I + b_Ai33 += m_I + + # ################################################################################ + # # 5. compute merged inverse using off-diagonals + # ################################################################################ + + # we used tf32x3 to maintain matrix inverse's precision whenever possible. + b_Ai10 = -tl.dot(tl.dot(b_Ai11, b_Akk10, input_precision=SOLVE_TRIL_DOT_PRECISION), b_Ai00, input_precision=SOLVE_TRIL_DOT_PRECISION) + b_Ai21 = -tl.dot(tl.dot(b_Ai22, b_Akk21, input_precision=SOLVE_TRIL_DOT_PRECISION), b_Ai11, input_precision=SOLVE_TRIL_DOT_PRECISION) + b_Ai32 = -tl.dot(tl.dot(b_Ai33, b_Akk32, input_precision=SOLVE_TRIL_DOT_PRECISION), b_Ai22, input_precision=SOLVE_TRIL_DOT_PRECISION) + + b_Ai20 = -tl.dot( + b_Ai22, + tl.dot(b_Akk20, b_Ai00, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk21, b_Ai10, input_precision=SOLVE_TRIL_DOT_PRECISION), + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + b_Ai31 = -tl.dot( + b_Ai33, + tl.dot(b_Akk31, b_Ai11, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk32, b_Ai21, input_precision=SOLVE_TRIL_DOT_PRECISION), + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + b_Ai30 = -tl.dot( + b_Ai33, + tl.dot(b_Akk30, b_Ai00, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk31, b_Ai10, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk32, b_Ai20, input_precision=SOLVE_TRIL_DOT_PRECISION), + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + + ################################################################################ + # 6. store full Akk_inv to Akk + ################################################################################ + + p_Akk00 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc0, 0), (BC, BC), (1, 0)) + p_Akk10 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc1, 0), (BC, BC), (1, 0)) + p_Akk11 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc1, BC), (BC, BC), (1, 0)) + p_Akk20 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc2, 0), (BC, BC), (1, 0)) + p_Akk21 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc2, BC), (BC, BC), (1, 0)) + p_Akk22 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc2, 2 * BC), (BC, BC), (1, 0)) + p_Akk30 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc3, 0), (BC, BC), (1, 0)) + p_Akk31 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc3, BC), (BC, BC), (1, 0)) + p_Akk32 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc3, 2 * BC), (BC, BC), (1, 0)) + p_Akk33 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc3, 3 * BC), (BC, BC), (1, 0)) + + tl.store(p_Akk00, b_Ai00.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk10, b_Ai10.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk11, b_Ai11.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk20, b_Ai20.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk21, b_Ai21.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk22, b_Ai22.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk30, b_Ai30.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk31, b_Ai31.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk32, b_Ai32.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk33, b_Ai33.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4, 8] for num_stages in [2, 3, 4]], + key=["BK", "NC", "BT"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["B", "T"]) +def chunk_kda_bwd_kernel_intra( + q, + k, + g, + beta, + dAqk, + dAkk, + dq, + dq2, + dk, + dk2, + dg, + dg2, + db, + cu_seqlens, + chunk_indices, + B, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_kc, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_k, i_i = i_kc // NC, i_kc % NC + + all = B * T + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + + i_ti = i_t * BT + i_i * BC + if i_ti >= T: + return + + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + g += (bos * H + i_h) * K + beta += bos * H + i_h + + dAqk += (bos * H + i_h) * BT + dAkk += (bos * H + i_h) * BT + dq += (bos * H + i_h) * K + dq2 += (bos * H + i_h) * K + dk += (bos * H + i_h) * K + dk2 += (bos * H + i_h) * K + dg += (bos * H + i_h) * K + dg2 += (bos * H + i_h) * K + db += (i_k * all + bos) * H + i_h + + p_g = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + + p_b = tl.make_block_ptr(beta, (T,), (H,), (i_ti,), (BC,), (0,)) + b_b = tl.load(p_b, boundary_check=(0,)) + + b_dq2 = tl.zeros([BC, BK], dtype=tl.float32) + b_dk2 = tl.zeros([BC, BK], dtype=tl.float32) + if i_i > 0: + p_gn = g + i_ti * H * K + o_k + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + for i_j in range(0, i_i): + p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (H * BT, 1), (i_ti, i_j * BC), (BC, BC), (1, 0)) + p_dAkk = tl.make_block_ptr(dAkk, (T, BT), (H * BT, 1), (i_ti, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = b_k * exp2(b_gn[None, :] - b_gk) + # [BC, BC] + b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1)) + b_dAkk = tl.load(p_dAkk, boundary_check=(0, 1)) + # [BC, BK] + b_dq2 += tl.dot(b_dAqk, b_kg) + b_dk2 += tl.dot(b_dAkk, b_kg) + b_gqn = exp2(b_g - b_gn[None, :]) + b_dq2 *= b_gqn + b_dk2 *= b_gqn + + o_i = tl.arange(0, BC) + m_dA = (i_ti + o_i) < T + o_dA = (i_ti + o_i) * H * BT + i_i * BC + p_kj = k + i_ti * H * K + o_k + p_gkj = g + i_ti * H * K + o_k + + p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC] + b_dAqk = tl.load(dAqk + o_dA + j, mask=m_dA, other=0) + b_dAkk = tl.load(dAkk + o_dA + j, mask=m_dA, other=0) + # [BK] + b_kj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) + b_gkj = tl.load(p_gkj, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] >= j + # [BC, BK] + b_kgj = b_kj[None, :] * exp2(b_g - b_gkj[None, :]) + b_dq2 += tl.where(m_i, b_dAqk[:, None] * b_kgj, 0.0) + b_dk2 += tl.where(m_i, b_dAkk[:, None] * b_kgj, 0.0) + + p_kj += H * K + p_gkj += H * K + b_db = tl.sum(b_dk2 * b_k, 1) + b_dk2 *= b_b[:, None] + + p_dq = tl.make_block_ptr(dq, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_dq2 = tl.make_block_ptr(dq2, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_db = tl.make_block_ptr(db, (T,), (H,), (i_ti,), (BC,), (0,)) + + b_dg2 = b_q * b_dq2 + b_dq2 = b_dq2 + tl.load(p_dq, boundary_check=(0, 1)) + tl.store(p_dq2, b_dq2.to(p_dq2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) + + tl.debug_barrier() + b_dkt = tl.zeros([BC, BK], dtype=tl.float32) + + NC = min(NC, tl.cdiv(T - i_t * BT, BC)) + if i_i < NC - 1: + p_gn = g + (min(i_ti + BC, T) - 1) * H * K + o_k + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + for i_j in range(i_i + 1, NC): + p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_b = tl.make_block_ptr(beta, (T,), (H,), (i_t * BT + i_j * BC,), (BC,), (0,)) + p_dAqk = tl.make_block_ptr(dAqk, (BT, T), (1, H * BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + p_dAkk = tl.make_block_ptr(dAkk, (BT, T), (1, H * BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + # [BC] + b_b = tl.load(p_b, boundary_check=(0,)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_kb = tl.load(p_k, boundary_check=(0, 1)) * b_b[:, None] + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + # [BC, BC] + b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1)) + b_dAkk = tl.load(p_dAkk, boundary_check=(0, 1)) + + o_j = i_t * BT + i_j * BC + o_i + m_j = o_j < T + # [BC, BK] + b_gkn = tl.where(m_j[:, None], exp2(b_gk - b_gn[None, :]), 0) + b_qg = b_q * b_gkn + b_kbg = b_kb * b_gkn + # [BC, BK] + b_dkt += tl.dot(b_dAqk, b_qg) + tl.dot(b_dAkk, b_kbg) + b_dkt *= exp2(b_gn[None, :] - b_g) + + o_dA = i_ti * H * BT + i_i * BC + o_i + p_qj = q + i_ti * H * K + o_k # [bs, i_ti, i_h*block_h, i_k*bk:(i_k+1)*bk] + p_kj = k + i_ti * H * K + o_k + p_gkj = g + i_ti * H * K + o_k + p_bj = beta + i_ti * H + + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC,] + b_dAqk = tl.load(dAqk + o_dA + j * H * BT) + b_dAkk = tl.load(dAkk + o_dA + j * H * BT) + # [BK,] + b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32) + b_kbj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) * tl.load(p_bj) + b_gkj = tl.load(p_gkj, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] <= j + b_gkq = exp2(b_gkj[None, :] - b_g) + b_dkt += tl.where(m_i, (b_dAkk[:, None] * b_kbj[None, :] + b_dAqk[:, None] * b_qj[None, :]) * b_gkq, 0.0) + + p_qj += H * K + p_kj += H * K + p_gkj += H * K + p_bj += H + p_dk = tl.make_block_ptr(dk, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_dk2 = tl.make_block_ptr(dk2, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_dg2 = tl.make_block_ptr(dg2, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + + b_dg2 += (b_dk2 - b_dkt) * b_k + tl.load(p_dg, boundary_check=(0, 1)) + b_dk2 += tl.load(p_dk, boundary_check=(0, 1)) + b_dk2 += b_dkt + + tl.store(p_dk2, b_dk2.to(p_dk2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg2, b_dg2.to(p_dg2.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_kda_bwd_intra( + q: torch.Tensor, + k: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + dAqk: torch.Tensor, + dAkk: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + db: torch.Tensor, + dg: torch.Tensor, + cu_seqlens: torch.LongTensor = None, + chunk_indices: torch.LongTensor = None, + chunk_size: int = 64, +): + B, T, H, K = k.shape + BT = chunk_size + BC = min(16, BT) + BK = min(32, triton.next_power_of_2(K)) + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NC = triton.cdiv(BT, BC) + NK = triton.cdiv(K, BK) + + dq2 = torch.empty_like(q) + dk2 = torch.empty_like(k) + db2 = beta.new_empty(NK, *beta.shape, dtype=torch.float) + dg2 = torch.empty_like(dg, dtype=torch.float) + grid = (NK * NC, NT, B * H) + chunk_kda_bwd_kernel_intra[grid]( + q=q, + k=k, + g=g, + beta=beta, + dAqk=dAqk, + dAkk=dAkk, + dq=dq, + dq2=dq2, + dk=dk, + dk2=dk2, + dg=dg, + dg2=dg2, + db=db2, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + B=B, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + NC=NC, + ) + dq = dq2 + dk = dk2 + db = db2.sum(0).add_(db) + dg = chunk_local_cumsum( + dg2, + chunk_size=chunk_size, + reverse=True, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + ) + + return dq, dk, db, dg + + +def chunk_kda_fwd_inter_solve_fused( + q, + k, + gk, + beta, + Aqk, + Akk_diag, + Akk, + scale, + cu_seqlens: torch.LongTensor = None, + chunk_size: int = 64, + chunk_indices: torch.LongTensor = None, +): + B, T, H, K = k.shape + assert K <= 256 + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + BC = 16 + + grid = (NT, B * H) + chunk_kda_fwd_kernel_inter_solve_fused[grid]( + q=q, + k=k, + g=gk, + beta=beta, + Aqk=Aqk, + Akk_diag=Akk_diag, + Akk=Akk, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + ) diff --git a/examples/kda/FLA_KDA/fla_chunk_intra_token_parallel.py b/examples/kda/FLA_KDA/fla_chunk_intra_token_parallel.py new file mode 100644 index 000000000..1dba20282 --- /dev/null +++ b/examples/kda/FLA_KDA/fla_chunk_intra_token_parallel.py @@ -0,0 +1,168 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# Token-parallel implementation of KDA intra chunk kernel + +import torch +import triton +import triton.language as tl + +from .fla_utils import exp2, autotune_cache_kwargs + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({"BH": BH}, num_warps=num_warps) for BH in [1, 2, 4, 8] for num_warps in [1, 2, 4, 8]], + key=["K", "H"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T", "N"]) +def chunk_kda_fwd_kernel_intra_token_parallel( + q, + k, + g, + beta, + Aqk, + Akk, + scale, + cu_seqlens, + N, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BH: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_tg, i_hg = tl.program_id(0), tl.program_id(1) + + if IS_VARLEN: + i_n = 0 + left, right = 0, N + + # Unrolled binary search (max B=2^32) + # We can limit iterations based on expected max batch size if needed + # 20 iterations covers B=1M, usually enough + for _ in range(20): + if left < right: + mid = (left + right) // 2 + if i_tg < tl.load(cu_seqlens + mid + 1).to(tl.int32): + right = mid + else: + left = mid + 1 + i_n = left + + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + i_t = i_tg - bos + else: + bos = (i_tg // T) * T + i_t = i_tg % T + + if i_t >= T: + return + + i_c = i_t // BT # chunk indices + i_s = (i_t % BT) // BC # sub_chunk indices + i_tc = i_c * BT # chunk 首坐标 + i_ts = i_tc + i_s * BC # subchunk 首坐标 + + q += bos * H * K + k += bos * H * K + g += bos * H * K + Aqk += bos * H * BT + Akk += bos * H * BC + beta += bos * H + + BK: tl.constexpr = triton.next_power_of_2(K) + o_h = tl.arange(0, BH) + o_k = tl.arange(0, BK) + m_h = (i_hg * BH + o_h) < H + m_k = o_k < K + + p_q = tl.make_block_ptr(q + i_t * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_t * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_t * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) + p_beta = tl.make_block_ptr(beta + i_t * H, (H,), (1,), (i_hg * BH,), (BH,), (0,)) + # [BH, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32) + b_k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32) + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + b_k = b_k * tl.load(p_beta, boundary_check=(0,)).to(tl.float32)[:, None] + + for j in range(i_ts, min(i_t + 1, min(T, i_ts + BC))): + p_kj = tl.make_block_ptr(k + j * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) + p_gj = tl.make_block_ptr(g + j * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) + # [BH, BK] + b_kj = tl.load(p_kj, boundary_check=(0, 1)).to(tl.float32) + b_gj = tl.load(p_gj, boundary_check=(0, 1)).to(tl.float32) + + b_kgj = b_kj * exp2(b_g - b_gj) + + b_kgj = tl.where(m_k[None, :], b_kgj, 0.0) + # [BH] + b_Aqk = tl.sum(b_q * b_kgj, axis=1) * scale + b_Akk = tl.sum(b_k * b_kgj, axis=1) * tl.where(j < i_t, 1.0, 0.0) + + tl.store(Aqk + i_t * H * BT + (i_hg * BH + o_h) * BT + j % BT, b_Aqk.to(Aqk.dtype.element_ty), mask=m_h) + tl.store(Akk + i_t * H * BC + (i_hg * BH + o_h) * BC + j - i_ts, b_Akk.to(Akk.dtype.element_ty), mask=m_h) + + +def chunk_kda_fwd_intra_token_parallel( + q: torch.Tensor, + k: torch.Tensor, + gk: torch.Tensor, + beta: torch.Tensor, + Aqk: torch.Tensor, + Akk: torch.Tensor, + scale: float, + cu_seqlens: torch.LongTensor = None, + chunk_size: int = 64, + sub_chunk_size: int = 16, +) -> None: + """ + Token-parallel implementation: each token gets its own thread block. + Supports both fixed-length and variable-length sequences. + Reduces wasted computation on padding. + + Writes directly to Aqk and Akk tensors (in-place). + + Args: + q: [B, T, H, K] + k: [B, T, H, K] + gk: [B, T, H, K] cumsum of gates + beta: [B, T, H] + Aqk: [B, T, H, BT] output tensor to write to + Akk: [B, T, H, BC] output tensor for diagonal blocks (fp32) + scale: attention scale + chunk_size: BT (default 64) + sub_chunk_size: BC (default 16) + """ + B, T, H, K = q.shape + N = len(cu_seqlens) - 1 if cu_seqlens is not None else B + BT = chunk_size + BC = sub_chunk_size + + def grid(meta): + return (B * T, triton.cdiv(H, meta["BH"])) + + chunk_kda_fwd_kernel_intra_token_parallel[grid]( + q=q, + k=k, + g=gk, + beta=beta, + Aqk=Aqk, + Akk=Akk, + scale=scale, + cu_seqlens=cu_seqlens, + N=N, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + ) + return Aqk, Akk diff --git a/examples/kda/FLA_KDA/fla_chunk_o.py b/examples/kda/FLA_KDA/fla_chunk_o.py new file mode 100644 index 000000000..c29db9508 --- /dev/null +++ b/examples/kda/FLA_KDA/fla_chunk_o.py @@ -0,0 +1,546 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import triton +import triton.language as tl + + +from .fla_utils import prepare_chunk_indices, exp, exp2, autotune_cache_kwargs, check_shared_mem + + +BK_LIST = [32, 64] if check_shared_mem() else [16, 32] +BV_LIST = [64, 128] if check_shared_mem("ampere") else [16, 32] + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for BV in [64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["BT"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gla_fwd_kernel_o( + q, + v, + g, + h, + o, + A, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_g = tl.load(p_g, boundary_check=(0, 1)) + # [BT, BK] + if USE_EXP2: + b_qg = (b_q * exp2(b_g)).to(b_q.dtype) + else: + b_qg = (b_q * exp(b_g)).to(b_q.dtype) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # works but dkw, owing to divine benevolence + # [BT, BV] + if i_k >= 0: + b_o += tl.dot(b_qg, b_h.to(b_qg.dtype)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_A = tl.where(m_s, b_A, 0.0).to(b_v.dtype) + b_o += tl.dot(b_A, b_v) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BK_LIST + for BV in BV_LIST + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["BT"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gla_bwd_kernel_dv( + k, + g, + A, + do, + dh, + dv, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0.0) + # (SY 09/17) important to disallow tf32 here to maintain a good precision. + b_dv = tl.dot(b_A, b_do.to(b_A.dtype), allow_tf32=False) + + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = g + (bos + min(i_t * BT + BT, T) - 1) * H * K + i_h * K + o_k + p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + b_gn = exp(tl.load(p_gn, mask=m_k, other=0)[None, :] - b_gk) + b_k = (b_k * b_gn).to(b_k.dtype) + # [BT, BV] + # (SY 09/17) it is ok to have bf16 interchunk gradient contribution here + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype)) + + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps) for BK in BK_LIST for BV in BV_LIST for num_warps in [2, 4, 8]], + key=["BT"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gla_bwd_kernel_inter( + q, + k, + v, + g, + h, + do, + dh, + dq, + dk, + dq2, + dk2, + dg, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + g += (bos * H + i_h) * K + h += (i_tg * H + i_h) * K * V + do += (bos * H + i_h) * V + dh += (i_tg * H + i_h) * K * V + dq += (bos * H + i_h) * K + dk += (bos * H + i_h) * K + dq2 += (bos * H + i_h) * K + dk2 += (bos * H + i_h) * K + dg += (bos * H + i_h) * K + + p_gk = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + p_gn = g + (min(T, i_t * BT + BT) - 1) * H * K + o_k + b_gn = tl.load(p_gn, mask=m_k, other=0) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dgk = tl.zeros([BK], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BK] + b_dgk += tl.sum(b_h * b_dh, axis=0) + # [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) + + b_dgk *= exp(b_gn) + b_dq *= scale + b_dq = b_dq * exp(b_gk) + b_dk = b_dk * exp(b_gn[None, :] - b_gk) + + p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dgk += tl.sum(b_dk * b_k, axis=0) + b_dq += tl.load(p_dq, boundary_check=(0, 1)) + b_dk += tl.load(p_dk, boundary_check=(0, 1)) + b_dg = b_q * b_dq - b_k * b_dk + # tl.debug_barrier() + b_dg = b_dg - tl.cumsum(b_dg, axis=0) + tl.sum(b_dg, axis=0)[None, :] + b_dgk[None, :] + # Buggy due to strange triton compiler issue. + # m_s = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], 1., 0.) + # b_dg = tl.dot(m_s, b_dg, allow_tf32=False) + b_dgk[None, :] + p_dq = tl.make_block_ptr(dq2, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk2, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gla_fwd_o_gk( + q: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + A: torch.Tensor, + h: torch.Tensor, + scale: float, + cu_seqlens: torch.LongTensor = None, + chunk_size: int = 64, + chunk_indices: torch.LongTensor = None, + use_exp2: bool = False, +): + B, T, H, K, V = *q.shape, v.shape[-1] + BT = chunk_size + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + o = torch.empty_like(v) + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), NT, B * H) + + chunk_gla_fwd_kernel_o[grid]( + q=q, + v=v, + g=g, + h=h, + o=o, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + USE_EXP2=use_exp2, + ) + return o + + +NUM_WARPS = [2, 4] + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_G_GAMMA": lambda args: args["g_gamma"] is not None, + "USE_A": lambda args: args["A"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in NUM_WARPS for num_stages in [2, 3, 4]], + key=["H", "K", "V", "BT", "BK", "BV", "USE_G"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_bwd_kernel_dv_local( + q, + k, + g, + g_gamma, + A, + do, + dv, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_G_GAMMA: tl.constexpr, + USE_A: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + do += (bos * H + i_h) * V + dv += (bos * H + i_h) * V + + if USE_A: + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0).to(do.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.dot(b_A.to(b_do.dtype), b_do) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_bwd_dv_local( + q: torch.Tensor, + k: torch.Tensor, + do: torch.Tensor, + g: torch.Tensor = None, + g_gamma: torch.Tensor = None, + A: torch.Tensor = None, + scale: float = None, + cu_seqlens: torch.LongTensor = None, + chunk_size: int = 64, + chunk_indices: torch.LongTensor = None, +) -> torch.Tensor: + B, T, H, K, V = *k.shape, do.shape[-1] + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + # H100 can have larger block size + if check_shared_mem("hopper", k.device.index): + CONST_TILING = 128 + elif check_shared_mem: + CONST_TILING = 64 + else: + CONST_TILING = 32 + BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) + BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + dv = torch.empty_like(do) + grid = (NT, B * H) + chunk_bwd_kernel_dv_local[grid]( + q=q, + k=k, + g=g, + g_gamma=g_gamma, + A=A, + do=do, + dv=dv, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return dv + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4, 8] for num_stages in [2, 3, 4]], + key=["BV", "BT"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gla_bwd_kernel_dA( + v, + do, + dA, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H * V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + + b_dA += tl.dot(b_do, b_v) + + p_dA = tl.make_block_ptr(dA + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] + b_dA = tl.where(m_s, b_dA * scale, 0.0) + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gla_bwd_dA( + v: torch.Tensor, + do: torch.Tensor, + scale: float, + cu_seqlens: torch.LongTensor = None, + chunk_size: int = 64, + chunk_indices: torch.LongTensor = None, +): + B, T, H, V = v.shape + BT = chunk_size + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BV = min(64, triton.next_power_of_2(V)) + + dA = v.new_empty(B, T, H, BT, dtype=torch.float32) + grid = (NT, B * H) + chunk_gla_bwd_kernel_dA[grid]( + v=v, + do=do, + dA=dA, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + V=V, + BT=BT, + BV=BV, + ) + return dA diff --git a/examples/kda/FLA_KDA/fla_utils.py b/examples/kda/FLA_KDA/fla_utils.py new file mode 100644 index 000000000..b278aec90 --- /dev/null +++ b/examples/kda/FLA_KDA/fla_utils.py @@ -0,0 +1,240 @@ +import contextlib +import functools +import inspect +import os +import warnings +from collections.abc import Callable +from typing import Any +from packaging import version +from enum import Enum + +import torch +import triton +import triton.language.extra.libdevice as tldevice + + +device = "cuda" +device_torch_lib = getattr(torch, device) + +exp = tldevice.fast_expf +exp2 = tldevice.exp2 +log = tldevice.fast_logf +log2 = tldevice.fast_log2f + +IS_NVIDIA_HOPPER = True and ("NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9) +USE_CUDA_GRAPH = True and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" + + +FLA_CACHE_RESULTS = os.getenv("FLA_CACHE_RESULTS", "1") == "1" +SUPPORTS_AUTOTUNE_CACHE = "cache_results" in inspect.signature(triton.autotune).parameters +autotune_cache_kwargs = {"cache_results": FLA_CACHE_RESULTS} if SUPPORTS_AUTOTUNE_CACHE else {} + + +# error check,copy from +def get_abs_err(x, y): + return (x.detach() - y.detach()).flatten().abs().max().item() + + +def get_err_ratio(x, y): + err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item() + base = (x.detach()).flatten().square().mean().sqrt().item() + return err / (base + 1e-8) + + +def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6): + abs_atol = get_abs_err(ref, tri) + msg = f"{prefix:>16} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}" + print(msg) + error_rate = get_err_ratio(ref, tri) + if abs_atol <= err_atol: + return + if warning or (error_rate < 0.01 or abs_atol <= 0.3): + if error_rate > ratio: + warnings.warn(msg, stacklevel=2) + else: + assert error_rate < ratio, msg + + +def tensor_cache( + fn: Callable[..., torch.Tensor], +) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent result of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + If the function is called again with the same input tensors, it will return the cached result. + + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + last_args: tuple | None = None + last_kwargs: dict | None = None + last_result: Any = None + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal last_args, last_kwargs, last_result + + if ( + last_args is not None + and last_kwargs is not None + and len(args) == len(last_args) + and len(kwargs) == len(last_kwargs) + and all(a is b for a, b in zip(args, last_args)) + and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()) + ): + return last_result + + result = fn(*args, **kwargs) + last_args, last_kwargs, last_result = args, kwargs, result + return result + + return wrapper + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return torch.diff(cu_seqlens) + + +@tensor_cache +def prepare_chunk_indices( + cu_seqlens: torch.LongTensor, + chunk_size: int, +) -> torch.LongTensor: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()]) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) + + +# @functools.cache +# def get_multiprocessor_count(tensor_idx: int = 0) -> int: +# try: +# return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)['multiprocessor_count'] +# except BaseException: +# # Maybe we use a NPU device. +# if triton.runtime.driver.active.get_current_target().backend == 'npu': +# return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)['num_vectorcore'] +# else: +# return 1 +@functools.cache +def get_multiprocessor_count(tensor_idx: int = 0) -> int: + """ + Compatible across Triton versions: + - 2.0.x + - 2.1.0 + - 2.2.x and above + Supports CUDA and NPU. + """ + + # ---- Try the newer Triton 2.2+ API ---- + try: + drv = triton.runtime.driver.active + props = drv.utils.get_device_properties(tensor_idx) + return props.get("multiprocessor_count") or props.get("num_vectorcore") or 1 + except Exception: + pass + + # ---- Fallback: Triton 2.0 / 2.1 API ---- + try: + cuda = triton.runtime.driver.CudaDriver + dev = cuda.get_current_device() + props = cuda.get_device_properties(dev) + return props.get("multiprocessor_count", 1) + except Exception: + pass + + return 1 + + +def input_guard( + fn: Callable[..., torch.Tensor], +) -> Callable[..., torch.Tensor]: + """ + A decorator to make sure all input tensors are contiguous and set the device based on input tensors. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) + contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()} + + tensor = None + for arg in args: + if isinstance(arg, torch.Tensor): + tensor = arg + break + if tensor is None: + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor = value + break + + if tensor is not None: + ctx = custom_device_ctx(tensor.device.index) + else: + ctx = contextlib.nullcontext() + + with ctx: + return fn(*contiguous_args, **contiguous_kwargs) + + return wrapper + + +@functools.cache +def check_pytorch_version(version_s: str = "2.4") -> bool: + return version.parse(torch.__version__) >= version.parse(version_s) + + +if check_pytorch_version("2.4"): + device = "cuda" + autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device) + autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device) + + def custom_device_ctx(index: int): + return device_torch_lib.device(index) +else: + assert device == "cuda", "Only cuda device is supported for PyTorch version < 2.4.0." + autocast_custom_fwd = device_torch_lib.amp.custom_fwd + autocast_custom_bwd = device_torch_lib.amp.custom_bwd + + def custom_device_ctx(index: int): + return torch.cuda.device(index) + + +class Backend(Enum): + ADA = 101376 # RTX 4090 + AMPERE = 166912 # A100 + HOPPER = 232448 # H100 + DEFAULT = 102400 # Default + + @classmethod + def get_shared_memory(cls, arch: str) -> int: + try: + return cls[arch.upper()].value + except KeyError: + return cls.DEFAULT.value + + +def get_all_max_shared_mem(): + try: + return [ + triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"] for i in range(device_torch_lib.device_count()) + ] + except BaseException: + return [-1] + + +@functools.cache +def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: + try: + device_shared_mem_list = get_all_max_shared_mem() + max_shared_memory = device_shared_mem_list[tensor_idx] + return max_shared_memory >= Backend.get_shared_memory(arch) + except Exception: + return False diff --git a/examples/kda/FLA_KDA/fla_wy_fast.py b/examples/kda/FLA_KDA/fla_wy_fast.py new file mode 100644 index 000000000..a042c2a5f --- /dev/null +++ b/examples/kda/FLA_KDA/fla_wy_fast.py @@ -0,0 +1,312 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import triton +import triton.language as tl + +from .fla_utils import prepare_chunk_indices, exp2, autotune_cache_kwargs + + +@triton.heuristics( + { + "STORE_QG": lambda args: args["qg"] is not None, + "STORE_KG": lambda args: args["kg"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"DOT_PRECISION": DOT_PRECISION}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + for DOT_PRECISION in (["tf32x3", "ieee"]) + ], + key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def recompute_w_u_fwd_kernel( + q, + k, + qg, + kg, + v, + beta, + w, + u, + A, + gk, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + STORE_QG: tl.constexpr, + STORE_KG: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_b = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_b = tl.load(p_b, boundary_check=(0,)) + + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_b[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, input_precision=DOT_PRECISION) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = b_k * b_b[:, None] # 乘beta + + p_gk = tl.make_block_ptr(gk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kb *= exp2(b_gk) + if STORE_QG: + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_qg = b_q * exp2(b_gk) + tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty), boundary_check=(0, 1)) + if STORE_KG: + last_idx = min(i_t * BT + BT, T) - 1 + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + b_gn = tl.load(gk + ((bos + last_idx) * H + i_h) * K + o_k, mask=m_k, other=0.0) # chunk的最后一个g + b_kg = b_k * tl.where((i_t * BT + tl.arange(0, BT) < T)[:, None], exp2(b_gn[None, :] - b_gk), 0) + p_kg = tl.make_block_ptr(kg + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty), boundary_check=(0, 1)) + + b_w = tl.dot(b_A, b_kb.to(b_k.dtype)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [2, 4] for num_stages in [2, 3, 4]], + key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def prepare_wy_repr_bwd_kernel( + k, + v, + beta, + gk, + A, + dA, + dw, + du, + dk, + dk2, + dv, + db, + dg, + dg2, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_b = tl.make_block_ptr(beta + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_db = tl.make_block_ptr(db + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) + + b_b = tl.load(p_b, boundary_check=(0,)) + b_db = tl.zeros([BT], dtype=tl.float32) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk2 = tl.make_block_ptr(dk2 + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg2 = tl.make_block_ptr(dg2 + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_gk = tl.make_block_ptr(gk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_gk_exp = exp2(tl.load(p_gk, boundary_check=(0, 1))) + b_kbg = b_k * b_b[:, None] * b_gk_exp + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + + b_dA += tl.dot(b_dw, tl.trans(b_kbg).to(b_dw.dtype)) + b_dkbg = tl.dot(b_A, b_dw) + b_dk = b_dkbg * b_gk_exp * b_b[:, None] + tl.load(p_dk, boundary_check=(0, 1)) + b_db += tl.sum(b_dkbg * b_k * b_gk_exp, 1) + b_dg = b_kbg * b_dkbg + tl.load(p_dg, boundary_check=(0, 1)) + + tl.store(p_dk2, b_dk.to(p_dk2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg2, b_dg.to(p_dg2.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_b[:, None]).to(b_v.dtype) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_vb)) + b_dvb = tl.dot(b_A, b_du) + b_dv = b_dvb * b_b[:, None] + b_db += tl.sum(b_dvb * b_v, 1) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) + b_dA = tl.where(m_A, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) + b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) + + b_dA = tl.where(m_A, -b_dA, 0) + + # if using gk, save dA first and handle dk in another kernel + p_dA = tl.make_block_ptr(dA + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + q: torch.Tensor = None, + gk: torch.Tensor = None, + cu_seqlens: torch.LongTensor = None, + chunk_indices: torch.LongTensor = None, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = A.shape[-1] + BK = 64 + BV = 64 + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + w = torch.empty_like(k) + u = torch.empty_like(v) + qg = torch.empty_like(q) if q is not None else None + kg = torch.empty_like(k) if gk is not None else None + recompute_w_u_fwd_kernel[(NT, B * H)]( + q=q, + k=k, + qg=qg, + kg=kg, + v=v, + beta=beta, + w=w, + u=u, + A=A, + gk=gk, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u, qg, kg + + +def prepare_wy_repr_bwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + gk: torch.Tensor, + A: torch.Tensor, + dk: torch.Tensor, + dw: torch.Tensor, + du: torch.Tensor, + dg: torch.Tensor, + cu_seqlens: torch.LongTensor = None, + chunk_indices: torch.LongTensor = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = 64 + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + CONST_TILING = 64 + BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) + BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) + + dk2 = torch.empty_like(dk, dtype=torch.float) + dv = torch.empty_like(v) + dg2 = torch.empty_like(gk, dtype=torch.float) + dA = torch.empty_like(A, dtype=torch.float) + db = torch.empty_like(beta, dtype=torch.float) + prepare_wy_repr_bwd_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + gk=gk, + A=A, + dA=dA, + dw=dw, + du=du, + dk=dk, + dk2=dk2, + dv=dv, + db=db, + dg=dg, + dg2=dg2, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + dk = dk2 + dg = dg2 + + return dk, dv, db, dg, dA diff --git a/examples/kda/README.md b/examples/kda/README.md new file mode 100644 index 000000000..f445a9f09 --- /dev/null +++ b/examples/kda/README.md @@ -0,0 +1,7 @@ +# KDA kernel implementation with TileLang +## Requirement +- TileLang: 0.1.6.post2+cuda.git729e66ca +- triton: 3.2.0 +- FLA: commit 9714c5(used for comparison) + +We copy the needed files and function from flash-linear-attention to the FLA_KDA/ for easily comparison. diff --git a/examples/kda/chunk_bwd_dqkwg.py b/examples/kda/chunk_bwd_dqkwg.py new file mode 100644 index 000000000..d3d4df4b4 --- /dev/null +++ b/examples/kda/chunk_bwd_dqkwg.py @@ -0,0 +1,274 @@ +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune + +from FLA_KDA.fla_chunk_inter import chunk_kda_bwd_dqkwg +from test_utils_kda import do_bench, compare_tensors + +import torch + +torch.random.manual_seed(42) + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + gate_dtype, +): + BS = S // chunk_size + q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + k = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + v_new = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + w = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + g = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + h = torch.randn(B, BS, H, DK, DV, dtype=input_dtype).cuda() + dv = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + do = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + dh = torch.randn(B, BS, H, DK, DV, dtype=input_dtype).cuda() + + return q, k, v_new, w, g, h, dv, do, dh + + +def prepare_output( + B, + S, + H, + DK, + DV, + chunk_size, + gate_dtype, +): + dq = torch.randn(B, S, H, DK, dtype=torch.float32).cuda() + dk = torch.randn(B, S, H, DK, dtype=torch.float32).cuda() + dw = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + dg = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + return dq, dk, dw, dg + + +def get_configs(): + import itertools + + block_DK = [32, 64, 128] + block_DV = [32, 64, 128] + threads = [32, 64, 128, 256] + num_stages = [0, 1, 2, 3] + _configs = list(itertools.product(block_DK, block_DV, threads, num_stages)) + + configs = [{"block_DK": c[0], "block_DV": c[1], "threads": c[2], "num_stages": c[3]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=3, rep=5) +@tilelang.jit(out_idx=[-4, -3, -2, -1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) +def chunk_bwd_dqkwg( + B, + S, + H, + DK, + DV, + scale, + chunk_size, + input_dtype, + gate_dtype, + block_DK=32, + block_DV=32, + threads=32, + num_stages=0, +): + block_S = chunk_size + BS = S // block_S + K_shape = (B, S, H, DK) + V_shape = (B, S, H, DV) + H_shape = (B, BS, H, DK, DV) + + @T.prim_func + def kernel( + Q: T.Tensor(K_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + G: T.Tensor(K_shape, dtype=gate_dtype), + h: T.Tensor(H_shape, dtype=input_dtype), + dv: T.Tensor(V_shape, dtype=input_dtype), + DO: T.Tensor(V_shape, dtype=input_dtype), + Dh: T.Tensor(H_shape, dtype=input_dtype), + dq: T.Tensor(K_shape, dtype=T.float32), + dk: T.Tensor(K_shape, dtype=T.float32), + dw: T.Tensor(K_shape, dtype=gate_dtype), + dg: T.Tensor(K_shape, dtype=gate_dtype), + ): + with T.Kernel(T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H, threads=threads) as (bk, bs, bbh): + bb, bh = bbh // H, bbh % H + chunk_last_idx = T.min(S, (bs + 1) * block_S) - 1 + + dgkn_fragment = T.alloc_fragment((block_DK), dtype=T.float32) + dgkn_fragment_tmp = T.alloc_fragment((block_DK,), dtype=T.float32) + dq_fragment = T.alloc_fragment((block_S, block_DK), dtype=T.float32) + dk_fragment = T.alloc_fragment((block_S, block_DK), dtype=T.float32) + dw_fragment = T.alloc_fragment((block_S, block_DK), dtype=T.float32) + dgk_shared = T.alloc_shared((block_S, block_DK), dtype=T.float32) + + h_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype) + dh_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype) + dgkn_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype) # d of last token in a chunk + V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + DO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + DV_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + G_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) # chunk G + Gn_shared = T.alloc_shared((block_DK), dtype=input_dtype) # chunk last token G + Q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + + dkkn_shared = T.alloc_shared((block_S, block_DK), dtype=T.float32) + pp_shared = T.alloc_shared((block_DK), dtype=T.float32) + + T.clear(dgkn_fragment) + T.clear(dq_fragment) + T.clear(dk_fragment) + T.clear(dw_fragment) + + T.copy(G[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], G_shared) + T.copy(G[bb, chunk_last_idx, bh, bk * block_DK : (bk + 1) * block_DK], Gn_shared) + + for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): + T.copy(h[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], h_shared) + T.copy(Dh[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], dh_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) + T.copy(DO[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], DO_shared) + T.copy(dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], DV_shared) + # += reduce_sum + for i_k1, i_v1 in T.Parallel(block_DK, block_DV): + dgkn_shared[i_k1, i_v1] = h_shared[i_k1, i_v1] * dh_shared[i_k1, i_v1] + T.reduce_sum(dgkn_shared, dgkn_fragment_tmp, dim=1, clear=True) # [block_DK] + for i_ks in T.Parallel(block_DK): + dgkn_fragment[i_ks] += dgkn_fragment_tmp[i_ks] + T.gemm(DO_shared, h_shared, dq_fragment, transpose_B=True, clear_accum=False) # [block_S, block_DK] + T.gemm(V_shared, dh_shared, dk_fragment, transpose_B=True, clear_accum=False) # [block_S, block_DK] + T.gemm(DV_shared, h_shared, dw_fragment, transpose_B=True, clear_accum=False) # [block_S, block_DK] + # chunk last token + for i_k0 in T.Parallel(block_DK): + dgkn_fragment[i_k0] = dgkn_fragment[i_k0] * T.exp2(Gn_shared[i_k0]) + + for i_s, i_k in T.Parallel(block_S, block_DK): + dw_fragment[i_s, i_k] = -dw_fragment[i_s, i_k] + dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * scale * T.exp2(G_shared[i_s, i_k]) + dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] * T.exp2(Gn_shared[i_k] - G_shared[i_s, i_k]) + + T.copy(dw_fragment, dw[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + + T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], Q_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], K_shared) + + for i_s2, i_k2 in T.Parallel(block_S, block_DK): + dkkn_shared[i_s2, i_k2] = dk_fragment[i_s2, i_k2] * K_shared[i_s2, i_k2] + T.reduce_sum(dkkn_shared, pp_shared, dim=0, clear=True) + for i_k3 in T.Parallel(block_DK): + pp_shared[i_k3] += dgkn_fragment[i_k3] + + for i_s4, i_k4 in T.Parallel(block_S, block_DK): + dgk_shared[i_s4, i_k4] = ( + Q_shared[i_s4, i_k4] * dq_fragment[i_s4, i_k4] + - K_shared[i_s4, i_k4] * dk_fragment[i_s4, i_k4] + + T.if_then_else(chunk_last_idx == bs * block_S + i_s4, pp_shared[i_k4], 0.0) + ) + + T.copy(dgk_shared, dg[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + DV, + scale, + input_dtype, + gate_dtype, + qk_dtype, + chunk_size, + use_gk=True, + use_initial_state=True, + store_final_state=True, + save_new_value=True, + block_DK=64, + block_DV=32, + threads=128, + num_stages=0, +): + q, k, v_new, w, g, h, dv, do, dh = prepare_input(B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, gate_dtype)) + + dq_ref, dk_ref, dw_ref, dg_ref = chunk_kda_bwd_dqkwg( + q=q, + k=k, + v=v_new, + w=w, + g=g, + h=h, + dv=dv, + do=do, + dh=dh, + scale=scale, + ) + + dq, dk, dw, dg = prepare_output(B, S, H, DK, DV, chunk_size, getattr(torch, gate_dtype)) + kernel = chunk_bwd_dqkwg( + B=B, S=S, H=H, DK=DK, DV=DV, scale=scale, chunk_size=chunk_size, input_dtype=input_dtype, gate_dtype=gate_dtype + ) + dq, dk, dw, dg = kernel(q, k, v_new, g, h, dv, do, dh) + + compare_tensors("dq", dq_ref, dq) + compare_tensors("dk", dk_ref, dk) + compare_tensors("dw", dw_ref, dw) + compare_tensors("dg", dg_ref, dg) + + fla_time = do_bench( + chunk_kda_bwd_dqkwg, + q=q, + k=k, + v=v_new, + w=w, + g=g, + h=h, + dv=dv, + do=do, + dh=dh, + scale=scale, + ) + tilelang_time = do_bench(kernel, q, k, v_new, g, h, dv, do, dh) + print("fla_time:", fla_time) + print("tilelang_time:", tilelang_time) + + +def main(): + run_test( + B=1, + S=8192, + H=64, + DK=128, + DV=128, + scale=1.0, + input_dtype="float32", + gate_dtype="float32", # gate must be float32 + qk_dtype="float32", + chunk_size=64, + use_gk=True, + use_initial_state=True, + store_final_state=True, + save_new_value=True, + block_DK=32, + block_DV=32, + threads=128, + num_stages=2, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/kda/chunk_bwd_dv.py b/examples/kda/chunk_bwd_dv.py new file mode 100644 index 000000000..cdbe0a899 --- /dev/null +++ b/examples/kda/chunk_bwd_dv.py @@ -0,0 +1,150 @@ +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune +import sys # noqa: F401 + +from FLA_KDA.fla_chunk_o import chunk_bwd_dv_local +from test_utils_kda import compare_tensors, do_bench + +import torch + +torch.random.manual_seed(1) + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + do_dtype, +): + q = torch.randn(B, S, H, DK, dtype=do_dtype).cuda() + k = torch.randn(B, S, H, DK, dtype=do_dtype).cuda() + DO = torch.randn(B, S, H, DV, dtype=do_dtype).cuda() + A = torch.randn(B, S, H, chunk_size, dtype=input_dtype).cuda() + return q, k, DO, A + + +def prepare_output( + B, + S, + H, + DV, + chunk_size, + output_dtype, +): + dv = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() + return dv + + +def get_configs(): + import itertools + + block_DV = [32, 64, 128] + threads = [32, 64, 128] + num_stages = [0, 1, 2, 3, 4] + _configs = list(itertools.product(block_DV, threads, num_stages)) + configs = [{"block_DV": c[0], "threads": c[1], "num_stages": c[2]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=10, rep=5) +@tilelang.jit(out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) +def tilelang_chunk_bwd_kernel_dv_local( + B, + S, + H, + DV, + input_dtype, + output_dtype, + do_dtype, + chunk_size, + block_DV=128, + threads=128, + num_stages=1, +): + block_S = BS = chunk_size + DO_shape = (B, S, H, DV) + A_shape = (B, S, H, BS) + + @T.prim_func + def kernel( + DO: T.Tensor(DO_shape, dtype=do_dtype), + A: T.Tensor(A_shape, dtype=input_dtype), + dv: T.Tensor(DO_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): + bb, bh = bbh // H, bbh % H + + A_shared = T.alloc_shared((BS, BS), dtype=do_dtype) + DO_shared = T.alloc_shared((BS, block_DV), dtype=do_dtype) + dv_fragment = T.alloc_fragment((BS, block_DV), dtype=T.float32) + dv_shared = T.alloc_shared((BS, block_DV), dtype=output_dtype) + + T.copy(A[bb, bs * BS : (bs + 1) * BS, bh, :], A_shared) + for i_s1, i_s2 in T.Parallel(BS, BS): + A_shared[i_s1, i_s2] = T.if_then_else(i_s1 >= i_s2, A_shared[i_s1, i_s2], 0.0) + for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): + T.copy(DO[bb, bs * BS : (bs + 1) * BS, bh, i_v * block_DV : (i_v + 1) * block_DV], DO_shared) + T.gemm(A_shared, DO_shared, dv_fragment, transpose_A=True, clear_accum=True) # transpose_A: A^T + T.copy(dv_fragment, dv_shared) + T.copy(dv_shared, dv[bb, bs * BS : (bs + 1) * BS, bh, i_v * block_DV : (i_v + 1) * block_DV]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + DV, + scale, + input_dtype, + do_dtype, + output_dtype, + chunk_size, +): + q, k, DO, A = prepare_input(B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, do_dtype)) + dv_ref = chunk_bwd_dv_local(q, k, do=DO, A=A) + + dv_tilelang = prepare_output(B, S, H, DV, chunk_size, getattr(torch, output_dtype)) + kernel = tilelang_chunk_bwd_kernel_dv_local( + B=B, + S=S, + H=H, + DV=DV, + input_dtype=input_dtype, + output_dtype=output_dtype, + do_dtype=do_dtype, + chunk_size=chunk_size, + ) + dv_tilelang = kernel(DO, A) + compare_tensors("dv", dv_ref, dv_tilelang) + + fla_time = do_bench(chunk_bwd_dv_local, q, k, do=DO, A=A) + tilelang_time = do_bench(kernel, DO, A) + print("fla_time: ", fla_time) + print("tilelang_time: ", tilelang_time) + + +def main(): + run_test( + B=1, + S=1024 * 8, # 32768 + H=64, + DK=128, + DV=128, + scale=1.0, + input_dtype="bfloat16", + do_dtype="float32", + output_dtype="bfloat16", + chunk_size=64, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/kda/chunk_bwd_gla_dA.py b/examples/kda/chunk_bwd_gla_dA.py new file mode 100644 index 000000000..913fa9171 --- /dev/null +++ b/examples/kda/chunk_bwd_gla_dA.py @@ -0,0 +1,147 @@ +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune + +from FLA_KDA.fla_chunk_o import chunk_gla_bwd_dA +from test_utils_kda import compare_tensors, do_bench + +import torch + +torch.random.manual_seed(1) + + +def prepare_input( + B, + S, + H, + DV, + chunk_size, + input_dtype, + do_dtype, +): + DO = torch.randn(B, S, H, DV, dtype=do_dtype).cuda() + V_new = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + return DO, V_new + + +def prepare_output( + B, + S, + H, + DV, + chunk_size, + d_type, +): + dA = torch.empty(B, S, H, chunk_size, dtype=d_type).cuda() + return dA + + +def get_configs(): + import itertools + + block_DV = [32, 64, 128] + threads = [32, 64, 128, 256] + num_stages = [0, 1, 2, 3, 4] + _configs = list(itertools.product(block_DV, threads, num_stages)) + configs = [{"block_DV": c[0], "threads": c[1], "num_stages": c[2]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=10, rep=5) +@tilelang.jit(out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) +def tilelang_chunk_bwd_kernel_dv_local( + B, + S, + H, + DV, + scale, + input_dtype, + da_dtype, + do_dtype, + chunk_size, + block_DV=128, + threads=128, + num_stages=1, +): + block_S = BS = chunk_size + DO_shape = (B, S, H, DV) + V_shape = (B, S, H, DV) + dA_shape = (B, S, H, BS) + + @T.prim_func + def kernel( + DO: T.Tensor(DO_shape, dtype=do_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + dA: T.Tensor(dA_shape, dtype=da_dtype), + ): + with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): + bb, bh = bbh // H, bbh % H + do_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype) + V_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype) + dA_fragment = T.alloc_fragment((block_S, block_S), dtype=T.float32) + + T.clear(dA_fragment) + for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): + T.copy(DO[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], do_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) + T.gemm(do_shared, V_shared, dA_fragment, transpose_B=True) + for i_s1, i_s2 in T.Parallel(block_S, block_S): + dA_fragment[i_s1, i_s2] = T.if_then_else(i_s1 >= i_s2, dA_fragment[i_s1, i_s2] * scale, 0.0) # 下三角矩阵 + T.copy(dA_fragment, dA[bb, bs * block_S : (bs + 1) * block_S, bh, 0:block_S]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + DV, + scale, + input_dtype, + do_dtype, + da_dtype, + chunk_size, +): + DO, V_new = prepare_input(B, S, H, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, do_dtype)) + print(DO.dtype, V_new.dtype) + dA_ref = chunk_gla_bwd_dA(v=V_new, do=DO, scale=scale) + + dA_tilelang = prepare_output(B, S, H, DV, chunk_size, getattr(torch, da_dtype)) + kernel = tilelang_chunk_bwd_kernel_dv_local( + B=B, + S=S, + H=H, + DV=DV, + scale=scale, + input_dtype=input_dtype, + da_dtype=da_dtype, + do_dtype=do_dtype, + chunk_size=chunk_size, + ) + dA_tilelang = kernel(DO, V_new) + compare_tensors("dA", dA_ref, dA_tilelang) + fla_time = do_bench(chunk_gla_bwd_dA, v=V_new, do=DO, scale=scale) + tilelang_time = do_bench(kernel, DO, V_new) + print("fla_time:", fla_time) + print("tilelang_time:", tilelang_time) + + +def main(): + run_test( + B=1, + S=1024 * 8, # 32768 + H=64, + DK=128, + DV=128, + scale=1.0, + input_dtype="bfloat16", + do_dtype="bfloat16", + da_dtype="float32", + chunk_size=64, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/kda/chunk_bwd_intra.py b/examples/kda/chunk_bwd_intra.py new file mode 100644 index 000000000..6c66732b4 --- /dev/null +++ b/examples/kda/chunk_bwd_intra.py @@ -0,0 +1,493 @@ +# Reference: FLA_KDA/fla_chunk_intra.py +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune + +from FLA_KDA.fla_chunk_intra import chunk_kda_bwd_intra +from FLA_KDA.cumsum import chunk_local_cumsum +from test_utils_kda import compare_tensors, do_bench + +import torch + +torch.random.manual_seed(0) +torch.set_printoptions(profile="full") + + +def prepare_input( + B, + S, + H, + DK, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, +): + BT = chunk_size + q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + k = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + g = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + beta = torch.randn(B, S, H, dtype=input_dtype).cuda() + + # dAqk and dAkk are gradients w.r.t. Aqk and Akk + # Shape: (B, S, H, BT) + dAqk = torch.randn(B, S, H, BT, dtype=input_dtype).cuda() + dAkk = torch.randn(B, S, H, BT, dtype=input_dtype).cuda() + + # Initial gradients (will be updated by the kernel) + dq = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + dk = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + db = torch.randn(B, S, H, dtype=input_dtype).cuda() + dg = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + + return q, k, g, beta, dAqk, dAkk, dq, dk, db, dg + + +def prepare_output( + B, + S, + H, + DK, + chunk_size, + NK, + output_dtype, + gate_dtype, + state_dtype, +): + dq = torch.empty(B, S, H, DK, dtype=output_dtype).cuda() + dk = torch.empty(B, S, H, DK, dtype=output_dtype).cuda() + db = torch.empty(NK, B, S, H, dtype=output_dtype).cuda() + dg = torch.empty(B, S, H, DK, dtype=gate_dtype).cuda() + return dq, dk, db, dg + + +def get_configs(): + import itertools + + threads = [32, 64, 128, 256] + num_stages = [0, 1, 2, 3] + _configs = list(itertools.product(threads, num_stages)) + + configs = [{"threads": c[0], "num_stages": c[1]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=5, rep=5) +@tilelang.jit( + out_idx=[-4, -3, -2, -1], + pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}, +) +def tilelang_chunk_bwd_intra( + # task config + B, + S, + H, + DK, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + # kernel config + block_DK, + block_BC=16, + threads=128, + num_stages=0, +): + BT = chunk_size + BC = block_BC # sub-chunk size, typically 16 + + NC = BT // BC # number of sub-chunks + NT = T.ceildiv(S, BT) + NK = T.ceildiv(DK, block_DK) # number of K blocks + + K_shape = (B, S, H, DK) + Beta_shape = (B, S, H) + G_shape = (B, S, H, DK) + BT_shape = (B, S, H, BT) # for dAqk and dAkk + + dq_shape = (B, S, H, DK) + dk_shape = (B, S, H, DK) + db_shape = (B, S, H) + db2_shape = (NK, B, S, H) + dg_shape = (B, S, H, DK) + + @T.prim_func + def kernel( + # input + q: T.Tensor(K_shape, dtype=input_dtype), + k: T.Tensor(K_shape, dtype=input_dtype), + g: T.Tensor(G_shape, dtype=gate_dtype), + beta: T.Tensor(Beta_shape, dtype=input_dtype), + dAqk: T.Tensor(BT_shape, dtype=input_dtype), + dAkk: T.Tensor(BT_shape, dtype=input_dtype), + dq: T.Tensor(dq_shape, dtype=input_dtype), + dk: T.Tensor(dk_shape, dtype=input_dtype), + db: T.Tensor(db_shape, dtype=input_dtype), + dg: T.Tensor(dg_shape, dtype=gate_dtype), + # output + dq2: T.Tensor(dq_shape, dtype=output_dtype), + dk2: T.Tensor(dk_shape, dtype=output_dtype), + db2: T.Tensor(db2_shape, dtype=output_dtype), + dg2: T.Tensor(dg_shape, dtype=gate_dtype), + ): + with T.Kernel(T.ceildiv(DK, block_DK) * NC, NT, B * H, threads=threads) as (i_kc, i_t, i_bh): + i_k, i_i = i_kc // NC, i_kc % NC + bb, bh = i_bh // H, i_bh % H + + # actual sub-chunk index + i_ti = i_t * BT + i_i * BC + + # current sub-chunk data + q_shared = T.alloc_shared((BC, block_DK), dtype=input_dtype) + k_shared = T.alloc_shared((BC, block_DK), dtype=input_dtype) + beta_shared = T.alloc_shared((BC,), dtype=input_dtype) + g_current_shared = T.alloc_shared((BC, block_DK), dtype=gate_dtype) + gn_shared = T.alloc_shared((block_DK,), dtype=gate_dtype) # last token's g in current sub-chunk + + dq_shared = T.alloc_shared((BC, block_DK), dtype=input_dtype) + dk_shared = T.alloc_shared((BC, block_DK), dtype=input_dtype) + dg_shared = T.alloc_shared((BC, block_DK), dtype=gate_dtype) + + # Allocate fragments + dq2_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + dk2_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + dg2_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + db_fragment = T.alloc_fragment((BC,), dtype=accum_dtype) + + # Initialize fragments + T.clear(dq2_fragment) + T.clear(dk2_fragment) + T.clear(dg2_fragment) + T.clear(db_fragment) + + # Temporary shared memory for previous sub-chunks + k_prev_shared = T.alloc_shared((BC, block_DK), dtype=input_dtype) + g_prev_shared = T.alloc_shared((BC, block_DK), dtype=gate_dtype) + dAqk_prev_shared = T.alloc_shared((BC, BC), dtype=input_dtype) + dAkk_prev_shared = T.alloc_shared((BC, BC), dtype=input_dtype) + + # Temporary fragment for b_kg computation + kg_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + + kj_shared = T.alloc_shared((block_DK,), dtype=T.float32) + gkj_shared = T.alloc_shared((block_DK,), dtype=T.float32) + kgj_fragment = T.alloc_fragment((BC, block_DK), dtype=T.float32) + dAqk_col = T.alloc_shared((BC,), dtype=input_dtype) + dAkk_col = T.alloc_shared((BC,), dtype=input_dtype) + + # Load g, q, k for current sub-chunk + T.copy(q[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], q_shared) + T.copy(k[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], k_shared) + T.copy(g[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], g_current_shared) + T.copy(beta[bb, i_ti : i_ti + BC, bh], beta_shared) + + if i_i > 0: + chunk_first_idx = i_ti # chunk first token idx + + T.copy(g[bb, chunk_first_idx, bh, i_k * block_DK : (i_k + 1) * block_DK], gn_shared) # Get the first token's g value (b_gn) + + # Loop over previous sub-chunks (i_j from 0 to i_i-1) + # Since i_i is computed from i_kc % NC and NC is small, we can use conditional blocks + # Process each possible previous sub-chunk with conditional execution + for i_j in T.Pipelined(i_i, num_stages=num_stages): # i_j is index ofprevious sub_chunks + prev_ti = i_t * BT + i_j * BC + T.copy(k[bb, prev_ti : prev_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], k_prev_shared) + T.copy(g[bb, prev_ti : prev_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], g_prev_shared) + + T.copy(dAqk[bb, i_ti : i_ti + BC, bh, i_j * BC : (i_j + 1) * BC], dAqk_prev_shared) + T.copy(dAkk[bb, i_ti : i_ti + BC, bh, i_j * BC : (i_j + 1) * BC], dAkk_prev_shared) + + for i_bc, i_k2 in T.Parallel(BC, block_DK): + kg_fragment[i_bc, i_k2] = k_prev_shared[i_bc, i_k2] * T.exp2(gn_shared[i_k2] - g_prev_shared[i_bc, i_k2]) + + T.gemm(dAqk_prev_shared, kg_fragment, dq2_fragment, clear_accum=False) + T.gemm(dAkk_prev_shared, kg_fragment, dk2_fragment, clear_accum=False) + + for i_bc, i_k2 in T.Parallel(BC, block_DK): + gqn = T.exp2(g_current_shared[i_bc, i_k2] - gn_shared[i_k2]) + dq2_fragment[i_bc, i_k2] = dq2_fragment[i_bc, i_k2] * gqn + dk2_fragment[i_bc, i_k2] = dk2_fragment[i_bc, i_k2] * gqn + + # Process current sub-chunk diagonal + loop_length = T.min(BC, S - i_t * BT - i_i * BC) + for j in T.Pipelined(loop_length, num_stages=num_stages): + token_j_idx = i_ti + j + + T.copy(k[bb, token_j_idx, bh, i_k * block_DK : (i_k + 1) * block_DK], kj_shared) + T.copy(g[bb, token_j_idx, bh, i_k * block_DK : (i_k + 1) * block_DK], gkj_shared) + T.copy(dAqk[bb, i_ti : i_ti + BC, bh, i_i * BC + j], dAqk_col) + T.copy(dAkk[bb, i_ti : i_ti + BC, bh, i_i * BC + j], dAkk_col) + + for i_bc, i_k2 in T.Parallel(BC, block_DK): + kgj_fragment[i_bc, i_k2] = kj_shared[i_k2] * T.exp2(g_current_shared[i_bc, i_k2] - gkj_shared[i_k2]) + dq2_fragment[i_bc, i_k2] += T.if_then_else(i_bc >= j, dAqk_col[i_bc] * kgj_fragment[i_bc, i_k2], 0.0) + dk2_fragment[i_bc, i_k2] += T.if_then_else(i_bc >= j, dAkk_col[i_bc] * kgj_fragment[i_bc, i_k2], 0.0) + + # Compute b_db = sum(b_dk2 * b_k, dim=1) + dk2_k_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dk2_k_fragment[i_bc, i_k2] = dk2_fragment[i_bc, i_k2] * k_shared[i_bc, i_k2] + T.reduce_sum(dk2_k_fragment, db_fragment, dim=1, clear=True) + + # b_dk2 *= b_b[:, None] + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dk2_fragment[i_bc, i_k2] = dk2_fragment[i_bc, i_k2] * beta_shared[i_bc] + + # Compute b_dg2 = b_q * b_dq2 (before adding dq to dq2) + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dg2_fragment[i_bc, i_k2] = q_shared[i_bc, i_k2] * dq2_fragment[i_bc, i_k2] + + # Load dq and compute b_dq2 = b_dq2 + b_dq + T.copy(dq[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], dq_shared) + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dq2_fragment[i_bc, i_k2] = dq2_fragment[i_bc, i_k2] + dq_shared[i_bc, i_k2] + + # # Store results + T.copy(dq2_fragment, dq2[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK]) + T.copy(db_fragment, db2[i_k, bb, i_ti : i_ti + BC, bh]) + + # Initialize dkt_fragment for processing subsequent sub-chunks and lower triangular part + dkt_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + T.clear(dkt_fragment) + + # Temporary shared memory for subsequent sub-chunks + q_next_shared = T.alloc_shared((BC, block_DK), dtype=input_dtype) + k_next_shared = T.alloc_shared((BC, block_DK), dtype=input_dtype) + g_next_shared = T.alloc_shared((BC, block_DK), dtype=gate_dtype) + beta_next_shared = T.alloc_shared((BC,), dtype=input_dtype) + dAqk_next_shared = T.alloc_shared((BC, BC), dtype=input_dtype) + dAkk_next_shared = T.alloc_shared((BC, BC), dtype=input_dtype) + + # Temporary fragments for computation + gkn_shared = T.alloc_shared((BC, block_DK), dtype=accum_dtype) + qg_shared = T.alloc_shared((BC, block_DK), dtype=accum_dtype) + kbg_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + kbg_shared = T.alloc_shared((BC, block_DK), dtype=accum_dtype) + dkt_temp_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + # T.use_swizzle(10) + + NC_actual = T.min(NC, T.ceildiv(S - i_t * BT, BC)) # Process subsequent sub-chunks (i_j from i_i+1 to NC-1) + if i_i < NC_actual - 1: + # Get the last token's g value in current sub-chunk + chunk_last_idx = T.min(S, i_ti + BC) - 1 + gn_last_shared = T.alloc_shared((block_DK,), dtype=gate_dtype) + T.copy(g[bb, chunk_last_idx, bh, i_k * block_DK : (i_k + 1) * block_DK], gn_last_shared) + + # Loop over subsequent sub-chunks + for i_j in T.Pipelined(i_i + 1, NC_actual, num_stages=num_stages): + i_tj = i_t * BT + i_j * BC + + T.copy(q[bb, i_tj : i_tj + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], q_next_shared) + T.copy(k[bb, i_tj : i_tj + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], k_next_shared) + T.copy(g[bb, i_tj : i_tj + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], g_next_shared) + T.copy(beta[bb, i_tj : i_tj + BC, bh], beta_next_shared) + + T.copy(dAqk[bb, i_tj : i_tj + BC, bh, i_i * BC : (i_i + 1) * BC], dAqk_next_shared) # [BC, BC] need transpose + T.copy(dAkk[bb, i_tj : i_tj + BC, bh, i_i * BC : (i_i + 1) * BC], dAkk_next_shared) # [BC, BC] need transpose + + for i_bc, i_k2 in T.Parallel(BC, block_DK): + # kbg = k * beta + kbg_fragment[i_bc, i_k2] = k_next_shared[i_bc, i_k2] * beta_next_shared[i_bc] + gkn_shared[i_bc, i_k2] = T.if_then_else( + i_tj + i_bc < S, T.exp2(g_next_shared[i_bc, i_k2] - gn_last_shared[i_k2]), 0.0 + ) + + # Compute qg and kbg + for i_bc, i_k2 in T.Parallel(BC, block_DK): + qg_shared[i_bc, i_k2] = q_next_shared[i_bc, i_k2] * gkn_shared[i_bc, i_k2] + kbg_shared[i_bc, i_k2] = kbg_fragment[i_bc, i_k2] * gkn_shared[i_bc, i_k2] + + # Accumulate: dkt += dAqk^T @ qg + dAkk^T @ kbg + # Use transpose_A=True because dAqk/dAkk are loaded in (T, BT) layout but we need (BT, T) for gemm + T.gemm(dAqk_next_shared, qg_shared, dkt_temp_fragment, transpose_A=True, clear_accum=True) + T.gemm(dAkk_next_shared, kbg_shared, dkt_temp_fragment, transpose_A=True, clear_accum=False) + + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dkt_fragment[i_bc, i_k2] = dkt_fragment[i_bc, i_k2] + dkt_temp_fragment[i_bc, i_k2] + + # Scale dkt by exp2(gn_last - g_current) + for i_bc, i_k2 in T.Parallel(BC, block_DK): + g_scale = T.exp2(gn_last_shared[i_k2] - g_current_shared[i_bc, i_k2]) + dkt_fragment[i_bc, i_k2] = dkt_fragment[i_bc, i_k2] * g_scale + + # Process lower triangular part of current sub-chunk diagonal + # This corresponds to j <= i_bc in the diagonal block + qj_shared = T.alloc_shared((block_DK,), dtype=T.float32) + kj_shared_lower = T.alloc_shared((block_DK,), dtype=T.float32) + gj_shared_lower = T.alloc_shared((block_DK,), dtype=T.float32) + bj_local = T.alloc_local((1), dtype=input_dtype) + dAqk_col_lower = T.alloc_shared((BC,), dtype=input_dtype) + dAkk_col_lower = T.alloc_shared((BC,), dtype=input_dtype) + + gkq_fragment = T.alloc_fragment((BC, block_DK), dtype=T.float32) + # dkt_lower_temp = T.alloc_fragment((BC, block_DK), dtype=T.float32) + kbj_fragment = T.alloc_fragment((block_DK,), dtype=T.float32) + + max_token_j_idx = T.min(S, i_ti + BC) + for j in T.Pipelined(BC, num_stages=num_stages): + token_j_idx = i_ti + j + + if token_j_idx < max_token_j_idx: + T.copy(q[bb, token_j_idx, bh, i_k * block_DK : (i_k + 1) * block_DK], qj_shared) # [BK] + T.copy(k[bb, token_j_idx, bh, i_k * block_DK : (i_k + 1) * block_DK], kj_shared_lower) + T.copy(g[bb, token_j_idx, bh, i_k * block_DK : (i_k + 1) * block_DK], gj_shared_lower) + + bj_local[0] = beta[bb, token_j_idx, bh] + T.copy(dAqk[bb, token_j_idx, bh, i_i * BC : (i_i + 1) * BC], dAqk_col_lower) # [BC] + T.copy(dAkk[bb, token_j_idx, bh, i_i * BC : (i_i + 1) * BC], dAkk_col_lower) + + # Compute kbj = kj * bj + for i_k2 in T.Parallel(block_DK): + kbj_fragment[i_k2] = kj_shared_lower[i_k2] * bj_local[0] + # Compute gkq = exp2(gj - g_current) + for i_bc, i_k2 in T.Parallel(BC, block_DK): + gkq_fragment[i_bc, i_k2] = T.exp2(gj_shared_lower[i_k2] - g_current_shared[i_bc, i_k2]) + + # Accumulate: dkt += (dAkk * kbj + dAqk * qj) * gkq for i_bc <= j + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dkt_fragment[i_bc, i_k2] += T.if_then_else( + i_bc <= j, + (dAkk_col_lower[i_bc] * kbj_fragment[i_k2] + dAqk_col_lower[i_bc] * qj_shared[i_k2]) * gkq_fragment[i_bc, i_k2], + 0.0, + ) + + # Load dk and dg + T.copy(dk[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], dk_shared) + T.copy(dg[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], dg_shared) + + # Update dg2: dg2 += (dk2 - dkt) * k + dg + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dg2_fragment[i_bc, i_k2] = ( + dg2_fragment[i_bc, i_k2] + + (dk2_fragment[i_bc, i_k2] - dkt_fragment[i_bc, i_k2]) * k_shared[i_bc, i_k2] + + dg_shared[i_bc, i_k2] + ) + + # Update dk2: dk2 += dk + dkt + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dk2_fragment[i_bc, i_k2] += dk_shared[i_bc, i_k2] + dkt_fragment[i_bc, i_k2] + + # Store dk2 and dg2 + T.copy(dk2_fragment, dk2[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK]) + T.copy(dg2_fragment, dg2[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + threads=128, + num_stages=0, + cu_seqlens=None, + chunk_indices=None, +): + q, k, g, beta, dAqk, dAkk, dq, dk, db, dg = prepare_input( + B, + S, + H, + DK, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + + # Reference implementation + dq_ref, dk_ref, db_ref, dg_ref = chunk_kda_bwd_intra( + q=q, + k=k, + g=g, + beta=beta, + dAqk=dAqk, + dAkk=dAkk, + dq=dq, + dk=dk, + db=db, + dg=dg, + ) + block_DK = min(64, tilelang.math.next_power_of_2(DK)) + NK = (DK + block_DK - 1) // block_DK + # TileLang implementation + kernel = tilelang_chunk_bwd_intra( + B=B, + S=S, + H=H, + DK=DK, + input_dtype=input_dtype, + output_dtype=output_dtype, + accum_dtype=accum_dtype, + gate_dtype=gate_dtype, + state_dtype=state_dtype, + chunk_size=chunk_size, + block_DK=block_DK, + ) + + dq_tilelang, dk_tilelang, db_tilelang, dg_tilelang = prepare_output( + B, S, H, DK, chunk_size, NK, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) + dq_tilelang, dk_tilelang, db_tilelang, dg_tilelang = kernel(q, k, g, beta, dAqk, dAkk, dq, dk, db, dg) + db_tilelang = db_tilelang.sum(0).add_(db) + dg_tilelang = chunk_local_cumsum( + dg_tilelang, + chunk_size=chunk_size, + reverse=True, + ) + + compare_tensors("dq", dq_tilelang, dq_ref) + compare_tensors("dk", dk_tilelang, dk_ref) + compare_tensors("db", db_tilelang, db_ref) + compare_tensors("dg", dg_tilelang, dg_ref) + + fla_time = do_bench( + chunk_kda_bwd_intra, + q=q, + k=k, + g=g, + beta=beta, + dAqk=dAqk, + dAkk=dAkk, + dq=dq, + dk=dk, + db=db, + dg=dg, + ) + tilelang_time = do_bench(kernel, q, k, g, beta, dAqk, dAkk, dq, dk, db, dg) + print(f"Fla time: {fla_time}") + print(f"Tilelang time: {tilelang_time}") + + +def main(): + DK = 128 + run_test( + B=1, + S=8192, + H=8, + DK=DK, + input_dtype=T.float32, + output_dtype=T.float32, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, + chunk_size=64, + threads=128, + num_stages=0, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/kda/chunk_delta_bwd.py b/examples/kda/chunk_delta_bwd.py new file mode 100644 index 000000000..8c22488ca --- /dev/null +++ b/examples/kda/chunk_delta_bwd.py @@ -0,0 +1,309 @@ +# Reference: fla/ops/common/chunk_delta_h.py +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune + +from FLA_KDA.fla_chunk_delta import chunk_gated_delta_rule_bwd_dhu +from FLA_KDA.cumsum import chunk_local_cumsum +from test_utils_kda import do_bench, compare_tensors + +import torch +import torch.nn.functional as F + +torch.random.manual_seed(42) + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, +): + Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() * 0.01 + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + K = F.normalize(K, dim=-1, p=2) + W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + # Note: G should be in logspace and do chunkwise cumsum + G = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + G = F.logsigmoid(G) + G = chunk_local_cumsum(G, chunk_size) + + h0 = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda() + dht = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda() + dO = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() * 0.01 + + dv = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + return Q, K, W, G, h0, dht, dO, dv + + +def prepare_output( + B, + S, + H, + DK, + DV, + chunk_size, + output_dtype, + gate_dtype, + state_dtype, +): + BS = S // chunk_size + dh = torch.empty(B, BS, H, DK, DV, dtype=output_dtype).cuda() + dh0 = torch.empty(B, H, DK, DV, dtype=state_dtype).cuda() + dv2 = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() + return dh, dh0, dv2 + + +def get_configs(): + import itertools + + block_DV = [32, 64, 128] + threads = [32, 64, 128, 256] + num_stages = [0, 1, 2, 3, 4] + _configs = list(itertools.product(block_DV, threads, num_stages)) + + configs = [{"block_DV": c[0], "threads": c[1], "num_stages": c[2]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit(out_idx=[-3, -2, -1]) +def tilelang_chunk_gated_delta_rule_bwd_dhu( + # task config + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_gk=True, + use_initial_state=True, + use_final_state_gradient=True, + # kernel config + block_DV=64, + threads=256, + num_stages=0, +): + block_S = chunk_size + # Should support cu_seqlen + BS = S // block_S + + Q_shape = (B, S, H, DK) + K_shape = (B, S, H, DK) + W_shape = (B, S, H, DK) + G_shape = (B, S, H, DK) + h0_shape = (B, H, DK, DV) + dht_shape = (B, H, DK, DV) + dO_shape = (B, S, H, DV) + dv_shape = (B, S, H, DV) + + dh_shape = (B, BS, H, DK, DV) + dh0_shape = (B, H, DK, DV) + dv2_shape = (B, S, H, DV) + + @T.prim_func + def kernel( + # Input + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + GK: T.Tensor(G_shape, dtype=gate_dtype), + h0: T.Tensor(h0_shape, dtype=input_dtype), + dht: T.Tensor(dht_shape, dtype=input_dtype), + dO: T.Tensor(dO_shape, dtype=input_dtype), + dv: T.Tensor(dv_shape, dtype=input_dtype), + # Output + dh: T.Tensor(dh_shape, dtype=output_dtype), + dh0: T.Tensor(dh0_shape, dtype=state_dtype), + dv2: T.Tensor(dv2_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): + bb, bh = bbh // H, bbh % H + + b_dh_shared = T.alloc_shared((DK, block_DV), dtype=output_dtype) + b_dh_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + b_dh_fragment_1 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + b_dh_fragment_2 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + dv_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + dv_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + dv_fragment_2 = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + dO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + + Q_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + + GK_last_shared = T.alloc_shared((DK,), dtype=gate_dtype) + + if use_final_state_gradient: + T.copy(dht[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_dh_shared) + T.copy(b_dh_shared, b_dh_fragment) + else: + T.clear(b_dh_fragment) + + for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages): + # The gradient should be stored in the reverse order + i_s_inv = T.ceildiv(S, block_S) - i_s - 1 # reverse indices + # Store the updated dh + T.copy(b_dh_fragment, b_dh_shared) + T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) + + # Update dv + T.copy(K[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], K_shared) + T.gemm(K_shared, b_dh_shared, dv_fragment, clear_accum=True) + T.copy( + dv[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dv_shared + ) # copy old dv + T.copy(dv_shared, dv_fragment_2) + for i_s2, i_v in T.Parallel(block_S, block_DV): + dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] + dv_fragment_2[i_s2, i_v] + # Store the updated dv + T.copy(dv_fragment, dv_shared) + T.copy(dv_shared, dv2[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) + + # Update dh + T.copy(Q[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], Q_shared) # [block_S, DK] + T.copy(W[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], W_shared) # [block_S, DK] + T.copy( + dO[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dO_shared + ) # [block_S, block_DV] + + if use_gk: + last_idx = T.min((i_s_inv + 1) * block_S, S) - 1 # chunk last token gk + T.copy(GK[bb, last_idx, bh, :], GK_last_shared) + for i_k, i_v in T.Parallel(DK, block_DV): + b_dh_fragment[i_k, i_v] *= T.exp2(GK_last_shared[i_k]) + + T.gemm(Q_shared, dO_shared, b_dh_fragment_1, transpose_A=True, clear_accum=True) # [DK, block_DV] + + # dv_shared: [block_S, block_DV] + T.gemm(W_shared, dv_shared, b_dh_fragment_2, transpose_A=True, clear_accum=True) # [DK, block_DV] + for i_k, i_v in T.Parallel(DK, block_DV): + b_dh_fragment[i_k, i_v] += b_dh_fragment_1[i_k, i_v] * scale - b_dh_fragment_2[i_k, i_v] + + if use_initial_state: + T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_gk=True, + use_initial_state=True, + use_final_state_gradient=True, + block_DV=64, + threads=256, + num_stages=0, + use_torch=False, +): + Q, K, W, G, h0, dht, dO, dv = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + + dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) + + # fla ref + print("fla running...", flush=True) + if use_gk: + dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu( + q=Q, k=K, w=W, do=dO, dv=dv, gk=G, h0=h0, dht=dht, scale=scale, use_exp2=True + ) + + # tilelang + print("tilelang running...", flush=True) + kernel = tilelang_chunk_gated_delta_rule_bwd_dhu( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_gk, + use_initial_state, + use_final_state_gradient, + ) + dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) + + fla_time = do_bench( + chunk_gated_delta_rule_bwd_dhu, q=Q, k=K, w=W, do=dO, dv=dv, gk=G, h0=h0, dht=dht, scale=scale, chunk_size=chunk_size + ) + tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv) + + print(f"fla time: {fla_time} ms") + print(f"tilelang time: {tilelang_time} ms") + + compare_tensors("dh", dh_ref, dh_tilelang) + compare_tensors("dh0", dh0_ref, dh0_tilelang) + compare_tensors("dv2", dv2_ref, dv2_tilelang) + + +def main(): + DK = 128 + run_test( + B=1, + S=1024 * 8, + H=64, + DK=DK, + DV=128, + input_dtype="bfloat16", + output_dtype="bfloat16", + accum_dtype="float32", + gate_dtype="float32", + state_dtype="float32", + chunk_size=64, + scale=DK**-0.5, + use_gk=True, + use_initial_state=True, + use_final_state_gradient=True, + block_DV=32, + threads=128, + num_stages=1, + use_torch=False, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/kda/chunk_delta_h_fwd.py b/examples/kda/chunk_delta_h_fwd.py new file mode 100644 index 000000000..fbb8bd988 --- /dev/null +++ b/examples/kda/chunk_delta_h_fwd.py @@ -0,0 +1,306 @@ +# Reference: fla/ops/common/chunk_delta_h.py + +import sys # noqa: F401 +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune + +# Add your fla repository path to sys.path +# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae +# sys.path.insert(0, "/your/path/to/flash-linear-attention") + +from FLA_KDA.fla_chunk_delta import chunk_gated_delta_rule_fwd_h +from FLA_KDA.cumsum import chunk_local_cumsum + +import torch +import torch.nn.functional as F + +from test_utils_kda import compare_tensors, do_bench + +torch.random.manual_seed(42) + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, +): + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + K = F.normalize(K, dim=-1, p=2) + W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + W = F.normalize(W, dim=-1, p=2) + U = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + U = F.normalize(U, dim=-1, p=2) + G = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + G = F.logsigmoid(G) + G = chunk_local_cumsum(G, chunk_size) + initial_state = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda() + return K, W, U, G, initial_state + + +def prepare_output( + B, + S, + H, + DK, + DV, + chunk_size, + output_dtype, + state_dtype, +): + BS = (S + chunk_size - 1) // chunk_size # ceildiv to match kernel iteration + h = torch.empty(B, BS, H, DK, DV, dtype=output_dtype).cuda() + final_state = torch.empty(B, H, DK, DV, dtype=state_dtype).cuda() + V_new = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() + return h, final_state, V_new + + +def get_configs(): + import itertools + + block_DK = [32, 64, 128] + block_DV = [32, 64, 128] + threads = [128, 256] + num_stages = [1, 2, 3] + _configs = list(itertools.product(block_DK, block_DV, threads, num_stages)) + + configs = [{"block_DK": c[0], "block_DV": c[1], "threads": c[2], "num_stages": c[3]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=3, rep=5) +@tilelang.jit(out_idx=[-3, -2, -1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) +def tilelang_chunk_gated_delta_rule_fwd_h( + # task config + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_gk, + use_initial_state, + store_final_state, + save_new_value, + # kernel config + block_DK=64, + block_DV=32, + threads=128, + num_stages=1, +): + block_S = chunk_size + BS = (S + chunk_size - 1) // chunk_size # ceildiv to match kernel iteration + + K_shape = (B, S, H, DK) + V_shape = (B, S, H, DV) + W_shape = (B, S, H, DK) + U_shape = (B, S, H, DV) + GK_shape = (B, S, H, DK) + h_shape = (B, BS, H, DK, DV) + initial_state_shape = (B, H, DK, DV) + final_state_shape = (B, H, DK, DV) + + @T.prim_func + def kernel( + K: T.Tensor(K_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + U: T.Tensor(U_shape, dtype=input_dtype), + GK: T.Tensor(GK_shape, dtype=gate_dtype), + initial_state: T.Tensor(initial_state_shape, dtype=input_dtype), + h: T.Tensor(h_shape, dtype=output_dtype), + final_state: T.Tensor(final_state_shape, dtype=state_dtype), + V_new: T.Tensor(V_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): + bb, bh = bbh // H, bbh % H + + b_h_shared = T.alloc_shared((DK, block_DV), dtype=input_dtype) + b_h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + + U_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + U_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + V_new_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + V_new_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype) + K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + GK_last_shared = T.alloc_shared((DK), dtype=gate_dtype) + + if use_initial_state: + T.copy(initial_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_h_shared) + T.copy(b_h_shared, b_h_fragment) + else: + T.clear(b_h_fragment) + + for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages): + # Store previous result to the hidden tensor, like the epilogue + T.copy(b_h_shared, h[bb, i_s, bh, :, bv * block_DV : (bv + 1) * block_DV]) + + # Recurrence + T.copy(W[bb, i_s * block_S : (i_s + 1) * block_S, bh, :], W_shared) + T.gemm(W_shared, b_h_shared, V_new_fragment, clear_accum=True) + + # U - W * S + T.copy(U[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], U_shared) + T.copy(U_shared, U_fragment) + for i_s2, i_v in T.Parallel(block_S, block_DV): + V_new_fragment[i_s2, i_v] = -V_new_fragment[i_s2, i_v] + U_fragment[i_s2, i_v] + + # Save V_new + if save_new_value: + T.copy(V_new_fragment, dst=V_new_shared) + T.copy(V_new_shared, V_new[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) + + T.copy(K[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], K_shared) + # use_gk + if use_gk: + T.copy(GK[bb, (i_s + 1) * block_S - 1, bh, :], GK_last_shared) # block last token + for i_k, i_v in T.Parallel(DK, block_DV): + b_h_fragment[i_k, i_v] *= T.exp2(GK_last_shared[i_k]) + + # Update intermediate results + T.copy(V_new_fragment, V_new_shared) + T.gemm(K_shared, V_new_shared, b_h_fragment, transpose_A=True) + + T.copy(b_h_fragment, b_h_shared) + + # Save final state + if store_final_state: + T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_gk=True, + use_initial_state=True, + store_final_state=True, + save_new_value=True, + block_DK=64, + block_DV=32, + threads=128, + num_stages=0, +): + K, W, U, G, initial_state = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + ) + h_ref, final_state_ref, V_new_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype) + ) + h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype) + ) + + # fla ref + h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h( + k=K, + w=W, + u=U, + gk=G, + initial_state=initial_state, + output_final_state=store_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value, + use_exp2=True, + ) + + # tilelang + kernel = tilelang_chunk_gated_delta_rule_fwd_h( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_gk, + use_initial_state, + store_final_state, + save_new_value, + ) + h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) + + fla_time = do_bench( + chunk_gated_delta_rule_fwd_h, + k=K, + w=W, + u=U, + gk=G, + initial_state=initial_state, + output_final_state=store_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value, + use_exp2=True, + ) + tilelang_time = do_bench(kernel, K, W, U, G, initial_state) + + # check correctness + compare_tensors("h", h_ref, h_tilelang) + compare_tensors("final_state", final_state_ref, final_state_tilelang) + compare_tensors("V_new", V_new_ref, V_new_tilelang) + + print(f"tilelang time: {tilelang_time} ms") + print(f"fla time: {fla_time} ms") + + +def main(): + run_test( + B=1, + S=8192, + H=64, + DK=128, + DV=128, + input_dtype="float16", + output_dtype="float16", + accum_dtype="float32", + gate_dtype="float32", + state_dtype="float32", + chunk_size=64, + use_gk=True, + use_initial_state=True, + store_final_state=True, + save_new_value=True, + block_DK=32, + block_DV=32, + threads=128, + num_stages=2, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/kda/chunk_inter_solve_fused.py b/examples/kda/chunk_inter_solve_fused.py new file mode 100644 index 000000000..940dc20c8 --- /dev/null +++ b/examples/kda/chunk_inter_solve_fused.py @@ -0,0 +1,566 @@ +import tilelang +import tilelang.language as T + +from FLA_KDA.fla_chunk_intra import chunk_kda_fwd_inter_solve_fused +from FLA_KDA.cumsum import chunk_local_cumsum +from test_utils_kda import compare_tensors, do_bench + +import torch +import torch.nn.functional as F + + +torch.random.manual_seed(42) + + +def prepare_input( + B, + S, + H, + DK, + chunk_size, + sub_chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, +): + q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + k = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + beta = torch.randn(B, S, H, dtype=input_dtype).cuda() + gk = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() # 需要是cumsum + gk = F.logsigmoid(gk) + gk = chunk_local_cumsum(gk, chunk_size) + + Aqk = torch.empty(B, S, H, chunk_size, dtype=input_dtype).cuda() + Akk_diag = torch.ones(B, S, H, sub_chunk_size, dtype=torch.float32).cuda() + + return q, k, gk, beta, Aqk, Akk_diag + + +def prepare_output( + B, + S, + H, + chunk_size, + sub_chunk_size, + output_dtype, +): + Akk = torch.empty(B, S, H, chunk_size, dtype=output_dtype).cuda() + return Akk + + +@tilelang.jit(out_idx=[-2, -1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) +def tilelang_chunk_kda_fwd_inter_fused( + B, + S, + H, + DK, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + chunk_size, + sub_chunk_size, + scale, + block_DK=32, + threads=32, + num_stages=1, +): + block_S = BS = chunk_size + BC = sub_chunk_size + Q_shape = (B, S, H, DK) + K_shape = (B, S, H, DK) + GK_shape = (B, S, H, DK) + Beta_shape = (B, S, H) + Aqk_shape = (B, S, H, BS) + Akk_diag_shape = (B, S, H, BC) + """ + Fused kernel: compute inter-subchunk Akk + solve_tril in one pass. + Prerequisite: token_parallel has already computed diagonal Akk blocks in Akk_diag. + + This kernel: + 1. Computes off-diagonal Aqk blocks -> writes to global + 2. Computes off-diagonal Akk blocks -> keeps in registers + 3. Loads diagonal Akk blocks from Akk_diag (fp32) + 4. Does forward substitution on diagonals + 5. Computes merged Akk_inv + 6. Writes Akk_inv to Akk + """ + + @T.prim_func + def kernel( + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + GK: T.Tensor(GK_shape, dtype=gate_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + Akk_diag: T.Tensor(Akk_diag_shape, dtype=T.float32), + Aqk: T.Tensor(Aqk_shape, dtype=output_dtype), + Akk: T.Tensor(Aqk_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): + bb, bh = bbh // H, bbh % H + + Aqk10_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Akk10_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Aqk20_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Akk20_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Aqk21_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Akk21_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Aqk30_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Akk30_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Aqk31_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Akk31_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Aqk32_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Akk32_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Akk10_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Akk20_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Akk21_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Akk30_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Akk31_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Akk32_shared = T.alloc_shared((BC, BC), dtype=T.float32) + + K0_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + GK0_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + Q1_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + K1_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + GK1_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + Q2_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + K2_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + GK2_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + Q3_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + K3_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + GK3_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + + Q_GK_scaled_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + K_GK_scaled_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + b_kt_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + + b_gn1_shared = T.alloc_shared((block_DK,), dtype=T.float32) + b_gn2_shared = T.alloc_shared((block_DK,), dtype=T.float32) + b_gn3_shared = T.alloc_shared((block_DK,), dtype=T.float32) + + b_gqn1_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + b_gqn2_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + b_gqn3_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + + beta_1_shared = T.alloc_shared((BC,), dtype=T.float32) + beta_2_shared = T.alloc_shared((BC,), dtype=T.float32) + beta_3_shared = T.alloc_shared((BC,), dtype=T.float32) + # Akk_inv + Ai_00_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_10_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_11_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_20_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_21_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_22_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_30_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_31_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_32_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_33_shared = T.alloc_shared((BC, BC), dtype=T.float32) + + T.clear(Aqk10_fragment) + T.clear(Akk10_fragment) + T.clear(Aqk20_fragment) + T.clear(Akk20_fragment) + T.clear(Aqk21_fragment) + T.clear(Akk21_fragment) + T.clear(Aqk30_fragment) + T.clear(Akk30_fragment) + T.clear(Aqk31_fragment) + T.clear(Akk31_fragment) + T.clear(Aqk32_fragment) + T.clear(Akk32_fragment) + + i_tc0 = bs * BS + i_tc1 = bs * BS + BC + i_tc2 = bs * BS + 2 * BC + i_tc3 = bs * BS + 3 * BC + + ################################################################################ + # 1. off-diagonal blocks + ################################################################################ + + for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): + T.copy(K[bb, bs * BS : bs * BS + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], K0_shared) + T.copy(GK[bb, bs * BS : bs * BS + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], GK0_shared) + if i_tc1 < S: + T.copy(Q[bb, i_tc1 : i_tc1 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], Q1_shared) + T.copy(K[bb, i_tc1 : i_tc1 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], K1_shared) + T.copy(GK[bb, i_tc1 : i_tc1 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], GK1_shared) + T.copy(GK[bb, i_tc1, bh, i_k * block_DK : (i_k + 1) * block_DK], b_gn1_shared) # subblock第一个token的GK + for i_c1, i_k1 in T.Parallel(BC, block_DK): + b_gqn1_shared[i_c1, i_k1] = T.if_then_else( + i_tc1 + i_c1 < S, T.exp2(GK1_shared[i_c1, i_k1] - b_gn1_shared[i_k1]), 0.0 + ) + Q_GK_scaled_shared[i_c1, i_k1] = Q1_shared[i_c1, i_k1] * b_gqn1_shared[i_c1, i_k1] + K_GK_scaled_shared[i_c1, i_k1] = K1_shared[i_c1, i_k1] * b_gqn1_shared[i_c1, i_k1] + b_kt_shared[i_c1, i_k1] = K0_shared[i_c1, i_k1] * T.exp2(b_gn1_shared[i_k1] - GK0_shared[i_c1, i_k1]) + T.gemm(Q_GK_scaled_shared, b_kt_shared, Aqk10_fragment, transpose_B=True) + T.gemm(K_GK_scaled_shared, b_kt_shared, Akk10_fragment, transpose_B=True) + if i_tc2 < S: + T.copy(Q[bb, i_tc2 : i_tc2 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], Q2_shared) + T.copy(K[bb, i_tc2 : i_tc2 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], K2_shared) + T.copy(GK[bb, i_tc2 : i_tc2 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], GK2_shared) + T.copy(GK[bb, i_tc2, bh, i_k * block_DK : (i_k + 1) * block_DK], b_gn2_shared) + for i_c2, i_k2 in T.Parallel(BC, block_DK): + b_gqn2_shared[i_c2, i_k2] = T.if_then_else( + i_tc2 + i_c2 < S, T.exp2(GK2_shared[i_c2, i_k2] - b_gn2_shared[i_k2]), 0.0 + ) + Q_GK_scaled_shared[i_c2, i_k2] = Q2_shared[i_c2, i_k2] * b_gqn2_shared[i_c2, i_k2] + K_GK_scaled_shared[i_c2, i_k2] = K2_shared[i_c2, i_k2] * b_gqn2_shared[i_c2, i_k2] + b_kt_shared[i_c2, i_k2] = K0_shared[i_c2, i_k2] * T.exp2(b_gn2_shared[i_k2] - GK0_shared[i_c2, i_k2]) + T.gemm(Q_GK_scaled_shared, b_kt_shared, Aqk20_fragment, transpose_B=True) + T.gemm(K_GK_scaled_shared, b_kt_shared, Akk20_fragment, transpose_B=True) + for i_c3, i_k3 in T.Parallel(BC, block_DK): + b_kt_shared[i_c3, i_k3] = K1_shared[i_c3, i_k3] * T.exp2(b_gn2_shared[i_k3] - GK1_shared[i_c3, i_k3]) + T.gemm(Q_GK_scaled_shared, b_kt_shared, Aqk21_fragment, transpose_B=True) + T.gemm(K_GK_scaled_shared, b_kt_shared, Akk21_fragment, transpose_B=True) + if i_tc3 < S: + T.copy(Q[bb, i_tc3 : i_tc3 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], Q3_shared) + T.copy(K[bb, i_tc3 : i_tc3 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], K3_shared) + T.copy(GK[bb, i_tc3 : i_tc3 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], GK3_shared) + T.copy(GK[bb, i_tc3, bh, i_k * block_DK : (i_k + 1) * block_DK], b_gn3_shared) + for i_c4, i_k4 in T.Parallel(BC, block_DK): + b_gqn3_shared[i_c4, i_k4] = T.if_then_else( + i_tc3 + i_c4 < S, T.exp2(GK3_shared[i_c4, i_k4] - b_gn3_shared[i_k4]), 0.0 + ) + Q_GK_scaled_shared[i_c4, i_k4] = Q3_shared[i_c4, i_k4] * b_gqn3_shared[i_c4, i_k4] + K_GK_scaled_shared[i_c4, i_k4] = K3_shared[i_c4, i_k4] * b_gqn3_shared[i_c4, i_k4] + b_kt_shared[i_c4, i_k4] = K0_shared[i_c4, i_k4] * T.exp2(b_gn3_shared[i_k4] - GK0_shared[i_c4, i_k4]) + T.gemm(Q_GK_scaled_shared, b_kt_shared, Aqk30_fragment, transpose_B=True) + T.gemm(K_GK_scaled_shared, b_kt_shared, Akk30_fragment, transpose_B=True) + for i_c5, i_k5 in T.Parallel(BC, block_DK): + b_kt_shared[i_c5, i_k5] = K1_shared[i_c5, i_k5] * T.exp2(b_gn3_shared[i_k5] - GK1_shared[i_c5, i_k5]) + T.gemm(Q_GK_scaled_shared, b_kt_shared, Aqk31_fragment, transpose_B=True) + T.gemm(K_GK_scaled_shared, b_kt_shared, Akk31_fragment, transpose_B=True) + for i_c6, i_k6 in T.Parallel(BC, block_DK): + b_kt_shared[i_c6, i_k6] = K2_shared[i_c6, i_k6] * T.exp2(b_gn3_shared[i_k6] - GK2_shared[i_c6, i_k6]) + T.gemm(Q_GK_scaled_shared, b_kt_shared, Aqk32_fragment, transpose_B=True) + T.gemm(K_GK_scaled_shared, b_kt_shared, Akk32_fragment, transpose_B=True) + + ################################################################################ + # 2. save off-diagonal Aqk blocks and prepare Akk + ################################################################################ + + if i_tc1 < S: + T.copy(Beta[bb, i_tc1 : i_tc1 + BC, bh], beta_1_shared) + for i_c21, i_c22 in T.Parallel(BC, BC): + Aqk10_fragment[i_c21, i_c22] = Aqk10_fragment[i_c21, i_c22] * scale + Akk10_fragment[i_c21, i_c22] = Akk10_fragment[i_c21, i_c22] * beta_1_shared[i_c21] + T.copy(Aqk10_fragment, Aqk[bb, i_tc1 : i_tc1 + BC, bh, 0:BC]) + T.copy(Akk10_fragment, Akk10_shared) + if i_tc2 < S: + T.copy(Beta[bb, i_tc2 : i_tc2 + BC, bh], beta_2_shared) + for i_c23, i_c24 in T.Parallel(BC, BC): + Aqk20_fragment[i_c23, i_c24] = Aqk20_fragment[i_c23, i_c24] * scale + Aqk21_fragment[i_c23, i_c24] = Aqk21_fragment[i_c23, i_c24] * scale + Akk20_fragment[i_c23, i_c24] = Akk20_fragment[i_c23, i_c24] * beta_2_shared[i_c23] + Akk21_fragment[i_c23, i_c24] = Akk21_fragment[i_c23, i_c24] * beta_2_shared[i_c23] + T.copy(Aqk20_fragment, Aqk[bb, i_tc2 : i_tc2 + BC, bh, 0:BC]) + T.copy(Aqk21_fragment, Aqk[bb, i_tc2 : i_tc2 + BC, bh, BC : 2 * BC]) + T.copy(Akk20_fragment, Akk20_shared) + T.copy(Akk21_fragment, Akk21_shared) + if i_tc3 < S: + T.copy(Beta[bb, i_tc3 : i_tc3 + BC, bh], beta_3_shared) + for i_c25, i_c26 in T.Parallel(BC, BC): + Aqk30_fragment[i_c25, i_c26] = Aqk30_fragment[i_c25, i_c26] * scale + Aqk31_fragment[i_c25, i_c26] = Aqk31_fragment[i_c25, i_c26] * scale + Aqk32_fragment[i_c25, i_c26] = Aqk32_fragment[i_c25, i_c26] * scale + Akk30_fragment[i_c25, i_c26] = Akk30_fragment[i_c25, i_c26] * beta_3_shared[i_c25] + Akk31_fragment[i_c25, i_c26] = Akk31_fragment[i_c25, i_c26] * beta_3_shared[i_c25] + Akk32_fragment[i_c25, i_c26] = Akk32_fragment[i_c25, i_c26] * beta_3_shared[i_c25] + T.copy(Aqk30_fragment, Aqk[bb, i_tc3 : i_tc3 + BC, bh, 0:BC]) + T.copy(Aqk31_fragment, Aqk[bb, i_tc3 : i_tc3 + BC, bh, BC : 2 * BC]) + T.copy(Aqk32_fragment, Aqk[bb, i_tc3 : i_tc3 + BC, bh, 2 * BC : 3 * BC]) + T.copy(Akk30_fragment, Akk30_shared) + T.copy(Akk31_fragment, Akk31_shared) + T.copy(Akk32_fragment, Akk32_shared) + + ################################################################################ + # 3. load diagonal Akk blocks + ################################################################################ + + T.copy(Akk_diag[bb, i_tc0 : i_tc0 + BC, bh, :], Ai_00_shared) + T.copy(Akk_diag[bb, i_tc1 : i_tc1 + BC, bh, :], Ai_11_shared) + T.copy(Akk_diag[bb, i_tc2 : i_tc2 + BC, bh, :], Ai_22_shared) + T.copy(Akk_diag[bb, i_tc3 : i_tc3 + BC, bh, :], Ai_33_shared) + for i_c1, i_c2 in T.Parallel(BC, BC): + Ai_00_shared[i_c1, i_c2] = T.if_then_else(i_c1 > i_c2, -Ai_00_shared[i_c1, i_c2], 0) + Ai_11_shared[i_c1, i_c2] = T.if_then_else(i_c1 > i_c2, -Ai_11_shared[i_c1, i_c2], 0) + Ai_22_shared[i_c1, i_c2] = T.if_then_else(i_c1 > i_c2, -Ai_22_shared[i_c1, i_c2], 0) + Ai_33_shared[i_c1, i_c2] = T.if_then_else(i_c1 > i_c2, -Ai_33_shared[i_c1, i_c2], 0) + + ################################################################################ + # 4. forward substitution on diagonals + ################################################################################ + a_00_shared = T.alloc_shared((BC,), dtype=T.float32) + Aa_mul_shared = T.alloc_shared((BC, BC), dtype=T.float32) + reduce_shared = T.alloc_shared((BC,), dtype=T.float32) + for i_i in T.Pipelined(2, T.min(BC, S - i_tc0), num_stages=num_stages): + T.copy(Akk_diag[bb, i_tc0 + i_i, bh, :], a_00_shared) # load row + for i_c in T.Parallel(BC): + a_00_shared[i_c] = T.if_then_else(i_c < i_i, -a_00_shared[i_c], 0.0) # mask:i_c