diff --git a/.clang-tidy b/.clang-tidy index 2ddbefbf9..f9b77bce8 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,10 +1,12 @@ --- InheritParentConfig: true -ExtraArgs: ['-v'] +ExtraArgs: [] FormatStyle: file UseColor: true WarningsAsErrors: '*' -ExcludeHeaderFilterRegex: '^(3rdparty|tvm)/.*$' +# FIXME: Use `ExcludeHeaderFilterRegex` instead when all maintainers upgraded their `clang-tidy` +HeaderFilterRegex: '^(?!.*(?:/|^)(3rdparty|tvm)/).*' +# ExcludeHeaderFilterRegex: '^(3rdparty|tvm)/.*$' # NOTE: there must be no spaces before the '-', so put the comma last. Checks: >- diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 3ba13e0ce..0086358db 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1 +1 @@ -blank_issues_enabled: false +blank_issues_enabled: true diff --git a/.github/ISSUE_TEMPLATE/release-plan.yml b/.github/ISSUE_TEMPLATE/release-plan.yml new file mode 100644 index 000000000..a3528275c --- /dev/null +++ b/.github/ISSUE_TEMPLATE/release-plan.yml @@ -0,0 +1,63 @@ +name: "Release Plan" +description: "Plan the next release" +title: "[Release Plan] vX.Y.Z" +labels: + - release-plan + - tracking +assignees: [] +body: + - type: input + id: version + attributes: + label: "Version" + placeholder: "v0.2.0" + validations: + required: true + + - type: input + id: milestone + attributes: + label: "Milestone" + description: "Link or name of the milestone for this release" + placeholder: "https://github.com/tile-ai/tilelang/milestone/XX" + + - type: textarea + id: scope + attributes: + label: "Scope" + description: "Goals and non-goals (brief)" + placeholder: | + - Goals: ... + - Non-goals: ... + + - type: textarea + id: tasks + attributes: + label: "Tasks" + description: "Task list; link issues/PRs" + value: | + - [ ] Features + - [ ] Fixes + - [ ] Docs + - [ ] API/Breaking changes + - [ ] Benchmarks + - [ ] Release notes + + - type: checkboxes + id: readiness + attributes: + label: "Readiness" + options: + - label: "All planned issues closed or deferred" + - label: "Docs updated" + - label: "CI green; artifacts verified" + - label: "Release notes drafted" + + - type: textarea + id: notes + attributes: + label: "Notes" + description: "Risks or communications (optional)" + placeholder: | + - Risk: ... + - Communication: ... diff --git a/.github/workflows/amd_ci.yml b/.github/workflows/amd_ci.yml index 2ef300b66..144c0f09f 100644 --- a/.github/workflows/amd_ci.yml +++ b/.github/workflows/amd_ci.yml @@ -11,7 +11,7 @@ jobs: runs-on: [self-hosted, amd, gpu] permissions: - contents: write + contents: write steps: - name: Checkout repository @@ -56,7 +56,7 @@ jobs: echo "------------------------------------" exit 1 fi - + - name: Commit and Push Changes uses: stefanzweifel/git-auto-commit-action@v5 with: @@ -86,7 +86,7 @@ jobs: set -e REQS_HASH=$(sha256sum requirements-rocm.txt | cut -d ' ' -f 1) MARKER="${{ runner.tool_cache }}/.venv_marker_${{ env.PYTHON_VERSION }}_${REQS_HASH:0:8}" - + echo "Installing requirements" if [[ -f "$MARKER" ]] && [[ -f "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" ]]; then echo "venv exists and hash matches – reuse it" @@ -117,4 +117,4 @@ jobs: source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" cd testing/python/amd unset PYTHONPATH - python -m pytest -v test_tilelang_test_amd.py \ No newline at end of file + python -m pytest -v test_tilelang_test_amd.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a04edc1eb..8d5f3ffb4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,154 +1,342 @@ name: CI -on: [pull_request] +on: + pull_request: + types: + - labeled + - unlabeled + - opened + - synchronize + - reopened + # Allow to trigger the workflow manually + workflow_dispatch: + +permissions: + contents: read + +concurrency: + group: "${{ github.workflow }}-${{ github.ref }}" + cancel-in-progress: ${{ github.event_name == 'pull_request' }} env: - PYTHON_VERSION: '3.12' - VENV_DIR: tilelang_ci + CLANG_TIDY_CMAKE_OPTIONS: "-DCMAKE_EXPORT_COMPILE_COMMANDS=ON" # to be updated + PYTHONDEVMODE: "1" + PYTHONUNBUFFERED: "1" + PYTHONPATH: "" # explicit cleanup + PIP_USER: "" # explicit cleanup + COLUMNS: "100" + FORCE_COLOR: "1" + CLICOLOR_FORCE: "1" + UV_INDEX_STRATEGY: "unsafe-best-match" + UV_HTTP_TIMEOUT: "600" + XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated + PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated + UV_CACHE_DIR: "${{ github.workspace }}/.cache/uv" # to be updated + PRE_COMMIT_HOME: "${{ github.workspace }}/.cache/pip/.pre-commit" # to be updated jobs: - format-check: - runs-on: [self-hosted, nvidia, hopper] + lint: + name: Quick Lint + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + fetch-depth: 0 + submodules: recursive - permissions: - contents: write + - name: Setup Python 3.8 + id: setup-pylowest + uses: actions/setup-python@v6 + with: + python-version: "3.8" # use lowest supported version for linting + update-environment: false + + - name: Check AST with Python 3.8 + run: | + "${{ steps.setup-pylowest.outputs.python-path }}" -m compileall -q -f tilelang + + - name: Setup Python 3.9 + uses: actions/setup-python@v6 + with: + python-version: "3.9" + update-environment: true + cache: pip + cache-dependency-path: | + pyproject.toml + requirements*.txt + .pre-commit-config.yaml + + - name: Pre-commit Lint + run: | + if ! pipx run pre-commit run --all-files --color=always --show-diff-on-failure; then + echo "::error::Pre-commit checks failed. Please run 'pre-commit install' and 'pre-commit run --all-files' locally to see the issues." + exit 1 + fi + + tests: + name: Test for Python ${{ matrix.python-version }} with ${{ matrix.runner.toolkit }} (on ${{ matrix.runner.name }}) + if: | + github.repository_owner == 'tile-ai' && + (github.event_name != 'pull_request' || !github.event.pull_request.draft) + needs: [lint] + runs-on: ${{ matrix.runner.tags }} + strategy: + matrix: + runner: + - tags: [self-hosted, tilescale] + name: self-hosted-nvidia + # Format: [Nightly-]CUDA-.[.]. E.g., "CUDA-12.8" or "Nightly-CUDA-13.0". + # Use "Nightly-" prefix to use torch nightly builds. + toolkit: CUDA-12.8 + python-version: + - "3.12" + fail-fast: false + timeout-minutes: 120 steps: - - name: Checkout repository - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: ${{ env.PYTHON_VERSION }} - - - name: Ensure venv (local & persistent) - run: | - set -e - REQS_HASH=$(sha256sum requirements-test.txt 2>/dev/null | awk '{print $1}' || echo "no_requirements") - MARKER="${{ runner.tool_cache }}/.venv_marker_${{ env.PYTHON_VERSION }}_${REQS_HASH:0:8}" - - if [[ -f "$MARKER" ]] && [[ -f "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" ]]; then - echo "venv exists and hash matches – reuse it" - else - echo "venv stale or missing – recreating" - rm -rf "${{ runner.tool_cache }}/${{ env.VENV_DIR }}" "$MARKER" - python -m venv "${{ runner.tool_cache }}/${{ env.VENV_DIR }}" - # shellcheck source=/dev/null - source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" - python -m pip install --upgrade pip --no-user - [[ -f requirements-test.txt ]] && \ - PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-test.txt --no-user - pip install flash_attn==2.5.8 --no-user --no-build-isolation - touch "$MARKER" - fi - - - name: Run format check - run: | - source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" - if ! output=$(./format.sh 2>&1); then - echo "------------------------------------" - echo "message:" - echo "$output" - printf '%s\n' "$output" | grep "Please review and stage the changes." - echo "------------------------------------" - exit 1 - fi - - - name: Commit and Push Changes - uses: stefanzweifel/git-auto-commit-action@v5 - with: - commit_message: "lint" - - build-test-nvidia: - runs-on: [self-hosted, nvidia, hopper] - needs: format-check - permissions: - contents: read - steps: - - name: Checkout repository - uses: actions/checkout@v4 - with: - fetch-depth: 0 - repository: ${{ github.event.pull_request.head.repo.full_name }} - ref: ${{ github.event.pull_request.head.ref }} - - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: ${{ env.PYTHON_VERSION }} - - - name: Ensure venv (local & persistent) - run: | - set -e - REQS_HASH=$(cat requirements-test.txt 2>/dev/null || true) - MARKER="${{ runner.tool_cache }}/.venv_marker_${{ env.PYTHON_VERSION }}_${REQS_HASH:0:8}" - # NOTE(wt): We disable the venv reuse for now to allow installing DeepEP - # echo "venv stale or missing – recreating" - rm -rf "${{ runner.tool_cache }}/${{ env.VENV_DIR }}" - python -m venv "${{ runner.tool_cache }}/${{ env.VENV_DIR }}" - source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" - python -m pip install --upgrade pip --no-user - [[ -f requirements-test.txt ]] && \ - PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-test.txt --no-user - # flash attention usually requires no isolation build - pip install flash_attn==2.5.8 --no-user --no-build-isolation - - - name: Install project (wheel form) - run: | - source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" - pip install . --no-user -v - bash tilelang/distributed/install_deepep.sh # Install DeepEP for testing purpose - - - name: Run examples - run: | - source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" - cd examples - unset PYTHONPATH - - # find and run distributed tests with TILELANG_USE_DISTRIBUTED=1 - mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true) - if [ "${#DIST_TESTS[@]}" -gt 0 ]; then + - name: Checkout repository + uses: actions/checkout@v6 + with: + fetch-depth: 0 + submodules: recursive + + - name: Set environment (self-hosted runners) + if: startsWith(matrix.runner.name, 'self-hosted') + run: | + # Hide sensitive data in logs for self-hosted runners + if [[ -n "${{ secrets.SECRET_PATH_PREFIXES }}" ]]; then + echo "::add-mask::${{ secrets.SECRET_PATH_PREFIXES }}" + # Colon separated list of secrets to mask + for secret in $(echo "${{ secrets.SECRET_PATH_PREFIXES }}" | tr ':' '\n'); do + echo "::add-mask::${secret}" + done + fi + + # Use runner tool_cache as cache root for self-hosted runners to avoid internet connection + # issues and to share cache between jobs. + export XDG_CACHE_HOME="${{ runner.tool_cache }}/.ci-cache-${{ github.workflow }}" + echo "XDG_CACHE_HOME=${XDG_CACHE_HOME}" | tee -a "${GITHUB_ENV}" + echo "PIP_CACHE_DIR=${XDG_CACHE_HOME}/pip" | tee -a "${GITHUB_ENV}" + echo "UV_CACHE_DIR=${XDG_CACHE_HOME}/uv" | tee -a "${GITHUB_ENV}" + echo "PRE_COMMIT_HOME=${XDG_CACHE_HOME}/pip/.pre-commit" | tee -a "${GITHUB_ENV}" + + # Do not use ccache on self-hosted runners, as it will download/upload caches which is slow. + # Self-hosted runners usually have more CPU power to compile without ccache. + - name: Setup ccache (GitHub-hosted runners) + id: setup-ccache + if: ${{ !startsWith(matrix.runner.name, 'self-hosted') }} + uses: hendrikmuhs/ccache-action@v1 + with: + create-symlink: true + evict-old-files: "7d" + append-timestamp: false + key: ${{ runner.os }}-${{ runner.arch }}-${{ matrix.runner.toolkit }}-${{ hashFiles('**/*.cc') }} + restore-keys: | + ${{ runner.os }}-${{ runner.arch }}-${{ matrix.runner.toolkit }}-${{ hashFiles('**/*.cc') }} + ${{ runner.os }}-${{ runner.arch }}-${{ matrix.runner.toolkit }} + ${{ runner.os }}-${{ runner.arch }} + + - name: Set environment (CUDA) + if: contains(matrix.runner.toolkit, 'CUDA') + run: | + TOOLKIT="${{ matrix.runner.toolkit }}" + CUDA_VERSION="${TOOLKIT##*-}" + CUDA_VERSION_MAJMIN="$(echo ${CUDA_VERSION} | cut -d '.' -f-2)" + CUDA_VERSION_MAJMIN_NODOT="${CUDA_VERSION_MAJMIN//./}" + if [[ "${TOOLKIT}" == "Nightly-"* ]]; then + # Use torch nightly builds + export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/nightly/cu${CUDA_VERSION_MAJMIN_NODOT}" + else + export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu${CUDA_VERSION_MAJMIN_NODOT}" + fi + export UV_INDEX="${PIP_EXTRA_INDEX_URL}" + export CLANG_TIDY_CMAKE_OPTIONS="${CLANG_TIDY_CMAKE_OPTIONS} -DUSE_CUDA=ON" + + echo "USE_CUDA=ON" | tee -a "${GITHUB_ENV}" + echo "CUDA_VERSION=${CUDA_VERSION}" | tee -a "${GITHUB_ENV}" + echo "CUDA_VERSION_MAJMIN=${CUDA_VERSION_MAJMIN}" | tee -a "${GITHUB_ENV}" + echo "CUDA_VERSION_MAJMIN_NODOT=${CUDA_VERSION_MAJMIN_NODOT}" | tee -a "${GITHUB_ENV}" + echo "PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}" | tee -a "${GITHUB_ENV}" + echo "UV_INDEX=${UV_INDEX}" | tee -a "${GITHUB_ENV}" + echo "CLANG_TIDY_CMAKE_OPTIONS=${CLANG_TIDY_CMAKE_OPTIONS}" | tee -a "${GITHUB_ENV}" + + if [[ ! -x "$(command -v nvcc)" ]]; then + export PATH="/usr/local/cuda/bin:${PATH}" + export LD_LIBRARY_PATH="/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" + echo "PATH=${PATH}" | tee -a "${GITHUB_ENV}" + echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" | tee -a "${GITHUB_ENV}" + fi + if [[ -x "$(command -v nvcc)" ]]; then + echo "\$ $(command -v nvcc) --version" && nvcc --version + else + echo "::warning::nvcc not found in PATH!" + fi + + - name: Setup Python and uv with caching + id: setup-uv + uses: astral-sh/setup-uv@v7 + with: + python-version: ${{ matrix.python-version }} + activate-environment: true + # Do not use cache for self-hosted runners, as it will download/upload caches which is slow. + enable-cache: ${{ !startsWith(matrix.runner.name, 'self-hosted') }} + prune-cache: ${{ !startsWith(matrix.runner.name, 'self-hosted') }} + # Use runner tool_cache for self-hosted runners + cache-local-path: ${{ env.UV_CACHE_DIR }} + ignore-nothing-to-cache: true + # Extra cache key to upload/download caches on GitHub-hosted runners + cache-suffix: uv-${{ runner.os }}-${{ runner.arch }}-${{ matrix.python-version }}-${{ matrix.runner.name }}-${{ matrix.runner.toolkit }} + cache-dependency-glob: | + pyproject.toml + requirements*.txt + .pre-commit-config.yaml + + - name: Setup venv + id: setup-venv + run: | + set -o pipefail + + uv pip install --upgrade pip setuptools wheel + if [[ "${UV_INDEX}" == *"/nightly/"* ]]; then + uv pip install --prerelease=allow -v torch + fi + uv pip install -v -r requirements-test.txt -i https://pypi.tuna.tsinghua.edu.cn/simple + echo "import torch; print(f'torch: {torch.__version__}')" | uv run --no-project --script - + if [[ "${{ matrix.runner.toolkit }}" == *"CUDA"* ]]; then + uv pip install --no-build-isolation-package=flash-attn -v -r requirements-test-cuda.txt -i https://pypi.tuna.tsinghua.edu.cn/simple + echo "import flash_attn; print(f'flash_attn: {flash_attn.__version__}')" | uv run --no-project --script - + # elif [[ "${{ matrix.runner.toolkit }}" == *"ROCm"* ]]; then + # uv pip install -v -r requirements-test-rocm.txt + # elif [[ "${{ matrix.runner.toolkit }}" == *"Metal"* ]]; then + # uv pip install -v -r requirements-test-metal.txt + else + echo "::error::Unknown toolkit: ${{ matrix.runner.toolkit }}" + exit 1 + fi + echo "::group::torch.utils.collect_env" + uv run --no-project -m -- torch.utils.collect_env + echo "::endgroup::" + + - name: Clear uv cache for self-hosted runners (if setup failed) + if: >- + ${{ + failure() && + startsWith(matrix.runner.name, 'self-hosted') && + (steps.setup-uv.conclusion == 'failure' || steps.setup-venv.conclusion == 'failure') + }} + run: | + echo "Clearing uv cache at ${UV_CACHE_DIR} due to failure." + uv cache clean + + - name: Enable core dump generation (Linux / GitHub-hosted runners) + if: ${{ runner.os == 'Linux' && !startsWith(matrix.runner.name, 'self-hosted') }} + run: | + sudo sysctl -w kernel.core_pattern="core.${{ matrix.python-version }}.${{ matrix.runner.toolkit }}.%P" + sudo sysctl -w kernel.core_uses_pid=0 + sudo sysctl -w fs.suid_dumpable=1 + sysctl kernel.core_pattern kernel.core_uses_pid fs.suid_dumpable + + - name: Enable core dump generation (macOS / GitHub-hosted runners) + if: ${{ runner.os == 'macOS' && !startsWith(matrix.runner.name, 'self-hosted') }} + run: | + sudo sysctl -w kern.corefile="core.${{ matrix.python-version }}.${{ matrix.runner.toolkit }}.%P" + sudo sysctl -w kern.coredump=1 + sudo sysctl -w kern.sugid_coredump=1 + sysctl kern.corefile kern.coredump kern.sugid_coredump + + - name: Install project (wheel form) + run: | + uv pip install -v . + bash tilelang/distributed/install_deepep.sh # Install DeepEP for testing purpose + export NCCL_IB_DISABLE=1 # Our CI machine's IB is incomplete, disable it to avoid unnecessary error msgs + + # - name: Run clang-tidy + # id: clang-tidy + # if: runner.os == 'Linux' + # run: | + # echo "\$ $(command -v clang-tidy) --version" && clang-tidy --version + + # # Download run-clang-tidy script + # RCT_URL=https://raw.githubusercontent.com/llvm/llvm-project/refs/heads/release/21.x/clang-tools-extra/clang-tidy/tool/run-clang-tidy.py + # echo "Downloading run-clang-tidy script from ${RCT_URL}" + # echo "import urllib.request; url = '${RCT_URL}'.rstrip('/'); urllib.request.urlretrieve(url, url.split('/')[-1])" | uv run --no-project --script - + # RUN_CLANG_TIDY=(uv run --no-project --script -- run-clang-tidy.py) + + # if [[ -x "$(command -v clang-apply-replacements)" ]]; then + # echo "Using clang-apply-replacements from $(command -v clang-apply-replacements)" + # RUN_CLANG_TIDY+=(-fix -clang-apply-replacements-binary="$(command -v clang-apply-replacements)") + # else + # echo "::warning::clang-apply-replacements not found in PATH, automatic fixing disabled." + # fi + + # # Run cmake to create the build directory with compile_commands.json + # cmake -S . -B cmake-build --fresh ${CLANG_TIDY_CMAKE_OPTIONS} # no quotes here + # echo "::group::compile_commands.json" + # ls -alh cmake-build/compile_commands.json + # uv run --no-project -m -- json.tool --no-ensure-ascii cmake-build/compile_commands.json + # echo "::endgroup::" + + # CXX_FILES=$(find src -type f -iname "*.[ch]pp" -o -iname "*.cc" -o -iname "*.c" -o -iname "*.h") + # rc=0 + # echo "::group::run-clang-tidy" + # "${RUN_CLANG_TIDY[@]}" -clang-tidy-binary="$(command -v clang-tidy)" \ + # -exclude-header-filter='^(3rdparty|tvm)/.*$' \ + # -p="cmake-build" ${CXX_FILES} || rc="$?" + # echo "::endgroup::" + # rm -rf cmake-build run-clang-tidy.py + # if (( rc != 0 )); then + # echo "::error::clang-tidy found issues (exit code: ${rc}). Please run 'clang-tidy --fix' locally to fix them." + # git diff --color=always || true + # exit "${rc}" + # fi + + - name: Run examples with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }}) + if: contains(matrix.runner.toolkit, 'CUDA') + run: | + cd examples + unset PYTHONPATH + PYTEST=( + uv run --no-project -m -- + pytest --verbose --color=yes --durations=0 --showlocals --cache-clear -r fE + ) + + # Run distributed tests (marked with @requires_distributed) with TILELANG_USE_DISTRIBUTED=1 + # DeepEP tests requires fullmesh nvl or internode environment, we disable for now echo "Running distributed examples with TILELANG_USE_DISTRIBUTED=1:" - printf '%s\n' "${DIST_TESTS[@]}" - TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 1 "${DIST_TESTS[@]}" -v -r fE - else - echo "No distributed examples found." - fi - - # run remaining example tests (non-distributed) - mapfile -t OTHER_TESTS < <(find . -type f -name 'test*.py' ! -path '*/distributed/*' | grep -vE 'sink|vs_sparse' 2>/dev/null || true) # temporarily disable problematic tests - if [ "${#OTHER_TESTS[@]}" -gt 0 ]; then + TILELANG_USE_DISTRIBUTED=1 "${PYTEST[@]}" --maxfail=3 --numprocesses=1 -m distributed --ignore-glob='*deepep*' . || true + + # Run remaining example tests (non-distributed) + # Temporarily disable problematic tests: sink, vs_sparse echo "Running non-distributed examples:" - printf '%s\n' "${OTHER_TESTS[@]}" - python -m pytest -n 4 "${OTHER_TESTS[@]}" -v -r fE - else - echo "No non-distributed example tests found." - fi - - - name: Run tests - run: | - source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" - cd testing/python - unset PYTHONPATH - - # run distributed tests first with env var - mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true) - if [ "${#DIST_TESTS[@]}" -gt 0 ]; then + "${PYTEST[@]}" --maxfail=3 --numprocesses=2 -m "not distributed" -k "not sink and not vs_sparse" . || true + + # NVIDIA CUDA tests + - name: Run CUDA tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }}) + id: cuda-tests + if: contains(matrix.runner.toolkit, 'CUDA') + run: | + cd testing/python + unset PYTHONPATH + PYTEST=( + uv run --no-project -m -- + pytest --verbose --color=yes --durations=0 --showlocals --cache-clear -r fE + ) + + # Run distributed tests (marked with @requires_distributed) with TILELANG_USE_DISTRIBUTED=1 echo "Running distributed tests with TILELANG_USE_DISTRIBUTED=1:" - printf '%s\n' "${DIST_TESTS[@]}" - TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 1 "${DIST_TESTS[@]}" -v -r fE - else - echo "No distributed tests found under testing/python." - fi - - # run remaining tests - mapfile -t OTHER_TESTS < <(find . -type f -name 'test*.py' ! -path '*/distributed/*' | grep -vE 'tilelibrary_gemm|jit_gemm_ctypes' 2>/dev/null || true) # temporarily disable problematic tests - if [ "${#OTHER_TESTS[@]}" -gt 0 ]; then + TILELANG_USE_DISTRIBUTED=1 "${PYTEST[@]}" --maxfail=3 --numprocesses=1 -m distributed . || true + + # Run remaining tests (non-distributed) + # Temporarily disable problematic tests: tilelibrary_gemm, jit_gemm_ctypes echo "Running non-distributed tests:" - printf '%s\n' "${OTHER_TESTS[@]}" - python -m pytest -n 4 "${OTHER_TESTS[@]}" -v -r fE - else - echo "No non-distributed tests found under testing/python." - fi + "${PYTEST[@]}" --maxfail=3 --numprocesses=2 -m "not distributed" -k "not tilelibrary_gemm and not jit_gemm_ctypes" . || true + + - name: List generated files + if: ${{ !cancelled() }} + run: | + find . -type f -name '*.py[co]' -delete + find . -depth -type d -name "__pycache__" -exec rm -r "{}" + + if git status --ignored --porcelain | grep -qvE '/$'; then + ls -alh $(git status --ignored --porcelain | grep -vE '/$' | grep -oE '\S+$') + fi diff --git a/.github/workflows/dist.yml b/.github/workflows/dist.yml index 904fbb13b..74132ffb3 100644 --- a/.github/workflows/dist.yml +++ b/.github/workflows/dist.yml @@ -1,5 +1,6 @@ name: Dist on: + workflow_dispatch: schedule: # gemini said this is 6:00 china time - cron: "0 22 * * *" @@ -28,6 +29,18 @@ concurrency: group: "${{ github.workflow }}-${{ github.ref }}" cancel-in-progress: true +env: + PYTHONDEVMODE: "1" + PYTHONUNBUFFERED: "1" + COLUMNS: "100" + FORCE_COLOR: "1" + CLICOLOR_FORCE: "1" + UV_INDEX_STRATEGY: "unsafe-best-match" + UV_HTTP_TIMEOUT: "600" + XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated + PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated + UV_CACHE_DIR: "${{ github.workspace }}/.cache/uv" # to be updated + jobs: build-wheels: name: Build wheels for Python ${{ matrix.python-version }} on ${{ matrix.target.runner }} with ${{ matrix.target.toolkit }} @@ -37,39 +50,41 @@ jobs: strategy: matrix: target: - - { runner: ubuntu-latest, toolkit: "CUDA-12.1" } - - { runner: ubuntu-24.04-arm, toolkit: "CUDA-12.8" } - - { runner: macos-latest, toolkit: "Metal" } + # NOTE(wt): Temporarily disable ARM and MacOS, as NVSHMEM only supports x86 (?) + - { runner: ubuntu-latest, toolkit: "CUDA-12.8" } + # - { runner: ubuntu-24.04-arm, toolkit: "CUDA-12.8" } + - { runner: ubuntu-latest, toolkit: "Nightly-CUDA-13.0" } + # - { runner: ubuntu-24.04-arm, toolkit: "Nightly-CUDA-13.0" } + # - { runner: macos-latest, toolkit: "Metal" } python-version: - - "3.8" - # TVM is built with Python 3.8 Limited API, it should work with all Python >= 3.8. - # - "3.9" - # - "3.10" - # - "3.11" - # - "3.12" - # - "3.13" - # - "3.14" + # Wheels are built with Python 3.8 Limited API, they should work with all Python >= 3.8. + # Only build wheels against Python 3.8 Limited API to save CI resources. + - "3.9" fail-fast: false timeout-minutes: 120 runs-on: ${{ matrix.target.runner }} env: - NO_VERSION_LABEL: ${{ github.event_name == 'release' && 'OFF' || 'ON' }} + IS_RELEASE: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.title, '[Release]') }} + NO_VERSION_LABEL: "OFF" steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 1 submodules: recursive - # NB: CIBW builds wheels in containers on Linux - - name: Setup ccache (macOS only) - if: runner.os == 'macOS' + - name: Setup ccache uses: hendrikmuhs/ccache-action@v1 with: + max-size: "200MB" create-symlink: true - key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.python-version }}-${{ matrix.target.toolkit }} evict-old-files: "7d" + append-timestamp: false + key: wheel-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.cc') }} + restore-keys: | + wheel-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.cc') }} + wheel-${{ runner.os }}-${{ runner.arch }} - name: Set CIBW_BUILD run: | @@ -80,21 +95,81 @@ jobs: if [[ "${{ matrix.target.toolkit }}" == *"CUDA"* ]]; then CUDA_VERSION="${{ matrix.target.toolkit }}" - CUDA_VERSION="${CUDA_VERSION#CUDA-}" + CUDA_VERSION="${CUDA_VERSION##*-}" + CUDA_VERSION_MAJMIN="$(echo ${CUDA_VERSION} | cut -d '.' -f-2)" + CUDA_VERSION_MAJMIN_NODOT="${CUDA_VERSION_MAJMIN//./}" echo "CUDA_VERSION=${CUDA_VERSION}" | tee -a "${GITHUB_ENV}" + if [[ "${{ matrix.target.toolkit }}" == "Nightly-"* ]]; then + # Use torch nightly builds + export UV_INDEX="https://download.pytorch.org/whl/nightly/cu${CUDA_VERSION_MAJMIN_NODOT}" + else + export UV_INDEX="https://download.pytorch.org/whl/cu${CUDA_VERSION_MAJMIN_NODOT}" + echo "UV_TORCH_BACKEND=cu${CUDA_VERSION_MAJMIN_NODOT}" | tee -a "${GITHUB_ENV}" + fi + echo "UV_INDEX=${UV_INDEX}" | tee -a "${GITHUB_ENV}" + fi + + if [[ "${{ env.IS_RELEASE }}" == "true" ]]; then + if [[ "${{ matrix.target.toolkit }}" == "Nightly-"* ]]; then + # Avoid using same file name for different toolkit. + echo "NO_GIT_VERSION=ON" | tee -a "${GITHUB_ENV}" + else + echo "NO_VERSION_LABEL=ON" | tee -a "${GITHUB_ENV}" + fi + fi + + if [[ "${{ runner.os }}" == "Linux" ]]; then + HOST_CCACHE_DIR="$(ccache --get-config cache_dir)" + # Install torch for tilescale_ext._C build, then setup ccache + echo "CIBW_BEFORE_BUILD_LINUX=pip install torch --no-cache-dir && dnf install -y ccache && ccache -o cache_dir=/host${HOST_CCACHE_DIR}" | tee -a "${GITHUB_ENV}" fi - name: Build wheels - uses: pypa/cibuildwheel@v3.2 + uses: pypa/cibuildwheel@v3.3 with: package-dir: . output-dir: wheelhouse config-file: "{package}/pyproject.toml" + - name: Setup Python and uv with caching + id: setup-uv + uses: astral-sh/setup-uv@v7 + with: + python-version: "3.12" + activate-environment: true + + - name: Test built wheels + # Skip CUDA wheel tests on GitHub-hosted runners (no CUDA available) + # Tests should be run on self-hosted runners with CUDA or during release validation + if: ${{ !contains(matrix.target.toolkit, 'CUDA') || contains(matrix.target.runner, 'self-hosted') }} + run: | + for WHEEL in wheelhouse/*.whl; do + echo "Testing wheel: ${WHEEL}" + ( + set -e + uv venv --python=3.12 test-venv + source test-venv/bin/activate + + uv pip install --upgrade pip setuptools wheel + if [[ "${UV_INDEX}" == *"/nightly/"* ]]; then + uv pip install --prerelease=allow -v torch + fi + + uv pip install -v "${WHEEL}" + ( + set -e + cd / + uv run --no-project -- python -c "import tilelang; print(tilelang.__version__)" + ) + deactivate + rm -rf test-venv + ) + done + - name: Upload wheels # Not PR to save artifact storage, as wheels are only needed for releases. - if: github.event_name != 'pull_request' - uses: actions/upload-artifact@v4 + if: github.event_name != 'pull_request' || contains(github.event.pull_request.title, '[Release]') + uses: actions/upload-artifact@v6 with: name: wheels-${{ matrix.python-version }}-${{ runner.os }}-${{ runner.arch }}-${{ matrix.target.toolkit }} path: wheelhouse/*.whl @@ -102,14 +177,14 @@ jobs: list-artifacts: name: List artifacts - # Not PR to save artifact storage, as wheels are only needed for releases. - if: github.event_name != 'pull_request' + # Not PR to save artifact storage, as artifacts are only needed for releases. + if: github.event_name != 'pull_request' || contains(github.event.pull_request.title, '[Release]') runs-on: ubuntu-latest needs: [build-wheels] timeout-minutes: 15 steps: - name: Download built wheels - uses: actions/download-artifact@v5 + uses: actions/download-artifact@v7 with: pattern: wheels-* path: dist @@ -119,7 +194,7 @@ jobs: run: ls -lh dist/* - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: artifacts path: dist/* diff --git a/.github/workflows/pr-regression-test-bot.yml b/.github/workflows/pr-regression-test-bot.yml new file mode 100644 index 000000000..568ce8555 --- /dev/null +++ b/.github/workflows/pr-regression-test-bot.yml @@ -0,0 +1,273 @@ +name: Performance Regression Bot + +on: + issue_comment: + types: + - created + +permissions: + contents: read + issues: write + pull-requests: write + +concurrency: + # Use the issue/PR number to differentiate between different PRs + group: "${{ github.workflow }}-${{ github.event.issue.number }}" + cancel-in-progress: true + +env: + PYTHONDEVMODE: "1" + PYTHONUNBUFFERED: "1" + PYTHONPATH: "" # explicit cleanup + PIP_USER: "" # explicit cleanup + COLUMNS: "100" + FORCE_COLOR: "1" + CLICOLOR_FORCE: "1" + UV_INDEX_STRATEGY: "unsafe-best-match" + UV_HTTP_TIMEOUT: "600" + XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated + PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated + UV_CACHE_DIR: "${{ github.workspace }}/.cache/uv" # to be updated + PRE_COMMIT_HOME: "${{ github.workspace }}/.cache/pip/.pre-commit" # to be updated + +jobs: + permissions-check: + name: Check bot permissions + if: | + github.repository_owner == 'tile-ai' && + github.event.issue.pull_request && + (contains(github.event.comment.body, '@regression-perf')) + runs-on: ubuntu-latest + steps: + - name: Get commenter permission + id: perm + uses: actions/github-script@v8 + with: + script: | + const username = context.payload.comment.user.login + const { owner, repo } = context.repo + const { data } = await github.rest.repos.getCollaboratorPermissionLevel({ owner, repo, username }) + core.setOutput('permission', data.permission) // admin|maintain|write|triage|read|none + + - name: Reject if not allowed + if: ${{ steps.perm.outputs.permission != 'admin' && steps.perm.outputs.permission != 'maintain' && steps.perm.outputs.permission != 'write' }} + run: | + echo "Not authorized: permission=${{ steps.perm.outputs.permission }}" + exit 1 + + pr-regression: + name: Performance regression test between PR and main + needs: [permissions-check] + runs-on: ${{ matrix.runner.tags }} + strategy: + matrix: + runner: + - tags: [self-hosted, nvidia] + name: self-hosted-nvidia + toolkit: CUDA-12.8 + python-version: + - "3.12" + fail-fast: false + timeout-minutes: 120 + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + ref: refs/pull/${{ github.event.issue.number }}/merge + fetch-depth: 0 + submodules: recursive + + - name: Set environment (self-hosted runners) + if: startsWith(matrix.runner.name, 'self-hosted') + run: | + # Hide sensitive data in logs for self-hosted runners + if [[ -n "${{ secrets.SECRET_PATH_PREFIXES }}" ]]; then + echo "::add-mask::${{ secrets.SECRET_PATH_PREFIXES }}" + # Colon separated list of secrets to mask + for secret in $(echo "${{ secrets.SECRET_PATH_PREFIXES }}" | tr ':' '\n'); do + echo "::add-mask::${secret}" + done + fi + + # Use runner tool_cache as cache root for self-hosted runners to avoid internet connection + # issues and to share cache between jobs. + export XDG_CACHE_HOME="${{ runner.tool_cache }}/.ci-cache-${{ github.workflow }}" + echo "XDG_CACHE_HOME=${XDG_CACHE_HOME}" | tee -a "${GITHUB_ENV}" + echo "PIP_CACHE_DIR=${XDG_CACHE_HOME}/pip" | tee -a "${GITHUB_ENV}" + echo "UV_CACHE_DIR=${XDG_CACHE_HOME}/uv" | tee -a "${GITHUB_ENV}" + echo "PRE_COMMIT_HOME=${XDG_CACHE_HOME}/pip/.pre-commit" | tee -a "${GITHUB_ENV}" + + # Do not use ccache on self-hosted runners, as it will download/upload caches which is slow. + # Self-hosted runners usually have more CPU power to compile without ccache. + - name: Setup ccache (GitHub-hosted runners) + id: setup-ccache + if: ${{ !startsWith(matrix.runner.name, 'self-hosted') }} + uses: hendrikmuhs/ccache-action@v1 + with: + create-symlink: true + evict-old-files: "7d" + append-timestamp: false + key: ${{ runner.os }}-${{ runner.arch }}-${{ matrix.runner.toolkit }}-${{ hashFiles('**/*.cc') }} + restore-keys: | + ${{ runner.os }}-${{ runner.arch }}-${{ matrix.runner.toolkit }}-${{ hashFiles('**/*.cc') }} + ${{ runner.os }}-${{ runner.arch }}-${{ matrix.runner.toolkit }} + ${{ runner.os }}-${{ runner.arch }} + + - name: Set environment (CUDA) + if: contains(matrix.runner.toolkit, 'CUDA') + run: | + TOOLKIT="${{ matrix.runner.toolkit }}" + CUDA_VERSION="${TOOLKIT##*-}" + CUDA_VERSION_MAJMIN="$(echo ${CUDA_VERSION} | cut -d '.' -f-2)" + CUDA_VERSION_MAJMIN_NODOT="${CUDA_VERSION_MAJMIN//./}" + if [[ "${TOOLKIT}" == "Nightly-"* ]]; then + # Use torch nightly builds + export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/nightly/cu${CUDA_VERSION_MAJMIN_NODOT}" + else + export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu${CUDA_VERSION_MAJMIN_NODOT}" + fi + export UV_INDEX="${PIP_EXTRA_INDEX_URL}" + export CLANG_TIDY_CMAKE_OPTIONS="${CLANG_TIDY_CMAKE_OPTIONS} -DUSE_CUDA=ON" + + echo "USE_CUDA=ON" | tee -a "${GITHUB_ENV}" + echo "CUDA_VERSION=${CUDA_VERSION}" | tee -a "${GITHUB_ENV}" + echo "CUDA_VERSION_MAJMIN=${CUDA_VERSION_MAJMIN}" | tee -a "${GITHUB_ENV}" + echo "CUDA_VERSION_MAJMIN_NODOT=${CUDA_VERSION_MAJMIN_NODOT}" | tee -a "${GITHUB_ENV}" + echo "PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}" | tee -a "${GITHUB_ENV}" + echo "UV_INDEX=${UV_INDEX}" | tee -a "${GITHUB_ENV}" + echo "CLANG_TIDY_CMAKE_OPTIONS=${CLANG_TIDY_CMAKE_OPTIONS}" | tee -a "${GITHUB_ENV}" + + if [[ ! -x "$(command -v nvcc)" ]]; then + export PATH="/usr/local/cuda/bin:${PATH}" + export LD_LIBRARY_PATH="/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" + echo "PATH=${PATH}" | tee -a "${GITHUB_ENV}" + echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" | tee -a "${GITHUB_ENV}" + fi + if [[ -x "$(command -v nvcc)" ]]; then + echo "\$ $(command -v nvcc) --version" && nvcc --version + else + echo "::warning::nvcc not found in PATH!" + fi + + - name: Setup Python and uv with caching + id: setup-uv + uses: astral-sh/setup-uv@v7 + with: + python-version: ${{ matrix.python-version }} + activate-environment: true + # Do not use cache for self-hosted runners, as it will download/upload caches which is slow. + enable-cache: ${{ !startsWith(matrix.runner.name, 'self-hosted') }} + prune-cache: ${{ !startsWith(matrix.runner.name, 'self-hosted') }} + # Use runner tool_cache for self-hosted runners + cache-local-path: ${{ env.UV_CACHE_DIR }} + ignore-nothing-to-cache: true + # Extra cache key to upload/download caches on GitHub-hosted runners + cache-suffix: uv-${{ runner.os }}-${{ runner.arch }}-${{ matrix.python-version }}-${{ matrix.runner.name }}-${{ matrix.runner.toolkit }} + cache-dependency-glob: | + pyproject.toml + requirements*.txt + + - name: Setup environments + id: setup-venv + run: | + set -e + + uv venv --python "${{ matrix.python-version }}" new + + source new/bin/activate + uv pip install -v -r requirements-test.txt + uv pip install -v . + + - name: Install Main version (Baseline) + run: | + set -e + git clean -dxf -e new/ -e .cache/ + git checkout main + git submodule update --init --recursive + uv venv --python "${{ matrix.python-version }}" old + source old/bin/activate + + uv pip install -v -r requirements-test.txt + uv pip install -v . + rm -rf tilelang build + + uv venv --python "${{ matrix.python-version }}" test_regression + source test_regression/bin/activate + uv pip install -v -r requirements-test.txt + + - name: Clear uv cache for self-hosted runners (if setup failed) + if: >- + ${{ + failure() && + startsWith(matrix.runner.name, 'self-hosted') && + (steps.setup-uv.conclusion == 'failure' || steps.setup-venv.conclusion == 'failure') + }} + run: | + echo "Clearing uv cache at ${UV_CACHE_DIR} due to failure." + uv cache clean + + - name: Enable core dump generation (Linux / GitHub-hosted runners) + if: ${{ runner.os == 'Linux' && !startsWith(matrix.runner.name, 'self-hosted') }} + run: | + sudo sysctl -w kernel.core_pattern="core.${{ matrix.python-version }}.${{ matrix.runner.toolkit }}.%P" + sudo sysctl -w kernel.core_uses_pid=0 + sudo sysctl -w fs.suid_dumpable=1 + sysctl kernel.core_pattern kernel.core_uses_pid fs.suid_dumpable + + - name: Enable core dump generation (macOS / GitHub-hosted runners) + if: ${{ runner.os == 'macOS' && !startsWith(matrix.runner.name, 'self-hosted') }} + run: | + sudo sysctl -w kern.corefile="core.${{ matrix.python-version }}.${{ matrix.runner.toolkit }}.%P" + sudo sysctl -w kern.coredump=1 + sudo sysctl -w kern.sugid_coredump=1 + sysctl kern.corefile kern.coredump kern.sugid_coredump + + - name: Run performance regression test + run: | + source test_regression/bin/activate + OLD_PYTHON=./old/bin/python NEW_PYTHON=./new/bin/python \ + PERF_REGRESSION_MD=regression_result.md PERF_REGRESSION_PNG=regression_result.png \ + python ./maint/scripts/test_perf_regression.py + + - name: Read markdown table + id: read_md + run: | + echo "content<> $GITHUB_OUTPUT + cat regression_result.md >> $GITHUB_OUTPUT + echo "EOF" >> $GITHUB_OUTPUT + + - name: Upload result image as artifact + uses: actions/upload-artifact@v6 + with: + name: perf-regression-${{ github.run_id }} + path: regression_result.png + + - name: Post test results as PR comment + uses: actions/github-script@v8 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + // Read the file directly instead of passing via env/outputs to avoid escaping issues + const md = fs.readFileSync('regression_result.md', 'utf8'); + + const runUrl = `${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}`; + + const body = + 'Performance Regression Test Report\n' + + '============================\n\n' + + `Triggered by: @${context.payload.comment.user.login}\n` + + `Workflow run: ${runUrl}\n\n` + + 'Results\n' + + '-------\n\n' + + md + '\n\n' + + 'Artifacts\n' + + '---------\n\n' + + '- regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.\n'; + + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body + }); diff --git a/.github/workflows/publish-docs.yml b/.github/workflows/publish-docs.yml index 953303102..2197015b6 100644 --- a/.github/workflows/publish-docs.yml +++ b/.github/workflows/publish-docs.yml @@ -25,7 +25,7 @@ jobs: runs-on: [self-hosted, nvidia] steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: recursive diff --git a/.gitignore b/.gitignore index 75aa07f82..e85c2c094 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,8 @@ debug/ build/ *dist/ +dist*/ +!distributed*/ wheelhouse/ __pycache__ nnfusion.tar.gz @@ -110,3 +112,24 @@ nvshmem_src/ # CMake cmake-build/ cmake-build-*/ + +# Git version for sdist +.git_commit.txt + +# pre-commit cache +.pre-commit-cache/* + +# host checks logs +maint/host_checks/logs/* + +# ncu +*.ncu-rep + +# csv +*.csv + +# clang-tidy +/run-clang-tidy.py + +# perf regression test +.perf_regression/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 99a05f4c6..f52f91b53 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,15 +13,13 @@ repos: hooks: - id: check-symlinks - id: destroyed-symlinks - # FIXME: enable these hooks - # - id: trailing-whitespace - # - id: end-of-file-fixer + - id: trailing-whitespace + - id: end-of-file-fixer - id: check-added-large-files - id: check-merge-conflict fail_fast: true - # FIXME: enable these hooks - # - id: check-executables-have-shebangs - # - id: check-shebang-scripts-are-executable + - id: check-executables-have-shebangs + - id: check-shebang-scripts-are-executable - id: detect-private-key - id: check-yaml - id: check-toml @@ -32,39 +30,30 @@ repos: args: [--ignore-case] files: ^docs/spelling_wordlist\.txt$ - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v21.1.2 # sync with requirements-lint.txt + rev: v21.1.7 # sync with requirements-lint.txt hooks: - id: clang-format - exclude: | - (?ix)( - ^.+\.(cu|cuh)$| - ^.+\.json$ - ) + types_or: [c++, c] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.1 # sync with requirements-lint.txt + rev: v0.14.9 # sync with requirements-lint.txt hooks: - id: ruff-check args: [--fix, --exit-non-zero-on-fix] - - repo: https://github.com/google/yapf - rev: v0.43.0 # sync with requirements-lint.txt - hooks: - - id: yapf - name: yapf-multiproc-bugfix - # yapf is not multiprocess safe, so we run a dummy yapf first. - args: [--in-place, docs/conf.py] - always_run: true - pass_filenames: false - - id: yapf - args: [--recursive, --in-place] + - id: ruff-format + args: [--exit-non-zero-on-format] - repo: https://github.com/codespell-project/codespell rev: v2.4.1 # sync with requirements-lint.txt hooks: - id: codespell additional_dependencies: [".[toml]"] - args: ["-L", "HDA"] exclude: | (?x)( ^.+\.(cpp|hpp|cxx|cc|c|h|cu|cuh)$| ^.+\.svg$| ^.*\brequirements\b.*\.txt$ ) + - repo: https://github.com/jackdewinter/pymarkdown + rev: v0.9.33 + hooks: + - id: pymarkdown + args: ["--config", ".pymarkdown", "fix"] diff --git a/.pymarkdown b/.pymarkdown new file mode 100644 index 000000000..5394265ed --- /dev/null +++ b/.pymarkdown @@ -0,0 +1,37 @@ +{ + "plugins": { + "md003": { + "style": "atx" + }, + "md004": { + "style": "dash" + }, + "md013": { + "enabled": false + }, + "md026": { + "enabled": false + }, + "md029": { + "enabled": false + }, + "md031": { + "enabled": false + }, + "md032": { + "enabled": false + }, + "md033": { + "enabled": false + }, + "md034": { + "enabled": false + }, + "md040": { + "enabled": false + }, + "md041": { + "enabled": false + } + } +} diff --git a/3rdparty/tvm b/3rdparty/tvm index 5bf17a346..23bce012f 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 5bf17a34602931e7d7e01cbccf358a21fe972779 +Subproject commit 23bce012ffd255a24289eea6ceab74a40b94a096 diff --git a/CMakeLists.txt b/CMakeLists.txt index afeccaceb..4fb370d50 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,6 +8,11 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND "$ENV{CIBUILDWHEEL}") + # Warning came from tvm submodule + string(APPEND CMAKE_CXX_FLAGS " -Wno-dangling-reference") +endif() + set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake) if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/.gitmodules" AND EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/.git") @@ -36,15 +41,74 @@ endif() find_program(CCACHE_PROGRAM ccache) if(CCACHE_PROGRAM) + message(STATUS "Using ccache: ${CCACHE_PROGRAM} with base_dir=${CMAKE_SOURCE_DIR}") + if(APPLE) + # Passing configs like `ccache base_dir=/xxx cc ...` is supported + # (likely) since ccache 4.x, which has been provided by homebrew. + # Our Linux builder image (manylinux2014 & manylinux_2_28) still + # provides ccache 3.x and do not support this form. + # `cibuildwheel` uses fixed folder on Linux (`/project`) as working directory, + # so cache would work without setting `base_dir`. + set(CCACHE_PROGRAM "${CCACHE_PROGRAM};base_dir=${CMAKE_SOURCE_DIR}") + endif() set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "C compiler launcher") set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CXX compiler launcher") set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CUDA compiler launcher") +else() + find_program(SCCACHE_PROGRAM sccache) + if(SCCACHE_PROGRAM) + message(STATUS "Using sccache: ${SCCACHE_PROGRAM}") + set(CMAKE_C_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "C compiler launcher") + set(CMAKE_CXX_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "CXX compiler launcher") + set(CMAKE_CUDA_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "CUDA compiler launcher") + endif() endif() # Configs -set(USE_CUDA OFF) -set(USE_ROCM OFF) -set(USE_METAL OFF) +set(TILELANG_BACKENDS CUDA ROCM METAL) + +set(TILELANG_BACKEND_DOC_CUDA "Enable CUDA backend (ON/OFF/or CUDA SDK path)") +set(TILELANG_BACKEND_DOC_ROCM "Enable ROCm backend (ON/OFF/or ROCm SDK path)") +set(TILELANG_BACKEND_DOC_METAL "Enable Metal backend") + +# TVM's config.cmake redefines USE_* options later, so we cache the user's choice +# (including explicit -DUSE_XXX arguments) before we include TVM and restore it +# afterwards. + +macro(tilelang_define_backend_option BACKEND) + set(_backend_var "USE_${BACKEND}") + set(_doc "${TILELANG_BACKEND_DOC_${BACKEND}}") + set(_user_override_var "TILELANG_USER_OVERRIDE_${_backend_var}") + + set(_user_override OFF) + if(DEFINED ${_user_override_var}) + set(_user_override "${${_user_override_var}}") + endif() + + if(DEFINED CACHE{${_backend_var}}) + get_property(_cache_type CACHE ${_backend_var} PROPERTY TYPE) + if(_cache_type STREQUAL "UNINITIALIZED") + set(_user_override ON) + endif() + endif() + + set(_default OFF) + if(DEFINED ${_backend_var}) + set(_default "${${_backend_var}}") + endif() + + option(${_backend_var} "${_doc}" "${_default}") + # Remember if the user explicitly set this option so that later logic + # won't auto-toggle backends they configured on the command line. + set(${_user_override_var} ${_user_override} CACHE INTERNAL + "User explicitly set ${_backend_var} during configuration" FORCE) + set(TILELANG_OPTION_${_backend_var} "${${_backend_var}}") +endmacro() + +foreach(BACKEND IN LISTS TILELANG_BACKENDS) + tilelang_define_backend_option(${BACKEND}) +endforeach() + set(PREBUILD_CYTHON ON) # Configs end @@ -55,6 +119,14 @@ if(EXISTS ${TVM_SOURCE}/cmake/config.cmake) else() message(FATAL_ERROR "Nor tvm provided or submodule checkout-ed.") endif() +# Re-apply TileLang's preferred backend settings after TVM's config may have +# overridden the USE_* cache entries. +foreach(BACKEND IN LISTS TILELANG_BACKENDS) + set(_backend_var "USE_${BACKEND}") + set(_doc "${TILELANG_BACKEND_DOC_${BACKEND}}") + set(${_backend_var} ${TILELANG_OPTION_${_backend_var}} CACHE STRING "${_doc}" FORCE) + set(${_backend_var} ${TILELANG_OPTION_${_backend_var}}) +endforeach() # Include directories for TileLang set(TILE_LANG_INCLUDES ${TVM_INCLUDES}) @@ -64,33 +136,50 @@ file(GLOB TILE_LANG_SRCS src/*.cc src/layout/*.cc src/transform/*.cc + src/transform/common/*.cc src/op/*.cc src/target/utils.cc + src/target/codegen_c_host.cc src/target/codegen_cpp.cc src/target/rt_mod_cpp.cc - # webgpu doesn't have system dependency - src/target/codegen_webgpu.cc # intrin_rule doesn't have system dependency src/target/intrin_rule*.cc ) -# Backend-specific checks and configs -if($ENV{USE_METAL}) - set(USE_METAL ON) -elseif(APPLE) - message(STATUS "Enable Metal support by default.") - set(USE_METAL ON) -elseif($ENV{USE_ROCM}) - set(USE_ROCM ON) -else() - if($ENV{USE_CUDA}) - set(USE_CUDA ON) - elseif(DEFINED ENV{USE_CUDA} AND NOT $ENV{USE_CUDA}) - # Build CPU-only when we explicitly disable CUDA - set(USE_CUDA OFF) +# Always include CPU-safe runtime helpers +list(APPEND TILE_LANG_SRCS + src/runtime/error_helpers.cc +) + +# Track if the user explicitly selected a backend via cache options. +set(TILELANG_BACKEND_USER_SELECTED OFF) +foreach(BACKEND IN LISTS TILELANG_BACKENDS) + set(_backend_var "USE_${BACKEND}") + set(_override_var "TILELANG_USER_OVERRIDE_${_backend_var}") + if(${_backend_var} OR ${_override_var}) + set(TILELANG_BACKEND_USER_SELECTED ON) + endif() +endforeach() + +# Only auto-select a backend when the user didn't specify one explicitly. +if(NOT TILELANG_BACKEND_USER_SELECTED) + if($ENV{USE_METAL}) + set(USE_METAL ON) + elseif(APPLE) + message(STATUS "Enable Metal support by default.") + set(USE_METAL ON) + elseif($ENV{USE_ROCM}) + set(USE_ROCM ON) else() - message(STATUS "Enable CUDA support by default.") - set(USE_CUDA ON) + if($ENV{USE_CUDA}) + set(USE_CUDA ON) + elseif(DEFINED ENV{USE_CUDA} AND NOT $ENV{USE_CUDA}) + # Build CPU-only when we explicitly disable CUDA + set(USE_CUDA OFF) + else() + message(STATUS "Enable CUDA support by default.") + set(USE_CUDA ON) + endif() endif() endif() @@ -104,7 +193,7 @@ if(USE_METAL) elseif(USE_ROCM) set(CMAKE_HIP_STANDARD 17) include(${TVM_SOURCE}/cmake/utils/FindROCM.cmake) - find_rocm($ENV{USE_ROCM}) + find_rocm(${USE_ROCM}) add_compile_definitions(__HIP_PLATFORM_AMD__ __HIP_PLATFORM_HCC__=1) file(GLOB TILE_LANG_HIP_SRCS @@ -123,16 +212,29 @@ elseif(USE_CUDA) cmake_path(GET CUDAToolkit_BIN_DIR PARENT_PATH USE_CUDA) file(GLOB TILE_LANG_CUDA_SRCS - src/runtime/*.cc + src/runtime/runtime.cc + src/runtime/tilescale_cuda_module.cc src/target/ptx.cc src/target/codegen_cuda.cc + src/target/codegen_py.cc + src/target/codegen_utils.cc + src/target/codegen_cutedsl.cc src/target/rt_mod_cuda.cc + src/target/rt_mod_cutedsl.cc ) list(APPEND TILE_LANG_SRCS ${TILE_LANG_CUDA_SRCS}) list(APPEND TILE_LANG_INCLUDES ${CUDAToolkit_INCLUDE_DIRS}) endif() +set(USE_Z3 ON CACHE STRING "Use Z3 SMT solver for TileLang optimizations") +set(USE_PYPI_Z3 ON CACHE BOOL "Use Z3 provided by PyPI z3-solver package") + +if(USE_Z3 AND USE_PYPI_Z3) + list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake/pypi-z3") + find_package(Z3 REQUIRED) +endif() + # Include tvm after configs have been populated add_subdirectory(${TVM_SOURCE} tvm EXCLUDE_FROM_ALL) @@ -140,7 +242,11 @@ add_subdirectory(${TVM_SOURCE} tvm EXCLUDE_FROM_ALL) add_compile_definitions(DMLC_USE_LOGGING_LIBRARY=) add_library(tilelang_objs OBJECT ${TILE_LANG_SRCS}) + +# Set debug mode compile definitions +# We open the deubg option of TVM, i.e. TVM_LOG_DEBUG if(CMAKE_BUILD_TYPE STREQUAL "Debug") + message(STATUS "Building TileLang with DEBUG mode") target_compile_definitions(tilelang_objs PRIVATE "TVM_LOG_DEBUG") endif() @@ -148,12 +254,20 @@ target_include_directories(tilelang_objs PRIVATE ${TILE_LANG_INCLUDES}) add_library(tilelang SHARED $) add_library(tilelang_module SHARED $) -target_link_libraries(tilelang PUBLIC tvm_runtime) +target_link_libraries(tilelang PUBLIC tvm_runtime tvm) target_link_libraries(tilelang_module PUBLIC tvm) -if(APPLE) - # FIXME: libtilelang should only link against tvm runtime - target_link_libraries(tilelang PUBLIC tvm) -endif() + +# Place dev build outputs under build/lib for consistency +set_target_properties(tilelang PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" +) +set_target_properties(tilelang_module PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" +) # Build cython extension find_package(Python REQUIRED COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT}) @@ -173,26 +287,112 @@ if(NOT "${SKBUILD_SABI_VERSION}" STREQUAL "") endif() python_add_library(tilelang_cython_wrapper MODULE "${CMAKE_BINARY_DIR}/tilelang_cython_wrapper.cpp" ${USE_SABI} WITH_SOABI) -# Install extension into the tilelang package directory + +# Ensure dev builds drop the extension into build/lib alongside other shared libs +set_target_properties(tilelang_cython_wrapper PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" +) + +# Install the extension into tilelang/lib inside the wheel install(TARGETS tilelang_cython_wrapper - LIBRARY DESTINATION tilelang - RUNTIME DESTINATION tilelang - ARCHIVE DESTINATION tilelang) + LIBRARY DESTINATION tilelang/lib + RUNTIME DESTINATION tilelang/lib + ARCHIVE DESTINATION tilelang/lib) + +# Copy libz3.so to build folder to workaround isolated build env issue +if(USE_Z3 AND USE_PYPI_Z3) + get_target_property(Z3_LIBRARY_PATH z3::libz3 IMPORTED_LOCATION) + install(FILES "${Z3_LIBRARY_PATH}" DESTINATION "${CMAKE_BINARY_DIR}/tvm") + if(APPLE) + set_target_properties(tvm PROPERTIES BUILD_RPATH "@loader_path") + else() + set_target_properties(tvm PROPERTIES BUILD_RPATH "\$ORIGIN") + endif() +endif() -# let libtilelang to search tvm/tvm_runtime in same dir if(APPLE) - set_target_properties(tilelang PROPERTIES INSTALL_RPATH "@loader_path") - set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "@loader_path") -else() - set_target_properties(tilelang PROPERTIES INSTALL_RPATH "\$ORIGIN") - set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "\$ORIGIN") + set(TILELANG_INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib") + if(USE_Z3 AND USE_PYPI_Z3) + # some z3 is placed in lib/ and some in bin/, we add both in rpath + list(APPEND TILELANG_INSTALL_RPATH "@loader_path/../../z3/lib" "@loader_path/../../z3/bin") + endif() +elseif(UNIX) + set(TILELANG_INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib") + if(USE_Z3 AND USE_PYPI_Z3) + # cmake uses ; by default, we explicitly use : for linux + string(APPEND TILELANG_INSTALL_RPATH ":\$ORIGIN/../../z3/lib") + endif() endif() -install(TARGETS tvm tvm_runtime tilelang_module tilelang LIBRARY DESTINATION tilelang/lib) +set_target_properties( + tilelang tilelang_module tvm tvm_runtime + PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}") -# Copy tvm cython ext for wheels -# TODO: not necessary for editable builds -if(TVM_BUILD_FROM_SOURCE) - add_dependencies(tilelang tvm_cython) - install(FILES "${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/python/tvm/ffi/core.abi3.so" DESTINATION tilelang/3rdparty/tvm/python/tvm/ffi/) +install( + TARGETS tvm tvm_runtime tilelang_module tilelang + LIBRARY DESTINATION tilelang/lib +) + +# Build tilescale_ext PyTorch C++ extension +if(USE_CUDA) + # Find Torch + execute_process( + COMMAND "${Python_EXECUTABLE}" -c "import torch; print(torch.utils.cmake_prefix_path)" + OUTPUT_VARIABLE TORCH_CMAKE_PREFIX_PATH + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE TORCH_CMAKE_RESULT + ) + if(TORCH_CMAKE_RESULT EQUAL 0 AND EXISTS "${TORCH_CMAKE_PREFIX_PATH}") + list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PREFIX_PATH}") + endif() + + find_package(Torch QUIET) + if(Torch_FOUND) + message(STATUS "Building tilescale_ext with Torch ${Torch_VERSION}") + + set(TILESCALE_EXT_SOURCES + ${CMAKE_CURRENT_SOURCE_DIR}/tilelang/utils/ts_ext/ts_ext_bindings.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/tilelang/utils/ts_ext/tensor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/tilelang/utils/ts_ext/ipc_ops.cpp + ) + + # Find libtorch_python.so + execute_process( + COMMAND "${Python_EXECUTABLE}" -c "import torch; import os; print(os.path.join(os.path.dirname(torch.__file__), 'lib', 'libtorch_python.so'))" + OUTPUT_VARIABLE TORCH_PYTHON_LIBRARY + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE TORCH_PYTHON_RESULT + ) + + python_add_library(tilescale_ext_C MODULE ${TILESCALE_EXT_SOURCES} WITH_SOABI) + target_compile_definitions(tilescale_ext_C PRIVATE TORCH_EXTENSION_NAME=_C) + target_include_directories(tilescale_ext_C PRIVATE + ${TORCH_INCLUDE_DIRS} + ${CUDAToolkit_INCLUDE_DIRS} + ) + + if(TORCH_PYTHON_RESULT EQUAL 0 AND EXISTS "${TORCH_PYTHON_LIBRARY}") + message(STATUS "Found libtorch_python: ${TORCH_PYTHON_LIBRARY}") + target_link_libraries(tilescale_ext_C PRIVATE ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY} CUDA::cudart) + else() + message(WARNING "libtorch_python.so not found, extension may have undefined symbols") + target_link_libraries(tilescale_ext_C PRIVATE ${TORCH_LIBRARIES} CUDA::cudart) + endif() + + target_compile_options(tilescale_ext_C PRIVATE -fPIC) + set_target_properties(tilescale_ext_C PROPERTIES + OUTPUT_NAME "_C" + CXX_STANDARD 17 + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + ) + + # Install as tilescale_ext/_C.so so it can be imported as tilescale_ext._C + install(TARGETS tilescale_ext_C + LIBRARY DESTINATION tilescale_ext + RUNTIME DESTINATION tilescale_ext) + else() + message(WARNING "Torch not found, tilescale_ext will not be built") + endif() endif() diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 9e380d831..5eba9044a 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -17,23 +17,23 @@ diverse, inclusive, and healthy community. Examples of behavior that contributes to a positive environment for our community include: -* Demonstrating empathy and kindness toward other people -* Being respectful of differing opinions, viewpoints, and experiences -* Giving and gracefully accepting constructive feedback -* Accepting responsibility and apologizing to those affected by our mistakes, +- Demonstrating empathy and kindness toward other people +- Being respectful of differing opinions, viewpoints, and experiences +- Giving and gracefully accepting constructive feedback +- Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience -* Focusing on what is best not just for us as individuals, but for the overall +- Focusing on what is best not just for us as individuals, but for the overall community Examples of unacceptable behavior include: -* The use of sexualized language or imagery, and sexual attention or advances of +- The use of sexualized language or imagery, and sexual attention or advances of any kind -* Trolling, insulting or derogatory comments, and personal or political attacks -* Public or private harassment -* Publishing others' private information, such as a physical or email address, +- Trolling, insulting or derogatory comments, and personal or political attacks +- Public or private harassment +- Publishing others' private information, such as a physical or email address, without their explicit permission -* Other conduct which could reasonably be considered inappropriate in a +- Other conduct which could reasonably be considered inappropriate in a professional setting ## Enforcement Responsibilities diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e4b45e24b..45284e980 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,7 +2,7 @@ That would be awesome if you want to contribute something to TileLang! -### Table of Contents +## Table of Contents - [Report Bugs](#report-bugs) - [Ask Questions](#ask-questions) @@ -81,6 +81,8 @@ in the main directory. This installation is removable by: python3 -m pip uninstall tilelang ``` +We also recommend installing TileLang in a more manual way for better control over the build process, by compiling the C++ extensions first and set the `PYTHONPATH`. See [Working from Source via `PYTHONPATH`](https://tilelang.com/get_started/Installation.html#working-from-source-via-pythonpath) for detailed instructions. + ## Lint Check To check the linting, run: diff --git a/LICENSE b/LICENSE index 2122252e9..09dd51c8c 100644 --- a/LICENSE +++ b/LICENSE @@ -1,7 +1,7 @@ MIT License Copyright (c) Tile-AI. - **During the period from December 1, 2024, to Mar 14, 2025, this project is + **During the period from December 1, 2024, to Mar 14, 2025, this project is subject to additional collaboration terms with Microsoft Corporation.** Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 88b206825..000000000 --- a/MANIFEST.in +++ /dev/null @@ -1,10 +0,0 @@ -include VERSION -include CMakeLists.txt -include requirements.txt -include requirements-test.txt -include requirements-dev.txt -include tilelang/jit/adapter/cython/cython_wrapper.pyx -recursive-include src * -recursive-include 3rdparty * -recursive-exclude 3rdparty/clang* * -recursive-exclude 3rdparty/llvm* * diff --git a/README.md b/README.md index 3962010df..886a14868 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,13 @@ # TileScale: Tile-based AI Compute at All Scales -TileScale is a distributed extension of TileLang. It expands TileLang's tile-level programming to multi-GPU, multi-node, and even distributed chip architecture scopes, with some new feature designs like tile-level communication and hierarchical programming introduced. +TileScale is a distributed extension of TileLang. It expands TileLang's tile-level programming to multi-GPU, multi-node, and even distributed chip architecture scopes, with some new feature designs like tile-level communication and hierarchical programming introduced. -TileScale is a distributed-native domain-specific language (DSL) and compiler stack designed for deep learning on next-generation distributed architectures. +TileScale is a distributed-native domain-specific language (DSL) and compiler stack designed for deep learning on next-generation distributed architectures. As AI model entering the "scaling-law" era, modern AI infrastructure is also scaling the computation across both intra-chip and inter-chip scopes. On one side, current large AI models are already executing on multiple GPUs or even multiple nodes connected by the high-performance links like NVLink or InfiniBand. On the other side, a bunch of next-gen AI accelerators are embracing new chip architectures—such as 3D IC, near/in-memory computing, wafer-scale accelerators, etc., which are all in distributed form inner the chip for better scalability. Together, these trends are shaping modern AI compute systems into a hybrid, multi-level of "distributed architecture". TileScale is the first programming and compiler stack to unify these intra-chip and inter-chip compute resources into a unified, hierarchical, distributed architecture, which virtualizes the whole distributed system as a unified "mega-device" to users. To facilitate programming, TileScale provides a set of consistent tile-level primitives across all hardware layers for compute, memory, and communication. Thus, users can just write tile-level computing logic or flow at certain layers of interest, then TileScale automatically compiles and optimizes the scheduling of computation, communication, memory access, and their overlap. The goal of TileScale is to define an open, streamlined programming model for future distributed architectures and systems, addressing the emerging needs of modern AI computation, such as fine-grained computation and communication overlap, flexible parallel mechanisms, dataflow computation, NUMA programming, etc. -#### The full technical white-paper is coming soon. +## The full technical white-paper is coming soon. ## Hierarchical Distributed Architecture (HDA) Unlike traditional GPU SIMT programming, which assumes thread-level computation on a single device, TileScale is designed to manage compute, memory, and communication across all hierarchical scales, from threads and PEs to dies, chips, and nodes. It introduces a unified virtual device architecture, called Hierarchical Distributed Architecture (HDA), to abstract these distributed systems. @@ -32,16 +32,15 @@ At each layer, the associated memory may be shared among all units or distribute Following the hierarchical hardware architecture, TileScale exposes a hierarchical programming interface. The fundamental unit of computation in TileScale is at the *tile* granularity. TileScale provides consistent tile-level compute, memory, and communication operators corresponding to each hardware scales.
TileScale Programming Interface
- -* *Compute*: A compute primitive takes input tensor tiles at certain memory layer and produces output tensor tiles. The same compute primitive can be used at different scale level, which will be translated to different implementations. A primitive at a high-level scale can be implemented by the lower-level-scale primitives. For example, a block-scale operator can be implemented by a group of warp-scale or thread-scale primitives. - -* *Memory*: The memory primitives are used to copy data tiles at certain memory layer, as well as to copy data tile between different memory layers. - -* *Communicate*: The communication primitives are used to transfer data tiles between compute units over the network, as well as to manage the synchronization. TileScale provides both basic peer-to-peer communication primitives as well as the collective communication primitives like AllReduce, All2All, etc., at a specific scale level. + +- *Compute*: A compute primitive takes input tensor tiles at certain memory layer and produces output tensor tiles. The same compute primitive can be used at different scale level, which will be translated to different implementations. A primitive at a high-level scale can be implemented by the lower-level-scale primitives. For example, a block-scale operator can be implemented by a group of warp-scale or thread-scale primitives. + +- *Memory*: The memory primitives are used to copy data tiles at certain memory layer, as well as to copy data tile between different memory layers. + +- *Communicate*: The communication primitives are used to transfer data tiles between compute units over the network, as well as to manage the synchronization. TileScale provides both basic peer-to-peer communication primitives as well as the collective communication primitives like AllReduce, All2All, etc., at a specific scale level. A primitive for a certain scale level may have multiple implementations. For example, a copy primitive could be implemented using TMA or LSU, while a remote copy across GPUs might be implemented using copy engines, TMA, or LSU. TileScale provides default implementations for each primitive, along with a compilation process to tune the best implementation. Users can also specify particular implementations through arguments in the tile primitives. -With this hierarchical interface, user can easily customize the computation at certain scale level. For example, we can leverage the DSMEM feature to implement a general cluster-scale GEMM primitive. - +With this hierarchical interface, user can easily customize the computation at certain scale level. For example, we can leverage the DSMEM feature to implement a general cluster-scale GEMM primitive. ## System Overview and Design
TileScale system overview @@ -60,7 +59,7 @@ The layout and partition dimensions are either automatically inferred through a
### Parallel task scheduling -TileScale introduces a *T.Scale* primitive to control which hardware scale the current computations are conducted on. +TileScale introduces a *T.Scale* primitive to control which hardware scale the current computations are conducted on. It follows the SPMD (Single Program Multiple Data) programming model that scale the specified computation to all parallel units at this level. For example, the following *T.gemm* represents a warp GEMM, which executes on all warps in parallel. ```python @@ -81,18 +80,18 @@ with T.Kernel( T.gemm(A, B, C) ``` #### Task(warp) specialization -Additionally, the T.Scale primitive can also return the rank and the total number of ranks of the current scale level. This allows you to easily leverage the rank index for task specialization, such as warp specialization or any other scale-level specialization. +Additionally, the T.Scale primitive can also return the rank and the total number of ranks of the current scale level. This allows you to easily leverage the rank index for task specialization, such as warp specialization or any other scale-level specialization. ```python # warp specialize example with T.Scale("warpgroup") as wg_id, wg_num: if wg_id == 0: - # do something + # do something else: # do other thing ``` #### MPI-style programming -Combined with the communication primitives, you can also implement MPI-like programs if a communication channel exists across those ranks. For those compute units without hardware links, TileScale can also implement software channels by passing data through lower-level memory. +Combined with the communication primitives, you can also implement MPI-like programs if a communication channel exists across those ranks. For those compute units without hardware links, TileScale can also implement software channels by passing data through lower-level memory. ```python # communication example: send data to neighbor GPU with T.Scale("device") as dev_id, dev_num: @@ -100,7 +99,7 @@ with T.Scale("device") as dev_id, dev_num: T.barrier() ``` -## Example: +## Example: ```python # Example of GEMM # 4-GPU Tensor Parallelism, using L2 to communicate @@ -119,12 +118,12 @@ def gemm( A_global = T.view(A, layout=T.FullCol) B_global = T.view(B, layout=T.FullRow) C_global = T.view(C, layout=T.Replica) - + with T.Scale("block"): A_local = T.alloc((block_M, block_K), dtype, level="l0") B_local = T.alloc((block_K, block_N), dtype, level="l0") C_local = T.alloc((block_M, block_N), accum_dtype, level="l0") - T.clear(C_local) + T.clear(C_local) for k in T.Pipelined(T.ceildiv(A_global.shape[1], block_K), num_stages=3): with T.Scale("warpgroup") as wg_id, wg_num: @@ -134,7 +133,7 @@ def gemm( T.copy(A_local_wg, A_global[by * block_M, k * block_K]) T.copy(B_local_wg, B_global[k * block_K, bx * block_N]) T.gemm(A_local_wg, B_local_wg, C_local_wg) - + # Allreduce C_local_wg through software-defined channel on L1 T.allreduce(C_local_wg) T.copy(C_global[by * block_M, bx * block_N], C_local) @@ -142,7 +141,7 @@ def gemm( with T.Scale("device") as dev_id, dev_num: # Allreduce C on L2 T.allreduce(C_global) - + ``` ```python # Example of FlashMLA @@ -156,8 +155,8 @@ def flash_mla( Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel( - device=(4), - block=(batch, heads // min(block_H, kv_group_num), + device=(4), + block=(batch, heads // min(block_H, kv_group_num), threads=256) ): with T.Scale("device"): @@ -182,8 +181,8 @@ def flash_mla( scores_scale = T.alloc([block_H], accum_dtype, level="l0") scores_sum = T.alloc([block_H], accum_dtype, level="l0") logsum = T.alloc([block_H], accum_dtype, level="l0") - - cur_kv_head = by // (kv_group_num // block_H) + + cur_kv_head = by // (kv_group_num // block_H) T.copy(Q_shared, Q_global[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :]) T.copy(Q_pe_shared, Q_pe_global[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :]) @@ -199,7 +198,7 @@ def flash_mla( T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) - + T.copy(scores_max_prev, scores_max) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -217,7 +216,7 @@ def flash_mla( T.copy(acc_s_cast_local[:, block_N // 2:block_N], acc_s_local, dst=(wg_id + 1) % wg_num) # Or, you can use high level cooperative primitive # T.allgather(acc_s_local), and Cast ... - + for i in T.Parallel(block_H): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_H, dim): diff --git a/THIRDPARTYNOTICES.txt b/THIRDPARTYNOTICES.txt index b7c481841..3558662a8 100644 --- a/THIRDPARTYNOTICES.txt +++ b/THIRDPARTYNOTICES.txt @@ -1,5 +1,5 @@ -BitBLAS uses third-party material as listed below. The attached notices are -provided for informational purposes only. +BitBLAS uses third-party material as listed below. The attached notices are +provided for informational purposes only. Notice for apache/tvm ------------------------------- diff --git a/VERSION b/VERSION index 70f6c676e..e52aba075 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.6.post1 +0.1.7.post1 diff --git a/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py index 6401276ac..3dd82aa5e 100644 --- a/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py @@ -7,10 +7,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -28,15 +25,15 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F def benchmark_topk_sparse_attention(): from benchmark_configs import configs + torch.manual_seed(0) # Config for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: - # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) import flash_attn diff --git a/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py index aefe4d420..0018e9c93 100644 --- a/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py @@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -39,16 +36,15 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) block_N = 64 num_stages = 2 threads = 128 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] block_mask_shape = [batch, heads, downsample_len, downsample_len] - dtype = "float16" - accum_dtype = "float" - block_mask_dtype = "bool" + dtype = T.float16 + accum_dtype = T.float32 + block_mask_dtype = T.bool def kernel_func(block_M, block_N, num_stages, threads): - @T.macro def MMA0( K: T.Tensor(shape, dtype), @@ -60,11 +56,10 @@ def MMA0( by: T.int32, bz: T.int32, ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -79,22 +74,24 @@ def MMA1( by: T.int32, bz: T.int32, ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # in the first ceil_div(kBlockM, kBlockN) steps. @@ -114,22 +111,21 @@ def Softmax( @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), - Output: T.Tensor(shape, dtype), + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), + Output: T.Tensor(shape, dtype), ): - with T.Kernel( - T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -142,31 +138,29 @@ def main( scores_scale = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - block_mask = T.alloc_local([downsample_len], block_mask_dtype) + block_mask = T.alloc_fragment([downsample_len], block_mask_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - for vj in T.serial(downsample_len): - block_mask[vj] = BlockSparseMask[bz, by, bx, vj] + T.copy(BlockSparseMask[bz, by, bx, :], block_mask) loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): - if block_mask[k]: + if block_mask[k] != 0: MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, - scores_sum, logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main @@ -175,26 +169,23 @@ def main( def benchmark_topk_sparse_attention(): from benchmark_configs import configs + torch.manual_seed(0) # Config for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: - # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) - program = blocksparse_flashattn( - BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) kernel = tilelang.compile(program, out_idx=4) def benchmark_fn(): diff --git a/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py index e4828ce5f..85d754ae3 100644 --- a/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py @@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -31,39 +28,37 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F def benchmark_topk_sparse_attention(): from benchmark_configs import configs + torch.manual_seed(0) # Config for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: - # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) def benchmark_fn(): # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal # PyTorch reference implementation - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale - attn = attn.masked_fill(~full_mask, float('-inf')) + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) return ref_output ref_latency = do_bench( diff --git a/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py index 86ac894bc..7ebca93a6 100644 --- a/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py @@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -56,7 +53,6 @@ def _fwd_kernel_inner( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) if mask_val == True: @@ -72,8 +68,7 @@ def _fwd_kernel_inner( # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N if LAST_K_BLOCK: - qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, - float('-inf')) + qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf")) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_ij[:, None] @@ -153,7 +148,7 @@ def _fwd_kernel( v_ptrs = V + off_v mask_ptrs = block_mask_ptr + start_m * stride_bmm - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) @@ -191,24 +186,12 @@ def _fwd_kernel( acc = acc * l_recip acc = acc.to(Out.dtype.element_ty) - off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[ - None, :] * stride_od + off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) -def _forward(ctx, - q, - k, - v, - block_sparse_mask, - sm_scale, - BLOCK_M=64, - BLOCK_N=64, - num_warps=None, - num_stages=1, - out=None): - +def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None): assert q.shape[-1] == k.shape[-1] == v.shape[-1] assert k.shape[2] == v.shape[2] o = out if out is not None else torch.empty_like(q).contiguous() @@ -253,7 +236,6 @@ def _forward(ctx, class _sparse_attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, block_sparse_dense, sm_scale): # shape constraints @@ -271,24 +253,22 @@ def backward(ctx, do): def benchmark_topk_sparse_attention(): from benchmark_configs import configs + torch.manual_seed(0) # Config for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: - # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) diff --git a/benchmark/distributed/README.md b/benchmark/distributed/README.md index ac1cea257..21db28531 100644 --- a/benchmark/distributed/README.md +++ b/benchmark/distributed/README.md @@ -1 +1 @@ -To compare with [TileLink](https://arxiv.org/abs/2503.20313), please install [Triton-distributed](https://github.com/ByteDance-Seed/Triton-distributed). \ No newline at end of file +To compare with [TileLink](https://arxiv.org/abs/2503.20313), please install [Triton-distributed](https://github.com/ByteDance-Seed/Triton-distributed). diff --git a/benchmark/distributed/benchmark_ag_gemm.py b/benchmark/distributed/benchmark_ag_gemm.py index a4b0bd785..8ac3c244e 100644 --- a/benchmark/distributed/benchmark_ag_gemm.py +++ b/benchmark/distributed/benchmark_ag_gemm.py @@ -1,4 +1,4 @@ -'''Bugfix first: +"""Bugfix first: Triton-distributed/python/triton_dist/kernels/nvidia/allgather_gemm.py:566 ```python M = M_per_rank * ctx.num_ranks @@ -7,9 +7,9 @@ ```python M = M_per_rank * num_ranks ``` -''' +""" -#TODO: further tune the performance +# TODO: further tune the performance import argparse import torch @@ -27,36 +27,27 @@ @tilelang.jit( out_idx=-1, - pass_configs={"tl.disable_rdc": True} - #FIXME: https://github.com/tile-ai/tilelang/issues/659 + pass_configs={"tl.disable_rdc": True}, + # FIXME: https://github.com/tile-ai/tilelang/issues/659 ) -def matmut_transpose(rank, - num_ranks, - M, - N_per_rank, - K, - block_M, - block_N, - block_K, - dtype="float16", - threads=256, - persistent=False) -> tilelang.JITKernel: +def matmut_transpose( + rank, num_ranks, M, N_per_rank, K, block_M, block_N, block_K, dtype="float16", threads=256, persistent=False +) -> tilelang.JITKernel: accum_dtype = "float32" signal_dtype = "uint64" # NVSHMEM requires uint64 for signal assert M % block_M == 0 and N_per_rank % block_N == 0 and K % block_K == 0 - M_blocks, N_blocks, K_stages = T.ceildiv(M, block_M), T.ceildiv(N_per_rank, - block_N), T.ceildiv(K, block_K) + M_blocks, N_blocks, K_stages = T.ceildiv(M, block_M), T.ceildiv(N_per_rank, block_N), T.ceildiv(K, block_K) M_blocks_per_rank = M_blocks // num_ranks sm_num = driver.get_num_sms() # Get # of SMs for persistent kernel @T.prim_func def nonpersistent_kernel( - A: T.Tensor((M, K), dtype), # type: ignore - B: T.Tensor((N_per_rank, K), dtype), # type: ignore - signal: T.Tensor((num_ranks), signal_dtype), # type: ignore - C: T.Tensor((M, N_per_rank), dtype), # type: ignore + A: T.Tensor((M, K), dtype), # type: ignore + B: T.Tensor((N_per_rank, K), dtype), # type: ignore + signal: T.Tensor((num_ranks), signal_dtype), # type: ignore + C: T.Tensor((M, N_per_rank), dtype), # type: ignore ): with T.Kernel(N_blocks, M_blocks, threads=threads) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -81,10 +72,10 @@ def nonpersistent_kernel( @T.prim_func def persistent_kernel( - A: T.Tensor((M, K), dtype), # type: ignore - B: T.Tensor((N_per_rank, K), dtype), # type: ignore - signal: T.Tensor((num_ranks), signal_dtype), # type: ignore - C: T.Tensor((M, N_per_rank), dtype), # type: ignore + A: T.Tensor((M, K), dtype), # type: ignore + B: T.Tensor((N_per_rank, K), dtype), # type: ignore + signal: T.Tensor((num_ranks), signal_dtype), # type: ignore + C: T.Tensor((M, N_per_rank), dtype), # type: ignore ): with T.Kernel(sm_num, threads=threads) as (block_id): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -145,9 +136,10 @@ def overlapped_ag_gemm( block_K=64, dtype=dtype, threads=threads, - persistent=persistent) + persistent=persistent, + ) if RANK == 0 and args.print_source: - print('We currently use cp-engine for producer, print consumer kernel code only...') + print("We currently use cp-engine for producer, print consumer kernel code only...") print(consumer.get_kernel_source()) ag_buffer = pynvshmem.nvshmem_create_tensor_list_intra_node( @@ -164,14 +156,13 @@ def overlapped_ag_gemm( gemm_stream.wait_stream(current_stream) with torch.cuda.stream(ag_stream): - ag_buffer[rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A) + ag_buffer[rank][rank * M_per_rank : (rank + 1) * M_per_rank, :].copy_(A) pynvshmem.write64_on_stream(signal_buffer[rank], 1, ag_stream) - pynvshmem.nvshmemx_barrier_all_on_stream( - ag_stream.cuda_stream) # Ensure visible to all ranks + pynvshmem.nvshmemx_barrier_all_on_stream(ag_stream.cuda_stream) # Ensure visible to all ranks rank_orders = [(rank + i) % num_ranks for i in range(1, num_ranks)] for src_rank in rank_orders: - dst = ag_buffer[rank][src_rank * M_per_rank:(src_rank + 1) * M_per_rank, :] - src = ag_buffer[src_rank][src_rank * M_per_rank:(src_rank + 1) * M_per_rank, :] + dst = ag_buffer[rank][src_rank * M_per_rank : (src_rank + 1) * M_per_rank, :] + src = ag_buffer[src_rank][src_rank * M_per_rank : (src_rank + 1) * M_per_rank, :] dst.copy_(src) pynvshmem.write64_on_stream(signal_buffer[src_rank], 1, ag_stream) @@ -188,19 +179,17 @@ def parse_args(): parser.add_argument("--M", type=int, default=8192) parser.add_argument("--N", type=int, default=49152) parser.add_argument("--K", type=int, default=12288) - parser.add_argument( - "--dtype", type=str, default="float16", choices=["float16", "float32", "bfloat16"]) + parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "float32", "bfloat16"]) parser.add_argument("--threads", type=int, default=256, help="number of threads in a block") - parser.add_argument( - "--persistent", action='store_true', default=False, help="use persistent GEMM consumers") + parser.add_argument("--persistent", action="store_true", default=False, help="use persistent GEMM consumers") parser.add_argument("--print_source", action="store_true", help="print kernel source code") parser.add_argument("--warmup", type=int, default=5, help="number of warmup iterations") parser.add_argument("--repeat", type=int, default=10, help="number of repeat iterations") return parser.parse_args() -if __name__ == '__main__': - assert torch.cuda.get_device_capability()[0] >= 9, '❗This benchmark requires sm_90 or higher' +if __name__ == "__main__": + assert torch.cuda.get_device_capability()[0] >= 9, "❗This benchmark requires sm_90 or higher" WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP = init_distributed(return_tp_group=True) assert WORLD_SIZE <= 8, "This benchmark is designed for intra-node AG-GEMM" @@ -231,12 +220,10 @@ def torch_ag_gemm(): # Benchmark Triton-dist (overlapped) ag_intranode_stream = torch.cuda.Stream(priority=-1) - ctx = create_ag_gemm_context( - A, B, RANK, PE_num, max_M=M, for_correctness=False, ag_intranode_stream=ag_intranode_stream) + ctx = create_ag_gemm_context(A, B, RANK, PE_num, max_M=M, for_correctness=False, ag_intranode_stream=ag_intranode_stream) def triton_ag_gemm(persistent, autotune): - return ag_gemm( - A, B, ctx=ctx, rank=RANK, num_ranks=PE_num, persistent=persistent, autotune=autotune) + return ag_gemm(A, B, ctx=ctx, rank=RANK, num_ranks=PE_num, persistent=persistent, autotune=autotune) dist.barrier(TP_GROUP) triton_ag_gemm = partial(triton_ag_gemm, persistent=False, autotune=False) @@ -257,8 +244,7 @@ def tilelang_ag_gemm(): print(f"rank {RANK} tilelang AG-GEMM avg time: {tl_t} ms") # Check correctness - assert torch.allclose( - tl_out, torch_out, atol=1e-2, rtol=1e-2), f'max error: {(tl_out - torch_out).abs().max()}' + assert torch.allclose(tl_out, torch_out, atol=1e-2, rtol=1e-2), f"max error: {(tl_out - torch_out).abs().max()}" print(f"rank {RANK} check passed.✅") dist.destroy_process_group() diff --git a/benchmark/distributed/benchmark_all_gather.py b/benchmark/distributed/benchmark_all_gather.py index 24d3445b2..676ad4853 100644 --- a/benchmark/distributed/benchmark_all_gather.py +++ b/benchmark/distributed/benchmark_all_gather.py @@ -30,9 +30,8 @@ def cp_engine_producer_all_gather_full_mesh_pull( if src_rank == rank: continue # peer: src_rank, offset src_rank[src_rank] -> rank[src_rank] - dst = remote_tensor_buffers[rank][src_rank * M_per_rank:(src_rank + 1) * M_per_rank, :] - src = remote_tensor_buffers[src_rank][src_rank * M_per_rank:(src_rank + 1) * - M_per_rank, :] + dst = remote_tensor_buffers[rank][src_rank * M_per_rank : (src_rank + 1) * M_per_rank, :] + src = remote_tensor_buffers[src_rank][src_rank * M_per_rank : (src_rank + 1) * M_per_rank, :] dst.copy_(src) pynvshmem.write64_on_stream( barrier_buffers[rank][src_rank], @@ -47,8 +46,8 @@ def allgather(PE_num, M, N, dtype="float16", threads=128): @T.prim_func def a2a_pull( - A: T.Tensor((M_per_rank, N), dtype), # type: ignore - B: T.Tensor((M, N), dtype), # type: ignore + A: T.Tensor((M_per_rank, N), dtype), # type: ignore + B: T.Tensor((M, N), dtype), # type: ignore ): with T.Kernel(M_per_rank // block_M, PE_num - 1, threads=threads) as (bx, by): mype = T.get_pe() @@ -57,7 +56,10 @@ def a2a_pull( T.getmem_nbi_block( T.address_of(B[peer * M_per_rank + bx * block_M, 0]), - T.address_of(A[bx * block_M, 0]), block_M * N * dtype_map[dtype].itemsize, peer) + T.address_of(A[bx * block_M, 0]), + block_M * N * dtype_map[dtype].itemsize, + peer, + ) # We don't need a barrier for the pull mode return a2a_pull @@ -65,12 +67,9 @@ def a2a_pull( def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--M", type=int, - default=8192) # Follow Triton-setting, we benchmark on (M, N) = (8192, 12288) + parser.add_argument("--M", type=int, default=8192) # Follow Triton-setting, we benchmark on (M, N) = (8192, 12288) parser.add_argument("--N", type=int, default=12288) - parser.add_argument( - "--dtype", type=str, default="float16", choices=["float16", "float32", "bfloat16"]) + parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "float32", "bfloat16"]) parser.add_argument("--threads", type=int, default=128, help="number of threads in a block") parser.add_argument("--print_source", action="store_true", help="print kernel source code") parser.add_argument("--warmup", type=int, default=5, help="number of warmup iterations") @@ -78,7 +77,7 @@ def parse_args(): return parser.parse_args() -if __name__ == '__main__': +if __name__ == "__main__": WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP = init_distributed(return_tp_group=True) assert WORLD_SIZE <= 8, "This benchmark is designed for intra-node communication" @@ -111,13 +110,9 @@ def torch_ag(): # Benchmark Triton-dist def triton_ag(): - ag_buffer_ptrs = pynvshmem.nvshmem_create_tensor_list_intra_node( - [M, N], torch_dtype) # buffer for dist-triton allgather - signal = pynvshmem.nvshmem_create_tensor_list_intra_node( - ([PE_num]), torch.uint64) # each rank corresponds to one barrier - ag_buffer_ptrs[RANK][ - RANK * M_per_rank:(RANK + 1) * M_per_rank, - ].copy_(local_data) + ag_buffer_ptrs = pynvshmem.nvshmem_create_tensor_list_intra_node([M, N], torch_dtype) # buffer for dist-triton allgather + signal = pynvshmem.nvshmem_create_tensor_list_intra_node(([PE_num]), torch.uint64) # each rank corresponds to one barrier + ag_buffer_ptrs[RANK][RANK * M_per_rank : (RANK + 1) * M_per_rank,].copy_(local_data) signal[RANK].zero_() pynvshmem.nvshmemx_barrier_all_on_stream(torch.cuda.current_stream().cuda_stream) cp_engine_producer_all_gather_full_mesh_pull( @@ -134,7 +129,7 @@ def tilelang_ag(): ag_buffer = pynvshmem.nvshmem_create_tensor([M_per_rank, N], torch_dtype) ag_buffer.copy_(local_data) out = pynvshmem.nvshmem_create_tensor([M, N], torch_dtype) - out[RANK * M_per_rank:(RANK + 1) * M_per_rank, :].copy_(local_data) + out[RANK * M_per_rank : (RANK + 1) * M_per_rank, :].copy_(local_data) kernel(ag_buffer, out) return out @@ -145,8 +140,7 @@ def tilelang_ag(): # Tested on 4A100 with full-mesh NVLink, comparable with Triton-dist and ~20x faster than Torch # Check correctness - assert torch.allclose( - tl_out, torch_out, atol=0, rtol=0), f'max error: {(tl_out - torch_out).abs().max()}' + assert torch.allclose(tl_out, torch_out, atol=0, rtol=0), f"max error: {(tl_out - torch_out).abs().max()}" print(f"rank {RANK} check passed.✅") dist.destroy_process_group() diff --git a/benchmark/distributed/benchmark_all_to_all.py b/benchmark/distributed/benchmark_all_to_all.py index 6aae8b203..d2d0ded3a 100644 --- a/benchmark/distributed/benchmark_all_to_all.py +++ b/benchmark/distributed/benchmark_all_to_all.py @@ -13,19 +13,18 @@ def all_to_all(max_m, hidden, num_tot_experts, WORLD_SIZE, threads=128, dtype="float16"): - scale_dtype = "float" EXPERTS_PER_RANK = num_tot_experts // WORLD_SIZE @T.prim_func def main( - send_buf: T.Tensor((max_m, hidden), dtype), # type: ignore - recv_buf: T.Tensor((WORLD_SIZE * max_m * 2, hidden), dtype), # type: ignore - scale_send_buf: T.Tensor((max_m), scale_dtype), # type: ignore - scale_recv_buf: T.Tensor((WORLD_SIZE * max_m * 2), scale_dtype), # type: ignore - split_send_buf: T.Tensor((num_tot_experts), "int32"), # type: ignore - split_recv_buf: T.Tensor((num_tot_experts * 2), "int32"), # type: ignore - signal_buf: T.Tensor((WORLD_SIZE * 2), "uint64"), # type: ignore + send_buf: T.Tensor((max_m, hidden), dtype), # type: ignore + recv_buf: T.Tensor((WORLD_SIZE * max_m * 2, hidden), dtype), # type: ignore + scale_send_buf: T.Tensor((max_m), scale_dtype), # type: ignore + scale_recv_buf: T.Tensor((WORLD_SIZE * max_m * 2), scale_dtype), # type: ignore + split_send_buf: T.Tensor((num_tot_experts), "int32"), # type: ignore + split_recv_buf: T.Tensor((num_tot_experts * 2), "int32"), # type: ignore + signal_buf: T.Tensor((WORLD_SIZE * 2), "uint64"), # type: ignore ): with T.Kernel(WORLD_SIZE, threads=threads) as (bx): peer = bx @@ -63,17 +62,14 @@ def main( class TilelangAllToAll: - def __init__(self, ctx: AllToAllContext): self.ctx = ctx - self.func = all_to_all( - ctx.max_m, ctx.hidden, ctx.num_tot_experts, ctx.WORLD_SIZE, threads=128) + self.func = all_to_all(ctx.max_m, ctx.hidden, ctx.num_tot_experts, ctx.WORLD_SIZE, threads=128) self.kernel = tilelang.compile(self.func, pass_configs={"tl.disable_tma_lower": True}) if self.ctx.rank == 0: print(self.kernel.get_kernel_source()) - def __call__(self, send_tensor: torch.Tensor, send_split_cumsum: torch.Tensor, - send_scale: torch.Tensor | None): + def __call__(self, send_tensor: torch.Tensor, send_split_cumsum: torch.Tensor, send_scale: torch.Tensor | None): """ low-latency all-to-all communication """ @@ -161,7 +157,6 @@ def calc_gather_index( row_end: int, BLOCK_SIZE: int = 1024, ): - @triton.jit def _kernel( scatter_index: torch.Tensor, @@ -202,8 +197,7 @@ def _kernel( def calc_scatter_index_stable(choosed_experts: torch.Tensor): - return (choosed_experts.flatten().argsort(stable=True).argsort().int().view( - choosed_experts.shape)) + return choosed_experts.flatten().argsort(stable=True).argsort().int().view(choosed_experts.shape) def main(): @@ -227,7 +221,6 @@ def main(): ) def perf_triton(input: torch.Tensor, scale_tensor: torch.Tensor, exp_indices: torch.Tensor): - # prepare the indexes splits_gpu_cur_rank = torch.bincount(exp_indices.view(-1), minlength=args.G).to(torch.int32) split_cumsum = splits_to_cumsum(splits_gpu_cur_rank) @@ -237,20 +230,17 @@ def perf_triton(input: torch.Tensor, scale_tensor: torch.Tensor, exp_indices: to # calculate the gather idx accordingly gather_idx_cur_rank, _ = calc_gather_index(scatter_idx_cur_rank, 0, token_num * args.topk) # use torch native scatter forward(will not be included in the e2e time measurement) - scattered_input = torch.empty( - input.size(0) * args.topk, input.size(1), dtype=input.dtype, device=input.device) + scattered_input = torch.empty(input.size(0) * args.topk, input.size(1), dtype=input.dtype, device=input.device) scattered_scale_tensor = torch.empty( (scale_tensor.size(0) * args.topk), dtype=scale_tensor.dtype, device=scale_tensor.device, ) scattered_input.copy_(torch.index_select(input, dim=0, index=gather_idx_cur_rank)) - scattered_scale_tensor.copy_( - torch.index_select(scale_tensor, dim=0, index=gather_idx_cur_rank)) + scattered_scale_tensor.copy_(torch.index_select(scale_tensor, dim=0, index=gather_idx_cur_rank)) def fwd(): - return fast_all_to_all(all_to_all_ctx, scattered_input, split_cumsum, - scattered_scale_tensor if args.with_scale else None) + return fast_all_to_all(all_to_all_ctx, scattered_input, split_cumsum, scattered_scale_tensor if args.with_scale else None) torch.cuda._sleep(1000000000) # warmup @@ -269,21 +259,22 @@ def fwd(): # 1. dispatch dispatch_splits, dispatch_token, dispatch_scale = fast_all_to_all( - all_to_all_ctx, scattered_input, split_cumsum, - scattered_scale_tensor if args.with_scale else None) + all_to_all_ctx, scattered_input, split_cumsum, scattered_scale_tensor if args.with_scale else None + ) dispatch_token, dispatch_scale = all_to_all_post_process( - all_to_all_ctx, dispatch_splits, dispatch_token, - dispatch_scale if args.with_scale else None) + all_to_all_ctx, dispatch_splits, dispatch_token, dispatch_scale if args.with_scale else None + ) # 2. compute: moe_compute(dispatch_token, dispatch_scale, moe_weight, ...) # ... # 3. combine combine_splits, combine_token, combine_scale = fast_all_to_all( - all_to_all_ctx, dispatch_token, splits_to_cumsum(dispatch_splits), dispatch_scale) + all_to_all_ctx, dispatch_token, splits_to_cumsum(dispatch_splits), dispatch_scale + ) combine_token, combine_scale = all_to_all_post_process( - all_to_all_ctx, combine_splits, combine_token, - combine_scale if args.with_scale else None) + all_to_all_ctx, combine_splits, combine_token, combine_scale if args.with_scale else None + ) # 3.1. reduce: [num_tokens_local_rank * topk] => [num_tokens_local_rank] combine_reduced_out = torch.zeros_like(input) @@ -293,8 +284,7 @@ def fwd(): torch.testing.assert_close(combine_reduced_out, input * args.topk, rtol=1e-2, atol=1e-2) tilelang_all_to_all = TilelangAllToAll(all_to_all_ctx) - tilelang_all_to_all(scattered_input, split_cumsum, - scattered_scale_tensor if args.with_scale else None) + tilelang_all_to_all(scattered_input, split_cumsum, scattered_scale_tensor if args.with_scale else None) # torch.testing.assert_close(tilelang_out[1], dispatch_token, rtol=1e-2, atol=1e-2) # torch.testing.assert_close(tilelang_scale, dispatch_scale, rtol=1e-2, atol=1e-2) @@ -307,8 +297,7 @@ def fwd(): exp_indices = generate_random_exp_indices(token_num, args.G, args.topk) assert exp_indices.size(0) == token_num and exp_indices.size(1) == args.topk exp_indices = exp_indices.to("cuda") - input = ( - torch.rand(token_num, args.N, dtype=torch.float32).to(dtype_map[args.dtype]).to("cuda")) + input = torch.rand(token_num, args.N, dtype=torch.float32).to(dtype_map[args.dtype]).to("cuda") scale_tensor = torch.rand(token_num, dtype=torch.float32).to("cuda") torch.cuda.synchronize() diff --git a/benchmark/distributed/benchmark_gemm_rs.py b/benchmark/distributed/benchmark_gemm_rs.py index 5be4431c3..a4570d2f4 100644 --- a/benchmark/distributed/benchmark_gemm_rs.py +++ b/benchmark/distributed/benchmark_gemm_rs.py @@ -1,6 +1,6 @@ # Currently we only implement in Tilelang -#TODO: add Triton-dist v3.4 impl -#TODO: further tune the performance +# TODO: add Triton-dist v3.4 impl +# TODO: further tune the performance import argparse import torch @@ -8,40 +8,33 @@ import pynvshmem import tilelang import tilelang.language as T + # from tilelang.carver.arch import driver from tilelang.distributed import init_distributed, dtype_map, perf_fn tilelang.disable_cache() -@tilelang.jit(pass_configs={"tl.disable_rdc": True} - #FIXME: https://github.com/tile-ai/tilelang/issues/659 - ) -def fused_gemm_scatter(rank, - num_ranks, - M, - N, - K_per_rank, - block_M, - block_N, - block_K, - dtype="float16", - threads=128, - persistent=False) -> tilelang.JITKernel: +@tilelang.jit( + pass_configs={"tl.disable_rdc": True} + # FIXME: https://github.com/tile-ai/tilelang/issues/659 +) +def fused_gemm_scatter( + rank, num_ranks, M, N, K_per_rank, block_M, block_N, block_K, dtype="float16", threads=128, persistent=False +) -> tilelang.JITKernel: accum_dtype = "float32" assert M % block_M == 0 and N % block_N == 0 and K_per_rank % block_K == 0 - M_blocks, N_blocks, K_stages = T.ceildiv(M, block_M), T.ceildiv(N, block_N), T.ceildiv( - K_per_rank, block_K) + M_blocks, N_blocks, K_stages = T.ceildiv(M, block_M), T.ceildiv(N, block_N), T.ceildiv(K_per_rank, block_K) M_blocks_per_rank = M_blocks // num_ranks # sm_num = driver.get_num_sms() # Get # of SMs for persistent kernel @T.prim_func def nonpersistent_kernel( - A: T.Tensor((M, K_per_rank), dtype), # type: ignore - B: T.Tensor((N, K_per_rank), dtype), # type: ignore - C: T.Tensor((M_blocks, N_blocks, block_M, block_N), dtype), # type: ignore + A: T.Tensor((M, K_per_rank), dtype), # type: ignore + B: T.Tensor((N, K_per_rank), dtype), # type: ignore + C: T.Tensor((M_blocks, N_blocks, block_M, block_N), dtype), # type: ignore ): with T.Kernel(N_blocks, M_blocks, threads=threads) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -63,8 +56,8 @@ def nonpersistent_kernel( T.copy(C_shared, C[by, bx, :, :]) peer = by // M_blocks_per_rank T.putmem_nbi_block( - T.address_of(C[by, bx, 0, 0]), T.address_of(C[by, bx, 0, 0]), - block_M * block_N * dtype_map[dtype].itemsize, peer) + T.address_of(C[by, bx, 0, 0]), T.address_of(C[by, bx, 0, 0]), block_M * block_N * dtype_map[dtype].itemsize, peer + ) assert not persistent return nonpersistent_kernel @@ -110,10 +103,10 @@ def overlapped_gemm_rs( block_K=block_K, dtype=dtype, threads=threads, - persistent=persistent) + persistent=persistent, + ) - gemm_output = pynvshmem.nvshmem_create_tensor_list_intra_node( - [M_blocks, N_blocks, block_M, block_N], dtype=input.dtype) + gemm_output = pynvshmem.nvshmem_create_tensor_list_intra_node([M_blocks, N_blocks, block_M, block_N], dtype=input.dtype) output = torch.empty((M_per_rank, N), dtype=input.dtype, device="cuda") fused_gemm_scatter_kernel(input, weight, gemm_output[rank]) dist.barrier(TP_GROUP) @@ -126,19 +119,17 @@ def parse_args(): parser.add_argument("--M", type=int, default=16384) parser.add_argument("--N", type=int, default=12288) parser.add_argument("--K", type=int, default=49152) - parser.add_argument( - "--dtype", type=str, default="float16", choices=["float16", "float32", "bfloat16"]) + parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "float32", "bfloat16"]) parser.add_argument("--threads", type=int, default=128, help="number of threads in a block") - parser.add_argument( - "--persistent", action='store_true', default=False, help="use persistent GEMM producers") + parser.add_argument("--persistent", action="store_true", default=False, help="use persistent GEMM producers") parser.add_argument("--print_source", action="store_true", help="print kernel source code") parser.add_argument("--warmup", type=int, default=5, help="number of warmup iterations") parser.add_argument("--repeat", type=int, default=10, help="number of repeat iterations") return parser.parse_args() -if __name__ == '__main__': - assert torch.cuda.get_device_capability()[0] >= 9, '❗This benchmark requires sm_90 or higher' +if __name__ == "__main__": + assert torch.cuda.get_device_capability()[0] >= 9, "❗This benchmark requires sm_90 or higher" WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP = init_distributed(return_tp_group=True) assert WORLD_SIZE <= 8, "This benchmark is designed for intra-node GEMM-RS" @@ -176,16 +167,14 @@ def torch_gemm_rs(): print("Use non-persistent GEMM producers...") def tilelang_gemm_rs(): - return overlapped_gemm_rs( - input, weight, rank=RANK, num_ranks=PE_num, persistent=args.persistent) + return overlapped_gemm_rs(input, weight, rank=RANK, num_ranks=PE_num, persistent=args.persistent) dist.barrier(TP_GROUP) tl_out, tl_t = perf_fn(tilelang_gemm_rs, warmup, repeat) print(f"rank {RANK} tilelang GEMM avg time: {tl_t} ms") # Check correctness - assert torch.allclose( - tl_out, torch_out, atol=1e-2, rtol=1e-2), f'max error: {(tl_out - torch_out).abs().max()}' + assert torch.allclose(tl_out, torch_out, atol=1e-2, rtol=1e-2), f"max error: {(tl_out - torch_out).abs().max()}" print(f"rank {RANK} check passed.✅") dist.destroy_process_group() diff --git a/benchmark/distributed/benchmark_reduce_scatter.py b/benchmark/distributed/benchmark_reduce_scatter.py index c6431f79a..277125bb6 100644 --- a/benchmark/distributed/benchmark_reduce_scatter.py +++ b/benchmark/distributed/benchmark_reduce_scatter.py @@ -11,13 +11,13 @@ tilelang.disable_cache() -#TODO: Bench on 4/8 H100 -#TODO: split N? -'''init_nvshmem_by_torch_process_group(_TP_GROUP) +# TODO: Bench on 4/8 H100 +# TODO: split N? +"""init_nvshmem_by_torch_process_group(_TP_GROUP) Note: Minor numerical differences exist between Triton/TileLang and Torch (~1e-2) due to the order reductions are handled in different implementations. (No error when #PE = 2) -''' +""" def reducescatter(PE_num, M, N, dtype="float16", threads=128): @@ -27,8 +27,8 @@ def reducescatter(PE_num, M, N, dtype="float16", threads=128): @T.prim_func def pull_reduce( - A: T.Tensor((M, N), dtype), # type: ignore - B: T.Tensor((M_per_rank, N), dtype), # type: ignore + A: T.Tensor((M, N), dtype), # type: ignore + B: T.Tensor((M_per_rank, N), dtype), # type: ignore ): with T.Kernel(M_per_rank // block_M, threads=threads) as (bx): mype = T.get_pe() @@ -42,15 +42,17 @@ def pull_reduce( T.getmem_nbi_block( T.address_of(A_shared[peer, 0, 0]), T.address_of(A[mype * M_per_rank + bx * block_M, 0]), - block_M * N * dtype_map[dtype].itemsize, peer) + block_M * N * dtype_map[dtype].itemsize, + peer, + ) base = mype * M_per_rank + bx * block_M - T.copy(A[base:base + block_M, :], A_shared[mype, :, :]) + T.copy(A[base : base + block_M, :], A_shared[mype, :, :]) T.fence() # Ensure reduce happens after all IO T.copy(A_shared, A_local) T.reduce_sum(A_local, A_local_sum, dim=0) - T.copy(A_local_sum, B[bx * block_M:bx * block_M + block_M, :]) + T.copy(A_local_sum, B[bx * block_M : bx * block_M + block_M, :]) return pull_reduce @@ -59,8 +61,7 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--M", type=int, default=8192) parser.add_argument("--N", type=int, default=16384) - parser.add_argument( - "--dtype", type=str, default="float16", choices=["float16", "float32", "bfloat16"]) + parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "float32", "bfloat16"]) parser.add_argument("--threads", type=int, default=128, help="number of threads in a block") parser.add_argument("--print_source", action="store_true", help="print kernel source code") parser.add_argument("--warmup", type=int, default=5, help="number of warmup iterations") @@ -68,8 +69,8 @@ def parse_args(): return parser.parse_args() -if __name__ == '__main__': - assert torch.cuda.get_device_capability()[0] >= 9, '❗This benchmark requires sm_90 or higher' +if __name__ == "__main__": + assert torch.cuda.get_device_capability()[0] >= 9, "❗This benchmark requires sm_90 or higher" WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP = init_distributed(return_tp_group=True) assert WORLD_SIZE <= 8, "This benchmark is designed for intra-node RS" @@ -83,7 +84,7 @@ def parse_args(): nelems = M * PE_num func = reducescatter(PE_num, M, N, dtype=dtype, threads=threads) - kernel = tilelang.compile(func, pass_configs={"tl.disable_tma_lower": True}, target='cuda') + kernel = tilelang.compile(func, pass_configs={"tl.disable_tma_lower": True}, target="cuda") # Get CUDA Source if RANK == 0 and args.print_source: @@ -142,8 +143,7 @@ def tilelang_rs(): print(f"rank {RANK} tilelang reduce_scatter avg time: {tl_t} ms") # Check correctness - assert torch.allclose( - tl_out, torch_out, atol=1e-2, rtol=1e-2), f'max error: {(tt_out - torch_out).abs().max()}' + assert torch.allclose(tl_out, torch_out, atol=1e-2, rtol=1e-2), f"max error: {(tt_out - torch_out).abs().max()}" print(f"rank {RANK} check passed.✅") dist.destroy_process_group() diff --git a/benchmark/distributed/ipc_impls/README.md b/benchmark/distributed/ipc_impls/README.md index d89d00956..59ad34e50 100644 --- a/benchmark/distributed/ipc_impls/README.md +++ b/benchmark/distributed/ipc_impls/README.md @@ -31,4 +31,3 @@ python benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py | 4,194,304 | 10.6560 | 2.2474 | 11.9145 | 2.2845 | > **Note:** All data presented above are unidirectional bandwidth. - diff --git a/benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py b/benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py index 5ab6265ae..b4836d1c3 100644 --- a/benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py +++ b/benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py @@ -12,15 +12,14 @@ from tilelang.distributed import init_distributed, perf_fn import pynvshmem -os.environ['NCCL_DEBUG'] = 'WARN' +os.environ["NCCL_DEBUG"] = "WARN" def nvshmem_kernel_push(size, threads): - @T.prim_func def nvshmem_push( - dst: T.Tensor((size), "float32"), # type: ignore - src: T.Tensor((size), "float32"), # type: ignore + dst: T.Tensor((size), "float32"), # type: ignore + src: T.Tensor((size), "float32"), # type: ignore ): with T.Kernel(1, threads=threads): T.putmem_block( @@ -35,11 +34,10 @@ def nvshmem_push( def nvshmem_kernel_pull(size, threads): - @T.prim_func def nvshmem_pull( - dst: T.Tensor((size), "float32"), # type: ignore - src: T.Tensor((size), "float32"), # type: ignore + dst: T.Tensor((size), "float32"), # type: ignore + src: T.Tensor((size), "float32"), # type: ignore ): with T.Kernel(1, threads=threads): T.getmem_block( @@ -53,8 +51,7 @@ def nvshmem_pull( return nvshmem_pull -def benchmark_nvshmem_bw(rank: int, num_ranks: int, group: dist.ProcessGroup, size: int, - args: argparse.Namespace): +def benchmark_nvshmem_bw(rank: int, num_ranks: int, group: dist.ProcessGroup, size: int, args: argparse.Namespace): assert num_ranks == 2, "this benchmark only supports 2 ranks" assert args.threads % 32 == 0, "threads must be divisible by 32" @@ -90,10 +87,8 @@ def pull_fn(): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - "--warmup", type=int, default=10, help="number of warmup iterations (default: 10)") - parser.add_argument( - "--repeat", type=int, default=50, help="number of repeat iterations (default: 50)") + parser.add_argument("--warmup", type=int, default=10, help="number of warmup iterations (default: 10)") + parser.add_argument("--repeat", type=int, default=50, help="number of repeat iterations (default: 50)") parser.add_argument("--threads", type=int, default=128, help="Threads per block (default: 128)") args = parser.parse_args() @@ -102,8 +97,6 @@ def pull_fn(): size = 2**log_size push_bw, pull_bw = benchmark_nvshmem_bw(rank, num_ranks, group, size, args) if rank == 0: - print( - f"size={size*4} bytes, nvshmem push bw: {push_bw:.4f} GB/s, nvshmem pull bw: {pull_bw:.4f} GB/s" - ) + print(f"size={size * 4} bytes, nvshmem push bw: {push_bw:.4f} GB/s, nvshmem pull bw: {pull_bw:.4f} GB/s") dist.destroy_process_group() diff --git a/benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py b/benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py index c7d3f2556..c320688ac 100644 --- a/benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py +++ b/benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py @@ -8,15 +8,14 @@ from tilelang.distributed import init_dist, perf_fn tilelang.disable_cache() -os.environ['NCCL_DEBUG'] = 'WARN' +os.environ["NCCL_DEBUG"] = "WARN" def ipc_kernel_push(size, threads, unroll_factor): - @T.prim_func def ipc_push( - dst: T.Tensor((size), "float32"), # type: ignore - src: T.Tensor((size), "float32"), # type: ignore + dst: T.Tensor((size), "float32"), # type: ignore + src: T.Tensor((size), "float32"), # type: ignore ): with T.Kernel(1, threads=threads): rank = T.alloc_local([1], "uint64") @@ -29,18 +28,18 @@ def ipc_push( dst=T.address_of(dst[warp_start]), size=warp_copy_size, dst_pe=rank[0] ^ 1, - unroll_factor=unroll_factor) + unroll_factor=unroll_factor, + ) T.fence_sys() return ipc_push def ipc_kernel_pull(size, threads, unroll_factor): - @T.prim_func def ipc_pull( - dst: T.Tensor((size), "float32"), # type: ignore - src: T.Tensor((size), "float32"), # type: ignore + dst: T.Tensor((size), "float32"), # type: ignore + src: T.Tensor((size), "float32"), # type: ignore ): with T.Kernel(1, threads=threads): rank = T.alloc_local([1], "uint64") @@ -53,14 +52,14 @@ def ipc_pull( dst=T.address_of(dst[warp_start]), size=warp_copy_size, src_pe=rank[0] ^ 1, - unroll_factor=unroll_factor) + unroll_factor=unroll_factor, + ) T.fence_sys() return ipc_pull -def benchmark_ipc_bw(rank: int, num_ranks: int, group: dist.ProcessGroup, size: int, - args: argparse.Namespace, allocator): +def benchmark_ipc_bw(rank: int, num_ranks: int, group: dist.ProcessGroup, size: int, args: argparse.Namespace, allocator): assert num_ranks == 2, "this benchmark only supports 2 ranks" assert args.threads % 32 == 0, "threads must be divisible by 32" @@ -100,30 +99,22 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rank, num_ranks, group = init_dist(local_rank, num_local_ranks) allocator = tilelang.get_allocator( - size=2**30, - device="cuda", - is_distributed=True, - local_rank=rank, - num_local_ranks=num_ranks, - group=group) + size=2**30, device="cuda", is_distributed=True, local_rank=rank, num_local_ranks=num_ranks, group=group + ) for log_size in range(9, 21): size = 2**log_size push_bw, pull_bw = benchmark_ipc_bw(rank, num_ranks, group, size, args, allocator) if rank == 0: - print( - f"size={size*4} bytes, ipc push bw: {push_bw:.4f} GB/s, ipc pull bw: {pull_bw:.4f} GB/s" - ) + print(f"size={size * 4} bytes, ipc push bw: {push_bw:.4f} GB/s, ipc pull bw: {pull_bw:.4f} GB/s") dist.destroy_process_group() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - "--warmup", type=int, default=10, help="number of warmup iterations (default: 10)") - parser.add_argument( - "--repeat", type=int, default=50, help="number of repeat iterations (default: 50)") + parser.add_argument("--warmup", type=int, default=10, help="number of warmup iterations (default: 10)") + parser.add_argument("--repeat", type=int, default=50, help="number of repeat iterations (default: 50)") parser.add_argument("--threads", type=int, default=128, help="Threads per block (default: 128)") parser.add_argument("--unroll-factor", type=int, default=4, help="Unroll factor (default: 4)") args = parser.parse_args() diff --git a/benchmark/distributed/utils.py b/benchmark/distributed/utils.py index fba164121..87cf9cc24 100644 --- a/benchmark/distributed/utils.py +++ b/benchmark/distributed/utils.py @@ -13,7 +13,6 @@ class AllToAllContext: - def __init__( self, max_m: int, diff --git a/benchmark/mamba2/README.md b/benchmark/mamba2/README.md index 8c6d933d5..f0b4b7e80 100644 --- a/benchmark/mamba2/README.md +++ b/benchmark/mamba2/README.md @@ -45,9 +45,14 @@ PY | 16384 | 2.531 | 135.711 | | 32768 | 5.076 | 135.379 | +## Compare with Baselines + +- Triton: v3.5.0, mamba-ssm: v2.2.6.post3 +- Helion: v0.2.1 +
Mamba2_chunk_scan Performance Comparison on H100
Performance comparison across compilers on NVIDIA H100
-
\ No newline at end of file + diff --git a/benchmark/mamba2/benchmark_mamba_chunk_scan.py b/benchmark/mamba2/benchmark_mamba_chunk_scan.py index 78dfb135e..55f802b4f 100644 --- a/benchmark/mamba2/benchmark_mamba_chunk_scan.py +++ b/benchmark/mamba2/benchmark_mamba_chunk_scan.py @@ -5,6 +5,20 @@ import tilelang.language as T from einops import rearrange, repeat import itertools +import math +from tilelang.profiler import do_bench + +try: + from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd +except ImportError as err: + raise ImportError("Please install mamba-ssm to use the triton chunk scan operator.") from err + +try: + import helion + from helion._testing import run_example + import helion.language as hl +except ImportError as err: + raise ImportError("Please install helion to use the helion chunk scan operator.") from err def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): @@ -37,14 +51,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] decay = torch.exp(dt_segment_sum) scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s") - causal_mask = torch.tril( - torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) + causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) scores_decay = scores_decay.masked_fill(~causal_mask, 0) - out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype), - rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) + out = torch.einsum( + "bchls,bhcs,bcshp->bclhp", scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks) + ) state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) - out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange( - C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out + out_prev = ( + torch.einsum("bclhn,bchpn->bclhp", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out + ) out = out + out_prev out = rearrange(out, "b c l h p -> b (c l) h p") if D is not None: @@ -54,13 +69,114 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): return out +def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D): + out, _ = _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D) + return out + + +def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): + @helion.kernel() + def helion_mamba2_chunk_scan_kernel( + cb: torch.Tensor, + x: torch.Tensor, + dt: torch.Tensor, + dA_cumsum: torch.Tensor, + C: torch.Tensor, + prev_states: torch.Tensor, + D: torch.Tensor, + ) -> torch.Tensor: + """ + Argument: + cb: (batch, nchunks, ngroups, chunk_size, chunk_size) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + C: (batch, seqlen, ngroups, dstate) + prev_states: (batch, nchunks, nheads, headdim, dstate) + D: (nheads,) + Return: + out: (batch, seqlen, nheads, headdim) + """ + + batch, nchunks, ngroups, chunk_size, _ = cb.shape + _, seqlen, nheads, headdim = x.shape + _, _, _, dstate = C.shape + assert nchunks == (seqlen + chunk_size - 1) // chunk_size + + block_m = hl.register_block_size(chunk_size) + block_n = hl.register_block_size(headdim) + block_k = hl.register_block_size(64, 64) + dstate = hl.specialize(dstate) + + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + assert x.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert C.shape == (batch, seqlen, ngroups, dstate) + assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) + assert D.shape == (nheads,) + + dtype = cb.dtype + accum_dtype = torch.float32 + assert x.dtype == dt.dtype == dA_cumsum.dtype == C.dtype == prev_states.dtype == D.dtype == dtype + + out = torch.empty_like(x) + + p = 1.44269504 + + for tile_h, tile_m, tile_n, tile_b, tile_c in hl.tile( + [nheads, chunk_size, headdim, batch, nchunks], + block_size=[1, block_m, block_n, 1, 1], + ): + acc_o = hl.zeros([tile_m, tile_n], dtype=accum_dtype) + dA_cumsum_local_m = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, tile_m].to(torch.float32) + scale_m_local = torch.exp2(dA_cumsum_local_m * p) + + C_local = C[ + tile_b.begin, + tile_m.index + tile_c.begin * chunk_size, + tile_h.begin // (nheads // ngroups), + :, + ] + prev_states_local = prev_states[tile_b.begin, tile_c.begin, tile_h.begin, tile_n, :] + acc_o = hl.dot(C_local, prev_states_local.T, acc=acc_o) + acc_o *= scale_m_local[:, None] + + for tile_k in hl.tile((tile_m.id + 1) * block_m, block_size=block_k): + cb_local = cb[ + tile_b.begin, + tile_c.begin, + tile_h.begin // (nheads // ngroups), + tile_m, + tile_k, + ] + dA_cumsum_local_k = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32) + cb_local *= torch.exp2(dA_cumsum_local_m[:, None] * p - dA_cumsum_local_k[None, :] * p) + dt_local = dt[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32) + cb_local = (cb_local * dt_local[None, :]).to(dtype) + pred = (tile_m.index + 0)[:, None] >= (tile_k.index + 0)[None, :] + cb_local = torch.where(pred, cb_local, torch.zeros_like(cb_local)) + x_local = x[ + tile_b.begin, + tile_c.begin * chunk_size + tile_k.index, + tile_h.begin, + tile_n, + ] + acc_o = hl.dot(cb_local, x_local, acc=acc_o) + + D_local = D[tile_h.begin].to(torch.float32) + x_residual = x[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n].to(torch.float32) + acc_o += x_residual * D_local + out[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n] = acc_o.to(dtype=dtype) + + return out + + args = (cb, x, dt, dA_cumsum, C, states, D) + run_example(helion_mamba2_chunk_scan_kernel, ref_program, args) + + def get_configs(): - iter_params = dict( - block_M=[64, 128, 256], - block_N=[32, 64], - block_K=[64, 128, 256], - block_Dstate=[128], - num_stages=[1, 2, 3, 4, 5]) + iter_params = dict(block_M=[64, 128, 256], block_N=[32, 64], block_K=[64, 128, 256], block_Dstate=[128], num_stages=[1, 2, 3, 4, 5]) return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] @@ -71,56 +187,58 @@ def get_configs(): tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }, ) -def chunk_scan_fwd(batch, - seqlen, - chunk_size, - ngroups, - nheads, - headdim, - dstate, - block_M=64, - block_N=64, - block_K=64, - block_Dstate=128, - num_stages=2, - threads=128): - dtype = "float16" - accum_dtype = "float" +def chunk_scan_fwd( + batch, + seqlen, + chunk_size, + ngroups, + nheads, + headdim, + dstate, + block_M=64, + block_N=64, + block_K=64, + block_Dstate=128, + num_stages=2, + threads=128, +): + dtype = T.float16 + accum_dtype = T.float32 nchunks = T.ceildiv(seqlen, chunk_size) p = 1.44269504 @T.prim_func def main( - cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore - x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore - dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore - dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore - C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore - prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore - D: T.Tensor((nheads), dtype), # type: ignore - Output: T.Tensor((batch, seqlen, nheads, headdim), dtype) # type: ignore + cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore + x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore + dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore + dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore + C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore + prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore + D: T.Tensor((nheads), dtype), # type: ignore + Output: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore ): - with T.Kernel( - nheads, - T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), - batch * nchunks, - threads=threads) as (bz, bx, by): + with T.Kernel(nheads, T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), batch * nchunks, threads=threads) as ( + bz, + bx, + by, + ): acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) acc_o_shared = T.alloc_shared((block_M, block_N), dtype) - cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn") + cb_shared = T.alloc_shared((block_M, block_K), dtype) cb_local = T.alloc_fragment((block_M, block_K), dtype) - dA_cs_k_shared = T.alloc_shared((block_K), dtype, scope="shared") + dA_cs_k_shared = T.alloc_shared((block_K), dtype) dA_cs_k_local = T.alloc_fragment((block_K), accum_dtype) dA_cs_m_local = T.alloc_fragment((block_M), accum_dtype) - dt_shared = T.alloc_shared((block_K), dtype, scope="shared") + dt_shared = T.alloc_shared((block_K), dtype) dt_local = T.alloc_fragment((block_K), accum_dtype) - x_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.dyn") - dA_cs_m_shared = T.alloc_shared((block_M), dtype, scope="shared") + x_shared = T.alloc_shared((block_K, block_N), dtype) + dA_cs_m_shared = T.alloc_shared((block_M), dtype) scale_m_local = T.alloc_fragment((block_M), accum_dtype) C_shared = T.alloc_shared((block_M, block_Dstate), dtype) prev_state_shared = T.alloc_shared((block_N, block_Dstate), dtype) D_local = T.alloc_fragment((1), accum_dtype) - x_residual_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.dyn") + x_residual_shared = T.alloc_shared((block_M, block_N), dtype) x_residual_local = T.alloc_fragment((block_M, block_N), accum_dtype) batch_idx = by % batch @@ -130,27 +248,31 @@ def main( m_idx = bx // T.ceildiv(headdim, block_N) n_idx = bx % T.ceildiv(headdim, block_N) - T.annotate_layout({ - acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared), - cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), - x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared) - }) + T.annotate_layout( + { + cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), + x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared), + } + ) T.no_set_max_nreg() - T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M], - dA_cs_m_shared) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M : (m_idx + 1) * block_M], dA_cs_m_shared) T.copy(dA_cs_m_shared, dA_cs_m_local) T.clear(acc_o) for i in T.Parallel(block_M): scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p) T.copy( - C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared) - T.copy( - prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N, - 0:block_Dstate], prev_state_shared) + C[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz // (nheads // ngroups), + 0:block_Dstate, + ], + C_shared, + ) + T.copy(prev_states[batch_idx, chunk_idx, bz, n_idx * block_N : (n_idx + 1) * block_N, 0:block_Dstate], prev_state_shared) T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True) for i, j in T.Parallel(block_M, block_N): acc_o[i, j] *= scale_m_local[i] @@ -159,34 +281,47 @@ def main( for k in T.Pipelined(loop_range, num_stages=num_stages): T.copy( - cb[batch_idx, chunk_idx, bz // (nheads // ngroups), - m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K], - cb_shared) + cb[ + batch_idx, + chunk_idx, + bz // (nheads // ngroups), + m_idx * block_M : (m_idx + 1) * block_M, + k * block_K : (k + 1) * block_K, + ], + cb_shared, + ) T.copy(cb_shared, cb_local) - T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], - dA_cs_k_shared) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cs_k_shared) T.copy(dA_cs_k_shared, dA_cs_k_local) for i, j in T.Parallel(block_M, block_K): - cb_local[i, - j] = cb_local[i, - j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) - T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared) + cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) + T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared) T.copy(dt_shared, dt_local) for i, j in T.Parallel(block_M, block_K): cb_local[i, j] *= dt_local[j] for i, j in T.Parallel(block_M, block_K): - cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, - cb_local[i, j], 0) + cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, cb_local[i, j], 0) T.copy( - x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + - (k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared) + x[ + batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + x_shared, + ) T.gemm(cb_local, x_shared, acc_o) D_local[0] = D[bz] T.copy( - x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N], - x_residual_shared) + x[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + x_residual_shared, + ) T.copy(x_residual_shared, x_residual_local) for i, j in T.Parallel(block_M, block_N): acc_o[i, j] += x_residual_local[i, j] * D_local[0] @@ -194,26 +329,41 @@ def main( T.copy(acc_o, acc_o_shared) T.copy( acc_o_shared, - Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N]) + Output[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + ) return main if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=80, help='heads') - parser.add_argument('--groups', type=int, default=1, help='groups') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--chunk_size', type=int, default=256, help='chunk size') - parser.add_argument('--dim', type=int, default=64, help='dim') - parser.add_argument('--dstate', type=int, default=128, help='dstate') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=80, help="heads") + parser.add_argument("--groups", type=int, default=1, help="groups") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--chunk_size", type=int, default=256, help="chunk size") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--dstate", type=int, default=128, help="dstate") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate + batch, heads, groups, seq_len, chunk_size, dim, dstate = ( + args.batch, + args.heads, + args.groups, + args.seq_len, + args.chunk_size, + args.dim, + args.dstate, + ) + nchunks = math.ceil(seq_len / chunk_size) total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate + print("Benchmarking TileLang...") kernel = chunk_scan_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate) best_latency = kernel.latency best_config = kernel.config @@ -221,3 +371,18 @@ def main( print(f"Best latency: {best_latency}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best config: {best_config}") + + cb = torch.randn(batch, nchunks, groups, chunk_size, chunk_size).half().cuda() + x = torch.randn(batch, seq_len, heads, dim).half().cuda() + dt = torch.randn(batch, heads, nchunks, chunk_size).half().cuda() + dA_cumsum = torch.randn(batch, heads, nchunks, chunk_size).half().cuda() + C = torch.randn(batch, seq_len, groups, dstate).half().cuda() + states = torch.randn(batch, nchunks, heads, dim, dstate).half().cuda() + D = torch.randn(heads).half().cuda() + + print("Benchmarking Triton...") + triton_latency = do_bench(lambda: chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D), _n_warmup=10, _n_repeat=10) + print(f"Triton TFlops: {total_flops / triton_latency * 1e-9}") + + print("Benchmarking Helion...") + chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D) diff --git a/benchmark/matmul/benchmark_matmul.py b/benchmark/matmul/benchmark_matmul.py index c64f4fabf..643c1fd5e 100644 --- a/benchmark/matmul/benchmark_matmul.py +++ b/benchmark/matmul/benchmark_matmul.py @@ -6,6 +6,7 @@ import tilelang.language as T from tilelang.autotuner import autotune from tilelang import jit + # Configure logger logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -61,9 +62,9 @@ def get_configs(args, kwargs): M=M, N=N, K=K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float", + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float32, ).with_arch(arch) func = carve_template.equivalent_function() @@ -101,9 +102,7 @@ def get_configs(args, kwargs): policy=[T.GemmWarpPolicy.Square], enable_rasteration=[True, False], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] return configs @@ -112,7 +111,9 @@ def get_configs(args, kwargs): warmup=3, rep=20, ) -@jit(out_idx=[2],) +@jit( + out_idx=[2], +) def matmul( M, N, @@ -154,14 +155,14 @@ def matmul( # Use half-precision for input data to reduce memory bandwidth, # accumulate in float for better numerical accuracy - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): """ The compiled TVM function for block-level matrix multiplication. @@ -176,7 +177,6 @@ def main( # Bind x-dimension to block index in N, # y-dimension to block index in M. with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - # Allocate shared memory for A sub-block of shape (block_M, block_K) A_shared = T.alloc_shared((block_M, block_K), dtype) # Allocate shared memory for B sub-block of shape (block_N, block_K) diff --git a/benchmark/matmul/benchmark_matmul_intrinsic.py b/benchmark/matmul/benchmark_matmul_intrinsic.py index 94e36b385..4ef860c21 100644 --- a/benchmark/matmul/benchmark_matmul_intrinsic.py +++ b/benchmark/matmul/benchmark_matmul_intrinsic.py @@ -6,7 +6,8 @@ import tilelang.language as T from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func from tilelang.autotuner import autotune import itertools @@ -48,22 +49,22 @@ def tl_matmul( enable_rasteration=False, ): assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 - if out_dtype == "int32": + if out_dtype == T.int32: micro_size_k = 32 # This is a debug config - # chunk = 32 if in_dtype == "float16" else 64 + # chunk = 32 if in_dtype == T.float16 else 64 shared_scope = "shared.dyn" block_M = block_row_warps * warp_row_tiles @@ -103,12 +104,11 @@ def tl_matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -116,10 +116,12 @@ def main( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10, enable=enable_rasteration) @@ -127,7 +129,6 @@ def main( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -137,7 +138,6 @@ def main( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a(A_local, A_shared, ki) @@ -194,9 +194,9 @@ def get_configs(args, kwargs): M=M, N=N, K=K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float16, ).with_arch(arch) func = carve_template.equivalent_function() @@ -223,7 +223,6 @@ def get_configs(args, kwargs): for config in configs: print(config) else: - iter_params = dict( block_row_warps=[1, 2, 4], block_col_warps=[1, 2, 4], @@ -233,9 +232,7 @@ def get_configs(args, kwargs): stage=[0, 2], enable_rasteration=[True, False], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] return configs @@ -247,14 +244,16 @@ def get_configs(args, kwargs): ref_prog=ref_program, skip_check=True, ) -@tl.jit(out_idx=[2],) +@tl.jit( + out_idx=[2], +) def matmul( M, N, K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float16, with_roller=False, block_row_warps=None, block_col_warps=None, @@ -291,19 +290,14 @@ def kernel(): parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") - parser.add_argument( - "--with_roller", - type=bool, - default=False, - help="Whether to use roller to deduce search spaces") - parser.add_argument( - "--dtype", type=str, default="float16", choices=["float16", "int8"], help="Input data type") + parser.add_argument("--with_roller", type=bool, default=False, help="Whether to use roller to deduce search spaces") + parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "int8"], help="Input data type") args = parser.parse_args() M, N, K = args.m, args.n, args.k - in_dtype = args.dtype - out_dtype = "float32" if in_dtype == "int8" else "float16" - accum_dtype = "float32" if in_dtype == "int8" else "float16" + in_dtype = T.dtype(args.dtype) + out_dtype = T.float32 if in_dtype == T.int8 else T.float16 + accum_dtype = T.float32 if in_dtype == T.int8 else T.float16 with_roller = args.with_roller with_roller = True # Compute total floating-point operations diff --git a/benchmark/matmul/benchmark_matmul_sp.py b/benchmark/matmul/benchmark_matmul_sp.py index 4e4ed6128..7ecffc26a 100644 --- a/benchmark/matmul/benchmark_matmul_sp.py +++ b/benchmark/matmul/benchmark_matmul_sp.py @@ -9,7 +9,7 @@ from tilelang.autotuner import autotune from tilelang import jit from tilelang.contrib import nvcc -from tilelang.layout import make_metadata_layout +from tilelang.layout import make_cutlass_metadata_layout # Configure logger logger = logging.getLogger(__name__) @@ -70,7 +70,8 @@ def get_configs(M, N, K): thread_num, policy, enable_rasterization, - )) + ) + ) configs = [ { @@ -81,12 +82,13 @@ def get_configs(M, N, K): "thread_num": c[4], "policy": c[5], "enable_rasterization": c[6], # keep param name for backward-compat - } for c in _configs + } + for c in _configs ] return configs -def matmul_sp(M, N, K, accum_dtype): +def matmul_sp(M, N, K, in_dtype, accum_dtype): """ Create an autotuned matrix multiplication kernel for matrices of shape: - A: (M, K) @@ -126,7 +128,9 @@ def matmul_sp(M, N, K, accum_dtype): warmup=3, rep=20, ) - @jit(out_idx=[2],) + @jit( + out_idx=[2], + ) def kernel( block_M=None, block_N=None, @@ -161,15 +165,14 @@ def kernel( """ # Use half-precision for input data to reduce memory bandwidth, # accumulate in float for better numerical accuracy - dtype = "float16" e_factor, e_dtype = ARCH_INFO[arch] @T.prim_func def main( - A_sparse: T.Tensor((M, K // 2), dtype), - E: T.Tensor((M, K // e_factor), e_dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), accum_dtype), + A_sparse: T.Tensor((M, K // 2), in_dtype), + E: T.Tensor((M, K // e_factor), e_dtype), + B: T.Tensor((K, N), in_dtype), + C: T.Tensor((M, N), accum_dtype), ): """ The compiled TVM function for block-level matrix multiplication. @@ -183,13 +186,11 @@ def main( """ # Bind x-dimension to block index in N, # y-dimension to block index in M. - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): # Allocate shared memory for A sub-block of shape (block_M, block_K) - A_shared = T.alloc_shared((block_M, block_K // 2), dtype) + A_shared = T.alloc_shared((block_M, block_K // 2), in_dtype) # Allocate shared memory for B sub-block of shape (block_N, block_K) - B_shared = T.alloc_shared((block_K, block_N), dtype) + B_shared = T.alloc_shared((block_K, block_N), in_dtype) # Allocate shared memory for E sub-block of shape (block_M, block_K // E_factor) E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) # Allocate a local fragment for intermediate accumulation @@ -202,14 +203,12 @@ def main( T.disable_warp_group_reg_alloc() T.use_swizzle(panel_size=10, enable=enable_rasterization) - T.annotate_layout({ - E: - make_metadata_layout( - E, mma_dtype="float16", backend="cutlass", block_k=block_K), - E_shared: - make_metadata_layout( - E_shared, mma_dtype="float16", backend="cutlass", block_k=block_K), - }) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, block_k=block_K), + } + ) # Loop over sub-blocks in K dimension, pipelined by num_stages for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): # Load a sub-block of A from global memory into A_shared @@ -220,7 +219,7 @@ def main( T.copy(B[k * block_K, bx * block_N], B_shared) # Perform a partial matrix multiplication: # C_local += A_shared @ B_shared - T.gemm_sp( + T.gemm_sp_v2( A_shared, E_shared, B_shared, @@ -244,18 +243,13 @@ def main( parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") parser.add_argument("--disable_cache", action="store_true") - parser.add_argument( - "--accum_dtype", - type=str, - default="float", - choices=["float", "float16"], - help="Accumulation datatype") + parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype") parser.add_argument( "--bench_torch_sparse", type=str, - choices=['cutlass', 'cusparselt'], + choices=["cutlass", "cusparselt"], default=None, - help="Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported" + help="Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported", ) args = parser.parse_args() @@ -268,7 +262,7 @@ def main( total_flops = 2 * M * N * K # matmul(...) returns (best_latency, best_config, ref_latency) - best_result = matmul_sp(M, N, K, args.accum_dtype) + best_result = matmul_sp(M, N, K, T.float16, args.accum_dtype) best_latency = best_result.latency best_config = best_result.config A = torch.randn(M, K, dtype=torch.float16, device="cuda") @@ -277,7 +271,8 @@ def main( if args.bench_torch_sparse is not None: from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor - if args.bench_torch_sparse == 'cutlass': + + if args.bench_torch_sparse == "cutlass": SparseSemiStructuredTensor._FORCE_CUTLASS = True A_sp = to_sparse_semi_structured(A, transposed=False) torch_sparse_latency = do_bench(lambda: A_sp @ B) @@ -288,8 +283,6 @@ def main( print(f"Best config: {best_config}") if args.bench_torch_sparse is not None: - print( - f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}" - ) + print(f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}") print(f"Reference Dense TFlops: {total_flops / ref_latency * 1e-9:.3f}") diff --git a/benchmark/matmul_fp8/benchmark_matmul.py b/benchmark/matmul_fp8/benchmark_matmul.py index 36b910355..64714b649 100644 --- a/benchmark/matmul_fp8/benchmark_matmul.py +++ b/benchmark/matmul_fp8/benchmark_matmul.py @@ -1,7 +1,7 @@ import argparse import itertools +import torch import logging -import tilelang import tilelang.language as T from tilelang.autotuner import autotune from tilelang import jit @@ -62,9 +62,9 @@ def get_configs(args, kwargs): M=M, N=N, K=K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float", + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float32, ).with_arch(arch) func = carve_template.equivalent_function() @@ -99,12 +99,11 @@ def get_configs(args, kwargs): block_K=[64, 128], num_stages=[0, 1, 2, 3], thread_num=[128, 256], + k_pack=[1, 2], policy=[T.GemmWarpPolicy.Square], enable_rasteration=[True, False], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] return configs @@ -114,7 +113,9 @@ def get_configs(args, kwargs): warmup=3, rep=20, ) -@jit(out_idx=[2],) +@jit( + out_idx=[2], +) def matmul( M, N, @@ -125,6 +126,7 @@ def matmul( block_K=None, num_stages=None, thread_num=None, + k_pack=None, policy=None, enable_rasteration=None, ): @@ -156,14 +158,14 @@ def matmul( # Use half-precision for input data to reduce memory bandwidth, # accumulate in float for better numerical accuracy - dtype = "float8_e4m3" - accum_dtype = "float" + dtype = T.float8_e4m3fnuz if torch.version.hip is not None else T.float8_e4m3fn + accum_dtype = T.float32 @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): """ The compiled TVM function for block-level matrix multiplication. @@ -178,7 +180,6 @@ def main( # Bind x-dimension to block index in N, # y-dimension to block index in M. with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - # Allocate shared memory for A sub-block of shape (block_M, block_K) A_shared = T.alloc_shared((block_M, block_K), dtype) # Allocate shared memory for B sub-block of shape (block_N, block_K) @@ -190,8 +191,6 @@ def main( # Enable (or disable) swizzling optimization T.use_swizzle(panel_size=10, enable=enable_rasteration) - # to utilize swizzle tma layout - T.annotate_layout({C_shared: tilelang.layout.make_swizzled_layout(C_shared)}) # Clear out the accumulation buffer T.clear(C_local) @@ -210,6 +209,7 @@ def main( C_local, transpose_B=True, policy=policy, + k_pack=k_pack, ) # Write back the results from C_local to the global memory C T.copy(C_local, C_shared) diff --git a/cmake/load_tvm.cmake b/cmake/load_tvm.cmake index 21fe6dfb5..cb21be95f 100644 --- a/cmake/load_tvm.cmake +++ b/cmake/load_tvm.cmake @@ -3,16 +3,28 @@ set(TVM_BUILD_FROM_SOURCE TRUE) set(TVM_SOURCE ${CMAKE_SOURCE_DIR}/3rdparty/tvm) -if(DEFINED $ENV{TVM_ROOT}) +if(DEFINED ENV{TVM_ROOT}) if(EXISTS $ENV{TVM_ROOT}/cmake/config.cmake) set(TVM_SOURCE $ENV{TVM_ROOT}) + message(STATUS "Using TVM_ROOT from environment variable: ${TVM_SOURCE}") endif() endif() +message(STATUS "Using TVM source: ${TVM_SOURCE}") + set(TVM_INCLUDES ${TVM_SOURCE}/include - ${TVM_SOURCE}/ffi/include ${TVM_SOURCE}/src ${TVM_SOURCE}/3rdparty/dlpack/include ${TVM_SOURCE}/3rdparty/dmlc-core/include ) + +if(EXISTS ${TVM_SOURCE}/ffi/include) + list(APPEND TVM_INCLUDES ${TVM_SOURCE}/ffi/include) +elseif(EXISTS ${TVM_SOURCE}/3rdparty/tvm-ffi/include) + list(APPEND TVM_INCLUDES ${TVM_SOURCE}/3rdparty/tvm-ffi/include) +endif() + +if(EXISTS ${TVM_SOURCE}/3rdparty/tvm-ffi/3rdparty/dlpack/include) + list(APPEND TVM_INCLUDES ${TVM_SOURCE}/3rdparty/tvm-ffi/3rdparty/dlpack/include) +endif() diff --git a/cmake/pypi-z3/FindZ3.cmake b/cmake/pypi-z3/FindZ3.cmake new file mode 100644 index 000000000..d7920f8f9 --- /dev/null +++ b/cmake/pypi-z3/FindZ3.cmake @@ -0,0 +1,30 @@ +if(Z3_FOUND) + return() +endif() +find_package(Python3 COMPONENTS Interpreter REQUIRED) +execute_process( + COMMAND "${Python3_EXECUTABLE}" -c "import z3; print(z3.__path__[0])" + OUTPUT_VARIABLE Z3_PATH + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE Z3_PYTHON_RESULT +) +if(NOT Z3_PYTHON_RESULT EQUAL 0 OR Z3_PATH STREQUAL "") + message(FATAL_ERROR "Failed to locate z3 Python package. Ensure z3-solver>=4.13.0 is installed.") +endif() +message("-- Find Z3 in path: ${Z3_PATH}") +find_path(Z3_INCLUDE_DIR NO_DEFAULT_PATH NAMES z3++.h PATHS ${Z3_PATH}/include) +find_library(Z3_LIBRARY NO_DEFAULT_PATH NAMES z3 libz3 PATHS ${Z3_PATH}/bin ${Z3_PATH}/lib ${Z3_PATH}/lib64) +message("-- Found Z3 include dir: ${Z3_INCLUDE_DIR}") +message("-- Found Z3 library: ${Z3_LIBRARY}") +add_library(z3::libz3 SHARED IMPORTED GLOBAL) +set_target_properties(z3::libz3 + PROPERTIES + IMPORTED_LOCATION ${Z3_LIBRARY} + INTERFACE_INCLUDE_DIRECTORIES ${Z3_INCLUDE_DIR} +) +if(NOT Z3_INCLUDE_DIR OR NOT Z3_LIBRARY) + message(FATAL_ERROR "Could not find Z3 library or include directory") +endif() +set(Z3_CXX_INCLUDE_DIRS ${Z3_INCLUDE_DIR}) +set(Z3_C_INCLUDE_DIRS ${Z3_INCLUDE_DIR}) +set(Z3_FOUND TRUE) diff --git a/docker/Dockerfile.cu118 b/docker/Dockerfile.cu118 index 9256fc09b..969b0e43c 100644 --- a/docker/Dockerfile.cu118 +++ b/docker/Dockerfile.cu118 @@ -1,4 +1,4 @@ -FROM nvcr.io/nvidia/pytorch:22.12-py3 +FROM nvcr.io/nvidia/pytorch:22.12-py3 WORKDIR /root @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.cu120 b/docker/Dockerfile.cu120 index c89ce82ef..341fe40c0 100644 --- a/docker/Dockerfile.cu120 +++ b/docker/Dockerfile.cu120 @@ -1,4 +1,4 @@ -FROM nvcr.io/nvidia/pytorch:23.01-py3 +FROM nvcr.io/nvidia/pytorch:23.01-py3 WORKDIR /root @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.cu121 b/docker/Dockerfile.cu121 index 5b092773d..f91029d75 100644 --- a/docker/Dockerfile.cu121 +++ b/docker/Dockerfile.cu121 @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.cu123 b/docker/Dockerfile.cu123 index 2715536a8..b3d1217fd 100644 --- a/docker/Dockerfile.cu123 +++ b/docker/Dockerfile.cu123 @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.cu124 b/docker/Dockerfile.cu124 index fb9654f48..335f52565 100644 --- a/docker/Dockerfile.cu124 +++ b/docker/Dockerfile.cu124 @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.cu125 b/docker/Dockerfile.cu125 index c409667cb..148e44b41 100644 --- a/docker/Dockerfile.cu125 +++ b/docker/Dockerfile.cu125 @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.cu126 b/docker/Dockerfile.cu126 index 93593b5df..c031c2bc9 100644 --- a/docker/Dockerfile.cu126 +++ b/docker/Dockerfile.cu126 @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.cu128 b/docker/Dockerfile.cu128 index 1617bc79c..2b895ecd8 100644 --- a/docker/Dockerfile.cu128 +++ b/docker/Dockerfile.cu128 @@ -20,9 +20,12 @@ ENV LIBGL_ALWAYS_INDIRECT=1 RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && conda clean --all -RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev +RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev \ + build-essential cmake libedit-dev libxml2-dev cython3 + +RUN pip install cython RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 1fb23a9f3..5f61f0e2e 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -9,23 +9,43 @@ ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update && apt-get install -y --no-install-recommends \ build-essential git wget \ libgtest-dev libprotobuf-dev protobuf-compiler libgflags-dev libsqlite3-dev llvm-dev \ + rocm-dev rocm-libs hip-dev hipblas-dev rocblas-dev \ && apt-get clean autoclean && rm -rf /var/lib/apt/lists/{cache,log} /tmp/* /var/tmp/* ENV PATH="/opt/conda/bin:${PATH}" ENV LIBGL_ALWAYS_INDIRECT=1 +ENV USE_ROCM=1 +ENV USE_CUDA=0 +ENV ROCM_HOME=/opt/rocm +ENV HIP_PLATFORM=amd +ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942" RUN conda run -n py_3.10 conda install pip cmake -y && \ conda run -n py_3.10 conda install -c conda-forge libstdcxx-ng=12 -y && \ conda clean --all -RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev +RUN apt-get update && apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev && \ + apt-get clean autoclean && rm -rf /var/lib/apt/lists/{cache,log} /tmp/* /var/tmp/* -RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main tilelang && \ - conda run -n py_3.10 bash -c "cd tilelang && ./install_rocm.sh" +# Copy local tilelang directory instead of cloning from git +# Build from tilelang root: docker build -f docker/Dockerfile.rocm -t mi300:latest . +COPY . /root/tilelang -RUN conda init bash +RUN mv /opt/conda/envs/py_3.10/compiler_compat /opt/conda/envs/py_3.10/compiler_compat.bak || true && \ + conda run -n py_3.10 bash -c "export USE_ROCM=1 USE_CUDA=0 && pip install 'numpy<2.0' --force-reinstall" && \ + conda run -n py_3.10 bash -c "cd /root/tilelang && \ + # Backup and modify pyproject.toml to remove torch from dependencies \ + cp pyproject.toml pyproject.toml.bak && \ + sed -i '/^[[:space:]]*\"torch/d' pyproject.toml && \ + # Install tilelang with all dependencies except torch \ + USE_ROCM=1 USE_CUDA=0 pip install -e . -v && \ + # Restore original pyproject.toml \ + mv pyproject.toml.bak pyproject.toml" + +RUN conda init bash && \ + echo "conda activate py_3.10" >> /root/.bashrc SHELL ["/bin/bash", "-l", "-c"] -CMD ["bash", "-c", "source ~/.bashrc && conda activate py_3.10 && exec bash"] \ No newline at end of file +ENTRYPOINT ["/bin/bash", "--login", "-i"] diff --git a/docs/.gitignore b/docs/.gitignore index 4d8eb4049..79ba97163 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -1,2 +1,2 @@ _build/ -autoapi/ \ No newline at end of file +autoapi/ diff --git a/docs/CNAME b/docs/CNAME index ca903c694..6862cd2e9 100644 --- a/docs/CNAME +++ b/docs/CNAME @@ -1 +1 @@ -tilelang.com \ No newline at end of file +tilelang.com diff --git a/docs/README.md b/docs/README.md index 349c0eccc..896d778d2 100644 --- a/docs/README.md +++ b/docs/README.md @@ -27,4 +27,4 @@ cd _build/html python3 -m http.server ``` -Then you can view the documentation in your browser at `http://localhost:8000` (the port can be customized by appending ` -p PORT_NUMBER` in the python command above). +Then you can view the documentation in your browser at `http://localhost:8000` (the port can be customized by appending `-p PORT_NUMBER` in the python command above). diff --git a/docs/_static/custom.css b/docs/_static/custom.css new file mode 100644 index 000000000..a1fee9c3d --- /dev/null +++ b/docs/_static/custom.css @@ -0,0 +1,10 @@ +/* Reduce the displayed size of the sidebar logo in Furo */ +.sidebar-logo { + max-height: 125px; + width: auto; +} + +/* Optional: keep container from growing too tall due to spacing */ +.sidebar-logo-container { + line-height: 0; +} diff --git a/docs/_static/img/logo-row.svg b/docs/_static/img/logo-row.svg index 633243f3a..e73244b74 100644 --- a/docs/_static/img/logo-row.svg +++ b/docs/_static/img/logo-row.svg @@ -1 +1 @@ - \ No newline at end of file + diff --git a/docs/_static/img/logo-v2.png b/docs/_static/img/logo-v2.png new file mode 100644 index 000000000..410773f60 Binary files /dev/null and b/docs/_static/img/logo-v2.png differ diff --git a/docs/_static/img/logo.png b/docs/_static/img/logo.png new file mode 100644 index 000000000..5d04697ce Binary files /dev/null and b/docs/_static/img/logo.png differ diff --git a/docs/_static/img/sparse_mma_storage_example.png b/docs/_static/img/sparse_mma_storage_example.png new file mode 100644 index 000000000..0b1639819 Binary files /dev/null and b/docs/_static/img/sparse_mma_storage_example.png differ diff --git a/docs/compiler_internals/inject_fence_proxy.md b/docs/compiler_internals/inject_fence_proxy.md index 81f498e57..7a89456ac 100644 --- a/docs/compiler_internals/inject_fence_proxy.md +++ b/docs/compiler_internals/inject_fence_proxy.md @@ -17,7 +17,7 @@ The pass is conservative: unknown extern calls are treated as async so that the ### Timeline View ``` -generic initialize_descriptor → generic shared-store → async wgmma +generic initialize_wgmma_descriptor → generic shared-store → async wgmma │ │ │ └─ generic proxy ┴─ generic proxy ┴─ async proxy │ fence inserted here ↑ @@ -53,7 +53,7 @@ def kernel(): with T.Kernel(1): desc = T.decl_buffer((1,), "uint64", scope="local.descriptor") smem = T.decl_buffer((128,), "float16", scope="shared") - T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32) + T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32) smem[0] = T.float16(0) T.ptx_wgmma_ss( "float16", @@ -83,7 +83,7 @@ def kernel(): with T.Kernel(1): desc = T.decl_buffer((1,), "uint64", scope="local.descriptor") smem = T.decl_buffer((128,), "float16", scope="shared") - T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32) + T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32) smem[0] = T.float16(0) T.fence_proxy_async() T.ptx_wgmma_ss( diff --git a/docs/compiler_internals/tensor_checks.md b/docs/compiler_internals/tensor_checks.md new file mode 100644 index 000000000..ed5a9e691 --- /dev/null +++ b/docs/compiler_internals/tensor_checks.md @@ -0,0 +1,386 @@ +# Tensor Checks (Host-Side Auto-Validation) + +This page explains the host-side checks that TileLang automatically inserts into the generated host stub for kernels. When you pass `torch.Tensor` or any DLPack-compatible object to a TileLang kernel, the host stub validates argument count, pointer kinds, dtype, shape, strides, device, and more — so you don’t need to handwrite Python checks. This keeps the ABI stable and significantly reduces Python overhead compared to doing equivalent checks in Python or via pybind. + +## Why Host-Side Checks +- ABI stability: the entry is based on TVM FFI + DLPack, consistently accepting tensors and scalars. +- Lower overhead: shifting checks from Python into C reduces interpreter/property-access costs; the call overhead is lower than pybind-based approaches. +- Focused error reporting: assertions are raised close to the call site with precise “which field failed” messages. + +## How To Inspect Host Source +You can inspect the auto-generated host source (with all checks and the final device-kernel call) for debugging: + +```python +print(matmul_relu_kernel.get_host_source()) +``` + +--- + +## What The Host Checks + +### 1) Argument count and pointer kind +- `num_args` must match the number of formal parameters; otherwise the kernel returns `-1` with an error message. +- Each argument’s FFI type must be a pointer kind (for DLTensor/handle) or a valid scalar type; otherwise you’ll see errors like `Expect arg[i] to be pointer` or a scalar type error. + +### 2) Tensor checks (per tensor, after nullability decision) +- Nullability + - If the tensor is “statically reachable/used” by the function body, the handle must be non-NULL; otherwise: `xxx is expected to have non-NULL pointer`. + - If an input tensor is not used by the function (statically unreachable), NULL is allowed; other field checks are executed only when `handle != NULL`. +- Rank (`ndim`) + - Runtime `ndim` must equal the compile-time rank. +- Data type (`dtype`) + - Match the triple `(code, bits, lanes)` with tolerance: + - `float8_e4m3`: accept `e4m3`, `e4m3fn`, `e4m3fnuz`. + - `float8_e5m2`: accept `e5m2`, `e5m2fnuz`. + - `bool`: accept `int8/uint8` with `bits=8` (same lanes), `kDLBool(code=6, bits=1 or 8)`, and any `bitwidth=1` (lanes must match). + - For packed-bit dtypes (e.g., `Int(1)`, `Int(4)`, `UInt(4)`), strict dtype checking is skipped. +- Shape + - Each runtime dimension is bound to the compile-time shape (constants or symbols) and checked for consistency. + - Linear equations among symbolic dims can be solved on the fly (when there’s only one unknown at a given check point), enabling cross-tensor constraints. +- Strides + - If `buffer_type = AutoBroadcast`: allow `strides == NULL` and derive strides from `shape`. If explicit `strides` is present, bind to compile-time constraints and check for equality. + - Otherwise: check per-dimension; if `strides == NULL`, derive from `shape` and compare (e.g., contiguous: `strides[-1] == 1`, `strides[-2] == shape[-1]`). +- `byte_offset` + - Must be 0 (non-zero raises an error) to keep addressing simple and aligned. +- Device info + - Assert `device_type == target backend` (CUDA/ROCM/Metal/OneAPI/WebGPU/CPU, etc.). Error messages include a DLPack code legend. + - When multiple tensors participate, assert that `device_id` matches across them. +- Data pointer + - Must be non-NULL when the tensor is required to be non-null by the nullability rule. + +### 3) Scalar checks +- `T.int*` family: require integer; error: `Expect arg[i] to be int`. +- `T.bool`: require boolean; error: `Expect arg[i] to be boolean`. + +--- + +## Shapes and Symbolic Equations: Linear Solving +When shapes are symbolic, the host binds and (when possible) solves linear relations at runtime (only one unknown per check point). Example: + +```python +@T.prim_func +def main( + A: T.Tensor((m,), dtype), + B: T.Tensor((m + n,), dtype), + C: T.Tensor((n * k,), dtype), +): + ... +``` + +This enables enforcing cross-tensor relationships like `len(B) == m + n` and `len(C) == n * k` at runtime. + +--- + +## Nullability Rules and Examples +Which tensors may be NULL? + +- Rule: If an input tensor is not used by the function under static analysis (i.e., the access is statically unreachable), it is considered Nullable; otherwise it must be non-NULL. +- Examples: + +1) Must be non-NULL (used) +```python +@T.prim_func +def main(A: T.Tensor((M, K), dtype)): + A[0] = 1 +``` +Passing `None` raises: `main.A_handle is expected to have non-NULL pointer`. + +2) Still must be non-NULL (constant-true branch) +```python +some_cond: bool = True +@T.prim_func +def main(A: T.Tensor((M, K), dtype)): + if some_cond: + A[0] = 1 +``` + +3) Nullable (constant-false branch, statically unreachable) +```python +some_cond: bool = False +@T.prim_func +def main(A: T.Tensor((M, K), dtype)): + if some_cond: + A[0] = 1 +``` + +4) Must be non-NULL (runtime condition) +```python +@T.prim_func +def main(A: T.Tensor((M, K), dtype), some_cond: T.bool): + if some_cond: + A[0] = 1 +``` +Since `some_cond` is only known at runtime, static analysis cannot prove `A` is unused; `A` is thus non-nullable. + +--- + +## Device Type Codes (DLPack) +Supported and referenced device codes in error messages: `1=CPU, 2=CUDA, 7=Vulkan, 8=Metal, 10=ROCM, 14=OneAPI, 15=WebGPU`. +Kernels assert that `device_type` matches the target backend, and require `device_id` consistency across tensors. + +--- + +## Common Error Examples (What you’ll see) +- Argument count mismatch (num_args) + - Trigger: missing/extra argument + - Error: `: num_args should be N; expected: , got: N` + +- Pointer-typed argument expected + - Trigger: scalar passed where a tensor is expected + - Error: `: Expect arg[i] to be pointer` + +- Rank (ndim) mismatch + - Trigger: runtime rank differs from compile-time rank + - Error: `..ndim is expected to equal R, but got mismatched ndim` + +- Dtype mismatch + - Trigger: dtype not equal to the compiled dtype and not within the tolerance set + - Error: `..dtype is expected to be , but got incompatible dtype` + +- Shape constraint violation + - Trigger: a dimension doesn’t match a constant/symbol binding + - Error: `Argument ..shape[i] has an unsatisfied constraint: ... == ` + +- Strides check failed (e.g., non-contiguous layout) + - Trigger: transposed/sliced tensors that violate expected strides + - Error: `Argument ..strides[j] has an unsatisfied constraint: ... == ` + +- Device type mismatch + - Trigger: calling a CUDA kernel with CPU tensors, etc. + - Error: `..device_type mismatch [expected: ()] ...` + +- Device id mismatch + - Trigger: mixing tensors from different GPUs + - Error: `Argument ..device_id has an unsatisfied constraint: ... == ...` + +- NULL data pointer + - Trigger: tensor required to be non-null has a NULL data pointer + - Error: `. is expected to have non-NULL data pointer, but got NULL` + +- Scalar type mismatch + - Trigger: passing float to `T.int32`, or non-boolean to `T.bool` + - Error: `: Expect arg[i] to be int/boolean` + +--- + +## Troubleshooting Tips +- Print the host source: `print(fn.get_host_source())` to see the exact assertion and expected vs. actual fields. +- Fix strides: call `.contiguous()` for non-contiguous tensors, or avoid generating transposed/sliced layouts that break assumptions. +- Align devices: ensure all participating tensors share the same `device_type` and `device_id`. +- Align dtype: use `.to()` or construct tensors with the correct dtype; pay attention to `float8` and `bool` tolerance. +- Dynamic shapes: ensure cross-tensor linear relations can be uniquely determined at the check point (only one unknown at a time). + +--- + +## FAQ +- Can I disable the checks? + - Not recommended and usually not supported. Checks are done on the host to preserve ABI stability and fail early close to the device call. +- Is the overhead noticeable? + - The checks are lightweight (branches and field reads). Compared to Python-side checks, it’s faster; the dominating cost remains the Python→C boundary. Overall it’s cheaper than equivalent checks in Python. + +--- + +## Reference Example (Matmul + ReLU) + +```python +@T.prim_func +def matmul_relu_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), +): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + T.copy(C_local, C[by * block_M, bx * block_N]) + +# For debugging, print the host source +print(matmul_relu_kernel.get_host_source()) +``` + +The host will insert all checks described above for this example. + +--- + +## Quick Error Reference (Short List) +- Argument count + - Trigger: missing/extra args; Error: `num_args should be N; expected: , got: N`. +- Pointer kind + - Trigger: scalar passed to tensor arg; Error: `Expect arg[i] to be pointer`. +- Rank (ndim) + - Trigger: runtime rank != compile-time; Error: `ndim ... expected to equal R`. +- Dtype + - Trigger: mismatch and not tolerated; Error: `dtype ... expected to be `. +- Shape + - Trigger: constant/symbol binding violated; Error: `shape[i] ... == `. +- Strides + - Trigger: layout mismatch; Error: `strides[j] ... == `. +- Device type + - Trigger: wrong backend device; Error: `device_type mismatch [expected: ...]`. +- Device id + - Trigger: tensors on different GPUs; Error: `device_id ... == ...`. +- Data pointer + - Trigger: required non-NULL but NULL; Error: `non-NULL data pointer`. +- Scalar types + - Trigger: wrong scalar type; Error: `Expect arg[i] to be int/boolean`. + +--- + +## Host Error Troubleshooting (Minimal Repros) + +Below are minimal repro snippets for common host-side errors, assuming a CUDA-targeted kernel like `matmul_relu_kernel` with: + +```python +# Convention: +# A: float16 [M, K] +# B: float16 [K, N] +# C: float16 [M, N] +# Target: CUDA (device_type=2) +fn = matmul_relu_kernel # your compiled function +M = N = K = 1024 +``` + +Adjust dtype/device if your kernel differs. + +### 0. Tip: print the host source +```python +print(fn.get_host_source()) +``` + +### 1. num_args mismatch +```python +import torch + +A = torch.empty((M, K), device='cuda', dtype=torch.float16) +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +# Missing C +fn(A, B) +``` +Expected: `: num_args should be 3; expected: , got: 3`. + +Fix: pass all arguments per the signature. + +### 2. Expect pointer (tensor) but got scalar +```python +import torch + +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(1, B, C) +``` +Expected: `: Expect arg[0] to be pointer`. + +Fix: pass a DLPack-compatible tensor (e.g., torch.Tensor). + +### 3. ndim mismatch +```python +import torch + +A = torch.empty((M, K, 1), device='cuda', dtype=torch.float16) # rank=3 +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `.A_handle.ndim is expected to equal 2, but got mismatched ndim`. + +Fix: ensure runtime rank equals compiled rank. + +### 4. dtype mismatch +```python +import torch + +A = torch.empty((M, K), device='cuda', dtype=torch.float32) # should be float16 +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `.A_handle.dtype is expected to be float16, but got incompatible dtype`. + +Fix: `A = A.to(torch.float16)` or create with the correct dtype. + +### 5. Shape constant/symbol mismatch +```python +import torch + +A = torch.empty((M, K + 1), device='cuda', dtype=torch.float16) # K mismatched +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `Argument .A_handle.shape[i] has an unsatisfied constraint: ... == `. + +Fix: satisfy linear constraints and constants across tensors. + +### 6. Strides check failure (non-contiguous) +```python +import torch + +A = torch.empty((M, K), device='cuda', dtype=torch.float16) +A_nc = A.t() # transpose -> non-contiguous +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A_nc, B, C) +``` +Expected: `Argument .A_handle.strides[1] has an unsatisfied constraint: ... == 1`. + +Fix: pass `A_nc.contiguous()` or align the layout expectation in the kernel. + +### 7. device_type mismatch +```python +import torch + +A = torch.empty((M, K), device='cpu', dtype=torch.float16) +B = torch.empty((K, N), device='cpu', dtype=torch.float16) +C = torch.empty((M, N), device='cpu', dtype=torch.float16) +fn(A, B, C) # CUDA-targeted kernel +``` +Expected: `.A_handle.device_type mismatch [expected: 2 (cuda)] ...`. + +Fix: move tensors to the CUDA device. + +### 8. device_id mismatch (multi-GPU) +```python +import torch + +A = torch.empty((M, K), device='cuda:0', dtype=torch.float16) +B = torch.empty((K, N), device='cuda:1', dtype=torch.float16) +C = torch.empty((M, N), device='cuda:0', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `Argument .B_handle.device_id has an unsatisfied constraint: ... == ...`. + +Fix: place all tensors on the same GPU (e.g., `cuda:0`). + +### 9. NULL data pointer (advanced) +This usually comes from hand-constructed DLTensor/NDArray, or external frameworks passing unallocated/freed storage. Regular `torch.Tensor` allocations rarely hit this. + +Expected: `. is expected to have non-NULL data pointer, but got NULL`. + +Fix: ensure valid underlying storage; in PyTorch scenarios, avoid constructing tensors from invalid external handles. + +### 10. Scalar type mismatch (int / bool) +```python +import tilelang.language as T + +@T.prim_func +def scalar_check(x: T.int32, flag: T.bool()): + T.evaluate(0) + +scalar_check(1.0, True) # x is float -> Expect arg[0] to be int +scalar_check(1, 2.5) # flag is float -> Expect arg[1] to be boolean +``` + +Fix: pass correct scalar types, e.g., `scalar_check(1, True)`. + +--- + +## Closing Notes +- Cross-check “shape / strides / device / dtype” against the kernel signature to localize issues efficiently. +- For complex symbolic relations, print the host source to confirm binding/solving order, then adjust runtime shapes/layouts accordingly. diff --git a/docs/conf.py b/docs/conf.py index 1b1289038..877b5582e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,5 +1,5 @@ # General information about the project. -project = "Tile Language
" +project = "TileLang
" author = "Tile Lang Contributors" copyright = f"2025-2025, {author}" @@ -20,33 +20,27 @@ "autoapi.extension", ] -autoapi_type = 'python' -autoapi_dirs = ['../tilelang'] +autoapi_type = "python" +autoapi_dirs = ["../tilelang"] autoapi_options = [ - 'members', - 'undoc-members', - 'show-inheritance', - 'show-module-summary', - 'special-members', + "members", + "undoc-members", + "show-inheritance", + "show-module-summary", + "special-members", ] autoapi_keep_files = False # Useful for debugging the generated rst files autoapi_generate_api_docs = True -autodoc_typehints = 'description' +autodoc_typehints = "description" autoapi_ignore = ["*language/ast*", "*version*", "*libinfo*", "*parser*"] -source_suffix = { - '.rst': 'restructuredtext', - '.md': 'markdown', -} +source_suffix = {".rst": "restructuredtext", ".md": "markdown"} -myst_enable_extensions = [ - "colon_fence", - "deflist", -] +myst_enable_extensions = ["colon_fence", "deflist"] redirects = {"get_started/try_out": "../index.html#getting-started"} @@ -62,13 +56,11 @@ html_theme = "furo" templates_path = [] html_static_path = ["_static"] -footer_copyright = "© 2025-2025 Tile Language" +html_css_files = ["custom.css"] +footer_copyright = "© 2025-2026 TileLang" footer_note = " " -html_theme_options = { - "light_logo": "img/logo-row.svg", - "dark_logo": "img/logo-row.svg", -} +html_theme_options = {"light_logo": "img/logo-v2.png", "dark_logo": "img/logo-v2.png"} header_links = [ ("Home", "https://github.com/tile-ai/tilelang"), diff --git a/docs/deeplearning_operators/deepseek_mla.md b/docs/deeplearning_operators/deepseek_mla.md index 08175778f..ed02b58b1 100644 --- a/docs/deeplearning_operators/deepseek_mla.md +++ b/docs/deeplearning_operators/deepseek_mla.md @@ -1,8 +1,7 @@ # 🚀 Write High Performance FlashMLA with TileLang on Hopper -
- Author: Yu Cheng + Author: Yu Cheng Author: Lei Wang
@@ -32,14 +31,14 @@ Figure 1: Performance under batch size=64 Figure 2: Performance under batch size=128 ``` -As shown in the results, TileLang achieves performance comparable to FlashMLA in most cases, significantly outperforming both FlashInfer and Triton. +As shown in the results, TileLang achieves performance comparable to FlashMLA in most cases, significantly outperforming both FlashInfer and Triton. Notably, **TileLang accomplishes this with just around 80 lines of Python code**, demonstrating its exceptional ease of use and efficiency. Let's dive in and see how TileLang achieves this. ## Implementation First, let's review the core computation logic of traditional FlashAttention: -```python +```python # acc_s: [block_M, block_N] # scores_max: [block_M] # scores_scale: [block_M] @@ -62,7 +61,7 @@ Compared to traditional attention operators like MHA (Multi-Headed Attention) or This raises the question of how to partition the matrix multiplication operation. On the Hopper architecture, most computation kernels use [`wgmma.mma_async`](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) instructions for optimal performance. The `wgmma.mma_async` instruction organizes 4 warps (128 threads) into a warpgroup for collective MMA operations. However, `wgmma.mma_async` instructions require a minimum M dimension of 64. This means each warpgroup's minimum M dimension can only be reduced to 64, but a tile size of 64*512 is too large for a single warpgroup, leading to register spilling. -Therefore, our only option is to partition `acc_o` along the `dim` dimension, with two warpgroups computing the left and right part of `acc_o` respectively. However, this introduces another challenge: both warpgroups require the complete `acc_s` result as input. +Therefore, our only option is to partition `acc_o` along the `dim` dimension, with two warpgroups computing the left and right part of `acc_o` respectively. However, this introduces another challenge: both warpgroups require the complete `acc_s` result as input. Our solution is to have each warpgroup compute half of `acc_s` during `Q @ K` computation, then obtain the other half computed by the other warpgroup through shared memory. @@ -106,7 +105,6 @@ T.use_swizzle(panel_size: int, order: str = "row") Here, `panel_size` specifies the width of the swizzled threadblock group, and `order` determines the swizzling pattern, which can be either "row" or "col". - ### Shared Memory Swizzling In CUDA programming, shared memory is divided into multiple memory banks, with each bank capable of servicing one thread request per clock cycle in parallel. Bank conflicts occur when multiple threads simultaneously access different addresses mapped to the same bank, forcing these accesses to be serialized and degrading performance. @@ -123,17 +121,14 @@ T.annotate_layout({ Here, `T.annotate_layout` allows users to specify any desired layout for a buffer. For convenience, TileLang provides the `make_swizzled_layout` primitive to automatically generate a swizzled layout. - ### Warp-Specialization The Hopper architecture commonly employs warp specialization for performance optimization. A typical approach is to designate one warpgroup as a producer that handles data movement using TMA (Tensor Memory Accelerator), while the remaining warpgroups serve as consumers performing computations. However, this programming pattern is complex, requiring developers to manually manage the execution logic for producers and consumers, including synchronization through the `mbarrier` objects. In TileLang, users are completely shielded from these implementation details. The frontend script is automatically transformed into a warp-specialized form, where TileLang handles all producer-consumer synchronization automatically, enabling efficient computation. - ### Pipeline - Pipeline is a technique used to improve memory access efficiency by overlapping memory access and computation. In TileLang, pipeline can be implemented through the `T.pipelined` annotation: ```python @@ -142,14 +137,12 @@ T.pipelined(range: int, stage: int) Here, `range` specifies the range of the pipeline, and `stage` specifies the stage of the pipeline. Multi-stage pipelining enables overlapping of computation and memory access, which can significantly improve performance for memory-intensive operators. However, setting a higher number of stages consumes more shared memory resources, so the optimal configuration needs to be determined based on specific use cases. - ### Split-KV We have also implemented Split-KV optimization similar to [FlashDecoding](https://pytorch.org/blog/flash-decoding/). Specifically, when the batch size is small, parallel SM resources cannot be fully utilized due to low parallelism. In such cases, we can split the kv_ctx dimension across multiple SMs for parallel computation and then merge the results. In our implementation, we have developed both split and combine kernels, allowing users to control the split size through a `num_split` parameter. - ## 🚀 On AMD MI300X Accelerators Following our previous demonstration of [high-performance FlashMLA implementation on NVIDIA Hopper architectures using TileLang](https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_mla/README.md), this work presents an optimized implementation for AMD MI300X accelerators. We examine architectural differences and corresponding optimization strategies between these platforms. @@ -167,7 +160,7 @@ Key implementation differences between Hopper and MI300X architectures include: # Original shared memory allocation Q_shared = T.alloc_shared([block_H, dim], dtype) Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) - + # Optimized register allocation Q_local = T.alloc_fragment([block_H, dim], dtype) Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype) diff --git a/docs/deeplearning_operators/elementwise.md b/docs/deeplearning_operators/elementwise.md index 5e1243c26..6aa8e4085 100644 --- a/docs/deeplearning_operators/elementwise.md +++ b/docs/deeplearning_operators/elementwise.md @@ -8,7 +8,7 @@ :class: myclass1 myclass2 :name: a-tip-reference - This document is still **experimental** and may be incomplete. + This document is still **experimental** and may be incomplete. Suggestions and improvements are highly encouraged—please submit a PR! ::: @@ -24,7 +24,7 @@ Please note that this tutorial does not delve deeply into the design principles ## Elementwise add in TileLang ```python -def elementwise_add(N, threads=256, dtype="bfloat16"): +def elementwise_add(N, threads=256, dtype=T.bfloat16): @T.prim_func def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): @@ -43,7 +43,7 @@ Those familiar with CUDA programming might wonder where `threadIdx` fits into th The program can be compiled using the following code: ```python -program = elementwise_add(1024, threads=256, dtype="bfloat16") +program = elementwise_add(1024, threads=256, dtype=T.bfloat16) kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython") ``` Launching the kernel is straightforward, just call it directly like a function: @@ -89,7 +89,7 @@ def elementwise_add( In the compilation process above, a fixed shape was used. However, in practical usage, we often want the kernel to support dynamic shapes. So, how can we compile a kernel in TileLang to handle dynamic shapes? In TileLang, we can replace the target size with a dynamic symbolic value, making the dimension dynamic. The following example illustrates this: ```python -program = elementwise_add(T.dynamic("N"), threads=256, dtype="bfloat16") +program = elementwise_add(T.dynamic("N"), threads=256, dtype=T.bfloat16) kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython") ``` @@ -102,7 +102,7 @@ TileLang automatically incorporates boundary-checking conditions; however, this When compiling the example below, let's set `N` to 2047: ```python -def elementwise_add(N, num_per_thread=8, threads=256, dtype="bfloat16"): +def elementwise_add(N, num_per_thread=8, threads=256, dtype=T.bfloat16): @T.prim_func def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): @@ -176,7 +176,7 @@ While TileLang incorporates various optimizations for the aforementioned case, i In such scenarios, explicitly specifying the number of elements computed per thread can help "guide" TileLang's code generation process, leading to implementations that are more closely aligned with the intended design. ```python -def elementwise_add(N, num_per_thread=8, threads=256, dtype="bfloat16"): +def elementwise_add(N, num_per_thread=8, threads=256, dtype=T.bfloat16): @T.prim_func def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): @@ -212,7 +212,7 @@ Aha, this CUDA code aligns closely with conventional programming practices, maki But what happens if we provide additional hints to TileLang? For instance, by explicitly specifying register copies using the `T.copy(...)` operation. The example below demonstrates a vector addition implementation. Unlike the previous examples, this code explicitly loads data into registers before performing computations. ```python -def elementwise_add(N, NUM_ELE_PER_THREAD=8, threads=256, dtype="bfloat16"): +def elementwise_add(N, NUM_ELE_PER_THREAD=8, threads=256, dtype=T.bfloat16): @T.prim_func def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): @@ -280,8 +280,8 @@ To evaluate complexity, one could implement the same elementwise addition operat ```c++ template -__global__ void elementwise_add(nv_bfloat16* C, - const nv_bfloat16* A, +__global__ void elementwise_add(nv_bfloat16* C, + const nv_bfloat16* A, const nv_bfloat16* B, int N) { using namespace cute; diff --git a/docs/deeplearning_operators/gemv.md b/docs/deeplearning_operators/gemv.md index c75a961b8..38287f220 100644 --- a/docs/deeplearning_operators/gemv.md +++ b/docs/deeplearning_operators/gemv.md @@ -6,7 +6,7 @@ :::{warning} - This document is still **experimental** and may be incomplete. + This document is still **experimental** and may be incomplete. Suggestions and improvements are highly encouraged—please submit a PR! ::: @@ -206,7 +206,6 @@ def splitk_gemv( return main ``` - ## Vectorized Reads GEMV is less computation intensive than GEMM as the computation intensity and memory throughput will be the optimization bottleneck. One effective strategy is to use vectorized load/store operations (e.g., `float2`, `float4`). In `TileLang`, you can specify vectorized operations via `T.vectorized`: @@ -254,7 +253,6 @@ def splitk_gemv_vectorized( With vectorized read, now the kernel finishes in **~0.0084 ms**, which is getting close to cuBLAS performance. - ## `tvm_thread_allreduce` Instead of `atomicAdd` [`tvm_thread_allreduce`](https://tvm.apache.org/docs/reference/api/python/tir/tir.html#tvm.tir.tvm_thread_allreduce) has implemented optimization when making an all-reduce across a number of threads, which should outperfrom out plain smem + `atomidAdd`: @@ -459,6 +457,5 @@ This corresponds closely to our `TileLang` program, with necessary synchronizati | splitk_gemv_vectorized | 0.00809 ms | | splitk_gemv_vectorized_tvm | 0.00675 ms | - Triton Time: 0.0077344514429569244 -In this tutorial, we implemented a simple GEMV kernel and learn that `TileLang` exposes low level control to user such as thread-level programming and CUDA primitives. \ No newline at end of file +In this tutorial, we implemented a simple GEMV kernel and learn that `TileLang` exposes low level control to user such as thread-level programming and CUDA primitives. diff --git a/docs/deeplearning_operators/matmul.md b/docs/deeplearning_operators/matmul.md index fea036ebe..12189eb8f 100644 --- a/docs/deeplearning_operators/matmul.md +++ b/docs/deeplearning_operators/matmul.md @@ -14,11 +14,11 @@ TileLang is a domain-specific language (DSL) designed for writing high-performance GPU kernels. It provides three main levels of abstraction: -* **Level 1:** A user writes pure compute logic without knowledge of or concern for hardware details (e.g., GPU caches, tiling, etc.). The compiler or runtime performs automatic scheduling and optimization. This level is conceptually similar to the idea behind TVM. +- **Level 1:** A user writes pure compute logic without knowledge of or concern for hardware details (e.g., GPU caches, tiling, etc.). The compiler or runtime performs automatic scheduling and optimization. This level is conceptually similar to the idea behind TVM. -* **Level 2:** A user is aware of GPU architecture concepts—such as shared memory, tiling, and thread blocks—but does not necessarily want to drop down to the lowest level of explicit thread control. This mode is somewhat comparable to Triton's programming model, where you can write tile-level operations and let the compiler do layout inference, pipelining, etc. +- **Level 2:** A user is aware of GPU architecture concepts—such as shared memory, tiling, and thread blocks—but does not necessarily want to drop down to the lowest level of explicit thread control. This mode is somewhat comparable to Triton's programming model, where you can write tile-level operations and let the compiler do layout inference, pipelining, etc. -* **Level 3:** A user takes full control of thread-level primitives and can write code that is almost as explicit as a hand-written CUDA/HIP kernel. This is useful for performance experts who need to manage every detail, such as PTX inline assembly, explicit thread behavior, etc. +- **Level 3:** A user takes full control of thread-level primitives and can write code that is almost as explicit as a hand-written CUDA/HIP kernel. This is useful for performance experts who need to manage every detail, such as PTX inline assembly, explicit thread behavior, etc. ```{figure} ../_static/img/overview.png :width: 50% @@ -52,12 +52,12 @@ While Level 1 in TileLang can be very comfortable for general users—since it r Below is a simplified code snippet for a 1024 x 1024 x 1024 matrix multiplication. It uses: -* **`T.Kernel(...)`** to initialize the thread block configuration (grid dimensions, block size, etc.). -* **`T.alloc_shared(...)`** to allocate GPU shared memory. -* **`T.alloc_fragment(...)`** to allocate a register fragment for accumulation. -* **`T.Pipelined(...)`** to express software pipelining across the K dimension. -* **`T.Parallel(...)`** to parallelize data copy loops. -* **`T.gemm(...)`** to perform tile-level GEMM operations (which map to the appropriate backends, such as MMA instructions on NVIDIA GPUs). +- **`T.Kernel(...)`** to initialize the thread block configuration (grid dimensions, block size, etc.). +- **`T.alloc_shared(...)`** to allocate GPU shared memory. +- **`T.alloc_fragment(...)`** to allocate a register fragment for accumulation. +- **`T.Pipelined(...)`** to express software pipelining across the K dimension. +- **`T.Parallel(...)`** to parallelize data copy loops. +- **`T.gemm(...)`** to perform tile-level GEMM operations (which map to the appropriate backends, such as MMA instructions on NVIDIA GPUs). ```python import tilelang @@ -147,14 +147,12 @@ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, - This sets up the block grid dimensions based on N/block_N and M/block_M. - `threads=128` specifies that each thread block uses 128 threads. The compiler will infer how loops map to these threads. - ```{figure} ../_static/img/Parallel.png :alt: Parallel :align: center ``` - 2. **Shared & Fragment Memory**: ```python @@ -182,7 +180,6 @@ for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): ``` - 4. **Parallel Copy**: ```python @@ -252,8 +249,8 @@ For more advanced usage—including partial lowering, explicitly controlling thr ## Further Resources -* [TileLang GitHub](https://github.com/tile-ai/tilelang) -* [BitBLAS](https://github.com/tile-ai/bitblas) -* [Triton](https://github.com/openai/triton) -* [Cutlass](https://github.com/NVIDIA/cutlass) -* [PyCUDA](https://documen.tician.de/pycuda/) +- [TileLang GitHub](https://github.com/tile-ai/tilelang) +- [BitBLAS](https://github.com/tile-ai/bitblas) +- [Triton](https://github.com/openai/triton) +- [Cutlass](https://github.com/NVIDIA/cutlass) +- [PyCUDA](https://documen.tician.de/pycuda/) diff --git a/docs/deeplearning_operators/matmul_sparse.md b/docs/deeplearning_operators/matmul_sparse.md new file mode 100644 index 000000000..8caa6182f --- /dev/null +++ b/docs/deeplearning_operators/matmul_sparse.md @@ -0,0 +1,261 @@ +# Sparse Matrix-Matrix Multiplication with Tile Library + +
+ Author: botbw +
+ +:::{warning} + This document is still **experimental** and may be incomplete. + + This feature is still **experimental** and need further optimization. + + Suggestions and improvements are highly encouraged—please submit a PR! +::: + +:::{tip} +It's suggested to go through `docs/deeplearning_operators/matmul.md` first. + +Example code can be found at `examples/gemm_sp`. +::: + +## Structured sparsity in the NVIDIA Ampere architecture + +Since the Ampere architecture (sm80 and above), sparsity support has been integrated into Tensor Cores. This allows a 2:4 (or 1:2 for 32-bit data types) semi-structured matrix to be compressed into its non-zero values along with associated metadata, which can then be fed into the Tensor Core. This enables up to **2x throughput** compared to the equivalent dense computation. + +:::{warning} + This tutorial primarily focuses on CUDA, as this feature is not yet supported on ROCm. However, AMD provides a similar capability in the matrix cores of GPUs such as the MI300X. +::: + +```{figure} ../_static/img/sparse_mma_storage_example.png +:align: center + +Figure: Sparse MMA storage example (from PTX doc) +``` + +## Compress a dense tensor + +To utilize sparse Tensor Cores, a dense tensor must first be **compressed** into its non-zero values along with the corresponding metadata. + +Both `PyTorch` and `vLLM` use `CUTLASS` as their computation backend (see references [here](https://github.com/pytorch/pytorch/blob/a8d6afb511a69687bbb2b7e88a3cf67917e1697e/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu#L47) and [here](https://github.com/vllm-project/vllm/blob/a5dd03c1ebc5e4f56f3c9d3dc0436e9c582c978f/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh#L116)), leveraging `CUTLASS`’s built-in compressor (or reimplementing it in `PyTorch`). + +A set of **CUTLASS-compatible** compressors is provided in `tilelang.utils.sparse`, where a dense tensor—along with other required arguments (e.g., block_K for sm90, transpose options)—can be passed in to perform the compression. + +```python +from tilelang.utils.sparse import compress +A_sparse, E = compress(A, transposed=trans_A, block_k=block_K) +``` + +Here, `A_sparse` contains all the non-zero elements of `A`, while `E` stores the corresponding metadata (indexing information) required to reconstruct the original sparse pattern. + +> NOTE: When using CUTLASS compressor, there is no naive position correspondence between the positions in `A_sparse`/`A` and `E`. (i.e. the 4-element group at [n, k] doesn't match the 4-bit metadata at [n, k] if you consider metadata as int4 tensor) +The metadata is reordered internally to optimize memory access patterns (e.g., for ldsm instructions and vectorized loads). +For more information, see **A note on `gemm_sp` and `gemm_sp_v2`**. + +## `T.gemm_sp` with CUTLASS's compressor + +:::{warning} + +It is strongly recommended to use T.gemm_sp_v2 due to its greater flexibility and faster compilation time. + +::: + +A 2:4 sparse GEMM kernel is similar to its dense counterpart, except that it also requires handling the associated metadata. + +Check comments in below kernel code for required modification. + +```python +def matmul_sp_sm80( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, + trans_A, + trans_B, +): + is_8_bit = "8" in in_dtype + metadata_dtype = 'int32' if is_8_bit else 'int16' + E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype] # Calculate shape for given datatypes + A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) + B_shape = (K, N) if not trans_B else (N, K) + A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) + B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K) + + import tilelang.language as T + + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) # Allocate smem for metadata + C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) + T.annotate_layout({ # Annotate reordered cutlass metadata layout + E: + make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: + make_cutlass_metadata_layout( + E_shared, mma_dtype=in_dtype, arch="8.0"), + }) + T.clear(C_frag) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(E[by * block_M, k * block_K // E_factor], E_shared) + if trans_A: + T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) + else: + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B) # Call gemm_sp with non-zero values and metadata + T.copy(C_frag, C[by * block_M, bx * block_N]) + + return main +``` + +Under the hood, `gemm_sp` invokes templates adapted from `CUTLASS`, and a compatible metadata layout must be specified using `T.annotate_layout`. + +## `T.gemm_sp_v2` with a custom compressor + +To migrate to `gemm_sp_v2`, simply replace occurrences of `gemm_sp`. + +Unlike `gemm_sp`, `gemm_sp_v2` can operate without `T.annotate_layout`, and it also supports user-defined layouts and compressors. + +The metadata is stored in a `(u)int8`/`(u)int16`/`(u)int32` tensor, where **each 4-bit chunk represents two 2-bit indices** of non-zero elements within four consecutive elements. Here, we start with an `int16` example, which is the **default dtype** for `bf16` and `fp16` on Ampere GPUs. + +Suppose we have the following row vector: +```python +t = tensor([[0, 7, 0, 3], [1, 5, 0, 0], [0, 0, 2, 4], [9, 0, 9, 0]], dtype=torch.float16).flatten() +``` + +The non-zero elements and their corresponding indices are: + +```python +t_sp = tensor([[7, 3], [1, 5], [2, 4], [9, 9]], dtype=torch.float16).flatten() +indices = tensor([[1, 3], [0, 1], [2, 3], [0, 2]], dtype=torch.float16).flatten() +``` + +The corresponding uint16 metadata is: +```python +# metadata_bits = tensor([0b1101, 0b0100, 0b1110, 0b1000]) +# Note: storage uses little-endian order: tensor(0b1000111001001101, dtype=torch.int16) +# Note: the above code is not runnable in python as the interpreter won't take the binary +# as 2's complement +metadata_int16 = tensor(-29107) +``` + +You can decode an int16 metadata tensor using the following utility: +```python +def decode_metadata(meta: torch.Tensor) -> torch.Tensor: + assert meta.dtype is torch.int16 + groups_per_meta = 16 // 4 + out = [] + for g in range(groups_per_meta): + group_bits = (meta >> (g * 4)) & 0xF + idx0 = group_bits & 0x3 + idx1 = (group_bits >> 2) & 0x3 + out.append(torch.stack([idx0, idx1], dim=-1)) + return torch.concat(out, dim=-1).view(meta.shape[0], -1) +``` + +The compressor can be implement at either `PyTorch`/`NumPy` level or kernel level. + +For example, `PyTorch` provides an Ampere compressor [here](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L47-L179). Note that in this implementation, a [permutation](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L173-L175) is applied to match CUTLASS’s metadata layout. If you do not annotate a metadata layout when using `gemm_sp_v2`, your compressor should replicate the same behavior as the PyTorch example—but without using the `_calculate_meta_reordering_scatter_offsets` function. + +If you want to use a custom metadata layout in your kernel, one approach is to define the layout in `TileLang` and then apply the same layout to both your compressor kernel and the matmul_sp kernel. + +```python + +@tilelang.jit(out_idx=[1, 2], pass_configs={ + tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True, +}) +def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): + e_factor, e_dtype = ARCH_INFO["8.0"] + e_K = K // e_factor + elem, group = 2, 4 + + assert M % block_M == 0, "M must be divisible by block_M" + assert K % block_K == 0, "K must be divisible by block_K" + assert K % e_factor == 0, "K must be divisible by e_factor" + assert block_K % e_factor == 0, "block_K must be divisible by e_factor" + + @T.prim_func + def kernel( + A: T.Tensor((M, K), dtype), + A_sp: T.Tensor((M, K // 2), dtype), + E: T.Tensor((M, e_K), e_dtype), + ): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype) + E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) + if use_cutlass_layout: # NOTE: Make sure compressor metadata layout + T.annotate_layout({ # is same with your computation kernel + E: + make_cutlass_metadata_layout( + E, mma_dtype="float16", arch="8.0", block_k=block_K), + E_shared: + make_cutlass_metadata_layout( + E_shared, + mma_dtype="float16", + arch="8.0", + block_k=block_K), + }) + T.clear(A_sp_shared) + T.clear(E_shared) + non_zero_cnt = T.alloc_local((1, ), dtype="uint8") + non_zero_elt_log_idx = T.alloc_local((elem, ), dtype="uint8") + T.copy(A[bx * block_M, by * block_K], A_shared) + for tm in T.Parallel(block_M): + for g_i in range(0, block_K // group): + a_k = g_i * group + T.clear(non_zero_cnt) + T.clear(non_zero_elt_log_idx) + for i in range(group): + val = A_shared[tm, a_k + i] + if val != 0.0: + non_zero_elt_log_idx[non_zero_cnt[0]] = i + A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val + non_zero_cnt[0] += 1 + if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3: + non_zero_elt_log_idx[0] = 0 + non_zero_elt_log_idx[1] = 3 + A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2] + A_sp_shared[tm, a_k // 2] = 0.0 + elif non_zero_cnt[0] == 1: + A_sp_shared[tm, a_k // 2 + 1] = 0 + non_zero_elt_log_idx[1] = 3 + for i in T.serial(elem): + val = non_zero_elt_log_idx[i] + E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i) + T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2]) + T.copy(E_shared, E[bx * block_M, by * block_K // e_factor]) + + return kernel +``` + +## A note on `gemm_sp` and `gemm_sp_v2` + +Initially, `T.gemm_sp` followed the same design as `T.gemm`, lowering to a `CUTLASS` template. This inherently requires metadata to be reordered offline following a predetermined layout. + +However, fixing a specific layout introduces several potential issues: + +1. Painful debugging experience: Debugging a failed kernel becomes difficult due to the reordered indexing, including permutations and swizzling. + +2. Limited flexibility: For example, concatenating two compressed tensors, such as `A_sparse_0` and `A_sparse_1`, into a new `A_sparse` makes sense. However, concatenating their metadata `E_0` and `E_1` may not be valid unless the layout allows it mathematically. + +3. Alignment requirements: `CUTLASS` enforces strict alignment checks, and many hyperparameter configurations can lead to compilation errors. (For reference, sm8x was implemented in `CUTLASS 2`.) + +`T.gemm_sp_v2` was designed to address these limitations, following the approach of `T.gemm_v2`. It lowers directly to PTX, removing the need for a fixed metadata layout. diff --git a/docs/get_started/Installation.md b/docs/get_started/Installation.md index f441d1a83..ea980b59b 100644 --- a/docs/get_started/Installation.md +++ b/docs/get_started/Installation.md @@ -15,7 +15,7 @@ We currently provide three methods to install **TileScale**: ```bash docker pull nvcr.io/nvidia/pytorch:25.03-py3 -docker run --name tilescale --ipc=host --network=host --privileged --cap-add=SYS_ADMIN --shm-size=10g --gpus=all -it nvcr.io/nvidia/pytorch:25.03-py3 /bin/bash +docker run --name tilescale --ipc=host --network=host --privileged --cap-add=SYS_ADMIN --shm-size=10g --gpus=all -it nvcr.io/nvidia/pytorch:25.03-py3 /bin/bash echo -n > /etc/pip/constraint.txt bash Miniconda3-latest-Linux-x86_64.sh # install conda conda install -c conda-forge libstdcxx-ng @@ -44,7 +44,7 @@ Verify that **TileScale** is working correctly: python -c "import tilelang; print(tilelang.__version__)" ``` -You can now run TileScale examples and develop your applications. +You can now run TileScale examples and develop your applications. **Example Usage:** @@ -55,12 +55,11 @@ cd /home/tilelang TILELANG_USE_DISTRIBUTED=1 python examples/distributed/example_allgather_gemm_overlapped.py ``` - ## To use NVSHMEM APIs Before running the examples using NVSHMEM APIs (e.g., [example_allgather.py](../../examples/distributed/example_allgather.py)), you need to build NVSHMEM library for device-side code generation. -```bash +```bash pip install mpich # building NVSHMEM needs MPI export NVSHMEM_SRC="your_custom_nvshmem_dir" # default to 3rdparty/nvshmem_src cd tilelang/distributed diff --git a/docs/get_started/overview.md b/docs/get_started/overview.md index 18fa9f193..a7c154f31 100644 --- a/docs/get_started/overview.md +++ b/docs/get_started/overview.md @@ -15,49 +15,49 @@ Figure 1: High-level overview of the TileLang compilation flow. ## Programming Interfaces 1. **Beginner Level (Hardware-Unaware)** - - Intended for users who need to write code that is independent of specific hardware details. - - The goal is to let developers focus on the basic logic without worrying about memory hierarchies or hardware-specific optimizations. + - Intended for users who need to write code that is independent of specific hardware details. + - The goal is to let developers focus on the basic logic without worrying about memory hierarchies or hardware-specific optimizations. - *Note:* This interface is not yet fully implemented. 2. **Developer Level (Hardware-Aware with Tile Library)** - - Designed for developers who have a basic understanding of GPU memory hierarchies and performance considerations. - - Provides a **Tile Library**, containing predefined operations and patterns optimized for various hardware architectures. + - Designed for developers who have a basic understanding of GPU memory hierarchies and performance considerations. + - Provides a **Tile Library**, containing predefined operations and patterns optimized for various hardware architectures. - Users at this level can leverage these ready-made primitives without diving into low-level threading details. 3. **Expert Level (Hardware-Aware with Thread Primitives)** - - For highly experienced users who have an in-depth understanding of low-level hardware characteristics (e.g., threading models, memory coalescing). - - Offers direct access to **thread primitives** and other low-level constructs, allowing for fine-grained control of performance-critical kernels. + - For highly experienced users who have an in-depth understanding of low-level hardware characteristics (e.g., threading models, memory coalescing). + - Offers direct access to **thread primitives** and other low-level constructs, allowing for fine-grained control of performance-critical kernels. - This level grants maximum flexibility for specialized optimizations tailored to specific GPU or multi-core architectures. ## Compilation Flow -1. **Tile Program** +1. **Tile Program** A high-level specification of the computation. Depending on the user’s expertise, they may write a purely hardware-unaware tile program or incorporate constructs from the Tile Library or thread primitives. -2. **Tile Program with Tile Library** +2. **Tile Program with Tile Library** When developers choose from the Tile Library, the original Tile Program is expanded with specialized library calls. These calls encapsulate efficient implementation patterns for different operations. -3. **Tile Program with Thread Primitives** +3. **Tile Program with Thread Primitives** Expert-level developers can explicitly use low-level threading constructs to hand-optimize data layout, synchronization, and memory usage. -4. **IRModule** +4. **IRModule** After the program is composed with libraries or thread primitives, it is lowered to an intermediate representation (IR) that captures the necessary hardware details. -5. **Source Code Generation (C/CUDA/HIP/LLVM/…)** +5. **Source Code Generation (C/CUDA/HIP/LLVM/…)** From the IR, the system generates target-specific source code. This source code is tuned for the desired backends or GPU architectures (e.g., NVIDIA, AMD). -6. **Hardware-Specific Executable/Runtime** +6. **Hardware-Specific Executable/Runtime** Finally, the generated source is compiled into hardware-specific executables, ready to run on the corresponding devices. The pipeline supports multiple GPU backends and can be extended to additional architectures. ## Tile-based Programming Model -[Figure 2](#fig-overview-gemm) provides a concise matrix multiplication (GEMM) example in ``TileLang``, -illustrating how developers can employ high-level constructs such as tiles, memory placement, pipelining, +[Figure 2](#fig-overview-gemm) provides a concise matrix multiplication (GEMM) example in ``TileLang``, +illustrating how developers can employ high-level constructs such as tiles, memory placement, pipelining, and operator calls to manage data movement and computation with fine-grained control. -In particular, this snippet ([Figure 2](#fig-overview-gemm) (a)) demonstrates how multi-level tiling -leverages different memory hierarchies (global, shared, and registers) to optimize bandwidth utilization +In particular, this snippet ([Figure 2](#fig-overview-gemm) (a)) demonstrates how multi-level tiling +leverages different memory hierarchies (global, shared, and registers) to optimize bandwidth utilization and reduce latency. -Overall, [Figure 2](#fig-overview-gemm) (b) showcases how the Python-like syntax of ``TileLang`` +Overall, [Figure 2](#fig-overview-gemm) (b) showcases how the Python-like syntax of ``TileLang`` allows developers to reason about performance-critical optimizations within a user-friendly programming model. ```{figure} ../_static/img/MatmulExample.png diff --git a/docs/get_started/run_example.md b/docs/get_started/run_example.md index aced5d5a8..e25f42fb8 100644 --- a/docs/get_started/run_example.md +++ b/docs/get_started/run_example.md @@ -5,11 +5,11 @@ Before running, enable TileLang’s distributed mode: ```bash -export TILELANG_USE_DISTRIBUTED=1 +export TILELANG_USE_DISTRIBUTED=1 ``` Then start an example directly with Python: ```bash - python examples/distributed/primitives/example_put_warp.py + python examples/distributed/primitives/example_put_warp.py ``` ## Examples using NVSHMEM APIs @@ -18,4 +18,4 @@ Use the provided launcher `tilelang/distributed/launch.sh` to start programs tha ```bash GPUS=2 ./tilelang/distributed/launch.sh examples/distributed/example_allgather.py ``` -You can change GPUS to the number of local GPUs you want to use. The launcher will set the required environment variables and invoke `torch.distributed.run`. \ No newline at end of file +You can change GPUS to the number of local GPUs you want to use. The launcher will set the required environment variables and invoke `torch.distributed.run`. diff --git a/docs/get_started/targets.md b/docs/get_started/targets.md index c2b3f2fb5..3a464bd66 100644 --- a/docs/get_started/targets.md +++ b/docs/get_started/targets.md @@ -14,6 +14,7 @@ the generated code. The most frequent choices are listed below: | --------- | ----------- | | `auto` | Detects CUDA → HIP → Metal in that order. Useful when running the same script across machines. | | `cuda` | NVIDIA GPUs. Supports options such as `-arch=sm_80`, `-max_num_threads=1024`, etc. | +| `cutedsl` | NVIDIA CUTLASS/CuTe DSL backend. Requires `nvidia-cutlass-dsl`. `cuda` options can also be applied to this target. | | `hip` | AMD GPUs via ROCm. Options like `-mcpu=gfx90a` can be appended. | | `metal` | Apple Silicon GPUs (arm64 Macs). | | `llvm` | CPU execution; accepts the standard TVM LLVM switches. | diff --git a/docs/index.md b/docs/index.md index 5d9a158f8..ca5a125eb 100644 --- a/docs/index.md +++ b/docs/index.md @@ -2,10 +2,10 @@ [GitHub](https://github.com/tile-ai/tilelang) -Tile Language (tile-lang) is a concise domain-specific language designed to streamline -the development of high-performance GPU/CPU kernels (e.g., GEMM, Dequant GEMM, FlashAttention, LinearAttention). -By employing a Pythonic syntax with an underlying compiler infrastructure on top of TVM, -tile-lang allows developers to focus on productivity without sacrificing the +Tile Language (tile-lang) is a concise domain-specific language designed to streamline +the development of high-performance GPU/CPU kernels (e.g., GEMM, Dequant GEMM, FlashAttention, LinearAttention). +By employing a Pythonic syntax with an underlying compiler infrastructure on top of TVM, +tile-lang allows developers to focus on productivity without sacrificing the low-level optimizations necessary for state-of-the-art performance. :::{toctree} @@ -17,13 +17,25 @@ get_started/overview get_started/targets ::: - :::{toctree} :maxdepth: 1 :caption: TUTORIALS tutorials/debug_tools_for_tilelang tutorials/auto_tuning +tutorials/logging +::: + +:::{toctree} +:maxdepth: 1 +:caption: PROGRAMMING GUIDES + +programming_guides/overview +programming_guides/language_basics +programming_guides/instructions +programming_guides/control_flow +programming_guides/autotuning +programming_guides/type_system ::: :::{toctree} @@ -33,6 +45,7 @@ tutorials/auto_tuning deeplearning_operators/elementwise deeplearning_operators/gemv deeplearning_operators/matmul +deeplearning_operators/matmul_sparse deeplearning_operators/deepseek_mla ::: @@ -42,6 +55,7 @@ deeplearning_operators/deepseek_mla compiler_internals/letstmt_inline compiler_internals/inject_fence_proxy +compiler_internals/tensor_checks ::: :::{toctree} diff --git a/docs/programming_guides/autotuning.md b/docs/programming_guides/autotuning.md new file mode 100644 index 000000000..9cc5a2d94 --- /dev/null +++ b/docs/programming_guides/autotuning.md @@ -0,0 +1,308 @@ +# Autotuning + +TileLang includes a built‑in autotuner that searches configuration spaces +for the best performing kernel, compiles candidates in parallel, validates +correctness, benchmarks them, and caches the best result for reuse. + +This guide covers two workflows: +- Decorator‑based: `@tilelang.autotune(configs=...)` stacked on `@tilelang.jit` +- Programmatic: `AutoTuner.from_kernel(...).set_*().run()` + +It also explains input tensor supply, validation, caching, and environment +variables that affect parallelism and cache behavior. + +## 1) Decorator‑based Autotune + +Use `@tilelang.autotune` above `@tilelang.jit` and expose tunable parameters as +function arguments with defaults. The autotuner overrides these parameters with +values from your config space. + +```python +import tilelang +import tilelang.language as T + +def matmul_configs(M, N, K): + # Example space — tailor to your target + tiles = [64, 128] + stages = [2, 3] + threads = [128, 256] + return [ + dict(block_M=BM, block_N=BN, block_K=BK, num_stages=S, threads=TH) + for BM in tiles + for BN in tiles + for BK in [32, 64] + for S in stages + for TH in threads + ] + +@tilelang.autotune(configs=matmul_configs, warmup=25, rep=100, timeout=60) +@tilelang.jit(out_idx=[-1]) +def matmul(M: int, N: int, K: int, + block_M: int = 128, block_N: int = 128, block_K: int = 32, + threads: int = 128, num_stages: int = 3, + dtype: str = 'float16', accum_dtype: str = 'float32'): + + @T.prim_func + def kernel(A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_s = T.alloc_shared((block_M, block_K), dtype) + B_s = T.alloc_shared((block_K, block_N), dtype) + C_f = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_f) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, ko * block_K], A_s) + T.copy(B[ko * block_K, bx * block_N], B_s) + T.gemm(A_s, B_s, C_f) + + T.copy(C_f, C[by * block_M, bx * block_N]) + + return kernel + +# Usage +# Provide inputs via context (recommended for reproducibility across configs) +import torch +M = N = K = 1024 +A = torch.randn(M, K, device='cuda', dtype=torch.float16) +B = torch.randn(K, N, device='cuda', dtype=torch.float16) +C = torch.empty(M, N, device='cuda', dtype=torch.float16) + +from tilelang.autotuner import set_autotune_inputs +with set_autotune_inputs(A, B, C): + tuned_kernel = matmul(M, N, K) # compiles, tunes, returns best kernel + tuned_kernel(A, B, C) # run best kernel +``` + +Notes +- `configs` can be a list of dicts or a callable `(args...) -> list[dict]`. Each + dict’s keys must match the tunable function arguments (e.g., `block_M`). +- The decorator returns a callable that runs autotune once per argument tuple + and caches the resulting best kernel in‑process. +- For explicit input control during tuning, wrap the call with + `set_autotune_inputs(...)`. Otherwise, `supply_type` (below) is used. + +## 2) Programmatic Autotune + +Use the `AutoTuner` class to manage configs and arguments more explicitly. + +```python +from tilelang.autotuner import AutoTuner + +kernel_factory = matmul # the function above (already @tilelang.jit) +tuner = AutoTuner.from_kernel(kernel_factory(M, N, K), configs=matmul_configs(M, N, K)) + +tuner.set_profile_args( + warmup=25, rep=100, timeout=60, + supply_type=tilelang.TensorSupplyType.Auto, # or provide supply_prog/ref_prog + ref_prog=lambda A, B, C: torch.allclose(C, (A @ B).to(C.dtype), rtol=1e-2, atol=1e-2), +) + +tuner.set_compile_args( + target='auto', # or 'cuda'/'hip'/'metal' + execution_backend='auto', # resolves per-target + out_idx=[-1], # which outputs to return if multiple + pass_configs={ # optional TVM passes/flags + # tilelang.PassConfigKey.EXAMPLE_KEY: value, + }, +) + +artifact = tuner.run() # compiles + runs + validates all configs +best_kernel = artifact.kernel # JITKernel +best_latency = artifact.latency +best_config = artifact.config + +# Reuse best kernel +best_kernel(A, B, C) +``` + +### Example Gallery (in repo) +- examples/gdn/example_chunk_delta_h.py:101 — uses `@autotune` to sweep configs +- examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py:451 — uses `@tilelang.autotune` +- examples/quickstart.py:84 — profiles a tuned kernel with `get_profiler` +- examples/hadamard_transform/example_hadamard.py:152 — profiler with custom warmup +- examples/dynamic_shape/example_dynamic.py:94 — profiler for dynamic shapes +- examples/gemm/example_gemm_persistent.py:135 — compare persistent vs non‑persistent + +Click any path to open the code and compare patterns. + +## Input Tensor Supply + +The tuner needs inputs to compile and benchmark kernels. Provide them in one of +three ways (priority order): + +1) Context manager (fixed inputs across configs) +```python +with set_autotune_inputs(A, B, C): + tuned = matmul(M, N, K) +``` + +2) Custom supplier program +```python +def supply_prog(signature): + # signature holds KernelParam objects describing shapes/dtypes + # Return a list of torch tensors matching the kernel’s arguments + return [A, B, C] + +tuner.set_profile_args(supply_prog=supply_prog) +``` + +3) Built‑in generators via `supply_type` +- `TensorSupplyType.Auto` (default): heuristic per dtype (uniform ints / fp ranges) +- `Integer`, `Uniform`, `Normal`, `Randn`, `Zero`, `One` + +Important +- Built‑in generators require static shapes; if your PrimFunc uses symbolic + dimensions (T.dyn), supply concrete inputs via (1) or (2). +- Float8 dtypes require PyTorch 2.1+ for `torch.float8_*` support. + +## Correctness Checking and Tolerances + +Use one of the following validation methods: +- `ref_prog`: Provide a reference program that receives the same inputs and + checks results. You can return a boolean or raise on mismatch. +- `manual_check_prog`: A callable that inspects outputs and raises on mismatch. +- `skip_check=True`: Skip correctness checks (faster, use with caution). + +Control numeric drift via: +- `rtol` and `atol` (defaults 1e‑2) +- `max_mismatched_ratio` (default 1%) + +## Configuration Spaces and Best Practices + +What to tune +- Tile sizes: `block_M`, `block_N`, `block_K` +- Software pipelining: `num_stages` +- Threads per block: `threads` (or (x, y) tuple) +- Optional: dtype variants, epilogues, small scheduling knobs + +Tips +- Start from a working baseline. Tune a small, meaningful space first. +- Respect hardware limits (shared memory bytes, registers per thread/block, + max threads per block). Eliminate impossible configs up‑front. +- Keep block sizes multiples of vector widths and warp sizes when relevant. +- Use `set_autotune_inputs` to ensure each config is measured on identical data. +- Record your best configs and bake them as defaults when stable. + +## Parallel Compilation/Benchmarking and Timeouts + +The tuner compiles configurations in parallel using a thread pool and benchmarks +them with a per‑config timeout. On CUDA, each worker sets the current device to +avoid context issues. + +Notes +- `timeout` uses POSIX signals; on non‑Unix systems, it may not take effect. +- Logs are written to `autotuner.log` in the working directory. + +## Caching + +The autotuner caches best artifacts both in‑memory (per process) and on disk under +`$TILELANG_CACHE_DIR/autotuner`. The cache key includes: +- TileLang version, function source, closure free‑vars +- Config list, compile args, profile args + +Disk cache contents (per key) +- Best config and latency: `best_config.json`, `latency.json` +- Kernel sources and library: `device_kernel.cu`, `host_kernel.cu`, `kernel_lib.so` (or `kernel.cubin`/`executable.so` depending on backend) +- Function and params: `function.pkl`, `params.pkl` + +Control via env vars (tilelang.env) +- `TILELANG_CACHE_DIR` (default `~/.tilelang/cache`) +- `TILELANG_TMP_DIR` (default `$TILELANG_CACHE_DIR/tmp`) +- Disable all kernel caches: `TILELANG_DISABLE_CACHE=1` +- Disable autotune disk cache only: `TILELANG_AUTO_TUNING_DISABLE_CACHE=1` + +CPU worker control +- `TILELANG_AUTO_TUNING_CPU_UTILITIES` (fraction, default 0.9) +- `TILELANG_AUTO_TUNING_CPU_COUNTS` (int, `-1` auto) +- `TILELANG_AUTO_TUNING_MAX_CPU_COUNT` (int, `-1` unlimited) + +Backend notes +- NVRTC backend persists `.cubin` and a Python launcher. +- Torch/DLPack backend may not save artifacts to disk; in this case, only + in‑memory caching applies and a warning is logged. + +## Alternative: Manual Sweeps with par_compile + +If you prefer manual control, use `JITImpl.par_compile` to compile a batch of +configs and drive your own benchmarking: + +```python +@tilelang.jit +def factory(M, N, K, block_M=128, block_N=128, block_K=32): + @T.prim_func + def k(A: T.Tensor((M, K), 'float16'), + B: T.Tensor((K, N), 'float16'), + C: T.Tensor((M, N), 'float16')): + ... + return k + +impl = factory # JITImpl +cfgs = [ + dict(block_M=64, block_N=128, block_K=32), + dict(block_M=128, block_N=128, block_K=64), +] +kernels = impl.par_compile(cfgs, num_workers=4) +# Now benchmark kernels[i](A, B, C) yourself +``` + +## Recording and Reusing Best Configs + +The programmatic path returns an `AutotuneResult` that can be saved and later +reloaded. This is useful for CI, multi‑host workflows, or shipping tuned configs. + +```python +artifact = tuner.run() # AutotuneResult + +# Save to disk +from pathlib import Path +save_dir = Path('out/best/matmul_1024') +artifact.save_to_disk(save_dir, verbose=True) + +# Reload later +from tilelang.autotuner.param import AutotuneResult, CompileArgs +restored = AutotuneResult.load_from_disk(save_dir, CompileArgs()) +best = restored.kernel +best(A, B, C) +``` + +Notes +- DLPack/Torch execution backend may not persist compiled binaries; in that + case, re‑compilation is needed on load or use a different backend. +- The directory contains human‑readable JSONs (best config/latency) and sources. + +## Advanced: Config Space Callables + +Derive config spaces from problem sizes to keep searches targeted and legal: + +```python +def matmul_configs(M, N, K): + large = min(M, N, K) >= 1024 + tiles = [128] if large else [64, 128] + for BM in tiles: + for BN in tiles: + for BK in [32, 64]: + for S in [2, 3]: + for TH in [128, 256]: + yield dict(block_M=BM, block_N=BN, block_K=BK, + num_stages=S, threads=TH) +``` + +## Device and Backend Selection + +Tune compile‑time options explicitly: +- `target='auto'|'cuda'|'hip'|'metal'` (normalized to a TVM Target) +- `execution_backend='auto'|'tvm_ffi'|'cython'|'nvrtc'|'torch'` +- `pass_configs={...}` to toggle TileLang/TVM passes for experiments + +On CUDA with multiple GPUs, the tuner sets the current device per worker thread +to avoid context mixups. + +## Troubleshooting +- “No configurations to tune”: Ensure `configs` is a non‑empty list or callable. +- Timeouts: Increase `timeout`; ensure inputs fit device memory; verify that + your reference check isn’t the bottleneck. +- Dynamic shapes: Provide concrete inputs via `set_autotune_inputs` or a custom + `supply_prog`. +- Disk cache disabled: Check `TILELANG_AUTO_TUNING_DISABLE_CACHE` and backend. diff --git a/docs/programming_guides/control_flow.md b/docs/programming_guides/control_flow.md new file mode 100644 index 000000000..158c51166 --- /dev/null +++ b/docs/programming_guides/control_flow.md @@ -0,0 +1,145 @@ +# Control Flow + +This guide covers the control‑flow primitives in TileLang and how they lower to +efficient GPU code. You will use these to structure loops, handle boundaries, +and express pipelined compute. + +## Overview +- Conditionals: `if` / `elif` / `else`, ternary (`x if c else y`) +- Loops: `T.serial`, `T.unroll`, `T.Parallel`, `T.Pipelined` +- While loops: `while` with a TIR condition +- Flow control: Python `break` / `continue` +- Safety: automatic OOB guards via the LegalizeSafeMemoryAccess pass + +The examples assume `import tilelang.language as T`. + +## Conditionals + +Standard Python `if`/`elif`/`else` is supported inside `@T.prim_func` kernels. +Conditions should be TIR expressions (e.g., `i < N`). Python plain booleans are +treated as compile‑time constants and will be folded. + +```python +for i in T.serial(N): + if i < N: # TIR condition + C[i] = A[i] + B[i] + else: + pass + +# Ternary +x = (A[i] if i < N else 0) +``` + +Short‑circuit boolean ops are supported. For multi‑dimensional bounds, use +`T.any_of` / `T.all_of` for clarity: + +```python +if T.all_of(i < M, j < N): + C[i, j] = A[i, j] + B[i, j] +``` + +Boundary handling note +- The LegalizeSafeMemoryAccess pass automatically inserts guards when an access + may be out‑of‑bounds, and elides them when proven safe. You can often omit + explicit `if` checks for simple edge handling, but keep them when you need + custom logic or clarity. + +## Loops + +### Serial + +`T.serial` creates a plain for‑loop. Common forms: + +```python +for i in T.serial(N): + ... # 0..N-1 + +for i in T.serial(0, N, 2): + ... # 0, 2, 4, ... +``` + +### Unroll + +`T.unroll` requests loop unrolling for small trip counts. + +```python +for k in T.unroll(K_TILE): + acc += a[k] * b[k] +``` + +Advanced: TileLang forwards unroll hints to TIR; factor/explicit knobs are +available for expert tuning. + +### Parallel (elementwise) + +`T.Parallel(ext0, ext1, ...)` builds nested loops that map well to elementwise +operations. The body receives all indices in one `for` header: + +```python +for i, j in T.Parallel(M, N): + C[i, j] = A[i, j] + B[i, j] +``` + +Optional: `coalesced_width=` can hint memory coalescing for the innermost loop. + +### Pipelined (software pipelining) + +`T.Pipelined(iters, num_stages=...)` overlaps producer/consumer stages (e.g., +Global→Shared copies with compute). This is the backbone of GEMM/attention +pipelines. + +```python +for ko in T.Pipelined(T.ceildiv(K, BK), num_stages=3): + T.copy(A[by * BM, ko * BK], A_s) # stage: copy A tile + T.copy(B[ko * BK, bx * BN], B_s) # stage: copy B tile + T.gemm(A_s, B_s, C_f) # stage: compute +``` + +### Persistent (advanced) + +`T.Persistent(domain, wave_size, index, group_size=...)` exposes persistent +thread‑block style looping. It is an advanced construct that TileLang lowers in +later passes and is typically used by specialized templates. + +## While Loops + +`while` is supported when the condition is a TIR expression. Avoid infinite +loops; TileLang will error if it detects a constant‑true condition. + +```python +i = 0 +while i < N: + ... + if done: + break + i += 1 +``` + +## Break and Continue + +Use Python `break`/`continue` to exit or skip within `T.serial`/`T.unroll`/ +`T.Parallel`/`while` loops. Keep the body clean after a `break`/`continue` for +readability; the compiler will ignore the dead path. + +## Putting It Together: Residual Tile Handling + +Below is a typical edge pattern for a 2D kernel. With LegalizeSafeMemoryAccess, +the explicit guard can be omitted when you don’t need a custom edge path. + +```python +for i, j in T.Parallel(M, N): + gi = by * BM + i + gj = bx * BN + j + if T.all_of(gi < M, gj < N): # optional in many cases + C[gi, gj] = A[gi, gj] + B[gi, gj] +``` + +## Debugging Conditions + +Use `T.print` to inspect values under predicates. For buffers, TileLang prints +from a single thread to avoid duplicate outputs. + +```python +if i == 0: + T.print(C, msg='C tile:') +``` diff --git a/docs/programming_guides/instructions.md b/docs/programming_guides/instructions.md new file mode 100644 index 000000000..69025c347 --- /dev/null +++ b/docs/programming_guides/instructions.md @@ -0,0 +1,180 @@ +# Instructions + +This page summarizes the core TileLang “instructions” available at the DSL +level, how they map to hardware concepts, and how to use them correctly. + +## Quick Categories +- Data movement: `T.copy`, `T.c2d_im2col`, staging Global ↔ Shared ↔ Fragment +- Compute primitives: `T.gemm`/`T.gemm_sp`, elementwise math (`T.exp`, `T.max`), + reductions (`T.reduce_sum`, `T.cumsum`, warp reducers) +- Control helpers: `T.clear`/`T.fill`, `T.reshape`/`T.view` +- Diagnostics: `T.print`, `T.device_assert` +- Advanced: atomics, memory barriers, warp‑group ops + +## Data Movement + +Use `T.copy(src, dst, coalesced_width=None, disable_tma=False, eviction_policy=None)` +to move tiles between memory scopes. It accepts `tir.Buffer`, `BufferLoad`, or +`BufferRegion`; extents are inferred or broadcast when possible. + +```python +# Global → Shared tiles (extents inferred from dst) +T.copy(A[by * BM, ko * BK], A_s) +T.copy(B[ko * BK, bx * BN], B_s) + +# Fragment/Register → Global (store result) +T.copy(C_f, C[by * BM, bx * BN]) +``` + +Semantics +- Extents are deduced from arguments; missing sides broadcast to the other’s rank. +- Access patterns are legalized and coalesced during lowering. Explicit + vectorization is not required in HL mode. +- Safety: the LegalizeSafeMemoryAccess pass inserts boundary guards when an + access may be out‑of‑bounds and drops them when proven safe. + +Other helpers +- `T.c2d_im2col(img, col, ...)`: convenience for conv‑style transforms. + +## Compute Primitives + +GEMM and sparse GEMM +- `T.gemm(A_shared, B_shared, C_fragment)`: computes a tile GEMM using shared + inputs and a fragment accumulator; lowered to target‑specific tensor cores. +- `T.gemm_sp(...)`: 2:4 sparse tensor core variant (see examples and README). + +Reductions and scans +- `T.reduce_sum`, `T.reduce_max`, `T.reduce_min`, `T.cumsum`, plus warp + reducers (`T.warp_reduce_sum`, etc.). +- Allocate and initialize accumulators via `T.alloc_fragment` + `T.clear` or + `T.fill`. + +Elementwise math +- Most math ops mirror TVM TIR: `T.exp`, `T.log`, `T.max`, `T.min`, `T.rsqrt`, + `T.sigmoid`, etc. Compose freely inside loops. + +Reshape/view (no copy) +- `T.reshape(buf, new_shape)` and `T.view(buf, shape=None, dtype=None)` create + new views that share storage, with shape/dtype checks enforced. + +## Synchronization (HL usage) + +In HL pipelines, you usually don’t need to write explicit barriers. Passes such +as PipelinePlanning/InjectSoftwarePipeline/InjectTmaBarrier orchestrate +producer/consumer ordering and thread synchronization behind the scenes. + +If you need debugging or explicit checks: +- `T.device_assert(cond, msg='')` emits device‑side asserts on CUDA targets. +- `T.print(obj, msg='...')` prints scalars or buffers safely from one thread. + +## Putting It Together: GEMM Tile + +```python +@T.prim_func +def gemm( + A: T.Tensor((M, K), 'float16'), + B: T.Tensor((K, N), 'float16'), + C: T.Tensor((M, N), 'float16'), +): + with T.Kernel(T.ceildiv(N, BN), T.ceildiv(M, BM), threads=128) as (bx, by): + A_s = T.alloc_shared((BM, BK), 'float16') + B_s = T.alloc_shared((BK, BN), 'float16') + C_f = T.alloc_fragment((BM, BN), 'float32') + T.clear(C_f) + + for ko in T.Pipelined(T.ceildiv(K, BK), num_stages=3): + T.copy(A[by * BM, ko * BK], A_s) # Global → Shared + T.copy(B[ko * BK, bx * BN], B_s) + T.gemm(A_s, B_s, C_f) # compute into fragment + + T.copy(C_f, C[by * BM, bx * BN]) # store back +``` + +## Instruction Reference (Concise) + +Below is a concise list of TileLang instructions grouped by category. For full +signatures, behaviors, constraints, and examples, refer to API Reference +(`autoapi/tilelang/index`). + +Data movement +- `T.copy(src, dst, ...)`: Move tiles between Global/Shared/Fragment. +- `T.c2d_im2col(img, col, ...)`: 2D im2col transform for conv. + +Memory allocation and descriptors +- `T.alloc_shared(shape, dtype, scope='shared.dyn')`: Allocate shared buffer. +- `T.alloc_fragment(shape, dtype, scope='local.fragment')`: Allocate fragment. +- `T.alloc_var(dtype, [init], scope='local.var')`: Scalar var buffer (1 elem). +- `T.alloc_barrier(arrive_count)`: Shared barrier buffer. +- `T.alloc_tmem(shape, dtype)`: Tensor memory (TMEM) buffer (Hopper+). +- `T.alloc_reducer(shape, dtype, op='sum', replication=None)`: Reducer buf. +- `T.alloc_descriptor(kind, dtype)`: Generic descriptor allocator. + - `T.alloc_wgmma_desc(dtype='uint64')` + - `T.alloc_tcgen05_smem_desc(dtype='uint64')` + - `T.alloc_tcgen05_instr_desc(dtype='uint32')` +- `T.empty(shape, dtype='float32')`: Declare function output tensors. + +Compute primitives +- `T.gemm(A_s, B_s, C_f)`: Tile GEMM into fragment accumulator. +- `T.gemm_sp(...)`: Sparse (2:4) tensor core GEMM. +- Reductions: `T.reduce_sum/max/min/abssum/absmax`, bitwise `and/or/xor`. +- Scans: `T.cumsum`, finalize: `T.finalize_reducer`. +- Warp reducers: `T.warp_reduce_sum/max/min/bitand/bitor`. +- Elementwise math: TIR ops (`T.exp`, `T.log`, `T.max`, `T.min`, `T.rsqrt`, ...). +- Fast math: `T.__log/__log2/__log10/__exp/__exp2/__exp10/__sin/__cos/__tan`. +- IEEE math: `T.ieee_add/sub/mul/fmaf` (configurable rounding). +- Helpers: `T.clear(buf)`, `T.fill(buf, value)`. +- Views: `T.reshape(buf, shape)`, `T.view(buf, shape=None, dtype=None)`. + +Diagnostics +- `T.print(obj, msg='')`: Print scalar/buffer from one thread. +- `T.device_assert(cond, msg='')`: Device-side assert (CUDA). + +Logical helpers +- `T.any_of(a, b, ...)`, `T.all_of(a, b, ...)`: Multi-term predicates. + +Annotation helpers +- `T.use_swizzle(panel_size=..., enable=True)`: Rasterization hint. +- `T.annotate_layout({...})`: Attach explicit layouts to buffers. +- `T.annotate_safe_value(var, ...)`: Safety/const hints. +- `T.annotate_l2_hit_ratio(buf, ratio)`: Cache behavior hint. + +Atomics +- `T.atomic_add(dst, value, memory_order=None, return_prev=False, use_tma=False)`. +- `T.atomic_addx2(dst, value, return_prev=False)`; `T.atomic_addx4(...)`. +- `T.atomic_max(dst, value, memory_order=None, return_prev=False)`. +- `T.atomic_min(dst, value, memory_order=None, return_prev=False)`. +- `T.atomic_load(dst)`, `T.atomic_store(dst, value)`. + +Custom intrinsics +- `T.dp4a(A, B, C)`: 4‑element dot‑product accumulate. +- `T.clamp(x, lo, hi)`: Clamp to [lo, hi]. +- `T.loop_break()`: Break from current loop via intrinsic. + +Barriers, TMA, warp‑group +- Barriers: `T.create_list_of_mbarrier(...)`, `T.get_mbarrier(i)`. +- Parity ops: `T.mbarrier_wait_parity(barrier, parity)`, `T.mbarrier_arrive(barrier)`. +- Expect tx: `T.mbarrier_expect_tx(...)`; sugar: `T.barrier_wait(id, parity=None)`. +- TMA: `T.create_tma_descriptor(...)`, `T.tma_load(...)`, + `T.tma_store_arrive(...)`, `T.tma_store_wait(...)`. +- Proxy/fences: `T.fence_proxy_async(...)`, `T.warpgroup_fence_operand(...)`. +- Warp‑group: `T.warpgroup_arrive()`, `T.warpgroup_commit_batch()`, + `T.warpgroup_wait(num_mma)`, `T.wait_wgmma(id)`. + +Lane/warp index +- `T.get_lane_idx(warp_size=None)`: Lane id in warp. +- `T.get_warp_idx_sync(warp_size=None)`: Canonical warp id (sync). +- `T.get_warp_idx(warp_size=None)`: Canonical warp id (no sync). +- `T.get_warp_group_idx(warp_size=None, warps_per_group=None)`: Group id. + +Register control +- `T.set_max_nreg(reg_count, is_inc)`, `T.inc_max_nreg(n)`, `T.dec_max_nreg(n)`. +- `T.annotate_producer_reg_dealloc(n=24)`, `T.annotate_consumer_reg_alloc(n=240)`. +- `T.no_set_max_nreg()`, `T.disable_warp_group_reg_alloc()`. + +## Notes on Dtypes + +Dtypes accept three equivalent forms: +- String: `'float32'` +- TileLang dtype: `T.float32` +- Framework dtype: `torch.float32` +All are normalized internally. See Type System for details. diff --git a/docs/programming_guides/language_basics.md b/docs/programming_guides/language_basics.md new file mode 100644 index 000000000..1152680c9 --- /dev/null +++ b/docs/programming_guides/language_basics.md @@ -0,0 +1,234 @@ +# Language Basics + +This page introduces the core TileLang (tile‑lang) DSL that you’ll use to write +high‑performance kernels. It focuses on how to define a kernel, express +iteration, move data across memory scopes, and run it with JIT. + +The examples use the conventional aliases: + +```python +import tilelang +import tilelang.language as T +from tilelang import jit +``` + +## 1. Defining a Kernel with `@T.prim_func` + +TileLang kernels are TIR (TVM IR) functions produced by the `@T.prim_func` +decorator. Arguments are annotated with shapes and dtypes via `T.Tensor` or +`T.Buffer`. + +Note on dtypes +- You can pass dtypes as a string (e.g., 'float32'), a TileLang dtype (e.g., `T.float32`), + or a framework dtype (e.g., `torch.float32`). TileLang normalizes all of these. + See Type System for details. + +```python +@T.prim_func +def add_kernel( + A: T.Tensor((N,), dtype), # dtype could be 'float32' | T.float32 | torch.float32 + B: T.Tensor((N,), dtype), + C: T.Tensor((N,), dtype), +): + ... # kernel body +``` + +- Shapes may be concrete integers or symbolic. For symbolic, you can pass + Python ints through the outer `@jit` wrapper (shown below), or annotate with + `T.dyn` when you want a named symbolic dimension. + +```python +# Named symbolic dimension (optional) +K = T.dyn['K'] +@T.prim_func +def uses_dyn(A: T.Tensor((K,), 'float32')): + ... +``` + +### Dynamic symbolic dimensions: two ways + +TileLang supports two complementary ways to introduce symbolic (dynamic) dims: + +- Type-level annotations via `T.dyn[...]` (recommended for function signatures) + - Use in `T.Tensor((T.dyn['K'], ...), dtype)` or bind once then reuse (as above). + - Inside the kernel body, prefer reading from the buffer’s shape, e.g. `M = A.shape[0]`. + +- Term-level variables via `T.dynamic(name, dtype)` + - Creates a TIR `tir.Var` you can use directly in expressions/loops. + - Handy when you need to reference the dimension symbol in the body. + +```python +# 1) Annotation-only symbol; read the bound size via shape +K = T.dyn['K'] # dtype defaults to int32 +@T.prim_func +def foo(A: T.Tensor((K,), 'float32')): + N = A.shape[0] + for i in T.serial(N): + ... + +# 2) Explicit Var symbol usable in the body +K = T.dynamic('K', 'int32') # or T.dynamic('K') defaults to int32 +@T.prim_func +def bar(A: T.Tensor((K,), 'float32')): + for i in T.serial(K): + ... +``` + +Notes +- `T.symbolic(name, dtype)` is a deprecated alias of `T.dynamic`; prefer `T.dynamic`. +- Under `@jit`, concrete sizes come from the actual tensor arguments at the first call. +- Symbols in annotations do not need to be separate kernel arguments; TileLang binds them from argument shapes. + +## 2. Launching Work with `T.Kernel` + +`with T.Kernel(...)` declares a launch context and creates block/thread +bindings. For GPU backends, specify a grid and threads per block. + +```python +with T.Kernel(grid_x, grid_y, threads=128) as (bx, by): + ... # bx/by are blockIdx.x/y +``` + +You rarely need raw thread indices; most kernels use structured loops +(`T.serial`, `T.unroll`, `T.Parallel`, `T.Pipelined`) inside a `T.Kernel`. + +## 3. Loops and Control Flow + +Core loop constructs map to familiar hardware patterns: + +- `T.serial(start, stop[, step])`: plain for‑loop +- `T.unroll(start, stop[, step])`: unrolled loop +- `T.Parallel(ext0, ext1, ...)`: nested parallel loops (elementwise‑friendly) +- `T.Pipelined(iters, num_stages=N)`: software pipelining for producer/consumer + +```python +for i in T.serial(N): + ... + +for i, j in T.Parallel(M, N): + C[i, j] = A[i, j] + B[i, j] + +for k in T.Pipelined(T.ceildiv(K, BK), num_stages=3): + # overlap copy/compute across stages + ... +``` + +Conditionals use standard Python `if`/`else`. Guard edges with predicates when +tile sizes do not divide problem sizes evenly. + +## 4. Memory Scopes and Allocation + +TileLang exposes key software‑managed scopes: + +- Global: device memory (default for `T.Tensor` arguments) +- Shared: on‑chip, block‑visible (`T.alloc_shared(shape, dtype)`) +- Fragment and scalars: per‑thread fragments and scalar vars but in Shared View + (`T.alloc_fragment`, `T.alloc_var`) + +```python +A_shared = T.alloc_shared((BM, BK), 'float16') +B_shared = T.alloc_shared((BK, BN), 'float16') +C_local = T.alloc_fragment((BM, BN), 'float32') +T.clear(C_local) # zero accumulators +``` + +## 5. Moving Data: `T.copy` + +Use `T.copy(src, dst)` to move tiles between scopes. It accepts buffers, +buffer regions, or buffer loads; extents are inferred or can be broadcast. + +```python +# Global -> Shared (tile copy), extents inferred from dst +T.copy(A[by * BM, ko * BK], A_shared) +T.copy(B[ko * BK, bx * BN], B_shared) + +# Fragment -> Global (store back) +T.copy(C_local, C[by * BM, bx * BN]) +``` + +`T.copy` performs coalescing and scope‑specific lowering during compilation. + +## 6. A Minimal End‑to‑End Example (Vector Add) + +```python +import tilelang +import tilelang.language as T +from tilelang import jit + +@jit # infers target from tensors at first call +def add(N: int, block: int = 256, dtype: str = 'float32'): + + @T.prim_func + def add_kernel( + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), + C: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, block), threads=block) as bx: + for i in T.Parallel(block): + gi = bx * block + i + # Optional — LegalizeSafeMemoryAccess inserts a guard when an access may be OOB + C[gi] = A[gi] + B[gi] + + return add_kernel + +# Host side (PyTorch shown; NumPy/DLPack also supported) +import torch +N = 1 << 20 +A = torch.randn(N, device='cuda', dtype=torch.float32) +B = torch.randn(N, device='cuda', dtype=torch.float32) +C = torch.empty(N, device='cuda', dtype=torch.float32) + +kernel = add(N) +kernel(A, B, C) # runs on GPU +torch.testing.assert_close(C, A + B) +``` + +Notes +- The `@jit` wrapper returns a callable kernel after the first compilation. +- You can pass compile‑time tunables (tile sizes, dtypes) through the outer + Python function and bake them into the generated TIR. + +## 7. Tiled GEMM Skeleton + +Below is a minimal pattern for a tiled GEMM using shared memory staging and a +fragment accumulator. It mirrors the quickstart style found in the repository. + +```python +@T.prim_func +def gemm( + A: T.Tensor((M, K), 'float16'), + B: T.Tensor((K, N), 'float16'), + C: T.Tensor((M, N), 'float16'), +): + with T.Kernel(T.ceildiv(N, BN), T.ceildiv(M, BM), threads=128) as (bx, by): + A_s = T.alloc_shared((BM, BK), 'float16') + B_s = T.alloc_shared((BK, BN), 'float16') + C_f = T.alloc_fragment((BM, BN), 'float32') + T.clear(C_f) + + for ko in T.Pipelined(T.ceildiv(K, BK), num_stages=3): + T.copy(A[by * BM, ko * BK], A_s) + T.copy(B[ko * BK, bx * BN], B_s) + T.gemm(A_s, B_s, C_f) # lowered to tensor‑core/ISA specific kernels + + T.copy(C_f, C[by * BM, bx * BN]) +``` + +## 8. Debugging and Printing + +Use `T.print` inside a kernel for quick introspection. TileLang emits printing +from a single thread for shared/fragment scopes to avoid floods. + +```python +T.print(C_f, msg='accumulator:') +T.print(A_s, msg='A tile:') +T.print(C[0], msg='C[0] = ') +``` + +## 9. Where to Go Next + +- Control flow details: see Programming Guides → Control Flow +- Memory topics: see Programming Guides → (removed cache/layout); basics are covered inline +- Autotuning tile sizes and mappings: Programming Guides → Autotuning +- Operator examples (GEMM, GEMV, attention): see Deep Learning Operators diff --git a/docs/programming_guides/overview.md b/docs/programming_guides/overview.md new file mode 100644 index 000000000..64b6d2039 --- /dev/null +++ b/docs/programming_guides/overview.md @@ -0,0 +1,27 @@ +# Programming Guides Overview + +This section provides a practical guide to writing high‑performance kernels with Tile Language (tile‑lang). +It mirrors the structure of a similar guide in another project and adapts it to tile‑lang concepts and APIs. + +- Audience: Developers implementing custom GPU/CPU kernels with tile‑lang +- Prereqs: Basic Python, NumPy/Tensor concepts, and familiarity with GPU programming notions +- Scope: Language basics, control flow, instructions, autotuning, and type system + +## What You’ll Learn +- How to structure kernels with TileLang’s core DSL constructs +- How to move data across global/shared/fragment and pipeline compute +- How to apply autotuning to tile sizes and schedules +- How to specify and work with dtypes in kernels + +## Suggested Reading Order +1. Language Basics +2. Control Flow +3. Instructions +4. Autotuning +5. Type System + +## Related Docs +- Tutorials: see existing guides in `tutorials/` +- Operators: examples in `deeplearning_operators/` + +> NOTE: This is a draft scaffold. Fill in code snippets and benchmarks as APIs evolve. diff --git a/docs/programming_guides/type_system.md b/docs/programming_guides/type_system.md new file mode 100644 index 000000000..60061df3f --- /dev/null +++ b/docs/programming_guides/type_system.md @@ -0,0 +1,41 @@ +# Type System + +This page lists the data types supported by TileLang and how to specify them in +kernels. For full details and the authoritative list, see the API Reference +(`autoapi/tilelang/index`) and `tilelang.language.v2.dtypes`. + +How to specify dtypes +- Use any of the following forms; TileLang normalizes them internally: + - String: `'float32'`, `'int8'`, `'bfloat16'`, ... + - TileLang dtype object: `T.float32`, `T.int8`, `T.bfloat16`, ... + - Framework dtype: `torch.float32`, `torch.int8`, `torch.bfloat16`, ... + +Common scalar types +- Boolean: `bool` +- Signed integers: `int8`, `int16`, `int32`, `int64` +- Unsigned integers: `uint8`, `uint16`, `uint32`, `uint64` +- Floating‑point: `float16` (half), `bfloat16`, `float32`, `float64` + +Float8 and low‑precision families +- Float8: `float8_e3m4`, `float8_e4m3`, `float8_e4m3b11fnuz`, `float8_e4m3fn`, + `float8_e4m3fnuz`, `float8_e5m2`, `float8_e5m2fnuz`, `float8_e8m0fnu` +- Float6: `float6_e2m3fn`, `float6_e3m2fn` +- Float4: `float4_e2m1fn` + +Vectorized element types (SIMD packs) +- For many base types, vector‑packed variants are available by lane count: + `x2`, `x4`, `x8`, `x16`, `x32`, `x64`. +- Examples: + - Integers: `int8x2`, `int8x4`, ..., `int32x2`, `int32x4`, ... + - Unsigned: `uint8x2`, `uint8x4`, ... + - Floats: `float16x2`, `float16x4`, `float32x2`, `float32x4`, ... + - Float8/6/4 families also provide `x2/x4/x8/x16/x32/x64` where applicable, + e.g., `float8_e4m3x2`, `float8_e4m3x4`, `float6_e2m3fnx8`, `float4_e2m1fnx16`. + +Notes +- Availability of certain low‑precision formats (float8/6/4) depends on target + architecture and backend support. +- Choose accumulation dtypes explicitly for mixed‑precision compute (e.g., + GEMM with `float16` inputs and `float32` accumulators). +- The complete, up‑to‑date list is exposed in + `tilelang.language.v2.dtypes` and rendered in the API Reference. diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index e859d0e7b..6fd433459 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1,4 +1,5 @@ cancelled +HDA hsa ist LOD diff --git a/docs/tutorials/auto_tuning.md b/docs/tutorials/auto_tuning.md index 3f3cad832..33368a2f0 100644 --- a/docs/tutorials/auto_tuning.md +++ b/docs/tutorials/auto_tuning.md @@ -14,7 +14,7 @@ Auto-tuning a Tile Language program involves three main steps: ## Matrix Multiplication Example -The following example demonstrates auto-tuning matrix multiplication. Code has been simplified for readability - see `examples/gemm/example_gemm.py` for complete implementation. +The following example demonstrates auto-tuning matrix multiplication. Code has been simplified for readability - see `examples/gemm/example_gemm.py` for complete implementation. ### Step 1: Implement with Reserved Parameters Users can implement matrix multiplication in Tile Language while reserving parameters for optimization: @@ -145,4 +145,4 @@ for hint in roller_hints: config["thread_num"] = block_rows * block_cols * 32 config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization -``` \ No newline at end of file +``` diff --git a/docs/tutorials/debug_tools_for_tilelang.md b/docs/tutorials/debug_tools_for_tilelang.md index e18b13279..d98d4cb5e 100644 --- a/docs/tutorials/debug_tools_for_tilelang.md +++ b/docs/tutorials/debug_tools_for_tilelang.md @@ -12,7 +12,6 @@ A Tile Language program (hereafter referred to as a *program*) is transformed in 2. The program undergoes multiple *Passes* for transformation and optimization (the *lower* stage, see `tilelang/engine/lower.py`), finally producing an intermediate representation (e.g., LLVM or C for CPU, CUDA for NVIDIA GPUs, etc.). 3. The generated code is compiled by the respective compiler (e.g., nvcc) into a hardware-executable file. - ```{figure} ../_static/img/overview.png :width: 300 :alt: Overview of the compilation process @@ -22,9 +21,9 @@ A Tile Language program (hereafter referred to as a *program*) is transformed in During this process, users may encounter roughly three categories of issues: -* **Generation issues**: The Tile Language program fails to generate a valid hardware-executable file (i.e., errors during the lowering process). -* **Correctness issues**: The resulting executable runs, but produces incorrect results. -* **Performance issues**: The executable runs with performance significantly below the expected theoretical hardware limits. +- **Generation issues**: The Tile Language program fails to generate a valid hardware-executable file (i.e., errors during the lowering process). +- **Correctness issues**: The resulting executable runs, but produces incorrect results. +- **Performance issues**: The executable runs with performance significantly below the expected theoretical hardware limits. This tutorial focuses on the first two issues—how to debug generation and correctness problems. Performance tuning often requires using vendor-provided profiling tools (e.g., **Nsight Compute**, **rocProf**, etc.) for further hardware-level analysis, which we will address in future materials. @@ -52,7 +51,6 @@ func = matmul(1024, 1024, 1024, 128, 128, 32) TileLang essentially performs *progressive lowering*. For example, a `T.copy` may first be expanded into `T.Parallel` (see the pass `LowerTileOP`), which is then expanded again, eventually resulting in lower-level statements that can be translated to CUDA C code. - ```{figure} ../_static/img/ir_transform_diagram.png :width: 400 :alt: IR transformation diagram @@ -171,6 +169,31 @@ The output messages will include something like: msg='hello world' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): 0 ``` +### Visual Layout Inference For TileLang + The **Visual Layout Inference** tool automatically generates visual diagrams that illustrate the mapping between logical indices, thread IDs, and register file locations. + +When TileLang performs layout inference, it determines how fragment buffers are distributed across threads. The visual layout tool captures this information and generates: +1. **Textual output**: A human-readable description of the layout mapping +2. **Visual diagrams**: Color-coded plots showing the thread-to-data mapping + +The visual layout inference tool is controlled through the `TL_LAYOUT_VISUALIZATION_ENABLE` and `TL_LAYOUT_VISUALIZATION_FORMATS` pass configuration. By default, `TL_LAYOUT_VISUALIZATION_ENABLE` is **disabled** to avoid performance overhead during compilation. + +When enabled, `TL_LAYOUT_VISUALIZATION_FORMATS` accepts string values to control output formats: +- "txt": Text output only (same as default) +- "all": Generates all formats (TXT, PDF, PNG, SVG) +- "png": Generate PNG format only +- "pdf": Generate PDF format only +- "svg": Generate SVG format only +- "txt,svg": Generate multiple formats (comma-separated) in addition to text output + +The output messages of "txt" will include something like: +``` +C_local inferenced layout: + Shape: [32, 32] -> [8] + Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2 + Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2] +``` + ## Conclusion By carefully examining intermediate representations (IR) before final code generation—and by leveraging runtime printing through `T.print`—one can quickly diagnose where index calculations, copy logic, or other kernel operations deviate from the intended behavior. This two-pronged approach (inspecting IR transformations and using runtime prints) is often sufficient for resolving generation and correctness issues in TileLang programs. diff --git a/docs/tutorials/logging.md b/docs/tutorials/logging.md new file mode 100644 index 000000000..1a015432d --- /dev/null +++ b/docs/tutorials/logging.md @@ -0,0 +1,116 @@ +Logging in Tilelang/TVM +=================================================== +
+Author: SiriusNEO +
+ +## TVM Logging Overview + +Tilelang currently utilizes the logging system from TVM. The implementation can be found in: + +- [include/tvm/runtime/logging.h](https://github.com/apache/tvm/blob/main/include/tvm/runtime/logging.h): Macro definitions +- [src/runtime/logging.cc](https://github.com/apache/tvm/blob/main/src/runtime/logging.cc): Logging logic implementation + +The design style is inspired by [Google's glog](https://google.github.io/glog/stable/). + +## Logging Categories + +There are three primary macro types: + +```c++ +LOG(INFO) << "aaa"; +DLOG(INFO) << "aaa"; +VLOG(1) << "aaa"; +``` + +- **LOG**: Standard logging preserved in code for displaying necessary information at different levels during runtime. Most Tilelang C++ error reporting is implemented via `LOG(FATAL) << "error msg"`. +- **DLOG**: Debug logging for developer debugging output. DLOG is controlled at build time by the TVM_LOG_DEBUG environment variable and is **eliminated in Release builds through dead code elimination**. + - The key difference between LOG(DEBUG) and DLOG is this build-time elimination. We recommend using DLOG over LOG(DEBUG), as the latter has overlapping functionality and gets compiled into the release runtime. +- **VLOG**: [Verbose logging](https://google.github.io/glog/stable/logging/#verbose-logging), primarily for debugging. Its main feature is customizable verbosity levels. For example, VLOG(n) where n can be 1, 2, 3, 4, 5, or 6, enabling complex tracing requirements. In contrast, LOG and DLOG typically use predefined verbose levels like INFO and DEBUG. + - In practical Tilelang development, VLOG is used less frequently. + - TVM's VLOG is implemented using DLOG, thus inheriting DLOG's characteristics. + +Additional useful macros include various **CHECK** variants: + +```c++ +CHECK(cond) << "error msg"; +DCHECK(cond) << "error msg"; +ICHECK(cond) << "error msg"; +``` + +The implementation routes errors to LogFatal: + +```c++ +#define CHECK(x) \ + if (!(x)) \ + ::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \ + << "Check failed: (" #x << ") is false: " +``` +- **DCHECK**: Debug mode CHECK, only compiled in debug builds +- **ICHECK**: Internal Check that should exist in Release builds. When ICHECK fails, the entire system should report an error. + +## Logging Verbose Levels + +TVM defines 5 levels for LOG and DLOG (adding DEBUG compared to glog): + +```c++ +#define TVM_LOG_LEVEL_DEBUG 0 +#define TVM_LOG_LEVEL_INFO 1 +#define TVM_LOG_LEVEL_WARNING 2 +#define TVM_LOG_LEVEL_ERROR 3 +#define TVM_LOG_LEVEL_FATAL 4 +``` + +## Using Logging in TileLang Development + +### Guidelines + +For temporary debugging output in your code, there are no restrictions (you can even use std::cout). Just remember to remove it before submitting a PR. + +For meaningful logging that should remain in the Tilelang codebase: + +- Critical correctness checks: Use ICHECK with sufficient error messages to facilitate debugging when issues arise. +- Complex Pass debugging: For passes requiring intermediate output that may need future review (e.g., LayoutInference), use DLOG. +- General INFO/WARNING messages: Use standard LOG. + +### Enabling Log Output in Tilelang + +To specify current log level at runtime, we need to set the environment variable `TVM_LOG_LEVEL`. An example usage is: + +```c++ +TVM_LOG_DEBUG=1 python3 code.py +``` + +which enables all DEBUG/INFO (level <= 1) logs for all files. + +#### Detailed Rules for TVM_LOG_DEBUG Specification + +The parsing logic is in `logging.cc`. Reference: [HyperAI Zhihu Article](https://zhuanlan.zhihu.com/p/1933106843468665163). + +Launch Python with `TVM_LOG_DEBUG=`, where `` is a comma-separated list of level assignments in the form `=`. Important notes: + +- The special filename DEFAULT sets the LOG level for all files. +- `` can be set to -1 to disable LOG for that file. +- `` is the C++ source filename (e.g., .cc, not .h) relative to the `src/` directory in the TVM repository. The `src/` prefix is optional when specifying file paths. + +### Enabling Debug Mode + +To enable DLOG/DCHECK, developers need to first build Tilelang in Debug mode: + +```bash +cmake .. -DCMAKE_BUILD_TYPE=Debug -DUSE_CUDA=ON +``` + +Tilelang's CMake logic automatically adds the `TVM_LOG_DEBUG` macro, compiling all DLOG statements: + +```cmake +target_compile_definitions(tilelang_objs PRIVATE "TVM_LOG_DEBUG") +``` + +Then you also need to specify the runtime environment variables. For example, to use `DLOG(INFO) << "xxx"` for debugging, run your code with INFO level (1): `TVM_LOG_DEBUG=1`. + +:::{note} + **Important**: There are two TVM_LOG_DEBUG variables. (1) Compile-time macro: Determines whether debug content (like DLOG) is compiled into the .so file. Referenced in C++ source via #ifdef TVM_LOG_DEBUG. This is automatically enabled when using Debug build mode in CMake. (2) Runtime environment variable: Controls logging level at runtime. TVM provides a specification for this variable, allowing control over per-file logging levels. + + These two should ideally have different names, but TVM uses the same name for both, which can cause confusion. +::: diff --git a/examples/amd/example_amd_flash_attn_bwd.py b/examples/amd/example_amd_flash_attn_bwd.py index d47866e1e..788aec367 100644 --- a/examples/amd/example_amd_flash_attn_bwd.py +++ b/examples/amd/example_amd_flash_attn_bwd.py @@ -2,7 +2,7 @@ import torch.nn.functional as F import tilelang import tilelang.language as T -from tilelang.primitives.gemm.base import GemmWarpPolicy +from tilelang.tileop.base import GemmWarpPolicy import itertools import argparse from functools import partial @@ -11,22 +11,20 @@ def ref_program(Q, K, V, is_causal, groups=1): - assert Q.size( - 2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" - assert Q.size( - 2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" dim = Q.size(-1) K_ref = K.repeat_interleave(groups, dim=2) V_ref = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K_ref) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K_ref) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V_ref) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V_ref) lse = torch.logsumexp(scores, dim=-1).float() return output, lse @@ -45,23 +43,23 @@ def get_fwd_configs(): valid_configs = [] - for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(block_M, block_N, num_split_q, - threads, num_stages, - enable_rasterization, k_pack, - panel_size, qk_coalesced_width, - v_coalesced_width): - valid_configs.append({ - "block_M": m, - "block_N": n, - "num_split_q": s, - "threads": t, - "num_stages": stages, - "enable_rasterization": r, - "k_pack": k, - "panel_size": p, - "qk_coalesced_width": qkw, - "v_coalesced_width": vw, - }) + for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product( + block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width + ): + valid_configs.append( + { + "block_M": m, + "block_N": n, + "num_split_q": s, + "threads": t, + "num_stages": stages, + "enable_rasterization": r, + "k_pack": k, + "panel_size": p, + "qk_coalesced_width": qkw, + "v_coalesced_width": vw, + } + ) return valid_configs @@ -85,23 +83,23 @@ def fast_flashattn( qk_coalesced_width: int, v_coalesced_width: int, ): - scale = (1.0 / dim)**0.5 + scale = (1.0 / dim) ** 0.5 head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 vec_size = qk_coalesced_width v_vec_size = v_coalesced_width @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), - LSE: T.Tensor([batch, heads, seq_len], accum_dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + LSE: T.Tensor([batch, heads, seq_len], accum_dtype), ): with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined): T.use_swizzle(panel_size, enable=enable_rasterization) @@ -111,7 +109,7 @@ def main( num_q_blocks = T.ceildiv(seq_len, block_M) - bx_loop_var = T.alloc_var("int32") + bx_loop_var = T.alloc_var(T.int32) bx_loop_var = b_split with T.While(bx_loop_var < num_q_blocks): @@ -135,33 +133,21 @@ def main( m_prev = T.alloc_fragment([block_M], accum_dtype) scale_factor = T.alloc_fragment([block_M], accum_dtype) - T.copy( - Q[bz, q_block_offset:q_block_offset + block_M, by, :], - Q_shared, - coalesced_width=vec_size) + T.copy(Q[bz, q_block_offset : q_block_offset + block_M, by, :], Q_shared, coalesced_width=vec_size) - loop_end_k = ( - T.ceildiv(q_block_offset + - block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_end_k = T.ceildiv(q_block_offset + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) row_sum = T.alloc_fragment([block_M], accum_dtype) for k in T.Pipelined(loop_end_k, num_stages=num_stages): kv_idx = k * block_N - T.copy( - K[bz, kv_idx:kv_idx + block_N, by // groups, :], - K_shared, - coalesced_width=vec_size) - T.copy( - V[bz, kv_idx:kv_idx + block_N, by // groups, :], - V_shared, - coalesced_width=v_vec_size) + T.copy(K[bz, kv_idx : kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size) + T.copy(V[bz, kv_idx : kv_idx + block_N, by // groups, :], V_shared, coalesced_width=v_vec_size) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) T.gemm( @@ -178,6 +164,8 @@ def main( T.copy(m_i, m_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for i in T.Parallel(block_M): + m_i[i] = T.max(m_i[i], m_prev[i]) for i in T.Parallel(block_M): if m_prev[i] == -T.infinity(accum_dtype): @@ -214,8 +202,7 @@ def main( for i in T.Parallel(block_M): if q_block_offset + i < seq_len: - lse_val = T.if_then_else(l_i[i] > 0, - T.log(l_i[i]) + m_i[i], -T.infinity(accum_dtype)) + lse_val = T.if_then_else(l_i[i] > 0, T.log(l_i[i]) + m_i[i], -T.infinity(accum_dtype)) LSE[bz, by, q_block_offset + i] = lse_val bx_loop_var = current_bx + num_split_q @@ -232,30 +219,30 @@ def get_bwd_configs(): panel_size = [7, 8, 9, 10] configs = [] - for m, n, stages, t, r, p in itertools.product(block_M, block_N, num_stages, threads, - enable_rasterization, panel_size): - configs.append({ - "block_M": m, - "block_N": n, - "num_stages": stages, - "threads": t, - "enable_rasterization": r, - "panel_size": p, - }) + for m, n, stages, t, r, p in itertools.product(block_M, block_N, num_stages, threads, enable_rasterization, panel_size): + configs.append( + { + "block_M": m, + "block_N": n, + "num_stages": stages, + "threads": t, + "enable_rasterization": r, + "panel_size": p, + } + ) return configs @tilelang.jit(out_idx=[2]) def flashattn_bwd_preprocess(batch, heads, seq_len, dim): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim] blk = 32 @T.prim_func - def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), - Delta: T.Tensor([batch, heads, seq_len], accum_dtype)): + def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), Delta: T.Tensor([batch, heads, seq_len], accum_dtype)): with T.Kernel(batch, heads, T.ceildiv(seq_len, blk)) as (bz, bx, by): o = T.alloc_fragment([blk, blk], dtype) do = T.alloc_fragment([blk, blk], dtype) @@ -263,36 +250,51 @@ def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep @tilelang.autotune(configs=get_bwd_configs(), cache_input_tensors=True) @tilelang.jit -def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, block_N: int, - num_stages: int, threads: int, enable_rasterization: bool, panel_size: int): - sm_scale = (1.0 / dim)**0.5 +def flashattn_bwd( + batch, + heads, + seq_len, + dim, + is_causal, + groups, + block_M: int, + block_N: int, + num_stages: int, + threads: int, + enable_rasterization: bool, + panel_size: int, +): + sm_scale = (1.0 / dim) ** 0.5 head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func - def flash_bwd_kernel(Q: T.Tensor(q_shape, - dtype), K: T.Tensor(kv_shape, - dtype), V: T.Tensor(kv_shape, dtype), - dO: T.Tensor(q_shape, dtype), lse: T.Tensor([batch, heads, seq_len], - accum_dtype), - Delta: T.Tensor([batch, heads, seq_len], - accum_dtype), dQ: T.Tensor(q_shape, accum_dtype), - dK: T.Tensor(kv_shape, accum_dtype), dV: T.Tensor(kv_shape, accum_dtype)): + def flash_bwd_kernel( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + dO: T.Tensor(q_shape, dtype), + lse: T.Tensor([batch, heads, seq_len], accum_dtype), + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), + dQ: T.Tensor(q_shape, accum_dtype), + dK: T.Tensor(kv_shape, accum_dtype), + dV: T.Tensor(kv_shape, accum_dtype), + ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): T.use_swizzle(panel_size, enable=enable_rasterization) @@ -313,8 +315,8 @@ def flash_bwd_kernel(Q: T.Tensor(q_shape, dk = T.alloc_fragment([block_M, dim], accum_dtype) dq = T.alloc_fragment([block_N, dim], accum_dtype) - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) @@ -322,22 +324,21 @@ def flash_bwd_kernel(Q: T.Tensor(q_shape, loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q_shared) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q_shared) T.clear(qkT) T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): P_acc[i, j] = T.exp(qkT[i, j] * sm_scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - P_acc[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, - P_acc[i, j], 0.0) + P_acc[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, P_acc[i, j], 0.0) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do_shared) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do_shared) T.clear(dP) T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -345,7 +346,7 @@ def flash_bwd_kernel(Q: T.Tensor(q_shape, T.copy(P_acc, p_cast) T.gemm(p_cast, do_shared, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta_shared) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta_shared) for i, j in T.Parallel(block_M, block_N): p_cast[i, j] = P_acc[i, j] * (dP[i, j] - delta_shared[j]) * sm_scale @@ -367,8 +368,8 @@ def flash_bwd_kernel(Q: T.Tensor(q_shape, @tilelang.jit(out_idx=[1]) def flashattn_bwd_postprocess(batch, heads, seq_len, dim): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim] blk = 64 @@ -376,8 +377,8 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): def flash_bwd_post(dQ_in: T.Tensor(shape, accum_dtype), dQ_out: T.Tensor(shape, dtype)): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.copy( - dQ_in[bz, bx * blk:(bx + 1) * blk, by, :], - dQ_out[bz, bx * blk:(bx + 1) * blk, by, :], + dQ_in[bz, bx * blk : (bx + 1) * blk, by, :], + dQ_out[bz, bx * blk : (bx + 1) * blk, by, :], ) return flash_bwd_post @@ -444,22 +445,14 @@ def benchmark_function(func, *args, warmup=10, repeat=100): return np.median(times) -def main(batch: int = 1, - heads: int = 8, - seq_len: int = 4096, - dim: int = 128, - is_causal: bool = False, - groups: int = 1): - +def main(batch: int = 1, heads: int = 8, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1): device = "cuda" dtype = torch.float16 torch.manual_seed(42) torch.cuda.manual_seed(42) - print( - f"Test configuration: batch={batch}, heads={heads}, seq_len={seq_len}, dim={dim}, is_causal={is_causal}, groups={groups}" - ) + print(f"Test configuration: batch={batch}, heads={heads}, seq_len={seq_len}, dim={dim}, is_causal={is_causal}, groups={groups}") flops_per_gemm = 2.0 * batch * heads * seq_len * seq_len * dim total_flops = 5 * flops_per_gemm @@ -515,22 +508,19 @@ def main(batch: int = 1, o_ref.backward(dO) print("Verifying backward pass correctness...") - dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison( - dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05) + dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison(dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05) if dq_close: print("dQ is correct.") else: print("dQ mismatch detected.") - dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison( - dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05) + dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison(dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05) if dk_close: print("dK is correct.") else: print("dK mismatch detected.") - dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison( - dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05) + dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison(dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05) if dv_close: print("dV is correct.") else: @@ -551,9 +541,7 @@ def run_reference_fwd_bwd(): torch.cuda.synchronize() ref_latency = benchmark_function(run_reference_fwd_bwd, warmup=10, repeat=100) - print( - f"Reference PyTorch Forward+Backward: {ref_latency:.2f} ms | {total_flops / ref_latency * 1e-9:.2f} TFlops" - ) + print(f"Reference PyTorch Forward+Backward: {ref_latency:.2f} ms | {total_flops / ref_latency * 1e-9:.2f} TFlops") def run_complete_fwd_bwd(): o_tl_bench, lse_tl_bench = fwd_kernel(q, k, v) @@ -591,12 +579,12 @@ def run_complete_fwd_bwd(): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=8, help='heads') - parser.add_argument('--seq_len', type=int, default=1024, help='sequence length') - parser.add_argument('--dim', type=int, default=64, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--groups', type=int, default=1, help='groups') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=8, help="heads") + parser.add_argument("--seq_len", type=int, default=1024, help="sequence length") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--groups", type=int, default=1, help="groups") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups) diff --git a/examples/amd/example_amd_flash_attn_fwd.py b/examples/amd/example_amd_flash_attn_fwd.py index 6ec5db1e5..ca9c361ff 100644 --- a/examples/amd/example_amd_flash_attn_fwd.py +++ b/examples/amd/example_amd_flash_attn_fwd.py @@ -2,29 +2,42 @@ import torch.nn.functional as F import tilelang import tilelang.language as T -from tilelang.primitives.gemm.base import GemmWarpPolicy +from tilelang.tileop.base import GemmWarpPolicy import itertools import argparse from functools import partial +# Custom supply function to ensure tensors are created on GPU +def supply_tensors_gpu(params): + """Supply function that creates tensors on GPU for ROCm/HIP.""" + tensors = [] + for param in params: + if hasattr(param, "shape") and hasattr(param, "dtype"): + # Force creation on GPU device + shape = [int(s) for s in param.shape] + tensor = torch.randn(shape, dtype=param.dtype, device="cuda") + tensors.append(tensor) + else: + tensors.append(param) + return tensors + + def ref_program(Q, K, V, is_causal, groups=1): - assert Q.size( - 2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" - assert Q.size( - 2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" dim = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -43,27 +56,27 @@ def get_configs(): valid_configs = [] - for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(block_M, block_N, num_split_q, - threads, num_stages, - enable_rasterization, k_pack, - panel_size, qk_coalesced_width, - v_coalesced_width): - valid_configs.append({ - "block_M": m, - "block_N": n, - "num_split_q": s, - "threads": t, - "num_stages": stages, - "enable_rasterization": r, - "k_pack": k, - "panel_size": p, - "qk_coalesced_width": qkw, - "v_coalesced_width": vw, - }) + for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product( + block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width + ): + valid_configs.append( + { + "block_M": m, + "block_N": n, + "num_split_q": s, + "threads": t, + "num_stages": stages, + "enable_rasterization": r, + "k_pack": k, + "panel_size": p, + "qk_coalesced_width": qkw, + "v_coalesced_width": vw, + } + ) return valid_configs -@tilelang.autotune(configs=get_configs(), cache_input_tensors=True) +@tilelang.autotune(configs=get_configs(), cache_input_tensors=True, supply_prog=supply_tensors_gpu) @tilelang.jit(out_idx=[3]) def fast_flashattn( batch, @@ -83,22 +96,22 @@ def fast_flashattn( qk_coalesced_width: int, v_coalesced_width: int, ): - scale = (1.0 / dim)**0.5 + scale = (1.0 / dim) ** 0.5 head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 vec_size = qk_coalesced_width v_vec_size = v_coalesced_width @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined): T.use_swizzle(panel_size, enable=enable_rasterization) @@ -108,7 +121,7 @@ def main( num_q_blocks = T.ceildiv(seq_len, block_M) - bx = T.alloc_var("int32") + bx = T.alloc_var(T.int32) bx = b_split with T.While(bx < num_q_blocks): @@ -132,32 +145,21 @@ def main( m_prev = T.alloc_fragment([block_M], accum_dtype) scale_factor = T.alloc_fragment([block_M], accum_dtype) - T.copy( - Q[bz, q_block_offset:q_block_offset + block_M, by, :], - Q_shared, - coalesced_width=vec_size) + T.copy(Q[bz, q_block_offset : q_block_offset + block_M, by, :], Q_shared, coalesced_width=vec_size) - loop_end_k = T.ceildiv(q_block_offset + block_M, - block_N) if is_causal else T.ceildiv(seq_len, block_N) + loop_end_k = T.ceildiv(q_block_offset + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) row_sum = T.alloc_fragment([block_M], accum_dtype) for k in T.Pipelined(loop_end_k, num_stages=num_stages): kv_idx = k * block_N - T.copy( - K[bz, kv_idx:kv_idx + block_N, by // groups, :], - K_shared, - coalesced_width=vec_size) - T.copy( - V[bz, kv_idx:kv_idx + block_N, by // groups, :], - V_shared, - coalesced_width=v_vec_size) + T.copy(K[bz, kv_idx : kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size) + T.copy(V[bz, kv_idx : kv_idx + block_N, by // groups, :], V_shared, coalesced_width=v_vec_size) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) T.gemm( @@ -171,6 +173,8 @@ def main( T.copy(m_i, m_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for i in T.Parallel(block_M): + m_i[i] = T.max(m_i[i], m_prev[i]) for i in T.Parallel(block_M): sf = T.exp(m_prev[i] * scale - m_i[i] * scale) @@ -205,13 +209,7 @@ def main( return main -def main(batch: int = 1, - heads: int = 8, - seq_len: int = 4096, - dim: int = 128, - is_causal: bool = False, - groups: int = 1): - +def main(batch: int = 1, heads: int = 8, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1): flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim total_flops = 2 * flops_per_matmul if is_causal: @@ -233,18 +231,16 @@ def main(batch: int = 1, print(f"Reference (PyTorch): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops") latency = profiler.do_bench(warmup=100) - print( - f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops" - ) + print(f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops") if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=8, help='heads') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--groups', type=int, default=1, help='groups') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=8, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--groups", type=int, default=1, help="groups") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups) diff --git a/examples/analyze/README.md b/examples/analyze/README.md index 8171d8826..1c2788b0b 100644 --- a/examples/analyze/README.md +++ b/examples/analyze/README.md @@ -21,9 +21,9 @@ M = N = K = 1024 def kernel(block_M=128, block_N=128, block_K=32, num_stages=3, thread_num=128): @T.prim_func - def main(A: T.Tensor((M, K), "float16"), - B: T.Tensor((N, K), "float16"), - C: T.Tensor((M, N), "float")): + def main(A: T.Tensor((M, K), T.float16), + B: T.Tensor((N, K), T.float16), + C: T.Tensor((M, N), T.float)): # ... (kernel definition) return main @@ -40,9 +40,9 @@ from tilelang.carver.arch import CUDA def kernel(N=64, C=256, H=512, W=512, F=512, K=3, block_M=64, block_N=128): @T.prim_func - def main(data: T.Tensor((N, H, W, C), "float16"), - kernel: T.Tensor((K, K, C, F), "float16"), - out: T.Tensor((N, (H-K+1), (W-K+1), F), "float")): + def main(data: T.Tensor((N, H, W, C), T.float16), + kernel: T.Tensor((K, K, C, F), T.float16), + out: T.Tensor((N, (H-K+1), (W-K+1), F), T.float)): # ... (convolution kernel definition) return main @@ -64,10 +64,10 @@ class AnalysisResult: ``` ### `Analyzer` Class Methods #### `analysis(fn, device)` -* ​Parameters: - * fn: TVM IRModule or PrimFunc - * device: Device configuration object -* Returns: AnalysisResult +- ​Parameters: + - fn: TVM IRModule or PrimFunc + - device: Device configuration object +- Returns: AnalysisResult #### Supported Architectures ```python # Extendable to custom hardware via: "compute_capability": (cores_per_SM, clock_GHz, flops_per_cycle, max_SM_count) diff --git a/examples/analyze/example_conv_analyze.py b/examples/analyze/example_conv_analyze.py index 540fcf4b7..06e5a86e9 100644 --- a/examples/analyze/example_conv_analyze.py +++ b/examples/analyze/example_conv_analyze.py @@ -2,7 +2,6 @@ from tilelang.tools import Analyzer from tilelang.carver.arch import CUDA from tilelang.carver.arch import CDNA -from tilelang.layout import make_swizzled_layout import torch N = 64 @@ -25,38 +24,21 @@ def check_hopper(): return False -def kernel(N, - C, - H, - W, - F, - K, - S, - D, - P, - block_M, - block_N, - block_K, - num_stages, - threads, - dtype="float16", - accum_dtype="float"): +def kernel(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 is_hopper = check_hopper() @T.prim_func def conv( - data: T.Tensor((N, H, W, C), dtype), - kernel: T.Tensor((KH, KW, C, F), dtype), - out: T.Tensor((N, OH, OW, F), dtype), + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), ): - with T.Kernel( - T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by): data_shared = T.alloc_shared((block_M, block_K), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -65,12 +47,6 @@ def conv( kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) - T.annotate_layout({ - out_shared: make_swizzled_layout(out_shared), - data_shared: make_swizzled_layout(data_shared), - kernel_shared: make_swizzled_layout(kernel_shared), - }) - T.clear(out_local) for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): if is_hopper: @@ -81,10 +57,8 @@ def conv( m = by * block_M + i access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P access_w = m % OW * S + k // C % KW * D - P - in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and - (access_w < W)) - data_shared[i, j] = T.if_then_else( - in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W) + data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) T.gemm(data_shared, kernel_shared, out_local) diff --git a/examples/analyze/example_gemm_analyze.py b/examples/analyze/example_gemm_analyze.py index bfd934f6a..0367af126 100644 --- a/examples/analyze/example_gemm_analyze.py +++ b/examples/analyze/example_gemm_analyze.py @@ -15,14 +15,14 @@ def kernel( thread_num=None, enable_rasteration=None, ): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def matmul( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) diff --git a/examples/attention_sink/README.md b/examples/attention_sink/README.md index ed4b7004e..2cba8f0cc 100644 --- a/examples/attention_sink/README.md +++ b/examples/attention_sink/README.md @@ -2,7 +2,6 @@ We compare with an optimized version of the official Triton implementation [here](https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py). - ## Algorithm ### Forward The only change from vanilla FlashAttention is that `sinks` should be taken into consideration in the softmax, which requires an extra rescaling at the epilogue stage. @@ -43,4 +42,4 @@ where $P_{b, h, q}$ is the proportion of $sink_h$ in the softmax in the $b$-th b | 16384 | 64 | 309.46 | **400.62** | 1.29x | | 16384 | 128 | 418.99 | **549.11** | 1.31x | -> The backward performance will be further optimized in the future. \ No newline at end of file +> The backward performance will be further optimized in the future. diff --git a/examples/attention_sink/benchmark_gqa_sink_fwd.py b/examples/attention_sink/benchmark_gqa_sink_fwd.py index 00256286b..211ef1d18 100644 --- a/examples/attention_sink/benchmark_gqa_sink_fwd.py +++ b/examples/attention_sink/benchmark_gqa_sink_fwd.py @@ -1,10 +1,12 @@ import torch import argparse from tilelang.profiler import do_bench +from tilelang import language as T import triton import triton.language as tl from triton.tools.tensor_descriptor import TensorDescriptor from example_gqa_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs +from typing import Optional @triton.jit @@ -50,8 +52,7 @@ def triton_kernel( q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM]) if BANDWIDTH: - lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M + lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M else: lo, hi = 0, start_q + (start_m + 1) * BLOCK_M @@ -94,7 +95,7 @@ def triton_kernel( Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc) -def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor: +def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor: bs, n_heads, seq_q, head_dim = Q.shape _, n_heads_kv, seq_kv, _ = K.shape BLOCK_M = 64 @@ -119,7 +120,8 @@ def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tens BANDWIDTH=window_size, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - start_q=seq_kv - seq_q) + start_q=seq_kv - seq_q, + ) return o @@ -130,18 +132,18 @@ def main( seq_kv: int = 256, dim: int = 128, groups: int = 8, - window_size: int | None = None, + window_size: Optional[int] = None, dtype: str = "float16", tune: bool = False, ): - torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= seq_q - flops_per_matmul = 2.0 * batch * heads * min( - window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 total_flops = 2 * flops_per_matmul @@ -169,15 +171,14 @@ def main( block_N=block_N, num_stages=num_stages, threads=threads, - dtype=dtype) + dtype=dtype, + ) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype) if torch.allclose( - triton_program(Q, K, V, sinks, window_size), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2): + triton_program(Q, K, V, sinks, window_size), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ): print("Checks for triton passed.✅") else: print("Checks for triton failed.❌") @@ -197,20 +198,14 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='heads') - parser.add_argument('--seq_q', type=int, default=2048, help='sequence length of query') - parser.add_argument('--seq_kv', type=int, default=2048, help='sequence length of key/value') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--groups', type=int, default=8, help='groups') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_q", type=int, default=2048, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=2048, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--groups", type=int, default=8, help="groups") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, - args.dtype, args.tune) + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/benchmark_mha_sink_fwd.py b/examples/attention_sink/benchmark_mha_sink_fwd.py index 734870fe4..50747e6b0 100644 --- a/examples/attention_sink/benchmark_mha_sink_fwd.py +++ b/examples/attention_sink/benchmark_mha_sink_fwd.py @@ -1,10 +1,12 @@ import torch import argparse from tilelang.profiler import do_bench +from tilelang import language as T import triton import triton.language as tl from triton.tools.tensor_descriptor import TensorDescriptor from example_mha_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs +from typing import Optional @triton.jit @@ -49,8 +51,7 @@ def triton_kernel( q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM]) if BANDWIDTH: - lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M + lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M else: lo, hi = 0, start_q + (start_m + 1) * BLOCK_M @@ -93,7 +94,7 @@ def triton_kernel( Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc) -def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor: +def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor: bs, n_heads, seq_q, head_dim = Q.shape seq_kv = K.shape[2] BLOCK_M = 64 @@ -116,26 +117,29 @@ def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tens BANDWIDTH=window_size, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - start_q=seq_kv - seq_q) + start_q=seq_kv - seq_q, + ) return o -def main(batch: int = 1, - heads: int = 32, - seq_q: int = 256, - seq_kv: int = 256, - dim: int = 128, - window_size: int | None = None, - dtype: str = "float16", - tune: bool = False): - torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] +def main( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False, +): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= seq_q - flops_per_matmul = 2.0 * batch * heads * min( - window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 total_flops = 2 * flops_per_matmul @@ -162,15 +166,14 @@ def main(batch: int = 1, block_N=block_N, num_stages=num_stages, threads=threads, - dtype=dtype) + dtype=dtype, + ) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) torch.testing.assert_close( - kernel(Q, K, V, sinks), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2) + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) print("All checks passed.✅") latency = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500) @@ -183,19 +186,13 @@ def main(batch: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query') - parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") - parser.add_argument('--tune', action='store_true', help='tune') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune") args = parser.parse_args() - main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, - args.tune) + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py index f8f970ea4..cfdcd21b5 100644 --- a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py @@ -13,50 +13,50 @@ def get_bwd_configs(): sm_version = sm_major * 10 + sm_minor if sm_version == 80: return 64, 32, 1, 128 - elif sm_version == 90: - return 128, 32, 2, 256 else: - raise ValueError(f"Unsupported SM version: {sm_version}") + return 128, 32, 2, 256 @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd( - batch, - heads, - seq_len, - dim, - groups=1, - window_size=None, # None for full attention - sm_scale=None, - block_M=64, - block_N=64, - num_stages=1, - threads=128, - dtype: str = "float16"): - + batch, + heads, + seq_len, + dim, + groups=1, + window_size=None, # None for full attention + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype: T.dtype = T.float16, +): if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, heads, seq_len, dim] kv_shape = [batch, head_kv, seq_len, dim] - accum_dtype = "float" + accum_dtype = T.float32 @T.prim_func def flash_fwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(kv_shape, dtype), # type: ignore - V: T.Tensor(kv_shape, dtype), # type: ignore - Output: T.Tensor(q_shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Sinks: T.Tensor([heads], dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(kv_shape, dtype), # type: ignore + V: T.Tensor(kv_shape, dtype), # type: ignore + Output: T.Tensor(q_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Sinks: T.Tensor([heads], dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -72,8 +72,7 @@ def flash_fwd( logsum = T.alloc_fragment([block_M], accum_dtype) sinks = T.alloc_fragment([heads], dtype) - T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -81,34 +80,30 @@ def flash_fwd( sinks[i] = Sinks[by] end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) - start = T.alloc_local([1], 'int32') - if window_size is not None: - start[0] = T.max(0, (bx * block_M - window_size) // block_N) - else: - start[0] = 0 - - for k in T.Pipelined(start[0], end, num_stages=num_stages): - T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared) + start = T.max(0, (bx * block_M - window_size) // block_N) if window_size is not None else 0 + + for k in T.Pipelined(start, end, num_stages=num_stages): + T.copy(K[bz, by // groups, k * block_N : (k + 1) * block_N, :], K_shared) for i, j in T.Parallel(block_M, block_N): q_idx = bx * block_M + i k_idx = k * block_N + j if window_size is not None: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, - 0, -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) else: acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, by // groups, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by // groups, k * block_N : (k + 1) * block_N, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # NOTE(wt): check_inf is necessary for sliding window attention. for i in T.Parallel(block_M): if window_size is not None: - scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, - scores_max[i]) + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): @@ -125,32 +120,33 @@ def flash_fwd( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i in T.Parallel(block_M): - logsum[i] += T.exp2(sinks[i] * 1.44269504 - - scores_max[i] * scale) # The only change for attention sink + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"): - accum_dtype = "float" + }, +) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 shape = [batch, heads, seq_len, dim] blk = 32 @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -159,65 +155,61 @@ def flash_bwd_prep( delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], do) + T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment - return T.Layout(dQ.shape, - lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) @tilelang.jit( - out_idx=[1], pass_configs={ + out_idx=[1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"): - accum_dtype = "float" + }, +) +def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 shape = [batch, heads, seq_len, dim] blk = 64 @T.prim_func def flash_bwd_post( - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(shape, dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) T.copy( - dQ[bz, by, bx * blk:(bx + 1) * blk, :], - dQ_out[bz, by, bx * blk:(bx + 1) * blk, :], + dQ[bz, by, bx * blk : (bx + 1) * blk, :], + dQ_out[bz, by, bx * blk : (bx + 1) * blk, :], ) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd(batch, - heads, - seq_len, - dim, - groups, - window_size=None, - sm_scale=None, - dtype="float16"): # None for full attention +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale=None, dtype=T.float16): # None for full attention if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, heads, seq_len, dim] kv_shape = [batch, head_kv, seq_len, dim] - accum_dtype = "float" + accum_dtype = T.float32 block_M, block_N, num_stages, threads = get_bwd_configs() @@ -226,15 +218,15 @@ def flashattn_bwd(batch, @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(kv_shape, dtype), # type: ignore - V: T.Tensor(kv_shape, dtype), # type: ignore - dO: T.Tensor(q_shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(kv_shape, accum_dtype), # type: ignore - dV: T.Tensor(kv_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(kv_shape, dtype), # type: ignore + V: T.Tensor(kv_shape, dtype), # type: ignore + dO: T.Tensor(q_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(kv_shape, accum_dtype), # type: ignore + dV: T.Tensor(kv_shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) @@ -254,47 +246,44 @@ def flash_bwd( dv_shared = T.alloc_shared([block_M, dim], accum_dtype) dk_shared = T.alloc_shared([block_M, dim], accum_dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) - T.copy(K[bz, bx // groups, by * block_M:(by + 1) * block_M, :], K_shared) - T.copy(V[bz, bx // groups, by * block_M:(by + 1) * block_M, :], V_shared) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) + T.copy(K[bz, bx // groups, by * block_M : (by + 1) * block_M, :], K_shared) + T.copy(V[bz, bx // groups, by * block_M : (by + 1) * block_M, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) - loop_ed = T.alloc_local([1], 'int32') - if window_size is not None: - loop_ed[0] = T.min( - T.ceildiv((by + 1) * block_M + window_size, block_N), - T.ceildiv(seq_len, block_N)) - else: - loop_ed[0] = T.ceildiv(seq_len, block_N) - for k in T.Pipelined(loop_st, loop_ed[0], num_stages=num_stages): - T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q) + loop_ed = ( + T.min(T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(seq_len, block_N)) + if window_size is not None + else T.ceildiv(seq_len, block_N) + ) + + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) for i, j in T.Parallel(block_M, block_N): if window_size is not None: qkT[i, j] = T.if_then_else( - by * block_M + i <= k * block_N + j and - by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0) + by * block_M + i <= k * block_N + j and by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0 + ) else: - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], dst=do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], dst=do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -303,50 +292,46 @@ def flash_bwd( T.copy(dsT_cast, dsT_shared) T.clear(dq) T.gemm(dsT_shared, K_shared, dq, transpose_A=True) - T.atomic_add(dQ[bz, bx, k * block_N:(k + 1) * block_N, :], dq) + T.atomic_add(dQ[bz, bx, k * block_N : (k + 1) * block_N, :], dq) T.copy(dv, dv_shared) - T.atomic_add(dV[bz, bx // groups, by * block_M:(by + 1) * block_M, :], dv_shared) + T.atomic_add(dV[bz, bx // groups, by * block_M : (by + 1) * block_M, :], dv_shared) T.copy(dk, dk_shared) - T.atomic_add(dK[bz, bx // groups, by * block_M:(by + 1) * block_M, :], dk_shared) + T.atomic_add(dK[bz, bx // groups, by * block_M : (by + 1) * block_M, :], dk_shared) return flash_bwd @tilelang.jit(out_idx=-1) -def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: str = "float16"): - accum_dtype = "float" +def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: T.dtype = T.float16): + accum_dtype = T.float32 shape = [batch, heads, seq_len] @T.prim_func def flash_bwd_dsink( - Sinks: T.Tensor([heads], dtype), # type: ignore - Delta: T.Tensor(shape, accum_dtype), # type: ignore - lse: T.Tensor(shape, accum_dtype), # type: ignore - dsinks: T.Tensor(shape, dtype), # type: ignore + Sinks: T.Tensor([heads], dtype), # type: ignore + Delta: T.Tensor(shape, accum_dtype), # type: ignore + lse: T.Tensor(shape, accum_dtype), # type: ignore + dsinks: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=256) as (bx, by, bz): - sink = T.alloc_local([1], dtype) lse_fragment = T.alloc_fragment([block], accum_dtype) delta_fragment = T.alloc_fragment([block], accum_dtype) dsink_fragment = T.alloc_fragment([block], dtype) - sink[0] = Sinks[bx] - T.copy(lse[bz, bx, by * block:(by + 1) * block], lse_fragment) - T.copy(Delta[bz, bx, by * block:(by + 1) * block], delta_fragment) + sink = Sinks[bx] + T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment) + T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment) for i in T.Parallel(block): - dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - - lse_fragment[i]) * delta_fragment[i] - T.copy(dsink_fragment, dsinks[bz, bx, by * block:(by + 1) * block]) + dsink_fragment[i] = -T.exp2(sink * 1.44269504 - lse_fragment[i]) * delta_fragment[i] + T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block]) return flash_bwd_dsink class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, sinks, window_size, groups): - def maybe_contiguous(x): if x.stride(-1) != 1: return x.contiguous() @@ -354,7 +339,7 @@ def maybe_contiguous(x): q, k, v, sinks = [maybe_contiguous(x) for x in (q, k, v, sinks)] BATCH, H, N_CTX, D_HEAD = q.shape - dtype = "float16" if q.dtype == torch.float16 else "bfloat16" + dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16 kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, groups, window_size, dtype=dtype) o, lse = kernel(q, k, v, sinks) ctx.save_for_backward(q, k, v, sinks, o, lse) @@ -367,7 +352,7 @@ def backward(ctx, do): q, k, v, sinks, o, lse = ctx.saved_tensors BATCH, H, N_CTX, D_HEAD = q.shape groups = ctx.groups - dtype = "float16" if q.dtype == torch.float16 else "bfloat16" + dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16 kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) @@ -392,13 +377,14 @@ def backward(ctx, do): # Adapted and optimized from # https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py -def ref_program(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sinks: torch.Tensor, - sliding_window: Optional[int] = None, - dtype: torch.dtype = torch.float16) -> torch.Tensor: - +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() batch_size, num_keys, num_key_value_heads, head_dim = key.shape @@ -434,32 +420,32 @@ def ref_program(query: torch.Tensor, output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) - output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, - head_dim).to(dtype) + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) return output.transpose(1, 2).contiguous() -def main(BATCH: int = 1, - H: int = 8, - N_CTX: int = 512, - D_HEAD: int = 64, - groups: int = 2, - window_size: int | None = None, - dtype: str = "float16"): - torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] +def main( + BATCH: int = 1, + H: int = 8, + N_CTX: int = 512, + D_HEAD: int = 64, + groups: int = 2, + window_size: Optional[int] = None, + dtype: str = "float16", +): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= N_CTX - flops_per_matmul = 2.0 * BATCH * H * min( - window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation + flops_per_matmul = 2.0 * BATCH * H * min(window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5 total_flops = 5 * flops_per_matmul - Q = (torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()) - K = torch.randn( - BATCH, H // groups, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() + K = torch.randn(BATCH, H // groups, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() V = torch.randn_like(K).requires_grad_() sinks = torch.randn(H, dtype=torch_dtype, device="cuda").requires_grad_() dO = torch.randn_like(Q) @@ -480,19 +466,14 @@ def main(BATCH: int = 1, # Checks rtol, atol = { - "float16": (1e-2, 1e-2), - "bfloat16": (2e-2, 2e-2), + T.float16: (1e-2, 1e-2), + T.bfloat16: (2e-2, 2e-2), }[dtype] - assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f'O max err: {(O-O_ref).abs().max()}' - assert torch.allclose( - dV, dV_ref, rtol=rtol, atol=atol), f'dV max err: {(dV-dV_ref).abs().max()}' - assert torch.allclose( - dK, dK_ref, rtol=rtol, atol=atol), f'dK max err: {(dK-dK_ref).abs().max()}' - assert torch.allclose( - dQ, dQ_ref, rtol=rtol, atol=atol), f'dq max err: {(dQ-dQ_ref).abs().max()}' - assert torch.allclose( - dsinks, dsinks_ref, rtol=rtol, - atol=atol), f'dsinks max err: {(dsinks-dsinks_ref).abs().max()}' + assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}" + assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}" + assert torch.allclose(dK, dK_ref, rtol=rtol, atol=atol), f"dK max err: {(dK - dK_ref).abs().max()}" + assert torch.allclose(dQ, dQ_ref, rtol=rtol, atol=atol), f"dq max err: {(dQ - dQ_ref).abs().max()}" + assert torch.allclose(dsinks, dsinks_ref, rtol=rtol, atol=atol), f"dsinks max err: {(dsinks - dsinks_ref).abs().max()}" print("All checks passed for tilelang kernels.✅") @@ -511,19 +492,57 @@ def tl_bwd(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf( + BATCH: int = 1, + H: int = 8, + N_CTX: int = 512, + D_HEAD: int = 64, + groups: int = 2, + window_size: Optional[int] = None, + dtype: str = "float16", +): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + with torch.no_grad(): + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda") + K = torch.randn(BATCH, H // groups, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda") + V = torch.randn_like(K) + sinks = torch.randn(H, dtype=torch_dtype, device="cuda") + dO = torch.randn_like(Q) + fwd = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, groups, window_size, dtype=dtype) + O, lse = fwd(Q, K, V, sinks) + + def maybe_contiguous(x): + return x if x.stride(-1) == 1 else x.contiguous() + + do, q, k, v, sinks_c, o = [maybe_contiguous(x) for x in (dO, Q, K, V, sinks, O)] + k_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) + Delta = k_prep(o, do) + k_bwd = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, groups, window_size, dtype=dtype) + k_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype) + q_shape = (BATCH, H, N_CTX, D_HEAD) + head_kv = H // groups + kv_shape = (BATCH, head_kv, N_CTX, D_HEAD) + dq = torch.zeros(q_shape, dtype=torch.float32, device="cuda") + dk = torch.zeros(kv_shape, dtype=torch.float32, device="cuda") + dv = torch.zeros(kv_shape, dtype=torch.float32, device="cuda") + k_bwd(q, k, v, do, lse, Delta, dq, dk, dv) + _ = k_dsink(sinks_c, Delta, lse).sum(0).sum(1) + + def run_kernel_only(): + k_bwd(q, k, v, do, lse, Delta, dq, dk, dv) + + latency_ms = do_bench(run_kernel_only, backend="cupti") + return latency_ms + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='Batch size') - parser.add_argument('--h', type=int, default=64, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=4096, help='Context size') - parser.add_argument('--d_head', type=int, default=128, help='Head dimension') - parser.add_argument('--groups', type=int, default=8, help='Groups') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--batch", type=int, default=1, help="Batch size") + parser.add_argument("--h", type=int, default=64, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=4096, help="Context size") + parser.add_argument("--d_head", type=int, default=128, help="Head dimension") + parser.add_argument("--groups", type=int, default=8, help="Groups") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head, args.groups, args.window_size, args.dtype) diff --git a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py index 49a3ecbd8..fa73df0af 100644 --- a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py +++ b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py @@ -6,7 +6,6 @@ from tilelang.autotuner import autotune from tilelang.profiler import do_bench import tilelang.language as T -from tilelang.layout import make_swizzled_layout import itertools import argparse from typing import Optional @@ -23,9 +22,11 @@ def get_configs(): rep=100, ) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn( batch, heads, @@ -39,106 +40,30 @@ def flashattn( block_N=128, num_stages=2, threads=256, - dtype: str = "float16", + dtype: T.dtype = T.float16, ): - if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, head_kv, seq_kv, dim] - accum_dtype = "float" + accum_dtype = T.float32 past_len = seq_kv - seq_q assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" - @T.macro - def MMA0( - K: T.Tensor(kv_shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared) - for i, j in T.Parallel(block_M, block_N): - q_idx = bx * block_M + i + past_len - k_idx = k * block_N + j - if window_size is not None: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, - -T.infinity(acc_s.dtype)) - else: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(kv_shape, dtype), - V_shared: T.SharedBuffer([block_M, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, by // groups, k * block_N:(k + 1) * block_N, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # NOTE(wt): check_inf is necessary for sliding window attention. - for i in T.Parallel(block_M): - if window_size is not None: - scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, - scores_max[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), - Sinks: T.Tensor([heads], dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + Sinks: T.Tensor([heads], dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -155,61 +80,83 @@ def main( logsum = T.alloc_fragment([block_M], accum_dtype) sinks = T.alloc_fragment([block_M], dtype) - T.annotate_layout({ - Q_shared: make_swizzled_layout(Q_shared), - K_shared: make_swizzled_layout(K_shared), - V_shared: make_swizzled_layout(V_shared), - O_shared: make_swizzled_layout(O_shared), - }) - - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) for i in T.Parallel(block_M): sinks[i] = Sinks[by] - end = T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) - start = T.alloc_local([1], 'int32') - if window_size is not None: - start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N) - else: - start[0] = 0 + start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0 for k in T.Pipelined( - start[0], - end, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + start, + end, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): + T.copy(K[bz, by // groups, k * block_N : (k + 1) * block_N, :], K_shared) + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + if window_size is not None: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) + else: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # NOTE(wt): check_inf is necessary for sliding window attention. + for i in T.Parallel(block_M): + if window_size is not None: + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, by // groups, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i in T.Parallel(block_M): - logsum[i] += T.exp2(sinks[i] * 1.44269504 - - scores_max[i] * scale) # The only change for attention sink + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main # Following functions are adapted and optimized from # https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py -def ref_program(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sinks: torch.Tensor, - sliding_window: Optional[int] = None, - dtype: torch.dtype = torch.float16) -> torch.Tensor: - +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() batch_size, num_keys, num_key_value_heads, head_dim = key.shape @@ -245,23 +192,15 @@ def ref_program(query: torch.Tensor, output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) - output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, - head_dim).to(dtype) + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) return output.transpose(1, 2).contiguous() -def gen_inputs( - B, - H, - Sq, - Skv, - D, - groups, - dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda') - key = torch.randn([B, H // groups, Skv, D], dtype=dtype, device='cuda') - value = torch.randn([B, H // groups, Skv, D], dtype=dtype, device='cuda') - sinks = torch.randn([H], dtype=dtype, device='cuda') +def gen_inputs(B, H, Sq, Skv, D, groups, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda") + key = torch.randn([B, H // groups, Skv, D], dtype=dtype, device="cuda") + value = torch.randn([B, H // groups, Skv, D], dtype=dtype, device="cuda") + sinks = torch.randn([H], dtype=dtype, device="cuda") return query, key, value, sinks @@ -272,18 +211,18 @@ def main( seq_kv: int = 256, dim: int = 128, groups: int = 8, - window_size: int | None = None, - dtype: str = "float16", + window_size: Optional[int] = None, + dtype: T.dtype = T.float16, tune: bool = False, ): - torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= seq_q - flops_per_matmul = 2.0 * batch * heads * min( - window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 total_flops = 2 * flops_per_matmul @@ -311,15 +250,14 @@ def main( block_N=block_N, num_stages=num_stages, threads=threads, - dtype=dtype) + dtype=dtype, + ) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype) torch.testing.assert_close( - kernel(Q, K, V, sinks), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2) + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) print("All checks passed.✅") # Benchmark tilelang @@ -328,22 +266,51 @@ def main( print("Tilelang: {:.2f} TFlops".format(total_flops / latency_tilelang * 1e-9)) +def run_regression_perf( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + groups: int = 8, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False, +): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + kernel = flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + groups, + window_size, + block_M=block_M, + block_N=block_N, + num_stages=num_stages, + threads=threads, + dtype=dtype, + ) + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype) + latency = do_bench(lambda: kernel(Q, K, V, sinks), backend="cupti") + return latency + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='heads') - parser.add_argument('--seq_q', type=int, default=2048, help='sequence length of query') - parser.add_argument('--seq_kv', type=int, default=2048, help='sequence length of key/value') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--groups', type=int, default=8, help='groups') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_q", type=int, default=2048, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=2048, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--groups", type=int, default=8, help="groups") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, - args.dtype, args.tune) + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/example_mha_sink_bwd_bhsd.py b/examples/attention_sink/example_mha_sink_bwd_bhsd.py index ee1c35ece..66905f55d 100644 --- a/examples/attention_sink/example_mha_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_bwd_bhsd.py @@ -20,40 +20,42 @@ def get_bwd_configs(): @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd( - batch, - heads, - seq_len, - dim, - window_size=None, # None for full attention, - sm_scale=None, - block_M=64, - block_N=64, - num_stages=1, - threads=128, - dtype: str = "float16"): - + batch, + heads, + seq_len, + dim, + window_size=None, # None for full attention, + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype: T.dtype = T.float16, +): if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] - accum_dtype = "float" + accum_dtype = T.float32 @T.prim_func def flash_fwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - Output: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Sinks: T.Tensor([heads], dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Sinks: T.Tensor([heads], dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -69,8 +71,7 @@ def flash_fwd( logsum = T.alloc_fragment([block_M], accum_dtype) sinks = T.alloc_fragment([heads], dtype) - T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -78,34 +79,30 @@ def flash_fwd( sinks[i] = Sinks[by] end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) - start = T.alloc_local([1], 'int32') - if window_size is not None: - start[0] = T.max(0, (bx * block_M - window_size) // block_N) - else: - start[0] = 0 - - for k in T.Pipelined(start[0], end, num_stages=num_stages): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + start = T.max(0, (bx * block_M - window_size) // block_N) if window_size is not None else 0 + + for k in T.Pipelined(start, end, num_stages=num_stages): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) for i, j in T.Parallel(block_M, block_N): q_idx = bx * block_M + i k_idx = k * block_N + j if window_size is not None: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, - 0, -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) else: acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # NOTE(wt): check_inf is necessary for sliding window attention. for i in T.Parallel(block_M): if window_size is not None: - scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, - scores_max[i]) + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): @@ -122,32 +119,33 @@ def flash_fwd( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i in T.Parallel(block_M): - logsum[i] += T.exp2(sinks[i] * 1.44269504 - - scores_max[i] * scale) # The only change for attention sink + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"): - accum_dtype = "float" + }, +) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 shape = [batch, heads, seq_len, dim] blk = 32 @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -156,49 +154,52 @@ def flash_bwd_prep( delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], do) + T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment - return T.Layout(dQ.shape, - lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) @tilelang.jit( - out_idx=[1], pass_configs={ + out_idx=[1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"): - accum_dtype = "float" + }, +) +def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 shape = [batch, heads, seq_len, dim] blk = 64 @T.prim_func def flash_bwd_post( - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(shape, dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) T.copy( - dQ[bz, by, bx * blk:(bx + 1) * blk, :], - dQ_out[bz, by, bx * blk:(bx + 1) * blk, :], + dQ[bz, by, bx * blk : (bx + 1) * blk, :], + dQ_out[bz, by, bx * blk : (bx + 1) * blk, :], ) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def flashattn_bwd( batch, heads, @@ -206,32 +207,31 @@ def flashattn_bwd( dim, window_size=None, # None for full attention sm_scale=None, - dtype: str = "float16", + dtype: T.dtype = T.float16, ): - block_M, block_N, num_stages, threads = get_bwd_configs() if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] - accum_dtype = "float" + accum_dtype = T.float32 if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" @T.prim_func def flash_bwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dK: T.Tensor(shape, dtype), # type: ignore - dV: T.Tensor(shape, dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) @@ -255,47 +255,43 @@ def flash_bwd( dv_shared = T.alloc_shared([block_M, dim], dtype) dk_shared = T.alloc_shared([block_M, dim], dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) - T.copy(K[bz, bx, by * block_M:(by + 1) * block_M, :], K_shared) - T.copy(V[bz, bx, by * block_M:(by + 1) * block_M, :], V_shared) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) + T.copy(K[bz, bx, by * block_M : (by + 1) * block_M, :], K_shared) + T.copy(V[bz, bx, by * block_M : (by + 1) * block_M, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) - loop_ed = T.alloc_local([1], 'int32') - if window_size is not None: - loop_ed[0] = T.min( - T.ceildiv((by + 1) * block_M + window_size, block_N), - T.ceildiv(seq_len, block_N)) - else: - loop_ed[0] = T.ceildiv(seq_len, block_N) - for k in T.Pipelined(loop_st, loop_ed[0], num_stages=num_stages): - T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q) + loop_ed = ( + T.min(T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(seq_len, block_N)) + if window_size is not None + else T.ceildiv(seq_len, block_N) + ) + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) for i, j in T.Parallel(block_M, block_N): if window_size is not None: qkT[i, j] = T.if_then_else( - by * block_M + i <= k * block_N + j and - by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0) + by * block_M + i <= k * block_N + j and by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0 + ) else: - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], dst=do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], dst=do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, B=do, C=dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -304,51 +300,48 @@ def flash_bwd( T.copy(dsT_cast, dsT_shared) T.clear(dq) T.gemm(dsT_shared, K_shared, dq, transpose_A=True) - T.atomic_add(dQ[bz, bx, k * block_N:(k + 1) * block_N, :], dq) + T.atomic_add(dQ[bz, bx, k * block_N : (k + 1) * block_N, :], dq) T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, dV[bz, bx, by * block_M:(by + 1) * block_M, :]) - T.copy(dk_shared, dK[bz, bx, by * block_M:(by + 1) * block_M, :]) + T.copy(dv_shared, dV[bz, bx, by * block_M : (by + 1) * block_M, :]) + T.copy(dk_shared, dK[bz, bx, by * block_M : (by + 1) * block_M, :]) return flash_bwd @tilelang.jit(out_idx=-1) -def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: str = "float16"): - accum_dtype = "float" +def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: T.dtype = T.float16): + accum_dtype = T.float32 shape = [batch, heads, seq_len] @T.prim_func def flash_bwd_dsink( - Sinks: T.Tensor([heads], dtype), # type: ignore - Delta: T.Tensor(shape, accum_dtype), # type: ignore - lse: T.Tensor(shape, accum_dtype), # type: ignore - dsinks: T.Tensor(shape, accum_dtype), # type: ignore + Sinks: T.Tensor([heads], dtype), # type: ignore + Delta: T.Tensor(shape, accum_dtype), # type: ignore + lse: T.Tensor(shape, accum_dtype), # type: ignore + dsinks: T.Tensor(shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=128) as (bx, by, bz): - sink = T.alloc_local([1], dtype) lse_fragment = T.alloc_fragment([block], accum_dtype) delta_fragment = T.alloc_fragment([block], accum_dtype) dsink_fragment = T.alloc_fragment([block], accum_dtype) - sink[0] = Sinks[bx] - T.copy(lse[bz, bx, by * block:(by + 1) * block], lse_fragment) - T.copy(Delta[bz, bx, by * block:(by + 1) * block], delta_fragment) + sink = Sinks[bx] + T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment) + T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment) for i in T.Parallel(block): - dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - - lse_fragment[i]) * delta_fragment[i] - T.copy(dsink_fragment, dsinks[bz, bx, by * block:(by + 1) * block]) + dsink_fragment[i] = -T.exp2(sink * 1.44269504 - lse_fragment[i]) * delta_fragment[i] + T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block]) return flash_bwd_dsink class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, sinks, window_size): BATCH, H, N_CTX, D_HEAD = q.shape - dtype = "float16" if q.dtype == torch.float16 else "bfloat16" + dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16 kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, window_size, dtype=dtype) o, lse = kernel(q, k, v, sinks) ctx.save_for_backward(q, k, v, sinks, o, lse) @@ -366,7 +359,7 @@ def maybe_contiguous(x): return x do, q, k, v, sinks, o = [maybe_contiguous(x) for x in (do, q, k, v, sinks, o)] - dtype = "float16" if q.dtype == torch.float16 else "bfloat16" + dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16 kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) delta = kernel_prep(o, do) @@ -388,15 +381,15 @@ def maybe_contiguous(x): # Adapted and optimized from # https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py -def ref_program(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sinks: torch.Tensor, - sliding_window: Optional[int] = None, - dtype: torch.dtype = torch.float16) -> torch.Tensor: - - query = query.transpose(1, 2).contiguous().unsqueeze( - 3) # align with the original function's interface +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function's interface key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() @@ -431,29 +424,23 @@ def ref_program(query: torch.Tensor, output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) - output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, - head_dim).to(dtype) + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) return output.transpose(1, 2).contiguous() -def main(BATCH: int = 1, - H: int = 1, - N_CTX: int = 512, - D_HEAD: int = 128, - window_size: int | None = None, - dtype: str = "float16"): - torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] +def main(BATCH: int = 1, H: int = 1, N_CTX: int = 512, D_HEAD: int = 128, window_size: Optional[int] = None, dtype: T.dtype = T.float16): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= N_CTX - flops_per_matmul = 2.0 * BATCH * H * min( - window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation + flops_per_matmul = 2.0 * BATCH * H * min(window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5 total_flops = 5 * flops_per_matmul - Q = (torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()) + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() K = torch.randn_like(Q).requires_grad_() V = torch.randn_like(Q).requires_grad_() sinks = torch.randn(H, dtype=torch_dtype, device=Q.device).requires_grad_() @@ -475,19 +462,14 @@ def main(BATCH: int = 1, # Checks rtol, atol = { - "float16": (1e-2, 1e-2), - "bfloat16": (2e-2, 2e-2), + T.float16: (1e-2, 1e-2), + T.bfloat16: (2e-2, 2e-2), }[dtype] - assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f'O max err: {(O-O_ref).abs().max()}' - assert torch.allclose( - dV, dV_ref, rtol=rtol, atol=atol), f'dV max err: {(dV-dV_ref).abs().max()}' - assert torch.allclose( - dK, dK_ref, rtol=rtol, atol=atol), f'dK max err: {(dK-dK_ref).abs().max()}' - assert torch.allclose( - dQ, dQ_ref, rtol=rtol, atol=atol), f'dq max err: {(dQ-dQ_ref).abs().max()}' - assert torch.allclose( - dsinks, dsinks_ref, rtol=rtol, - atol=atol), f'dsinks max err: {(dsinks-dsinks_ref).abs().max()}' + assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}" + assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}" + assert torch.allclose(dK, dK_ref, rtol=rtol, atol=atol), f"dK max err: {(dK - dK_ref).abs().max()}" + assert torch.allclose(dQ, dQ_ref, rtol=rtol, atol=atol), f"dq max err: {(dQ - dQ_ref).abs().max()}" + assert torch.allclose(dsinks, dsinks_ref, rtol=rtol, atol=atol), f"dsinks max err: {(dsinks - dsinks_ref).abs().max()}" print("All checks passed for tilelang kernels.✅") @@ -506,18 +488,53 @@ def tl_bwd(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf( + BATCH: int = 1, + H: int = 32, + N_CTX: int = 512, + D_HEAD: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", +): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + with torch.no_grad(): + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda") + K = torch.randn_like(Q) + V = torch.randn_like(Q) + sinks = torch.randn(H, dtype=torch_dtype, device=Q.device) + dO = torch.randn_like(Q) + fwd = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, window_size=window_size, dtype=dtype) + O, lse = fwd(Q, K, V, sinks) + + def maybe_contiguous(x): + return x if x.stride(-1) == 1 else x.contiguous() + + do, q, k, v, sinks_c, o = [maybe_contiguous(x) for x in (dO, Q, K, V, sinks, O)] + k_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) + Delta = k_prep(o, do) + k_bwd = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, window_size, dtype=dtype) + k_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype) + shape = (BATCH, H, N_CTX, D_HEAD) + dq = torch.zeros(shape, dtype=torch.float32, device=Q.device) + dk = torch.empty(shape, dtype=torch_dtype, device=Q.device) + dv = torch.empty(shape, dtype=torch_dtype, device=Q.device) + k_bwd(q, k, v, do, lse, Delta, dq, dk, dv) + _ = k_dsink(sinks_c, Delta, lse).sum(0).sum(1) + + def run_kernel_only(): + k_bwd(q, k, v, do, lse, Delta, dq, dk, dv) + + latency_ms = do_bench(run_kernel_only, backend="cupti") + return latency_ms + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='Batch size') - parser.add_argument('--h', type=int, default=64, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=4096, help='Context size') - parser.add_argument('--d_head', type=int, default=128, help='Head dimension') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--batch", type=int, default=1, help="Batch size") + parser.add_argument("--h", type=int, default=64, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=4096, help="Context size") + parser.add_argument("--d_head", type=int, default=128, help="Head dimension") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head, args.window_size, args.dtype) diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd.py b/examples/attention_sink/example_mha_sink_fwd_bhsd.py index 7e59e277e..f24aa38b7 100644 --- a/examples/attention_sink/example_mha_sink_fwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd.py @@ -5,7 +5,6 @@ from tilelang.autotuner import autotune from tilelang.profiler import do_bench import tilelang.language as T -from tilelang.layout import make_swizzled_layout import itertools import argparse from typing import Optional @@ -18,117 +17,45 @@ def get_configs(): @autotune(configs=get_configs(), warmup=500, rep=100) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn( - batch, - heads, - seq_q, - seq_kv, - dim, - window_size=None, # None for full attention - sm_scale=None, - block_M=64, - block_N=64, - num_stages=1, - threads=128, - dtype: str = "float16"): + batch, + heads, + seq_q, + seq_kv, + dim, + window_size=None, # None for full attention + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype: T.dtype = T.float16, +): if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] - accum_dtype = "float" + accum_dtype = T.float32 past_len = seq_kv - seq_q assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" - @T.macro - def MMA0( - K: T.Tensor(kv_shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) - for i, j in T.Parallel(block_M, block_N): - q_idx = bx * block_M + i + past_len - k_idx = k * block_N + j - if window_size is not None: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, - -T.infinity(acc_s.dtype)) - else: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(kv_shape, dtype), - V_shared: T.SharedBuffer([block_M, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # NOTE(wt): check_inf is necessary for sliding window attention. - for i in T.Parallel(block_M): - if window_size is not None: - scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, - scores_max[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), - Sinks: T.Tensor([heads], dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + Sinks: T.Tensor([heads], dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -145,56 +72,76 @@ def main( logsum = T.alloc_fragment([block_M], accum_dtype) sinks = T.alloc_fragment([block_M], dtype) - T.annotate_layout({ - Q_shared: make_swizzled_layout(Q_shared), - K_shared: make_swizzled_layout(K_shared), - V_shared: make_swizzled_layout(V_shared), - O_shared: make_swizzled_layout(O_shared), - }) - - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) for i in T.Parallel(block_M): sinks[i] = Sinks[by] - end = T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) - - start = T.alloc_local([1], 'int32') - if window_size is not None: - start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N) - else: - start[0] = 0 - - for k in T.Pipelined(start[0], end, num_stages=num_stages): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + + start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0 + + for k in T.Pipelined(start, end, num_stages=num_stages): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + if window_size is not None: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) + else: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # NOTE(wt): check_inf is necessary for sliding window attention. + for i in T.Parallel(block_M): + if window_size is not None: + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i in T.Parallel(block_M): - logsum[i] += T.exp2(sinks[i] * 1.44269504 - - scores_max[i] * scale) # The only change for attention sink + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main # Modified from https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py -def ref_program(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sinks: torch.Tensor, - sliding_window: Optional[int] = None, - dtype: torch.dtype = torch.float16) -> torch.Tensor: - - query = query.transpose(1, 2).contiguous().unsqueeze( - 3) # align with the original function's interface +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function's interface key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() @@ -229,41 +176,36 @@ def ref_program(query: torch.Tensor, output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) - output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, - head_dim).to(dtype) + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) return output.transpose(1, 2).contiguous() -def gen_inputs( - B, - H, - Sq, - Skv, - D, - dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda') - key = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda') - value = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda') - sinks = torch.randn([H], dtype=dtype, device='cuda') +def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda") + key = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + value = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + sinks = torch.randn([H], dtype=dtype, device="cuda") return query, key, value, sinks -def main(batch: int = 1, - heads: int = 1, - seq_q: int = 256, - seq_kv: int = 256, - dim: int = 128, - window_size: int | None = None, - dtype: str = "float16", - tune: bool = False): - torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] +def main( + batch: int = 1, + heads: int = 1, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: T.dtype = T.float16, + tune: bool = False, +): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= seq_q - flops_per_matmul = 2.0 * batch * heads * min( - window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 total_flops = 2 * flops_per_matmul @@ -290,19 +232,17 @@ def main(batch: int = 1, block_N=block_N, num_stages=num_stages, threads=threads, - dtype=dtype) + dtype=dtype, + ) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) torch.testing.assert_close( - kernel(Q, K, V, sinks), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2) + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) print("All checks passed.✅") - latency = do_bench( - lambda: ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), warmup=500) + latency = do_bench(lambda: ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), warmup=500) print("Ref: {:.2f} ms".format(latency)) print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) @@ -310,21 +250,37 @@ def main(batch: int = 1, print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", +): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + kernel = flashattn( + batch, heads, seq_q, seq_kv, dim, window_size, block_M=block_M, block_N=block_N, num_stages=num_stages, threads=threads, dtype=dtype + ) + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) + latency = do_bench(lambda: kernel(Q, K, V, sinks), backend="cupti") + return latency + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query') - parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") - parser.add_argument('--tune', action='store_true', help='tune') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default=T.float16, help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune") args = parser.parse_args() - main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, - args.tune) + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py index eee2f3ac5..b47c8175f 100644 --- a/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py @@ -6,7 +6,6 @@ from tilelang.autotuner import autotune from tilelang.profiler import do_bench import tilelang.language as T -from tilelang.layout import make_swizzled_layout import itertools import argparse from typing import Optional @@ -19,119 +18,46 @@ def get_configs(): @autotune(configs=get_configs(), warmup=500, rep=100) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn( - batch, - heads, - seq_q, - seq_kv, - dim, - window_size=None, # None for full attention - sm_scale=None, - block_M=128, - block_N=128, - num_stages=2, - threads=256, - dtype: str = "float16"): - + batch, + heads, + seq_q, + seq_kv, + dim, + window_size=None, # None for full attention + sm_scale=None, + block_M=128, + block_N=128, + num_stages=2, + threads=256, + dtype: T.dtype = T.float16, +): if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] - accum_dtype = "float" + accum_dtype = T.float32 past_len = seq_kv - seq_q assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" - @T.macro - def MMA0( - K: T.Tensor(kv_shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) - for i, j in T.Parallel(block_M, block_N): - q_idx = bx * block_M + i + past_len - k_idx = k * block_N + j - if window_size is not None: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, - -T.infinity(acc_s.dtype)) - else: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(kv_shape, dtype), - V_shared: T.SharedBuffer([block_M, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # NOTE(wt): check_inf is necessary for sliding window attention. - for i in T.Parallel(block_M): - if window_size is not None: - scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, - scores_max[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), - Sinks: T.Tensor([heads], dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + Sinks: T.Tensor([heads], dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -148,63 +74,84 @@ def main( logsum = T.alloc_fragment([block_M], accum_dtype) sinks = T.alloc_fragment([block_M], dtype) - T.annotate_layout({ - Q_shared: make_swizzled_layout(Q_shared), - K_shared: make_swizzled_layout(K_shared), - V_shared: make_swizzled_layout(V_shared), - O_shared: make_swizzled_layout(O_shared), - }) - - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) for i in T.Parallel(block_M): sinks[i] = Sinks[by] - end = T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) - start = T.alloc_local([1], 'int32') - if window_size is not None: - start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N) - else: - start[0] = 0 + start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0 for k in T.Pipelined( - start[0], - end, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + start, + end, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + if window_size is not None: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) + else: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # NOTE(wt): check_inf is necessary for sliding window attention. + for i in T.Parallel(block_M): + if window_size is not None: + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i in T.Parallel(block_M): - logsum[i] += T.exp2(sinks[i] * 1.44269504 - - scores_max[i] * scale) # The only change for attention sink + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main # Following functions are adapted and optimized from # https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py -def ref_program(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sinks: torch.Tensor, - sliding_window: Optional[int] = None, - dtype: torch.dtype = torch.float16) -> torch.Tensor: - - query = query.transpose(1, 2).contiguous().unsqueeze( - 3) # align with the original function'sinterface +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function'sinterface key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() @@ -239,41 +186,36 @@ def ref_program(query: torch.Tensor, output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) - output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, - head_dim).to(dtype) + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) return output.transpose(1, 2).contiguous() -def gen_inputs( - B, - H, - Sq, - Skv, - D, - dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda') - key = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda') - value = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda') - sinks = torch.randn([H], dtype=dtype, device='cuda') +def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda") + key = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + value = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + sinks = torch.randn([H], dtype=dtype, device="cuda") return query, key, value, sinks -def main(batch: int = 1, - heads: int = 32, - seq_q: int = 256, - seq_kv: int = 256, - dim: int = 128, - window_size: int | None = None, - dtype: str = "float16", - tune: bool = False): - torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] +def main( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: T.dtype = T.float16, + tune: bool = False, +): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= seq_q - flops_per_matmul = 2.0 * batch * heads * min( - window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 total_flops = 2 * flops_per_matmul @@ -300,15 +242,14 @@ def main(batch: int = 1, block_N=block_N, num_stages=num_stages, threads=threads, - dtype=dtype) + dtype=dtype, + ) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) torch.testing.assert_close( - kernel(Q, K, V, sinks), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2) + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) print("All checks passed.✅") latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) @@ -316,21 +257,38 @@ def main(batch: int = 1, print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False, +): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + kernel = flashattn( + batch, heads, seq_q, seq_kv, dim, window_size, block_M=block_M, block_N=block_N, num_stages=num_stages, threads=threads, dtype=dtype + ) + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) + latency = do_bench(lambda: kernel(Q, K, V, sinks), backend="cupti") + return latency + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query') - parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") - parser.add_argument('--tune', action='store_true', help='tune') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default=T.float16, help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune") args = parser.parse_args() - main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, - args.tune) + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/regression_attention_sink.py b/examples/attention_sink/regression_attention_sink.py new file mode 100644 index 000000000..e2453173c --- /dev/null +++ b/examples/attention_sink/regression_attention_sink.py @@ -0,0 +1,64 @@ +import tilelang.testing +import example_mha_sink_fwd_bhsd +import example_mha_sink_fwd_bhsd_wgmma_pipelined +import example_mha_sink_bwd_bhsd +import example_gqa_sink_bwd_bhsd +import example_gqa_sink_fwd_bhsd_wgmma_pipelined + + +def regression_example_mha_sink_fwd_bhsd(): + tilelang.testing.process_func(example_mha_sink_fwd_bhsd.run_regression_perf) + + +def regression_example_mha_sink_fwd_bhsd_sliding_window(): + tilelang.testing.process_func( + example_mha_sink_fwd_bhsd.run_regression_perf, "regression_example_mha_sink_fwd_bhsd_sliding_window", window_size=128 + ) + + +def regression_example_mha_sink_fwd_bhsd_wgmma_pipelined(): + tilelang.testing.process_func(example_mha_sink_fwd_bhsd_wgmma_pipelined.run_regression_perf) + + +def regression_example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window(): + tilelang.testing.process_func( + example_mha_sink_fwd_bhsd_wgmma_pipelined.run_regression_perf, + "regression_example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window", + window_size=128, + ) + + +def regression_example_gqa_sink_fwd_bhsd_wgmma_pipelined(): + tilelang.testing.process_func(example_gqa_sink_fwd_bhsd_wgmma_pipelined.run_regression_perf) + + +def regression_example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window(): + tilelang.testing.process_func( + example_gqa_sink_fwd_bhsd_wgmma_pipelined.run_regression_perf, + "regression_example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window", + window_size=128, + ) + + +def regression_example_mha_sink_bwd_bhsd(): + tilelang.testing.process_func(example_mha_sink_bwd_bhsd.run_regression_perf) + + +def regression_example_mha_sink_bwd_bhsd_sliding_window(): + tilelang.testing.process_func( + example_mha_sink_bwd_bhsd.run_regression_perf, "regression_example_mha_sink_bwd_bhsd_sliding_window", window_size=128 + ) + + +def regression_example_gqa_sink_bwd_bhsd(): + tilelang.testing.process_func(example_gqa_sink_bwd_bhsd.run_regression_perf) + + +def regression_example_gqa_sink_bwd_bhsd_sliding_window(): + tilelang.testing.process_func( + example_gqa_sink_bwd_bhsd.run_regression_perf, "regression_example_gqa_sink_bwd_bhsd_sliding_window", window_size=128 + ) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/bitnet-1.58b/.gitignore b/examples/bitnet-1.58b/.gitignore index 6ea887496..2bcdfd92b 100644 --- a/examples/bitnet-1.58b/.gitignore +++ b/examples/bitnet-1.58b/.gitignore @@ -1 +1 @@ -models/ \ No newline at end of file +models/ diff --git a/examples/bitnet-1.58b/README.md b/examples/bitnet-1.58b/README.md index 2b587eab4..b9898741b 100644 --- a/examples/bitnet-1.58b/README.md +++ b/examples/bitnet-1.58b/README.md @@ -2,7 +2,6 @@ license: mit --- - This is a Tilelang Implementation for the reproduced 1.58bit model from [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B). We replaced the original simulated Int8x3bit Quantized Inference Kernel with INT8xINT2 Kernel. We also evaluated the model's correctness and performance through `eval_correctness.py` and `benchmark_inference_latency.py`. ## Make Checkpoints for vLLM @@ -43,7 +42,6 @@ python3 inference_with_bitblas_format.py | bitnet-3b-1.58bits | vllm-tilelang | 379.25 | 117.43 | 752.55 | | bitnet-3b-1.58bits | vllm-tilelang-cuda-graph | 2543.58 | 1621.08 | 2731.79 | - ## BitBLAS Results ### Performance @@ -94,4 +92,4 @@ The differences between the reported numbers and the reproduced results are poss journal={arXiv preprint arXiv:2402.17764}, year={2024} } -``` \ No newline at end of file +``` diff --git a/examples/bitnet-1.58b/benchmark.sh b/examples/bitnet-1.58b/benchmark.sh index 6a2550d45..839443dc6 100755 --- a/examples/bitnet-1.58b/benchmark.sh +++ b/examples/bitnet-1.58b/benchmark.sh @@ -1,3 +1,5 @@ +#!/usr/bin/env bash + python benchmark_generate.py --bs 16 --in_seq_len 32 --out_seq_len 128 | tee b16_i32_o128.log python benchmark_generate.py --bs 1 --in_seq_len 512 --out_seq_len 64 | tee b1_i512_o64.log diff --git a/examples/bitnet-1.58b/benchmark_generate.py b/examples/bitnet-1.58b/benchmark_generate.py index d6f21ed50..d678b91a4 100644 --- a/examples/bitnet-1.58b/benchmark_generate.py +++ b/examples/bitnet-1.58b/benchmark_generate.py @@ -12,8 +12,7 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100): # Encode the input prompts as a batch - input_ids = tokenizer( - prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device) + input_ids = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device) # Generate cos and sin values (commented out as not used in generation) seq_length = input_ids.size(1) @@ -37,9 +36,7 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100): end_time = time.time() # Decode the output ids to text - generated_texts = [ - tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids - ] + generated_texts = [tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids] generation_time = end_time - start_time num_tokens = sum(len(output_id) for output_id in output_ids) @@ -52,8 +49,8 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100): def profile(model, input_data): - import numpy as np + model = model.cuda() model.eval() @@ -74,25 +71,29 @@ def get_runtime(num_repeats=1): return np.mean(times) -model_path = '1bitLLM/bitnet_b1_58-3B' +model_path = "1bitLLM/bitnet_b1_58-3B" def main(): parser = argparse.ArgumentParser() - parser.add_argument('--bs', default=16, type=int) - parser.add_argument('--in_seq_len', default=32, type=int) - parser.add_argument('--out_seq_len', default=128, type=int) - parser.add_argument('--bitblas', action='store_true') + parser.add_argument("--bs", default=16, type=int) + parser.add_argument("--in_seq_len", default=32, type=int) + parser.add_argument("--out_seq_len", default=128, type=int) + parser.add_argument("--bitblas", action="store_true") args = parser.parse_args() bs = args.bs in_seq_len = args.in_seq_len out_seq_len = args.out_seq_len is_bitblas = args.bitblas - model = BitnetForCausalLM.from_pretrained( - model_path, - use_flash_attention_2=True, - torch_dtype=torch.float16, - ).cuda().half() + model = ( + BitnetForCausalLM.from_pretrained( + model_path, + use_flash_attention_2=True, + torch_dtype=torch.float16, + ) + .cuda() + .half() + ) if is_bitblas: with torch.no_grad(): model.quantize() @@ -109,5 +110,5 @@ def main(): print(generate_text_batch(model, tokenizer, prompts, max_length=max_length)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/bitnet-1.58b/benchmark_inference_latency.py b/examples/bitnet-1.58b/benchmark_inference_latency.py index 9ce7a3898..788fc5565 100644 --- a/examples/bitnet-1.58b/benchmark_inference_latency.py +++ b/examples/bitnet-1.58b/benchmark_inference_latency.py @@ -6,13 +6,14 @@ torch.set_grad_enabled(False) parser = argparse.ArgumentParser() -parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) +parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) def profile(model, input_data): import time import numpy as np + model = model.cuda() model.eval() @@ -35,8 +36,8 @@ def get_runtime(num_repeats=1): def main(): model = BitnetForCausalLM.from_pretrained( - '1bitLLM/bitnet_b1_58-3B', - device_map='auto', + "1bitLLM/bitnet_b1_58-3B", + device_map="auto", low_cpu_mem_usage=True, use_flash_attention_2=True, torch_dtype=torch.float16, @@ -52,5 +53,5 @@ def main(): print(f"Batch size: {batch_size}, Seq len: {seq_len}, Latency: {latency}") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/bitnet-1.58b/configuration_bitnet.py b/examples/bitnet-1.58b/configuration_bitnet.py index 5f4937b87..63c499db3 100644 --- a/examples/bitnet-1.58b/configuration_bitnet.py +++ b/examples/bitnet-1.58b/configuration_bitnet.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" LLaMA model configuration""" +"""LLaMA model configuration""" from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging @@ -180,16 +180,10 @@ def _rope_scaling_validation(self): return if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " - f"got {self.rope_scaling}") + raise ValueError(f"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, got {self.rope_scaling}") rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, - float) or rope_scaling_factor <= 1.0: - raise ValueError( - f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + raise ValueError(f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}") + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/examples/bitnet-1.58b/eval_correctness.py b/examples/bitnet-1.58b/eval_correctness.py index ac1e34072..11d47004b 100644 --- a/examples/bitnet-1.58b/eval_correctness.py +++ b/examples/bitnet-1.58b/eval_correctness.py @@ -47,8 +47,8 @@ def generate_text(model, tokenizer, prompt, max_length=100): def profile(model, input_data): - import numpy as np + model = model.cuda() model.eval() @@ -69,18 +69,22 @@ def get_runtime(num_repeats=1): return np.mean(times) -model_path = '1bitLLM/bitnet_b1_58-3B' +model_path = "1bitLLM/bitnet_b1_58-3B" def main(): - model = BitnetForCausalLM.from_pretrained( - model_path, - use_flash_attention_2=False, - torch_dtype=torch.float16, - ).cuda().half() + model = ( + BitnetForCausalLM.from_pretrained( + model_path, + use_flash_attention_2=False, + torch_dtype=torch.float16, + ) + .cuda() + .half() + ) tokenizer = BitnetTokenizer.from_pretrained(model_path, use_fast=False) - input_id = tokenizer("Hello")['input_ids'] + input_id = tokenizer("Hello")["input_ids"] input_id = torch.tensor(input_id).unsqueeze(0).cuda() print("original model generated text:") @@ -91,5 +95,5 @@ def main(): print(generate_text(model, tokenizer, "Hello", max_length=100)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/bitnet-1.58b/eval_gpu_memory.py b/examples/bitnet-1.58b/eval_gpu_memory.py index 597cbbfcd..00c914cb3 100644 --- a/examples/bitnet-1.58b/eval_gpu_memory.py +++ b/examples/bitnet-1.58b/eval_gpu_memory.py @@ -6,13 +6,14 @@ torch.set_grad_enabled(False) parser = argparse.ArgumentParser() -parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) +parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) def profile(model, input_data): import time import numpy as np + model = model.cuda() model.eval() @@ -35,17 +36,17 @@ def get_runtime(num_repeats=1): def main(): model = BitnetForCausalLM.from_pretrained( - '1bitLLM/bitnet_b1_58-3B', - device_map='auto', + "1bitLLM/bitnet_b1_58-3B", + device_map="auto", low_cpu_mem_usage=True, use_flash_attention_2=True, torch_dtype=torch.float16, ).half() - print(f"gpu memory: {torch.cuda.memory_allocated() / 1024 ** 3} GB") + print(f"gpu memory: {torch.cuda.memory_allocated() / 1024**3} GB") with torch.no_grad(): model._post_process_weights() - print(f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024 ** 3} GB") + print(f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024**3} GB") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/bitnet-1.58b/eval_ppl.py b/examples/bitnet-1.58b/eval_ppl.py index 61c8488e4..97db2d0f5 100644 --- a/examples/bitnet-1.58b/eval_ppl.py +++ b/examples/bitnet-1.58b/eval_ppl.py @@ -15,9 +15,9 @@ torch.set_grad_enabled(False) parser = argparse.ArgumentParser() -parser.add_argument('--seed', default=0, type=int) -parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) -parser.add_argument('--seqlen', default=2048, type=int) +parser.add_argument("--seed", default=0, type=int) +parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) +parser.add_argument("--seqlen", default=2048, type=int) def calulate_loss(model, input, loss_fct): @@ -29,12 +29,16 @@ def calulate_loss(model, input, loss_fct): def main(args): - datasets = ['c4', 'wikitext2'] - model = BitnetForCausalLM.from_pretrained( - args.hf_path, - use_flash_attention_2=True, - torch_dtype=torch.float16, - ).cuda().half() + datasets = ["c4", "wikitext2"] + model = ( + BitnetForCausalLM.from_pretrained( + args.hf_path, + use_flash_attention_2=True, + torch_dtype=torch.float16, + ) + .cuda() + .half() + ) with torch.no_grad(): model._post_process_weights() tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False) @@ -48,9 +52,9 @@ def main(args): for ii in progress: input = torch.Tensor(testdata[ii]).long().cuda().view(1, -1) loss = calulate_loss(model, input, loss_fct) - count += (input.size(-1) - 1) + count += input.size(-1) - 1 acc_loss += loss.item() - progress.set_description(f"avg_loss = {acc_loss/ count / math.log(2)}") + progress.set_description(f"avg_loss = {acc_loss / count / math.log(2)}") avg_loss = acc_loss / count / math.log(2) ppl.append(2**avg_loss) @@ -60,7 +64,7 @@ def main(args): print("Avg PPL:", sum(ppl) / len(ppl)) -if __name__ == '__main__': +if __name__ == "__main__": torch.set_grad_enabled(False) args = parser.parse_args() random.seed(args.seed) diff --git a/examples/bitnet-1.58b/eval_utils.py b/examples/bitnet-1.58b/eval_utils.py index 46241eedf..72480c392 100644 --- a/examples/bitnet-1.58b/eval_utils.py +++ b/examples/bitnet-1.58b/eval_utils.py @@ -15,21 +15,17 @@ def set_seed(seed): def get_test_dataset(dataset_name, tokenizer, seqlen=2048): if dataset_name == "wikitext2": - testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') - testdata = "".join(testdata['text']).split('\n') + testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + testdata = "".join(testdata["text"]).split("\n") elif dataset_name == "c4": - testdata = load_dataset( - 'allenai/c4', - data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, - split='validation')['text'] + testdata = load_dataset("allenai/c4", data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, split="validation")[ + "text" + ] else: raise NotImplementedError testdata = [item for item in testdata if item != ""] - tokenized_text = [ - tokenizer(item, add_special_tokens=False)['input_ids'] + [tokenizer.eos_token_id] - for item in testdata - ] + tokenized_text = [tokenizer(item, add_special_tokens=False)["input_ids"] + [tokenizer.eos_token_id] for item in testdata] data, doc = [], [tokenizer.bos_token_id] for sen in tokenized_text: @@ -45,7 +41,6 @@ def get_test_dataset(dataset_name, tokenizer, seqlen=2048): class LMEvalAdaptor(BaseLM): - def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1): super().__init__() @@ -137,5 +132,4 @@ def _model_call(self, inps): return out def _model_generate(self, context, max_length, eos_token_id): - return self.model.generate( - context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False) + return self.model.generate(context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False) diff --git a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py index e5af16cc4..7b8b7b95c 100644 --- a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py +++ b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py @@ -76,13 +76,13 @@ def bitnet_158_int8xint2_decode( reduce_thread=32, ): assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" storage_nbit = 8 num_bits = 2 @@ -94,7 +94,7 @@ def bitnet_158_int8xint2_decode( MAX_TRANSACTION_SIZE_IN_BITS = 128 micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits micro_size_k_compressed = micro_size_k // num_elems_per_byte - storage_dtype = "int8" + storage_dtype = T.int8 block_K = reduce_thread * micro_size_k use_dp4a = True @@ -102,17 +102,17 @@ def bitnet_158_int8xint2_decode( @T.prim_func def kernel( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, storage_dtype), - C: T.Buffer(C_shape, out_dtype), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer(C_shape, out_dtype), ): with T.Kernel( - T.ceildiv(N, n_partition), - M, - threads=(reduce_thread, n_partition), + T.ceildiv(N, n_partition), + M, + threads=(reduce_thread, n_partition), ) as ( - bx, - by, + bx, + by, ): A_local = T.alloc_local((micro_size_k,), in_dtype) B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype) @@ -133,8 +133,7 @@ def kernel( for v in T.vectorized(micro_size_k_compressed): B_quant_local[v] = B[ bx * n_partition + ni, - ko * (reduce_thread * micro_size_k_compressed) + - kr * micro_size_k_compressed + v, + ko * (reduce_thread * micro_size_k_compressed) + kr * micro_size_k_compressed + v, ] T.call_extern( @@ -156,9 +155,9 @@ def kernel( accum_res[0] += A_local[ki] * B_dequantize_local[ki] with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), - "reduce_scope", - T.reinterpret(T.uint64(0), dtype="handle"), + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), ): T.evaluate( T.tvm_thread_allreduce( @@ -168,7 +167,8 @@ def kernel( reduced_accum_res[0], kr, dtype="handle", - )) + ) + ) if kr == 0: C[by, bx * n_partition + ni] = reduced_accum_res[0] @@ -194,12 +194,12 @@ def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8): # interleave weight numpy implementation -def interleave_weight(qweight, nbits=4, target_dtype="float16"): - assert target_dtype in ["float16", "int8"] +def interleave_weight(qweight, nbits=4, target_dtype=T.float16): + assert target_dtype in [T.float16, T.int8] # reinterpret the data type of qweight to int32 qweight = qweight.view(np.int32) new_qweight = np.zeros_like(qweight) - bits_stride = 8 if target_dtype == "int8" else 16 + bits_stride = 8 if target_dtype == T.int8 else 16 mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f num_groups = 32 // bits_stride elems_per_group = bits_stride // nbits @@ -209,7 +209,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift - if nbits == 1 and target_dtype == "int8": + if nbits == 1 and target_dtype == T.int8: # special handling for 1b interleave n16_weight = new_qweight & np.int32(0xF0F00F0F) n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16 @@ -217,12 +217,12 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4 n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12 return n16_weight.view(np.int8) - elif nbits == 2 and target_dtype == "float16": + elif nbits == 2 and target_dtype == T.float16: n8_weight = new_qweight & np.int32(0xFF0000FF) n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16 n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8 return n8_weight.view(np.int8) - elif nbits == 1 and target_dtype == "float16": + elif nbits == 1 and target_dtype == T.float16: n8_weight = new_qweight & 0xF000000F n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8 n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16 @@ -234,13 +234,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): return new_qweight.view(np.int8) -def assert_bitnet_158_int8xint2_decode_correctness(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - fast_decoding=True): +def assert_bitnet_158_int8xint2_decode_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True): program = bitnet_158_int8xint2_decode(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding) print(program) kernel = tilelang.compile(program) @@ -265,4 +259,4 @@ def assert_bitnet_158_int8xint2_decode_correctness(M, if __name__ == "__main__": - assert_bitnet_158_int8xint2_decode_correctness(1, 256, 256, "int8", "int32", "int32") + assert_bitnet_158_int8xint2_decode_correctness(1, 256, 256, T.int8, T.int32, T.int32) diff --git a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py index d8b1f6228..f4a60098a 100644 --- a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py +++ b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py @@ -8,11 +8,13 @@ from tilelang import tvm as tvm from tvm import DataType from tilelang.intrinsics.mma_layout import ( - make_mma_swizzle_layout as make_swizzle_layout,) + make_mma_swizzle_layout as make_swizzle_layout, +) import numpy as np from tilelang.intrinsics.mma_macro_generator import ( - INT4TensorCoreIntrinEmitter,) + INT4TensorCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func torch.manual_seed(42) @@ -86,9 +88,9 @@ def bitnet_158_int8xint2_prefill( Create a TVM GPU prim_func implementing a block-tiled matrix multiply that multiplies dense A by compressed/interleaved low‑precision B (2-bit packed into int8 storage), decoding B to int8 on-chip and accumulating into C. The returned prim_func expects: - - A: shape (M, K) with dtype `in_dtype` ("float16" or "int8"). + - A: shape (M, K) with dtype `in_dtype` (T.float16 or T.int8). - B: compressed storage with shape (N, K/4) and int8 storage layout (packing 4 2-bit elements per byte). - - C: output buffer shape (M, N) with dtype `out_dtype` ("float16", "float32", or "int32"). + - C: output buffer shape (M, N) with dtype `out_dtype` (T.float16, T.float32, or T.int32). Details: - Builds a tiled, pipelined kernel using shared memory and warp-level MMA intrinsics (INT4TensorCoreIntrinEmitter). B is loaded from compressed storage, decoded to int8 in threads (via decode_i2u_to_i8s / decode_i2s_to_i8s), and dequantized into a shared buffer used by the MMA emitter. @@ -96,15 +98,15 @@ def bitnet_158_int8xint2_prefill( - block_row_warps, block_col_warps: number of warps per block in row/col. - warp_row_tiles, warp_col_tiles: tiles per warp. - chunk: K-sized chunk per block (block_K). - - micro sizes are fixed (16x16x16, except micro_k=32 when accum_dtype == "int32"). + - micro sizes are fixed (16x16x16, except micro_k=32 when accum_dtype == T.int32). - Uses 2-stage pipelining by default to overlap loads and compute and applies a swizzle layout to improve L2 behavior. - Assertions: raises AssertionError if in_dtype or out_dtype are not among supported values. Parameters: M, N, K (int): Global matrix dimensions. - in_dtype (str): Input and decoded B element dtype; "float16" or "int8". - out_dtype (str): Output C dtype; one of "float16", "float32", "int32". - accum_dtype (str): Accumulator dtype used by MMA (e.g., "int32"). + in_dtype (str): Input and decoded B element dtype; T.float16 or T.int8. + out_dtype (str): Output C dtype; one of T.float16, T.float32, T.int32. + accum_dtype (str): Accumulator dtype used by MMA (e.g., T.int32). fast_decoding (bool): If True, enable the fast decoding path (affects which device decode is used). block_row_warps (int): Warps in block row dimension. block_col_warps (int): Warps in block column dimension. @@ -116,18 +118,18 @@ def bitnet_158_int8xint2_prefill( T.prim_func: A TVM prim_func implementing the described GPU kernel suitable for compilation and execution. """ assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 - if accum_dtype == "int32": + if accum_dtype == T.int32: micro_size_k = 32 num_elems_per_byte = 4 @@ -136,7 +138,7 @@ def bitnet_158_int8xint2_prefill( local_size_compressed = local_size // num_elems_per_byte shared_scope = "shared.dyn" - storage_dtype = "int8" + storage_dtype = T.int8 # Pipeline Stage stage = 2 @@ -181,38 +183,36 @@ def bitnet_158_int8xint2_prefill( @T.prim_func def main( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, storage_dtype), - C: T.Buffer((M, N), out_dtype), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), out_dtype), ): """ - GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C. + GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C. - This kernel: - - Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory. - - Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine. - - Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages. - - Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing. + This kernel: + - Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory. + - Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine. + - Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages. + - Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing. - Parameters: - A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations. - B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel. - C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C). + Parameters: + A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations. + B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel. + C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C). - Side effects: - Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation. + Side effects: + Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation. """ with T.Kernel( - T.ceildiv(N, block_N), - T.ceildiv(M, block_M), - threads=threads, - prelude=decode_i2s_to_i8s, + T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + threads=threads, + prelude=decode_i2s_to_i8s, ) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) - B_dequantize_shared = T.alloc_shared( - B_dequantize_shared_shape, in_dtype, scope=shared_scope) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype) B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype) @@ -221,12 +221,14 @@ def main( B_local = T.alloc_local([local_size_compressed], storage_dtype) B_dequantize_local = T.alloc_local([local_size], in_dtype) - thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + thread_bindings = T.get_thread_binding(0) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_dequantize_shared: make_swizzle_layout(B_dequantize_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_dequantize_shared: make_swizzle_layout(B_dequantize_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -234,7 +236,6 @@ def main( T.clear(C_frag) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -243,12 +244,9 @@ def main( for j, k in T.Parallel(block_N, block_K // num_elems_per_byte): B_shared[j, k] = B[bx * block_N + j, ko * (block_K // num_elems_per_byte) + k] - for i in T.serial(block_N * block_K // num_elems_per_byte // - (threads * local_size_compressed)): + for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)): for v in T.vectorized(0, local_size_compressed): - index = ( - i * threads * local_size_compressed + - thread_bindings * local_size_compressed + v) + index = i * threads * local_size_compressed + thread_bindings * local_size_compressed + v vi, vj = T.index_to_coordinates(index, B_shared_shape) B_local[v] = B_shared[vi, vj] @@ -260,12 +258,11 @@ def main( ) for v in T.vectorized(0, local_size): - index = (i * threads * local_size + thread_bindings * local_size + v) + index = i * threads * local_size + thread_bindings * local_size + v vi, vj = T.index_to_coordinates(index, B_dequantize_shared_shape) B_dequantize_shared[vi, vj] = B_dequantize_local[v] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_frag, @@ -320,12 +317,12 @@ def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8): # interleave weight numpy implementation -def interleave_weight(qweight, nbits=4, target_dtype="float16"): - assert target_dtype in ["float16", "int8"] +def interleave_weight(qweight, nbits=4, target_dtype=T.float16): + assert target_dtype in [T.float16, T.int8] # reinterpret the data type of qweight to int32 qweight = qweight.view(np.int32) new_qweight = np.zeros_like(qweight) - bits_stride = 8 if target_dtype == "int8" else 16 + bits_stride = 8 if target_dtype == T.int8 else 16 mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f num_groups = 32 // bits_stride elems_per_group = bits_stride // nbits @@ -335,7 +332,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift - if nbits == 1 and target_dtype == "int8": + if nbits == 1 and target_dtype == T.int8: # special handling for 1b interleave n16_weight = new_qweight & np.int32(0xF0F00F0F) n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16 @@ -343,12 +340,12 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4 n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12 return n16_weight.view(np.int8) - elif nbits == 2 and target_dtype == "float16": + elif nbits == 2 and target_dtype == T.float16: n8_weight = new_qweight & np.int32(0xFF0000FF) n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16 n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8 return n8_weight.view(np.int8) - elif nbits == 1 and target_dtype == "float16": + elif nbits == 1 and target_dtype == T.float16: n8_weight = new_qweight & 0xF000000F n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8 n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16 @@ -360,13 +357,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): return new_qweight.view(np.int8) -def assert_bitnet_158_int8xint2_prefill_correctness(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - fast_decoding=True): +def assert_bitnet_158_int8xint2_prefill_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True): program = bitnet_158_int8xint2_prefill(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding) print(program) kernel = tilelang.compile(program) @@ -391,4 +382,4 @@ def assert_bitnet_158_int8xint2_prefill_correctness(M, if __name__ == "__main__": - assert_bitnet_158_int8xint2_prefill_correctness(256, 256, 256, "int8", "int32", "int32") + assert_bitnet_158_int8xint2_prefill_correctness(256, 256, 256, T.int8, T.int32, T.int32) diff --git a/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py b/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py index 986463598..e3d35df4b 100644 --- a/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py +++ b/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py @@ -6,7 +6,8 @@ import tvm.tl.language as T from bitblas.tl.utils import get_swizzle_layout from bitblas.tl.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from bitblas.base import simplify_prim_func torch.manual_seed(0) @@ -37,18 +38,18 @@ def tl_matmul( accum_dtype, ): assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 - if out_dtype == "int32": + if out_dtype == T.int32: micro_size_k = 32 # This is a debug config @@ -56,7 +57,7 @@ def tl_matmul( block_col_warps = 2 warp_row_tiles = 64 warp_col_tiles = 64 - chunk = 32 if in_dtype == "float16" else 64 + chunk = 32 if in_dtype == T.float16 else 64 shared_scope = "shared.dyn" # Pipeline Stage @@ -101,12 +102,11 @@ def tl_matmul( @T.prim_func def main( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, in_dtype), - C: T.Buffer((M, N), out_dtype), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -116,10 +116,12 @@ def main( thread_bindings = T.thread_binding(0, threads, "threadIdx.x") - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -127,7 +129,6 @@ def main( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -137,7 +138,6 @@ def main( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, @@ -183,7 +183,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): # src_code is the generated cuda source assert src_code is not None print(src_code) - if in_dtype == "int8": + if in_dtype == T.int8: A = torch.randint(-7, 7, (M, K), device="cuda", dtype=torch.int8) B = torch.randint(-7, 7, (N, K), device="cuda", dtype=torch.int8) else: @@ -209,12 +209,12 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): def test_assert_tl_matmul(): - assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") - assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") + assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16) + assert_tl_matmul_correctness(128, 256, 256, T.float16, T.float32, T.float32) if __name__ == "__main__": # bitblas.testing.main() - # assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") - # assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") - assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") + # assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16) + # assert_tl_matmul_correctness(128, 128, 128, T.int8, T.int32, T.int32) + assert_tl_matmul_correctness(16384, 16384, 16384, T.int8, T.int32, T.int32) diff --git a/examples/bitnet-1.58b/load_from_quantized.py b/examples/bitnet-1.58b/load_from_quantized.py index 26a32f974..8c775aa4c 100644 --- a/examples/bitnet-1.58b/load_from_quantized.py +++ b/examples/bitnet-1.58b/load_from_quantized.py @@ -49,7 +49,13 @@ def generate_text(model, tokenizer, prompt, max_length=100): def main(): # load quantized model - qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half() + qmodel = ( + BitnetForCausalLM.from_quantized( + saved_model_path, + ) + .cuda() + .half() + ) tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False) # print("original model generated text:") # print(generate_text(model, tokenizer, "Hi, ", max_length=100)) diff --git a/examples/bitnet-1.58b/maint/README.md b/examples/bitnet-1.58b/maint/README.md index 63cc3e275..6bccdf93a 100644 --- a/examples/bitnet-1.58b/maint/README.md +++ b/examples/bitnet-1.58b/maint/README.md @@ -2,7 +2,6 @@ license: mit --- - This is a BitBLAS Implementation for the reproduced 1.58bit model from [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B). We replaced the original simulated Int8x3bit Quantized Inference Kernel with BitBLAS INT8xINT2 Kernel. We also evaluated the model's correctness and performance through `eval_correctness.py` and `benchmark_inference_latency.py`. ## Latest News @@ -88,4 +87,4 @@ The differences between the reported numbers and the reproduced results are poss journal={arXiv preprint arXiv:2402.17764}, year={2024} } -``` \ No newline at end of file +``` diff --git a/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py b/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py index 1e29a553a..2604ef387 100644 --- a/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py +++ b/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py @@ -25,9 +25,9 @@ args = parser.parse_args() model_name_or_path = args.model_name_or_path -saved_model_path = os.path.join( - dirpath, "models", - f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path +saved_model_path = ( + os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path +) def generate_text(model, tokenizer, prompt, max_length=100): @@ -67,7 +67,10 @@ def main(): model_name_or_path, use_flash_attention_2=False, torch_dtype=torch.float16, - ).cuda().half()) + ) + .cuda() + .half() + ) tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False) # print("original model generated text:") @@ -112,10 +115,16 @@ def main(): file_path = cached_file(model_name_or_path, file) os.system(f"cp {file_path} {saved_model_path}") # load quantized model - qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half() + qmodel = ( + BitnetForCausalLM.from_quantized( + saved_model_path, + ) + .cuda() + .half() + ) print("quantized model generated text:") print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/bitnet-1.58b/maint/generate_bitnet_model_bitblas_format.sh b/examples/bitnet-1.58b/maint/generate_bitnet_model_bitblas_format.sh index 741c3a124..b0430588a 100755 --- a/examples/bitnet-1.58b/maint/generate_bitnet_model_bitblas_format.sh +++ b/examples/bitnet-1.58b/maint/generate_bitnet_model_bitblas_format.sh @@ -1,3 +1,5 @@ +#!/usr/bin/env bash + # retrieve the native model input and saved model directory MODEL_DIR=$1 SAVED_MODEL_DIR=$2 diff --git a/examples/bitnet-1.58b/maint/generate_bitnet_model_native_format.sh b/examples/bitnet-1.58b/maint/generate_bitnet_model_native_format.sh index a2df0eb8c..66356d3d8 100755 --- a/examples/bitnet-1.58b/maint/generate_bitnet_model_native_format.sh +++ b/examples/bitnet-1.58b/maint/generate_bitnet_model_native_format.sh @@ -1,3 +1,5 @@ +#!/usr/bin/env bash + # require git lfs if ! command -v git-lfs &> /dev/null; then echo "Please install git-lfs first by running 'sudo apt install git-lfs'" diff --git a/examples/bitnet-1.58b/maint/quantize_config.json b/examples/bitnet-1.58b/maint/quantize_config.json index e2b24123a..80fbf02f0 100644 --- a/examples/bitnet-1.58b/maint/quantize_config.json +++ b/examples/bitnet-1.58b/maint/quantize_config.json @@ -7,4 +7,4 @@ "model_name_or_path": "1bitLLM/bitnet_b1_58-3B", "quant_method": "bitnet", "checkpoint_format": "bitnet" -} \ No newline at end of file +} diff --git a/examples/bitnet-1.58b/maint/upload_models.sh b/examples/bitnet-1.58b/maint/upload_models.sh index b764b0da6..7c6d76e32 100755 --- a/examples/bitnet-1.58b/maint/upload_models.sh +++ b/examples/bitnet-1.58b/maint/upload_models.sh @@ -1,3 +1,5 @@ +#!/usr/bin/env bash + MODEL_DIR=$1 REMOTE_DIR=$2 diff --git a/examples/bitnet-1.58b/modeling_bitnet.py b/examples/bitnet-1.58b/modeling_bitnet.py index 6e3c42b6f..1830995ee 100644 --- a/examples/bitnet-1.58b/modeling_bitnet.py +++ b/examples/bitnet-1.58b/modeling_bitnet.py @@ -64,8 +64,7 @@ def find_layers(module, layers=None, name=""): return {name: module} res = {} for name1, child in module.named_children(): - res.update( - find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1)) + res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1)) return res @@ -87,7 +86,6 @@ def _get_unpad_data(attention_mask): class BitnetRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): """ BitnetRMSNorm is equivalent to T5LayerNorm @@ -108,34 +106,23 @@ def forward(self, hidden_states): class BitnetRotaryEmbedding(nn.Module): - - def __init__(self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): super().__init__() self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / ( - self.base - **(torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq) # For BC we register cos and sin cached self.max_seq_len_cached = max_position_embeddings - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) t = t / self.scaling_factor freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - "_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) - self.register_buffer( - "_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) @property def sin_cached(self): @@ -156,14 +143,12 @@ def cos_cached(self): @torch.no_grad() def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] - inv_freq_expanded = self.inv_freq[None, :, - None].float().expand(position_ids.shape[0], -1, 1) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type - device_type = device_type if isinstance(device_type, - str) and device_type != "mps" else "cpu" + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) @@ -174,8 +159,8 @@ def forward(self, x, position_ids): def rotate_half(x): """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) @@ -207,7 +192,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class BitnetMLP(nn.Module): - def __init__(self, config): super().__init__() self.config = config @@ -245,7 +229,6 @@ def forward(self, x): class BitnetMLPFuseGateUp(nn.Module): - def __init__(self, config): super().__init__() self.config = config @@ -272,8 +255,7 @@ def __init__(self, config): def from_bit_mlp(cls, bit_mlp: BitnetMLP): module = cls(bit_mlp.config) # assign the weights - module.gate_up_proj.weight = nn.Parameter( - torch.cat([bit_mlp.gate_proj.weight, bit_mlp.up_proj.weight], dim=0)) + module.gate_up_proj.weight = nn.Parameter(torch.cat([bit_mlp.gate_proj.weight, bit_mlp.up_proj.weight], dim=0)) module.down_proj = bit_mlp.down_proj module.ffn_layernorm = bit_mlp.ffn_layernorm return module @@ -295,8 +277,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, - head_dim) + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @@ -311,7 +292,8 @@ def __init__(self, config: BitnetConfig, layer_idx: Optional[int] = None): logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class.") + "when creating this class." + ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size @@ -325,8 +307,8 @@ def __init__(self, config: BitnetConfig, layer_idx: Optional[int] = None): if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads}).") + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads})." + ) self.q_proj = BitLinear( self.hidden_size, @@ -387,10 +369,8 @@ def forward( value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) past_key_value = getattr(self, "past_key_value", past_key_value) cos, sin = self.rotary_emb(value_states, position_ids) @@ -399,30 +379,24 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( - self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, :key_states.shape[-2]] + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout( - attn_weights, p=self.attention_dropout, training=self.training) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}") + raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is {attn_output.size()}") attn_output = attn_output.transpose(1, 2).contiguous() @@ -448,7 +422,8 @@ def __init__(self, config: BitnetConfig, layer_idx: Optional[int] = None): logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class.") + "when creating this class." + ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size @@ -462,8 +437,8 @@ def __init__(self, config: BitnetConfig, layer_idx: Optional[int] = None): if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads}).") + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads})." + ) self.qkv_proj = BitLinear( self.hidden_size, @@ -497,17 +472,12 @@ def from_bit_attention(cls, bit_attention: BitnetAttention): module = cls(bit_attention.config, bit_attention.layer_idx) # assign the weights module.qkv_proj.weight = nn.Parameter( - torch.cat([ - bit_attention.q_proj.weight, bit_attention.k_proj.weight, - bit_attention.v_proj.weight - ], - dim=0)) + torch.cat([bit_attention.q_proj.weight, bit_attention.k_proj.weight, bit_attention.v_proj.weight], dim=0) + ) if bit_attention.q_proj.bias is not None and bit_attention.k_proj.bias is not None and bit_attention.v_proj.bias is not None: module.qkv_proj.bias = nn.Parameter( - torch.cat([ - bit_attention.q_proj.bias, bit_attention.k_proj.bias, bit_attention.v_proj.bias - ], - dim=0)) + torch.cat([bit_attention.q_proj.bias, bit_attention.k_proj.bias, bit_attention.v_proj.bias], dim=0) + ) module.o_proj = bit_attention.o_proj module.inner_attn_ln = bit_attention.inner_attn_ln if bit_attention.config.rope_scaling is None: @@ -528,16 +498,13 @@ def forward( bsz, q_len, _ = hidden_states.size() qkv_states = self.qkv_proj(hidden_states) query_states, key_states, value_states = torch.split( - qkv_states, [ - self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim, - self.num_key_value_heads * self.head_dim - ], - dim=-1) + qkv_states, + [self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim, self.num_key_value_heads * self.head_dim], + dim=-1, + ) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) past_key_value = getattr(self, "past_key_value", past_key_value) cos, sin = self.rotary_emb(value_states, position_ids) @@ -546,30 +513,24 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( - self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, :key_states.shape[-2]] + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout( - attn_weights, p=self.attention_dropout, training=self.training) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}") + raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is {attn_output.size()}") attn_output = attn_output.transpose(1, 2).contiguous() @@ -622,10 +583,8 @@ def forward( # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -635,8 +594,7 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -665,14 +623,14 @@ def forward( logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}.") + f" {target_dtype}." + ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate) + attn_output = self._flash_attention_forward(query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.inner_attn_ln(attn_output) @@ -683,14 +641,9 @@ def forward( return attn_output, attn_weights, past_key_value - def _flash_attention_forward(self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None): + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. @@ -720,7 +673,8 @@ def _flash_attention_forward(self, if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length) + query_states, key_states, value_states, attention_mask, query_length + ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens @@ -740,13 +694,7 @@ def _flash_attention_forward(self, attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal) + attn_output = flash_attn_func(query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal) return attn_output @@ -754,28 +702,24 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k) + query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, - device=query_layer.device) # There is a memcpy here, that is very bad. + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( - query_layer, attention_mask) + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, @@ -794,13 +738,11 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query class BitnetDecoderLayer(nn.Module): - def __init__(self, config: BitnetConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation]( - config=config, layer_idx=layer_idx) + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.mlp = BitnetMLP(config) self.input_layernorm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -834,7 +776,8 @@ def forward( if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`", - stacklevel=2) + stacklevel=2, + ) residual = hidden_states @@ -925,8 +868,7 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = dtype = self.config._pre_quantization_dtype else: dtype = layer.self_attn.o_proj.weight.dtype - layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=device, dtype=dtype) + layer.self_attn.past_key_value = cache_cls(self.config, max_batch_size, max_cache_len, device=device, dtype=dtype) def _reset_cache(self): for layer in self.model.layers: @@ -1025,9 +967,7 @@ def __init__(self, config: BitnetConfig): self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([ - BitnetDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList([BitnetDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) self.norm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -1055,21 +995,15 @@ def forward( cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" - ) + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one") if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.") use_cache = False if inputs_embeds is None: @@ -1083,10 +1017,7 @@ def forward( if cache_position is None: if isinstance(past_key_values, StaticCache): raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device) + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device) if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -1143,12 +1074,9 @@ def forward( next_cache = None if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() - if isinstance(next_decoder_cache, Cache) else next_decoder_cache) + next_cache = next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None) + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1172,14 +1100,9 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache target_length = self.config.max_position_embeddings else: # dynamic cache - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1) - - causal_mask = torch.full((sequence_length, target_length), - fill_value=min_dtype, - dtype=dtype, - device=device) + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 + + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) @@ -1188,10 +1111,8 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[..., :mask_length].eq( - 0.0) * attention_mask[:, None, None, :].eq(0.0) - causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill( - padding_mask, min_dtype) + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) elif attention_mask.dim() == 4: # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. @@ -1201,8 +1122,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[:mask_shape[0], :mask_shape[1], - offset:mask_shape[2] + offset, :mask_shape[3]] = mask_slice + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = mask_slice return causal_mask @@ -1279,9 +1199,7 @@ def forward( "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) @@ -1327,13 +1245,9 @@ def forward( attentions=outputs.attentions, ) - def prepare_inputs_for_generation(self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - **kwargs): + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs + ): # With static cache, the `past_key_values` is None # TODO joao: standardize interface for the different Cache classes and remove of this if has_static_cache = False @@ -1344,13 +1258,13 @@ def prepare_inputs_for_generation(self, past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - past_length = cache_position[ - 0] if cache_position is not None else past_key_values.get_seq_length() + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() max_cache_length = ( torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None else None) - cache_length = past_length if max_cache_length is None else torch.min( - max_cache_length, past_length) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] @@ -1361,7 +1275,7 @@ def prepare_inputs_for_generation(self, # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):] + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: @@ -1369,8 +1283,7 @@ def prepare_inputs_for_generation(self, # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if (max_cache_length is not None and attention_mask is not None and - cache_length + input_ids.shape[1] > max_cache_length): + if max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length: attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids") @@ -1379,7 +1292,7 @@ def prepare_inputs_for_generation(self, position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1]:] + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: @@ -1392,39 +1305,38 @@ def prepare_inputs_for_generation(self, input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] if cache_position is None: - cache_position = torch.arange( - past_length, past_length + input_length, device=input_ids.device) + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) else: cache_position = cache_position[-input_length:] if has_static_cache: past_key_values = None - model_inputs.update({ - "position_ids": position_ids, - "cache_position": cache_position, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - }) + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past),) + reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),) return reordered_past @staticmethod def recursive_set(model, name, attr): - ''' - set layers.25.mlp.up_proj to attr - ''' + """ + set layers.25.mlp.up_proj to attr + """ - names = name.split('.') + names = name.split(".") obj = model for n in names[:-1]: obj = getattr(obj, n) @@ -1521,6 +1433,7 @@ def from_quantized( fuse_gateup = quant_config.get("fuse_gateup", True) import accelerate + if checkpoint_format == "bitblas": model = cls(config) for name, module in model.named_modules(): @@ -1567,7 +1480,6 @@ def from_quantized( LLAMA_START_DOCSTRING, ) class BitnetForSequenceClassification(BitnetPreTrainedModel): - def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels @@ -1631,8 +1543,7 @@ def forward( else: if input_ids is not None: # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, - self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 sequence_lengths = sequence_lengths % input_ids.shape[-1] sequence_lengths = sequence_lengths.to(logits.device) else: @@ -1646,8 +1557,7 @@ def forward( if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or - labels.dtype == torch.int): + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" diff --git a/examples/bitnet-1.58b/nvidia_measure_memory.sh b/examples/bitnet-1.58b/nvidia_measure_memory.sh index e8998f309..82cf4855f 100755 --- a/examples/bitnet-1.58b/nvidia_measure_memory.sh +++ b/examples/bitnet-1.58b/nvidia_measure_memory.sh @@ -1 +1,3 @@ +#!/usr/bin/env bash + nvidia-smi --query-gpu=memory.used --format=csv -lms 500 diff --git a/examples/bitnet-1.58b/tokenization_bitnet.py b/examples/bitnet-1.58b/tokenization_bitnet.py index 6fea3252a..2adfd6dee 100644 --- a/examples/bitnet-1.58b/tokenization_bitnet.py +++ b/examples/bitnet-1.58b/tokenization_bitnet.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tokenization classes for LLaMA.""" + import os from shutil import copyfile from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple @@ -37,12 +38,10 @@ PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "hf-internal-testing/llama-tokenizer": - "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", }, "tokenizer_file": { - "hf-internal-testing/llama-tokenizer": - "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", }, } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { @@ -159,14 +158,10 @@ def __init__( **kwargs, ): self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs - bos_token = AddedToken( - bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token - eos_token = AddedToken( - eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token - unk_token = AddedToken( - unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token - pad_token = AddedToken( - pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token + bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token if legacy is None: logger.warning_once( @@ -174,7 +169,8 @@ def __init__( " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." " If you want to use the new behavior, set `legacy=False`. This should only be set if you understand what it" " means, and thoroughly read the reason why this was added as explained in" - " https://github.com/huggingface/transformers/pull/24565") + " https://github.com/huggingface/transformers/pull/24565" + ) legacy = True self.legacy = legacy @@ -214,8 +210,7 @@ def get_spm_processor(self, from_slow=False): with open(self.vocab_file, "rb") as f: sp_model = f.read() - model_pb2 = import_protobuf( - f"The new behavior of {self.__class__.__name__} (with `self.legacy = False`)") + model_pb2 = import_protobuf(f"The new behavior of {self.__class__.__name__} (with `self.legacy = False`)") model = model_pb2.ModelProto.FromString(sp_model) normalizer_spec = model_pb2.NormalizerSpec() normalizer_spec.add_dummy_prefix = False @@ -261,8 +256,7 @@ def tokenize(self, text: "TextInput", **kwargs) -> List[str]: tokens = super().tokenize(text, **kwargs) - if len(tokens - ) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: tokens = tokens[1:] return tokens @@ -284,7 +278,7 @@ def _tokenize(self, text, **kwargs): # 1. Encode string + prefix ex: " Hey" tokens = self.sp_model.encode(self.unk_token + text, out_type=str) # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] - return tokens[self.unk_token_length:] if len(tokens) >= self.unk_token_length else tokens + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" @@ -332,12 +326,9 @@ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return - out_vocab_file = os.path.join(save_directory, - (filename_prefix + "-" if filename_prefix else "") + - VOCAB_FILES_NAMES["vocab_file"]) + out_vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile( - self.vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) elif not os.path.isfile(self.vocab_file): with open(out_vocab_file, "wb") as fi: @@ -357,10 +348,9 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): return output - def get_special_tokens_mask(self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None, - already_has_special_tokens: bool = False) -> List[int]: + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: """ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer `prepare_for_model` method. @@ -377,20 +367,16 @@ def get_special_tokens_mask(self, `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. """ if already_has_special_tokens: - return super().get_special_tokens_mask( - token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True) + return super().get_special_tokens_mask(token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True) bos_token_id = [1] if self.add_bos_token else [] eos_token_id = [1] if self.add_eos_token else [] if token_ids_1 is None: return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id - return (bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + bos_token_id + - ([0] * len(token_ids_1)) + eos_token_id) + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + bos_token_id + ([0] * len(token_ids_1)) + eos_token_id - def create_token_type_ids_from_sequences(self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None) -> List[int]: + def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: """ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT sequence pair mask has the following format: @@ -473,9 +459,9 @@ def default_chat_template(self): "{% elif message['role'] == 'assistant' %}" "{{ ' ' + content.strip() + ' ' + eos_token }}" "{% endif %}" - "{% endfor %}") - template = template.replace("USE_DEFAULT_PROMPT", - "true" if self.use_default_system_prompt else "false") + "{% endfor %}" + ) + template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false") default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) diff --git a/examples/bitnet-1.58b/utils_quant.py b/examples/bitnet-1.58b/utils_quant.py index 5f5db5dbc..5a50edb39 100644 --- a/examples/bitnet-1.58b/utils_quant.py +++ b/examples/bitnet-1.58b/utils_quant.py @@ -24,15 +24,14 @@ def weight_quant(weight, num_bits=1): def activation_quant(x, num_bits=8): dtype = x.dtype x = x.float() - Qn = -(2**(num_bits - 1)) - Qp = 2**(num_bits - 1) - 1 + Qn = -(2 ** (num_bits - 1)) + Qp = 2 ** (num_bits - 1) - 1 s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) result = (x * s).round().clamp(Qn, Qp) / s return result.type(dtype) class BitLinearBitBLAS(nn.Module): - def __init__( self, in_features: int, @@ -68,7 +67,7 @@ def __init__( self.bitblas_matmul = self._get_or_create_bitblas_operator(matmul_config, ENABLE_TUNING) self.format = "bitnet" - self.Qp = 2**(self.input_bits - 1) - 1 + self.Qp = 2 ** (self.input_bits - 1) - 1 def _get_or_create_bitblas_operator(self, config, enable_tuning): if global_operator_cache.size() == 0: @@ -99,8 +98,7 @@ def replace_weight_param_with_qweight(self): @classmethod def from_bit_linear(cls, bitlinear, weight_group=1): - bitblas_linear = cls( - bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8) + bitblas_linear = cls(bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8) sw, qweight = bitblas_linear.create_bitblas_weights(bitlinear.weight, weight_group) bitblas_linear.register_buffer("qweight", qweight) bitblas_linear.register_buffer("sw", sw) @@ -158,8 +156,8 @@ def weight_quant(weight): @torch.compile def activation_quant(self, x, num_bits=8): x = x.float() - Qn = -(2**(num_bits - 1)) - Qp = 2**(num_bits - 1) - 1 + Qn = -(2 ** (num_bits - 1)) + Qp = 2 ** (num_bits - 1) - 1 s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) result = (x * s).round().clamp(Qn, Qp) return result.type(torch.int8), s @@ -173,9 +171,8 @@ def post_quant_process(self, input, si, sw): # for the correctness evaluation. def native_forward(self, input): - quant_input = (input + (activation_quant(input, self.input_bits) - input).detach()) - quant_weight = ( - self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach()) + quant_input = input + (activation_quant(input, self.input_bits) - input).detach() + quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach() out = nn.functional.linear(quant_input, quant_weight) if self.bias is not None: @@ -214,7 +211,6 @@ def forward(self, input): # Naive BitLinear from HuggingFace class BitLinear(nn.Linear): - def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs): super(BitLinear, self).__init__(*kargs, **kwargs) """ @@ -224,10 +220,8 @@ def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs): self.input_bits = input_bits def forward(self, input): - quant_input = input + (activation_quant(input, self.input_bits) - input).detach() - quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - - self.weight).detach() + quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach() out = nn.functional.linear(quant_input, quant_weight) if self.bias is not None: diff --git a/examples/bitnet-1.58b/vllm_workspace/conftest.py b/examples/bitnet-1.58b/vllm_workspace/conftest.py index 951f38991..e9e2997ef 100644 --- a/examples/bitnet-1.58b/vllm_workspace/conftest.py +++ b/examples/bitnet-1.58b/vllm_workspace/conftest.py @@ -20,7 +20,7 @@ from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from vllm.config import TokenizerPoolConfig -from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel) +from vllm.distributed import destroy_distributed_environment, destroy_model_parallel from vllm.inputs import TextPrompt from vllm.logger import init_logger from vllm.sequence import SampleLogprobs @@ -56,12 +56,13 @@ class _ImageAssetsBase(UserList[ImageAsset]): class _ImageAssets(_ImageAssetsBase): - def __init__(self) -> None: - super().__init__([ - ImageAsset("stop_sign"), - ImageAsset("cherry_blossom"), - ]) + super().__init__( + [ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), + ] + ) def prompts(self, prompts: _ImageAssetPrompts) -> List[str]: """ @@ -136,7 +137,6 @@ def image_assets() -> _ImageAssets: class HfRunner: - def wrap_device(self, input: _T) -> _T: if not is_cpu(): return input.to("cuda") @@ -166,7 +166,8 @@ def __init__( SentenceTransformer( model_name, device="cpu", - ).to(dtype=torch_dtype)) + ).to(dtype=torch_dtype) + ) else: if is_vision_model: auto_cls = AutoModelForVision2Seq @@ -184,7 +185,8 @@ def __init__( torch_dtype=torch_dtype, trust_remote_code=True, **model_kwargs, - )) + ) + ) self.tokenizer = AutoTokenizer.from_pretrained( model_name, @@ -204,8 +206,7 @@ def __init__( ) except Exception: logger.warning( - "Unable to auto-load processor from HuggingFace for " - "model %s. Using tokenizer instead.", + "Unable to auto-load processor from HuggingFace for model %s. Using tokenizer instead.", model_name, ) self.processor = self.tokenizer @@ -362,7 +363,7 @@ def generate_greedy_logprobs_limit( last_hidden_states, self.model.get_output_embeddings().weight.t(), ) - if (getattr(self.model.get_output_embeddings(), "bias", None) is not None): + if getattr(self.model.get_output_embeddings(), "bias", None) is not None: logits += self.model.get_output_embeddings().bias.unsqueeze(0) logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) seq_logprobs.append(logprobs) @@ -389,8 +390,7 @@ def generate_greedy_logprobs_limit( all_output_strs.append(self.tokenizer.decode(output_ids)) outputs = zip(all_output_ids, all_output_strs, all_logprobs) - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] + return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: return self.model.encode(prompts) @@ -409,7 +409,6 @@ def hf_runner(): class VllmRunner: - def __init__( self, model_name: str, @@ -514,12 +513,10 @@ def generate_greedy_logprobs( num_logprobs: int, images: Optional[List[Image.Image]] = None, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: - greedy_logprobs_params = SamplingParams( - temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs) + greedy_logprobs_params = SamplingParams(temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs) outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params, images=images) - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] + return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] def generate_beam_search( self, diff --git a/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py b/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py index 55a24543e..ea18239cb 100644 --- a/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py +++ b/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py @@ -32,15 +32,14 @@ ckpt_path = args.ckpt_path with VllmRunner( - ckpt_path, - dtype="half", - quantization="bitblas", - # set enforce_eager = False to enable cuda graph - # set enforce_eager = True to disable cuda graph - enforce_eager=False, + ckpt_path, + dtype="half", + quantization="bitblas", + # set enforce_eager = False to enable cuda graph + # set enforce_eager = True to disable cuda graph + enforce_eager=False, ) as bitnet_model: - bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], - max_tokens=1024) + bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=1024) print("bitnet inference:") print(bitbnet_outputs[0][0]) print(bitbnet_outputs[0][1]) diff --git a/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py b/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py index 4f5f87f6f..f631fb306 100644 --- a/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py +++ b/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py @@ -33,13 +33,13 @@ ckpt_path = args.ckpt_path with VllmRunner( - ckpt_path, - dtype="half", - quantization="bitnet_bitblas", - gpu_memory_utilization=0.5, - # set enforce_eager = False to enable cuda graph - # set enforce_eager = True to disable cuda graph - enforce_eager=False, + ckpt_path, + dtype="half", + quantization="bitnet_bitblas", + gpu_memory_utilization=0.5, + # set enforce_eager = False to enable cuda graph + # set enforce_eager = True to disable cuda graph + enforce_eager=False, ) as bitnet_model: bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=128) print("bitnet inference output:") diff --git a/examples/bitnet-1.58b/vllm_workspace/utils.py b/examples/bitnet-1.58b/vllm_workspace/utils.py index daa9d8f52..e96b19e28 100644 --- a/examples/bitnet-1.58b/vllm_workspace/utils.py +++ b/examples/bitnet-1.58b/vllm_workspace/utils.py @@ -3,8 +3,7 @@ TokensText = Tuple[List[int], str] -def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText], - name_0: str, name_1: str): +def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText], name_0: str, name_1: str): """ Compare the two sequences generated by different models, which should be equal. @@ -15,19 +14,14 @@ def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[Tok output_ids_0, output_str_0 = outputs_0 output_ids_1, output_str_1 = outputs_1 - assert output_str_0 == output_str_1, (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") - assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") + assert output_str_0 == output_str_1, f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" + assert output_ids_0 == output_ids_1, f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]] -def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], - outputs_1_lst: List[TokensTextLogprobs], name_0: str, name_1: str): +def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], outputs_1_lst: List[TokensTextLogprobs], name_0: str, name_1: str): """ Compare the logprobs of two sequences generated by different models, which should be similar but not necessarily equal. @@ -41,16 +35,11 @@ def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], # Loop through generated tokens. for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): - # If generated tokens don't match, then if output_id_0 != output_id_1: # Each predicted token must be in top N logprobs of the other - assert output_id_0 in logprobs_1[idx], (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") - assert output_id_1 in logprobs_0[idx], (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") + assert output_id_0 in logprobs_1[idx], f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" + assert output_id_1 in logprobs_0[idx], f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" # Break out since sequences will now diverge. break diff --git a/examples/blocksparse_attention/README.md b/examples/blocksparse_attention/README.md index 89f75b81d..34bf3c637 100644 --- a/examples/blocksparse_attention/README.md +++ b/examples/blocksparse_attention/README.md @@ -1,6 +1,5 @@ # Block-Sparse Flash-Attention -Tilelang implementation of block-sparse flash-attention kernels. - -The kernels have been used in [Rectified Sparse Attention](https://arxiv.org/abs/2506.04108) and [SeerAttention-R](https://arxiv.org/abs/2506.08889). +Tilelang implementation of block-sparse flash-attention kernels. +The kernels have been used in [Rectified Sparse Attention](https://arxiv.org/abs/2506.04108) and [SeerAttention-R](https://arxiv.org/abs/2506.08889). diff --git a/examples/blocksparse_attention/block_sparse_attn_triton.py b/examples/blocksparse_attention/block_sparse_attn_triton.py index 014f0c5fc..b94e602f6 100644 --- a/examples/blocksparse_attention/block_sparse_attn_triton.py +++ b/examples/blocksparse_attention/block_sparse_attn_triton.py @@ -1,7 +1,6 @@ # ruff: noqa: E712 import math import torch - import triton import triton.language as tl import torch.nn.functional as F @@ -15,10 +14,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -56,7 +52,6 @@ def _fwd_kernel_inner( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) # print @@ -73,8 +68,7 @@ def _fwd_kernel_inner( # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N if LAST_K_BLOCK: - qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, - float('-inf')) + qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf")) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_ij[:, None] @@ -154,7 +148,7 @@ def _fwd_kernel( v_ptrs = V + off_v mask_ptrs = block_mask_ptr + start_m * stride_bmm - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) @@ -192,24 +186,12 @@ def _fwd_kernel( acc = acc * l_recip acc = acc.to(Out.dtype.element_ty) - off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[ - None, :] * stride_od + off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) -def _forward(ctx, - q, - k, - v, - block_sparse_mask, - sm_scale, - BLOCK_M=64, - BLOCK_N=64, - num_warps=None, - num_stages=1, - out=None): - +def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None): assert q.shape[-1] == k.shape[-1] == v.shape[-1] assert k.shape[2] == v.shape[2] o = out if out is not None else torch.empty_like(q).contiguous() @@ -254,7 +236,6 @@ def _forward(ctx, class _sparse_attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, block_sparse_dense, sm_scale): # shape constraints @@ -278,9 +259,9 @@ def test_topk_sparse_attention(): torch.manual_seed(0) # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) @@ -288,9 +269,7 @@ def test_topk_sparse_attention(): downsample_len = math.ceil(SEQ_LEN / downsample_factor) print("downsample_len", downsample_len) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 print("x_ds.shape", x_ds.shape) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) @@ -302,22 +281,21 @@ def test_topk_sparse_attention(): # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal # PyTorch reference implementation - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale - attn = attn.masked_fill(~full_mask, float('-inf')) + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) # print("ref_output", ref_output) # print("triton_output", triton_output) # Verify accuracy - assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ - "Triton output doesn't match reference" + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference" print("Pass topk sparse attention test with qlen == klen") @@ -329,9 +307,9 @@ def test_topk_sparse_attention_qlt_kl(): torch.manual_seed(0) # Create inputs. - q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) + q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) # softmax scale sm_scale = 1.0 / (D_HEAD**0.5) @@ -339,8 +317,7 @@ def test_topk_sparse_attention_qlt_kl(): print("downsample_factor", downsample_factor) downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension print("downsample_len", downsample_len) - x_ds = torch.randn( - BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.bfloat16) + x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.bfloat16) # Force the first column to be high so that the first block is always selected. x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) @@ -351,26 +328,25 @@ def test_topk_sparse_attention_qlt_kl(): past_len = K_LEN - Q_LEN - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale - full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool() + full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool() full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) - causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN) + causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN) final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) - attn = attn.masked_fill(~final_mask, float('-inf')) + attn = attn.masked_fill(~final_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) # Verify accuracy. - assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ - "Triton output doesn't match reference when qlen < klen" + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference when qlen < klen" print("Pass topk sparse attention test with qlen < klen") diff --git a/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py index 7e90db7e5..9a394710f 100644 --- a/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py +++ b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py @@ -1,8 +1,8 @@ import math import torch - import tilelang import tilelang.language as T +from tilelang.profiler import do_bench import torch.nn.functional as F @@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -30,105 +27,34 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F @tilelang.jit( - out_idx=[4], pass_configs={ + out_idx=[4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal): block_M = 64 block_N = 64 num_stages = 1 threads = 128 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] block_mask_shape = [batch, heads, downsample_len, downsample_len] - dtype = "float16" - accum_dtype = "float" - block_mask_dtype = "bool" + dtype = T.float16 + accum_dtype = T.float32 + block_mask_dtype = T.bool def kernel_func(block_M, block_N, num_stages, threads): - - @T.macro - def MMA0( - K: T.Tensor(shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(shape, dtype), - V_shared: T.SharedBuffer([block_M, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - @T.prim_func def blocksparse_flashattn( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), - Output: T.Tensor(shape, dtype), + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), + Output: T.Tensor(shape, dtype), ): - with T.Kernel( - T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -141,31 +67,59 @@ def blocksparse_flashattn( scores_scale = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - block_mask = T.alloc_local([downsample_len], block_mask_dtype) + block_mask = T.alloc_fragment([downsample_len], block_mask_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - for vj in T.serial(downsample_len): - block_mask[vj] = BlockSparseMask[bz, by, bx, vj] + T.copy(BlockSparseMask[bz, by, bx, :], block_mask) loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): if block_mask[k] != 0: - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, - scores_sum, logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return blocksparse_flashattn @@ -180,18 +134,16 @@ def test_topk_sparse_attention(): torch.manual_seed(0) # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) @@ -202,15 +154,15 @@ def test_topk_sparse_attention(): # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal # PyTorch reference implementation - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale - attn = attn.masked_fill(~full_mask, float('-inf')) + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) print("ref_output", ref_output) print("tilelang_output", tilelang_output) @@ -224,5 +176,26 @@ def main(): test_topk_sparse_attention() +def run_regression_perf(): + BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 32, 256, 64 + TOPK = 2 + BLOCK = 64 + torch.manual_seed(0) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + + def run_kernel_only(): + kernel(q, k, v, block_mask) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py index e29982162..6e7321452 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py @@ -8,22 +8,26 @@ import argparse import time import math +from tilelang.profiler import do_bench from heuristic import num_splits_heuristic def flashattn(batch, heads, heads_kv, dim, dim_v): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) - dtype = "float16" - accum_dtype = "float" + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // heads_kv @tilelang.jit( - out_idx=[-1], pass_configs={ + out_idx=[-1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) - def kernel_func(block_N, block_H, page_block_size, num_split, num_stages, threads, num_pages, - max_num_blocks_per_seq, max_selected_blocks): + }, + ) + def kernel_func( + block_N, block_H, page_block_size, num_split, num_stages, threads, num_pages, max_num_blocks_per_seq, max_selected_blocks + ): shape_q = [batch, heads, dim] shape_k = [num_pages, page_block_size, heads_kv, dim] shape_v = [num_pages, page_block_size, heads_kv, dim_v] @@ -35,19 +39,20 @@ def kernel_func(block_N, block_H, page_block_size, num_split, num_stages, thread assert block_N <= page_block_size and page_block_size % block_N == 0 block_ratio = page_block_size // block_N - @T.macro - def flash_attn_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - block_table: T.Tensor(shape_block_table, "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), + @T.prim_func + def main( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + block_table: T.Tensor(shape_block_table, T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), ): - with T.Kernel( - batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + # flash_attn_split + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim_v], dtype) @@ -67,7 +72,7 @@ def flash_attn_split( sid = bz cur_kv_head = hid // (kv_group_num // valid_block_H) - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -75,7 +80,7 @@ def flash_attn_split( num_blocks = max_selected_blocks blocks_per_split = T.floordiv(num_blocks, num_split) remaining_blocks = T.floormod(num_blocks, num_split) - loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) start = blocks_per_split * sid + T.min(sid, remaining_blocks) has_valid_block = False for k in T.Pipelined(loop_range, num_stages=num_stages): @@ -85,30 +90,20 @@ def flash_attn_split( block_table_idx = T.floordiv(logical_block_idx, block_ratio) block_tile_idx = T.floormod(logical_block_idx, block_ratio) physical_block_idx = block_table[bid, block_table_idx] - T.copy( - K[physical_block_idx, - block_tile_idx * block_N:(block_tile_idx + 1) * block_N, - cur_kv_head, :], K_shared) + T.copy(K[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], K_shared) T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.if_then_else( - logical_block_idx * block_N + j >= cache_seqlens[bid], - -T.infinity(accum_dtype), acc_s[i, j]) + logical_block_idx * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j] + ) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): - scores_max[i] = T.if_then_else(scores_max[i] > scores_max_prev[i], - scores_max[i], scores_max_prev[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - - scores_max[i] * scale) + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) @@ -117,10 +112,7 @@ def flash_attn_split( T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(block_H, dim_v): acc_o[i, j] *= scores_scale[i] - T.copy( - V[physical_block_idx, - block_tile_idx * block_N:(block_tile_idx + 1) * block_N, - cur_kv_head, :], V_shared) + T.copy(V[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) if has_valid_block: for i, j in T.Parallel(block_H, dim_v): @@ -137,74 +129,47 @@ def flash_attn_split( if i < valid_block_H: Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] - @T.macro - def combine( - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), - ): + # combine with T.Kernel(heads, batch, threads=128) as (by, bz): po_local = T.alloc_fragment([dim_v], accum_dtype) o_accum_local = T.alloc_fragment([dim_v], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - max_split = T.alloc_local([1], "int32") - - T.annotate_layout({ - lse_logsum_local: - T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + max_split = T.alloc_var(T.int32) T.clear(lse_logsum_local) T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) + lse_max_local = -T.infinity(accum_dtype) for k in T.serial(num_split): - lse_local_split[0] = glse[bz, by, k] - if (lse_local_split[0] != 0): - max_split[0] = k - lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) + lse_local_split = glse[bz, by, k] + if lse_local_split != 0: + max_split = k + lse_max_local = T.max(lse_max_local, glse[bz, by, k]) for k in T.Pipelined(num_split, num_stages=1): - if k <= max_split[0]: - lse_local_split[0] = glse[bz, by, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + if k <= max_split: + lse_local_split = glse[bz, by, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local for k in T.serial(num_split): - if k <= max_split[0]: + if k <= max_split: for i in T.Parallel(dim_v): po_local[i] = Output_partial[bz, by, k, i] - lse_local_split[0] = glse[bz, by, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + lse_local_split = glse[bz, by, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) for i in T.Parallel(dim_v): - o_accum_local[i] += po_local[i] * scale_local[0] + o_accum_local[i] += po_local[i] * scale_local for i in T.Parallel(dim_v): Output[bz, by, i] = o_accum_local[i] - @T.prim_func - def main( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - block_table: T.Tensor(shape_block_table, "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), - ): - flash_attn_split(Q, K, V, block_indices, cache_seqlens, block_table, glse, - Output_partial) - combine(glse, Output_partial, Output) - return main return kernel_func class SparseFlashAttn(torch.nn.Module): - def __init__(self, batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_pages): super(SparseFlashAttn, self).__init__() self.batch = batch @@ -250,18 +215,11 @@ def forward(self, query, key, value, block_indices, cache_seqlens, block_table): num_sm = self.num_sm num_split = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) - - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') - output_partial = torch.empty((batch, heads, num_split, dim_v), - dtype=torch.float32, - device='cuda') + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") output = self.kernel( query, @@ -276,14 +234,13 @@ def forward(self, query, key, value, block_indices, cache_seqlens, block_table): return output -def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_seqlens, - block_table, page_block_size, block_size): +def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_seqlens, block_table, page_block_size, block_size): """ Paged version of sparse attention reference implementation. - + Args: query: [batch, heads, dim] - key_cache: [num_pages, page_block_size, heads_kv, dim] + key_cache: [num_pages, page_block_size, heads_kv, dim] value_cache: [num_pages, page_block_size, heads_kv, dim] block_indices: [batch, heads_kv, max_selected_blocks] - logical block indices cache_seqlens: [batch] - actual sequence lengths @@ -299,12 +256,8 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_ # Reconstruct the full key and value tensors from paged cache max_cache_seqlen = max(cache_seqlens).item() - key_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim), - dtype=key_cache.dtype, - device=key_cache.device) - value_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim_v), - dtype=value_cache.dtype, - device=value_cache.device) + key_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim), dtype=key_cache.dtype, device=key_cache.device) + value_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim_v), dtype=value_cache.dtype, device=value_cache.device) # Reconstruct full tensors from paged cache using block_table for b in range(batch): @@ -320,20 +273,14 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_ actual_block_size = end_token - start_token # Copy from paged cache to full tensors - key_full[b, :, start_token:end_token, :] = key_cache[ - physical_block_idx, :actual_block_size, :, :].transpose(0, 1) - value_full[b, :, start_token:end_token, :] = value_cache[ - physical_block_idx, :actual_block_size, :, :].transpose(0, 1) + key_full[b, :, start_token:end_token, :] = key_cache[physical_block_idx, :actual_block_size, :, :].transpose(0, 1) + value_full[b, :, start_token:end_token, :] = value_cache[physical_block_idx, :actual_block_size, :, :].transpose(0, 1) # Reshape query for grouped attention - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] # Compute attention scores - scores = einsum( - query, key_full, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = einsum(query, key_full, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] # Create sparse mask based on block_indices sparse_mask = torch.zeros_like(scores) @@ -349,24 +296,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_ sparse_mask[b, :, h, start_pos:end_pos] = 1 # Apply sparse mask - scores = scores.masked_fill(sparse_mask == 0, float('-inf')) + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) # Apply causal mask based on actual sequence lengths range_len = torch.arange(scores.shape[-1], device=scores.device).unsqueeze(0) cache_seqlens_expanded = cache_seqlens.unsqueeze(1) pad_mask = range_len >= cache_seqlens_expanded pad_mask = pad_mask[:, None, None, :] - scores = scores.masked_fill(pad_mask, float('-inf')) + scores = scores.masked_fill(pad_mask, float("-inf")) # Compute attention weights attention = F.softmax(scores / scale, dim=-1) # Apply attention to values - out = einsum(attention, value_full, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] + out = einsum(attention, value_full, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] # Reshape output back to original format - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out @@ -374,17 +320,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_ def ref_program_fa(query, kcache, vcache, cache_seqlens, block_table): # latency reference # from flash_attn_interface import flash_attn_with_kvcache # fa3 - from flash_attn import flash_attn_with_kvcache #fa2 + from flash_attn import flash_attn_with_kvcache # fa2 + query = query.unsqueeze(1) - output = flash_attn_with_kvcache( - query, kcache, vcache, cache_seqlens=cache_seqlens, block_table=block_table) + output = flash_attn_with_kvcache(query, kcache, vcache, cache_seqlens=cache_seqlens, block_table=block_table) output = output.squeeze(1) return output def main(args): - - batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = ( + args.batch, + args.heads, + args.heads_kv, + args.max_cache_seqlen, + args.dim, + args.dim_v, + ) sparse_ratio = args.sparse_ratio block_N = args.block_N page_block_size = args.page_block_size @@ -396,35 +348,30 @@ def main(args): dtype = torch.float16 # Generate random inputs - Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') - cache_seqlens = torch.randint( - max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device='cuda') + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device="cuda") print("cache_seqlens: ", cache_seqlens) - K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') - V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") # Create paged KV cache - K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device='cuda') - V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v), - dtype=dtype, - device='cuda') + K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device="cuda") + V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v), dtype=dtype, device="cuda") # Create block table and block indices for dense case (all blocks selected) max_num_blocks_per_seq = int(math.ceil(max_cache_seqlen / page_block_size)) print("max_num_blocks_per_seq: ", max_num_blocks_per_seq) - block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device='cuda') - block_indices = torch.zeros((batch, heads_kv, max_selected_blocks), - dtype=torch.int32, - device='cuda') + block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device="cuda") + block_indices = torch.zeros((batch, heads_kv, max_selected_blocks), dtype=torch.int32, device="cuda") # Fill block table and block indices and cache # Create a pool of available physical blocks - total_blocks_needed = sum( - int(math.ceil(cache_seqlens[seq_idx].item() / page_block_size)) for seq_idx in range(batch)) + total_blocks_needed = sum(int(math.ceil(cache_seqlens[seq_idx].item() / page_block_size)) for seq_idx in range(batch)) available_blocks = list(range(total_blocks_needed)) import random + random.seed(42) # For reproducibility random.shuffle(available_blocks) @@ -459,10 +406,8 @@ def main(args): actual_block_size = end_token - start_token # Copy K and V data to the paged cache - K_cache[physical_block_idx, :actual_block_size, :, :] = K[seq_idx, - start_token:end_token, :, :] - V_cache[physical_block_idx, :actual_block_size, :, :] = V[seq_idx, - start_token:end_token, :, :] + K_cache[physical_block_idx, :actual_block_size, :, :] = K[seq_idx, start_token:end_token, :, :] + V_cache[physical_block_idx, :actual_block_size, :, :] = V[seq_idx, start_token:end_token, :, :] # Fill block_indices for sparse attention # For dense case (verification), we select all blocks in reverse order @@ -497,10 +442,9 @@ def main(args): remaining_blocks = [b for b in all_blocks if b not in selected_blocks] if remaining_blocks: import random + random.seed(42) # For reproducibility - additional_blocks = random.sample( - remaining_blocks, - min(num_selected - recent_blocks, len(remaining_blocks))) + additional_blocks = random.sample(remaining_blocks, min(num_selected - recent_blocks, len(remaining_blocks))) selected_blocks.extend(additional_blocks) # Sort selected blocks in reverse order (most recent first) @@ -513,25 +457,20 @@ def main(args): block_indices[seq_idx, head_idx, i] = -1 # Initialize sparse attention module - sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, - num_blocks) - output_sparse = sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, - block_table) + sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_blocks) + output_sparse = sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table) import flash_attn # noqa: F401 - output_ref_torch = ref_program_torch_paged(Q, K_cache, V_cache, block_indices, cache_seqlens, - block_table, page_block_size, block_N) + output_ref_torch = ref_program_torch_paged(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table, page_block_size, block_N) output_ref_fa = ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table) # Check correctness if sparse_ratio == 0.0: max_diff = torch.max(torch.abs(output_sparse - output_ref_fa)).item() mean_diff = torch.mean(torch.abs(output_sparse - output_ref_fa)).item() - assert torch.allclose( - output_ref_fa, output_ref_torch, atol=1e-2), "Reference outputs do not match!" + assert torch.allclose(output_ref_fa, output_ref_torch, atol=1e-2), "Reference outputs do not match!" else: - max_diff = torch.max(torch.abs(output_sparse - output_ref_torch)).item() mean_diff = torch.mean(torch.abs(output_sparse - output_ref_torch)).item() @@ -573,18 +512,140 @@ def main(args): print(f"Speedup: {kernel_time_fa / kernel_time:.2f}x") +def run_regression_perf(args): + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = ( + args.batch, + args.heads, + args.heads_kv, + args.max_cache_seqlen, + args.dim, + args.dim_v, + ) + sparse_ratio = args.sparse_ratio + block_N = args.block_N + page_block_size = args.page_block_size + num_blocks = args.num_pages + max_selected_blocks = int(math.ceil(max_cache_seqlen / block_N)) + dtype = torch.float16 + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device="cuda") + V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v), dtype=dtype, device="cuda") + max_num_blocks_per_seq = int(math.ceil(max_cache_seqlen / page_block_size)) + block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device="cuda") + block_indices = torch.zeros((batch, heads_kv, max_selected_blocks), dtype=torch.int32, device="cuda") + total_blocks_needed = sum(int(math.ceil(cache_seqlens[seq_idx].item() / page_block_size)) for seq_idx in range(batch)) + available_blocks = list(range(total_blocks_needed)) + import random + + random.seed(42) + random.shuffle(available_blocks) + block_assignment = {} + block_idx_counter = 0 + for seq_idx in range(batch): + seq_len = cache_seqlens[seq_idx].item() + num_blocks_needed = int(math.ceil(seq_len / page_block_size)) + for block_idx in range(num_blocks_needed): + physical_block_idx = available_blocks[block_idx_counter] + block_table[seq_idx, block_idx] = physical_block_idx + block_assignment[(seq_idx, block_idx)] = physical_block_idx + block_idx_counter += 1 + for seq_idx in range(batch): + seq_len = cache_seqlens[seq_idx].item() + num_blocks_needed = int(math.ceil(seq_len / page_block_size)) + for block_idx in range(num_blocks_needed): + physical_block_idx = block_assignment[(seq_idx, block_idx)] + start_token = block_idx * page_block_size + end_token = min(start_token + page_block_size, seq_len) + actual_block_size = end_token - start_token + K_cache[physical_block_idx, :actual_block_size, :, :] = K[seq_idx, start_token:end_token, :, :] + V_cache[physical_block_idx, :actual_block_size, :, :] = V[seq_idx, start_token:end_token, :, :] + for seq_idx in range(batch): + seq_len = cache_seqlens[seq_idx].item() + num_tile = int(math.ceil(seq_len / block_N)) + if sparse_ratio == 0.0: + selected_blocks = min(num_tile, max_selected_blocks) + for head_idx in range(heads_kv): + for i in range(selected_blocks): + block_indices[seq_idx, head_idx, i] = num_tile - 1 - i + for i in range(selected_blocks, max_selected_blocks): + block_indices[seq_idx, head_idx, i] = -1 + else: + num_selected = int(num_tile * (1.0 - sparse_ratio)) + num_selected = max(1, min(num_selected, max_selected_blocks)) + all_blocks = list(range(num_tile)) + for head_idx in range(heads_kv): + selected_blocks = [] + recent_blocks = 1 + selected_blocks.append(num_tile - 1) + if num_selected > recent_blocks: + remaining_blocks = [b for b in all_blocks if b not in selected_blocks] + if remaining_blocks: + import random + + random.seed(42) + additional_blocks = random.sample(remaining_blocks, min(num_selected - recent_blocks, len(remaining_blocks))) + selected_blocks.extend(additional_blocks) + + selected_blocks.sort(reverse=True) + + for i in range(len(selected_blocks)): + block_indices[seq_idx, head_idx, i] = selected_blocks[i] + for i in range(len(selected_blocks), max_selected_blocks): + block_indices[seq_idx, head_idx, i] = -1 + + sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_blocks) + kernel = sparse_attn.kernel + batch = sparse_attn.batch + heads = sparse_attn.heads + heads_kv = sparse_attn.heads_kv + dim_v = sparse_attn.dim_v + dim = sparse_attn.dim + block_size = sparse_attn.block_N + max_selected_blocks = block_indices.shape[-1] + + num_m_blocks = 1 * (heads // heads_kv + sparse_attn.block_H - 1) // sparse_attn.block_H + num_n_blocks = max_selected_blocks + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + + num_sm = sparse_attn.num_sm + + num_split = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + + def run_kernel_only(): + kernel( + Q, + K_cache, + V_cache, + block_indices, + cache_seqlens, + block_table, + glse, + output_partial, + ) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.0, help='sparse ratio') - parser.add_argument('--block_N', type=int, default=64, help='block_N') - parser.add_argument('--page_block_size', type=int, default=256, help='block size of pages') - parser.add_argument('--num_pages', type=int, default=1024, help='total number of pages') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.0, help="sparse ratio") + parser.add_argument("--block_N", type=int, default=64, help="block_N") + parser.add_argument("--page_block_size", type=int, default=256, help="block size of pages") + parser.add_argument("--num_pages", type=int, default=1024, help="total number of pages") args = parser.parse_args() main(args) diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py index ae3004267..d6cf7d917 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py @@ -7,20 +7,22 @@ import time import math from heuristic import num_splits_heuristic +from tilelang.profiler import do_bench def flashattn(batch, heads, heads_kv, dim, dim_v): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) - dtype = "float16" - accum_dtype = "float" + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // heads_kv @tilelang.jit( - out_idx=[-1], pass_configs={ + out_idx=[-1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) - def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, - max_selected_blocks): + }, + ) + def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, max_selected_blocks): shape_q = [batch, heads, dim] shape_k = [batch, max_cache_seqlen, heads_kv, dim] shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] @@ -29,19 +31,21 @@ def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seql part_shape = [batch, heads, num_split, dim_v] valid_block_H = min(block_H, kv_group_num) - @T.macro - def flash_attn_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - # actual_num_blocks: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), + @T.prim_func + def main( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + # actual_num_blocks: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), ): - with T.Kernel( - batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + # flash_attn_split(Q, K, V, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial) + # flash_attn_split + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim_v], dtype) @@ -62,7 +66,7 @@ def flash_attn_split( sid = bz cur_kv_head = hid // (kv_group_num // valid_block_H) - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -70,7 +74,7 @@ def flash_attn_split( num_blocks = max_selected_blocks blocks_per_split = T.floordiv(num_blocks, num_split) remaining_blocks = T.floormod(num_blocks, num_split) - loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) start = blocks_per_split * sid + T.min(sid, remaining_blocks) has_valid_block = False @@ -78,27 +82,18 @@ def flash_attn_split( i_s = block_indices[bid, cur_kv_head, start + k] if i_s >= 0: has_valid_block = True - T.copy(K[bid, i_s * block_N:(i_s + 1) * block_N, cur_kv_head, :], K_shared) + T.copy(K[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], K_shared) T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition for i, j in T.Parallel(block_H, block_N): - acc_s[i, - j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid], - -T.infinity(accum_dtype), acc_s[i, j]) + acc_s[i, j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j]) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): - scores_max[i] = T.if_then_else(scores_max[i] > scores_max_prev[i], - scores_max[i], scores_max_prev[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - - scores_max[i] * scale) + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) @@ -107,7 +102,7 @@ def flash_attn_split( T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(block_H, dim_v): acc_o[i, j] *= scores_scale[i] - T.copy(V[bid, i_s * block_N:(i_s + 1) * block_N, cur_kv_head, :], V_shared) + T.copy(V[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) if has_valid_block: for i, j in T.Parallel(block_H, dim_v): @@ -124,74 +119,47 @@ def flash_attn_split( if i < valid_block_H: Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] - @T.macro - def combine( - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), - ): + # combine with T.Kernel(heads, batch, threads=128) as (by, bz): po_local = T.alloc_fragment([dim_v], accum_dtype) o_accum_local = T.alloc_fragment([dim_v], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - max_split = T.alloc_local([1], "int32") - - T.annotate_layout({ - lse_logsum_local: - T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + max_split = T.alloc_var(T.int32) T.clear(lse_logsum_local) T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) + lse_max_local = -T.infinity(accum_dtype) for k in T.serial(num_split): - lse_local_split[0] = glse[bz, by, k] - if (lse_local_split[0] != 0): - max_split[0] = k - lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) + lse_local_split = glse[bz, by, k] + if lse_local_split != 0: + max_split = k + lse_max_local = T.max(lse_max_local, glse[bz, by, k]) for k in T.Pipelined(num_split, num_stages=1): - if k <= max_split[0]: - lse_local_split[0] = glse[bz, by, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + if k <= max_split: + lse_local_split = glse[bz, by, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local for k in T.serial(num_split): - if k <= max_split[0]: + if k <= max_split: for i in T.Parallel(dim_v): po_local[i] = Output_partial[bz, by, k, i] - lse_local_split[0] = glse[bz, by, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + lse_local_split = glse[bz, by, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) for i in T.Parallel(dim_v): - o_accum_local[i] += po_local[i] * scale_local[0] + o_accum_local[i] += po_local[i] * scale_local for i in T.Parallel(dim_v): Output[bz, by, i] = o_accum_local[i] - @T.prim_func - def main( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - # actual_num_blocks: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), - ): - # flash_attn_split(Q, K, V, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial) - flash_attn_split(Q, K, V, block_indices, cache_seqlens, glse, Output_partial) - combine(glse, Output_partial, Output) - return main return kernel_func class SparseFlashAttn(torch.nn.Module): - def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): super(SparseFlashAttn, self).__init__() self.batch = batch @@ -210,7 +178,8 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): num_stages=2, threads=128, max_cache_seqlen=T.dynamic("max_cache_seqlen"), - max_selected_blocks=T.dynamic("max_selected_blocks")) + max_selected_blocks=T.dynamic("max_selected_blocks"), + ) props = torch.cuda.get_device_properties(torch.device("cuda:0")) self.num_sm = props.multi_processor_count @@ -233,25 +202,17 @@ def forward(self, query, key, value, block_indices, cache_seqlens): num_sm = self.num_sm num_split = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) - - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') - output_partial = torch.empty((batch, heads, num_split, dim_v), - dtype=torch.float32, - device='cuda') + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") output = self.kernel(query, key, value, block_indices, cache_seqlens, glse, output_partial) return output -def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seqlens, - max_cache_seqlen, block_size): +def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, block_size): """ Args: query: [batch, heads, dim] @@ -273,31 +234,24 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql block_H = 64 actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32) - actual_num_blocks = actual_num_blocks[:, - 0] #[batch], number of valid blocks, assume all groups in the same batch have the same number of blocks + actual_num_blocks = actual_num_blocks[ + :, 0 + ] # [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks # get num_split num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H - num_n_blocks = max_selected_blocks #(kv_seqlen + block_size - 1 ) // block_size + num_n_blocks = max_selected_blocks # (kv_seqlen + block_size - 1 ) // block_size # num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks - size_one_kv_head = max_selected_blocks * block_size * ( - dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2 + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks num_sm = 132 num_split = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) - - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') - Output_partial = torch.empty((batch, heads, num_split, dim_v), - dtype=torch.float32, - device='cuda') + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( block_N=block_size, block_H=block_H, @@ -305,29 +259,24 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql num_stages=2, threads=128, max_cache_seqlen=T.dynamic("max_cache_seqlen"), - max_selected_blocks=T.dynamic("max_selected_blocks")) + max_selected_blocks=T.dynamic("max_selected_blocks"), + ) output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial) return output -def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): - +def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): batch, heads, dim = query.shape heads_kv = key.shape[2] num_head_groups = query.shape[1] // key.shape[2] scale = dim**0.5 - key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] - value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] sparse_mask = torch.zeros_like(scores) # Assign mask values based on block_indices @@ -336,28 +285,26 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache valid_indices = block_indices[b, h] # Extract indices for this batch and head for idx in valid_indices: if idx >= 0: - sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 - scores = scores.masked_fill(sparse_mask == 0, float('-inf')) + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) - range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) cache_seqlens_expanded = cache_seqlens.unsqueeze(1) pad_mask = range_len >= cache_seqlens_expanded pad_mask = pad_mask[:, None, None, :] - scores = scores.masked_fill(pad_mask, float('-inf')) - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] - out = einsum(attention, value, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out -def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): +def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): # latency reference # from flash_attn_interface import flash_attn_with_kvcache # fa3 - from flash_attn import flash_attn_with_kvcache #fa2 + from flash_attn import flash_attn_with_kvcache # fa2 + query = query.unsqueeze(1) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = output.squeeze(1) @@ -369,23 +316,13 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3): print(name + " all_close={}".format(all_close)) if not all_close: diff = (expect - actual).abs() - print("all_close={}, max={}, min={}, mean={}".format(all_close, - diff.max().item(), - diff.min().item(), - diff.mean().item())) + print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item())) max_indices = torch.nonzero(diff == diff.max().item()) first_index = tuple(max_indices[0].tolist()) print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}") -def main(batch=8, - heads=32, - heads_kv=8, - max_cache_seqlen=8192, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32): +def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v sparse_ratio = sparse_ratio block_size = block_size @@ -393,10 +330,10 @@ def main(batch=8, print("max_selected_blocks: ", max_selected_blocks) dtype = torch.float16 - Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') - K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') - V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') - cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') # # Ensure at least one element equals cache_seqlen # random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index @@ -407,10 +344,7 @@ def main(batch=8, max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() print("max_valid_num_blocks: ", max_valid_num_blocks) # Initialize block_indices with -1 (for padding blocks) - block_indices = torch.full((batch, heads_kv, max_selected_blocks), - -1, - dtype=torch.int32, - device='cuda') + block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda") # max_num_blocks = int((max_cache_seqlen + block_size - 1)/ block_size) # block_indices = torch.full((batch, heads_kv, max_num_blocks), -1, dtype=torch.int32, device='cuda') @@ -419,10 +353,9 @@ def main(batch=8, max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch if max_valid_block > 0: # Ensure there's at least one valid block for h in range(heads_kv): - valid_indices = torch.randperm( - max_valid_block, device='cuda', dtype=torch.int32)[:max_selected_blocks] + valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks] # valid_indices = torch.randperm(max_valid_block, device='cuda', dtype=torch.int32)[:max_num_blocks] - block_indices[b, h, :len(valid_indices)] = valid_indices + block_indices[b, h, : len(valid_indices)] = valid_indices # Sort indices within each batch-group for consistency block_indices, _ = block_indices.sort(dim=-1, descending=True) @@ -435,8 +368,7 @@ def main(batch=8, print("max_num_blocks: ", max_num_blocks) # parity reference - ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, - block_size) + ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) @@ -446,13 +378,11 @@ def main(batch=8, ## latency reference for _ in range(10): - ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, - max_num_blocks, block_size) + ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) torch.cuda.synchronize() start = time.time() for _ in range(100): - ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, - max_num_blocks, block_size) + ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) torch.cuda.synchronize() print("dense time: ", (time.time() - start) / 100 * 1000) @@ -468,17 +398,67 @@ def main(batch=8, print("sparse time: ", (time.time() - start) / 100 * 1000) +def run_regression_perf(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + sparse_ratio = sparse_ratio + block_size = block_size + max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) + dtype = torch.float16 + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") + max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() + block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda") + + for b in range(batch): + max_valid_block = max_valid_num_blocks[b].item() + if max_valid_block > 0: + for h in range(heads_kv): + valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks] + block_indices[b, h, : len(valid_indices)] = valid_indices + + block_indices, _ = block_indices.sort(dim=-1, descending=True) + sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) + batch = sparse_kernel.batch + heads = sparse_kernel.heads + heads_kv = sparse_kernel.heads_kv + dim_v = sparse_kernel.dim_v + dim = sparse_kernel.dim + block_size = sparse_kernel.block_size + max_selected_blocks = block_indices.shape[-1] + + num_m_blocks = 1 * (heads // heads_kv + sparse_kernel.block_H - 1) // sparse_kernel.block_H + num_n_blocks = max_selected_blocks + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + num_sm = sparse_kernel.num_sm + + num_split = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + kernel = sparse_kernel.kernel + + def run_kernel_only(): + kernel(Q, K, V, block_indices, cache_seqlens, glse, output_partial) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') - parser.add_argument('--block_size', type=int, default=32, help='block_size') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") args = parser.parse_args() - main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, - args.sparse_ratio, args.block_size) + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py index ad62817dd..e48428fb8 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py @@ -5,22 +5,24 @@ import tilelang.language as T from einops import rearrange, einsum import argparse - import time import math from heuristic import num_splits_heuristic +from tilelang.profiler import do_bench def flashattn(batch, heads, heads_kv, dim, dim_v): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) - dtype = "float16" - accum_dtype = "float" + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // heads_kv @tilelang.jit( - out_idx=[-1], pass_configs={ + out_idx=[-1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, + ) def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks): shape_q = [batch, heads, dim] shape_k = [batch, max_cache_seqlen, heads_kv, dim] @@ -30,22 +32,21 @@ def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seql part_shape = [batch, heads, num_split, dim_v] valid_block_H = min(block_H, kv_group_num) - @T.macro - def flash_attn_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_mask: T.Tensor(shape_mask, "bool"), - cache_seqlens: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), + @T.prim_func + def main( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_mask: T.Tensor(shape_mask, T.bool), + cache_seqlens: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), ): - with T.Kernel( - batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim_v], dtype) - # O_shared = T.alloc_shared([valid_block_H, dim_v], dtype) acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) @@ -62,38 +63,31 @@ def flash_attn_split( sid = bz cur_kv_head = hid // (kv_group_num // valid_block_H) - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) blocks_per_split = T.floordiv(num_blocks, num_split) remaining_blocks = T.floormod(num_blocks, num_split) - loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) start = blocks_per_split * sid + T.min(sid, remaining_blocks) has_valid_block = False for k in T.Pipelined(loop_range, num_stages=num_stages): if block_mask[bid, hid, start + k]: has_valid_block = True - T.copy( - K[bid, (start + k) * block_N:(start + k + 1) * block_N, cur_kv_head, :], - K_shared) + T.copy(K[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], K_shared) T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else((start + k) * block_N + j - >= cache_seqlens[bx], - -T.infinity(accum_dtype), acc_s[i, j]) + acc_s[i, j] = T.if_then_else( + (start + k) * block_N + j >= cache_seqlens[bx], -T.infinity(accum_dtype), acc_s[i, j] + ) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - - scores_max[i] * scale) + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) @@ -102,9 +96,7 @@ def flash_attn_split( T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(block_H, dim_v): acc_o[i, j] *= scores_scale[i] - T.copy( - V[bid, (start + k) * block_N:(start + k + 1) * block_N, cur_kv_head, :], - V_shared) + T.copy(V[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) if has_valid_block: for i, j in T.Parallel(block_H, dim_v): @@ -120,65 +112,39 @@ def flash_attn_split( if i < valid_block_H: Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] - @T.macro - def combine( - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), - ): with T.Kernel(heads, batch, threads=128) as (by, bz): po_local = T.alloc_fragment([dim_v], accum_dtype) o_accum_local = T.alloc_fragment([dim_v], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - - T.annotate_layout({ - lse_logsum_local: - T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) T.clear(lse_logsum_local) T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) + lse_max_local = -T.infinity(accum_dtype) for k in T.serial(num_split): - lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) + lse_max_local = T.max(lse_max_local, glse[bz, by, k]) for k in T.Pipelined(num_split, num_stages=1): - lse_local_split[0] = glse[bz, by, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + lse_local_split = glse[bz, by, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local for k in T.serial(num_split): for i in T.Parallel(dim_v): po_local[i] = Output_partial[bz, by, k, i] - lse_local_split[0] = glse[bz, by, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + lse_local_split = glse[bz, by, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) for i in T.Parallel(dim_v): - o_accum_local[i] += po_local[i] * scale_local[0] + o_accum_local[i] += po_local[i] * scale_local for i in T.Parallel(dim_v): Output[bz, by, i] = o_accum_local[i] - @T.prim_func - def main( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_mask: T.Tensor(shape_mask, "bool"), - cache_seqlens: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), - ): - flash_attn_split(Q, K, V, block_mask, cache_seqlens, glse, Output_partial) - combine(glse, Output_partial, Output) - return main return kernel_func class SparseFlashAttn(torch.nn.Module): - def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): super(SparseFlashAttn, self).__init__() self.batch = batch @@ -197,7 +163,8 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): num_stages=2, threads=128, max_cache_seqlen=T.dynamic("max_cache_seqlen"), - num_blocks=T.dynamic("num_blocks")) + num_blocks=T.dynamic("num_blocks"), + ) props = torch.cuda.get_device_properties(torch.device("cuda:0")) self.num_sm = props.multi_processor_count @@ -216,24 +183,16 @@ def forward(self, query, key, value, block_mask, cache_seqlens): num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H num_n_blocks = max_selected_blocks - size_one_kv_head = max_selected_blocks * block_size * ( - dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2 + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks # num_sm = 132 num_sm = self.num_sm num_split = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) # print("num_split: ", num_split) - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') - Output_partial = torch.empty((batch, heads, num_split, dim_v), - dtype=torch.float32, - device='cuda') + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") output = self.kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) return output @@ -258,26 +217,21 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, block_H = 64 actual_num_blocks = torch.sum(block_mask, dim=-1).to(torch.int32) - actual_num_blocks = actual_num_blocks[:, - 0] #[batch], number of valid blocks, assume all groups in the same batch have the same number of blocks + actual_num_blocks = actual_num_blocks[ + :, 0 + ] # [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks max_selected_blocks = actual_num_blocks.max().item() # get num_split num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H - num_n_blocks = max_selected_blocks #(kv_seqlen + block_size - 1 ) // block_size + num_n_blocks = max_selected_blocks # (kv_seqlen + block_size - 1 ) // block_size # num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks - size_one_kv_head = max_selected_blocks * block_size * ( - dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2 + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks num_sm = 132 num_split = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( block_N=block_size, @@ -286,11 +240,10 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, num_stages=2, threads=128, max_cache_seqlen=T.dynamic("max_cache_seqlen"), - num_blocks=T.dynamic("num_blocks")) - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') - Output_partial = torch.empty((batch, heads, num_split, dim_v), - dtype=torch.float32, - device='cuda') + num_blocks=T.dynamic("num_blocks"), + ) + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") # print(kernel.get_kernel_source()) output = kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) @@ -298,24 +251,18 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, return output -def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): - +def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size): batch, heads, dim = query.shape heads_kv = key.shape[2] num_head_groups = query.shape[1] // key.shape[2] scale = dim**0.5 - key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] - value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] sparse_mask = torch.zeros_like(scores) # Assign mask values @@ -323,29 +270,27 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se for h in range(heads_kv): for idx in range(num_blocks): if block_mask[b, h, idx]: - sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 - scores = scores.masked_fill(sparse_mask == 0, float('-inf')) + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) - range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) cache_seqlens_expanded = cache_seqlens.unsqueeze(1) pad_mask = range_len >= cache_seqlens_expanded pad_mask = pad_mask[:, None, None, :] - scores = scores.masked_fill(pad_mask, float('-inf')) - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] - out = einsum(attention, value, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out -def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): +def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): # latency reference # from flash_attn_interface import flash_attn_with_kvcache # fa3 - from flash_attn import flash_attn_with_kvcache #fa2 + from flash_attn import flash_attn_with_kvcache # fa2 + query = query.unsqueeze(1) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = output.squeeze(1) @@ -359,23 +304,13 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3): # print(expect[3, 28]) # print(actual[3, 28]) diff = (expect - actual).abs() - print("all_close={}, max={}, min={}, mean={}".format(all_close, - diff.max().item(), - diff.min().item(), - diff.mean().item())) + print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item())) max_indices = torch.nonzero(diff == diff.max().item()) first_index = tuple(max_indices[0].tolist()) print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}") -def main(batch=8, - heads=32, - heads_kv=8, - max_cache_seqlen=8192, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32): +def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v sparse_ratio = sparse_ratio block_size = block_size @@ -383,14 +318,13 @@ def main(batch=8, print("max_selected_blocks: ", max_selected_blocks) dtype = torch.float16 - Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') - K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') - V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') - cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") # Ensure at least one element equals cache_seqlen - random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index - cache_seqlens[ - random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence + random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index + cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') print("cache_seqlens: ", cache_seqlens) @@ -402,7 +336,7 @@ def main(batch=8, max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() print("max_valid_num_blocks: ", max_valid_num_blocks) # Initialize block_mask with false (for padding blocks) - block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device='cuda') + block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda") # Assign valid indices while ensuring no duplicates within each batch-group for b in range(batch): @@ -410,13 +344,12 @@ def main(batch=8, valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch if valid_num_block > 0: # Ensure there's at least one valid block for h in range(heads_kv): - perm = torch.randperm(max_valid_block, device='cuda')[:valid_num_block] + perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block] block_mask[b, h, perm] = True # print("block_mask: ", block_mask) # parity reference - ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size) + ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) # out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size) model = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) out = model(Q, K, V, block_mask, cache_seqlens) @@ -426,13 +359,11 @@ def main(batch=8, ## latency reference for _ in range(10): - ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size) + ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) torch.cuda.synchronize() start = time.time() for _ in range(100): - ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size) + ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) torch.cuda.synchronize() print("dense time: ", (time.time() - start) / 100 * 1000) @@ -449,17 +380,72 @@ def main(batch=8, print("sparse time: ", (time.time() - start) / 100 * 1000) +def run_regression_perf(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + sparse_ratio = sparse_ratio + block_size = block_size + max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) + dtype = torch.float16 + + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") + random_index = torch.randint(0, batch, (1,), device="cuda").item() + cache_seqlens[random_index] = max_cache_seqlen + + num_blocks = (max_cache_seqlen + block_size - 1) // block_size + + valid_num_blocks = torch.ceil(cache_seqlens * (1 - sparse_ratio) / block_size).int() + max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() + block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda") + + for b in range(batch): + max_valid_block = max_valid_num_blocks[b].item() + valid_num_block = valid_num_blocks[b].item() + if valid_num_block > 0: + for h in range(heads_kv): + perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block] + block_mask[b, h, perm] = True + + model = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) + batch = model.batch + heads = model.heads + heads_kv = model.heads_kv + dim_v = model.dim_v + dim = model.dim + block_size = model.block_size + block_H = model.block_H + max_cache_seqlen = K.shape[1] + max_selected_blocks = (max_cache_seqlen + block_size - 1) // block_size + num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H + num_n_blocks = max_selected_blocks + + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + num_sm = model.num_sm + num_split = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + kernel = model.kernel + + def run_kernel_only(): + kernel(Q, K, V, block_mask, cache_seqlens, glse, Output_partial) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') - parser.add_argument('--block_size', type=int, default=32, help='block_size') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") args = parser.parse_args() - main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, - args.sparse_ratio, args.block_size) + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py index 85b72b775..01695742b 100644 --- a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py @@ -5,19 +5,15 @@ import argparse from einops import rearrange, einsum import torch.nn.functional as F - import math import time from heuristic import num_splits_heuristic +from tilelang.profiler import do_bench @triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4]\ - for num_stages in [1, 2, 3, 4, 7] - ], - key=['BLOCK_H', 'BLOCK_N', 'BLOCK_D'], + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_H", "BLOCK_N", "BLOCK_D"], ) @triton.jit def _split_kernel( @@ -79,16 +75,11 @@ def _split_kernel( loop_range = blocks_per_split q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h - k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[ - None, :] * stride_k_s + offs_d[:, None] * stride_k_d - v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, - None] * stride_v_s + offs_d[ - None, :] * stride_v_d + k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[None, :] * stride_k_s + offs_d[:, None] * stride_k_d + v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, None] * stride_v_s + offs_d[None, :] * stride_v_d mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h - q = tl.load( - q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, - mask=offs_h[:, None] < gqa_group_size) + q = tl.load(q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, mask=offs_h[:, None] < gqa_group_size) start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks) for i in range(loop_range): block_idx = tl.load(mask_ptr + (start + i) * stride_mask_s) @@ -119,23 +110,18 @@ def _split_kernel( acc = acc * l_recip acc = acc.to(o_partial_ptr.dtype.element_ty) - lse_partial_ptr += batch_idx * stride_lse_b + ( - head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split + lse_partial_ptr += batch_idx * stride_lse_b + (head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size) - o_partial_ptr += batch_idx * stride_o_b + ( - head_idx_q + - offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d + o_partial_ptr += ( + batch_idx * stride_o_b + (head_idx_q + offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d + ) tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size) @triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4]\ - for num_stages in [1, 2, 3, 4, 7] - ], - key=['BLOCK_D'], + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_D"], ) @triton.jit def _merge_kernel( @@ -163,18 +149,15 @@ def _merge_kernel( offs_d = tl.arange(0, BLOCK_D) lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h - lse = tl.load( - lse_offsets + offs_splits * lse_partial_stride_split, - mask=offs_splits < num_splits, - other=float("-inf")) + lse = tl.load(lse_offsets + offs_splits * lse_partial_stride_split, mask=offs_splits < num_splits, other=float("-inf")) lse_max = tl.max(lse) o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h o_partial = tl.load( - o_offsets + offs_splits[:, None] * o_partial_stride_split + - offs_d[None, :] * o_partial_stride_d, - mask=offs_splits[:, None] < num_splits) + o_offsets + offs_splits[:, None] * o_partial_stride_split + offs_d[None, :] * o_partial_stride_d, + mask=offs_splits[:, None] < num_splits, + ) sumexp_normalized_splitk = tl.exp(lse - lse_max) sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0) numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0) @@ -209,19 +192,13 @@ def block_sparse_flash_decode_gqa_indice_triton( num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H num_n_blocks = max_selected_blocks - size_one_kv_head = max_selected_blocks * block_size * ( - dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2 + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks num_sm = 64 # num_sm = self.num_sm num_splits = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) # print("num_splits:", num_splits, "num_blocks:", num_n_blocks) @@ -295,24 +272,18 @@ def block_sparse_flash_decode_gqa_indice_triton( return output -def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): - +def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): batch, heads, dim = query.shape heads_kv = key.shape[2] dim_v = value.shape[-1] num_head_groups = query.shape[1] // key.shape[2] scale = dim**0.5 - key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] - value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] sparse_mask = torch.zeros_like(scores) # Assign mask values based on block_indices @@ -321,42 +292,33 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache valid_indices = block_indices[b, h] # Extract indices for this batch and head for idx in valid_indices: if idx >= 0: - sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 - scores = scores.masked_fill(sparse_mask == 0, float('-inf')) + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) - range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) cache_seqlens_expanded = cache_seqlens.unsqueeze(1) pad_mask = range_len >= cache_seqlens_expanded pad_mask = pad_mask[:, None, None, :] - scores = scores.masked_fill(pad_mask, float('-inf')) - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] - out = einsum(attention, value, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out def ref_program_fa(query, key, value, cache_seqlens): # latency reference # from flash_attn_interface import flash_attn_with_kvcache # fa3 - from flash_attn import flash_attn_with_kvcache #fa2 + from flash_attn import flash_attn_with_kvcache # fa2 + query = query.unsqueeze(1) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = output.squeeze(1) return output -def main(batch=64, - heads=32, - heads_kv=8, - max_cache_seqlen=8192, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32): - +def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v sparse_ratio = sparse_ratio block_size = block_size @@ -369,34 +331,29 @@ def main(batch=64, dtype = torch.float16 block_H = 64 - Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') - K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') - V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') - cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') # Ensure at least one element equals cache_seqlen - random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index - cache_seqlens[ - random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence + random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index + cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence print("cache_seqlens: ", cache_seqlens) max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() print("max_valid_num_blocks: ", max_valid_num_blocks) # Initialize block_indices with -1 (for padding blocks) - block_indices = torch.full((batch, heads_kv, max_selected_blocks), - -1, - dtype=torch.int32, - device='cuda') + block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda") # Assign valid indices while ensuring no duplicates within each batch-group for b in range(batch): max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch if max_valid_block > 0: # Ensure there's at least one valid block for h in range(heads_kv): - valid_indices = torch.randperm( - max_valid_block, device='cuda', dtype=torch.int32)[:max_selected_blocks] - block_indices[b, h, :len(valid_indices)] = valid_indices + valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks] + block_indices[b, h, : len(valid_indices)] = valid_indices # Sort indices within each batch-group for consistency block_indices, _ = block_indices.sort(dim=-1, descending=True) @@ -408,8 +365,7 @@ def main(batch=64, max_num_blocks = torch.max(max_valid_num_blocks).item() print("max_num_blocks: ", max_num_blocks) - ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, - block_size) + ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) triton_out = block_sparse_flash_decode_gqa_indice_triton( Q, @@ -423,8 +379,7 @@ def main(batch=64, ) print("max difference: ", torch.max(torch.abs(ref - triton_out))) - assert torch.allclose( - ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation" + assert torch.allclose(ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation" print("Passed the ref test!") # Measure performance @@ -466,15 +421,13 @@ def main(batch=64, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=64, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') - parser.add_argument('--block_size', type=int, default=32, help='block_size') + parser.add_argument("--batch", type=int, default=64, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") args = parser.parse_args() - main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, - args.sparse_ratio, args.block_size) + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py index 348572526..232bcacaf 100644 --- a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py @@ -4,19 +4,14 @@ import argparse from einops import rearrange, einsum import torch.nn.functional as F - import math import time from heuristic import num_splits_heuristic @triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4]\ - for num_stages in [1, 2, 3, 4, 7] - ], - key=['BLOCK_H', 'BLOCK_N', 'BLOCK_D'], + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_H", "BLOCK_N", "BLOCK_D"], ) @triton.jit def _split_kernel( @@ -77,16 +72,11 @@ def _split_kernel( loop_range = blocks_per_split q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h - k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[ - None, :] * stride_k_s + offs_d[:, None] * stride_k_d - v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, - None] * stride_v_s + offs_d[ - None, :] * stride_v_d + k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[None, :] * stride_k_s + offs_d[:, None] * stride_k_d + v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, None] * stride_v_s + offs_d[None, :] * stride_v_d mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h - q = tl.load( - q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, - mask=offs_h[:, None] < gqa_group_size) + q = tl.load(q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, mask=offs_h[:, None] < gqa_group_size) start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks) for block_idx in range(loop_range): start_n = (start + block_idx) * BLOCK_N @@ -117,23 +107,18 @@ def _split_kernel( acc = acc * l_recip acc = acc.to(o_partial_ptr.dtype.element_ty) - lse_partial_ptr += batch_idx * stride_lse_b + ( - head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split + lse_partial_ptr += batch_idx * stride_lse_b + (head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size) - o_partial_ptr += batch_idx * stride_o_b + ( - head_idx_q + - offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d + o_partial_ptr += ( + batch_idx * stride_o_b + (head_idx_q + offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d + ) tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size) @triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4]\ - for num_stages in [1, 2, 3, 4, 7] - ], - key=['BLOCK_D'], + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_D"], ) @triton.jit def _merge_kernel( @@ -161,18 +146,15 @@ def _merge_kernel( offs_d = tl.arange(0, BLOCK_D) lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h - lse = tl.load( - lse_offsets + offs_splits * lse_partial_stride_split, - mask=offs_splits < num_splits, - other=float("-inf")) + lse = tl.load(lse_offsets + offs_splits * lse_partial_stride_split, mask=offs_splits < num_splits, other=float("-inf")) lse_max = tl.max(lse) o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h o_partial = tl.load( - o_offsets + offs_splits[:, None] * o_partial_stride_split + - offs_d[None, :] * o_partial_stride_d, - mask=offs_splits[:, None] < num_splits) + o_offsets + offs_splits[:, None] * o_partial_stride_split + offs_d[None, :] * o_partial_stride_d, + mask=offs_splits[:, None] < num_splits, + ) sumexp_normalized_splitk = tl.exp(lse - lse_max) sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0) numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0) @@ -207,19 +189,13 @@ def block_sparse_flash_decode_gqa_mask_triton( num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H num_n_blocks = max_selected_blocks - size_one_kv_head = max_selected_blocks * block_size * ( - dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2 + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks num_sm = 64 # num_sm = self.num_sm num_splits = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) # print("num_splits:", num_splits, "num_blocks:", num_n_blocks) @@ -292,24 +268,18 @@ def block_sparse_flash_decode_gqa_mask_triton( return output -def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): - +def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size): batch, heads, dim = query.shape heads_kv = key.shape[2] num_head_groups = query.shape[1] // key.shape[2] scale = dim**0.5 - key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] - value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] sparse_mask = torch.zeros_like(scores) # Assign mask values @@ -317,43 +287,34 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se for h in range(heads_kv): for idx in range(num_blocks): if block_mask[b, h, idx]: - sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 - scores = scores.masked_fill(sparse_mask == 0, float('-inf')) + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) - range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) cache_seqlens_expanded = cache_seqlens.unsqueeze(1) pad_mask = range_len >= cache_seqlens_expanded pad_mask = pad_mask[:, None, None, :] - scores = scores.masked_fill(pad_mask, float('-inf')) - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] - out = einsum(attention, value, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out def ref_program_fa(query, key, value, cache_seqlens): # latency reference # from flash_attn_interface import flash_attn_with_kvcache # fa3 - from flash_attn import flash_attn_with_kvcache #fa2 + from flash_attn import flash_attn_with_kvcache # fa2 + query = query.unsqueeze(1) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = output.squeeze(1) return output -def main(batch=64, - heads=32, - heads_kv=8, - max_cache_seqlen=8192, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32): - +def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v block_size = block_size sparse_ratio = sparse_ratio @@ -363,14 +324,13 @@ def main(batch=64, dtype = torch.float16 - Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') - K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') - V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') - cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") # Ensure at least one element equals cache_seqlen - random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index - cache_seqlens[ - random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence + random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index + cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence num_blocks = (max_cache_seqlen + block_size - 1) // block_size @@ -379,7 +339,7 @@ def main(batch=64, max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() print("max_valid_num_blocks: ", max_valid_num_blocks) # Initialize block_mask with false (for padding blocks) - block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device='cuda') + block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda") # Assign valid indices while ensuring no duplicates within each batch-group for b in range(batch): @@ -387,11 +347,10 @@ def main(batch=64, valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch if valid_num_block > 0: # Ensure there's at least one valid block for h in range(heads_kv): - perm = torch.randperm(max_valid_block, device='cuda')[:valid_num_block] + perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block] block_mask[b, h, perm] = True - ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size) + ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) triton_out = block_sparse_flash_decode_gqa_mask_triton( Q, @@ -404,8 +363,7 @@ def main(batch=64, ) # print("max difference: ", torch.max(torch.abs(ref - triton_out))) - assert torch.allclose( - ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation" + assert torch.allclose(ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation" print("Passed the ref test!") # Measure performance @@ -448,15 +406,13 @@ def main(batch=64, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=64, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') - parser.add_argument('--block_size', type=int, default=32, help='block_size') + parser.add_argument("--batch", type=int, default=64, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") args = parser.parse_args() - main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, - args.sparse_ratio, args.block_size) + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/heuristic.py b/examples/blocksparse_attention/heuristic.py index b60a81dc3..0e6fc5281 100644 --- a/examples/blocksparse_attention/heuristic.py +++ b/examples/blocksparse_attention/heuristic.py @@ -1,8 +1,7 @@ import math -def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head, - is_causal_or_local, max_splits): +def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local, max_splits): """ Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency. diff --git a/examples/blocksparse_attention/regression_example_blocksparse_attention.py b/examples/blocksparse_attention/regression_example_blocksparse_attention.py new file mode 100644 index 000000000..26fa60df5 --- /dev/null +++ b/examples/blocksparse_attention/regression_example_blocksparse_attention.py @@ -0,0 +1,20 @@ +import tilelang.testing +import example_tilelang_block_sparse_attn +import example_tilelang_sparse_gqa_decode_varlen_indice +import example_tilelang_sparse_gqa_decode_varlen_mask + + +def regression_example_tilelang_block_sparse_attn(): + tilelang.testing.process_func(example_tilelang_block_sparse_attn.run_regression_perf) + + +def regression_example_tilelang_sparse_gqa_decode_varlen_indice(): + tilelang.testing.process_func(example_tilelang_sparse_gqa_decode_varlen_indice.run_regression_perf, batch=1, max_cache_seqlen=2048) + + +def regression_example_tilelang_sparse_gqa_decode_varlen_mask(): + tilelang.testing.process_func(example_tilelang_sparse_gqa_decode_varlen_mask.run_regression_perf, batch=1, max_cache_seqlen=2048) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/blocksparse_attention/test_example_blocksparse_attention.py b/examples/blocksparse_attention/test_example_blocksparse_attention.py index 88527f7b3..dd33f46c4 100644 --- a/examples/blocksparse_attention/test_example_blocksparse_attention.py +++ b/examples/blocksparse_attention/test_example_blocksparse_attention.py @@ -25,26 +25,14 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask(): def test_example_triton_sparse_gqa_decode_varlen_indice(): example_triton_sparse_gqa_decode_varlen_indice.main( - batch=16, - heads=16, - heads_kv=8, - max_cache_seqlen=4096, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32) + batch=8, heads=8, heads_kv=4, max_cache_seqlen=2048, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32 + ) def test_example_triton_sparse_gqa_decode_varlen_mask(): example_triton_sparse_gqa_decode_varlen_mask.main( - batch=16, - heads=16, - heads_kv=8, - max_cache_seqlen=4096, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32) + batch=16, heads=16, heads_kv=8, max_cache_seqlen=1024, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32 + ) if __name__ == "__main__": diff --git a/examples/blocksparse_gemm/example_blocksparse_gemm.py b/examples/blocksparse_gemm/example_blocksparse_gemm.py index 7b9cff7c1..178cc5984 100644 --- a/examples/blocksparse_gemm/example_blocksparse_gemm.py +++ b/examples/blocksparse_gemm/example_blocksparse_gemm.py @@ -6,6 +6,7 @@ from tilelang.utils.tensor import get_tensor_supply, TensorSupplyType import torch from typing import List +from tilelang.profiler import do_bench DEFAULT_BLOCK_M = 128 DEFAULT_BLOCK_N = 128 @@ -19,8 +20,7 @@ parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N") parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K") parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio (0-1)") -parser.add_argument( - "--use_autotune", action="store_true", default=False, help="Whether to use autotune") +parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune") args, _ = parser.parse_known_args() M, N, K = args.m, args.n, args.k @@ -41,17 +41,19 @@ def get_configs(): thread_num = [128, 256] enable_rasterization = [True, False] - _configs = list( - itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasterization)) + _configs = list(itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasterization)) - return [{ - "block_M": c[0], - "block_N": c[1], - "block_K": c[2], - "num_stages": c[3], - "thread_num": c[4], - "enable_rasteration": c[5], - } for c in _configs] + return [ + { + "block_M": c[0], + "block_N": c[1], + "block_K": c[2], + "num_stages": c[3], + "thread_num": c[4], + "enable_rasteration": c[5], + } + for c in _configs + ] def ref_program(A, B, BlockMask, block_M, block_N, block_K): @@ -61,12 +63,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K): accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device) for k in range(K // block_K): if BlockMask[i, j, k]: - accu += ( - A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to( - torch.float32) @ B[k * block_K:(k + 1) * block_K, - j * block_N:(j + 1) * block_N].to(torch.float32)) - ref_c[i * block_M:(i + 1) * block_M, - j * block_N:(j + 1) * block_N] = accu.to(torch.float16) + accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[ + k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N + ].to(torch.float32) + ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16) return ref_c @@ -89,28 +89,21 @@ def supply_program(params: List[KernelParam]): return input_tensors -@tilelang.autotune(configs=get_configs(),) +@tilelang.autotune( + configs=get_configs(), +) @tilelang.jit(out_idx=[-1]) -def blocksparse_matmul(M, - N, - K, - block_M, - block_N, - block_K, - num_stages, - thread_num, - enable_rasteration, - dtype="float16", - accum_dtype="float"): - +def blocksparse_matmul( + M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32 +): block_mask_shape = (M // block_M, N // block_N, K // block_K) @T.prim_func def block_sparse_matmul( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - BlockMask: T.Tensor(block_mask_shape, "bool"), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -134,7 +127,6 @@ def block_sparse_matmul( def main(): - # Initialize input matrices A and B on the GPU with half precision a = torch.randn(M, K).cuda().half() b = torch.randn(K, N).cuda().half() @@ -147,8 +139,7 @@ def main(): best_config = kernel.config best_latency = kernel.latency - block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config[ - "block_K"] + block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config["block_K"] print(f"Best Config: {best_config}") print(f"Sparsity Ratio: {sparsity}") @@ -163,10 +154,10 @@ def main(): block_K=DEFAULT_BLOCK_K, num_stages=DEFAULT_NUM_STAGES, thread_num=DEFAULT_THREAD_NUM, - enable_rasteration=DEFAULT_ENABLE_RASTERIZATION) + enable_rasteration=DEFAULT_ENABLE_RASTERIZATION, + ) block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})") - # Create block mask with desired sparsity mask_shape = (M // block_M, N // block_N, K // block_K) block_mask = torch.rand(mask_shape).cuda() > sparsity @@ -185,5 +176,32 @@ def main(): print(e) +def run_regression_perf(): + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + + kernel = blocksparse_matmul( + M, + N, + K, + block_M=DEFAULT_BLOCK_M, + block_N=DEFAULT_BLOCK_N, + block_K=DEFAULT_BLOCK_K, + num_stages=DEFAULT_NUM_STAGES, + thread_num=DEFAULT_THREAD_NUM, + enable_rasteration=DEFAULT_ENABLE_RASTERIZATION, + ) + block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K + mask_shape = (M // block_M, N // block_N, K // block_K) + block_mask = torch.rand(mask_shape).cuda() > sparsity + + def run_kernel_only(): + kernel(a, b, block_mask) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/blocksparse_gemm/regression_example_blocksparse_gemm.py b/examples/blocksparse_gemm/regression_example_blocksparse_gemm.py new file mode 100644 index 000000000..81900a00c --- /dev/null +++ b/examples/blocksparse_gemm/regression_example_blocksparse_gemm.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_blocksparse_gemm + + +def regression_example_blocksparse_gemm(): + tilelang.testing.process_func(example_blocksparse_gemm.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/cast/example_group_per_split_token_cast_to_fp8.py b/examples/cast/example_group_per_split_token_cast_to_fp8.py index 4c2f574c0..db6beab1e 100644 --- a/examples/cast/example_group_per_split_token_cast_to_fp8.py +++ b/examples/cast/example_group_per_split_token_cast_to_fp8.py @@ -5,8 +5,8 @@ from tilelang.utils.tensor import torch_assert_close # support bfloat16, float, float16 -dtype = "bfloat16" -accum_dtype = "float" +dtype = T.bfloat16 +accum_dtype = T.float32 @tilelang.jit(out_idx=[2, 3]) @@ -16,11 +16,13 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): fp8_max = 448.0 @T.prim_func - def group_per_split_token_cast(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor( - (BG,), "int32"), X_fp8: T.Tensor((BG, M_max, N), "float8_e4m3"), X_amax: T.Tensor( - (BG, M_max, T.ceildiv(N, group_size)), accum_dtype)): - with T.Kernel( - T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz): + def group_per_split_token_cast( + X: T.Tensor((M, N), dtype), + batch_sizes: T.Tensor((BG,), T.int32), + X_fp8: T.Tensor((BG, M_max, N), T.float8_e4m3fn), + X_amax: T.Tensor((BG, M_max, T.ceildiv(N, group_size)), accum_dtype), + ): + with T.Kernel(T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz): row = bx row_g_id = by bg = bz @@ -28,39 +30,29 @@ def group_per_split_token_cast(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor y_amax_local = T.alloc_fragment((blk_m,), accum_dtype) y_s_local = T.alloc_fragment((blk_m,), accum_dtype) y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype) - y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3") - row_offset = T.alloc_fragment((1,), "int32") + y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), T.float8_e4m3fn) + row_offset = T.alloc_var(dtype=T.int32) - T.annotate_layout({ - y_local: - T.Fragment( - y_local.shape, - forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32), - }) - - row_offset[0] = 0 + row_offset = 0 for i in T.serial(bg): - row_offset[0] += batch_sizes[i] + row_offset += batch_sizes[i] T.copy( - X[row_offset[0] + row * blk_m:row_offset[0] + (row + 1) * blk_m, - row_g_id * group_size:(row_g_id + 1) * group_size], y_local) + X[row_offset + row * blk_m : row_offset + (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size], + y_local, + ) T.reduce_absmax(y_local, y_amax_local, dim=1) for i in T.Parallel(blk_m): y_amax_local[i] = T.max(y_amax_local[i], 1e-4) - y_s_local[i] = T.if_then_else(row * blk_m + i < batch_sizes[bg], - y_amax_local[i] / fp8_max, 0) + y_s_local[i] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_amax_local[i] / fp8_max, 0) for i, j in T.Parallel(blk_m, group_size): y_q_local[i, j] = T.clamp(y_local[i, j] / y_s_local[i], fp8_min, fp8_max) T.copy(y_q_local, y_q_local_fp8) for i, j in T.Parallel(blk_m, group_size): - y_q_local_fp8[i, j] = T.if_then_else(row * blk_m + i < batch_sizes[bg], - y_q_local[i, j], 0) + y_q_local_fp8[i, j] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_q_local[i, j], 0) for i in T.Parallel(blk_m): X_amax[bg, row * blk_m + i, row_g_id] = y_s_local[i] - T.copy( - y_q_local_fp8, X_fp8[bg, row * blk_m:(row + 1) * blk_m, - row_g_id * group_size:(row_g_id + 1) * group_size]) + T.copy(y_q_local_fp8, X_fp8[bg, row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size]) return group_per_split_token_cast @@ -127,8 +119,7 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: return x.squeeze(0) if remove_dim else x # Normal layout requires transposing - aligned_x = torch.transpose( - torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) + aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) aligned_x[:, :m, :] = x aligned_x = aligned_x[:, :m, :] return aligned_x.squeeze(0) if remove_dim else aligned_x @@ -146,31 +137,35 @@ def ref_per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens x_fp8 = x_fp8.view(m, -1)[:, :n].contiguous() return x_fp8, (x_amax / 448.0).view(m, -1) -def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \ - Tuple[torch.Tensor, torch.Tensor]: + +def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # assert x.shape[0] == batch_sizes.sum() M_max = ceil_div(batch_sizes.max(), 128) * 128 split_x = torch.split(x, batch_sizes.tolist(), dim=0) padded_x = [torch.nn.functional.pad(t, (0, 0, 0, M_max - t.shape[0])) for t in split_x] num_groups, m, n = batch_sizes.shape[0], M_max, x.shape[1] - x_fp8 = (torch.empty((num_groups, m, n), device='cuda', dtype=torch.float8_e4m3fn), - torch.empty((num_groups, m, n // 128), device='cuda', dtype=torch.float)) + x_fp8 = ( + torch.empty((num_groups, m, n), device="cuda", dtype=torch.float8_e4m3fn), + torch.empty((num_groups, m, n // 128), device="cuda", dtype=torch.float), + ) for i in range(num_groups): x_fp8[0][i], x_fp8[1][i] = ref_per_token_cast_to_fp8(padded_x[i]) x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) return x_fp8 -def main(M=8192, N=8192, BG=2, blk_m=8): - if dtype == "float": +def main(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None): + if batch_sizes is None: + batch_sizes = [2048, 6144] + if dtype == T.float: x = torch.randn(M, N, device="cuda", dtype=torch.float32) - elif dtype == "float16": + elif dtype == T.float16: x = torch.randn(M, N, device="cuda", dtype=torch.float16) - elif dtype == "bfloat16": + elif dtype == T.bfloat16: x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) else: raise ValueError(f"Unsupported dtype: {dtype}") - batch_sizes = torch.tensor([2048, 6144], device="cuda", dtype=torch.int32) + batch_sizes = torch.tensor(batch_sizes, device="cuda", dtype=torch.int32) M_max = int(ceil_div(batch_sizes.max(), 128) * 128) print("batch_sizes:", batch_sizes) @@ -204,5 +199,35 @@ def run_torch(): print("Torch: {:.2f} ms".format(latency)) +def run_regression_perf(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None): + if batch_sizes is None: + batch_sizes = [2048, 6144] + if dtype == "float": + x = torch.randn(M, N, device="cuda", dtype=torch.float32) + elif dtype == "float16": + x = torch.randn(M, N, device="cuda", dtype=torch.float16) + elif dtype == "bfloat16": + x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) + else: + raise ValueError(f"Unsupported dtype: {dtype}") + batch_sizes = torch.tensor(batch_sizes, device="cuda", dtype=torch.int32) + M_max = int(ceil_div(batch_sizes.max(), 128) * 128) + + kernel = group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m) + + x_fp8, x_amax = kernel(x, batch_sizes) + x_fp8_ref, x_amax_ref = ref_program(x, batch_sizes) + + torch_assert_close(x_fp8.to(torch.float32), x_fp8_ref.to(torch.float32), rtol=0.01, atol=0.01) + torch_assert_close(x_amax, x_amax_ref, rtol=0.01, atol=0.01) + + from tilelang.profiler import do_bench + + def run_tilelang(): + kernel(x, batch_sizes) + + return do_bench(run_tilelang, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/cast/example_per_token_cast_to_fp8.py b/examples/cast/example_per_token_cast_to_fp8.py index 484a092f0..4b3730b4b 100644 --- a/examples/cast/example_per_token_cast_to_fp8.py +++ b/examples/cast/example_per_token_cast_to_fp8.py @@ -7,14 +7,15 @@ @tilelang.jit(out_idx=[1, 2]) def per_token_cast_to_fp8(M, N, blk_m): - dtype = "float" + dtype = T.float group_size = 128 fp8_min = -448.0 fp8_max = 448.0 @T.prim_func - def per_token_cast(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "float8_e4m3"), - X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype)): + def per_token_cast( + X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), T.float8_e4m3fn), X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype) + ): with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by): row = bx row_g_id = by @@ -22,18 +23,9 @@ def per_token_cast(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "float8_e y_amax_local = T.alloc_fragment((blk_m,), dtype) y_s_local = T.alloc_fragment((blk_m,), dtype) y_q_local = T.alloc_fragment((blk_m, group_size), dtype) - y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3") - - T.annotate_layout({ - y_local: - T.Fragment( - y_local.shape, - forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32), - }) - - T.copy( - X[row * blk_m:(row + 1) * blk_m, row_g_id * group_size:(row_g_id + 1) * group_size], - y_local) + y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), T.float8_e4m3fn) + + T.copy(X[row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size], y_local) T.reduce_absmax(y_local, y_amax_local, dim=1) for i in T.Parallel(blk_m): y_amax_local[i] = T.max(y_amax_local[i], 1e-4) @@ -43,9 +35,7 @@ def per_token_cast(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "float8_e T.copy(y_q_local, y_q_local_fp8) for i in T.Parallel(blk_m): X_amax[row * blk_m + i, row_g_id] = y_s_local[i] - T.copy( - y_q_local_fp8, X_fp8[row * blk_m:(row + 1) * blk_m, - row_g_id * group_size:(row_g_id + 1) * group_size]) + T.copy(y_q_local_fp8, X_fp8[row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size]) return per_token_cast @@ -102,16 +92,32 @@ def main(M=8192, N=8192, blk_m=8): print("Tile-lang: {:.2f} ms".format(latency)) from tilelang.profiler import do_bench - from example_triton_cast_to_fp8 import per_token_group_quant_fp8 - def run_triton(): - x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8( - x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False) - return x_fp8_triton_, x_amax_triton_ + # Triton fp8e4nv is only supported on Hopper (SM90) and later + major, _ = torch.cuda.get_device_capability() + if major >= 9: + from example_triton_cast_to_fp8 import per_token_group_quant_fp8 + + def run_triton(): + x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8(x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False) + return x_fp8_triton_, x_amax_triton_ + + x_fp8_triton, x_amax_triton = run_triton() + latency = do_bench(run_triton) + print("Triton: {:.2f} ms".format(latency)) + else: + print("Triton fp8e4nv benchmark skipped (requires SM90+)") + + +def run_regression_perf(M=8192, N=8192, blk_m=8): + kernel = per_token_cast_to_fp8(M, N, blk_m) + x = torch.randn(M, N, device="cuda", dtype=torch.float32) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(x) - x_fp8_triton, x_amax_triton = run_triton() - latency = do_bench(run_triton) - print("Triton: {:.2f} ms".format(latency)) + return do_bench(run_kernel_only, backend="cupti") if __name__ == "__main__": diff --git a/examples/cast/example_triton_cast_to_fp8.py b/examples/cast/example_triton_cast_to_fp8.py index cc56defe7..1859433f1 100644 --- a/examples/cast/example_triton_cast_to_fp8.py +++ b/examples/cast/example_triton_cast_to_fp8.py @@ -128,9 +128,7 @@ def per_token_group_quant_fp8( Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. """ - assert (x.shape[-1] % - group_size == 0), (f"the last dimension of `x` {x.shape[-1]} must be divisible " - f"by `group_size` {group_size}") + assert x.shape[-1] % group_size == 0, f"the last dimension of `x` {x.shape[-1]} must be divisible by `group_size` {group_size}" assert x.stride(-1) == 1, "`x` groups must be contiguous" finfo = torch.finfo(dtype) diff --git a/examples/cast/regression_example_cast.py b/examples/cast/regression_example_cast.py new file mode 100644 index 000000000..4bdfb99e7 --- /dev/null +++ b/examples/cast/regression_example_cast.py @@ -0,0 +1,17 @@ +import tilelang.testing +import example_group_per_split_token_cast_to_fp8 +import example_per_token_cast_to_fp8 + + +def regression_example_group_per_split_token_cast_to_fp8(): + tilelang.testing.process_func( + example_group_per_split_token_cast_to_fp8.run_regression_perf, M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896] + ) + + +def regression_example_per_token_cast_to_fp8(): + tilelang.testing.process_func(example_per_token_cast_to_fp8.run_regression_perf, M=2048, N=512, blk_m=8) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/cast/test_example_cast.py b/examples/cast/test_example_cast.py index 2f978c1d4..e8b10a797 100644 --- a/examples/cast/test_example_cast.py +++ b/examples/cast/test_example_cast.py @@ -4,11 +4,11 @@ def test_example_group_per_split_token_cast_to_fp8(): - example_group_per_split_token_cast_to_fp8.main(M=8192, N=2048, BG=2, blk_m=8) + example_group_per_split_token_cast_to_fp8.main(M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896]) def test_example_per_token_cast_to_fp8(): - example_per_token_cast_to_fp8.main(M=8192, N=2048, blk_m=8) + example_per_token_cast_to_fp8.main(M=2048, N=512, blk_m=8) if __name__ == "__main__": diff --git a/examples/compile_flags/usecase.py b/examples/compile_flags/usecase.py index 8451b04fc..80e2b784b 100644 --- a/examples/compile_flags/usecase.py +++ b/examples/compile_flags/usecase.py @@ -4,12 +4,11 @@ # @tilelang.jit(compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"]) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): @@ -36,8 +35,7 @@ def main( func = matmul(M, N, K, block_M, block_N, block_K) -jit_kernel = tilelang.compile( - func, out_idx=[2], target="cuda", compile_flags="-O3 --use_fast_math --expt-relaxed-constexpr") +jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags="-O3 --use_fast_math --expt-relaxed-constexpr") # or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"]) # or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3 --use_fast_math --expt-relaxed-constexpr"]) diff --git a/examples/conftest.py b/examples/conftest.py index 9f49d40a9..4010e0d83 100644 --- a/examples/conftest.py +++ b/examples/conftest.py @@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): "warnings", "error", } - if (sum( - len(terminalreporter.stats.get(k, [])) - for k in known_types.difference({"skipped", "deselected"})) == 0): + if sum(len(terminalreporter.stats.get(k, [])) for k in known_types.difference({"skipped", "deselected"})) == 0: terminalreporter.write_sep( "!", - (f"Error: No tests were collected. " - f"{dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"), + (f"Error: No tests were collected. {dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"), ) pytest.exit("No tests were collected.", returncode=5) diff --git a/examples/convolution/example_convolution.py b/examples/convolution/example_convolution.py index b2696ba8f..1599d3464 100644 --- a/examples/convolution/example_convolution.py +++ b/examples/convolution/example_convolution.py @@ -14,7 +14,6 @@ def check_hopper(): def ref_program(stride, padding, dilation): - def main(A, B): A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W @@ -26,38 +25,21 @@ def main(A, B): @tilelang.jit(out_idx=[2]) -def convolution(N, - C, - H, - W, - F, - K, - S, - D, - P, - block_M, - block_N, - block_K, - num_stages, - threads, - dtype="float16", - accum_dtype="float"): +def convolution(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 is_hopper = check_hopper() @T.prim_func def main( - data: T.Tensor((N, H, W, C), dtype), - kernel: T.Tensor((KH, KW, C, F), dtype), - out: T.Tensor((N, OH, OW, F), dtype), + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), ): - with T.Kernel( - T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by): data_shared = T.alloc_shared((block_M, block_K), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -66,12 +48,6 @@ def main( kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) - T.annotate_layout({ - out_shared: tilelang.layout.make_swizzled_layout(out_shared), - data_shared: tilelang.layout.make_swizzled_layout(data_shared), - kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), - }) - T.clear(out_local) for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): if is_hopper: @@ -82,10 +58,8 @@ def main( m = by * block_M + i access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P access_w = m % OW * S + k // C % KW * D - P - in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and - (access_w < W)) - data_shared[i, j] = T.if_then_else( - in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W) + data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) T.gemm(data_shared, kernel_shared, out_local) @@ -97,15 +71,15 @@ def main( def main(argv=None): parser = argparse.ArgumentParser() - parser.add_argument('--n', type=int, default=128, help='n') - parser.add_argument('--c', type=int, default=128, help='c') - parser.add_argument('--h', type=int, default=64, help='h') - parser.add_argument('--w', type=int, default=64, help='w') - parser.add_argument('--f', type=int, default=128, help='f') - parser.add_argument('--k', type=int, default=3, help='k') - parser.add_argument('--s', type=int, default=1, help='s') - parser.add_argument('--d', type=int, default=1, help='d') - parser.add_argument('--p', type=int, default=1, help='p') + parser.add_argument("--n", type=int, default=128, help="n") + parser.add_argument("--c", type=int, default=128, help="c") + parser.add_argument("--h", type=int, default=64, help="h") + parser.add_argument("--w", type=int, default=64, help="w") + parser.add_argument("--f", type=int, default=128, help="f") + parser.add_argument("--k", type=int, default=3, help="k") + parser.add_argument("--s", type=int, default=1, help="s") + parser.add_argument("--d", type=int, default=1, help="d") + parser.add_argument("--p", type=int, default=1, help="p") args = parser.parse_args(argv) N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p @@ -125,5 +99,30 @@ def main(argv=None): print("All checks passed.✅") +def run_regression_perf(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument("--n", type=int, default=128, help="n") + parser.add_argument("--c", type=int, default=128, help="c") + parser.add_argument("--h", type=int, default=64, help="h") + parser.add_argument("--w", type=int, default=64, help="w") + parser.add_argument("--f", type=int, default=128, help="f") + parser.add_argument("--k", type=int, default=3, help="k") + parser.add_argument("--s", type=int, default=1, help="s") + parser.add_argument("--d", type=int, default=1, help="d") + parser.add_argument("--p", type=int, default=1, help="p") + + args = parser.parse_args(argv) + N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p + + block_m = 64 + block_n = 128 + block_k = 32 + num_stages = 3 + threads = 256 + kernel = convolution(N, C, H, W, F, K, S, D, P, block_m, block_n, block_k, num_stages, threads) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/convolution/example_convolution_autotune.py b/examples/convolution/example_convolution_autotune.py index 393677489..c0c666402 100644 --- a/examples/convolution/example_convolution_autotune.py +++ b/examples/convolution/example_convolution_autotune.py @@ -14,7 +14,6 @@ def check_hopper(): def ref_program(stride, padding, dilation): - def main(A, B): A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W @@ -40,7 +39,8 @@ def get_configs(): num_stages, thread_num, enable_rasterization, - )) + ) + ) configs = [ { @@ -50,7 +50,8 @@ def get_configs(): "num_stages": c[3], "thread_num": c[4], "enable_rasteration": c[5], # keep param name for backward-compat - } for c in _configs + } + for c in _configs ] return configs @@ -64,69 +65,32 @@ def get_heuristic_config() -> dict: sm_version = sm_major * 10 + sm_minor print(f"CUDA device capability: {sm_version}") if sm_version in {80}: - return { - "block_M": 128, - "block_N": 256, - "block_K": 32, - "num_stages": 2, - "thread_num": 128, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True} elif sm_version in {90}: - return { - "block_M": 128, - "block_N": 256, - "block_K": 64, - "num_stages": 3, - "thread_num": 256, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True} else: - return { - "block_M": 128, - "block_N": 256, - "block_K": 32, - "num_stages": 0, - "thread_num": 128, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True} @tilelang.autotune(configs=get_configs()) @tilelang.jit(out_idx=[2]) -def convolution(N, - C, - H, - W, - F, - K, - S, - D, - P, - block_M, - block_N, - block_K, - num_stages, - thread_num, - enable_rasteration, - dtype="float16", - accum_dtype="float"): +def convolution( + N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32 +): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 is_hopper = check_hopper() @T.prim_func def main( - data: T.Tensor((N, H, W, C), dtype), - kernel: T.Tensor((KH, KW, C, F), dtype), - out: T.Tensor((N, OH, OW, F), dtype), + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), ): - with T.Kernel( - T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), - threads=thread_num) as (bx, by): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=thread_num) as (bx, by): data_shared = T.alloc_shared((block_M, block_K), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -135,11 +99,6 @@ def main( kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) - if is_hopper: - T.annotate_layout({ - out_shared: tilelang.layout.make_swizzled_layout(out_shared), - }) - T.clear(out_local) for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): if is_hopper: @@ -150,10 +109,8 @@ def main( m = by * block_M + i access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P access_w = m % OW * S + k // C % KW * D - P - in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and - (access_w < W)) - data_shared[i, j] = T.if_then_else( - in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W) + data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) T.gemm(data_shared, kernel_shared, out_local) @@ -166,17 +123,19 @@ def main( return main -def main(n: int = 128, - c: int = 128, - h: int = 64, - w: int = 64, - f: int = 128, - k: int = 3, - s: int = 1, - d: int = 1, - p: int = 1, - use_autotune: bool = False, - with_roller: bool = True): +def main( + n: int = 128, + c: int = 128, + h: int = 64, + w: int = 64, + f: int = 128, + k: int = 3, + s: int = 1, + d: int = 1, + p: int = 1, + use_autotune: bool = False, + with_roller: bool = True, +): N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p ref_prog = ref_program(S, P, D) @@ -194,27 +153,38 @@ def main(n: int = 128, print(f"Ref latency: {ref_latency}") +def run_regression_perf( + n: int = 128, + c: int = 128, + h: int = 64, + w: int = 64, + f: int = 128, + k: int = 3, + s: int = 1, + d: int = 1, + p: int = 1, + use_autotune: bool = False, + with_roller: bool = True, +): + N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p + config = get_heuristic_config() + kernel = convolution(N, C, H, W, F, K, S, D, P, **config) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") - parser.add_argument('--n', type=int, default=128, help='n') - parser.add_argument('--c', type=int, default=128, help='c') - parser.add_argument('--h', type=int, default=64, help='h') - parser.add_argument('--w', type=int, default=64, help='w') - parser.add_argument('--f', type=int, default=128, help='f') - parser.add_argument('--k', type=int, default=3, help='k') - parser.add_argument('--s', type=int, default=1, help='s') - parser.add_argument('--d', type=int, default=1, help='d') - parser.add_argument('--p', type=int, default=1, help='p') - parser.add_argument( - "--use_autotune", - action="store_true", - default=False, - help="Whether to use autotune for matmul configs") - parser.add_argument( - "--with_roller", - action="store_true", - default=True, - help="Whether to enable BitBLAS roller for search space") + parser.add_argument("--n", type=int, default=128, help="n") + parser.add_argument("--c", type=int, default=128, help="c") + parser.add_argument("--h", type=int, default=64, help="h") + parser.add_argument("--w", type=int, default=64, help="w") + parser.add_argument("--f", type=int, default=128, help="f") + parser.add_argument("--k", type=int, default=3, help="k") + parser.add_argument("--s", type=int, default=1, help="s") + parser.add_argument("--d", type=int, default=1, help="d") + parser.add_argument("--p", type=int, default=1, help="p") + parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs") + parser.add_argument("--with_roller", action="store_true", default=True, help="Whether to enable BitBLAS roller for search space") args = parser.parse_args() - main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune, - args.with_roller) + main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune, args.with_roller) diff --git a/examples/convolution/regression_example_convolution.py b/examples/convolution/regression_example_convolution.py new file mode 100644 index 000000000..18d4bcb68 --- /dev/null +++ b/examples/convolution/regression_example_convolution.py @@ -0,0 +1,15 @@ +import tilelang.testing +import example_convolution +import example_convolution_autotune + + +def regression_example_convolution(): + tilelang.testing.process_func(example_convolution.run_regression_perf) + + +def regression_example_convolution_autotune(): + tilelang.testing.process_func(example_convolution_autotune.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py b/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py index 715f09a9b..18467a811 100644 --- a/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py +++ b/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py @@ -20,11 +20,11 @@ def tl_gemm( accum_dtype, ): assert in_dtype in [ - "float8_e4m3", + T.float8_e4m3fn, ], "Currently only float8_e4m3 is supported" assert out_dtype in [ - "bfloat16", - "float32", + T.bfloat16, + T.float32, ], "Currently only float16 and float32 are supported" group_size = 128 @@ -41,18 +41,17 @@ def tl_gemm( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - scales_a: T.Tensor(Scales_A_shape, "float32"), - scales_b: T.Tensor(Scales_B_shape, "float32"), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + scales_a: T.Tensor(Scales_A_shape, T.float32), + scales_b: T.Tensor(Scales_B_shape, T.float32), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_shared = T.alloc_shared(C_shared_shape, out_dtype) - Scale_C_shared = T.alloc_shared((block_M), "float32") + Scale_C_shared = T.alloc_shared((block_M), T.float32) C_local = T.alloc_fragment(C_shared_shape, accum_dtype) C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype) @@ -93,21 +92,18 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: m, n = x.shape x_view = x.view(m, -1, 128) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view( - m, n), (x_amax / 448.0).view(m, -1) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape - x_padded = torch.zeros( - ceildiv(m, 128) * 128, ceildiv(n, 128) * 128, dtype=x.dtype, device=x.device) + x_padded = torch.zeros(ceildiv(m, 128) * 128, ceildiv(n, 128) * 128, dtype=x.dtype, device=x.device) x_padded[:m, :n] = x x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view( - x_view.size(0), x_view.size(2)) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype): @@ -127,13 +123,14 @@ def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype): c_acc.zero_() for k in range(ceildiv(K, 128)): c = torch._scaled_mm( - A_fp8[i * 128:(i + 1) * 128, k * 128:(k + 1) * 128], - B_fp8[j * 128:(j + 1) * 128, k * 128:(k + 1) * 128].T, + A_fp8[i * 128 : (i + 1) * 128, k * 128 : (k + 1) * 128], + B_fp8[j * 128 : (j + 1) * 128, k * 128 : (k + 1) * 128].T, scale_a=A_scales[i, k].view(128, 1).contiguous(), scale_b=B_scales[j, k].view(1, 128).contiguous(), - out_dtype=torch.bfloat16) + out_dtype=torch.bfloat16, + ) c_acc += c.to(torch.float32) - C[i * 128:(i + 1) * 128, j * 128:(j + 1) * 128] = c_acc.to(out_dtype) + C[i * 128 : (i + 1) * 128, j * 128 : (j + 1) * 128] = c_acc.to(out_dtype) return C @@ -179,11 +176,11 @@ def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtyp def main(): - assert_tl_gemm_correctness(1024, 1024, 8192, 128, "float8_e4m3", "bfloat16", "float32") + assert_tl_gemm_correctness(1024, 1024, 8192, 128, T.float8_e4m3fn, T.bfloat16, T.float32) if __name__ == "__main__": - for dtype in ["float8_e4m3"]: - for out_dtype in ["bfloat16", "float32"]: + for dtype in [T.float8_e4m3fn]: + for out_dtype in [T.bfloat16, T.float32]: for block_N in [16, 32, 64, 128]: - assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, "float32") + assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, T.float32) diff --git a/examples/deepseek_mla/README.md b/examples/deepseek_mla/README.md index e64b1c37d..bd3539d26 100644 --- a/examples/deepseek_mla/README.md +++ b/examples/deepseek_mla/README.md @@ -24,14 +24,14 @@ We benchmarked the performance of FlashMLA, TileLang, Torch, Triton, and FlashIn
Figure 2:Performance under batch size=128
-As shown in the results, TileLang achieves performance comparable to FlashMLA in most cases, significantly outperforming both FlashInfer and Triton. +As shown in the results, TileLang achieves performance comparable to FlashMLA in most cases, significantly outperforming both FlashInfer and Triton. Notably, **TileLang accomplishes this with just around 80 lines of Python code**, demonstrating its exceptional ease of use and efficiency. Let's dive in and see how TileLang achieves this. ## Implementation First, let's review the core computation logic of traditional FlashAttention: -```python +```python # acc_s: [block_M, block_N] # scores_max: [block_M] # scores_scale: [block_M] @@ -54,7 +54,7 @@ Compared to traditional attention operators like MHA (Multi-Headed Attention) or This raises the question of how to partition the matrix multiplication operation. On the Hopper architecture, most computation kernels use [`wgmma.mma_async`](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) instructions for optimal performance. The `wgmma.mma_async` instruction organizes 4 warps (128 threads) into a warpgroup for collective MMA operations. However, `wgmma.mma_async` instructions require a minimum M dimension of 64. This means each warpgroup's minimum M dimension can only be reduced to 64, but a tile size of 64*512 is too large for a single warpgroup, leading to register spilling. -Therefore, our only option is to partition `acc_o` along the `dim` dimension, with two warpgroups computing the left and right part of `acc_o` respectively. However, this introduces another challenge: both warpgroups require the complete `acc_s` result as input. +Therefore, our only option is to partition `acc_o` along the `dim` dimension, with two warpgroups computing the left and right part of `acc_o` respectively. However, this introduces another challenge: both warpgroups require the complete `acc_s` result as input. Our solution is to have each warpgroup compute half of `acc_s` during `Q @ K` computation, then obtain the other half computed by the other warpgroup through shared memory. @@ -96,7 +96,6 @@ T.use_swizzle(panel_size: int, order: str = "row") Here, `panel_size` specifies the width of the swizzled threadblock group, and `order` determines the swizzling pattern, which can be either "row" or "col". - ### Shared Memory Swizzling In CUDA programming, shared memory is divided into multiple memory banks, with each bank capable of servicing one thread request per clock cycle in parallel. Bank conflicts occur when multiple threads simultaneously access different addresses mapped to the same bank, forcing these accesses to be serialized and degrading performance. @@ -113,17 +112,14 @@ T.annotate_layout({ Here, `T.annotate_layout` allows users to specify any desired layout for a buffer. For convenience, TileLang provides the `make_swizzled_layout` primitive to automatically generate a swizzled layout. - ### Warp-Specialization The Hopper architecture commonly employs warp specialization for performance optimization. A typical approach is to designate one warpgroup as a producer that handles data movement using TMA (Tensor Memory Accelerator), while the remaining warpgroups serve as consumers performing computations. However, this programming pattern is complex, requiring developers to manually manage the execution logic for producers and consumers, including synchronization through the `mbarrier` objects. In TileLang, users are completely shielded from these implementation details. The frontend script is automatically transformed into a warp-specialized form, where TileLang handles all producer-consumer synchronization automatically, enabling efficient computation. - ### Pipeline - Pipeline is a technique used to improve memory access efficiency by overlapping memory access and computation. In TileLang, pipeline can be implemented through the `T.pipelined` annotation: ```python @@ -132,9 +128,8 @@ T.pipelined(range: int, stage: int) Here, `range` specifies the range of the pipeline, and `stage` specifies the stage of the pipeline. Multi-stage pipelining enables overlapping of computation and memory access, which can significantly improve performance for memory-intensive operators. However, setting a higher number of stages consumes more shared memory resources, so the optimal configuration needs to be determined based on specific use cases. - ### Split-KV We have also implemented Split-KV optimization similar to [FlashDecoding](https://pytorch.org/blog/flash-decoding/). Specifically, when the batch size is small, parallel SM resources cannot be fully utilized due to low parallelism. In such cases, we can split the kv_ctx dimension across multiple SMs for parallel computation and then merge the results. -In our implementation, we have developed both split and combine kernels, allowing users to control the split size through a `num_split` parameter. \ No newline at end of file +In our implementation, we have developed both split and combine kernels, allowing users to control the split size through a `num_split` parameter. diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py index db460437f..dccf333ad 100644 --- a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py @@ -8,6 +8,7 @@ def get_configs(): import itertools + BLOCK_N = [16, 32, 64, 128] BLOCK_H = [16, 32, 64, 128] num_split = [1, 2, 4, 8, 16, 32] @@ -15,45 +16,44 @@ def get_configs(): _configs = list(itertools.product(BLOCK_N, BLOCK_H, num_split, threads)) - return [{ - "block_N": c[0], - "block_H": c[1], - "num_split": c[2], - "threads": c[3], - } for c in _configs] + return [ + { + "block_N": c[0], + "block_H": c[1], + "num_split": c[2], + "threads": c[3], + } + for c in _configs + ] @tilelang.autotune(configs=get_configs()) @tilelang.jit( - out_idx=[6], pass_configs={ + out_idx=[6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashmla_decode(batch, - heads, - kv_head_num, - seqlen_kv, - dim, - pe_dim, - block_N, - block_H, - num_split, - threads=128): - scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) - dtype = "float16" - accum_dtype = "float" + }, +) +def flashmla_decode(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, threads=128): + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // kv_head_num VALID_BLOCK_H = min(block_H, kv_group_num) assert kv_head_num == 1, "kv_head_num must be 1" - @T.macro - def flash_attn( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + @T.prim_func + def main_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): - with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=threads) as (bx, by): + # flash_attn_split + with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=threads) as (bx, by, bz): Q_local = T.alloc_fragment([block_H, dim], dtype) Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype) KV_shared = T.alloc_shared([block_N, dim], dtype) @@ -69,34 +69,31 @@ def flash_attn( cur_kv_head = by // (kv_group_num // block_H) T.use_swizzle(10) - - T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_local) - T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_local) + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_local) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_local) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = T.ceildiv(seqlen_kv, block_N) + loop_range = T.ceildiv((seqlen_kv // num_split), block_N) for k in T.Pipelined(loop_range, num_stages=0): - T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared) - T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared) + kv_start = (seqlen_kv // num_split) * bz + k * block_N + kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N + T.copy(KV[bx, kv_start:kv_end, cur_kv_head, :], KV_shared) + T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared) T.clear(acc_s) T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.gemm( - Q_pe_local, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_pe_local, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) - # T.copy(acc_s, S_shared) T.copy(acc_s, acc_s_cast) for i in T.Parallel(block_H): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] @@ -105,20 +102,50 @@ def flash_attn( T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :]) + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, glse[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz]) + T.copy(acc_o, Output_partial[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz, :]) + + # combine + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim], dtype) + o_accum_local = T.alloc_fragment([dim], accum_dtype) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local = T.max(lse_max_local, glse[bz, by, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split = glse[bz, by, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local + for k in T.serial(num_split): + for i in T.Parallel(dim): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split = glse[bz, by, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) + for i in T.Parallel(dim): + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dim): + Output[bz, by, i] = o_accum_local[i] - @T.macro - def flash_attn_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): - with T.Kernel( - batch, heads // min(block_H, kv_group_num), num_split, - threads=threads) as (bx, by, bz): + with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=threads) as (bx, by): Q_local = T.alloc_fragment([block_H, dim], dtype) Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype) KV_shared = T.alloc_shared([block_N, dim], dtype) @@ -134,34 +161,31 @@ def flash_attn_split( cur_kv_head = by // (kv_group_num // block_H) T.use_swizzle(10) - T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_local) - T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_local) + + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_local) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_local) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + loop_range = T.ceildiv(seqlen_kv, block_N) for k in T.Pipelined(loop_range, num_stages=0): - kv_start = (seqlen_kv // num_split) * bz + k * block_N - kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N - T.copy(KV[bx, kv_start:kv_end, cur_kv_head, :], KV_shared) - T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared) + T.copy(KV[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared) T.clear(acc_s) T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.gemm( - Q_pe_local, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_pe_local, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) + # T.copy(acc_s, S_shared) T.copy(acc_s, acc_s_cast) for i in T.Parallel(block_H): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] @@ -170,72 +194,7 @@ def flash_attn_split( T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] - for i in T.Parallel(block_H): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, glse[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz]) - T.copy(acc_o, Output_partial[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz, :]) - - @T.macro - def combine( - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), - ): - with T.Kernel(heads, batch, threads=128) as (by, bz): - po_local = T.alloc_fragment([dim], dtype) - o_accum_local = T.alloc_fragment([dim], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - - T.annotate_layout({ - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) - for k in T.serial(num_split): - lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) - for k in T.Pipelined(num_split, num_stages=1): - lse_local_split[0] = glse[bz, by, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] - for k in T.serial(num_split): - for i in T.Parallel(dim): - po_local[i] = Output_partial[bz, by, k, i] - lse_local_split[0] = glse[bz, by, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) - for i in T.Parallel(dim): - o_accum_local[i] += po_local[i] * scale_local[0] - for i in T.Parallel(dim): - Output[bz, by, i] = o_accum_local[i] - - @T.prim_func - def main_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), - ): - flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial) - combine(glse, Output_partial, Output) - - @T.prim_func - def main_no_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), - ): - flash_attn(Q, Q_pe, KV, K_pe, Output) + T.copy(acc_o, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :]) if num_split > 1: return main_split @@ -258,43 +217,36 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=128, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') - parser.add_argument('--autotune', action='store_true', help='auto tune') + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") + parser.add_argument("--autotune", action="store_true", help="auto tune") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim enable_autotune = args.autotune @@ -310,17 +262,7 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): if enable_autotune: kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim) else: - kernel = flashmla_decode( - batch, - heads, - kv_heads, - kv_ctx, - dim, - pe_dim, - BLOCK_N, - BLOCK_H, - num_split, - threads=threads) + kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, threads=threads) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) input_tensors = profiler._get_inputs() tilelang_output = kernel(*input_tensors) diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py index 0006d9468..18c0a5f86 100644 --- a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py @@ -32,8 +32,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): @torch.inference_mode() -def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, - h_kv, d, dv, causal, dtype): +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] def ref_mla(): @@ -94,8 +93,7 @@ def _mla_attn_kernel( offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) - offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[ - None, :] + offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] q_nope = tl.load(Q_nope + offs_q_nope) offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) @@ -141,9 +139,7 @@ def _mla_attn_kernel( e_sum = e_sum * re_scale + tl.sum(p, 1) e_max = n_e_max - offs_o = cur_batch * stride_o_b + cur_head[:, - None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[ - None, :] + offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] tl.store(O + offs_o, acc / e_sum[:, None]) offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV tl.store(O + offs_o_1, e_max + tl.log(e_sum)) @@ -309,24 +305,30 @@ def mla_decode_triton( @torch.inference_mode() -def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, - cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - +def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() def flash_mla_triton(): num_kv_splits = 32 o = torch.empty([b * s_q, h_q, dv]) attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) mla_decode_triton( - q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv), - blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits, - num_kv_splits, 1 / math.sqrt(d), block_size) + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, d - dv), + blocked_k_nope.view(-1, dv), + blocked_k_pe.view(-1, d - dv), + o, + block_table, + cache_seqlens, + attn_logits, + num_kv_splits, + 1 / math.sqrt(d), + block_size, + ) return o.view([b, s_q, h_q, dv]) out_flash = flash_mla_triton() @@ -362,14 +364,15 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_a, lse_a, perf_a = baseline_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" if target not in ["flash_mla_triton"]: @@ -377,21 +380,14 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s" - ) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - print( - f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" - ) + print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") torch.set_default_dtype(dtype) device = torch.device("cuda:0") torch.set_default_device(device) @@ -408,19 +404,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") return bytes / 10**6 / perf_b @@ -429,26 +422,22 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): "flash_mla_triton", ] -shape_configs = [{ - "b": - batch, - "s_q": - 1, - "cache_seqlens": - torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), - "h_q": - head, - "h_kv": - 1, - "d": - 512 + 64, - "dv": - 512, - "causal": - True, - "dtype": - torch.float16 -} for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 16384] for head in [128]] +shape_configs = [ + { + "b": batch, + "s_q": 1, + "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), + "h_q": head, + "h_kv": 1, + "d": 512 + 64, + "dv": 512, + "causal": True, + "dtype": torch.float16, + } + for batch in [128] + for seqlen in [1024, 2048, 4096, 8192, 16384] + for head in [128] +] def get_args(): @@ -470,26 +459,54 @@ def get_args(): for shape in shape_configs: if args.all: for target in available_targets: - perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) elif args.compare: - perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], - shape["cache_seqlens"], shape["h_q"], shape["h_kv"], - shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + perfa, prefb = compare_ab( + args.baseline, + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n' + f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n" ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n" ) elif args.one: - perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py index 644f97da1..861e841c4 100644 --- a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py @@ -29,8 +29,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): @torch.inference_mode() -def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, - h_kv, d, dv, causal, dtype): +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] def ref_mla(): @@ -91,8 +90,7 @@ def _mla_attn_kernel( offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) - offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[ - None, :] + offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] q_nope = tl.load(Q_nope + offs_q_nope) offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) @@ -138,9 +136,7 @@ def _mla_attn_kernel( e_sum = e_sum * re_scale + tl.sum(p, 1) e_max = n_e_max - offs_o = cur_batch * stride_o_b + cur_head[:, - None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[ - None, :] + offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] tl.store(O + offs_o, acc / e_sum[:, None]) offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV tl.store(O + offs_o_1, e_max + tl.log(e_sum)) @@ -306,24 +302,30 @@ def mla_decode_triton( @torch.inference_mode() -def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, - cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - +def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() def flash_mla_triton(): num_kv_splits = 32 o = torch.empty([b * s_q, h_q, dv]) attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) mla_decode_triton( - q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv), - blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits, - num_kv_splits, 1 / math.sqrt(d), block_size) + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, d - dv), + blocked_k_nope.view(-1, dv), + blocked_k_pe.view(-1, d - dv), + o, + block_table, + cache_seqlens, + attn_logits, + num_kv_splits, + 1 / math.sqrt(d), + block_size, + ) return o.view([b, s_q, h_q, dv]) out_flash = flash_mla_triton() @@ -359,14 +361,15 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_a, lse_a, perf_a = baseline_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" if target not in ["flash_mla_triton"]: @@ -374,21 +377,14 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s" - ) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - print( - f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" - ) + print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") torch.set_default_dtype(dtype) device = torch.device("cuda:0") torch.set_default_device(device) @@ -405,19 +401,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") return bytes / 10**6 / perf_b @@ -426,26 +419,22 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): "flash_mla_triton", ] -shape_configs = [{ - "b": - batch, - "s_q": - 1, - "cache_seqlens": - torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), - "h_q": - head, - "h_kv": - 1, - "d": - 512 + 64, - "dv": - 512, - "causal": - True, - "dtype": - torch.float16 -} for batch in [64, 128] for seqlen in [1024, 2048, 4096, 8192, 16384] for head in [128]] +shape_configs = [ + { + "b": batch, + "s_q": 1, + "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), + "h_q": head, + "h_kv": 1, + "d": 512 + 64, + "dv": 512, + "causal": True, + "dtype": torch.float16, + } + for batch in [64, 128] + for seqlen in [1024, 2048, 4096, 8192, 16384] + for head in [128] +] def get_args(): @@ -467,26 +456,54 @@ def get_args(): for shape in shape_configs: if args.all: for target in available_targets: - perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) elif args.compare: - perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], - shape["cache_seqlens"], shape["h_q"], shape["h_kv"], - shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + perfa, prefb = compare_ab( + args.baseline, + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n' + f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n" ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n" ) elif args.one: - perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) diff --git a/examples/deepseek_mla/benchmark_mla.py b/examples/deepseek_mla/benchmark_mla.py index a542ff611..544b5e128 100644 --- a/examples/deepseek_mla/benchmark_mla.py +++ b/examples/deepseek_mla/benchmark_mla.py @@ -33,8 +33,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): @torch.inference_mode() -def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, - h_kv, d, dv, causal, dtype): +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] def ref_mla(): @@ -61,8 +60,7 @@ def ref_mla(): @torch.inference_mode() -def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, - h_kv, d, dv, causal, dtype): +def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): from flash_mla import flash_mla_with_kvcache, get_mla_metadata blocked_v = blocked_k[..., :dv] @@ -87,14 +85,13 @@ def flash_mla(): @torch.inference_mode() -def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, - h_q, h_kv, d, dv, causal, dtype): +def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): # pip install flashinfer-python import flashinfer + assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() kv_indptr = [0] kv_indices = [] @@ -111,8 +108,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) - mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( - torch.empty(128 * 1024 * 1024, dtype=torch.int8), backend="fa3") + mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(torch.empty(128 * 1024 * 1024, dtype=torch.int8), backend="fa3") mla_wrapper.plan( q_indptr, kv_indptr, @@ -129,12 +125,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q ) def flashinfer(): - output, lse = mla_wrapper.run( - q_nope.view(-1, h_q, dv), - q_pe.view(-1, h_q, d - dv), - blocked_k_nope, - blocked_k_pe, - return_lse=True) + output, lse = mla_wrapper.run(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope, blocked_k_pe, return_lse=True) return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1) out_flash, lse_flash = flashinfer() @@ -177,8 +168,7 @@ def _mla_attn_kernel( offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) - offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[ - None, :] + offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] q_nope = tl.load(Q_nope + offs_q_nope) offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) @@ -224,9 +214,7 @@ def _mla_attn_kernel( e_sum = e_sum * re_scale + tl.sum(p, 1) e_max = n_e_max - offs_o = cur_batch * stride_o_b + cur_head[:, - None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[ - None, :] + offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] tl.store(O + offs_o, acc / e_sum[:, None]) offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV tl.store(O + offs_o_1, e_max + tl.log(e_sum)) @@ -393,24 +381,30 @@ def mla_decode_triton( @torch.inference_mode() -def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, - cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - +def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() def flash_mla_triton(): num_kv_splits = 32 o = torch.empty([b * s_q, h_q, dv]) attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) mla_decode_triton( - q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv), - blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits, - num_kv_splits, 1 / math.sqrt(d), block_size) + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, d - dv), + blocked_k_nope.view(-1, dv), + blocked_k_pe.view(-1, d - dv), + o, + block_table, + cache_seqlens, + attn_logits, + num_kv_splits, + 1 / math.sqrt(d), + block_size, + ) return o.view([b, s_q, h_q, dv]) out_flash = flash_mla_triton() @@ -419,13 +413,10 @@ def flash_mla_triton(): @torch.inference_mode() -def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, - cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - +def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() dpe = d - dv num_kv_splits = 1 @@ -434,8 +425,7 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) - kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, - num_kv_splits, block_size) + kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size) def flash_mla_tilelang(): out = kernel( @@ -486,38 +476,31 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_a, lse_a, perf_a = baseline_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" - if target not in ["flashinfer", "flash_mla_triton", "tilelang" - ] and baseline not in ["flashinfer", "flash_mla_triton", "tilelang"]: + if target not in ["flashinfer", "flash_mla_triton", "tilelang"] and baseline not in ["flashinfer", "flash_mla_triton", "tilelang"]: # flashinfer has a different lse return value # flash_mla_triton and flash_mla_tilelang doesn't return lse torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s" - ) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - print( - f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" - ) + print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") torch.set_default_dtype(dtype) device = torch.device("cuda:0") torch.set_default_device(device) @@ -534,19 +517,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") return bytes / 10**6 / perf_b @@ -558,26 +538,22 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): "flash_mla_triton", ] -shape_configs = [{ - "b": - batch, - "s_q": - 1, - "cache_seqlens": - torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), - "h_q": - head, - "h_kv": - 1, - "d": - 512 + 64, - "dv": - 512, - "causal": - True, - "dtype": - torch.float16 -} for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 16384, 32768] for head in [128]] +shape_configs = [ + { + "b": batch, + "s_q": 1, + "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), + "h_q": head, + "h_kv": 1, + "d": 512 + 64, + "dv": 512, + "causal": True, + "dtype": torch.float16, + } + for batch in [128] + for seqlen in [1024, 2048, 4096, 8192, 16384, 32768] + for head in [128] +] def get_args(): @@ -599,26 +575,54 @@ def get_args(): for shape in shape_configs: if args.all: for target in available_targets: - perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) elif args.compare: - perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], - shape["cache_seqlens"], shape["h_q"], shape["h_kv"], - shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + perfa, prefb = compare_ab( + args.baseline, + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n' + f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n" ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n" ) elif args.one: - perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) diff --git a/examples/deepseek_mla/example_mla_decode.py b/examples/deepseek_mla/example_mla_decode.py index e1dd0b4d6..7de4faf08 100644 --- a/examples/deepseek_mla/example_mla_decode.py +++ b/examples/deepseek_mla/example_mla_decode.py @@ -10,27 +10,31 @@ @tilelang.jit( - out_idx=[6], pass_configs={ + out_idx=[6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, - softmax_scale): + }, +) +def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale): scale = float(softmax_scale * 1.44269504) # log2(e) - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // kv_head_num VALID_BLOCK_H = min(block_H, kv_group_num) assert kv_head_num == 1, "kv_head_num must be 1" - @T.macro - def flash_attn( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + @T.prim_func + def main_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): - with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid): + # flash_attn_split + with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=256) as (bid, hid, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) S_shared = T.alloc_shared([block_H, block_N], dtype) Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) @@ -38,6 +42,7 @@ def flash_attn( K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype) O_shared = T.alloc_shared([block_H, dim], dtype) acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) acc_o = T.alloc_fragment([block_H, dim], accum_dtype) scores_max = T.alloc_fragment([block_H], accum_dtype) scores_max_prev = T.alloc_fragment([block_H], accum_dtype) @@ -46,64 +51,87 @@ def flash_attn( logsum = T.alloc_fragment([block_H], accum_dtype) cur_kv_head = hid // (kv_group_num // block_H) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - }) + T.use_swizzle(10) - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = T.ceildiv(seqlen_kv, block_N) + loop_range = T.ceildiv((seqlen_kv // num_split), block_N) for k in T.Pipelined(loop_range, num_stages=2): - T.copy(KV[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared) - T.copy(K_pe[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared) - T.gemm( - Q_shared, - KV_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol, - clear_accum=True) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + kv_start = (seqlen_kv // num_split) * bz + k * block_N + kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N + T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared) + T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared) + T.clear(acc_s) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) T.copy(acc_s, S_shared) + T.copy(S_shared, acc_s_cast) for i in T.Parallel(block_H): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_H, dim): acc_o[i, j] *= scores_scale[i] - T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz]) T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :]) + T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, :]) + + # combine + with T.Kernel(heads, batch, threads=128) as (hid, bz): + po_local = T.alloc_fragment([dim], dtype) + o_accum_local = T.alloc_fragment([dim], accum_dtype) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local = T.max(lse_max_local, glse[bz, hid, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split = glse[bz, hid, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local + for k in T.serial(num_split): + for i in T.Parallel(dim): + po_local[i] = Output_partial[bz, hid, k, i] + lse_local_split = glse[bz, hid, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) + for i in T.Parallel(dim): + o_accum_local[i] += po_local[i] * scale_local + for i in T.Parallel(dim): + Output[bz, hid, i] = o_accum_local[i] - @T.macro - def flash_attn_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): - with T.Kernel( - batch, heads // min(block_H, kv_group_num), num_split, - threads=256) as (bid, hid, bz): + with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid): Q_shared = T.alloc_shared([block_H, dim], dtype) S_shared = T.alloc_shared([block_H, block_N], dtype) Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) @@ -111,7 +139,6 @@ def flash_attn_split( K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype) O_shared = T.alloc_shared([block_H, dim], dtype) acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) acc_o = T.alloc_fragment([block_H, dim], accum_dtype) scores_max = T.alloc_fragment([block_H], accum_dtype) scores_max_prev = T.alloc_fragment([block_H], accum_dtype) @@ -120,118 +147,39 @@ def flash_attn_split( logsum = T.alloc_fragment([block_H], accum_dtype) cur_kv_head = hid // (kv_group_num // block_H) - T.use_swizzle(10) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - S_shared: tilelang.layout.make_swizzled_layout(S_shared), - }) - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + loop_range = T.ceildiv(seqlen_kv, block_N) for k in T.Pipelined(loop_range, num_stages=2): - kv_start = (seqlen_kv // num_split) * bz + k * block_N - kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N - T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared) - T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared) - T.clear(acc_s) - T.gemm( - Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.copy(KV[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) T.copy(acc_s, S_shared) - T.copy(S_shared, acc_s_cast) for i in T.Parallel(block_H): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_H, dim): acc_o[i, j] *= scores_scale[i] - T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] - for i in T.Parallel(block_H): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, bz]) T.copy(acc_o, O_shared) - T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - bz, :]) - - @T.macro - def combine( - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), - ): - with T.Kernel(heads, batch, threads=128) as (hid, bz): - po_local = T.alloc_fragment([dim], dtype) - o_accum_local = T.alloc_fragment([dim], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - - T.annotate_layout({ - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) - for k in T.serial(num_split): - lse_max_local[0] = T.max(lse_max_local[0], glse[bz, hid, k]) - for k in T.Pipelined(num_split, num_stages=1): - lse_local_split[0] = glse[bz, hid, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] - for k in T.serial(num_split): - for i in T.Parallel(dim): - po_local[i] = Output_partial[bz, hid, k, i] - lse_local_split[0] = glse[bz, hid, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) - for i in T.Parallel(dim): - o_accum_local[i] += po_local[i] * scale_local[0] - for i in T.Parallel(dim): - Output[bz, hid, i] = o_accum_local[i] - - @T.prim_func - def main_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), - ): - flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial) - combine(glse, Output_partial, Output) - - @T.prim_func - def main_no_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), - ): - flash_attn(Q, Q_pe, KV, K_pe, Output) + T.copy(O_shared, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :]) if num_split > 1: return main_split @@ -254,31 +202,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out @@ -296,10 +237,9 @@ def main( BLOCK_N = 64 BLOCK_H = min(64, heads // kv_heads) num_split = 1 - softmax_scale = (dim + pe_dim)**-0.5 + softmax_scale = (dim + pe_dim) ** -0.5 - kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, - softmax_scale) + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, softmax_scale) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4) latency = profiler.do_bench(warmup=500) @@ -307,14 +247,33 @@ def main( print(f"TFlops: {total_flops / latency * 1e-9} TFlops") +def run_regression_perf( + batch=1, + heads=128, + kv_heads=1, + kv_ctx=8192, + dim=512, + pe_dim=64, +): + BLOCK_N = 64 + BLOCK_H = min(64, heads // kv_heads) + num_split = 1 + softmax_scale = (dim + pe_dim) ** -0.5 + + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, softmax_scale) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=132, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') + parser.add_argument("--batch", type=int, default=132, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) diff --git a/examples/deepseek_mla/example_mla_decode_paged.py b/examples/deepseek_mla/example_mla_decode_paged.py index fe50d4d4f..2e1911028 100644 --- a/examples/deepseek_mla/example_mla_decode_paged.py +++ b/examples/deepseek_mla/example_mla_decode_paged.py @@ -8,41 +8,36 @@ @tilelang.jit( - out_idx=[8], pass_configs={ + out_idx=[8], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def mla_decode_tilelang(batch, - h_q, - h_kv, - max_seqlen_pad, - dv, - dpe, - block_N, - block_H, - num_split, - block_size, - softmax_scale=None): + }, +) +def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split, block_size, softmax_scale=None): if softmax_scale is None: - softmax_scale = (dv + dpe)**-0.5 + softmax_scale = (dv + dpe) ** -0.5 scale = float(softmax_scale * 1.44269504) # log2(e) - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = h_q // h_kv VALID_BLOCK_H = min(block_H, kv_group_num) assert h_kv == 1, "h_kv must be 1" assert block_size >= block_N and block_size % block_N == 0, "block_size must be larger than block_N and a multiple of block_N" - @T.macro - def flash_mla_kernel( - Q: T.Tensor([batch, h_q, dv], dtype), - Q_pe: T.Tensor([batch, h_q, dpe], dtype), - KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), - K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), - BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), - CACHE_SEQLENS: T.Tensor([batch], "int32"), - Output: T.Tensor([batch, h_q, dv], dtype), + @T.prim_func + def main_split( + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + block_table: T.Tensor([batch, max_seqlen_pad // block_size], T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + Output: T.Tensor([batch, h_q, dv], dtype), ): - with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by): + # split kv + with T.Kernel(batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dv], dtype) S_shared = T.alloc_shared([block_H, block_N], dtype) Q_pe_shared = T.alloc_shared([block_H, dpe], dtype) @@ -50,6 +45,7 @@ def flash_mla_kernel( K_pe_shared = T.alloc_shared([block_N, dpe], dtype) O_shared = T.alloc_shared([block_H, dv], dtype) acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) acc_o = T.alloc_fragment([block_H, dv], accum_dtype) scores_max = T.alloc_fragment([block_H], accum_dtype) scores_max_prev = T.alloc_fragment([block_H], accum_dtype) @@ -59,69 +55,94 @@ def flash_mla_kernel( cur_kv_head = by // (kv_group_num // block_H) T.use_swizzle(10) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - S_shared: tilelang.layout.make_swizzled_layout(S_shared), - }) - T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = T.ceildiv(CACHE_SEQLENS[bx], block_N) - for kr in T.Pipelined(loop_range, num_stages=2): - k = loop_range - 1 - kr - kv_start = BLOCK_TABLE[bx, (k * block_N) // - block_size] * block_size + (k * block_N) % block_size - T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared) - T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared) + total_blocks = T.ceildiv(cache_seqlens[bx], block_N) + blocks_per_split = T.floordiv(total_blocks, num_split) + remaining_blocks = T.floormod(total_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0) + start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N + + for k in T.Pipelined(loop_range, num_stages=2): + kv_start = block_table[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_size + T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[kv_start : kv_start + block_N, cur_kv_head, :], K_pe_shared) T.clear(acc_s) - T.gemm( - Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) - if kr == 0: - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], - -T.infinity(accum_dtype), acc_s[i, j]) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(start + k * block_N + j >= cache_seqlens[bx], -T.infinity(accum_dtype), acc_s[i, j]) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) T.copy(acc_s, S_shared) + T.copy(S_shared, acc_s_cast) for i in T.Parallel(block_H): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_H, dv): acc_o[i, j] *= scores_scale[i] - T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) for i, j in T.Parallel(block_H, dv): acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, glse[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz]) T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :]) - - @T.macro - def flash_mla_split_kv_kernel( - Q: T.Tensor([batch, h_q, dv], dtype), - Q_pe: T.Tensor([batch, h_q, dpe], dtype), - KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), - K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), - BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), - CACHE_SEQLENS: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, h_q, num_split], dtype), - Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz, :]) + + # combine + with T.Kernel(h_q, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dv], dtype) + o_accum_local = T.alloc_fragment([dv], accum_dtype) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local = T.max(lse_max_local, glse[bz, by, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split = glse[bz, by, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local + for k in T.serial(num_split): + for i in T.Parallel(dv): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split = glse[bz, by, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) + for i in T.Parallel(dv): + o_accum_local[i] += po_local[i] * scale_local + for i in T.Parallel(dv): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + block_table: T.Tensor([batch, max_seqlen_pad // block_size], T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + Output: T.Tensor([batch, h_q, dv], dtype), ): - with T.Kernel( - batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz): + with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by): Q_shared = T.alloc_shared([block_H, dv], dtype) S_shared = T.alloc_shared([block_H, block_N], dtype) Q_pe_shared = T.alloc_shared([block_H, dpe], dtype) @@ -129,7 +150,6 @@ def flash_mla_split_kv_kernel( K_pe_shared = T.alloc_shared([block_N, dpe], dtype) O_shared = T.alloc_shared([block_H, dv], dtype) acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) acc_o = T.alloc_fragment([block_H, dv], accum_dtype) scores_max = T.alloc_fragment([block_H], accum_dtype) scores_max_prev = T.alloc_fragment([block_H], accum_dtype) @@ -139,129 +159,45 @@ def flash_mla_split_kv_kernel( cur_kv_head = by // (kv_group_num // block_H) T.use_swizzle(10) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - S_shared: tilelang.layout.make_swizzled_layout(S_shared), - }) - T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - total_blocks = T.ceildiv(CACHE_SEQLENS[bx], block_N) - blocks_per_split = T.floordiv(total_blocks, num_split) - remaining_blocks = T.floormod(total_blocks, num_split) - loop_range = (blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0)) - start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N - - for k in T.Pipelined(loop_range, num_stages=2): - kv_start = BLOCK_TABLE[bx, (start + k * block_N) // - block_size] * block_size + (k * block_N) % block_size - T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared) - T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared) + loop_range = T.ceildiv(cache_seqlens[bx], block_N) + for kr in T.Pipelined(loop_range, num_stages=2): + k = loop_range - 1 - kr + kv_start = block_table[bx, (k * block_N) // block_size] * block_size + (k * block_N) % block_size + T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[kv_start : kv_start + block_N, cur_kv_head, :], K_pe_shared) T.clear(acc_s) - T.gemm( - Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx], - -T.infinity(accum_dtype), acc_s[i, j]) + if kr == 0: + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= cache_seqlens[bx], -T.infinity(accum_dtype), acc_s[i, j]) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) T.copy(acc_s, S_shared) - T.copy(S_shared, acc_s_cast) for i in T.Parallel(block_H): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_H, dv): acc_o[i, j] *= scores_scale[i] - T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) for i, j in T.Parallel(block_H, dv): acc_o[i, j] /= logsum[i] - for i in T.Parallel(block_H): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, glse[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz]) T.copy(acc_o, O_shared) - T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz, :]) - - @T.macro - def combine( - glse: T.Tensor([batch, h_q, num_split], dtype), - Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), - Output: T.Tensor([batch, h_q, dv], dtype), - ): - with T.Kernel(h_q, batch, threads=128) as (by, bz): - po_local = T.alloc_fragment([dv], dtype) - o_accum_local = T.alloc_fragment([dv], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - - T.annotate_layout({ - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) - for k in T.serial(num_split): - lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) - for k in T.Pipelined(num_split, num_stages=1): - lse_local_split[0] = glse[bz, by, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] - for k in T.serial(num_split): - for i in T.Parallel(dv): - po_local[i] = Output_partial[bz, by, k, i] - lse_local_split[0] = glse[bz, by, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) - for i in T.Parallel(dv): - o_accum_local[i] += po_local[i] * scale_local[0] - for i in T.Parallel(dv): - Output[bz, by, i] = o_accum_local[i] - - @T.prim_func - def main_split( - Q: T.Tensor([batch, h_q, dv], dtype), - Q_pe: T.Tensor([batch, h_q, dpe], dtype), - KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), - K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), - block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, h_q, num_split], dtype), - Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), - Output: T.Tensor([batch, h_q, dv], dtype), - ): - flash_mla_split_kv_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, glse, - Output_partial) - combine(glse, Output_partial, Output) - - @T.prim_func - def main_no_split( - Q: T.Tensor([batch, h_q, dv], dtype), - Q_pe: T.Tensor([batch, h_q, dpe], dtype), - KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), - K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), - block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, h_q, num_split], dtype), - Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), - Output: T.Tensor([batch, h_q, dv], dtype), - ): - flash_mla_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, Output) + T.copy(O_shared, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :]) if num_split > 1: return main_split @@ -280,8 +216,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): s_q = query.shape[-2] s_k = key.shape[-2] attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype, device=query.device) - temp_mask = torch.ones( - s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) attn_weight += attn_bias @@ -291,8 +226,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): @torch.inference_mode() -def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, - h_kv, d, dv, causal, dtype): +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): # q: [b, s_q, h_q, d] # block_table: [b, max_seqlen_pad // block_size] # blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d] @@ -321,13 +255,10 @@ def ref_mla(): return out_torch -def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, - h_q, h_kv, d, dv, causal, dtype): - +def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() dpe = d - dv num_kv_splits = 1 @@ -337,8 +268,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) - kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, - num_kv_splits, block_size, softmax_scale) + kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size, softmax_scale) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) def flash_mla_tilelang(): @@ -356,8 +286,7 @@ def flash_mla_tilelang(): out_flash = flash_mla_tilelang() t = do_bench(flash_mla_tilelang) - out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, - cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) torch.testing.assert_close(out_flash, out_ref, rtol=0.01, atol=0.01) print("All close") return out_flash, t @@ -365,12 +294,12 @@ def flash_mla_tilelang(): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=128, help='batch size') - parser.add_argument('--h_q', type=int, default=128, help='q heads number') - parser.add_argument('--h_kv', type=int, default=1, help='kv heads number') - parser.add_argument('--cache_seqlen', type=int, default=8192, help='kv cache context length') - parser.add_argument('--d', type=int, default=576, help='query/key head dim, d = dv + dpe') - parser.add_argument('--dv', type=int, default=512, help='value head dim') + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--h_q", type=int, default=128, help="q heads number") + parser.add_argument("--h_kv", type=int, default=1, help="kv heads number") + parser.add_argument("--cache_seqlen", type=int, default=8192, help="kv cache context length") + parser.add_argument("--d", type=int, default=576, help="query/key head dim, d = dv + dpe") + parser.add_argument("--dv", type=int, default=512, help="value head dim") args = parser.parse_args() b, h_q, h_kv, cache_seqlen, d, dv = args.batch, args.h_q, args.h_kv, args.cache_seqlen, args.d, args.dv @@ -379,9 +308,7 @@ def flash_mla_tilelang(): s_q = 1 # for decode, s_q = 1 block_size = 64 - cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)], - dtype=torch.int32, - device=device) + cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)], dtype=torch.int32, device=device) dpe = d - dv causal = True @@ -393,12 +320,11 @@ def flash_mla_tilelang(): total_flops = s_q * total_seqlens * h_q * d * 2 q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device) - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32, - device=device).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32, device=device).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d, dtype=dtype, device=device) - out_flash, latency = run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_flash, latency = run_tilelang_mla( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) diff --git a/examples/deepseek_mla/example_mla_decode_persistent.py b/examples/deepseek_mla/example_mla_decode_persistent.py index 3f57ea051..74d974fbb 100644 --- a/examples/deepseek_mla/example_mla_decode_persistent.py +++ b/examples/deepseek_mla/example_mla_decode_persistent.py @@ -9,13 +9,15 @@ @tilelang.jit( - out_idx=[6], pass_configs={ + out_idx=[6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): - scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) - dtype = "float16" - accum_dtype = "float" + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // kv_head_num VALID_BLOCK_H = min(block_H, kv_group_num) assert kv_head_num == 1, "kv_head_num must be 1" @@ -23,13 +25,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ @T.prim_func def main_split_persistent( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(sm_num, threads=256) as (block_id): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -48,16 +50,11 @@ def main_split_persistent( logsum = T.alloc_fragment([block_H], accum_dtype) po_local = T.alloc_fragment([dim], dtype) o_accum_local = T.alloc_fragment([dim], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - - T.annotate_layout({ - # O_shared: tilelang.layout.make_swizzled_layout(O_shared), - S_shared: tilelang.layout.make_swizzled_layout(S_shared), - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + T.use_swizzle(10) total_tiles = batch * (heads // min(block_H, kv_group_num)) * num_split @@ -70,8 +67,8 @@ def main_split_persistent( cur_kv_head = hid // (kv_group_num // block_H) if bid < batch and hid * VALID_BLOCK_H < heads and sid < num_split: - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -83,24 +80,15 @@ def main_split_persistent( T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared) T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared) T.clear(acc_s) - T.gemm( - Q_shared, - KV_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - - scores_max[i] * scale) + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) @@ -115,11 +103,9 @@ def main_split_persistent( acc_o[i, j] /= logsum[i] for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, sid]) + T.copy(logsum, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, sid]) # T.copy(acc_o, O_shared) - T.copy( - acc_o, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - sid, :]) + T.copy(acc_o, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, sid, :]) T.sync_grid() waves = T.ceildiv(heads * batch, sm_num) @@ -130,20 +116,20 @@ def main_split_persistent( if bid < batch and hid < heads: T.clear(lse_logsum_local) T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) + lse_max_local = -T.infinity(accum_dtype) for k in T.serial(num_split): - lse_max_local[0] = T.max(lse_max_local[0], glse[bid, hid, k]) + lse_max_local = T.max(lse_max_local, glse[bid, hid, k]) for k in T.Pipelined(num_split, num_stages=1): - lse_local_split[0] = glse[bid, hid, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + lse_local_split = glse[bid, hid, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local for k in T.serial(num_split): for i in T.Parallel(dim): po_local[i] = Output_partial[bid, hid, k, i] - lse_local_split[0] = glse[bid, hid, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + lse_local_split = glse[bid, hid, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) for i in T.Parallel(dim): - o_accum_local[i] += po_local[i] * scale_local[0] + o_accum_local[i] += po_local[i] * scale_local for i in T.Parallel(dim): Output[bid, hid, i] = o_accum_local[i] @@ -165,42 +151,35 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out def main(): parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=128, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) diff --git a/examples/deepseek_mla/example_mla_decode_ws.py b/examples/deepseek_mla/example_mla_decode_ws.py index 6554d57de..32eb0d475 100644 --- a/examples/deepseek_mla/example_mla_decode_ws.py +++ b/examples/deepseek_mla/example_mla_decode_ws.py @@ -13,30 +13,38 @@ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }, compile_flags=[ - "-O3", "-Wno-deprecated-declarations", "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", - "--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG" + "-O3", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--ptxas-options=-v,--register-usage-level=10", + "-DNDEBUG", ], ) -def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, - softmax_scale): +def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale): sm_scale = float(softmax_scale * 1.44269504) # log2(e) - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // kv_head_num VALID_BLOCK_H = min(block_H, kv_group_num) assert kv_head_num == 1, "kv_head_num must be 1" - @T.macro - def flash_attn( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + @T.prim_func + def main_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): - with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=384) as (hid, bid): + # flash_attn_split + with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=384) as (bid, hid, bz): Q_shared_l = T.alloc_shared([block_H, dim // 2], dtype) Q_shared_r = T.alloc_shared([block_H, dim // 2], dtype) Q_tail_shared = T.alloc_shared([block_H, pe_dim], dtype) @@ -75,16 +83,16 @@ def flash_attn( tx = T.get_thread_binding() - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, 0:dim // 2], Q_shared_l) - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, dim // 2:dim], Q_shared_r) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_tail_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2], Q_shared_l) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim], Q_shared_r) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_tail_shared) T.barrier_arrive(bar_q) if tx < 128: T.set_max_nreg(240, 1) T.fill(sumexp, 0) - T.fill(m_i, -2**30) # avoid -inf - inf to cause nan + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan T.fill(acc_o_l, 0) T.barrier_wait(bar_q, 0) @@ -105,6 +113,8 @@ def flash_attn( T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(block_H): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(block_H, block_N): @@ -137,6 +147,8 @@ def flash_attn( T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(block_H): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(block_H, block_N): @@ -162,8 +174,8 @@ def flash_attn( for h_i in T.Parallel(block_H): sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale T.copy(acc_o_l, O_shared_l) - T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - 0:dim // 2]) + T.copy(O_shared_l, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, 0 : dim // 2]) + T.copy(sumexp, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz]) elif tx >= 128 and tx < 256: T.set_max_nreg(168, 1) @@ -193,8 +205,7 @@ def flash_attn( acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] T.copy(acc_o_r, O_shared_r) - T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - dim // 2:dim]) + T.copy(O_shared_r, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, dim // 2 : dim]) elif tx >= 256: # producer @@ -203,59 +214,82 @@ def flash_attn( # Buffer 0 T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - kv_indices = (i_i * 2) * block_N + r * 16 + (tx - 256) // 8 + kv_indices = (seqlen_kv // num_split) * bz + (i_i * 2) * block_N + r * 16 + (tx - 256) // 8 with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_0_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_0_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, dim // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = K_pe[bid, kv_indices, cur_kv_head, - (tx - 256) % 8 * 8 + v] + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_0_ready[0]) # Buffer 1 T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - kv_indices = (i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8 + kv_indices = (seqlen_kv // num_split) * bz + (i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8 with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_1_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_1_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, dim // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = K_pe[bid, kv_indices, cur_kv_head, - (tx - 256) % 8 * 8 + v] + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_1_ready[0]) - @T.macro - def flash_attn_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + # combine + with T.Kernel(heads, batch, threads=128) as (hid, bz): + po_local = T.alloc_fragment([dim], dtype) + o_accum_local = T.alloc_fragment([dim], accum_dtype) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local = T.max(lse_max_local, glse[bz, hid, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split = glse[bz, hid, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local + for k in T.serial(num_split): + for i in T.Parallel(dim): + po_local[i] = Output_partial[bz, hid, k, i] + lse_local_split = glse[bz, hid, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) + for i in T.Parallel(dim): + o_accum_local[i] += po_local[i] * scale_local + for i in T.Parallel(dim): + Output[bz, hid, i] = o_accum_local[i] + + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): - with T.Kernel( - batch, heads // min(block_H, kv_group_num), num_split, - threads=384) as (bid, hid, bz): + with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=384) as (hid, bid): Q_shared_l = T.alloc_shared([block_H, dim // 2], dtype) Q_shared_r = T.alloc_shared([block_H, dim // 2], dtype) Q_tail_shared = T.alloc_shared([block_H, pe_dim], dtype) @@ -294,16 +328,16 @@ def flash_attn_split( tx = T.get_thread_binding() - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, 0:dim // 2], Q_shared_l) - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, dim // 2:dim], Q_shared_r) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_tail_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2], Q_shared_l) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim], Q_shared_r) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_tail_shared) T.barrier_arrive(bar_q) if tx < 128: T.set_max_nreg(240, 1) T.fill(sumexp, 0) - T.fill(m_i, -2**30) # avoid -inf - inf to cause nan + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan T.fill(acc_o_l, 0) T.barrier_wait(bar_q, 0) @@ -323,7 +357,9 @@ def flash_attn_split( T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1) T.copy(m_i, m_i_prev) - T.reduce_max(acc_s, m_i, dim=1, clear=False) + T.reduce_max(acc_s, out=m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(block_H): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(block_H, block_N): @@ -356,6 +392,8 @@ def flash_attn_split( T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(block_H): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(block_H, block_N): @@ -381,10 +419,7 @@ def flash_attn_split( for h_i in T.Parallel(block_H): sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale T.copy(acc_o_l, O_shared_l) - T.copy( - O_shared_l, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - bz, 0:dim // 2]) - T.copy(sumexp, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, bz]) + T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2]) elif tx >= 128 and tx < 256: T.set_max_nreg(168, 1) @@ -414,9 +449,7 @@ def flash_attn_split( acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] T.copy(acc_o_r, O_shared_r) - T.copy( - O_shared_r, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - bz, dim // 2:dim]) + T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim]) elif tx >= 256: # producer @@ -425,111 +458,43 @@ def flash_attn_split( # Buffer 0 T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - kv_indices = (seqlen_kv // num_split) * bz + ( - i_i * 2) * block_N + r * 16 + (tx - 256) // 8 + kv_indices = (i_i * 2) * block_N + r * 16 + (tx - 256) // 8 with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_0_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_0_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, dim // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = K_pe[bid, kv_indices, cur_kv_head, - (tx - 256) % 8 * 8 + v] + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_0_ready[0]) # Buffer 1 T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - kv_indices = (seqlen_kv // num_split) * bz + ( - i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8 + kv_indices = (i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8 with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_1_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_1_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, dim // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = K_pe[bid, kv_indices, cur_kv_head, - (tx - 256) % 8 * 8 + v] + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_1_ready[0]) - @T.macro - def combine( - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), - ): - with T.Kernel(heads, batch, threads=128) as (hid, bz): - po_local = T.alloc_fragment([dim], dtype) - o_accum_local = T.alloc_fragment([dim], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - - T.annotate_layout({ - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) - for k in T.serial(num_split): - lse_max_local[0] = T.max(lse_max_local[0], glse[bz, hid, k]) - for k in T.Pipelined(num_split, num_stages=1): - lse_local_split[0] = glse[bz, hid, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] - for k in T.serial(num_split): - for i in T.Parallel(dim): - po_local[i] = Output_partial[bz, hid, k, i] - lse_local_split[0] = glse[bz, hid, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) - for i in T.Parallel(dim): - o_accum_local[i] += po_local[i] * scale_local[0] - for i in T.Parallel(dim): - Output[bz, hid, i] = o_accum_local[i] - - @T.prim_func - def main_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), - ): - flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial) - combine(glse, Output_partial, Output) - - @T.prim_func - def main_no_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), - ): - flash_attn(Q, Q_pe, KV, K_pe, Output) - if num_split > 1: return main_split else: @@ -551,31 +516,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out @@ -593,10 +551,9 @@ def main( BLOCK_N = 64 BLOCK_H = min(64, heads // kv_heads) num_split = 1 - softmax_scale = (dim + pe_dim)**-0.5 + softmax_scale = (dim + pe_dim) ** -0.5 - kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, - softmax_scale) + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, softmax_scale) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4) latency = profiler.do_bench(warmup=500) @@ -606,12 +563,12 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=132, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') + parser.add_argument("--batch", type=int, default=132, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) diff --git a/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py b/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py index 1b1447e88..e70c35349 100644 --- a/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py +++ b/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py @@ -8,25 +8,27 @@ @tilelang.jit( - out_idx=[-1], pass_configs={ + out_idx=[-1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H): - scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) - dtype = "float16" - q_dtype = "float8_e4m3" - accum_dtype = "float" + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + q_dtype = T.float8_e4m3fn + accum_dtype = T.float32 kv_group_num = heads // kv_head_num VALID_BLOCK_H = min(block_H, kv_group_num) assert kv_head_num == 1, "kv_head_num must be 1" @T.prim_func def main_no_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], q_dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], q_dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=256) as (bx, by): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -46,34 +48,27 @@ def main_no_split( cur_kv_head = by // (kv_group_num // block_H) T.use_swizzle(10) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - }) - T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) T.disable_warp_group_reg_alloc() loop_range = T.ceildiv(seqlen_kv, block_N) for k in T.Pipelined(loop_range, num_stages=2): - T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], qKV_shared) - T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared) + T.copy(KV[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], qKV_shared) + T.copy(K_pe[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared) T.copy(qKV_shared, KV_shared) T.clear(acc_s) - T.gemm( - Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): @@ -88,7 +83,7 @@ def main_no_split( for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :]) + T.copy(O_shared, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :]) return main_no_split @@ -106,42 +101,35 @@ def ref_program(q, q_pe, kv, k_pe): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=128, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) diff --git a/examples/deepseek_mla/regression_example_mla_decode.py b/examples/deepseek_mla/regression_example_mla_decode.py new file mode 100644 index 000000000..64e1c436a --- /dev/null +++ b/examples/deepseek_mla/regression_example_mla_decode.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_mla_decode + + +def regression_example_mla_decode(): + tilelang.testing.process_func(example_mla_decode.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/deepseek_mla/test_example_mla_decode.py b/examples/deepseek_mla/test_example_mla_decode.py index 66a750f7d..a269ea57a 100644 --- a/examples/deepseek_mla/test_example_mla_decode.py +++ b/examples/deepseek_mla/test_example_mla_decode.py @@ -1,5 +1,4 @@ import tilelang.testing - import example_mla_decode diff --git a/examples/deepseek_mla/torch_refs.py b/examples/deepseek_mla/torch_refs.py index 4b4c888cd..aae6c7cd2 100644 --- a/examples/deepseek_mla/torch_refs.py +++ b/examples/deepseek_mla/torch_refs.py @@ -11,7 +11,7 @@ def flash_split_ref(Q, Q_pe, KV, K_pe): block_N = 64 seqlen_kv = KV.size(1) - scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) acc_s = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float) acc_s_cast = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float16) acc_o = torch.empty((batch, nheads, dim), device="cuda", dtype=torch.float) @@ -31,18 +31,20 @@ def flash_split_ref(Q, Q_pe, KV, K_pe): for ks in range(num_split): acc_o.fill_(0) logsum.fill_(0) - scores_max.fill_(float('-inf')) - scores_max_prev.fill_(float('-inf')) + scores_max.fill_(float("-inf")) + scores_max_prev.fill_(float("-inf")) for i in range(int((seqlen_kv // num_split) / block_N)): acc_s.fill_(0) - acc_s = torch.einsum('bhd,bkhd->bhk', Q_, - KV_[:, (seqlen_kv // num_split) * ks + - i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) # [batch, nheads, block_N] + acc_s = torch.einsum( + "bhd,bkhd->bhk", + Q_, + KV_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) # [batch, nheads, block_N] acc_s += torch.einsum( - 'bhd,bkhd->bhk', Q_pe_, - K_pe_[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) + "bhd,bkhd->bhk", + Q_pe_, + K_pe_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) scores_max_prev = scores_max scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads] scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads] @@ -50,9 +52,10 @@ def flash_split_ref(Q, Q_pe, KV, K_pe): acc_s = torch.exp2(acc_s - scores_max[:, :, None]) acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N] acc_o += torch.einsum( - 'bhk,bkhd->bhd', acc_s_cast, - KV_[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) + "bhk,bkhd->bhd", + acc_s_cast, + KV_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) scores_sum = acc_s.sum(dim=-1, keepdim=False) logsum = logsum * scores_scale + scores_sum acc_o /= logsum[:, :, None] diff --git a/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py b/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py index daee39865..ca98d01be 100644 --- a/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py +++ b/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py @@ -14,21 +14,44 @@ from fla.utils import autocast_custom_fwd, contiguous -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) @triton.jit -def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices, - block_counts, offsets, token_indices, T, H: tl.constexpr, - HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, - S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, USE_OFFSETS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -40,20 +63,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc NS = S - p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), - (G, BV), (1, 0)) + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_slc = tl.zeros([G, BV], dtype=tl.float32) - b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_slc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS @@ -66,7 +87,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) # [G, BS] b_s_slc = tl.dot(b_q, b_k_slc) - b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf')) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) # [G] b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc @@ -87,7 +108,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -100,8 +120,7 @@ def forward(ctx, q, k, v, block_indices, block_size, scale, offsets): # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] token_indices = prepare_token_indices(offsets) if offsets is not None else None - o, lse = parallel_nsa_fwd( - q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) + o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) ctx.save_for_backward(q, k, v, o, lse) ctx.block_indices = block_indices ctx.block_size = block_size @@ -172,7 +191,6 @@ def parallel_nsa_fwd( @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -195,7 +213,8 @@ def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, window_size=window_size, scale=scale, offsets=offsets, - token_indices=token_indices) + token_indices=token_indices, + ) ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) ctx.block_indices = block_indices ctx.block_counts = block_counts @@ -207,18 +226,20 @@ def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa -def parallel_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -258,44 +279,44 @@ def parallel_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, - window_size, scale, cu_seqlens) + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) else: o = o_slc * g_slc.unsqueeze(-1) if head_first: - o = rearrange(o, 'b t h d -> b h t d') + o = rearrange(o, "b t h d -> b h t d") return o -def naive_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def naive_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -335,26 +356,24 @@ def naive_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - raise RuntimeError( - "Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") dtype = q.dtype G = q.shape[2] // k.shape[2] BS = block_size S = block_indices.shape[-1] - k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices)) + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) if isinstance(block_counts, torch.Tensor): - block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G) + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) q, k, v = map(lambda x: x.float(), (q, k, v)) @@ -364,14 +383,11 @@ def naive_nsa(q: torch.Tensor, if cu_seqlens is None: varlen = False B, T = q.shape[:2] - cu_seqlens = torch.cat( - [block_indices.new_tensor(range(0, B * T, T)), - block_indices.new_tensor([B * T])]) + cu_seqlens = torch.cat([block_indices.new_tensor(range(0, B * T, T)), block_indices.new_tensor([B * T])]) for i in range(len(cu_seqlens) - 1): if not varlen: - q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[ - i], block_indices[i] + q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[i], block_indices[i] if isinstance(block_counts, torch.Tensor): s_b = block_counts[i] else: @@ -379,10 +395,10 @@ def naive_nsa(q: torch.Tensor, else: T = cu_seqlens[i + 1] - cu_seqlens[i] q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map( - lambda x: x[0][cu_seqlens[i]:cu_seqlens[i + 1]], - (q, k, v, g_slc, g_swa, block_indices)) + lambda x: x[0][cu_seqlens[i] : cu_seqlens[i + 1]], (q, k, v, g_slc, g_swa, block_indices) + ) if isinstance(block_counts, torch.Tensor): - s_b = block_counts[0][cu_seqlens[i]:cu_seqlens[i + 1]] + s_b = block_counts[0][cu_seqlens[i] : cu_seqlens[i + 1]] else: s_b = block_counts @@ -404,71 +420,58 @@ def naive_nsa(q: torch.Tensor, else: s_i = s_b # [S*BS, HQ, -1] - k_i_slc, v_i_slc = map( - lambda x: x.gather( - 0, - i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) + k_i_slc, v_i_slc = map(lambda x: x.gather(0, i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) # [S*BS, HQ] - attn_slc = torch.einsum('h d, n h d -> n h', q_i, k_i_slc).masked_fill( - torch.logical_or(i_i < 0, i_i > i_q) | - (c >= s_i if block_counts is not None else False), float('-inf')).softmax(0) + attn_slc = ( + torch.einsum("h d, n h d -> n h", q_i, k_i_slc) + .masked_fill(torch.logical_or(i_i < 0, i_i > i_q) | (c >= s_i if block_counts is not None else False), float("-inf")) + .softmax(0) + ) if not varlen: - o_slc[i, i_q] = torch.einsum('n h, n h v -> h v', attn_slc, - v_i_slc) * g_slc_i.unsqueeze(-1) + o_slc[i, i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) else: - o_slc[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_slc, - v_i_slc) * g_slc_i.unsqueeze(-1) + o_slc[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) if window_size > 0: - k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1):i_q + 1], - (k_b, v_b)) - attn_swa = torch.einsum('h d, n h d -> n h', q_i, k_i_swa).softmax(0) + k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1) : i_q + 1], (k_b, v_b)) + attn_swa = torch.einsum("h d, n h d -> n h", q_i, k_i_swa).softmax(0) if not varlen: - o_swa[i, i_q] = torch.einsum('n h, n h v -> h v', attn_swa, - v_i_swa) * g_swa_i.unsqueeze(-1) + o_swa[i, i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) else: - o_swa[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_swa, - v_i_swa) * g_swa_i.unsqueeze(-1) + o_swa[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) if head_first: - o_slc = rearrange(o_slc, 'b t h d -> b h t d') - o_swa = rearrange(o_swa, 'b t h d -> b h t d') + o_slc = rearrange(o_slc, "b t h d -> b h t d") + o_swa = rearrange(o_swa, "b t h d -> b h t d") return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype) def get_configs(): import itertools + iter_params = dict( block_T=[128, 256, 512], num_stages=[0, 1, 2, 4, 5], threads=[32, 64, 128, 256, 512], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] -@tilelang.autotune(configs=get_configs(),) +@tilelang.autotune( + configs=get_configs(), +) @tilelang.jit( pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) -def tilelang_sparse_attention(batch, - heads, - seq_len, - dim, - is_causal, - scale=None, - block_size=64, - groups=1, - selected_blocks=16, - block_T=128, - num_stages=2, - threads=32): + } +) +def tilelang_sparse_attention( + batch, heads, seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16, block_T=128, num_stages=2, threads=32 +): if scale is None: - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) else: scale = scale * 1.44269504 # log2(e) @@ -476,9 +479,9 @@ def tilelang_sparse_attention(batch, q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] block_indices_shape = [batch, seq_len, head_kv, selected_blocks] - block_indices_dtype = "int32" - dtype = "float16" - accum_dtype = "float" + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 block_S = block_size block_T = min(block_T, tilelang.math.next_power_of_2(dim)) @@ -493,11 +496,11 @@ def tilelang_sparse_attention(batch, @T.prim_func def tilelang_sparse_attention( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([G, BK], dtype) @@ -514,13 +517,11 @@ def tilelang_sparse_attention( scores_sum = T.alloc_fragment([G], accum_dtype) logsum = T.alloc_fragment([G], accum_dtype) - T.annotate_layout({O_shared: tilelang.layout.make_swizzled_layout(O_shared)}) - i_t, i_v, i_bh = bx, by, bz i_b, i_h = i_bh // head_kv, i_bh % head_kv NS = S - T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared) + T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -530,21 +531,15 @@ def tilelang_sparse_attention( i_s = BlockIndices[i_b, i_t, i_h, i] * BS if i_s <= i_t and i_s >= 0: # [BS, BK] - T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared) + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) if is_causal: for i, j in T.Parallel(G, BS): - acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) # Softmax T.copy(scores_max, scores_max_prev) @@ -564,45 +559,33 @@ def tilelang_sparse_attention( acc_o[i, j] *= scores_scale[i] # V * softmax(Q * K) - T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(G, BV): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV]) + T.copy(O_shared, Output[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) return tilelang_sparse_attention def generate_block_indices(batch, seq_len, heads, selected_blocks, block_size): """Generate random block indices for the benchmark.""" - block_indices = torch.full((batch, seq_len, heads, selected_blocks), - seq_len, - dtype=torch.long, - device='cuda') + block_indices = torch.full((batch, seq_len, heads, selected_blocks), seq_len, dtype=torch.long, device="cuda") for b in range(batch): for t in range(seq_len): for h in range(heads): i_i = torch.randperm(max(1, (t // block_size)))[:selected_blocks] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i return block_indices.sort(-1)[0] -def benchmark_nsa(batch_size, - seq_len, - heads, - head_query, - dim, - selected_blocks, - block_size, - dtype, - scale, - warmup=10, - iterations=100, - validate=False): +def benchmark_nsa( + batch_size, seq_len, heads, head_query, dim, selected_blocks, block_size, dtype, scale, warmup=10, iterations=100, validate=False +): """Benchmark the TileLang Sparse Attention implementation.""" # Set random seed for reproducibility @@ -628,14 +611,13 @@ def benchmark_nsa(batch_size, print(f"Profiler latency: {profiler_latency} ms") # Create input tensors - Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda') - K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda') - V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda') - out = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda') + Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") + K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + out = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") # Generate block indices - block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, - block_size).to(torch.int32) + block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, block_size).to(torch.int32) # Warmup for _ in range(warmup): @@ -666,10 +648,9 @@ def benchmark_nsa(batch_size, # Validate result against reference if requested if validate: - g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda') - g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda') - block_counts = torch.randint( - 1, selected_blocks + 1, (batch_size, seq_len, heads), device='cuda') + g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") + g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") + block_counts = torch.randint(1, selected_blocks + 1, (batch_size, seq_len, heads), device="cuda") ref = naive_nsa( q=Q, @@ -700,22 +681,13 @@ def benchmark_nsa(batch_size, "head_query": head_query, "dim": dim, "selected_blocks": selected_blocks, - "block_size": block_size + "block_size": block_size, } -def benchmark_triton_nsa(batch_size, - seq_len, - heads, - head_query, - dim, - selected_blocks, - block_size, - dtype, - scale, - warmup=10, - iterations=100, - validate=False): +def benchmark_triton_nsa( + batch_size, seq_len, heads, head_query, dim, selected_blocks, block_size, dtype, scale, warmup=10, iterations=100, validate=False +): """Benchmark the Triton-based TileLang Sparse Attention implementation.""" # Set random seed for reproducibility @@ -723,18 +695,17 @@ def benchmark_triton_nsa(batch_size, torch.random.manual_seed(0) # Create input tensors - Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda') - K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda') - V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda') - g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda') - g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda') + Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") + K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") + g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") # Generate block indices block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, block_size) - block_counts = torch.randint( - 1, selected_blocks + 1, (batch_size, seq_len, heads), device='cuda') - o_slc = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda') - lse_slc = torch.empty((batch_size, seq_len, head_query), dtype=torch.float, device='cuda') + block_counts = torch.randint(1, selected_blocks + 1, (batch_size, seq_len, heads), device="cuda") + o_slc = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") + lse_slc = torch.empty((batch_size, seq_len, head_query), dtype=torch.float, device="cuda") # Warmup for _ in range(warmup): @@ -750,7 +721,8 @@ def benchmark_triton_nsa(batch_size, block_counts=block_counts, block_size=block_size, window_size=0, - scale=scale) + scale=scale, + ) # Synchronize before timing torch.cuda.synchronize() @@ -770,7 +742,8 @@ def benchmark_triton_nsa(batch_size, block_counts=block_counts, block_size=block_size, window_size=0, - scale=scale) + scale=scale, + ) torch.cuda.synchronize() end_time = time.time() @@ -815,54 +788,28 @@ def benchmark_triton_nsa(batch_size, "head_query": head_query, "dim": dim, "selected_blocks": selected_blocks, - "block_size": block_size + "block_size": block_size, } -def run_benchmark_suite(impl='all'): +def run_benchmark_suite(impl="all"): """Run a suite of benchmarks with different configurations.""" # Define configurations to benchmark configs = [ # Small model config - Note: head_query must be a multiple of heads*16 for Triton - { - "batch_size": 2, - "seq_len": 1024, - "heads": 8, - "head_query": 8 * 16, - "dim": 64, - "selected_blocks": 8, - "block_size": 32 - }, - + {"batch_size": 2, "seq_len": 1024, "heads": 8, "head_query": 8 * 16, "dim": 64, "selected_blocks": 8, "block_size": 32}, # Medium model config - { - "batch_size": 2, - "seq_len": 2048, - "heads": 16, - "head_query": 16 * 16, - "dim": 64, - "selected_blocks": 16, - "block_size": 64 - }, - + {"batch_size": 2, "seq_len": 2048, "heads": 16, "head_query": 16 * 16, "dim": 64, "selected_blocks": 16, "block_size": 64}, # Large model config - { - "batch_size": 1, - "seq_len": 4096, - "heads": 32, - "head_query": 32 * 16, - "dim": 128, - "selected_blocks": 32, - "block_size": 128 - }, + {"batch_size": 1, "seq_len": 4096, "heads": 32, "head_query": 32 * 16, "dim": 128, "selected_blocks": 32, "block_size": 128}, ] results = [] for config in configs: print(f"Running benchmark with config: {config}") - if impl in ['all', 'tilelang']: + if impl in ["all", "tilelang"]: print("Benchmarking TileLang implementation:") result = benchmark_nsa( batch_size=config["batch_size"], @@ -874,12 +821,13 @@ def run_benchmark_suite(impl='all'): block_size=config["block_size"], dtype=torch.float16, scale=0.1, - validate=False) + validate=False, + ) results.append({"impl": "tilelang", **result}) print(f"Average time: {result['avg_time_ms']:.2f} ms") print(f"Performance: {result['tflops']:.2f} TFLOPs") - if impl in ['all', 'triton']: + if impl in ["all", "triton"]: print("Benchmarking Triton implementation:") result = benchmark_triton_nsa( batch_size=config["batch_size"], @@ -891,19 +839,24 @@ def run_benchmark_suite(impl='all'): block_size=config["block_size"], dtype=torch.float16, scale=0.1, - validate=False) + validate=False, + ) results.append({"impl": "triton", **result}) print(f"Average time: {result['avg_time_ms']:.2f} ms") print(f"Performance: {result['tflops']:.2f} TFLOPs") - if impl in ['all']: + if impl in ["all"]: # Print comparison if both implementations were run tilelang_result = next( - r for r in results if r["impl"] == "tilelang" and - r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"]) + r + for r in results + if r["impl"] == "tilelang" and r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"] + ) triton_result = next( - r for r in results if r["impl"] == "triton" and - r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"]) + r + for r in results + if r["impl"] == "triton" and r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"] + ) speedup = tilelang_result["avg_time_ms"] / triton_result["avg_time_ms"] print(f"Speedup (Triton vs TileLang): {speedup:.2f}x") @@ -921,8 +874,7 @@ def run_benchmark_suite(impl='all'): parser.add_argument("--dim", type=int, default=128, help="Head dimension") parser.add_argument("--selected_blocks", type=int, default=16, help="Number of selected blocks") parser.add_argument("--block_size", type=int, default=32, help="Block size") - parser.add_argument( - "--dtype", type=str, default="float16", help="Data type (float16 or float32)") + parser.add_argument("--dtype", type=str, default=T.float16, help="Data type (float16 or float32)") parser.add_argument("--scale", type=float, default=0.1, help="Attention scale factor") parser.add_argument("--iterations", type=int, default=100, help="Number of iterations") parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations") @@ -933,7 +885,8 @@ def run_benchmark_suite(impl='all'): type=str, default="all", choices=["tilelang", "triton", "all"], - help="Implementation to benchmark (tilelang, triton, or all)") + help="Implementation to benchmark (tilelang, triton, or all)", + ) args = parser.parse_args() @@ -941,13 +894,12 @@ def run_benchmark_suite(impl='all'): if args.impl in ["triton", "all"] and args.head_query % (args.heads * 16) != 0: # Adjust head_query to nearest valid value args.head_query = ((args.head_query // (args.heads * 16)) + 1) * (args.heads * 16) - print( - f"Adjusted head_query to {args.head_query} to be compatible with Triton implementation") + print(f"Adjusted head_query to {args.head_query} to be compatible with Triton implementation") if args.suite: run_benchmark_suite(impl=args.impl) else: - dtype = torch.float16 if args.dtype == "float16" else torch.float32 + dtype = torch.float16 if args.dtype == T.float16 else torch.float32 if args.impl in ["tilelang", "all"]: print("Benchmarking TileLang implementation:") @@ -963,12 +915,14 @@ def run_benchmark_suite(impl='all'): scale=args.scale, warmup=args.warmup, iterations=args.iterations, - validate=args.validate) + validate=args.validate, + ) print("\nBenchmark Results (TileLang):") print( - f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " + - f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " + - f"block_size={args.block_size}") + f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " + + f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " + + f"block_size={args.block_size}" + ) print(f"Average time: {result['avg_time_ms']:.2f} ms") print(f"Performance: {result['tflops']:.2f} TFLOPs") @@ -986,11 +940,13 @@ def run_benchmark_suite(impl='all'): scale=args.scale, warmup=args.warmup, iterations=args.iterations, - validate=args.validate) + validate=args.validate, + ) print("\nBenchmark Results (Triton):") print( - f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " + - f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " + - f"block_size={args.block_size}") + f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " + + f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " + + f"block_size={args.block_size}" + ) print(f"Average time: {result['avg_time_ms']:.2f} ms") print(f"Performance: {result['tflops']:.2f} TFLOPs") diff --git a/examples/deepseek_nsa/example_tilelang_nsa_bwd.py b/examples/deepseek_nsa/example_tilelang_nsa_bwd.py index 8387d2271..3da285a9b 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_bwd.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_bwd.py @@ -7,6 +7,7 @@ import triton import fla + if parse(fla.__version__) < parse("0.2.1"): from fla.ops.common.utils import prepare_token_indices else: @@ -22,7 +23,8 @@ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + } +) def tilelang_kernel_fwd( batch, heads, @@ -34,11 +36,10 @@ def tilelang_kernel_fwd( groups=1, selected_blocks=16, ): - from tilelang import language as T if scale is None: - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) else: scale = scale * 1.44269504 # log2(e) @@ -48,9 +49,9 @@ def tilelang_kernel_fwd( o_slc_shape = [batch, seq_len, heads, dim] lse_slc_shape = [batch, seq_len, heads] block_indices_shape = [batch, seq_len, head_kv, selected_blocks] - block_indices_dtype = "int32" - dtype = "float16" - accum_dtype = "float" + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 block_S = block_size block_T = min(128, tilelang.math.next_power_of_2(dim)) @@ -67,12 +68,12 @@ def tilelang_kernel_fwd( @T.prim_func def native_sparse_attention( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), - O_slc: T.Tensor(o_slc_shape, dtype), - LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + O_slc: T.Tensor(o_slc_shape, dtype), + LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), ): with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([G, BK], dtype) @@ -93,7 +94,7 @@ def native_sparse_attention( i_b, i_h = i_bh // head_kv, i_bh % head_kv NS = S - T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared) + T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -103,12 +104,11 @@ def native_sparse_attention( i_s = BlockIndices[i_b, i_t, i_h, i] * BS if i_s <= i_t and i_s >= 0: # [BS, BK] - T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared) + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) if is_causal: - for i, j in T.Parallel(G, BS): - acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, - -T.infinity(acc_s.dtype)) + for k, j in T.Parallel(G, BS): + acc_s[k, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) @@ -124,21 +124,21 @@ def native_sparse_attention( T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=True) - for i in T.Parallel(G): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(G, BS): - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + for k in T.Parallel(G): + scores_scale[k] = T.exp2(scores_max_prev[k] * scale - scores_max[k] * scale) + for k, j in T.Parallel(G, BS): + acc_s[k, j] = T.exp2(acc_s[k, j] * scale - scores_max[k] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(G): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for k in T.Parallel(G): + logsum[k] = logsum[k] * scores_scale[k] + scores_sum[k] T.copy(acc_s, acc_s_cast) # Rescale - for i, j in T.Parallel(G, BV): - acc_o[i, j] *= scores_scale[i] + for k, j in T.Parallel(G, BV): + acc_o[k, j] *= scores_scale[k] # V * softmax(Q * K) - T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(G, BV): @@ -146,18 +146,20 @@ def native_sparse_attention( T.copy(acc_o, O_shared) T.copy( O_shared, - O_slc[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV], + O_slc[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV], ) for i in T.Parallel(G): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, LSE_slc[i_b, i_t, i_h * G:(i_h + 1) * G]) + T.copy(logsum, LSE_slc[i_b, i_t, i_h * G : (i_h + 1) * G]) return native_sparse_attention -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def tilelang_kernel_bwd_dkv( batch, heads, @@ -168,11 +170,11 @@ def tilelang_kernel_bwd_dkv( block_size=64, groups=1, selected_blocks=16, - dtype="float16", - accum_dtype="float", + dtype=T.float16, + accum_dtype=T.float32, ): if scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 else: sm_scale = scale @@ -207,15 +209,15 @@ def tilelang_kernel_bwd_dkv( @T.prim_func def flash_bwd_dkv( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(k_shape, dtype), - V: T.Tensor(v_shape, dtype), - LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), - Delta_slc: T.Tensor(delta_slc_shape, accum_dtype), - DO_slc: T.Tensor(do_slc_shape, dtype), - DK: T.Tensor(dk_shape, dtype), - DV: T.Tensor(dv_shape, dtype), - BlockMask: T.Tensor(block_mask_shape, "int32"), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(k_shape, dtype), + V: T.Tensor(v_shape, dtype), + LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), + Delta_slc: T.Tensor(delta_slc_shape, accum_dtype), + DO_slc: T.Tensor(do_slc_shape, dtype), + DK: T.Tensor(dk_shape, dtype), + DV: T.Tensor(dv_shape, dtype), + BlockMask: T.Tensor(block_mask_shape, T.int32), ): with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): K_shared = T.alloc_shared([BS, BK], dtype) @@ -238,31 +240,25 @@ def flash_bwd_dkv( i_b, i_h = i_bh // H, i_bh % H - T.copy(K[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK], K_shared) - T.copy(V[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV], V_shared) + T.copy(K[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK], K_shared) + T.copy(V[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV], V_shared) # [BS, BK] T.clear(dk) # [BS, BV] T.clear(dv) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) - loop_st = i_s * BS loop_ed = seq_len for i in T.Pipelined( - start=loop_st, - stop=loop_ed, - num_stages=0, + start=loop_st, + stop=loop_ed, + num_stages=0, ): b_m_slc = BlockMask[i_b, i, i_h, i_s] if b_m_slc != 0: # [G, BK] - T.copy(Q[i_b, i, i_h * G:(i_h + 1) * G, :BK], Q_shared) + T.copy(Q[i_b, i, i_h * G : (i_h + 1) * G, :BK], Q_shared) T.clear(qkT) # [BS, BK] @ [G, BK] -> [BS, G] T.gemm( @@ -273,7 +269,7 @@ def flash_bwd_dkv( policy=T.GemmWarpPolicy.FullRow, ) # [G] - T.copy(LSE_slc[i_b, i, i_h * G:(i_h + 1) * G], lse_shared) + T.copy(LSE_slc[i_b, i, i_h * G : (i_h + 1) * G], lse_shared) for _i, _j in T.Parallel(BS, G): qkT[_i, _j] = T.exp2(qkT[_i, _j] * scale - lse_shared[_j]) @@ -282,7 +278,7 @@ def flash_bwd_dkv( qkT[_i, _j] = T.if_then_else(i >= (i_s * BS + _i), qkT[_i, _j], 0) # [G, BV] - T.copy(DO_slc[i_b, i, i_h * G:(i_h + 1) * G, :BV], do) + T.copy(DO_slc[i_b, i, i_h * G : (i_h + 1) * G, :BV], do) T.clear(dsT) # [BS, BV] @ [G, BV] -> [BS, G] T.gemm( @@ -296,7 +292,7 @@ def flash_bwd_dkv( # [BS, G] @ [G, BV] -> [BS, BV] T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) # [G] - T.copy(Delta_slc[i_b, i, i_h * G:(i_h + 1) * G], delta) + T.copy(Delta_slc[i_b, i, i_h * G : (i_h + 1) * G], delta) for i, j in T.Parallel(BS, G): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -305,8 +301,8 @@ def flash_bwd_dkv( T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, DV[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV]) - T.copy(dk_shared, DK[i_v, i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK]) + T.copy(dv_shared, DV[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV]) + T.copy(dk_shared, DK[i_v, i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK]) return flash_bwd_dkv @@ -321,9 +317,11 @@ def make_dq_layout(dQ): ) -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def tilelang_kernel_bwd_dqkv( batch, heads, @@ -334,11 +332,11 @@ def tilelang_kernel_bwd_dqkv( block_size=64, groups=1, selected_blocks=16, - dtype="float16", - accum_dtype="float", + dtype=T.float16, + accum_dtype=T.float32, ): if scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 else: sm_scale = scale @@ -373,16 +371,16 @@ def tilelang_kernel_bwd_dqkv( @T.prim_func def flash_bwd_dqkv( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(k_shape, dtype), - V: T.Tensor(v_shape, dtype), - LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), - Delta_slc: T.Tensor(delta_slc_shape, accum_dtype), - DO_slc: T.Tensor(do_slc_shape, dtype), - DQ: T.Tensor(dq_shape, dtype), - DK: T.Tensor(dk_shape, dtype), - DV: T.Tensor(dv_shape, dtype), - BlockMask: T.Tensor(block_mask_shape, "int32"), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(k_shape, dtype), + V: T.Tensor(v_shape, dtype), + LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), + Delta_slc: T.Tensor(delta_slc_shape, accum_dtype), + DO_slc: T.Tensor(do_slc_shape, dtype), + DQ: T.Tensor(dq_shape, dtype), + DK: T.Tensor(dk_shape, dtype), + DV: T.Tensor(dv_shape, dtype), + BlockMask: T.Tensor(block_mask_shape, T.int32), ): with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): K_shared = T.alloc_shared([BS, BK], dtype) @@ -406,31 +404,25 @@ def flash_bwd_dqkv( i_b, i_h = i_bh // H, i_bh % H - T.copy(K[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK], K_shared) - T.copy(V[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV], V_shared) + T.copy(K[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK], K_shared) + T.copy(V[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV], V_shared) # [BS, BK] T.clear(dk) # [BS, BV] T.clear(dv) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) - loop_st = i_s * BS loop_ed = seq_len for i in T.Pipelined( - start=loop_st, - stop=loop_ed, - num_stages=0, + start=loop_st, + stop=loop_ed, + num_stages=0, ): b_m_slc = BlockMask[i_b, i, i_h, i_s] if b_m_slc != 0: # [G, BK] - T.copy(Q[i_b, i, i_h * G:(i_h + 1) * G, :BK], Q_shared) + T.copy(Q[i_b, i, i_h * G : (i_h + 1) * G, :BK], Q_shared) T.clear(qkT) # [BS, BK] @ [G, BK] -> [BS, G] T.gemm( @@ -441,7 +433,7 @@ def flash_bwd_dqkv( policy=T.GemmWarpPolicy.FullRow, ) # [G] - T.copy(LSE_slc[i_b, i, i_h * G:(i_h + 1) * G], lse_shared) + T.copy(LSE_slc[i_b, i, i_h * G : (i_h + 1) * G], lse_shared) for _i, _j in T.Parallel(BS, G): qkT[_i, _j] = T.exp2(qkT[_i, _j] * scale - lse_shared[_j]) @@ -450,7 +442,7 @@ def flash_bwd_dqkv( qkT[_i, _j] = T.if_then_else(i >= (i_s * BS + _i), qkT[_i, _j], 0) # [G, BV] - T.copy(DO_slc[i_b, i, i_h * G:(i_h + 1) * G, :BV], do) + T.copy(DO_slc[i_b, i, i_h * G : (i_h + 1) * G, :BV], do) T.clear(dsT) # [BS, BV] @ [G, BV] -> [BS, G] T.gemm( @@ -464,9 +456,9 @@ def flash_bwd_dqkv( # [BS, G] @ [G, BV] -> [BS, BV] T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) # [G] - T.copy(Delta_slc[i_b, i, i_h * G:(i_h + 1) * G], delta) - for i, j in T.Parallel(BS, G): - dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.copy(Delta_slc[i_b, i, i_h * G : (i_h + 1) * G], delta) + for _i, _j in T.Parallel(BS, G): + dsT_cast[_i, _j] = qkT[_i, _j] * (dsT[_i, _j] - delta[_j]) * sm_scale # [BS, G] @ [G, BK] -> [BS, BK] T.gemm(dsT_cast, Q_shared, dk, policy=T.GemmWarpPolicy.FullRow) @@ -480,23 +472,25 @@ def flash_bwd_dqkv( T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, DV[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV]) - T.copy(dk_shared, DK[i_v, i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK]) + T.copy(dv_shared, DV[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV]) + T.copy(dk_shared, DK[i_v, i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK]) return flash_bwd_dqkv @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def tilelang_kernel_preprocess( batch, heads, seq_len, dim, - dtype="float16", - accum_dtype="float", + dtype=T.float16, + accum_dtype=T.float32, blk=32, ): from tilelang import language as T @@ -505,9 +499,9 @@ def tilelang_kernel_preprocess( @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, seq_len, heads], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, seq_len, heads], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -516,27 +510,29 @@ def flash_bwd_prep( delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, by * blk:(by + 1) * blk, bx]) + T.copy(delta, Delta[bz, by * blk : (by + 1) * blk, bx]) return flash_bwd_prep @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def tilelang_kernel_block_mask( batch, heads, seq_len, selected_blocks, block_size, - dtype="int32", + dtype=T.int32, ): from tilelang import language as T @@ -551,9 +547,9 @@ def tilelang_kernel_block_mask( @T.prim_func def flash_bwd_block_mask( - BlockIndices: T.Tensor(block_indices_shape, dtype), # type: ignore - BlockCounts: T.Tensor(block_counts_shape, dtype), # type: ignore - BlockMask: T.Tensor(block_mask_shape, dtype), # type: ignore + BlockIndices: T.Tensor(block_indices_shape, dtype), # type: ignore + BlockCounts: T.Tensor(block_counts_shape, dtype), # type: ignore + BlockMask: T.Tensor(block_mask_shape, dtype), # type: ignore ): with T.Kernel(seq_len, batch, heads * S) as (bx, by, bz): i_t, i_b, i_hs = bx, by, bz @@ -603,9 +599,7 @@ def parallel_nsa_bwd( dk = torch.empty(NV, *k.shape, dtype=k.dtype, device=q.device) dv = torch.empty(v.shape, dtype=v.dtype, device=q.device) - block_mask = tilelang_kernel_block_mask(B, H, T, S, - BS)(block_indices.to(torch.int32), - block_counts.to(torch.int32)).to(torch.bool) + block_mask = tilelang_kernel_block_mask(B, H, T, S, BS)(block_indices.to(torch.int32), block_counts.to(torch.int32)).to(torch.bool) fused_qkv_bwd_kernel = tilelang_kernel_bwd_dqkv( batch=B, @@ -618,8 +612,7 @@ def parallel_nsa_bwd( selected_blocks=S, scale=scale, ) - fused_qkv_bwd_kernel(q, k, v, lse_slc, delta_slc, do_slc, dq, dk, dv, - block_mask.to(torch.int32)) + fused_qkv_bwd_kernel(q, k, v, lse_slc, delta_slc, do_slc, dq, dk, dv, block_mask.to(torch.int32)) dq = dq.sum(0) dk = dk.sum(0) @@ -628,7 +621,6 @@ def parallel_nsa_bwd( @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -773,23 +765,21 @@ def parallel_nsa( Outputs of shape `[B, SEQLEN, HQ, V]` if `head_first=False` else `[B, HQ, SEQLEN, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), - (q, k, v, block_indices)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): block_counts = rearrange(block_counts, "b h t -> b t h") - assert (q.shape[2] % (k.shape[2] * 16) == 0), "Group size must be a multiple of 16 in NSA" + assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, - window_size, scale, cu_seqlens) + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) else: @@ -814,7 +804,7 @@ def parallel_nsa( for t in range(T): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda") diff --git a/examples/deepseek_nsa/example_tilelang_nsa_decode.py b/examples/deepseek_nsa/example_tilelang_nsa_decode.py index 58f435509..381d92493 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_decode.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_decode.py @@ -16,7 +16,8 @@ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def native_sparse_attention( batch, heads, @@ -25,18 +26,18 @@ def native_sparse_attention( scale=None, block_size=64, # Tile size for attention computation groups=1, # Grouped query attention (GQA) groups - selected_blocks=16 # Number of blocks to select per attention head + selected_blocks=16, # Number of blocks to select per attention head ): if scale is None: - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups # Modified shapes for inference (q has seq_len=1)a q_shape = [batch, 1, heads, dim] # Changed seq_len to 1 kv_shape = [batch, seq_len, head_kv, dim] block_indices_shape = [batch, 1, head_kv, selected_blocks] # Changed seq_len to 1 - block_indices_dtype = "int32" - dtype = "float16" - accum_dtype = "float" + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 block_S = block_size block_T = min(128, tilelang.math.next_power_of_2(dim)) @@ -53,12 +54,11 @@ def native_sparse_attention( @T.prim_func def native_sparse_attention( - Q: T.Tensor(q_shape, dtype), # [batch, 1, heads, dim] - K: T.Tensor(kv_shape, dtype), # [batch, seq_len, head_kv, dim] - V: T.Tensor(kv_shape, dtype), # Same shape as K - BlockIndices: T.Tensor(block_indices_shape, - block_indices_dtype), # Selected block indices - Output: T.Tensor(q_shape, dtype), # Output attention tensor + Q: T.Tensor(q_shape, dtype), # [batch, 1, heads, dim] + K: T.Tensor(kv_shape, dtype), # [batch, seq_len, head_kv, dim] + V: T.Tensor(kv_shape, dtype), # Same shape as K + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), # Selected block indices + Output: T.Tensor(q_shape, dtype), # Output attention tensor ): with T.Kernel(1, NV, batch * head_kv, threads=threads) as (bx, by, bz): # Shared memory allocations for tile storage @@ -82,7 +82,7 @@ def native_sparse_attention( NS = S # Copy Q for the single position - T.copy(Q[i_b, 0, i_h * G:(i_h + 1) * G, :], Q_shared) # Changed i_t to 0 + T.copy(Q[i_b, 0, i_h * G : (i_h + 1) * G, :], Q_shared) # Changed i_t to 0 T.fill(acc_o, 0) T.fill(logsum, 0) @@ -93,16 +93,11 @@ def native_sparse_attention( i_s = BlockIndices[i_b, 0, i_h, i] * BS # Get block offset if i_s >= 0: # Skip invalid/padding blocks # Load current key block to shared memory - T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared) + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) # Compute QK^T attention scores T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) # Online softmax with numerical stability # 1. Compute max for scaling @@ -122,15 +117,14 @@ def native_sparse_attention( T.copy(acc_s, acc_s_cast) # Accumulate attention-weighted values - T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) # Final normalization and output for i, j in T.Parallel(G, BV): acc_o[i, j] /= logsum[i] # Normalize by logsum T.copy(acc_o, O_shared) - T.copy(O_shared, Output[i_b, 0, i_h * G:(i_h + 1) * G, - i_v * BV:(i_v + 1) * BV]) # Changed i_t to 0 + T.copy(O_shared, Output[i_b, 0, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) # Changed i_t to 0 return native_sparse_attention @@ -149,21 +143,21 @@ def main(): selected_blocks=S, ) - Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) - K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) - V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) + Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) - mask = torch.randint(0, 2, (B, SEQ_LEN, groups), device='cuda') - DO = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda') + mask = torch.randint(0, 2, (B, SEQ_LEN, groups), device="cuda") + DO = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda") - block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device='cuda') + block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device="cuda") for b in range(B): for t in range(SEQ_LEN_Q): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (B, SEQ_LEN_Q, H), device='cuda') + block_counts = torch.randint(1, S + 1, (B, SEQ_LEN_Q, H), device="cuda") out = kernel(Q, K, V, block_indices.to(torch.int32)) @@ -178,5 +172,38 @@ def main(): torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) +def run_regression_perf(): + B, SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 16, 1, 32, torch.float16 + groups = HQ // H + SEQ_LEN_Q = 1 + kernel = native_sparse_attention( + batch=B, + heads=HQ, + seq_len=SEQ_LEN, + dim=D, + block_size=block_size, + groups=HQ // H, + selected_blocks=S, + ) + + Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device="cuda") + for b in range(B): + for t in range(SEQ_LEN_Q): + for h in range(H): + i_i = torch.randperm(max(1, (t // block_size)))[:S] + block_indices[b, t, h, : len(i_i)] = i_i + block_indices = block_indices.sort(-1)[0] + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, block_indices.to(torch.int32)) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py index f8a7ebfb0..7b36d6e26 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py @@ -14,18 +14,11 @@ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) -def native_sparse_attention(batch, - heads, - seq_len, - dim, - is_causal, - scale=None, - block_size=64, - groups=1, - selected_blocks=16): + }, +) +def native_sparse_attention(batch, heads, seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16): if scale is None: - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) else: scale = scale * 1.44269504 # log2(e) @@ -33,9 +26,9 @@ def native_sparse_attention(batch, q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] block_indices_shape = [batch, seq_len, head_kv, selected_blocks] - block_indices_dtype = "int32" - dtype = "float16" - accum_dtype = "float" + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 block_S = block_size block_T = min(128, tilelang.math.next_power_of_2(dim)) @@ -52,11 +45,11 @@ def native_sparse_attention(batch, @T.prim_func def native_sparse_attention( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([G, BK], dtype) @@ -77,7 +70,7 @@ def native_sparse_attention( i_b, i_h = i_bh // head_kv, i_bh % head_kv NS = S - T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared) + T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -87,21 +80,15 @@ def native_sparse_attention( i_s = BlockIndices[i_b, i_t, i_h, i] * BS if i_s <= i_t and i_s >= 0: # [BS, BK] - T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared) + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) if is_causal: for i, j in T.Parallel(G, BS): - acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) # Softmax T.copy(scores_max, scores_max_prev) @@ -121,13 +108,13 @@ def native_sparse_attention( acc_o[i, j] *= scores_scale[i] # V * softmax(Q * K) - T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(G, BV): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV]) + T.copy(O_shared, Output[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) return native_sparse_attention @@ -148,21 +135,22 @@ def main(): ) print(kernel.get_kernel_source()) torch.random.manual_seed(0) - Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) - K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) - V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) - g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True) - DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda') - - block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device='cuda') + Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda") + block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda") for b in range(B): for t in range(SEQ_LEN): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i + block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item() block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (B, SEQ_LEN, H), device='cuda') out = kernel(Q, K, V, block_indices.to(torch.int32)) @@ -183,5 +171,43 @@ def main(): torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) +def run_regression_perf(): + B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 + kernel = native_sparse_attention( + batch=B, + heads=HQ, + seq_len=SEQ_LEN, + dim=D, + is_causal=True, + block_size=block_size, + groups=HQ // H, + selected_blocks=S, + scale=scale, + ) + torch.random.manual_seed(0) + Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda") + block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda") + block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda") + for b in range(B): + for t in range(SEQ_LEN): + for h in range(H): + i_i = torch.randperm(max(1, (t // block_size)))[:S] + block_indices[b, t, h, : len(i_i)] = i_i + block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item() + block_indices = block_indices.sort(-1)[0] + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, block_indices.to(torch.int32)) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py b/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py index d365e7a5f..b52ebe42e 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py @@ -8,6 +8,7 @@ import tilelang.testing import fla + if parse(fla.__version__) < parse("0.2.1"): from fla.ops.common.utils import prepare_token_indices else: @@ -21,18 +22,11 @@ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) -def native_sparse_attention_varlen(batch, - heads, - c_seq_len, - dim, - is_causal, - scale=None, - block_size=64, - groups=1, - selected_blocks=16): + } +) +def native_sparse_attention_varlen(batch, heads, c_seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16): if scale is None: - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [c_seq_len, heads, dim] kv_shape = [c_seq_len, head_kv, dim] @@ -44,12 +38,12 @@ def native_sparse_attention_varlen(batch, block_counts_shape = [c_seq_len, head_kv] offsets_shape = [batch + 1] token_indices_shape = [c_seq_len, 2] - block_indices_dtype = "int32" - block_counts_dtype = "int32" - offsets_dtype = "int32" - token_indices_dtype = "int32" - dtype = "float16" - accum_dtype = "float" + block_indices_dtype = T.int32 + block_counts_dtype = T.int32 + offsets_dtype = T.int32 + token_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 block_S = block_size block_T = min(128, tilelang.math.next_power_of_2(dim)) @@ -66,14 +60,14 @@ def native_sparse_attention_varlen(batch, @T.prim_func def native_sparse_attention_varlen( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - O_slc: T.Tensor(o_slc_shape, dtype), - BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), - BlockCounts: T.Tensor(block_counts_shape, block_counts_dtype), - Offsets: T.Tensor(offsets_shape, offsets_dtype), - TokenIndices: T.Tensor(token_indices_shape, token_indices_dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + O_slc: T.Tensor(o_slc_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + BlockCounts: T.Tensor(block_counts_shape, block_counts_dtype), + Offsets: T.Tensor(offsets_shape, offsets_dtype), + TokenIndices: T.Tensor(token_indices_shape, token_indices_dtype), ): with T.Kernel(c_seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([G, BK], dtype) @@ -100,7 +94,7 @@ def native_sparse_attention_varlen( current_seq_len = eos - bos NS = BlockCounts[i_t, i_h] - T.copy(Q[bos + i_t, i_h * G:(i_h + 1) * G, :BK], Q_shared) + T.copy(Q[bos + i_t, i_h * G : (i_h + 1) * G, :BK], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -112,21 +106,15 @@ def native_sparse_attention_varlen( # [BS, BK] # Lei: may have some padding issues # we should learn from mha varlen templates to handle this - T.copy(K[bos + i_s:bos + i_s + BS, i_h, :BK], K_shared) + T.copy(K[bos + i_s : bos + i_s + BS, i_h, :BK], K_shared) if is_causal: for i, j in T.Parallel(G, BS): - acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) # Softmax T.copy(scores_max, scores_max_prev) @@ -146,13 +134,13 @@ def native_sparse_attention_varlen( acc_o[i, j] *= scores_scale[i] # V * softmax(Q * K) - T.copy(V[bos + i_s:bos + i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) + T.copy(V[bos + i_s : bos + i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(G, BV): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, O_slc[bos + i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV]) + T.copy(O_shared, O_slc[bos + i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) return native_sparse_attention_varlen @@ -190,17 +178,20 @@ def parallel_nsa_fwd( o_slc = torch.empty(B, C_SEQ_LEN, HQ, V, dtype=v.dtype, device=q.device) kernel( - q.view(C_SEQ_LEN, HQ, D), k.view(C_SEQ_LEN, H, D), v.view(C_SEQ_LEN, H, D), + q.view(C_SEQ_LEN, HQ, D), + k.view(C_SEQ_LEN, H, D), + v.view(C_SEQ_LEN, H, D), o_slc.view(C_SEQ_LEN, HQ, V), block_indices.to(torch.int32).view(C_SEQ_LEN, H, S), - block_counts.to(torch.int32).view(C_SEQ_LEN, H), offsets.to(torch.int32), - token_indices.to(torch.int32)) + block_counts.to(torch.int32).view(C_SEQ_LEN, H), + offsets.to(torch.int32), + token_indices.to(torch.int32), + ) return o_slc @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets): ctx.dtype = q.dtype @@ -221,22 +212,25 @@ def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, window_size=window_size, scale=scale, offsets=offsets, - token_indices=token_indices) + token_indices=token_indices, + ) return o_slc.to(q.dtype) -def parallel_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -276,29 +270,27 @@ def parallel_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, - scale, cu_seqlens) + o_slc = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: assert False, "Window size is not supported yet" else: o = o_slc * g_slc.unsqueeze(-1) if head_first: - o = rearrange(o, 'b t h d -> b h t d') + o = rearrange(o, "b t h d -> b h t d") return o @@ -306,41 +298,57 @@ def parallel_nsa(q: torch.Tensor, N, C_SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 64, 1, 32, torch.float16 torch.manual_seed(42) # randomly split the sequence into N segments - offsets = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, C_SEQ_LEN)[torch.randperm(C_SEQ_LEN - 1)[:N - 1]], - torch.tensor([C_SEQ_LEN], dtype=torch.long) - ], 0).cuda().sort()[0] + offsets = ( + torch.cat( + [ + torch.tensor([0], dtype=torch.long), + torch.arange(16, C_SEQ_LEN)[torch.randperm(C_SEQ_LEN - 1)[: N - 1]], + torch.tensor([C_SEQ_LEN], dtype=torch.long), + ], + 0, + ) + .cuda() + .sort()[0] + ) # seq-first required for inputs with variable lengths - perm_q = torch.randperm(C_SEQ_LEN, device='cuda') - perm_k = torch.randperm(C_SEQ_LEN, device='cuda') - perm_v = torch.randperm(C_SEQ_LEN, device='cuda') - q = torch.linspace( - 0, 1, steps=C_SEQ_LEN, dtype=dtype, - device='cuda')[perm_q].view(1, C_SEQ_LEN, 1, 1).expand(1, C_SEQ_LEN, HQ, - D).clone().requires_grad_(True) - k = torch.linspace( - 0, 1, steps=C_SEQ_LEN, dtype=dtype, - device='cuda')[perm_k].view(1, C_SEQ_LEN, 1, 1).expand(1, C_SEQ_LEN, H, - D).clone().requires_grad_(True) - v = torch.linspace( - 0, 1, steps=C_SEQ_LEN, dtype=dtype, - device='cuda')[perm_v].view(1, C_SEQ_LEN, 1, 1).expand(1, C_SEQ_LEN, H, - D).clone().requires_grad_(True) - g_slc = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((1, C_SEQ_LEN, HQ, D), dtype=dtype, device='cuda') + perm_q = torch.randperm(C_SEQ_LEN, device="cuda") + perm_k = torch.randperm(C_SEQ_LEN, device="cuda") + perm_v = torch.randperm(C_SEQ_LEN, device="cuda") + q = ( + torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_q] + .view(1, C_SEQ_LEN, 1, 1) + .expand(1, C_SEQ_LEN, HQ, D) + .clone() + .requires_grad_(True) + ) + k = ( + torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_k] + .view(1, C_SEQ_LEN, 1, 1) + .expand(1, C_SEQ_LEN, H, D) + .clone() + .requires_grad_(True) + ) + v = ( + torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_v] + .view(1, C_SEQ_LEN, 1, 1) + .expand(1, C_SEQ_LEN, H, D) + .clone() + .requires_grad_(True) + ) + g_slc = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((1, C_SEQ_LEN, HQ, D), dtype=dtype, device="cuda") token_indices = prepare_token_indices(offsets).tolist() - block_indices = torch.full((1, C_SEQ_LEN, H, S), C_SEQ_LEN, dtype=torch.long, device='cuda') + block_indices = torch.full((1, C_SEQ_LEN, H, S), C_SEQ_LEN, dtype=torch.long, device="cuda") for i in range(C_SEQ_LEN): _, t = token_indices[i] for h in range(H): i_i = torch.randperm(max(1, tilelang.cdiv(t, block_size)))[:S] - block_indices[0, i, h, :len(i_i)] = i_i + block_indices[0, i, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (1, C_SEQ_LEN, H), device='cuda') + block_counts = torch.randint(1, S + 1, (1, C_SEQ_LEN, H), device="cuda") ref = naive_nsa( q=q, @@ -351,7 +359,8 @@ def parallel_nsa(q: torch.Tensor, block_indices=block_indices, block_counts=block_counts, block_size=block_size, - cu_seqlens=offsets) + cu_seqlens=offsets, + ) tri = parallel_nsa( q=q, @@ -362,7 +371,8 @@ def parallel_nsa(q: torch.Tensor, block_indices=block_indices, block_counts=block_counts, block_size=block_size, - cu_seqlens=offsets) + cu_seqlens=offsets, + ) print("tri", tri) print("ref", ref) diff --git a/examples/deepseek_nsa/example_triton_nsa_bwd.py b/examples/deepseek_nsa/example_triton_nsa_bwd.py index e912794a4..af05bfa70 100644 --- a/examples/deepseek_nsa/example_triton_nsa_bwd.py +++ b/examples/deepseek_nsa/example_triton_nsa_bwd.py @@ -8,6 +8,7 @@ import triton.language as tl import fla + if parse(fla.__version__) < parse("0.2.1"): from fla.ops.common.utils import prepare_token_indices else: @@ -17,21 +18,44 @@ from einops import rearrange -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) @triton.jit -def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices, - block_counts, offsets, token_indices, T, H: tl.constexpr, - HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, - S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, USE_OFFSETS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -46,20 +70,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc # else: NS = S - p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), - (G, BV), (1, 0)) + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_slc = tl.zeros([G, BV], dtype=tl.float32) - b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_slc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS @@ -72,7 +94,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) # [G, BS] b_s_slc = tl.dot(b_q, b_k_slc) - b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf')) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) # [G] b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc @@ -92,7 +114,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -105,8 +126,7 @@ def forward(ctx, q, k, v, block_indices, block_size, scale, offsets): # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] token_indices = prepare_token_indices(offsets) if offsets is not None else None - o, lse = parallel_nsa_fwd( - q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) + o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) ctx.save_for_backward(q, k, v, o, lse) ctx.block_indices = block_indices ctx.block_size = block_size @@ -134,7 +154,8 @@ def backward(ctx, do_slc, do_swa): window_size=ctx.window_size, scale=ctx.scale, offsets=ctx.offsets, - token_indices=ctx.token_indices) + token_indices=ctx.token_indices, + ) return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None @@ -199,37 +220,56 @@ def parallel_nsa_fwd( return o_slc, lse_slc, o_swa, lse_swa -@triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) +@triton.heuristics({"USE_OFFSETS": lambda args: args["offsets"] is not None}) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) -@triton.jit(do_not_specialize=['T']) -def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa, do_slc, do_swa, dk, - dv, block_mask, offsets, chunk_indices, scale, T, B: tl.constexpr, - H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, - V: tl.constexpr, M: tl.constexpr, BS: tl.constexpr, - WS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr): +@triton.jit(do_not_specialize=["T"]) +def parallel_nsa_bwd_kernel_dkv( + q, + k, + v, + lse_slc, + lse_swa, + delta_slc, + delta_swa, + do_slc, + do_swa, + dk, + dv, + block_mask, + offsets, + chunk_indices, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + M: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, +): i_v, i_s, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: - i_n, i_s = tl.load(chunk_indices + i_s * 2).to(tl.int32), tl.load(chunk_indices + i_s * 2 + - 1).to(tl.int32) + i_n, i_s = tl.load(chunk_indices + i_s * 2).to(tl.int32), tl.load(chunk_indices + i_s * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T - p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK), - (1, 0)) - p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), - (BS, BV), (1, 0)) - p_dk = tl.make_block_ptr(dk + (i_v * B * T * H + bos * H + i_h) * K, (T, K), (H * K, 1), - (i_s * BS, 0), (BS, BK), (1, 0)) - p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), - (BS, BV), (1, 0)) + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_v * B * T * H + bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0)) # [BS, BK] b_k = tl.load(p_k, boundary_check=(0, 1)) @@ -241,14 +281,12 @@ def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa, for i in range(i_s * BS, T): b_m_slc = tl.load(block_mask + (bos + i) * H * M + i_h * M + i_s) if b_m_slc: - p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_do_slc = tl.make_block_ptr(do_slc + (bos + i) * HQ * V, (HQ, V), (V, 1), - (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_do_slc = tl.make_block_ptr(do_slc + (bos + i) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i) * HQ + i_h * G + tl.arange(0, G) p_delta_slc = delta_slc + (bos + i) * HQ + i_h * G + tl.arange(0, G) # [G, BV] @@ -272,14 +310,12 @@ def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa, if WS > 0: o_s = i_s * BS + tl.arange(0, BS) if max(i_s * BS, i - WS + 1) < min((i_s + 1) * BS, i + 1): - p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), - (G, BK), (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_do_swa = tl.make_block_ptr(do_swa + (bos + i) * HQ * V, (HQ, V), (V, 1), - (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_do_swa = tl.make_block_ptr(do_swa + (bos + i) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_swa = lse_swa + (bos + i) * HQ + i_h * G + tl.arange(0, G) p_delta_swa = delta_swa + (bos + i) * HQ + i_h * G + tl.arange(0, G) # [G, BV] @@ -304,12 +340,19 @@ def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa, tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) -@triton.heuristics( - {'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor)}) +@triton.heuristics({"USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor)}) @triton.jit -def parallel_nsa_kernel_mask(block_indices, block_counts, block_mask, T: tl.constexpr, - H: tl.constexpr, S: tl.constexpr, BS: tl.constexpr, NS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_kernel_mask( + block_indices, + block_counts, + block_mask, + T: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + NS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_b, i_hs = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_h, i_s = i_hs // S, i_hs % S @@ -320,31 +363,56 @@ def parallel_nsa_kernel_mask(block_indices, block_counts, block_mask, T: tl.cons b_m = b_i * BS <= i_t if b_i < NS and b_i >= 0: - tl.store(block_mask + i_b * T * H * NS + i_t * H * NS + i_h * NS + b_i, - b_m.to(block_mask.dtype.element_ty)) + tl.store(block_mask + i_b * T * H * NS + i_t * H * NS + i_h * NS + b_i, b_m.to(block_mask.dtype.element_ty)) -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor) -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) -@triton.jit(do_not_specialize=['T']) -def parallel_nsa_bwd_kernel_dq(q, k, v, lse_slc, delta_slc, do_slc, lse_swa, delta_swa, do_swa, dq, - scale, block_indices, block_counts, offsets, token_indices, T, - B: tl.constexpr, H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr, - K: tl.constexpr, V: tl.constexpr, S: tl.constexpr, BS: tl.constexpr, - WS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, USE_BLOCK_COUNTS: tl.constexpr): +@triton.jit(do_not_specialize=["T"]) +def parallel_nsa_bwd_kernel_dq( + q, + k, + v, + lse_slc, + delta_slc, + do_slc, + lse_swa, + delta_swa, + do_swa, + dq, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: - i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + - 1).to(tl.int32) + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: @@ -449,27 +517,49 @@ def parallel_nsa_bwd_kernel_dq(q, k, v, lse_slc, delta_slc, do_slc, lse_swa, del tl.store(p_dq, (b_dq_slc + b_dq_swa).to(p_dq.dtype.element_ty), boundary_check=(0, 1)) -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) @triton.jit -def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices, - block_counts, offsets, token_indices, T, H: tl.constexpr, - HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, - S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, USE_OFFSETS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: - i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + - 1).to(tl.int32) + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: @@ -484,20 +574,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc else: NS = S - p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), - (G, BV), (1, 0)) + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_slc = tl.zeros([G, BV], dtype=tl.float32) - b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_slc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS @@ -510,7 +598,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) # [G, BS] b_s_slc = tl.dot(b_q, b_k_slc) - b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf')) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) # [G] b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc @@ -529,13 +617,12 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty)) if WS > 0: - p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), - (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_swa = lse_swa + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_swa = tl.zeros([G, BV], dtype=tl.float32) - b_m_swa = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_swa = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_swa = tl.zeros([G], dtype=tl.float32) for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS): p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) @@ -546,7 +633,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1)) # [G, BS] b_s_swa = tl.dot(b_q, b_k_swa) - b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float('-inf')) + b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float("-inf")) # [G] b_m_swa, b_mp_swa = tl.maximum(b_m_swa, tl.max(b_s_swa, 1)), b_m_swa @@ -593,14 +680,8 @@ def parallel_nsa_block_mask( block_mask = torch.zeros(B, T, H, NS, dtype=torch.bool, device=block_indices.device) parallel_nsa_kernel_mask[(T, B, H * S)]( - block_indices=block_indices, - block_counts=block_counts, - block_mask=block_mask, - T=T, - H=H, - S=S, - BS=BS, - NS=NS) + block_indices=block_indices, block_counts=block_counts, block_mask=block_mask, T=T, H=H, S=S, BS=BS, NS=NS + ) return block_mask @@ -676,7 +757,8 @@ def parallel_nsa_bwd( BS=BS, WS=WS, BK=BK, - BV=BV) + BV=BV, + ) dq = dq.sum(0) if offsets is not None: @@ -719,14 +801,14 @@ def parallel_nsa_bwd( BS=BS, WS=WS, BK=BK, - BV=BV) + BV=BV, + ) dk = dk.sum(0) return dq, dk, dv @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -749,7 +831,8 @@ def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, window_size=window_size, scale=scale, offsets=offsets, - token_indices=token_indices) + token_indices=token_indices, + ) ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) ctx.block_indices = block_indices ctx.block_counts = block_counts @@ -781,22 +864,25 @@ def backward(ctx, do_slc, do_swa): window_size=ctx.window_size, scale=ctx.scale, offsets=ctx.offsets, - token_indices=ctx.token_indices) + token_indices=ctx.token_indices, + ) return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None -def parallel_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -836,51 +922,49 @@ def parallel_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, - window_size, scale, cu_seqlens) + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) else: o = o_slc * g_slc.unsqueeze(-1) if head_first: - o = rearrange(o, 'b t h d -> b h t d') + o = rearrange(o, "b t h d -> b h t d") return o if __name__ == "__main__": B, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16 torch.random.manual_seed(0) - q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) - k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True) - v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True) - g_slc = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda') - - block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='cuda') + q = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + k = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + v = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device="cuda") for b in range(B): for t in range(T): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (B, T, H), device='cuda') + block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda") ref = naive_nsa( q=q, diff --git a/examples/deepseek_nsa/example_triton_nsa_fwd.py b/examples/deepseek_nsa/example_triton_nsa_fwd.py index 2c740013a..c9ab28daa 100644 --- a/examples/deepseek_nsa/example_triton_nsa_fwd.py +++ b/examples/deepseek_nsa/example_triton_nsa_fwd.py @@ -8,6 +8,7 @@ import triton.language as tl import fla + if parse(fla.__version__) < parse("0.2.1"): from fla.ops.common.utils import prepare_token_indices else: @@ -17,21 +18,44 @@ from einops import rearrange -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) @triton.jit -def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices, - block_counts, offsets, token_indices, T, H: tl.constexpr, - HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, - S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, USE_OFFSETS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -46,20 +70,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc # else: NS = S - p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), - (G, BV), (1, 0)) + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_slc = tl.zeros([G, BV], dtype=tl.float32) - b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_slc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS @@ -72,7 +94,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) # [G, BS] b_s_slc = tl.dot(b_q, b_k_slc) - b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf')) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) # [G] b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc @@ -92,7 +114,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -105,8 +126,7 @@ def forward(ctx, q, k, v, block_indices, block_size, scale, offsets): # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] token_indices = prepare_token_indices(offsets) if offsets is not None else None - o, lse = parallel_nsa_fwd( - q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) + o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) ctx.save_for_backward(q, k, v, o, lse) ctx.block_indices = block_indices ctx.block_size = block_size @@ -177,7 +197,6 @@ def parallel_nsa_fwd( @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -200,7 +219,8 @@ def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, window_size=window_size, scale=scale, offsets=offsets, - token_indices=token_indices) + token_indices=token_indices, + ) ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) ctx.block_indices = block_indices ctx.block_counts = block_counts @@ -212,18 +232,20 @@ def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa -def parallel_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -263,51 +285,49 @@ def parallel_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, - window_size, scale, cu_seqlens) + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) else: o = o_slc * g_slc.unsqueeze(-1) if head_first: - o = rearrange(o, 'b t h d -> b h t d') + o = rearrange(o, "b t h d -> b h t d") return o if __name__ == "__main__": B, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16 torch.random.manual_seed(0) - q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) - k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True) - v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True) - g_slc = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda') - - block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='cuda') + q = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + k = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + v = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device="cuda") for b in range(B): for t in range(T): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (B, T, H), device='cuda') + block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda") ref = naive_nsa( q=q, diff --git a/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py b/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py index 9ccbff6a4..cb4eb6d7b 100644 --- a/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py +++ b/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py @@ -8,6 +8,7 @@ import triton.language as tl import fla + if parse(fla.__version__) < parse("0.2.1"): from fla.ops.common.utils import prepare_token_indices else: @@ -17,27 +18,49 @@ from einops import rearrange -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) @triton.jit -def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices, - block_counts, offsets, token_indices, T, H: tl.constexpr, - HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, - S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, USE_OFFSETS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: - i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + - 1).to(tl.int32) + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: @@ -52,20 +75,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc else: NS = S - p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), - (G, BV), (1, 0)) + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_slc = tl.zeros([G, BV], dtype=tl.float32) - b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_slc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS @@ -78,7 +99,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) # [G, BS] b_s_slc = tl.dot(b_q, b_k_slc) - b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf')) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) # [G] b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc @@ -97,13 +118,12 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty)) if WS > 0: - p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), - (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_swa = lse_swa + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_swa = tl.zeros([G, BV], dtype=tl.float32) - b_m_swa = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_swa = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_swa = tl.zeros([G], dtype=tl.float32) for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS): p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) @@ -114,7 +134,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1)) # [G, BS] b_s_swa = tl.dot(b_q, b_k_swa) - b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float('-inf')) + b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float("-inf")) # [G] b_m_swa, b_mp_swa = tl.maximum(b_m_swa, tl.max(b_s_swa, 1)), b_m_swa @@ -196,7 +216,6 @@ def parallel_nsa_fwd( @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -219,7 +238,8 @@ def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, window_size=window_size, scale=scale, offsets=offsets, - token_indices=token_indices) + token_indices=token_indices, + ) ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) ctx.block_indices = block_indices ctx.block_counts = block_counts @@ -231,18 +251,20 @@ def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa -def parallel_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -282,29 +304,27 @@ def parallel_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, - window_size, scale, cu_seqlens) + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) else: o = o_slc * g_slc.unsqueeze(-1) if head_first: - o = rearrange(o, 'b t h d -> b h t d') + o = rearrange(o, "b t h d -> b h t d") return o @@ -312,38 +332,35 @@ def parallel_nsa(q: torch.Tensor, N, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 64, 1, 32, torch.float16 torch.manual_seed(42) # randomly split the sequence into N segments - offsets = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 1)[:N - 1]], - torch.tensor([T], dtype=torch.long) - ], 0).cuda().sort()[0] + offsets = ( + torch.cat( + [torch.tensor([0], dtype=torch.long), torch.arange(16, T)[torch.randperm(T - 1)[: N - 1]], torch.tensor([T], dtype=torch.long)], + 0, + ) + .cuda() + .sort()[0] + ) # offsets.shape is [N+1] # seq-first required for inputs with variable lengths - perm_q = torch.randperm(T, device='cuda') - perm_k = torch.randperm(T, device='cuda') - perm_v = torch.randperm(T, device='cuda') - q = torch.linspace( - 0, 1, steps=T, dtype=dtype, - device='cuda')[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True) - k = torch.linspace( - 0, 1, steps=T, dtype=dtype, - device='cuda')[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) - v = torch.linspace( - 0, 1, steps=T, dtype=dtype, - device='cuda')[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) - g_slc = torch.rand((1, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.rand((1, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((1, T, HQ, D), dtype=dtype, device='cuda') + perm_q = torch.randperm(T, device="cuda") + perm_k = torch.randperm(T, device="cuda") + perm_v = torch.randperm(T, device="cuda") + q = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True) + k = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) + v = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) + g_slc = torch.rand((1, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.rand((1, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((1, T, HQ, D), dtype=dtype, device="cuda") token_indices = prepare_token_indices(offsets).tolist() - block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device='cuda') + block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device="cuda") for i in range(T): _, t = token_indices[i] for h in range(H): i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S] - block_indices[0, i, h, :len(i_i)] = i_i + block_indices[0, i, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (1, T, H), device='cuda') + block_counts = torch.randint(1, S + 1, (1, T, H), device="cuda") ref = naive_nsa( q=q, @@ -354,7 +371,8 @@ def parallel_nsa(q: torch.Tensor, block_indices=block_indices, block_counts=block_counts, block_size=block_size, - cu_seqlens=offsets) + cu_seqlens=offsets, + ) tri = parallel_nsa( q=q, @@ -365,7 +383,8 @@ def parallel_nsa(q: torch.Tensor, block_indices=block_indices, block_counts=block_counts, block_size=block_size, - cu_seqlens=offsets) + cu_seqlens=offsets, + ) print("tri", tri) print("ref", ref) diff --git a/examples/deepseek_nsa/reference.py b/examples/deepseek_nsa/reference.py index 958d0c19e..58083108e 100644 --- a/examples/deepseek_nsa/reference.py +++ b/examples/deepseek_nsa/reference.py @@ -6,18 +6,20 @@ from einops import rearrange, repeat -def naive_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def naive_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -57,26 +59,24 @@ def naive_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - raise RuntimeError( - "Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") dtype = q.dtype G = q.shape[2] // k.shape[2] BS = block_size S = block_indices.shape[-1] - k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices)) + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) if isinstance(block_counts, torch.Tensor): - block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G) + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) q, k, v = map(lambda x: x.float(), (q, k, v)) @@ -86,14 +86,11 @@ def naive_nsa(q: torch.Tensor, if cu_seqlens is None: varlen = False B, T = q.shape[:2] - cu_seqlens = torch.cat( - [block_indices.new_tensor(range(0, B * T, T)), - block_indices.new_tensor([B * T])]) + cu_seqlens = torch.cat([block_indices.new_tensor(range(0, B * T, T)), block_indices.new_tensor([B * T])]) for i in range(len(cu_seqlens) - 1): if not varlen: - q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[ - i], block_indices[i] + q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[i], block_indices[i] if isinstance(block_counts, torch.Tensor): s_b = block_counts[i] else: @@ -101,10 +98,10 @@ def naive_nsa(q: torch.Tensor, else: T = cu_seqlens[i + 1] - cu_seqlens[i] q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map( - lambda x: x[0][cu_seqlens[i]:cu_seqlens[i + 1]], - (q, k, v, g_slc, g_swa, block_indices)) + lambda x: x[0][cu_seqlens[i] : cu_seqlens[i + 1]], (q, k, v, g_slc, g_swa, block_indices) + ) if isinstance(block_counts, torch.Tensor): - s_b = block_counts[0][cu_seqlens[i]:cu_seqlens[i + 1]] + s_b = block_counts[0][cu_seqlens[i] : cu_seqlens[i + 1]] else: s_b = block_counts @@ -126,34 +123,28 @@ def naive_nsa(q: torch.Tensor, else: s_i = s_b # [S*BS, HQ, -1] - k_i_slc, v_i_slc = map( - lambda x: x.gather( - 0, - i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) + k_i_slc, v_i_slc = map(lambda x: x.gather(0, i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) # [S*BS, HQ] - attn_slc = torch.einsum('h d, n h d -> n h', q_i, k_i_slc).masked_fill( - torch.logical_or(i_i < 0, i_i > i_q) | - (c >= s_i if block_counts is not None else False), float('-inf')).softmax(0) + attn_slc = ( + torch.einsum("h d, n h d -> n h", q_i, k_i_slc) + .masked_fill(torch.logical_or(i_i < 0, i_i > i_q) | (c >= s_i if block_counts is not None else False), float("-inf")) + .softmax(0) + ) if not varlen: - o_slc[i, i_q] = torch.einsum('n h, n h v -> h v', attn_slc, - v_i_slc) * g_slc_i.unsqueeze(-1) + o_slc[i, i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) else: - o_slc[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_slc, - v_i_slc) * g_slc_i.unsqueeze(-1) + o_slc[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) if window_size > 0: - k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1):i_q + 1], - (k_b, v_b)) - attn_swa = torch.einsum('h d, n h d -> n h', q_i, k_i_swa).softmax(0) + k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1) : i_q + 1], (k_b, v_b)) + attn_swa = torch.einsum("h d, n h d -> n h", q_i, k_i_swa).softmax(0) if not varlen: - o_swa[i, i_q] = torch.einsum('n h, n h v -> h v', attn_swa, - v_i_swa) * g_swa_i.unsqueeze(-1) + o_swa[i, i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) else: - o_swa[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_swa, - v_i_swa) * g_swa_i.unsqueeze(-1) + o_swa[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) if head_first: - o_slc = rearrange(o_slc, 'b t h d -> b h t d') - o_swa = rearrange(o_swa, 'b t h d -> b h t d') + o_slc = rearrange(o_slc, "b t h d -> b h t d") + o_swa = rearrange(o_swa, "b t h d -> b h t d") return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype) @@ -187,7 +178,7 @@ def naive_nsa_simple( o (torch.Tensor): Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 dtype = q.dtype HQ = q.shape[2] @@ -197,8 +188,8 @@ def naive_nsa_simple( BS = block_size S = block_indices.shape[-1] SELECTED_BLOCKS_SIZE = S * BS - k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices)) - block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G) + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) q, k, v = map(lambda x: x.float(), (q, k, v)) o = torch.zeros_like(v) @@ -228,10 +219,10 @@ def naive_nsa_simple( v_i[t, h] = v_b[selected_block_index, h, :] # [S*BS, HQ] - attn = torch.einsum('h d, n h d -> n h', q_i, k_i) - attn = attn.masked_fill((i_i > i_q) | (c >= s_i), float('-inf')) + attn = torch.einsum("h d, n h d -> n h", q_i, k_i) + attn = attn.masked_fill((i_i > i_q) | (c >= s_i), float("-inf")) attn = torch.softmax(attn, dim=0) - o[i, i_q] = torch.einsum('n h, n h v -> h v', attn, v_i) + o[i, i_q] = torch.einsum("n h, n h v -> h v", attn, v_i) return o.to(dtype) @@ -265,7 +256,7 @@ def naive_nsa_simple_inference( o (torch.Tensor): Outputs of shape `[B, 1, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 dtype = q.dtype HQ = q.shape[2] @@ -275,8 +266,8 @@ def naive_nsa_simple_inference( BS = block_size S = block_indices.shape[-1] SELECTED_BLOCKS_SIZE = S * BS - k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices)) - block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G) + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) q, k, v = map(lambda x: x.float(), (q, k, v)) o = torch.zeros_like(q) @@ -306,9 +297,9 @@ def naive_nsa_simple_inference( v_i[t, h] = v_b[selected_block_index, h, :] # [S*BS, HQ] - attn = torch.einsum('h d, n h d -> n h', q_i, k_i) - attn = attn.masked_fill((c >= s_i), float('-inf')) + attn = torch.einsum("h d, n h d -> n h", q_i, k_i) + attn = attn.masked_fill((c >= s_i), float("-inf")) attn = torch.softmax(attn, dim=0) - o[i, 0] = torch.einsum('n h, n h v -> h v', attn, v_i) + o[i, 0] = torch.einsum("n h, n h v -> h v", attn, v_i) return o.to(dtype) diff --git a/examples/deepseek_nsa/regression_example_tilelang_nsa.py b/examples/deepseek_nsa/regression_example_tilelang_nsa.py new file mode 100644 index 000000000..1858f045a --- /dev/null +++ b/examples/deepseek_nsa/regression_example_tilelang_nsa.py @@ -0,0 +1,15 @@ +import tilelang.testing +import example_tilelang_nsa_fwd +import example_tilelang_nsa_decode + + +def regression_example_tilelang_nsa_fwd(): + tilelang.testing.process_func(example_tilelang_nsa_fwd.run_regression_perf) + + +def regression_example_tilelang_nsa_fwd_decode(): + tilelang.testing.process_func(example_tilelang_nsa_decode.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/deepseek_nsa/requirements.txt b/examples/deepseek_nsa/requirements.txt index 777c2ad4c..e096dfd7d 100644 --- a/examples/deepseek_nsa/requirements.txt +++ b/examples/deepseek_nsa/requirements.txt @@ -1 +1 @@ -git+https://github.com/fla-org/flash-linear-attention@c3bd56589033610264532b11f0972c69e4645f6e \ No newline at end of file +git+https://github.com/fla-org/flash-linear-attention@c3bd56589033610264532b11f0972c69e4645f6e diff --git a/examples/deepseek_v32/README.md b/examples/deepseek_v32/README.md index 8457745b0..01a14b6b2 100644 --- a/examples/deepseek_v32/README.md +++ b/examples/deepseek_v32/README.md @@ -121,7 +121,7 @@ for i_i in T.Pipelined(NI, num_stages=num_stages): # ... compute attention over selected tokens ``` -This reduces compute from O(seq_len * seq_len_kv) to O(seq_len * topk). The causal mask is enforced by checking whether each index position is valid: +This reduces compute from O(seq_len *seq_len_kv) to O(seq_len* topk). The causal mask is enforced by checking whether each index position is valid: ```python for bi_i in T.Parallel(BI): @@ -193,10 +193,10 @@ for i_i in T.Pipelined(NI, num_stages=num_stages): # Load KV data for selected indices for bi_i, d_i in T.Parallel(BI, D): KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BI + bi_i], bz, d_i] - + # Recompute attention scores for backward T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) - + # Apply softmax gradient: dP = P * (dP_raw - Delta) for h_i, bi_i in T.Parallel(padded_H, BI): acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale @@ -204,7 +204,7 @@ for i_i in T.Pipelined(NI, num_stages=num_stages): The key gradient computations are: - **dQ = dP @ K** (query gradients) -- **dK = dP^T @ Q** (key gradients) +- **dK = dP^T @ Q** (key gradients) - **dV = P^T @ dO** (value gradients) **3. Atomic Sparse Updates**: Uses atomic operations for dKV accumulation: @@ -212,7 +212,7 @@ The key gradient computations are: ```python # Atomically update dKV at selected indices for bi_i, d_i in T.Parallel(BI // split_store, D // 4): - T.atomic_addx4(dKV[by, Indices[by, s_i, bz, i_i * BI + bi_i + s * (BI // split_store)], bz, d_i * 4], + T.atomic_addx4(dKV[by, Indices[by, s_i, bz, i_i * BI + bi_i + s * (BI // split_store)], bz, d_i * 4], acc_dkv_shared[bi_i, d_i * 4]) ``` diff --git a/examples/deepseek_v32/fp8_lighting_indexer.py b/examples/deepseek_v32/fp8_lighting_indexer.py index 21baa8fa8..03e88dd97 100644 --- a/examples/deepseek_v32/fp8_lighting_indexer.py +++ b/examples/deepseek_v32/fp8_lighting_indexer.py @@ -28,11 +28,11 @@ def validate_tensor_match(a, b, tolerance=1e-8, tensor_name="tensor", should_rai if should_raise: assert False if not torch.isclose( - a.masked_fill(a_finite, 0), - b.masked_fill(b_finite, 0), - rtol=0, - atol=0, - equal_nan=True, + a.masked_fill(a_finite, 0), + b.masked_fill(b_finite, 0), + rtol=0, + atol=0, + equal_nan=True, ).all(): display_error_message(f"{tensor_name} Error: nonfinite value mismatch") if should_raise: @@ -55,13 +55,10 @@ def get_configs(): threads=[128, 256], block_Q=[1, 2, 4], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] class SupplyProg: - def __init__(self): self.tensors_dict = {} @@ -88,7 +85,8 @@ def supply_prog(self, params): @tilelang.jit( pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - },) + }, +) def mqa_attn_return_logits( heads, index_dim, @@ -99,9 +97,9 @@ def mqa_attn_return_logits( ): if block_Q is None: block_Q = 128 // heads - dtype = "float8_e4m3" - accum_dtype = "float" - index_dtype = "int32" + dtype = T.float8_e4m3fn + accum_dtype = T.float32 + index_dtype = T.int32 seq_len = T.dynamic("seq_len") seq_len_kv = T.dynamic("seq_len_kv") @@ -113,46 +111,42 @@ def mqa_attn_return_logits( @T.prim_func def mqa_attn_return_logits_kernel( - IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore - IndexK: T.Tensor(index_k_shape, dtype), # type: ignore - IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype), # type: ignore - Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore - Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore - CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore - CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore + IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore + IndexK: T.Tensor(index_k_shape, dtype), # type: ignore + IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype), # type: ignore + Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore + Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore + CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore + CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx: - index_q_shared = T.alloc_shared([block_Q * heads, index_dim], dtype) index_k_shared = T.alloc_shared([block_N, index_dim], dtype) index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype) s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype) - s_reshaped = T.alloc_fragment([block_N, block_Q, heads], accum_dtype) + s_reshaped = T.reshape(s, (block_N, block_Q, heads)) logits = T.alloc_fragment([block_N, block_Q], accum_dtype) weights = T.alloc_fragment([block_Q, heads], accum_dtype) seq_len_i = bx * block_Q - cu_k_s_min = T.alloc_local([1], index_dtype) - cu_k_e_max = T.alloc_local([1], index_dtype) + cu_k_s_min = T.alloc_var(index_dtype) + cu_k_e_max = T.alloc_var(index_dtype) - cu_k_s_min[0] = 2147483647 - cu_k_e_max[0] = -2147483648 + cu_k_s_min = 2147483647 + cu_k_e_max = -2147483648 for bq_i in T.serial(block_Q): - cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], - seq_len_kv)) + cu_k_s_min = T.min(cu_k_s_min, T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv)) for bq_i in T.serial(block_Q): - cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], - seq_len_kv)) + cu_k_e_max = T.max(cu_k_e_max, T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv)) T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared) T.copy(Weights[seq_len_i, 0], weights) - for nbn_i in T.Pipelined( - T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages): - T.copy(IndexK[cu_k_s_min[0] + nbn_i * block_N, 0], index_k_shared) - T.copy(IndexKScale[cu_k_s_min[0] + nbn_i * block_N], index_k_scale_fragment) + for nbn_i in T.Pipelined(T.ceildiv(cu_k_e_max - cu_k_s_min, block_N), num_stages=num_stages): + T.copy(IndexK[cu_k_s_min + nbn_i * block_N, 0], index_k_shared) + T.copy(IndexKScale[cu_k_s_min + nbn_i * block_N], index_k_scale_fragment) T.gemm( index_k_shared, @@ -164,15 +158,14 @@ def mqa_attn_return_logits_kernel( ) for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads): - s_reshaped[bn_i, bq_i, - h_i] = (T.max(s[bn_i, bq_i * heads + h_i], 0) * - weights[bq_i, h_i]) * index_k_scale_fragment[bn_i] + s_reshaped[bn_i, bq_i, h_i] = (T.max(s_reshaped[bn_i, bq_i, h_i], 0) * weights[bq_i, h_i]) * index_k_scale_fragment[ + bn_i + ] T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) for bq_i, bn_i in T.Parallel(block_Q, block_N): - Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = ( - logits[bn_i, bq_i]) + Logits[seq_len_i + bq_i, cu_k_s_min + nbn_i * block_N + bn_i] = logits[bn_i, bq_i] return mqa_attn_return_logits_kernel @@ -185,38 +178,30 @@ def clean_logits_( seq_len = T.dynamic("seq_len") seq_len_kv = T.dynamic("seq_len_kv") - dtype = "float" - indices_dtype = "int32" + dtype = T.float + indices_dtype = T.int32 @T.prim_func def clean_logits_kernel( - Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore - CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore - CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore + Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore + CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore + CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore ): with T.Kernel(seq_len, threads=threads) as bx: tx = T.thread_binding(0, threads, thread="threadIdx.x") - cu_k_s = T.alloc_local([1], indices_dtype) - cu_k_e = T.alloc_local([1], indices_dtype) - cu_k_s[0] = CuSeqLenKS[bx] - cu_k_e[0] = CuSeqLenKE[bx] + cu_k_s = CuSeqLenKS[bx] + cu_k_e = CuSeqLenKE[bx] for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)): for k_i in T.serial(block_K // threads): idx = n_i * block_K + k_i * threads + tx - if idx < cu_k_s[0] or idx >= cu_k_e[0]: + if idx < cu_k_s or idx >= cu_k_e: Logits[bx, idx] = -T.infinity(dtype) return clean_logits_kernel -def mqa_attn_return_logits_interface(q, - kv, - kv_scales, - weights, - cu_seqlen_ks, - cu_seqlen_ke, - clean_logits=True): +def mqa_attn_return_logits_interface(q, kv, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=True): seq_len, heads, index_dim = q.shape seq_len_kv = kv.shape[0] @@ -238,57 +223,48 @@ def mqa_attn_return_logits_interface(q, return logits -def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, - cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor): +def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor): k = kv q = q.float() k = k.float() seq_len_kv = kv.shape[0] - mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None] - mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None] + mask_lo = torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] + mask_hi = torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] mask = mask_lo & mask_hi - score = torch.einsum('mhd,nd->hmn', q, k) + score = torch.einsum("mhd,nd->hmn", q, k) logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) - logits = logits.masked_fill(~mask, float('-inf')) + logits = logits.masked_fill(~mask, float("-inf")) cost = mask.sum() return logits, cost def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1): + # initial random seed to make the performance reproducible + torch.manual_seed(0) q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) weights = torch.randn(S, H, device="cuda", dtype=torch.float32) p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1) - ks, ke = generate_random_cu_seqlens( - per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048) + ks, ke = generate_random_cu_seqlens(per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048) - logits_ref, cost_ref = ref_fp8_mqa_logits( - q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + logits_ref, cost_ref = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) q_fp8 = q.to(torch.float8_e4m3fn) kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0,), False) - logits_tl = mqa_attn_return_logits_interface( - q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) - diff = validate_tensor_match( - logits_ref, logits_tl, tolerance=1e-14, tensor_name="logits", should_raise=False) + logits_tl = mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + diff = validate_tensor_match(logits_ref, logits_tl, tolerance=1e-14, tensor_name="logits", should_raise=False) print(f"diff: {diff}") from tilelang.profiler import do_bench def logits_fn(): - return mqa_attn_return_logits_interface( - q=q_fp8, - kv=kv_fp8, - kv_scales=kv_scales, - weights=weights, - cu_seqlen_ks=ks, - cu_seqlen_ke=ke) + return mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: logits_fn() @@ -302,5 +278,35 @@ def logits_fn(): print(f"cost_ref: {cost_ref}") +def run_regression_perf(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1): + torch.manual_seed(0) + q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) + kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) + weights = torch.randn(S, H, device="cuda", dtype=torch.float32) + p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1) + + ks, ke = generate_random_cu_seqlens(per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048) + + logits_ref, cost_ref = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0,), False) + + logits_tl = mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + diff = validate_tensor_match(logits_ref, logits_tl, tolerance=1e-14, tensor_name="logits", should_raise=False) + + from tilelang.profiler import do_bench + + def logits_fn(): + return mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + logits_fn() + + print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=50)) + + return do_bench(logits_fn, backend="cupti") + + if __name__ == "__main__": test_fp8_lighting_indexer() diff --git a/examples/deepseek_v32/inference/README.md b/examples/deepseek_v32/inference/README.md index fe4cc21bb..60afe7ceb 100644 --- a/examples/deepseek_v32/inference/README.md +++ b/examples/deepseek_v32/inference/README.md @@ -11,4 +11,4 @@ Launch the interactive chat interface and start exploring DeepSeek's capabilitie ```bash export CONFIG=config_671B_v3.2.json torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive -``` \ No newline at end of file +``` diff --git a/examples/deepseek_v32/inference/config_671B_v3.2.json b/examples/deepseek_v32/inference/config_671B_v3.2.json index be88f1cca..375aa9aa2 100644 --- a/examples/deepseek_v32/inference/config_671B_v3.2.json +++ b/examples/deepseek_v32/inference/config_671B_v3.2.json @@ -23,4 +23,4 @@ "index_n_heads": 64, "index_head_dim": 128, "index_topk": 2048 -} \ No newline at end of file +} diff --git a/examples/deepseek_v32/inference/convert.py b/examples/deepseek_v32/inference/convert.py index df7943918..090be7145 100644 --- a/examples/deepseek_v32/inference/convert.py +++ b/examples/deepseek_v32/inference/convert.py @@ -42,7 +42,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp): save_path (str): Path to the directory where the converted checkpoint files will be saved. n_experts (int): Total number of experts in the model. mp (int): Model parallelism factor. - + Returns: None """ diff --git a/examples/deepseek_v32/inference/kernel.py b/examples/deepseek_v32/inference/kernel.py index 262343536..25abf15d5 100644 --- a/examples/deepseek_v32/inference/kernel.py +++ b/examples/deepseek_v32/inference/kernel.py @@ -11,21 +11,21 @@ tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True, } -FP8 = "float8_e4m3" -BF16 = "bfloat16" -FP32 = "float32" +FP8 = T.float8_e4m3fn +BF16 = T.bfloat16 +FP32 = T.float32 def fast_log2_ceil(x): - bits_x = T.reinterpret("uint32", x) + bits_x = T.reinterpret(T.uint32, x) exp_x = (bits_x >> 23) & 0xFF man_bits = bits_x & ((1 << 23) - 1) - return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0)) + return T.Cast(T.int32, exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0)) def fast_pow2(x): bits_x = (x + 127) << 23 - return T.reinterpret("float32", bits_x) + return T.reinterpret(T.float32, bits_x) def fast_round_scale(amax, fp8_max_inv): @@ -107,8 +107,8 @@ def act_quant(x: torch.Tensor, @tilelang.jit(pass_configs=pass_configs) -def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"): - assert out_dtype in [BF16, "float32"] +def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=T.float32): + assert out_dtype in [BF16, T.float32] M = T.dynamic("M") group_size = 128 diff --git a/examples/deepseek_v32/inference/requirements.txt b/examples/deepseek_v32/inference/requirements.txt index 604fed552..8c208a8b1 100644 --- a/examples/deepseek_v32/inference/requirements.txt +++ b/examples/deepseek_v32/inference/requirements.txt @@ -2,4 +2,4 @@ torch transformers safetensors fast_hadamard_transform -tilelang==0.1.6 \ No newline at end of file +tilelang==0.1.6 diff --git a/examples/deepseek_v32/regression_tilelang_example_deepseek_v32.py b/examples/deepseek_v32/regression_tilelang_example_deepseek_v32.py new file mode 100644 index 000000000..0610002a6 --- /dev/null +++ b/examples/deepseek_v32/regression_tilelang_example_deepseek_v32.py @@ -0,0 +1,30 @@ +import tilelang.testing +import fp8_lighting_indexer +import sparse_mla_bwd +import sparse_mla_fwd +import sparse_mla_fwd_pipelined +import topk_selector + + +def regression_topk_selector(): + tilelang.testing.process_func(topk_selector.run_regression_perf) + + +def regression_fp8_lighting_indexer(): + tilelang.testing.process_func(fp8_lighting_indexer.run_regression_perf, S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1) + + +def regression_sparse_mla_fwd(): + tilelang.testing.process_func(sparse_mla_fwd.run_regression_perf, S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256) + + +def regression_sparse_mla_fwd_pipelined(): + tilelang.testing.process_func(sparse_mla_fwd_pipelined.run_regression_perf, S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256) + + +def regression_sparse_mla_bwd(): + tilelang.testing.process_func(sparse_mla_bwd.run_regression_perf, S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/deepseek_v32/sparse_mla_bwd.py b/examples/deepseek_v32/sparse_mla_bwd.py index e7f9c6093..527de22b3 100644 --- a/examples/deepseek_v32/sparse_mla_bwd.py +++ b/examples/deepseek_v32/sparse_mla_bwd.py @@ -13,18 +13,18 @@ def preprocess( D, block_ND=32, num_stages=5, - dtype="bfloat16", - accum_dtype="float", + dtype=T.bfloat16, + accum_dtype=T.float32, ): - assert dtype == "bfloat16" - assert accum_dtype == "float" + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 shape = [B, S, H, D] @T.prim_func def preprocess_kernel( - O: T.Tensor(shape, dtype), - dO: T.Tensor(shape, dtype), - Delta: T.Tensor([B, S, H], accum_dtype), + O: T.Tensor(shape, dtype), + dO: T.Tensor(shape, dtype), + Delta: T.Tensor([B, S, H], accum_dtype), ): with T.Kernel(H, T.ceildiv(S, block_ND), B) as (bx, by, bz): o = T.alloc_fragment([block_ND, block_ND], accum_dtype) @@ -33,16 +33,12 @@ def preprocess_kernel( acc = T.alloc_fragment([block_ND, block_ND], accum_dtype) T.clear(acc) for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages): - T.copy( - O[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], - o) - T.copy( - dO[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], - do) + T.copy(O[bz, by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], o) + T.copy(dO[bz, by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], do) for i, j in T.Parallel(block_ND, block_ND): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, by * block_ND:(by + 1) * block_ND, bx]) + T.copy(delta, Delta[bz, by * block_ND : (by + 1) * block_ND, bx]) return preprocess_kernel @@ -56,22 +52,22 @@ def postprocess( kv_group=1, block_N=64, threads=128, - dtype="bfloat16", - accum_dtype="float", + dtype=T.bfloat16, + accum_dtype=T.float32, ): - assert dtype == "bfloat16" - assert accum_dtype == "float" + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 dkv_shape = [B, S_kv, kv_group, D + D_tail] @T.prim_func def postprocess_kernel( - dKV: T.Tensor(dkv_shape, accum_dtype), - dKV_out: T.Tensor(dkv_shape, dtype), + dKV: T.Tensor(dkv_shape, accum_dtype), + dKV_out: T.Tensor(dkv_shape, dtype), ): with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, B, threads=threads) as (bx, by, bz): T.copy( - dKV[bz, bx * block_N:(bx + 1) * block_N, by, :], - dKV_out[bz, bx * block_N:(bx + 1) * block_N, by, :], + dKV[bz, bx * block_N : (bx + 1) * block_N, by, :], + dKV_out[bz, bx * block_N : (bx + 1) * block_N, by, :], ) return postprocess_kernel @@ -82,7 +78,9 @@ def postprocess_kernel( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True, + }, +) def bwd( B, S, @@ -97,18 +95,18 @@ def bwd( block_size=32, num_stages=0, threads=256, - indices_dtype="int32", - dtype="bfloat16", - accum_dtype="float", + indices_dtype=T.int32, + dtype=T.bfloat16, + accum_dtype=T.float32, ): - assert is_causal == True, 'non-casual is not supported now' - assert topk % block_size == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded' - assert dtype == "bfloat16" - assert accum_dtype == "float" - assert indices_dtype == "int32" + assert is_causal == True, "non-casual is not supported now" + assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + assert indices_dtype == T.int32 if sm_scale is None: - sm_scale = (D + D_tail)**(-0.5) + sm_scale = (D + D_tail) ** (-0.5) sm_scale_mul_reciprocal_log2 = sm_scale * 1.44269504 # log2(e) H_kv = H // kv_group @@ -118,12 +116,15 @@ def bwd( indices_shape = [B, S, kv_group, topk] delta_shape = [B, S, H] lse_shape = [B, S, H] - assert indices_dtype == "int32" - assert dtype == "bfloat16" - assert accum_dtype == "float" + assert indices_dtype == T.int32 + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 H = H_kv padded_H = max(tilelang.math.next_power_of_2(H_kv), 16) + block_H = min(64, padded_H) + assert padded_H % block_H == 0 + NH = padded_H // block_H BS = block_size NS = tilelang.cdiv(topk, block_size) @@ -131,122 +132,85 @@ def bwd( @T.prim_func def sparse_mla_bwd_kernel( - Q: T.Tensor(q_shape, dtype), - KV: T.Tensor(k_shape, dtype), - dO: T.Tensor(o_shape, dtype), - Indices: T.Tensor(indices_shape, indices_dtype), - Lse: T.Tensor(lse_shape, accum_dtype), - Delta: T.Tensor(delta_shape, accum_dtype), - dQ: T.Tensor(q_shape, dtype), - dKV: T.Tensor(k_shape, accum_dtype), + Q: T.Tensor(q_shape, dtype), + KV: T.Tensor(k_shape, dtype), + dO: T.Tensor(o_shape, dtype), + Indices: T.Tensor(indices_shape, indices_dtype), + Lse: T.Tensor(lse_shape, accum_dtype), + Delta: T.Tensor(delta_shape, accum_dtype), + dQ: T.Tensor(q_shape, dtype), + dKV: T.Tensor(k_shape, accum_dtype), ): - with T.Kernel(S, B, kv_group, threads=threads) as (s_i, by, bz): - Q_shared = T.alloc_shared([padded_H, D], dtype) - Q_tail_shared = T.alloc_shared([padded_H, D_tail], dtype) + with T.Kernel(S, B, kv_group * NH, threads=threads) as (s_i, by, bz): + Q_shared = T.alloc_shared([block_H, D], dtype) + Q_tail_shared = T.alloc_shared([block_H, D_tail], dtype) KV_shared = T.alloc_shared([BS, D], dtype) KV_tail_shared = T.alloc_shared([BS, D_tail], dtype) - dO_shared = T.alloc_shared([padded_H, D], dtype) + dO_shared = T.alloc_shared([block_H, D], dtype) mask = T.alloc_fragment([BS], "bool") - P_shared_cast = T.alloc_shared([padded_H, BS], dtype) - dP_shared_cast = T.alloc_shared([padded_H, BS], dtype) - dQ_shared = T.alloc_shared([padded_H, D], dtype) - dQ_tail_shared = T.alloc_shared([padded_H, D_tail], dtype) + P_shared_cast = T.alloc_shared([block_H, BS], dtype) + dP_shared_cast = T.alloc_shared([block_H, BS], dtype) + dQ_shared = T.alloc_shared([block_H, D], dtype) + dQ_tail_shared = T.alloc_shared([block_H, D_tail], dtype) - acc_p = T.alloc_fragment([padded_H, BS], accum_dtype) - acc_dp = T.alloc_fragment([padded_H, BS], accum_dtype) - acc_dq = T.alloc_fragment([padded_H, D], accum_dtype) - acc_dq_tail = T.alloc_fragment([padded_H, D_tail], accum_dtype) + acc_p = T.alloc_fragment([block_H, BS], accum_dtype) + acc_dp = T.alloc_fragment([block_H, BS], accum_dtype) + acc_dq = T.alloc_fragment([block_H, D], accum_dtype) + acc_dq_tail = T.alloc_fragment([block_H, D_tail], accum_dtype) acc_dkv = T.alloc_fragment([BS, D], accum_dtype) acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype) - acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype) - acc_dkv_tail_shared = T.view( - KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype) + acc_dkv_shared = T.alloc_shared([BS // split_store, D], accum_dtype) + acc_dkv_tail_shared = T.alloc_shared([BS // split_store, D_tail], accum_dtype) max_kv_i = s_i - T.copy(Q[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D], Q_shared) - T.copy(Q[by, s_i, bz * padded_H:(bz + 1) * padded_H, D:], Q_tail_shared) - T.copy(dO[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D], dO_shared) + T.copy(Q[by, s_i, bz * block_H : (bz + 1) * block_H, :D], Q_shared) + T.copy(Q[by, s_i, bz * block_H : (bz + 1) * block_H, D:], Q_tail_shared) + T.copy(dO[by, s_i, bz * block_H : (bz + 1) * block_H, :D], dO_shared) T.clear(acc_dq) T.clear(acc_dq_tail) - T.annotate_layout({ - dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared), - dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared), - }) - # Process each block of indices for i_i in T.Pipelined(NS, num_stages=num_stages): # Check which indices are valid for bi_i in T.Parallel(BS): - mask[bi_i] = Indices[by, s_i, bz, i_i * BS + bi_i] <= max_kv_i + mask[bi_i] = Indices[by, s_i, bz // NH, i_i * BS + bi_i] <= max_kv_i # Compute attention scores - for h_i, bi_i in T.Parallel(padded_H, BS): + for h_i, bi_i in T.Parallel(block_H, BS): acc_p[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_p.dtype)) # Load KV, V for this block of indices for bi_i, d_i in T.Parallel(BS, D): - KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, d_i] + KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz // NH, i_i * BS + bi_i], bz // NH, d_i] - T.gemm( - Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) for bi_i, d_i in T.Parallel(BS, D_tail): - KV_tail_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, - D + d_i] - T.gemm( - Q_tail_shared, - KV_tail_shared, - acc_p, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) - - for h_i, bi_i in T.Parallel(padded_H, BS): - acc_p[h_i, bi_i] = T.exp2(acc_p[h_i, bi_i] * sm_scale_mul_reciprocal_log2 - - Lse[by, s_i, bz * padded_H + h_i]) + KV_tail_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz // NH, i_i * BS + bi_i], bz // NH, D + d_i] + T.gemm(Q_tail_shared, KV_tail_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + + for h_i, bi_i in T.Parallel(block_H, BS): + acc_p[h_i, bi_i] = T.exp2(acc_p[h_i, bi_i] * sm_scale_mul_reciprocal_log2 - Lse[by, s_i, bz * block_H + h_i]) T.copy(acc_p, P_shared_cast) - T.gemm( - dO_shared, - KV_shared, - acc_dp, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol, - clear_accum=True) + T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) - for h_i, bi_i in T.Parallel(padded_H, BS): - acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * ( - acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale + for h_i, bi_i in T.Parallel(block_H, BS): + acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[by, s_i, bz * block_H + h_i]) * sm_scale T.copy(acc_dp, dP_shared_cast) T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol) T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - dP_shared_cast, - Q_shared, - acc_dkv, - transpose_A=True, - policy=T.GemmWarpPolicy.FullCol, - clear_accum=True) - T.gemm( - P_shared_cast, - dO_shared, - acc_dkv, - transpose_A=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) T.clear(acc_dkv_tail) - T.gemm( - dP_shared_cast, - Q_tail_shared, - acc_dkv_tail, - transpose_A=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(dP_shared_cast, Q_tail_shared, acc_dkv_tail, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) for s in range(split_store): for bi_i, d_i in T.Parallel(BS, D): @@ -255,41 +219,32 @@ def sparse_mla_bwd_kernel( for bi_i, d_i in T.Parallel(BS, D_tail): if bi_i < BS // split_store: - acc_dkv_tail_shared[bi_i, - d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), - d_i] + acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i] for bi_i, d_i in T.Parallel(BS // split_store, D // 4): T.atomic_addx4( - dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], - bz, d_i * 4], acc_dkv_shared[bi_i, d_i * 4]) + dKV[by, Indices[by, s_i, bz // NH, i_i * BS + bi_i + s * (BS // split_store)], bz // NH, d_i * 4], + acc_dkv_shared[bi_i, d_i * 4], + ) # Atomically update dKV, dKV_tail tensors for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4): T.atomic_addx4( - dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], - bz, D + d_i * 4], acc_dkv_tail_shared[bi_i, d_i * 4]) + dKV[by, Indices[by, s_i, bz // NH, i_i * BS + bi_i + s * (BS // split_store)], bz // NH, D + d_i * 4], + acc_dkv_tail_shared[bi_i, d_i * 4], + ) # Store the accumulated dQ T.copy(acc_dq, dQ_shared) T.copy(acc_dq_tail, dQ_tail_shared) - T.copy(dQ_shared, dQ[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D]) - T.copy(dQ_tail_shared, dQ[by, s_i, bz * padded_H:(bz + 1) * padded_H, D:]) + T.copy(dQ_shared, dQ[by, s_i, bz * block_H : (bz + 1) * block_H, :D]) + T.copy(dQ_tail_shared, dQ[by, s_i, bz * block_H : (bz + 1) * block_H, D:]) return sparse_mla_bwd_kernel -def sparse_mla_bwd(q, - kv, - o, - do, - indices, - lse, - sm_scale=None, - is_casual=True, - return_kernel=False, - delta=None): +def sparse_mla_bwd(q, kv, o, do, indices, lse, sm_scale=None, is_casual=True, return_kernel=False, delta=None): assert q.is_contiguous() assert kv.is_contiguous() assert indices.is_contiguous() @@ -322,6 +277,7 @@ def sparse_mla_bwd(q, def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, sm_scale=None, is_casual=True): from sparse_mla_fwd import ref_sparse_mla_fwd_interface + q = q.detach().clone() kv = kv.detach().clone() q.requires_grad = True @@ -331,30 +287,22 @@ def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, sm_scale=None, is_c return q.grad, kv.grad -def test_sparse_mla_bwd(B=1, - S=4096, - SKV=8192, - H=64, - HKV=1, - DQKV=576, - DV=512, - topk=2048, - dtype=torch.bfloat16, - check_correctness=True): +def test_sparse_mla_bwd(B=1, S=4096, SKV=8192, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16, check_correctness=True): # Prepare data - q = torch.randn((B, S, H, DQKV), dtype=dtype, device='cuda').requires_grad_(True) - kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((B, S, H, DV), dtype=dtype, device='cuda') + q = torch.randn((B, S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, S, H, DV), dtype=dtype, device="cuda") - indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device='cuda') + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") for b in range(B): for t in range(S): for h in range(HKV): i_i = torch.randperm(max(1, t))[:topk] - indices[b, t, h, :len(i_i)] = i_i + indices[b, t, h, : len(i_i)] = i_i # Forward from sparse_mla_fwd import sparse_mla_fwd_interface + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices) tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse) @@ -365,13 +313,15 @@ def test_sparse_mla_bwd(B=1, assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv") print("assert_tensors_similar passed") - per_token_flop = 2 * sum([ - H * DV * topk, - H * DQKV * topk, - H * DQKV * topk, - H * DQKV * topk, - H * DV * topk, - ]) + per_token_flop = 2 * sum( + [ + H * DV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DV * topk, + ] + ) from tilelang.profiler import do_bench def fn(): @@ -379,20 +329,44 @@ def fn(): ms = do_bench(fn, rep=100, warmup=250) print(f"Average time: {ms:.3f} ms") - print(f'bwd io bandwidth = ', - (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12) - print(f'bwd tflops = ', per_token_flop * S / (ms * 1e-3) / 1e12) + print(f"bwd io bandwidth = ", (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12) + print(f"bwd tflops = ", per_token_flop * S / (ms * 1e-3) / 1e12) + + +def run_regression_perf(B=1, S=4096, SKV=8192, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16): + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + q = torch.randn((B, S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, S, H, DV), dtype=dtype, device="cuda") + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") + for b in range(B): + for t in range(S): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[b, t, h, : len(i_i)] = i_i + + from sparse_mla_fwd import sparse_mla_fwd_interface + + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices) + B, S, H, dim_plus_tail_dim = q.shape + _, S_kv, kv_group, _ = kv.shape + D = 512 + D_tail = dim_plus_tail_dim - D + topk = indices.shape[-1] + preprocess_kernel = preprocess(B, S, H, D) + bwd_kernel = bwd(B, S, S_kv, H, D, D_tail, topk, kv_group, None, True) + delta = preprocess_kernel(tl_out, do) + dkv = torch.zeros_like(kv, dtype=torch.float32) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + return bwd_kernel(q, kv, do, indices, tl_lse, delta, dkv) + + return do_bench(run_kernel_only, backend="cupti") if __name__ == "__main__": - test_sparse_mla_bwd( - B=1, - S=4096, - SKV=8192, - H=64, - HKV=1, - DQKV=576, - DV=512, - topk=2048, - dtype=torch.bfloat16, - check_correctness=True) + test_sparse_mla_bwd(B=1, S=4096, SKV=8192, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16, check_correctness=True) diff --git a/examples/deepseek_v32/sparse_mla_fwd.py b/examples/deepseek_v32/sparse_mla_fwd.py index a39c72c40..2c8bf7fc7 100644 --- a/examples/deepseek_v32/sparse_mla_fwd.py +++ b/examples/deepseek_v32/sparse_mla_fwd.py @@ -25,15 +25,12 @@ def sparse_mla_fwd( num_stages=2, threads=256, ): - assert dim == tilelang.math.next_power_of_2( - dim), f"haven't check padding correctness yet, dim={dim}" - assert tail_dim == tilelang.math.next_power_of_2( - tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" assert is_causal == True, "non-casual is not supported" - assert (topk % - block_I == 0), "otherwise will load some index=0 thus causing wrong kv to be loaded" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" if sm_scale is None: - sm_scale = (1.0 / (dim + tail_dim))**0.5 * 1.44269504 # log2(e) + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e) else: sm_scale = sm_scale * 1.44269504 # log2(e) @@ -47,17 +44,17 @@ def sparse_mla_fwd( o_shape = [batch, seq_len, heads, dim] indices_shape = [batch, seq_len, kv_group, topk] lse_shape = [batch, seq_len, heads] - indices_dtype = "int32" - dtype = "bfloat16" - accum_dtype = "float" + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 G = kv_group H = head_kv padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) if padded_H != H: - assert ( - kv_group == 1 - ), "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) BI = block_I NI = tilelang.cdiv(topk, block_I) D = dim @@ -73,18 +70,17 @@ def sparse_mla_fwd( @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), # type: ignore - KV: T.Tensor(kv_shape, dtype), # type: ignore - Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore - Output: T.Tensor(o_shape, dtype), # type: ignore - Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore ): - with T.Kernel( - seq_len * REPLICATE_H, batch, kv_group, threads=threads) as ( - bx, - by, - bz, - ): + with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as ( + bx, + by, + bz, + ): Q_shared = T.alloc_shared([H_per_block, D], dtype) Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) KV_shared = T.alloc_shared([BI, D], dtype) @@ -118,16 +114,13 @@ def main( T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) for i_i in T.Pipelined(NI, num_stages=num_stages): - for bi_i in T.Parallel(BI): mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i for bi_i, d_i in T.Parallel(BI, D): - KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, - d_i] + KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i] for bi_i, d_i in T.Parallel(BI, D_tail): - K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, - D + d_i] + K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] for h_i, bi_i in T.Parallel(H_per_block, BI): acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) @@ -147,6 +140,8 @@ def main( ) T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(H_per_block): alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(H_per_block, BI): @@ -174,15 +169,7 @@ def main( return main -def sparse_mla_fwd_interface(q, - kv, - indices, - sm_scale=None, - return_p_sum: bool = False, - d_v=512, - block_I=64, - num_stages=2, - threads=256): +def sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512, block_I=64, num_stages=2, threads=256): is_casual = True assert return_p_sum == False, "This kernel file is for fwd only" assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() @@ -199,16 +186,8 @@ def sparse_mla_fwd_interface(q, assert indices.shape == (batch, seq_len, kv_group, topk) kernel = sparse_mla_fwd( - heads, - dim, - tail_dim, - topk, - kv_group, - sm_scale, - is_casual, - block_I=block_I, - num_stages=num_stages, - threads=threads) + heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual, block_I=block_I, num_stages=num_stages, threads=threads + ) out, lse = kernel(q, kv, indices) return out, lse @@ -228,14 +207,14 @@ def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True): b, _, _, dim_v = v.shape g_index = g h_index = h // g - compressed_casual_mask = torch.arange( - 0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( - 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda").view(1, -1) + compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( + 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda" + ).view(1, -1) mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) mask = mask[..., :-1] mask = mask & compressed_casual_mask.view(1, 1, sq, sk) - mask[:, :, :1 - 1, 0] = True + mask[:, :, : 1 - 1, 0] = True mask = mask.view(b, g_index, 1, sq, sk) q = q.view(b, sq, g, -1, dim_q) @@ -250,19 +229,21 @@ def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True): return o.to(torch.bfloat16) -def test_sparse_mla_fwd(B=1, - S=4096, - SKV=8192, - H=128, - HKV=1, - DQK=576, - DV=512, - topk=2048, - dtype=torch.bfloat16, - check_correctness=True, - block_I=64, - num_stages=2, - threads=256): +def test_sparse_mla_fwd( + B=1, + S=4096, + SKV=8192, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + check_correctness=True, + block_I=64, + num_stages=2, + threads=256, +): torch.random.manual_seed(0) q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) @@ -272,10 +253,9 @@ def test_sparse_mla_fwd(B=1, for t in range(S): for h in range(HKV): i_i = torch.randperm(max(1, t))[:topk] - indices[b, t, h, :len(i_i)] = i_i + indices[b, t, h, : len(i_i)] = i_i - tl_out, tl_lse = sparse_mla_fwd_interface( - q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads) + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads) if check_correctness: # otherwise may cause out of memory @@ -284,8 +264,7 @@ def test_sparse_mla_fwd(B=1, print("assert_tensors_similar passed") def fn(): - return sparse_mla_fwd_interface( - q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads) + return sparse_mla_fwd_interface(q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads) from tilelang.profiler import do_bench @@ -299,6 +278,36 @@ def fn(): print("fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) +def run_regression_perf( + B=1, S=4096, SKV=8192, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16, block_I=64, num_stages=2, threads=256 +): + torch.random.manual_seed(0) + q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") + for b in range(B): + for t in range(S): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[b, t, h, : len(i_i)] = i_i + + is_casual = True + _, _, heads, dim_plus_tail_dim = q.shape + _, _, kv_group, _ = kv.shape + dim = 512 + tail_dim = dim_plus_tail_dim - dim + _, _, _, topk = indices.shape + kernel = sparse_mla_fwd(heads, dim, tail_dim, topk, kv_group, None, is_casual, block_I=block_I, num_stages=num_stages, threads=threads) + + def run_kernel_only(): + kernel(q, kv, indices) + + from tilelang.profiler import do_bench + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": test_sparse_mla_fwd( B=1, @@ -313,4 +322,5 @@ def fn(): check_correctness=True, block_I=64, num_stages=2, - threads=256) + threads=256, + ) diff --git a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py index 96dda7df5..7e664d11b 100644 --- a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py +++ b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py @@ -9,10 +9,16 @@ @tilelang.jit( out_idx=[-2, -1], compile_flags=[ - "-O3", "-Wno-deprecated-declarations", "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", - "--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG" + "-O3", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--ptxas-options=-v,--register-usage-level=10", + "-DNDEBUG", ], ) def sparse_mla_fwd( @@ -32,14 +38,12 @@ def sparse_mla_fwd( num_stages=0, threads=384, ): - assert dim == tilelang.math.next_power_of_2( - dim), f"haven't check padding correctness yet, dim={dim}" - assert tail_dim == tilelang.math.next_power_of_2( - tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" - assert is_causal == True, 'non-casual is not supported' - assert topk % block_I == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded' + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert is_causal == True, "non-casual is not supported" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" if sm_scale is None: - sm_scale = (1.0 / (dim + tail_dim))**0.5 * 1.44269504 # log2(e) + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e) else: sm_scale = sm_scale * 1.44269504 # log2(e) @@ -49,23 +53,25 @@ def sparse_mla_fwd( o_shape = [batch, seq_len, heads, dim] indices_shape = [batch, seq_len, kv_group, topk] lse_shape = [batch, seq_len, heads] - indices_dtype = "int32" - dtype = "bfloat16" - accum_dtype = "float" + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 G = kv_group H = head_kv padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) if padded_H != H: - assert kv_group == 1, 'here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)' + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) BI = block_I NI = tilelang.cdiv(topk, block_I) - assert NI % 2 == 0, 'NI should be a multiple of 2' + assert NI % 2 == 0, "NI should be a multiple of 2" D = dim D_tail = tail_dim KV_stride = kv_stride if head_kv > 64: - assert head_kv % 64 == 0, 'head_kv should be a multiple of 64' + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" REPLICATE_H = head_kv // 64 else: REPLICATE_H = 1 @@ -74,18 +80,14 @@ def sparse_mla_fwd( @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), # type: ignore - KV: T.Tensor(kv_shape, dtype), # type: ignore - Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore - q_start_index_s: T.Tensor(1, indices_dtype), - Output: T.Tensor(o_shape, dtype), # type: ignore - Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + q_start_index_s: T.Tensor(1, indices_dtype), + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore ): - with T.Kernel( - (seq_len - kv_stride + 1 if CP0 else seq_len) * REPLICATE_H, - batch, - kv_group, - threads=threads) as (bx, by, bz): + with T.Kernel((seq_len - kv_stride + 1 if CP0 else seq_len) * REPLICATE_H, batch, kv_group, threads=threads) as (bx, by, bz): Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype) Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype) Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) @@ -110,7 +112,7 @@ def main( alpha_local = T.alloc_fragment([H_per_block], accum_dtype) m_i = T.alloc_fragment([H_per_block], accum_dtype) m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) - indices_local = T.alloc_local([1], indices_dtype) + indices_local = T.alloc_var(indices_dtype) # TODO: Multi buffer bar_q = T.alloc_barrier(arrive_count=384) @@ -122,8 +124,7 @@ def main( bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256) b_i, g_i = by, bz - s_i = (bx + (KV_stride - 1 if CP0 else 0)) if REPLICATE_H == 1 else ( - bx // REPLICATE_H + (KV_stride - 1 if CP0 else 0)) + s_i = (bx + (KV_stride - 1 if CP0 else 0)) if REPLICATE_H == 1 else (bx // REPLICATE_H + (KV_stride - 1 if CP0 else 0)) q_i = q_start_index_s[0] + s_i max_kv_i = (q_i + 1 - KV_stride) // KV_stride @@ -132,26 +133,24 @@ def main( tx = T.get_thread_binding() - T.copy(Q[b_i, s_i, H0:H1, 0:D // 2], Q_shared_l) - T.copy(Q[b_i, s_i, H0:H1, D // 2:D], Q_shared_r) + T.copy(Q[b_i, s_i, H0:H1, 0 : D // 2], Q_shared_l) + T.copy(Q[b_i, s_i, H0:H1, D // 2 : D], Q_shared_r) T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) T.barrier_arrive(bar_q) if tx < 128: T.set_max_nreg(240, 1) T.fill(sumexp, 0) - T.fill(m_i, -2**30) # avoid -inf - inf to cause nan + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan T.fill(acc_o_l, 0) T.barrier_wait(bar_q, 0) for i_i in T.serial(T.ceildiv(NI, 2)): - # Buffer 0 T.barrier_wait(bar_k_0_ready[0], (i_i & 1)) for h_i, bi_i in T.Parallel(H_per_block, BI): - acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, - -T.infinity(acc_s.dtype)) + acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1) @@ -164,6 +163,8 @@ def main( T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(H_per_block): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(H_per_block, BI): @@ -185,8 +186,7 @@ def main( T.barrier_wait(bar_k_1_ready[0], (i_i & 1)) for h_i, bi_i in T.Parallel(H_per_block, BI): - acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, - -T.infinity(acc_s.dtype)) + acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1) @@ -198,6 +198,8 @@ def main( T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(H_per_block): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(H_per_block, BI): @@ -223,7 +225,7 @@ def main( for h_i in T.Parallel(H_per_block): sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale T.copy(acc_o_l, O_shared_l) - T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0:D // 2]) + T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0 : D // 2]) elif tx >= 128 and tx < 256: T.set_max_nreg(168, 1) @@ -253,7 +255,7 @@ def main( acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] T.copy(acc_o_r, O_shared_r) - T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2:D]) + T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2 : D]) elif tx >= 256: # producer T.set_max_nreg(80, 0) @@ -261,70 +263,58 @@ def main( # Buffer 0 T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - indices_local[0] = Indices[b_i, s_i, g_i, - (i_i * 2) * BI + r * 16 + (tx - 256) // 8] - is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i + indices_local = Indices[b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8] + is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local <= max_kv_i if is_kv_valid[r * 16 + (tx - 256) // 8]: with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_0_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_0_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, D // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local, g_i, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local, g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, - D + (tx - 256) % 8 * 8 + v] + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local, g_i, D + (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_0_ready[0]) # Buffer 1 T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - indices_local[0] = Indices[b_i, s_i, g_i, - (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8] - is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i + indices_local = Indices[b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8] + is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local <= max_kv_i if is_kv_valid[r * 16 + (tx - 256) // 8]: with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_1_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_1_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, D // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local, g_i, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local, g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, - D + (tx - 256) % 8 * 8 + v] + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local, g_i, D + (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_1_ready[0]) return main -def sparse_mla_fwd_interface(q, - kv, - indices, - q_start_index_s, - kv_stride, - sm_scale=None, - is_casual=True, - return_kernel=False, - print_kernel=False): +def sparse_mla_fwd_interface( + q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_casual=True, return_kernel=False, print_kernel=False +): assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() batch, seq_len, heads, dim_plus_tail_dim = q.shape _, seq_len_kv, kv_group, _ = kv.shape - assert dim_plus_tail_dim == 576, 'you should assign dim otherwise' + assert dim_plus_tail_dim == 576, "you should assign dim otherwise" dim = 512 assert kv.shape[-1] == dim_plus_tail_dim @@ -334,29 +324,23 @@ def sparse_mla_fwd_interface(q, assert indices.shape == (batch, seq_len, kv_group, topk) if q_start_index_s != 0: - assert q_start_index_s > kv_stride, "If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)" + assert q_start_index_s > kv_stride, ( + "If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)" + ) CP0 = q_start_index_s == 0 - kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, - kv_group, sm_scale, is_casual, CP0) + kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_casual, CP0) if print_kernel: print(kernel.get_kernel_source()) - out, lse = kernel(q, kv, indices, - torch.tensor([q_start_index_s], dtype=torch.int32, device="cuda")) + out, lse = kernel(q, kv, indices, torch.tensor([q_start_index_s], dtype=torch.int32, device="cuda")) if return_kernel: return kernel if q_start_index_s == 0 and kv_stride > 1: - out[:, :kv_stride - 1, :, :] = 0 + out[:, : kv_stride - 1, :, :] = 0 return out, lse -def ref_sparse_mla_fwd_interface(q, - kv, - indices, - q_start_index_s, - kv_stride=4, - sm_scale=None, - is_casual=True): +def ref_sparse_mla_fwd_interface(q, kv, indices, q_start_index_s, kv_stride=4, sm_scale=None, is_casual=True): q = q.float() kv = kv.float() indices = indices.transpose(1, 2) @@ -365,7 +349,7 @@ def ref_sparse_mla_fwd_interface(q, if q_start_index_s is None: q_start_index_s = sk * kv_stride - sq - assert kv.shape[-1] == 576, 'you should assign dim otherwise' + assert kv.shape[-1] == 576, "you should assign dim otherwise" dim = 512 k = kv v = kv[..., :dim] @@ -374,15 +358,14 @@ def ref_sparse_mla_fwd_interface(q, num_kv_per_index = 1 g_index = g h_index = h // g - compressed_casual_mask = torch.arange( - q_start_index_s, sq + q_start_index_s, dtype=torch.int32, - device="cuda").view(-1, 1) >= torch.arange( - kv_stride - 1, sk * kv_stride, kv_stride, dtype=torch.int32, device="cuda").view(1, -1) + compressed_casual_mask = torch.arange(q_start_index_s, sq + q_start_index_s, dtype=torch.int32, device="cuda").view( + -1, 1 + ) >= torch.arange(kv_stride - 1, sk * kv_stride, kv_stride, dtype=torch.int32, device="cuda").view(1, -1) mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) mask = mask[..., :-1] mask = mask & compressed_casual_mask.view(1, 1, sq, sk) - mask[:, :, :kv_stride - 1, 0] = True + mask[:, :, : kv_stride - 1, 0] = True mask = mask.view(b, g_index, 1, sq, sk) q = q.view(b, sq, g, -1, dim_q) @@ -397,41 +380,32 @@ def ref_sparse_mla_fwd_interface(q, return o.to(torch.bfloat16) -def test_sparse_mla_fwd_pipelined(B=1, - S=4096, - SKV=8192, - H=128, - HKV=1, - DQK=576, - DV=512, - topk=2048, - dtype=torch.bfloat16, - q_start_s_index=1024, - check_correctness=True): +def test_sparse_mla_fwd_pipelined( + B=1, S=4096, SKV=8192, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16, q_start_s_index=1024, check_correctness=True +): KV_stride = 1 torch.random.manual_seed(0) - q = torch.randn((B, S, H, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10 - kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10 + q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 q_start_s_index_t = torch.tensor([q_start_s_index], dtype=torch.int32, device="cuda") q.clamp_(-10, 10) kv.clamp_(-10, 10) - indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device='cuda') + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") for b in range(B): for t in range(S): for h in range(HKV): i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk] - indices[b, t, h, :len(i_i)] = i_i + indices[b, t, h, : len(i_i)] = i_i - kernel = sparse_mla_fwd_interface( - q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True) + kernel = sparse_mla_fwd_interface(q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True) def fn(): out, lse = kernel(q, kv, indices, q_start_s_index_t) if q_start_s_index == 0 and KV_stride > 1: - out[:, :KV_stride - 1, :, :] = 0 + out[:, : KV_stride - 1, :, :] = 0 return out, lse tl_out, tl_lse = fn() @@ -442,14 +416,46 @@ def fn(): torch.testing.assert_close(tl_out, ref_out, rtol=1e-3, atol=1e-3) from tilelang.profiler import do_bench + ms = do_bench( fn, rep=10, warmup=10, ) print(f"Average time: {ms:.3f} ms") - print(f'fwd io bandwidth = ', (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) - print(f'fwd tflops = ', (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) + print(f"fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) + print(f"fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) + + +def run_regression_perf(B=1, S=4096, SKV=8192, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16, q_start_s_index=1024): + KV_stride = 1 + + torch.random.manual_seed(0) + q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 + q.clamp_(-10, 10) + kv.clamp_(-10, 10) + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") + for b in range(B): + for t in range(S): + for h in range(HKV): + i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk] + indices[b, t, h, : len(i_i)] = i_i + + batch, seq_len, heads, dim_plus_tail_dim = q.shape + _, seq_len_kv, kv_group, _ = kv.shape + dim = 512 + tail_dim = dim_plus_tail_dim - dim + CP0 = q_start_s_index == 0 + kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, KV_stride, kv_group, None, True, CP0) + + def run_kernel_only(): + kernel(q, kv, indices, torch.tensor([q_start_s_index], dtype=torch.int32, device="cuda")) + + from tilelang.profiler import do_bench + + return do_bench(run_kernel_only, backend="cupti") if __name__ == "__main__": @@ -460,5 +466,4 @@ def fn(): B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 else: B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 - test_sparse_mla_fwd_pipelined( - B, S, SKV, H, HKV, DQK, DV, topk, dtype, check_correctness=args.test_correctness) + test_sparse_mla_fwd_pipelined(B, S, SKV, H, HKV, DQK, DV, topk, dtype, check_correctness=args.test_correctness) diff --git a/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py b/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py index 971a3206c..983798f9f 100644 --- a/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py +++ b/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py @@ -1,42 +1,43 @@ # ruff: noqa +import tilelang import tilelang.testing -from topk_selector import test_topk_selector -from fp8_lighting_indexer import test_fp8_lighting_indexer -from sparse_mla_fwd import test_sparse_mla_fwd -from sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined -from sparse_mla_bwd import test_sparse_mla_bwd +import topk_selector +import fp8_lighting_indexer +import sparse_mla_fwd +import sparse_mla_fwd_pipelined +import sparse_mla_bwd def test_example_topk_selector(): - test_topk_selector() + topk_selector.test_topk_selector() def test_example_fp8_lighting_indexer(): - test_fp8_lighting_indexer(S=1024, SKV=2048, H=32, HKV=1, D=64, kv_stride=1) + fp8_lighting_indexer.test_fp8_lighting_indexer(S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1) @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_sparse_mla_fwd(): # small shapes for testing - test_sparse_mla_fwd( - S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) + sparse_mla_fwd.test_sparse_mla_fwd(S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_sparse_mla_fwd_pipelined(): # small shapes for testing - test_sparse_mla_fwd_pipelined( - S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) + sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined(S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_sparse_mla_bwd(): - test_sparse_mla_bwd( - S=256, SKV=1024, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False) + sparse_mla_bwd.test_sparse_mla_bwd(S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False) + sparse_mla_bwd.test_sparse_mla_bwd( + S=256, SKV=512, H=128, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False + ) # test for large H if __name__ == "__main__": diff --git a/examples/deepseek_v32/topk_selector.py b/examples/deepseek_v32/topk_selector.py index 4a4b43277..078eb2686 100644 --- a/examples/deepseek_v32/topk_selector.py +++ b/examples/deepseek_v32/topk_selector.py @@ -8,24 +8,24 @@ def convert_to_uint16(x): - hval = T.Cast("float16", x) - bits_uint = T.reinterpret("uint16", hval) + hval = T.Cast(T.float16, x) + bits_uint = T.reinterpret(T.uint16, hval) bits_uint = T.if_then_else(x < 0, ~bits_uint & (0xFFFF), bits_uint | (0x8000)) return bits_uint >> 8 def convert_to_uint32(x): - bits_uint = T.reinterpret("uint32", x) + bits_uint = T.reinterpret(T.uint32, x) bits_uint = T.if_then_else( x < 0, - ~bits_uint & T.Cast("uint32", (0xFFFFFFFF)), - bits_uint | T.Cast("uint32", (0x80000000)), + ~bits_uint & T.Cast(T.uint32, (0xFFFFFFFF)), + bits_uint | T.Cast(T.uint32, (0x80000000)), ) return bits_uint @tilelang.jit(pass_configs=pass_configs) -def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"): +def tl_topk_impl(topk, in_dtype=T.float32, out_dtype=T.int32): batch = T.dynamic("batch") seq_len = T.dynamic("seq_len") RADIX = 1 << 8 @@ -42,20 +42,20 @@ def tl_topk_kernel( with T.Kernel(batch, threads=BLOCK_SIZE) as (bx): tx = T.get_thread_binding() - s_threshold_bin_id = T.alloc_shared([1], "int32") - s_histogram = T.alloc_shared([RADIX + 1], "int32") - s_num_input = T.alloc_shared([2], "int32") - s_input_idx = T.alloc_shared([2, SMEM_INPUT_SIZE], "int32") - - l_threshold_bin_id = T.alloc_var("int32") - l_new_topk = T.alloc_var("int32") - l_num_input = T.alloc_var("int32") - l_bin_id32 = T.alloc_var("int32") - l_val = T.alloc_var("int32") - l_start_pos = T.alloc_var("int32") - l_start_idx = T.alloc_var("int32") - l_end_idx = T.alloc_var("int32") - l_out_pos = T.alloc_var("int32") + s_threshold_bin_id = T.alloc_shared([1], T.int32) + s_histogram = T.alloc_shared([RADIX + 1], T.int32) + s_num_input = T.alloc_shared([2], T.int32) + s_input_idx = T.alloc_shared([2, SMEM_INPUT_SIZE], T.int32) + + l_threshold_bin_id = T.alloc_var(T.int32) + l_new_topk = T.alloc_var(T.int32) + l_num_input = T.alloc_var(T.int32) + l_bin_id32 = T.alloc_var(T.int32) + l_val = T.alloc_var(T.int32) + l_start_pos = T.alloc_var(T.int32) + l_start_idx = T.alloc_var(T.int32) + l_end_idx = T.alloc_var(T.int32) + l_out_pos = T.alloc_var(T.int32) l_new_topk = topk l_start_idx = starts[bx] @@ -99,7 +99,7 @@ def tl_topk_kernel( input_idx = s * BLOCK_SIZE + tx if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len: bin_id = convert_to_uint16(input[bx, input_idx]) - l_bin_id32 = T.Cast("int32", bin_id) + l_bin_id32 = T.Cast(T.int32, bin_id) if l_bin_id32 > l_threshold_bin_id: # need a pos = T.atomic_add(s_histogram[bin_id32+1], 1) pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) @@ -127,9 +127,9 @@ def tl_topk_kernel( l_num_input = s_num_input[r_idx] for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): if s * BLOCK_SIZE + tx < l_num_input: - l_bin_id32 = T.Cast("int32", (( - convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> - (24 - round * 8)) & 0xFF)) + l_bin_id32 = T.Cast( + T.int32, ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF) + ) T.atomic_add(s_histogram[l_bin_id32], 1) T.sync_threads() # cumsum @@ -156,23 +156,20 @@ def tl_topk_kernel( for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): T.sync_threads() if s * BLOCK_SIZE + tx < l_num_input: - l_bin_id32 = T.Cast("int32", (( - convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> - (24 - round * 8)) & 0xFF)) + l_bin_id32 = T.Cast( + T.int32, ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF) + ) if l_bin_id32 > l_threshold_bin_id: - pos = T.atomic_add( - s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos + pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos index[bx, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: if round == 3: - l_out_pos = T.atomic_add( - s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos + l_out_pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos if l_out_pos < topk: index[bx, l_out_pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] else: pos = T.atomic_add(s_num_input[r_idx ^ 1], 1, return_prev=True) - s_input_idx[r_idx ^ 1, pos] = s_input_idx[r_idx, - s * BLOCK_SIZE + tx] + s_input_idx[r_idx ^ 1, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] return tl_topk_kernel @@ -186,7 +183,6 @@ def tl_topk(input, starts, ends, topk): def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): - batch = 64 seq_len = 32 * 1024 topk = 2048 @@ -212,8 +208,7 @@ def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): set_ref = set(ref_np) set_trt = set(trt_np) intersection = set_ref & set_trt - print("selected/all:", len(intersection), "/", len(set_ref), "=", - len(intersection) / len(set_ref)) + print("selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref)) # Performance test with CUDA events @@ -245,5 +240,35 @@ def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): print(f"Average torch.topk time: {elapsed_time_ms / n_iters:.3f} ms") +def run_regression_perf(batch=64, seq_len=32 * 1024, topk=2048): + batch = 64 + seq_len = 32 * 1024 + topk = 2048 + torch.manual_seed(1) + input = torch.randn(batch, seq_len, dtype=torch.float32).cuda() + starts = torch.zeros(batch, dtype=torch.int32).cuda() + ends = torch.ones(batch, dtype=torch.int32).cuda() * seq_len + + indexes = tl_topk(input, starts, ends, topk) + + indexes_ref = torch.topk(input, topk, dim=-1)[1] + + for i in range(batch): + ref_np = indexes_ref[i].cpu().to(torch.int32).numpy() + trt_np = indexes[i].cpu().to(torch.int32).numpy() + + set_ref = set(ref_np) + set_trt = set(trt_np) + intersection = set_ref & set_trt + print("selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref)) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + tl_topk(input, starts, ends, topk) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": test_topk_selector() diff --git a/examples/deepseek_v32/utils.py b/examples/deepseek_v32/utils.py index 2ea34b14a..d7252e171 100644 --- a/examples/deepseek_v32/utils.py +++ b/examples/deepseek_v32/utils.py @@ -23,8 +23,7 @@ def _is_equal(a, b): if isinstance(a, torch.Tensor): return a is b # Whitelist of types that are safe to compare by value for caching. - if isinstance(a, (int, float, str, bool, type(None))) and isinstance( - b, (int, float, str, bool, type(None))): + if isinstance(a, (int, float, str, bool, type(None))) and isinstance(b, (int, float, str, bool, type(None))): return a == b # For other types, we cannot guarantee a cheap and safe comparison, so we fail the cache check. return False @@ -58,9 +57,11 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: if len(args) == len(last_args) and len(kwargs) == len(last_kwargs): # For Tensors, check for object identity. For other types, check for equality. # Python caches small integers, so `is` works for them but not for large integers like 4096. - if all(_is_equal(a, b) for a, b in zip(args, last_args)) and \ - set(kwargs.keys()) == set(last_kwargs.keys()) and \ - all(_is_equal(v, last_kwargs[k]) for k, v in kwargs.items()): + if ( + all(_is_equal(a, b) for a, b in zip(args, last_args)) + and set(kwargs.keys()) == set(last_kwargs.keys()) + and all(_is_equal(v, last_kwargs[k]) for k, v in kwargs.items()) + ): return last_result result = fn(*args, **kwargs) @@ -79,73 +80,68 @@ def cal_seq_idx_from_cu_seqlens(cu_seqlens: torch.LongTensor, seq_len: int): @tensor_cache -def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, - seq_len: int) -> torch.IntTensor: - seq_idx_for_q = torch.full((seq_len,), - len(cu_seqlens_qs), - dtype=torch.int32, - device=cu_seqlens_qs.device) +def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, seq_len: int) -> torch.IntTensor: + seq_idx_for_q = torch.full((seq_len,), len(cu_seqlens_qs), dtype=torch.int32, device=cu_seqlens_qs.device) for i in range(len(cu_seqlens_qs)): - seq_idx_for_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = i + seq_idx_for_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = i return seq_idx_for_q @tensor_cache -def cal_cu_seqlen_ks_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, - cu_seqlens_ks: torch.LongTensor, seq_len: int) -> torch.IntTensor: +def cal_cu_seqlen_ks_for_q( + cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, cu_seqlens_ks: torch.LongTensor, seq_len: int +) -> torch.IntTensor: cu_seqlen_ks_for_each_q = torch.gather( - input=torch.cat([ - cu_seqlens_ks, - torch.full((1,), - torch.iinfo(torch.int32).max, - dtype=torch.int32, - device=cu_seqlens_qs.device) - ]), + input=torch.cat([cu_seqlens_ks, torch.full((1,), torch.iinfo(torch.int32).max, dtype=torch.int32, device=cu_seqlens_qs.device)]), dim=0, - index=cal_seq_idx_for_q( - cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long()) + index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(), + ) return cu_seqlen_ks_for_each_q.int() @tensor_cache -def cal_cu_seqlen_ke_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, - cu_seqlens_ks: torch.LongTensor, cu_seqlens_ke: torch.LongTensor, - q_start_idxs: torch.LongTensor, seq_len: int, - kv_stride: int) -> torch.IntTensor: +def cal_cu_seqlen_ke_for_q( + cu_seqlens_qs: torch.LongTensor, + cu_seqlens_qe: torch.LongTensor, + cu_seqlens_ks: torch.LongTensor, + cu_seqlens_ke: torch.LongTensor, + q_start_idxs: torch.LongTensor, + seq_len: int, + kv_stride: int, +) -> torch.IntTensor: cu_seqlen_ke_for_each_q = torch.gather( - input=torch.cat( - [cu_seqlens_ke, - torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]), + input=torch.cat([cu_seqlens_ke, torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]), dim=0, - index=cal_seq_idx_for_q( - cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long()) - casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,), - dtype=torch.int32, - device=cu_seqlens_qs.device) + index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(), + ) + casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,), dtype=torch.int32, device=cu_seqlens_qs.device) for i in range(len(cu_seqlens_qs)): - casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = (torch.arange( - q_start_idxs[i], - q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i], - dtype=torch.int32, - device=cu_seqlens_qs.device) + 1) // kv_stride + cu_seqlens_ks[i] + casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = ( + torch.arange( + q_start_idxs[i], q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i], dtype=torch.int32, device=cu_seqlens_qs.device + ) + + 1 + ) // kv_stride + cu_seqlens_ks[i] cu_seqlen_ke_for_each_q = torch.minimum(casual_cu_seqlen_ke_for_each_q, cu_seqlen_ke_for_each_q) return cu_seqlen_ke_for_each_q.int() @tensor_cache -def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor, - cu_seqlens_k: torch.LongTensor = None, - offs_q: torch.LongTensor = None, - *, - seq_len: int, - kv_stride: int = 1, - cp_rank: int = 0, - cp_size: int = 1, - balanced_cp=False): - ''' +def cal_ks_ke_from_cu_seqlen_qk( + cu_seqlens_q: torch.LongTensor, + cu_seqlens_k: torch.LongTensor = None, + offs_q: torch.LongTensor = None, + *, + seq_len: int, + kv_stride: int = 1, + cp_rank: int = 0, + cp_size: int = 1, + balanced_cp=False, +): + """ seq_len: seq len per cp rank balanced cp slice assignment: 0 1 2 3 3 2 1 0 - ''' + """ n_seq = len(cu_seqlens_q) - 1 assert n_seq > 0 assert cu_seqlens_q.shape == (n_seq + 1,) @@ -170,10 +166,12 @@ def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor, def f(x: torch.Tensor): chunks = x.chunk(cp_size * 2) - return torch.cat([ - chunks[cp_rank], - chunks[cp_size - cp_rank - 1], - ]) + return torch.cat( + [ + chunks[cp_rank], + chunks[cp_size - cp_rank - 1], + ] + ) ks = f(ks) ke = f(ke) @@ -189,8 +187,7 @@ def ceil_to_ue8m0(x: torch.Tensor): return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) -def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], - use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: +def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) sf = x_amax / 448.0 @@ -239,14 +236,18 @@ def generate_random_cu_seqlens(per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1, total_seqlen - (cp_rank + 1) * per_chunk_seqlen, total_seqlen - cp_rank * per_chunk_seqlen, ) - ks = torch.cat([ - cu_seqlens_ks_for_each_q[slice_short], - cu_seqlens_ks_for_each_q[slice_long], - ]) - ke = torch.cat([ - cu_seqlens_ke_for_each_q[slice_short], - cu_seqlens_ke_for_each_q[slice_long], - ]) + ks = torch.cat( + [ + cu_seqlens_ks_for_each_q[slice_short], + cu_seqlens_ks_for_each_q[slice_long], + ] + ) + ke = torch.cat( + [ + cu_seqlens_ke_for_each_q[slice_short], + cu_seqlens_ke_for_each_q[slice_long], + ] + ) assert len(ks) == len(ke) == per_cp_seqlen return ks, ke @@ -302,11 +303,9 @@ def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): raise_assert: Whether to raise assertion error on failure """ sim = calculate_tensor_similarity(x, y, name) - diff = 1. - sim + diff = 1.0 - sim if not (0 <= diff <= eps): - print( - f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m" - ) + print(f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m") if raise_assert: assert False # noqa: B011 @@ -316,11 +315,8 @@ def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): cu_seqlens = torch.randint(128, 4096, (1000,), dtype=torch.int32, device="cuda") last_idx = torch.where(cu_seqlens.cumsum(dim=0) >= seq_len)[0][0] cu_seqlens_cumsum = cu_seqlens[:last_idx].cumsum(dim=0) - cu_seqlens_qs = torch.cat( - [torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum]) - cu_seqlens_qe = torch.cat( - [cu_seqlens_cumsum, - torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len]) + cu_seqlens_qs = torch.cat([torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum]) + cu_seqlens_qe = torch.cat([cu_seqlens_cumsum, torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len]) from tilelang.profiler import do_bench diff --git a/examples/dequantize_gemm/README.md b/examples/dequantize_gemm/README.md index 0c6116775..25ef617a2 100644 --- a/examples/dequantize_gemm/README.md +++ b/examples/dequantize_gemm/README.md @@ -19,7 +19,7 @@ def dequant_matmul( T.clear(Ct_local) for k in T.Pipelined( - T.ceildiv(K, block_K), + T.ceildiv(K, block_K), num_stages=num_stages ): T.copy(A[by * block_M, k * block_K], A_shared) diff --git a/examples/dequantize_gemm/dequantize_utils.py b/examples/dequantize_gemm/dequantize_utils.py index b14c0aee6..90a6265ff 100644 --- a/examples/dequantize_gemm/dequantize_utils.py +++ b/examples/dequantize_gemm/dequantize_utils.py @@ -39,12 +39,10 @@ def torch_convert_bit_twiddling(tensor): res0 = val_concat_expanded & mask res1 = (val_concat_expanded << 3) & mask res2 = (val_concat_expanded << 6) & mask - res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | ( - (val_concat_expanded >> 7) & mask3) + res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | ((val_concat_expanded >> 7) & mask3) # Select the correct result based on position - bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1, - torch.where(pos == 2, res2, res3))) + bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1, torch.where(pos == 2, res2, res3))) # Convert to uint16 for .view(torch.bfloat16) bf16_uint16 = (bf16 & 0xFFFF).to(torch.uint16) @@ -110,7 +108,7 @@ def print_bit(name, val): val (torch.Tensor): A scalar PyTorch tensor (numeric) whose 32-bit binary representation will be shown. """ val_cpu = val.cpu().item() - binary_repr = f'{val_cpu:032b}' + binary_repr = f"{val_cpu:032b}" print(name, binary_repr) @@ -122,7 +120,7 @@ def calc_sim(x, y, name="tensor"): x, y = x.data.double(), y.data.double() denominator = (x * x + y * y).sum() if denominator == 0: - print_red_warning(f'{name} all zero') + print_red_warning(f"{name} all zero") return 1 sim = 2 * (x * y).sum() / denominator return sim @@ -132,21 +130,19 @@ def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True): x_mask = torch.isfinite(x) y_mask = torch.isfinite(y) if not torch.all(x_mask == y_mask): - print_red_warning(f'{name} Error: isfinite mask mismatch') + print_red_warning(f"{name} Error: isfinite mask mismatch") if raise_assert: raise AssertionError - if not torch.isclose( - x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, - equal_nan=True).all(): - print_red_warning(f'{name} Error: nonfinite value mismatch') + if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all(): + print_red_warning(f"{name} Error: nonfinite value mismatch") if raise_assert: raise AssertionError x = x.masked_fill(~x_mask, 0) y = y.masked_fill(~y_mask, 0) sim = calc_sim(x, y, name) - diff = (1. - sim).item() - print(f'{diff=}') + diff = (1.0 - sim).item() + print(f"{diff=}") if not (0 <= diff <= eps): - print_red_warning(f'{name} Error: {diff=}') + print_red_warning(f"{name} Error: {diff=}") if raise_assert: raise AssertionError diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py index e30845b8d..36b32c0a8 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py @@ -24,6 +24,7 @@ def get_configs(): the parameter name to its chosen value. """ import itertools + iter_params = dict( block_M=[64, 128, 256], block_N=[64, 128, 256], @@ -32,65 +33,64 @@ def get_configs(): threads=[128, 256, 512], split=[1, 2], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] -@tilelang.autotune(configs=get_configs(),) +@tilelang.autotune( + configs=get_configs(), +) @tilelang.jit( out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) -def matmul(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - source_format='uint', - num_bits=4, - fast_dequant=True, - block_M=256, - block_N=128, - block_K=128, - num_stages=2, - threads=256, - split=1): +def matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + source_format=T.uint32, + num_bits=4, + fast_dequant=True, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, +): + """ + Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T. + + This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts: + - A: dense input of shape (M, K) with dtype `in_dtype`. + - B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`. + - C: output of shape (M, N) with dtype `out_dtype`. + + The generated kernel supports two dequantization paths: + - fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group. + - simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element. + + Important behavior and requirements: + - num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits. + - QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes. + - Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid. + - When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group. + - The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages. + + Parameters that alter kernel layout/behavior (brief): + - block_M, block_N, block_K: tile sizes for M, N, and K dimensions. + - num_stages: number of software pipeline stages for the K-loop. + - threads: number of threads used per kernel block. + - split: extra K-splitting factor; K must be divisible by block_K * split. + - source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics. + + Returns: + A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel. """ - Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T. - - This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts: - - A: dense input of shape (M, K) with dtype `in_dtype`. - - B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`. - - C: output of shape (M, N) with dtype `out_dtype`. - - The generated kernel supports two dequantization paths: - - fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group. - - simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element. - - Important behavior and requirements: - - num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits. - - QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes. - - Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid. - - When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group. - - The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages. - - Parameters that alter kernel layout/behavior (brief): - - block_M, block_N, block_K: tile sizes for M, N, and K dimensions. - - num_stages: number of software pipeline stages for the K-loop. - - threads: number of threads used per kernel block. - - split: extra K-splitting factor; K must be divisible by block_K * split. - - source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics. - - Returns: - A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel. - """ num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 QK = K // num_elems_per_byte Block_QK = block_K // num_elems_per_byte @@ -121,7 +121,7 @@ def matmul(M, assert func_name is not None, "mxfp_intrin_info is not found" import_source = import_source - def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): + def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16): """ Create a TileLang macro that performs fast, twiddling-based dequantization from packed FP4 to BF16 using an external runtime plugin. @@ -131,13 +131,13 @@ def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): - Writes the dequantized BF16 values back to a shared dequantized buffer for use by the kernel. Notes and preconditions: - - Asserts that `in_dtype == "fp4"` and `out_dtype == "bfloat16"`. + - Asserts that `in_dtype == "fp4"` and `out_dtype == T.bfloat16`. - The generated macro depends on several surrounding-scope symbols (e.g., `import_source`, `func_name`, `block_K`, `Block_QK`, `threads`, `num_elems_per_byte`, `storage_dtype`, and `out_dtype`) and expects them to be defined consistently in the enclosing kernel. - The macro is optimized for block-wise, per-thread transactions sized to the target storage width (uses a MAX_TRANSACTION_SIZE_BITS constant) and uses local/register buffers sized accordingly. - The macro uses `T.import_source` to bring the external plugin into the module and `T.call_extern` to perform the high-throughput dequantization; callers must ensure the external function matches the expected calling convention and memory layout. """ assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] + assert out_dtype in [T.bfloat16] # Some variables for dequantization in each thread MAX_TRANSACTION_SIZE_BITS = 128 @@ -189,12 +189,11 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared): # Finally, store the dequantized data to shared memory. for v in T.vectorized(0, local_size): index = i * threads * local_size + tx * local_size + v - B_dequantize_shared[index // block_K, - index % block_K] = B_dequantize_local_thread[v] + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] return fast_dequant_bf16_fp4_twiddling - def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): + def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16): """ Create a simple TIR dequantization macro that converts packed 4-bit FP (FP4) stored in uint8 into bfloat16. @@ -205,7 +204,7 @@ def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): - Writes the dequantized bfloat16 block into B_dequantize_shared. Constraints: - - Supports only in_dtype="fp4" and out_dtype="bfloat16". + - Supports only in_dtype="fp4" and out_dtype=T.bfloat16. - The helper assumes nbit == 4 and produces bfloat16 values. - The macro uses a fixed test-scale of 0 (no per-element scaling) as written. @@ -213,49 +212,49 @@ def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): A TIR macro function performing the described in-place block dequantization from packed uint8 FP4 to bfloat16. """ assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] + assert out_dtype in [T.bfloat16] - def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, - scale: tir.PrimExpr, dtype: str): + def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str): """ - Convert a 4-bit FP4 value packed in a uint8 byte into a bfloat16 value. - - This helper extracts the 4-bit field located at the bit position `pos` within the - byte `val`, interprets it as an FP4 (sign, exponent, mantissa) value, applies an - exponent `scale` offset to align it with bfloat16 exponent bias, clamps the - resulting exponent to 8 bits, and returns the assembled bfloat16 bit pattern. - - Parameters: - nbit (int): Number of bits in the packed element; must be 4. - val (tir.PrimExpr): A uint8 value containing packed FP4 elements. - pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract. - scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16. - dtype (str): Target dtype string; must be "bfloat16". - - Returns: - tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value. - - Notes: - - The function asserts `nbit == 4`, `dtype == "bfloat16"`, and that `val.dtype` is "uint8". - - The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16 - bit fields and clamps the computed exponent to fit into 8 bits. + Convert a 4-bit FP4 value packed in a uint8 byte into a bfloat16 value. + + This helper extracts the 4-bit field located at the bit position `pos` within the + byte `val`, interprets it as an FP4 (sign, exponent, mantissa) value, applies an + exponent `scale` offset to align it with bfloat16 exponent bias, clamps the + resulting exponent to 8 bits, and returns the assembled bfloat16 bit pattern. + + Parameters: + nbit (int): Number of bits in the packed element; must be 4. + val (tir.PrimExpr): A uint8 value containing packed FP4 elements. + pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract. + scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16. + dtype (str): Target dtype string; must be T.bfloat16. + + Returns: + tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value. + + Notes: + - The function asserts `nbit == 4`, `dtype == T.bfloat16`, and that `val.dtype` is T.uint8. + - The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16 + bit fields and clamps the computed exponent to fit into 8 bits. """ assert nbit == 4 - assert dtype == "bfloat16" - assert val.dtype == "uint8" - mask = tir.const((1 << nbit) - 1, "uint16") - f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask - s = f4 >> tir.const(3, "uint16") - e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16") + assert dtype == T.bfloat16 + assert val.dtype == T.uint8 + mask = tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 - e_bf16 = e_f4 + tir.const(126, "uint16") + e_bf16 = e_f4 + tir.const(126, T.uint16) # Scale is the exponential part, within the representation of uint8 # To handle the overflow, we use the max function to limit the exponential part to 8 bits - e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) - m_f4 = f4 & tir.const(1, "uint16") + e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, T.uint16)) + m_f4 = f4 & tir.const(1, T.uint16) val_bf16 = tir.reinterpret( - "bfloat16", ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) - | (m_f4 << tir.const(6, "uint16"))).astype("uint16")) + T.bfloat16, + ((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16), + ) return val_bf16 @T.macro @@ -292,32 +291,32 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared): @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((M, N), out_dtype), ): """ - Kernel entry for the tiled, pipelined matmul used by the generated prim_func. - - This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it: - - Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile. - - Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout. - - Pipelines over K in chunks of `block_K` for `num_stages` stages: - - Loads A and packed B tiles into shared memory. - - Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine. - - Performs a GEMM accumulating into C_local with B transposed. - - Stores the accumulated block from C_local back to the global output C via C_shared. - - Parameters: - - A: input tile of shape (M, K) with dtype `in_dtype`. - - B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing). - - C: output tensor of shape (M, N) with dtype `out_dtype`. - - Side effects: - - Writes the computed output block into the global tensor `C`. - - Uses and updates shared memory buffers and per-thread accumulators. - - No value is returned. + Kernel entry for the tiled, pipelined matmul used by the generated prim_func. + + This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it: + - Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile. + - Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout. + - Pipelines over K in chunks of `block_K` for `num_stages` stages: + - Loads A and packed B tiles into shared memory. + - Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine. + - Performs a GEMM accumulating into C_local with B transposed. + - Stores the accumulated block from C_local back to the global output C via C_shared. + + Parameters: + - A: input tile of shape (M, K) with dtype `in_dtype`. + - B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing). + - C: output tensor of shape (M, N) with dtype `out_dtype`. + + Side effects: + - Writes the computed output block into the global tensor `C`. + - Uses and updates shared memory buffers and per-thread accumulators. + + No value is returned. """ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -327,10 +326,6 @@ def main( C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype) - T.annotate_layout({ - C_shared: tilelang.layout.make_swizzled_layout(C_shared), - }) - T.clear(C_local) for k in T.Pipelined(K // block_K, num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) @@ -344,7 +339,7 @@ def main( T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N]) + T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) return main @@ -363,7 +358,7 @@ def ref_program_twiddling(A, qB): Returns: torch.Tensor: Result matrix C with shape (M, N) in bfloat16. """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert_bit_twiddling(qB) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) @@ -383,7 +378,7 @@ def ref_program_simple(A, qB): Returns: torch.Tensor: Resulting matrix C in bfloat16 with shape (M, N). """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert(qB) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) @@ -409,16 +404,15 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False): """ total_flops = 2 * m * n * k if tune: - kernel = matmul( - m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, fast_dequant=fast_dequant) + kernel = matmul(m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, fast_dequant=fast_dequant) else: kernel = matmul( m, n, k, - "bfloat16", - "bfloat16", - "float32", + T.bfloat16, + T.bfloat16, + T.float32, num_bits=4, fast_dequant=fast_dequant, block_M=256, @@ -426,7 +420,8 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False): block_K=128, num_stages=2, threads=256, - split=1) + split=1, + ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) if fast_dequant: profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01) @@ -437,6 +432,27 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False): print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(m=4096, n=4096, k=4096, fast_dequant=True): + kernel = matmul( + m, + n, + k, + "bfloat16", + "bfloat16", + "float32", + num_bits=4, + fast_dequant=fast_dequant, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, + ) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main(256, 256, 256, True) main(256, 256, 256, False) diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py index ac1417aeb..cc37c8bc4 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py @@ -7,45 +7,45 @@ from dequantize_utils import torch_convert_bit_twiddling, torch_convert -def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, - dtype: str): +def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str): """ - Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. + Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. - This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its - bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by - `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. + This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its + bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by + `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. - Parameters: - nbit (int): Number of bits in the packed field (must be 4). - val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. - pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). - scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). - dtype (str): Destination dtype string (must be "bfloat16"). + Parameters: + nbit (int): Number of bits in the packed field (must be 4). + val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. + pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). + scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). + dtype (str): Destination dtype string (must be T.bfloat16). - Returns: - tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. + Returns: + tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. - Notes: - - Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8". - - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. - """ + Notes: + - Preconditions are enforced via assertions: nbit == 4, dtype == T.bfloat16, and val.dtype == T.uint8. + - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. + """ assert nbit == 4 - assert dtype == "bfloat16" - assert val.dtype == "uint8" - mask = tir.const((1 << nbit) - 1, "uint16") - f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask - s = f4 >> tir.const(3, "uint16") - e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16") + assert dtype == T.bfloat16 + assert val.dtype == T.uint8 + mask = tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 - e_bf16 = e_f4 + tir.const(126, "uint16") + e_bf16 = e_f4 + tir.const(126, T.uint16) # Scale is the exponential part, within the representation of uint8 # To handle the overflow, we may use the min function to limit the exponential part to 8 bits # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) - m_f4 = f4 & tir.const(1, "uint16") - val_bf16 = tir.reinterpret("bfloat16", - ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) - | (m_f4 << tir.const(6, "uint16"))).astype("uint16")) + m_f4 = f4 & tir.const(1, T.uint16) + val_bf16 = tir.reinterpret( + T.bfloat16, + ((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16), + ) return val_bf16 @@ -65,6 +65,7 @@ def get_configs(): List[dict]: A list of configuration dictionaries covering all combinations. """ import itertools + iter_params = dict( block_M=[64, 128, 256], block_N=[64, 128, 256], @@ -73,70 +74,74 @@ def get_configs(): threads=[128, 256, 512], split=[1, 2], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] - - -@tilelang.autotune(configs=get_configs(),) -@tilelang.jit(out_idx=[-1],) -def matmul(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - source_format='uint', - num_bits=4, - scale_size=32, - fast_dequant=True, - with_bias=False, - block_M=256, - block_N=128, - block_K=128, - num_stages=2, - threads=256, - split=1): + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit( + out_idx=[-1], +) +def matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + source_format=T.uint32, + num_bits=4, + scale_size=32, + fast_dequant=True, + with_bias=False, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, +): """ - Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype. - - The generated kernel accepts: - - A: dense matrix with element type `in_dtype`. - - B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)). - - Scale: per-block scale/exponent information used to dequantize B. - The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: - - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. - - fast_dequant (False): uses a simple elementwise dequantization helper. - - Parameters: - M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). - in_dtype (str): element type of A (e.g., "fp4" in this file). - out_dtype (str): output tensor element type (e.g., "bfloat16"). - accum_dtype (str): accumulation type used for the inner GEMM. - source_format (str, optional): format string passed to intrinsic selector (default "uint"). - num_bits (int, optional): number of bits per quantized element in B (default 4). - scale_size (int, optional): number of elements grouped per scale entry (default 32). - fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). - block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). - num_stages (int, optional): pipelining stages for K loop (default 2). - threads (int, optional): threads per block used by the kernel (default 256). - split (int, optional): split factor along K used by the scheduler (default 1). - with_bias (bool, optional): whether to add Bias to the output (default False). - - Returns: - A T.prim_func implementing the tiled, pipelined GEMM that: - - loads tiled blocks of A and packed B to shared memory, - - dequantizes B via the chosen path into a shared dequantized tile, - - performs a tiled GEMM accumulating into local fragments, - - writes the final MxN block to the global output tensor. + Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype. - Notes: - - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. - - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. - - An assertion enforces that K % (block_K * split) == 0. + The generated kernel accepts: + - A: dense matrix with element type `in_dtype`. + - B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)). + - Scale: per-block scale/exponent information used to dequantize B. + The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: + - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. + - fast_dequant (False): uses a simple elementwise dequantization helper. + + Parameters: + M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). + in_dtype (str): element type of A (e.g., "fp4" in this file). + out_dtype (str): output tensor element type (e.g., T.bfloat16). + accum_dtype (str): accumulation type used for the inner GEMM. + source_format (str, optional): format string passed to intrinsic selector (default "uint"). + num_bits (int, optional): number of bits per quantized element in B (default 4). + scale_size (int, optional): number of elements grouped per scale entry (default 32). + fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). + block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). + num_stages (int, optional): pipelining stages for K loop (default 2). + threads (int, optional): threads per block used by the kernel (default 256). + split (int, optional): split factor along K used by the scheduler (default 1). + with_bias (bool, optional): whether to add Bias to the output (default False). + + Returns: + A T.prim_func implementing the tiled, pipelined GEMM that: + - loads tiled blocks of A and packed B to shared memory, + - dequantizes B via the chosen path into a shared dequantized tile, + - performs a tiled GEMM accumulating into local fragments, + - writes the final MxN block to the global output tensor. + + Notes: + - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. + - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. + - An assertion enforces that K % (block_K * split) == 0. """ num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 QK = K // num_elems_per_byte Block_QK = block_K // num_elems_per_byte A_shape = (M, K) @@ -150,6 +155,7 @@ def matmul(M, assert K % (block_K * split) == 0 from tilelang.quantize import get_mxfp_intrin_group + # fast_dequant_bf16_fp4_twiddling mxfp_intrin_info = get_mxfp_intrin_group( out_dtype=in_dtype, @@ -164,7 +170,7 @@ def matmul(M, assert func_name is not None, "mxfp_intrin_info is not found" import_source = import_source - def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): + def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16): """ Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16. @@ -175,12 +181,12 @@ def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): - Writes the scaled BF16 results into B_dequantize_shared. Notes: - - This factory only supports in_dtype="fp4" and out_dtype="bfloat16". + - This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16. - The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro. - The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime. """ assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] + assert out_dtype in [T.bfloat16] # Some variables for dequantization in each thread MAX_TRANSACTION_SIZE_BITS = 128 @@ -252,24 +258,23 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale, k): for v in T.vectorized(0, local_size): index = i * threads * local_size + tx * local_size + v - B_dequantize_shared[index // block_K, - index % block_K] = B_dequantize_local_thread[v] + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] return fast_dequant_bf16_fp4_twiddling - def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): + def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16): """ Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16. Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared. Notes: - - Only supports in_dtype="fp4" and out_dtype="bfloat16". + - Only supports in_dtype="fp4" and out_dtype=T.bfloat16. - The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion. - Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper. """ assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] + assert out_dtype in [T.bfloat16] @T.macro def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k): @@ -301,33 +306,32 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k): B_local[i, j // num_elems_per_byte], j % num_elems_per_byte, Scale[ - bx * block_N + i, k * block_K // scale_size + j // - scale_size], # Scale is the exponential part, within the representation of uint8 + bx * block_N + i, k * block_K // scale_size + j // scale_size + ], # Scale is the exponential part, within the representation of uint8 dtype=out_dtype, - ) * T.shift_left( - 1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size])) + ) * T.shift_left(1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size])) T.copy(B_dequantize_local, B_dequantize_shared) return simple_dequant_bf16_fp4 @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - Scale: T.Tensor(Scale_shape, storage_dtype), - Bias: T.Tensor(Bias_shape, out_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Scale: T.Tensor(Scale_shape, storage_dtype), + Bias: T.Tensor(Bias_shape, out_dtype), + C: T.Tensor((M, N), out_dtype), ): """ - Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C. + Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C. - This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function. + This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function. - Parameters are self-descriptive in the signature; notable behaviors: - - B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM. - - The selected dequantization path is controlled by the outer-scope flag `fast_dequant`. - - The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization). - - The function writes results in-place into C. + Parameters are self-descriptive in the signature; notable behaviors: + - B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM. + - The selected dequantization path is controlled by the outer-scope flag `fast_dequant`. + - The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization). + - The function writes results in-place into C. """ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -337,23 +341,24 @@ def main( C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype) - T.annotate_layout({ - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - C_shared: tilelang.layout.make_swizzled_layout(C_shared), - }) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + } + ) if with_bias: - T.annotate_layout({ - Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared), - }) + T.annotate_layout( + { + Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared), + } + ) if threads == 512: T.disable_warp_group_reg_alloc() if with_bias: - T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N], - Bias_shared) + T.copy(Bias[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N], Bias_shared) T.copy(Bias_shared, C_local) else: T.clear(C_local) @@ -368,7 +373,7 @@ def main( T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N]) + T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) return main @@ -387,9 +392,9 @@ def ref_program_twiddling(A, qB, Scale, Bias=None): Returns: torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert_bit_twiddling(qB) - B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) return C @@ -410,9 +415,9 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias): Returns: torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert_bit_twiddling(qB) - B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = C.to(torch.__getattribute__(dtypeC)) return C @@ -434,9 +439,9 @@ def ref_program_simple(A, qB, Scale, Bias=None): No in-place modification is performed on inputs (a local floating copy of B is scaled). """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert(qB) - B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) return C @@ -462,9 +467,9 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias): No in-place modification is performed on inputs (a local floating copy of B is scaled). """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert(qB) - B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = C.to(torch.__getattribute__(dtypeC)) return C @@ -491,24 +496,16 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, if tune: kernel = matmul( - m, - n, - k, - "bfloat16", - "bfloat16", - "float32", - num_bits=4, - scale_size=scale_size, - fast_dequant=fast_dequant, - with_bias=with_bias) + m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias + ) else: kernel = matmul( m, n, k, - "bfloat16", - "bfloat16", - "float32", + T.bfloat16, + T.bfloat16, + T.float32, num_bits=4, scale_size=scale_size, block_M=256, @@ -518,7 +515,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, threads=256, split=1, fast_dequant=fast_dequant, - with_bias=with_bias) + with_bias=with_bias, + ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) @@ -538,6 +536,29 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(m=4096, n=4096, k=4096, scale_size=32, fast_dequant=True, with_bias=False): + kernel = matmul( + m, + n, + k, + "bfloat16", + "bfloat16", + "float32", + num_bits=4, + scale_size=scale_size, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, + fast_dequant=fast_dequant, + with_bias=with_bias, + ) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": M, N, K = 256, 256, 256 scale_size = 32 diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py index 7dad79597..12395df0a 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py @@ -7,29 +7,28 @@ from dequantize_utils import torch_convert_bit_twiddling, torch_convert -def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, - dtype: str): +def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str): """ - Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. + Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. - This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its - bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by - `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. + This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its + bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by + `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. - Parameters: - nbit (int): Number of bits in the packed field (must be 4). - val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. - pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). - scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). - dtype (str): Destination dtype string (must be "bfloat16"). + Parameters: + nbit (int): Number of bits in the packed field (must be 4). + val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. + pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). + scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). + dtype (str): Destination dtype string (must be "bfloat16"). - Returns: - tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. + Returns: + tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. - Notes: - - Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8". - - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. - """ + Notes: + - Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8". + - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. + """ assert nbit == 4 assert dtype == "bfloat16" assert val.dtype == "uint8" @@ -43,9 +42,10 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale # To handle the overflow, we may use the min function to limit the exponential part to 8 bits # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) m_f4 = f4 & tir.const(1, "uint16") - val_bf16 = tir.reinterpret("bfloat16", - ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) - | (m_f4 << tir.const(6, "uint16"))).astype("uint16")) + val_bf16 = tir.reinterpret( + "bfloat16", + ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"), + ) return val_bf16 @@ -65,6 +65,7 @@ def get_configs(): List[dict]: A list of configuration dictionaries covering all combinations. """ import itertools + iter_params = dict( block_M=[64, 128, 256], block_N=[64, 128, 256], @@ -73,67 +74,71 @@ def get_configs(): threads=[128, 256, 512], split=[1, 2], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] - - -@tilelang.autotune(configs=get_configs(),) -@tilelang.jit(out_idx=[-1],) -def matmul(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - source_format='uint', - num_bits=4, - scale_size=32, - fast_dequant=True, - with_bias=False, - block_M=256, - block_N=128, - block_K=128, - num_stages=2, - threads=256, - split=1): + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit( + out_idx=[-1], +) +def matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + source_format="uint", + num_bits=4, + scale_size=32, + fast_dequant=True, + with_bias=False, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, +): """ - Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype. - - The generated kernel accepts: - - A: dense matrix with element type `in_dtype`. - - B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)). - - Scale: per-block scale/exponent information used to dequantize B. - The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: - - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. - - fast_dequant (False): uses a simple elementwise dequantization helper. - - Parameters: - M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). - in_dtype (str): element type of A (e.g., "fp4" in this file). - out_dtype (str): output tensor element type (e.g., "bfloat16"). - accum_dtype (str): accumulation type used for the inner GEMM. - source_format (str, optional): format string passed to intrinsic selector (default "uint"). - num_bits (int, optional): number of bits per quantized element in B (default 4). - scale_size (int, optional): number of elements grouped per scale entry (default 32). - fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). - block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). - num_stages (int, optional): pipelining stages for K loop (default 2). - threads (int, optional): threads per block used by the kernel (default 256). - split (int, optional): split factor along K used by the scheduler (default 1). - with_bias (bool, optional): whether to add Bias to the output (default False). - - Returns: - A T.prim_func implementing the tiled, pipelined GEMM that: - - loads tiled blocks of A and packed B to shared memory, - - dequantizes B via the chosen path into a shared dequantized tile, - - performs a tiled GEMM accumulating into local fragments, - - writes the final MxN block to the global output tensor. + Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype. - Notes: - - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. - - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. - - An assertion enforces that K % (block_K * split) == 0. + The generated kernel accepts: + - A: dense matrix with element type `in_dtype`. + - B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)). + - Scale: per-block scale/exponent information used to dequantize B. + The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: + - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. + - fast_dequant (False): uses a simple elementwise dequantization helper. + + Parameters: + M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). + in_dtype (str): element type of A (e.g., "fp4" in this file). + out_dtype (str): output tensor element type (e.g., "bfloat16"). + accum_dtype (str): accumulation type used for the inner GEMM. + source_format (str, optional): format string passed to intrinsic selector (default "uint"). + num_bits (int, optional): number of bits per quantized element in B (default 4). + scale_size (int, optional): number of elements grouped per scale entry (default 32). + fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). + block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). + num_stages (int, optional): pipelining stages for K loop (default 2). + threads (int, optional): threads per block used by the kernel (default 256). + split (int, optional): split factor along K used by the scheduler (default 1). + with_bias (bool, optional): whether to add Bias to the output (default False). + + Returns: + A T.prim_func implementing the tiled, pipelined GEMM that: + - loads tiled blocks of A and packed B to shared memory, + - dequantizes B via the chosen path into a shared dequantized tile, + - performs a tiled GEMM accumulating into local fragments, + - writes the final MxN block to the global output tensor. + + Notes: + - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. + - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. + - An assertion enforces that K % (block_K * split) == 0. """ num_elems_per_byte = 8 // num_bits storage_dtype = "uint8" @@ -150,6 +155,7 @@ def matmul(M, assert K % (block_K * split) == 0 from tilelang.quantize import get_mxfp_intrin_group + # fast_dequant_bf16_fp4_twiddling mxfp_intrin_info = get_mxfp_intrin_group( out_dtype=in_dtype, @@ -252,8 +258,7 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale_shared, for v in T.vectorized(0, local_size): index = i * threads * local_size + tx * local_size + v - B_dequantize_shared[index // block_K, - index % block_K] = B_dequantize_local_thread[v] + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] return fast_dequant_bf16_fp4_twiddling @@ -301,8 +306,8 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): B_local[i, j // num_elems_per_byte], j % num_elems_per_byte, Scale_shared[ - i, k * block_K // scale_size + j // - scale_size], # Scale is the exponential part, within the representation of uint8 + i, k * block_K // scale_size + j // scale_size + ], # Scale is the exponential part, within the representation of uint8 dtype=out_dtype, ) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size])) T.copy(B_dequantize_local, B_dequantize_shared) @@ -311,22 +316,22 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - Scale: T.Tensor(Scale_shape, storage_dtype), - Bias: T.Tensor(Bias_shape, out_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Scale: T.Tensor(Scale_shape, storage_dtype), + Bias: T.Tensor(Bias_shape, out_dtype), + C: T.Tensor((M, N), out_dtype), ): """ - Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C. + Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C. - This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function. + This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function. - Parameters are self-descriptive in the signature; notable behaviors: - - B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM. - - The selected dequantization path is controlled by the outer-scope flag `fast_dequant`. - - The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization). - - The function writes results in-place into C. + Parameters are self-descriptive in the signature; notable behaviors: + - B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM. + - The selected dequantization path is controlled by the outer-scope flag `fast_dequant`. + - The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization). + - The function writes results in-place into C. """ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -339,16 +344,20 @@ def main( # May use much more shared memory than necessary Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype) - T.annotate_layout({ - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - C_shared: tilelang.layout.make_swizzled_layout(C_shared), - }) + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + C_shared: tilelang.layout.make_swizzled_layout(C_shared), + } + ) if with_bias: - T.annotate_layout({ - Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared), - }) + T.annotate_layout( + { + Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared), + } + ) if threads == 512: T.disable_warp_group_reg_alloc() @@ -357,26 +366,24 @@ def main( # T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N], # Bias_shared) # T.copy(Bias_shared, C_local) - T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N], - C_local) + T.copy(Bias[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N], C_local) else: T.clear(C_local) # Use 1D TMA to load Scale - T.copy(Scale[bx * block_N:(bx + 1) * block_N, :], Scale_shared) + T.copy(Scale[bx * block_N : (bx + 1) * block_N, :], Scale_shared) for k in T.Pipelined(K // block_K, num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) if fast_dequant: - get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, - k) + get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, k) else: get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k) T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N]) + T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) return main @@ -399,7 +406,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None): B = torch_convert_bit_twiddling(qB) for i in range(B.shape[0]): for j in range(B.shape[1]): - B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) + B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32])) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) return C @@ -424,7 +431,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias): B = torch_convert_bit_twiddling(qB) for i in range(B.shape[0]): for j in range(B.shape[1]): - B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) + B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32])) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = C.to(torch.__getattribute__(dtypeC)) return C @@ -450,7 +457,7 @@ def ref_program_simple(A, qB, Scale, Bias=None): B = torch_convert(qB) for i in range(B.shape[0]): for j in range(B.shape[1]): - B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) + B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32])) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) return C @@ -480,7 +487,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias): B = torch_convert(qB) for i in range(B.shape[0]): for j in range(B.shape[1]): - B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) + B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32])) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = C.to(torch.__getattribute__(dtypeC)) return C @@ -507,16 +514,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, if tune: kernel = matmul( - m, - n, - k, - "bfloat16", - "bfloat16", - "float32", - num_bits=4, - scale_size=scale_size, - fast_dequant=fast_dequant, - with_bias=with_bias) + m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias + ) else: kernel = matmul( m, @@ -534,7 +533,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, threads=256, split=1, fast_dequant=fast_dequant, - with_bias=with_bias) + with_bias=with_bias, + ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) diff --git a/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py index 727d6d3b6..37826874b 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py @@ -24,8 +24,9 @@ def matmul( num_bits=4, ): from tilelang.quantize import _tir_packed_to_unsigned_convert + num_elems_per_byte = 8 // num_bits - storage_dtype = "int8" + storage_dtype = T.int8 storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) A_shape = (M, K) @@ -39,9 +40,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -58,21 +59,19 @@ def main( T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) - for i in T.serial(block_N * block_K // num_elems_per_byte // - (threads * local_size_compressed)): + for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)): for v in T.vectorized(0, local_size_compressed): index = i * threads * local_size_compressed + tx * local_size_compressed + v vi = index // (block_K // num_elems_per_byte) vj = index % (block_K // num_elems_per_byte) B_local[v] = B_shared[vi, vj] for v in T.serial(0, local_size): - B_dequantize_local[v] = _tir_packed_to_unsigned_convert( - storage_type, storage_nbit)( - num_bits, - B_local[v // num_elems_per_byte], - v % num_elems_per_byte, - dtype=in_dtype, - ) + B_dequantize_local[v] = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + num_bits, + B_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) for v in T.vectorized(0, local_size): index = i * threads * local_size + tx * local_size + v vi = index // block_K @@ -121,9 +120,7 @@ def run_gemm( def ref_program(A, qB): import torch - B = ( - torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, - dtype=torch.half).to(torch.half).to(A.device)) + B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device) for i in range(B.shape[0]): for j in range(B.shape[1]): B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) @@ -146,25 +143,27 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ): from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout as make_swizzle_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitterWithLadderTransform,) + TensorCoreIntrinEmitterWithLadderTransform, + ) from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 + assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" num_bits = 4 num_elems_per_byte = 8 // num_bits - storage_dtype = "int8" + storage_dtype = T.int8 micro_size_x = micro_size_y = micro_size_k = 16 - if out_dtype == "int32": + if out_dtype == T.int32: micro_size_k = 32 # This is a debug config @@ -183,7 +182,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( block_M = block_row_warps * warp_row_tiles block_N = block_col_warps * warp_col_tiles - block_K = 32 if in_dtype == "float16" else 64 + block_K = 32 if in_dtype == T.float16 else 64 chunk = block_K // reduce_k is_smooth_a = False @@ -192,8 +191,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( pad_factor = 8 A_shape = (M, K) - B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, - micro_size_k // num_elems_per_byte) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k // num_elems_per_byte) A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) B_shared_shape = ( block_N // micro_size_y, @@ -228,7 +226,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( chunk=chunk, reduce_k=reduce_k, transform_kind_b=transform_b, - num_elems_per_byte=num_elems_per_byte) + num_elems_per_byte=num_elems_per_byte, + ) vec_load_qb = 16 if block_N * (block_K // reduce_k) // num_elems_per_byte // threads < vec_load_qb: @@ -236,14 +235,11 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((M, N), out_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, - prelude=decode_i4_to_f16) as (bx, by): - + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, prelude=decode_i4_to_f16) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -255,40 +251,36 @@ def main( thread_binding = T.get_thread_binding(0) rk = T.get_thread_binding(1) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + } + ) T.use_swizzle(panel_size=10) T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, (block_K // reduce_k)): vk = rk * (block_K // reduce_k) + k A_shared[i, vk] = A[by * block_M + i, ko * block_K + vk] # TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load - for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte // - (threads * vec_load_qb)): + for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte // (threads * vec_load_qb)): for v in T.vectorized(0, vec_load_qb): t = thread_binding idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v vkk = idx % (micro_size_k // num_elems_per_byte) vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y - vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % ( - block_K // micro_size_k) - vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y // - (block_K // micro_size_k)) % ( - block_N // micro_size_y) - B_shared[vj, vk, vjj, - vkk] = B[bx * (block_N // micro_size_y) + vj, - ko * (block_K // micro_size_k) + vk, vjj, vkk] + vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % (block_K // micro_size_k) + vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y // (block_K // micro_size_k)) % ( + block_N // micro_size_y + ) + B_shared[vj, vk, vjj, vkk] = B[bx * (block_N // micro_size_y) + vj, ko * (block_K // micro_size_k) + vk, vjj, vkk] for ki in T.serial(0, (block_K // (micro_size_k * reduce_k))): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, @@ -307,9 +299,13 @@ def main( for j in T.serial(warp_cols): local_size_b = mma_emitter.local_size_b - T.call_extern('handle', 'decode_i4u_to_f16', - T.address_of(B_local[j * local_size_b // num_elems_per_byte]), - T.address_of(B_dequantize_local[j * local_size_b]), 8) + T.call_extern( + "handle", + "decode_i4u_to_f16", + T.address_of(B_local[j * local_size_b // num_elems_per_byte]), + T.address_of(B_dequantize_local[j * local_size_b]), + 8, + ) mma_emitter.mma(A_local, B_dequantize_local, C_local) @@ -328,7 +324,8 @@ def main( reduced_accum_res[0], rk, dtype="handle", - )) + ) + ) if rk == 0: C_local[n] = reduced_accum_res[0] @@ -340,9 +337,9 @@ def main( for i, j in T.Parallel(block_M, (block_N // reduce_k)): vj = rk * (block_N // reduce_k) + j - C[by * block_M + i, - bx * block_N + vj] = C_shared[i // micro_size_x, vj // micro_size_y, - i % micro_size_x, vj % micro_size_y] + C[by * block_M + i, bx * block_N + vj] = C_shared[ + i // micro_size_x, vj // micro_size_y, i % micro_size_x, vj % micro_size_y + ] return main @@ -357,8 +354,8 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct transform_b, ): import bitblas - matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( - M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) + + matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) kernel = tilelang.compile(matmul, out_idx=[2]) src_code = kernel.get_kernel_source() @@ -368,11 +365,10 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct assert src_code is not None num_bits = 4 num_elems_per_byte = 8 // num_bits - storage_dtype = "int8" + storage_dtype = T.int8 A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - qB = torch.randint( - 0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype)) + qB = torch.randint(0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) ladder_permutate_config = bitblas.ops.LadderPermutateConfig( @@ -407,9 +403,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct # Ensure that the latency is not None assert latency is not None - B = ( - torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, - dtype=torch.half).to(torch.half).to(A.device)) + B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device) for i in range(B.shape[0]): for j in range(B.shape[1]): B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) @@ -423,14 +417,13 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct @tilelang.testing.requires_package("bitblas") def test_run_dequantize_gemm(): - run_gemm(256, 256, 256, "float16", "float16", "float16", 128, 128, 32, num_threads=128) - run_gemm(256, 256, 256, "int8", "int32", "int32", 128, 128, 32, num_threads=128) + run_gemm(256, 256, 256, T.float16, T.float16, T.float16, 128, 128, 32, num_threads=128) + run_gemm(256, 256, 256, T.int8, T.int32, T.int32, 128, 128, 32, num_threads=128) @tilelang.testing.requires_package("bitblas") def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(): - assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness( - 256, 1024, 512, "float16", "float16", "float16", 3) + assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(256, 1024, 512, T.float16, T.float16, T.float16, 3) def main(): diff --git a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py index c5588d516..2bdcbb068 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py @@ -9,30 +9,29 @@ def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): assert nbit == 4 - assert dtype == "float16" - assert val.dtype == "uint8" + assert dtype == T.float16 + assert val.dtype == T.uint8 # e_f4 == 0 -> e_f16 = 0 # e_f4 != 0 -> e_f16 = e_f4 + ExponentialBias(f16, f4) = e_f4 + (2^4 - 2^1) = e_f4 + 14 # s1e2m1 - mask = tir.const((1 << nbit) - 1, "uint16") - f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask - s = f4 >> tir.const(3, "uint16") - e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16") - e_f16 = e_f4 + tir.const(14, "uint16") - m_f4 = f4 & tir.const(1, "uint16") + mask = tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) + e_f16 = e_f4 + tir.const(14, T.uint16) + m_f4 = f4 & tir.const(1, T.uint16) m_f16 = m_f4 - val_f16 = tir.reinterpret("float16", - ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16") - | m_f16 << tir.const(9, "uint16")).astype("uint16")) - # return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16) + val_f16 = tir.reinterpret( + T.float16, ((e_f16 | (s << tir.const(5, T.uint16))) << tir.const(10, T.uint16) | m_f16 << tir.const(9, T.uint16)).astype(T.uint16) + ) + # return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, T.float16), val_f16) return val_f16 def torch_convert(tensor): - def print_bit(name, val): val_cpu = val.cpu().item() - binary_repr = f'{val_cpu:032b}' + binary_repr = f"{val_cpu:032b}" print(name, binary_repr) def _convert(val, pos): @@ -61,15 +60,15 @@ def _convert(val, pos): @tilelang.jit(out_idx=[1]) def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 B_shape = (N, K // num_elems_per_byte) B_shared_shape = (block_N, block_K // num_elems_per_byte) B_dequantize_shared_shape = (block_N, block_K) @T.prim_func def main( - B: T.Tensor(B_shape, storage_dtype), - C: T.Tensor((N, K), in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((N, K), in_dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): B_shared = T.alloc_shared(B_shared_shape, storage_dtype) @@ -99,7 +98,7 @@ def test_fp4_fp16_convert_close(): K, block_N, block_K, - "float16", + T.float16, ) B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8) @@ -118,23 +117,15 @@ def get_configs(): splits = [1] _configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits)) - configs = [{ - 'block_M': c[0], - 'block_N': c[1], - 'block_K': c[2], - 'num_stages': c[3], - 'threads': c[4], - 'split': c[5] - } for c in _configs] + configs = [{"block_M": c[0], "block_N": c[1], "block_K": c[2], "num_stages": c[3], "threads": c[4], "split": c[5]} for c in _configs] return configs def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): - @tilelang.jit(out_idx=[2]) def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1): num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 A_shape = (M, K) B_shape = (N, K // num_elems_per_byte) A_shared_shape = (block_M, block_K) @@ -145,29 +136,24 @@ def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1): @T.prim_func def main_split( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - Ct: T.Tensor((N, M), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), ): - SplitC = T.alloc_buffer([ - split, (N + block_N - 1) // block_N * block_N, - (M + block_M - 1) // block_M * block_M - ], out_dtype) - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), split, - threads=threads) as (bx, by, bz): + SplitC = T.alloc_buffer([split, (N + block_N - 1) // block_N * block_N, (M + block_M - 1) // block_M * block_M], out_dtype) + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split, threads=threads) as (bx, by, bz): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype) B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) - Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) - T.annotate_layout({ - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), - }) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + } + ) T.clear(Ct_local) for k in T.Pipelined(K // (block_K * split), num_stages=num_stages): @@ -183,8 +169,7 @@ def main_split( ) T.copy(B_dequantize_local, B_dequantize_prev_local) T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) - T.copy(Ct_local, SplitC[bz, bx * block_N:(bx + 1) * block_N, - by * block_M:(by + 1) * block_M]) + T.copy(Ct_local, SplitC[bz, bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M]) with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by): acc = T.alloc_fragment((block_N, block_M), out_dtype) T.clear(acc) @@ -195,12 +180,11 @@ def main_split( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - Ct: T.Tensor((N, M), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype) @@ -209,10 +193,11 @@ def main( Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) - T.annotate_layout({ - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), - }) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + } + ) T.clear(Ct_local) for k in T.Pipelined(K // block_K, num_stages=num_stages): @@ -229,8 +214,7 @@ def main( T.copy(B_dequantize_local, B_dequantize_prev_local) T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) T.copy(Ct_local, Ct_shared) - T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N, - by * block_M:(by + 1) * block_M]) + T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M]) if split == 1: return main @@ -241,12 +225,7 @@ def main( @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[2]) - def kernel(block_M=None, - block_N=None, - block_K=None, - num_stages=None, - threads=None, - split=None): + def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None, split=None): return kernel_func(block_M, block_N, block_K, num_stages, threads, split).prim_func return kernel() @@ -259,7 +238,7 @@ def kernel(block_M, block_N, block_K, num_stages, threads, split=1): def ref_program(A, qB): - dtypeC = "float16" + dtypeC = T.float16 B = torch_convert(qB) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) @@ -269,10 +248,10 @@ def ref_program(A, qB): def main(m=256, n=256, k=256, tune=False): total_flops = 2 * m * n * k - if (not tune): - kernel = matmul( - m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)( - block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1) + if not tune: + kernel = matmul(m, n, k, T.float16, T.float16, T.float32, num_bits=4, tune=tune)( + block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1 + ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("All checks pass.") @@ -283,7 +262,7 @@ def main(m=256, n=256, k=256, tune=False): print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) else: - best_result = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune) + best_result = matmul(m, n, k, T.float16, T.float16, T.float32, num_bits=4, tune=tune) best_latency = best_result.latency best_config = best_result.config print(f"Best latency: {best_latency}") @@ -291,12 +270,20 @@ def main(m=256, n=256, k=256, tune=False): print(f"Best config: {best_config}") +def run_regression_perf(m=4096, n=4096, k=4096): + kernel = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=False)( + block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1 + ) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--m', type=int, default=256, help='M') - parser.add_argument('--n', type=int, default=256, help='N') - parser.add_argument('--k', type=int, default=256, help='K') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--m", type=int, default=256, help="M") + parser.add_argument("--n", type=int, default=256, help="N") + parser.add_argument("--k", type=int, default=256, help="K") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() M, N, K = args.m, args.n, args.k main(M, N, K, args.tune) diff --git a/examples/dequantize_gemm/example_dequant_gemm_w4a8.py b/examples/dequantize_gemm/example_dequant_gemm_w4a8.py index 52ee8216f..b1f8b1132 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_w4a8.py +++ b/examples/dequantize_gemm/example_dequant_gemm_w4a8.py @@ -9,15 +9,15 @@ def _tir_u8_to_i4_to_i8(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): assert nbit == 4 - assert dtype == "int8" - assert val.dtype == "uint8" + assert dtype == T.int8 + assert val.dtype == T.uint8 - mask = tir.const((1 << nbit) - 1, "uint8") + mask = tir.const((1 << nbit) - 1, T.uint8) - i4 = (val >> (pos.astype("uint8") * tir.const(nbit, "uint8"))) & mask + i4 = (val >> (pos.astype(T.uint8) * tir.const(nbit, T.uint8))) & mask - i8_shifted = tir.reinterpret("int8", i4 << tir.const(4, "uint8")) - i8 = i8_shifted >> tir.const(4, "int8") + i8_shifted = tir.reinterpret(T.int8, i4 << tir.const(4, T.uint8)) + i8 = i8_shifted >> tir.const(4, T.int8) return i8 @@ -35,15 +35,15 @@ def get_configs(): @tilelang.jit(out_idx=[1]) def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 B_shape = (N, K // num_elems_per_byte) B_shared_shape = (block_N, block_K // num_elems_per_byte) B_dequantize_shared_shape = (block_N, block_K) @T.prim_func def main( - B: T.Tensor(B_shape, storage_dtype), - C: T.Tensor((N, K), in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((N, K), in_dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): B_shared = T.alloc_shared(B_shared_shape, storage_dtype) @@ -66,13 +66,12 @@ def main( def torch_convert(tensor): - def _convert(val, pos): assert val.dtype == torch.uint8 val = val.view(torch.int8) mask = (1 << 4) - 1 - i4_shifted = ((val >> (pos * 4)) & mask) - i4 = ((i4_shifted << 4) >> 4) + i4_shifted = (val >> (pos * 4)) & mask + i4 = (i4_shifted << 4) >> 4 return i4.view(torch.int8) @@ -86,7 +85,7 @@ def _convert(val, pos): def ref_program(A, qB): - dtypeC = "int32" + dtypeC = T.int32 B = torch_convert(qB) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) @@ -94,11 +93,10 @@ def ref_program(A, qB): def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): - @tilelang.jit(out_idx=[2]) def kernel_func(block_M, block_N, block_K, num_stages, threads): num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 A_shape = (M, K) B_shape = (N, K // num_elems_per_byte) A_shared_shape = (block_M, block_K) @@ -109,12 +107,11 @@ def kernel_func(block_M, block_N, block_K, num_stages, threads): @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - Ct: T.Tensor((N, M), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype) @@ -123,10 +120,11 @@ def main( Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) - T.annotate_layout({ - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), - }) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + } + ) T.clear(Ct_local) for k in T.Pipelined(K // block_K, num_stages=num_stages): @@ -143,8 +141,7 @@ def main( T.copy(B_dequantize_local, B_dequantize_prev_local) T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) T.copy(Ct_local, Ct_shared) - T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N, - by * block_M:(by + 1) * block_M]) + T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M]) return main @@ -167,10 +164,10 @@ def kernel(block_M, block_N, block_K, num_stages, threads): def main(m=128, n=256, k=256, tune=False): total_flops = 2 * m * n * k - if (not tune): - kernel = matmul_int8xint4( - m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)( - block_M=32, block_N=32, block_K=128, num_stages=1, threads=128) + if not tune: + kernel = matmul_int8xint4(m, n, k, T.int8, T.int32, T.int32, num_bits=4, tune=tune)( + block_M=32, block_N=32, block_K=128, num_stages=1, threads=128 + ) profiler = kernel.get_profiler() profiler.assert_allclose(ref_program, rtol=1e-2, atol=1e-2) print("All checks pass.") @@ -179,7 +176,7 @@ def main(m=128, n=256, k=256, tune=False): print(f"Tilelang: {latency} ms") else: - best_result = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune) + best_result = matmul_int8xint4(m, n, k, T.int8, T.int32, T.int32, num_bits=4, tune=tune) best_latency = best_result.latency best_config = best_result.config print(f"Bset latency: {best_latency}") @@ -187,6 +184,14 @@ def main(m=128, n=256, k=256, tune=False): print(f"Best tflops: {total_flops / best_latency * 1e-9}") +def run_regression_perf(m=4096, n=4096, k=4096): + kernel = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=False)( + block_M=32, block_N=32, block_K=128, num_stages=1, threads=128 + ) + profiler = kernel.get_profiler() + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--m", type=int, default=512, help="Matrix dimension M") diff --git a/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py index d3e90ec93..43e97f930 100644 --- a/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py +++ b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py @@ -4,7 +4,8 @@ import torch from tilelang import DataType from tilelang.quantize import ( - _tir_packed_int_to_int_convert,) + _tir_packed_int_to_int_convert, +) @tilelang.jit @@ -16,7 +17,7 @@ def dequantize_gemv( out_dtype: str, accum_dtype: str, num_bits: int = 4, - storage_dtype: str = "int8", + storage_dtype: T.dtype = T.int8, source_format: str = "uint", n_partition: int = 4, reduce_thread: int = 32, @@ -26,11 +27,10 @@ def dequantize_gemv( group_size: int = -1, with_scaling: bool = False, ) -> Callable[..., Any]: - assert n_partition is not None, "n_partition must be provided" assert reduce_thread is not None, ( - "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV" - "sch_outer_reduction_with_config is not implemented") + "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMVsch_outer_reduction_with_config is not implemented" + ) assert trans_A is False, "Dequantize only implement for trans_A=False currently" assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" @@ -51,7 +51,7 @@ def dequantize_gemv( C_shape = (M, N) dp4a_size = 4 - use_dp4a = in_dtype == "int8" and accum_dtype == "int32" + use_dp4a = in_dtype == T.int8 and accum_dtype == T.int32 import_source: Optional[str] = None func_name: str = "" @@ -81,12 +81,12 @@ def main( C: T.Tensor[C_shape, out_dtype], ): with T.Kernel( - T.ceildiv(N, n_partition), - M, - threads=(reduce_thread, n_partition), + T.ceildiv(N, n_partition), + M, + threads=(reduce_thread, n_partition), ) as ( - bx, - by, + bx, + by, ): A_local = T.alloc_local((micro_size_k,), in_dtype) B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype) @@ -107,8 +107,7 @@ def main( for v in T.vectorized(micro_size_k_compressed): B_quant_local[v] = B[ bx * n_partition + ni, - ko * (reduce_thread * micro_size_k_compressed) + - kr * micro_size_k_compressed + v, + ko * (reduce_thread * micro_size_k_compressed) + kr * micro_size_k_compressed + v, ] if fast_decoding: @@ -120,10 +119,9 @@ def main( ) else: for ki in T.serial(micro_size_k): - B_dequantize_local[ki] = _tir_packed_int_to_int_convert( - storage_type, - storage_nbit)(num_bits, B_quant_local[ki // num_elems_per_byte], - ki % num_elems_per_byte, in_dtype) + B_dequantize_local[ki] = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + num_bits, B_quant_local[ki // num_elems_per_byte], ki % num_elems_per_byte, in_dtype + ) if use_dp4a: for ki in T.serial(micro_size_k // dp4a_size): @@ -137,9 +135,9 @@ def main( accum_res[0] += A_local[ki] * B_dequantize_local[ki] with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), - "reduce_scope", - T.reinterpret(T.uint64(0), dtype="handle"), + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), ): T.evaluate( T.tvm_thread_allreduce( @@ -149,7 +147,8 @@ def main( reduced_accum_res[0], kr, dtype="handle", - )) + ) + ) if kr == 0: C[by, bx * n_partition + ni] = reduced_accum_res[0] @@ -160,11 +159,11 @@ def main() -> None: M = 1 N = 1024 K = 1024 - in_dtype = "float16" - out_dtype = "float16" - accum_dtype = "float16" + in_dtype = T.float16 + out_dtype = T.float16 + accum_dtype = T.float16 num_bits = 4 - storage_dtype = "int8" + storage_dtype = T.int8 source_format = "uint" n_partition = 4 reduce_thread = 32 @@ -174,26 +173,39 @@ def main() -> None: group_size = -1 with_scaling = False - kernel = dequantize_gemv(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits, storage_dtype, - source_format, n_partition, reduce_thread, fast_decoding, trans_A, - trans_B, group_size, with_scaling) + kernel = dequantize_gemv( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + num_bits, + storage_dtype, + source_format, + n_partition, + reduce_thread, + fast_decoding, + trans_A, + trans_B, + group_size, + with_scaling, + ) storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) num_elems_per_byte = storage_nbit // num_bits A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda() - qB = torch.randint( - 0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda() + qB = torch.randint(0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda() C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda() if fast_decoding: from tilelang.quantize.utils import interleave_weight + qB = interleave_weight(qB, num_bits, in_dtype) kernel(A, qB, C) # int4 reference - B = ( - torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, - dtype=torch.half).to(torch.half).to(A.device)) + B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device) for j in range(B.shape[1]): B[:, j] = ((qB[:, j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) @@ -205,5 +217,62 @@ def main() -> None: torch.testing.assert_close(C, ref_c, atol=1e3, rtol=1e-1) +def run_regression_perf(): + M = 1 + N = 8192 + K = 8192 + in_dtype = "float16" + out_dtype = "float16" + accum_dtype = "float16" + num_bits = 4 + storage_dtype = "int8" + source_format = "uint" + n_partition = 4 + reduce_thread = 32 + fast_decoding = True + trans_A = False + trans_B = True + group_size = -1 + with_scaling = False + + kernel = dequantize_gemv( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + num_bits, + storage_dtype, + source_format, + n_partition, + reduce_thread, + fast_decoding, + trans_A, + trans_B, + group_size, + with_scaling, + ) + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + num_elems_per_byte = storage_nbit // num_bits + A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda() + qB = torch.randint(0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda() + C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda() + + if fast_decoding: + from tilelang.quantize.utils import interleave_weight + + qB = interleave_weight(qB, num_bits, in_dtype) + kernel(A, qB, C) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(A, qB, C) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py index c4cf5fb50..6ee595921 100644 --- a/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py @@ -25,6 +25,7 @@ def get_configs(): List[dict]: A list of configuration dictionaries covering all combinations. """ import itertools + iter_params = dict( block_M=[128], block_N=[64, 128, 256], @@ -33,33 +34,33 @@ def get_configs(): threads=[128, 256, 512], split=[1], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] @tilelang.autotune(configs=get_configs()) @tilelang.jit(out_idx=[-1]) -def matmul(M, - N, - K, - topk, - E, - padding_M, - in_dtype, - out_dtype, - accum_dtype, - source_format='uint', - num_bits=4, - scale_size=32, - fast_dequant=True, - with_bias=False, - block_M=128, - block_N=256, - block_K=128, - num_stages=2, - threads=256, - split=1): +def matmul( + M, + N, + K, + topk, + E, + padding_M, + in_dtype, + out_dtype, + accum_dtype, + source_format=T.uint32, + num_bits=4, + scale_size=32, + fast_dequant=True, + with_bias=False, + block_M=128, + block_N=256, + block_K=128, + num_stages=2, + threads=256, + split=1, +): """ Construct and return a grouped (Mixture-of-Experts) matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized, expert-grouped B (shape ExNxQK) and writes an output of shape (M, topk, N) in out_dtype. @@ -82,8 +83,8 @@ def matmul(M, topk (int): number of experts selected per token. E (int): number of experts. padding_M (int): padded number of tokens after grouping and block alignment. - in_dtype (str): element type of A (e.g., "bfloat16"). - out_dtype (str): output tensor element type (e.g., "bfloat16"). + in_dtype (str): element type of A (e.g., T.bfloat16). + out_dtype (str): output tensor element type (e.g., T.bfloat16). accum_dtype (str): accumulation type used for the inner GEMM. source_format (str, optional): format string passed to intrinsic selector (default "uint"). num_bits (int, optional): number of bits per quantized element in B (default 4). @@ -110,16 +111,17 @@ def matmul(M, """ num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 QK = K // num_elems_per_byte Block_QK = block_K // num_elems_per_byte A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, Block_QK) - Bias_shared_shape = (block_N) + Bias_shared_shape = block_N B_dequantize_shared_shape = (block_N, block_K) assert K % (block_K * split) == 0 from tilelang.quantize import get_mxfp_intrin_group + # fast_dequant_bf16_fp4_twiddling mxfp_intrin_info = get_mxfp_intrin_group( out_dtype=in_dtype, @@ -135,7 +137,7 @@ def matmul(M, import_source = import_source # the dequant part is the same as in dequant_gemm - def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): + def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16): """ Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16. The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and: @@ -145,12 +147,12 @@ def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): - Writes the scaled BF16 results into B_dequantize_shared. Notes: - - This factory only supports in_dtype="fp4" and out_dtype="bfloat16". + - This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16. - The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro. - The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime. """ assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] + assert out_dtype in [T.bfloat16] # Some variables for dequantization in each thread MAX_TRANSACTION_SIZE_BITS = 128 @@ -221,19 +223,16 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale_shared, for v in T.vectorized(0, local_size): index = i * threads * local_size + tx * local_size + v - B_dequantize_shared[index // block_K, - index % block_K] = B_dequantize_local_thread[v] + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] return fast_dequant_bf16_fp4_twiddling - def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): - + def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16): assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] + assert out_dtype in [T.bfloat16] @T.macro def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): - B_local = T.alloc_fragment(B_shared_shape, storage_dtype) B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype) @@ -244,8 +243,8 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): B_local[i, j // num_elems_per_byte], j % num_elems_per_byte, Scale_shared[ - i, k * block_K // scale_size + j // - scale_size], # Scale is the exponential part, within the representation of uint8 + i, k * block_K // scale_size + j // scale_size + ], # Scale is the exponential part, within the representation of uint8 dtype=out_dtype, ) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size])) T.copy(B_dequantize_local, B_dequantize_shared) @@ -254,19 +253,17 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): @T.prim_func def main( - A: T.Tensor((M, K), in_dtype), - B: T.Tensor((E, N, QK), storage_dtype), - Scale: T.Tensor((E, N, K // scale_size), storage_dtype), - Bias: T.Tensor((E, N), out_dtype), - # Add fusedmoe tensors - topk_weights: T.Tensor((M * topk), out_dtype), - sorted_token_ids: T.Tensor((padding_M), "int32"), - expert_ids: T.Tensor((padding_M // block_M), "int32"), - C: T.Tensor((M, topk, N), out_dtype), + A: T.Tensor((M, K), in_dtype), + B: T.Tensor((E, N, QK), storage_dtype), + Scale: T.Tensor((E, N, K // scale_size), storage_dtype), + Bias: T.Tensor((E, N), out_dtype), + # Add fusedmoe tensors + topk_weights: T.Tensor((M * topk), out_dtype), + sorted_token_ids: T.Tensor((padding_M), T.int32), + expert_ids: T.Tensor((padding_M // block_M), T.int32), + C: T.Tensor((M, topk, N), out_dtype), ): - - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) @@ -274,23 +271,23 @@ def main( C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype) topk_weights_shared = T.alloc_shared((block_M), out_dtype) - sorted_token_ids_shared = T.alloc_shared((block_M), "int32") - expert_id = T.alloc_local((1), "int32") # the expert id for the current block + sorted_token_ids_shared = T.alloc_shared((block_M), T.int32) + expert_id = T.alloc_local((1), T.int32) # the expert id for the current block # To use 1D TMA, the last dim of Scale_shared must have stride=1 # May use much more shared memory than necessary Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype) - T.annotate_layout({ - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - C_shared: tilelang.layout.make_swizzled_layout(C_shared), - }) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + } + ) T.use_swizzle(10) if threads == 512: T.disable_warp_group_reg_alloc() - T.copy(sorted_token_ids[by * block_M:(by + 1) * block_M], sorted_token_ids_shared) + T.copy(sorted_token_ids[by * block_M : (by + 1) * block_M], sorted_token_ids_shared) expert_id[0] = expert_ids[by] # Get the topk weights of each token in the current block @@ -300,11 +297,11 @@ def main( # Get bias and scale based on the expert id if with_bias: - T.copy(Bias[expert_id[0], bx * block_N:(bx + 1) * block_N], Bias_shared) + T.copy(Bias[expert_id[0], bx * block_N : (bx + 1) * block_N], Bias_shared) else: T.clear(Bias_shared) - T.copy(Scale[expert_id[0], bx * block_N:(bx + 1) * block_N, :], Scale_shared) + T.copy(Scale[expert_id[0], bx * block_N : (bx + 1) * block_N, :], Scale_shared) for i, j in T.Parallel(block_M, block_N): C_local[i, j] = Bias_shared[j] @@ -317,14 +314,13 @@ def main( base = copy_i * threads * 16 + tx * 16 if sorted_token_ids_shared[base // block_K] != -1: for copy_j in T.vectorized(16): - A_shared[base // block_K, base % block_K + - copy_j] = A[sorted_token_ids_shared[base // block_K] // topk, - k * block_K + base % block_K + copy_j] + A_shared[base // block_K, base % block_K + copy_j] = A[ + sorted_token_ids_shared[base // block_K] // topk, k * block_K + base % block_K + copy_j + ] T.copy(B[expert_id[0], bx * block_N, k * block_K // num_elems_per_byte], B_shared) if fast_dequant: - get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, - k) + get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, k) else: get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k) @@ -338,16 +334,17 @@ def main( base = copy_i * threads * 16 + tx * 16 if sorted_token_ids_shared[base // block_N] != -1: for copy_j in T.vectorized(16): - C[sorted_token_ids_shared[base // block_N] // topk, - sorted_token_ids_shared[base // block_N] % topk, bx * block_N + - base % block_N + copy_j] = C_shared[base // block_N, - base % block_N + copy_j] + C[ + sorted_token_ids_shared[base // block_N] // topk, + sorted_token_ids_shared[base // block_N] % topk, + bx * block_N + base % block_N + copy_j, + ] = C_shared[base // block_N, base % block_N + copy_j] return main def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=256): - dtypeC = "bfloat16" + dtypeC = T.bfloat16 M, K = A.shape E, N, QK = qB.shape topk = topk_weights.shape[0] // M @@ -355,7 +352,7 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc assert scale_size == 32 # MXFP4 # Initialize output tensor - C = torch.ones((M, topk, N), dtype=getattr(torch, dtypeC), device='cuda') + C = torch.ones((M, topk, N), dtype=getattr(torch, dtypeC), device="cuda") # Iterate over sorted_token_ids for idx in range(len(sorted_token_ids)): # padding_M @@ -370,14 +367,11 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc # Dequantize the expert weights B = torch_convert_bit_twiddling(qB[expert_id]) # shape: (N, K) - B *= 2**( - Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to( - torch.bfloat16)) + B *= 2 ** (Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to(torch.bfloat16)) # Compute the output for this token-expert pair # token_embedding @ B.T + bias - output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to( - torch.bfloat16)) + Bias[expert_id] + output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to(torch.bfloat16)) + Bias[expert_id] output = output.to(torch.__getattribute__(dtypeC)) # Apply the topk weight @@ -391,14 +385,12 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc def get_data(m, n, k, qk, scale_size, topk, E, block_M): - A = torch.empty(m, k, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1) - qB = torch.randint( - 0, 256, (E, n, qk), dtype=torch.uint8, - device='cuda') # Quantized weight tensor for E experts. - Scale = torch.randint(0, 8, (E, n, k // scale_size), dtype=torch.uint8, device='cuda') - Bias = torch.empty(E, n, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1) - - weights = torch.empty(m, E, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1) + A = torch.empty(m, k, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1) + qB = torch.randint(0, 256, (E, n, qk), dtype=torch.uint8, device="cuda") # Quantized weight tensor for E experts. + Scale = torch.randint(0, 8, (E, n, k // scale_size), dtype=torch.uint8, device="cuda") + Bias = torch.empty(E, n, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1) + + weights = torch.empty(m, E, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1) # topk_weights: Router weights for the top-k experts for each token. # Shape: (m, topk) # tokens_experts: A flattened tensor of expert assignments for each token. @@ -420,10 +412,7 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M): pad_len = ((cnt + block_M - 1) // block_M) * block_M - cnt if pad_len > 0: # -1 for padding (`M` instead in vLLM moe_align_block_size()) - group_token_ids = torch.cat([ - group_token_ids, - torch.full((pad_len,), -1, dtype=group_token_ids.dtype, device='cuda') - ]) + group_token_ids = torch.cat([group_token_ids, torch.full((pad_len,), -1, dtype=group_token_ids.dtype, device="cuda")]) padded_token_ids.append(group_token_ids) expert_ids.extend([eid] * ((cnt + block_M - 1) // block_M)) start = end @@ -431,21 +420,13 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M): # sorted_token_ids: The final flattened and padded tensor of token indices. sorted_token_ids = torch.cat(padded_token_ids, dim=0).to(torch.int32) # (padding_M,) # expert_ids: The final tensor of expert IDs corresponding to `sorted_token_ids`. - expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda') # (padding_M,) + expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda") # (padding_M,) padding_M = sorted_token_ids.shape[0] # padding_M: token number after padding return A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M -def main(m=256, - n=256, - k=256, - scale_size=32, - topk=4, - E=32, - fast_dequant=True, - with_bias=False, - tune=False): +def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, with_bias=False, tune=False): # Tunable parameters block_M, block_N, block_K = 128, 256, 128 # noqa: F841 num_stages = 1 # noqa: F841 @@ -456,8 +437,7 @@ def main(m=256, num_bits = 4 num_elems_per_byte = 8 // num_bits qk = k // num_elems_per_byte - A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data( - m, n, k, qk, scale_size, topk, E, block_M) + A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data(m, n, k, qk, scale_size, topk, E, block_M) if tune: with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]): @@ -469,9 +449,9 @@ def main(m=256, topk, E, padding_M, - "bfloat16", - "bfloat16", - "float32", + T.bfloat16, + T.bfloat16, + T.float32, num_bits=num_bits, scale_size=scale_size, fast_dequant=fast_dequant, @@ -485,9 +465,9 @@ def main(m=256, topk, E, padding_M, - "bfloat16", - "bfloat16", - "float32", + T.bfloat16, + T.bfloat16, + T.float32, num_bits=num_bits, scale_size=scale_size, fast_dequant=fast_dequant, @@ -510,14 +490,11 @@ def main(m=256, expert_ids, ) - print('Tilelang kernel run finished.') + print("Tilelang kernel run finished.") - ref_output = ref_moe( - A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, - block_M=block_M) # Maybe a little bit slow... + ref_output = ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=block_M) # Maybe a little bit slow... - latency = tilelang.profiler.do_bench( - lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100) + latency = tilelang.profiler.do_bench(lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100) print("Tilelang: {:.2f} ms".format(latency)) print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) @@ -525,32 +502,72 @@ def main(m=256, max_val = diff.max() max_idx = diff.argmax() print(f"max abs diff: {max_val} at index: {max_idx}") - assert_similar( - output, ref_output, name="output", - eps=2e-5) # We care about the similarity rather than abs. difference + assert_similar(output, ref_output, name="output", eps=2e-5) # We care about the similarity rather than abs. difference print("All checks pass. ✅") +def run_regression_perf(m=4096, n=4096, k=4096, scale_size=32, topk=4, E=32, fast_dequant=True, with_bias=False, tune=False): + block_M, block_N, block_K = 128, 256, 128 + num_stages = 1 + threads = 512 + split = 1 + num_bits = 4 + num_elems_per_byte = 8 // num_bits + qk = k // num_elems_per_byte + A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data(m, n, k, qk, scale_size, topk, E, block_M) + + if tune: + with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]): + kernel = matmul( + m, + n, + k, + topk, + E, + padding_M, + "bfloat16", + "bfloat16", + "float32", + num_bits=num_bits, + scale_size=scale_size, + fast_dequant=fast_dequant, + with_bias=with_bias, + ) + else: + kernel = matmul( + m, + n, + k, + topk, + E, + padding_M, + "bfloat16", + "bfloat16", + "float32", + num_bits=num_bits, + scale_size=scale_size, + fast_dequant=fast_dequant, + with_bias=with_bias, + block_M=block_M, + block_N=block_N, + block_K=block_K, + num_stages=num_stages, + threads=threads, + split=split, + ) + + return tilelang.profiler.do_bench(lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - "--M", type=int, default=16384, help="M") # From gpt-oss-20b MoE's first gemm + parser.add_argument("--M", type=int, default=16384, help="M") # From gpt-oss-20b MoE's first gemm parser.add_argument("--N", type=int, default=5760, help="N") parser.add_argument("--K", type=int, default=2944, help="K") parser.add_argument("--scale_size", type=int, default=32, help="scale size") - parser.add_argument( - "--topk", type=int, default=4, help="topk") # experts activated for each token + parser.add_argument("--topk", type=int, default=4, help="topk") # experts activated for each token parser.add_argument("--E", type=int, default=32, help="E") # number of experts parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - main( - args.M, - args.N, - args.K, - args.scale_size, - topk=args.topk, - E=args.E, - fast_dequant=True, - with_bias=True, - tune=args.tune) + main(args.M, args.N, args.K, args.scale_size, topk=args.topk, E=args.E, fast_dequant=True, with_bias=True, tune=args.tune) diff --git a/examples/dequantize_gemm/regression_example_dequantize_gemm.py b/examples/dequantize_gemm/regression_example_dequantize_gemm.py new file mode 100644 index 000000000..4ab03784f --- /dev/null +++ b/examples/dequantize_gemm/regression_example_dequantize_gemm.py @@ -0,0 +1,35 @@ +import tilelang.testing +import example_dequant_gemm_bf16_fp4_hopper +import example_dequant_gemm_bf16_mxfp4_hopper +import example_dequant_gemm_fp4_hopper +import example_dequant_gemm_w4a8 +import example_dequant_gemv_fp16xint4 +import example_dequant_groupedgemm_bf16_mxfp4_hopper + + +def regression_example_dequant_gemv_fp16xint4(): + tilelang.testing.process_func(example_dequant_gemv_fp16xint4.run_regression_perf) + + +def regression_example_dequant_gemm_fp4_hopper(): + tilelang.testing.process_func(example_dequant_gemm_fp4_hopper.run_regression_perf) + + +def regression_example_dequant_gemm_bf16_fp4_hopper(): + tilelang.testing.process_func(example_dequant_gemm_bf16_fp4_hopper.run_regression_perf) + + +def regression_example_dequant_gemm_bf16_mxfp4_hopper(): + tilelang.testing.process_func(example_dequant_gemm_bf16_mxfp4_hopper.run_regression_perf) + + +def regression_example_dequant_groupedgemm_bf16_mxfp4_hopper(): + tilelang.testing.process_func(example_dequant_groupedgemm_bf16_mxfp4_hopper.run_regression_perf) + + +def regression_example_dequant_gemm_w4a8(): + tilelang.testing.process_func(example_dequant_gemm_w4a8.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/dequantize_gemm/test_example_dequantize_gemm.py b/examples/dequantize_gemm/test_example_dequantize_gemm.py index 01bc40e6c..a2f777222 100644 --- a/examples/dequantize_gemm/test_example_dequantize_gemm.py +++ b/examples/dequantize_gemm/test_example_dequantize_gemm.py @@ -3,7 +3,6 @@ import example_dequant_gemv_fp16xint4 import example_dequant_gemm_fp4_hopper import example_dequant_gemm_bf16_mxfp4_hopper -import example_dequant_gemm_bf16_mxfp4_hopper_tma import example_dequant_groupedgemm_bf16_mxfp4_hopper import example_dequant_gemm_w4a8 @@ -25,12 +24,6 @@ def test_example_dequant_gemm_bf16_mxfp4_hopper(): example_dequant_gemm_bf16_mxfp4_hopper.main() -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) -def test_example_dequant_gemm_bf16_mxfp4_hopper_tma(): - example_dequant_gemm_bf16_mxfp4_hopper_tma.main() - - @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_dequant_groupedgemm_bf16_mxfp4_hopper(): diff --git a/examples/dequantize_gemm/utils.py b/examples/dequantize_gemm/utils.py index 7134ae6aa..da9ddb9f8 100644 --- a/examples/dequantize_gemm/utils.py +++ b/examples/dequantize_gemm/utils.py @@ -34,8 +34,7 @@ def _convert(val0, val1, pos) -> torch.bfloat16: mask1 = 0b1000000000000000 mask2 = 0b0000000110000000 mask3 = 0b0000000001000000 - bf16 = ((val_concat << 1) & mask1) | ((val_concat >> 3) & mask2) | ( - (val_concat >> 7) & mask3) + bf16 = ((val_concat << 1) & mask1) | ((val_concat >> 3) & mask2) | ((val_concat >> 7) & mask3) bf16_new = torch.tensor([bf16], dtype=torch.uint16, device=val0.device).view(torch.bfloat16) # Add bias for change from fp4 to bf16 bf16_new = bf16_new.item() * (2**126) @@ -104,5 +103,5 @@ def print_bit(name, val): val (torch.Tensor): A scalar PyTorch tensor (numeric) whose 32-bit binary representation will be shown. """ val_cpu = val.cpu().item() - binary_repr = f'{val_cpu:032b}' + binary_repr = f"{val_cpu:032b}" print(name, binary_repr) diff --git a/examples/distributed/README.md b/examples/distributed/README.md index e73ae0fac..48cf85488 100644 --- a/examples/distributed/README.md +++ b/examples/distributed/README.md @@ -2,7 +2,7 @@ This directory contains examples demonstrating distributed computing capabilities using TileLang. -For example, +For example, ``` ./tilelang/distributed/launch.sh examples/distributed/example_allgather.py ``` @@ -11,7 +11,7 @@ For example, Before running the examples, you need to build NVSHMEM library for device-side code generation. -```bash +```bash export NVSHMEM_SRC="your_custom_nvshmem_dir" # default to 3rdparty/nvshmem_src cd tilelang/distributed source build_nvshmem.sh diff --git a/examples/distributed/deepseek_deepep/buffer.py b/examples/distributed/deepseek_deepep/buffer.py index f281f19e3..71f7f3faf 100644 --- a/examples/distributed/deepseek_deepep/buffer.py +++ b/examples/distributed/deepseek_deepep/buffer.py @@ -1,4 +1,4 @@ -""" The interface for DeepEP. """ +"""The interface for DeepEP.""" import torch import torch.distributed as dist @@ -27,14 +27,16 @@ class EPBuffer: num_sms: int = 20 symm_heap_size: int = 2**30 # size of the symm heap for allocators - def __init__(self, - group: dist.ProcessGroup, - num_nvl_bytes: int, - num_topk: int, - num_experts: int, - hidden: int, - dispatch_cfg: Optional[Config] = None, - combine_cfg: Optional[Config] = None): + def __init__( + self, + group: dist.ProcessGroup, + num_nvl_bytes: int, + num_topk: int, + num_experts: int, + hidden: int, + dispatch_cfg: Optional[Config] = None, + combine_cfg: Optional[Config] = None, + ): """ Initialize the communication buffer. @@ -70,7 +72,8 @@ def __init__(self, is_distributed=True, local_rank=self.rank, num_local_ranks=self.num_ranks, - group=group) + group=group, + ) self._pre_alloc_symm_buffers() self._prepare_counters() @@ -87,81 +90,70 @@ def _pre_alloc_symm_buffers(self): def _pre_alloc_symm_buffers_intranode(self): # barrier signal is always zeroed after each usage, so we can pre-init here - barrier_signal = tilelang.tensor((self.num_ranks), - dtype=torch.int32, - device='cuda', - allocator=self._allocator).zero_() - - per_rank_buffer = tilelang.tensor((self.num_ranks, self.num_ranks), - dtype=torch.int32, - device='cuda', - allocator=self._allocator) - per_expert_buffer = tilelang.tensor((self.num_ranks, self.num_local_experts), - dtype=torch.int32, - device='cuda', - allocator=self._allocator) - - channel_start_offset = tilelang.tensor([self.num_channels, self.num_ranks], - dtype=torch.int32, - device='cuda', - allocator=self._allocator) - channel_end_offset = tilelang.tensor([self.num_channels, self.num_ranks], - dtype=torch.int32, - device='cuda', - allocator=self._allocator) - channel_head_idx = tilelang.tensor([self.num_channels, self.num_ranks], - dtype=torch.int32, - device='cuda', - allocator=self._allocator) - channel_tail_idx = tilelang.tensor([self.num_channels, self.num_ranks], - dtype=torch.int32, - device='cuda', - allocator=self._allocator) + barrier_signal = tilelang.tensor((self.num_ranks), dtype=torch.int32, device="cuda", allocator=self._allocator).zero_() + + per_rank_buffer = tilelang.tensor((self.num_ranks, self.num_ranks), dtype=torch.int32, device="cuda", allocator=self._allocator) + per_expert_buffer = tilelang.tensor( + (self.num_ranks, self.num_local_experts), dtype=torch.int32, device="cuda", allocator=self._allocator + ) + + channel_start_offset = tilelang.tensor( + [self.num_channels, self.num_ranks], dtype=torch.int32, device="cuda", allocator=self._allocator + ) + channel_end_offset = tilelang.tensor( + [self.num_channels, self.num_ranks], dtype=torch.int32, device="cuda", allocator=self._allocator + ) + channel_head_idx = tilelang.tensor([self.num_channels, self.num_ranks], dtype=torch.int32, device="cuda", allocator=self._allocator) + channel_tail_idx = tilelang.tensor([self.num_channels, self.num_ranks], dtype=torch.int32, device="cuda", allocator=self._allocator) # NOTE: for each #ranks, dispatch and combine cfg have the same num_max_nvl_chunked_recv_tokens, so we can use the same buffer here - channel_x_buffers = tilelang.tensor([ - self.num_channels, self.num_ranks, self.dispatch_cfg.num_max_nvl_chunked_recv_tokens, - self.hidden - ], - dtype=torch.bfloat16, - device='cuda', - allocator=self._allocator) + channel_x_buffers = tilelang.tensor( + [self.num_channels, self.num_ranks, self.dispatch_cfg.num_max_nvl_chunked_recv_tokens, self.hidden], + dtype=torch.bfloat16, + device="cuda", + allocator=self._allocator, + ) channel_src_idx_buffers = tilelang.tensor( [self.num_channels, self.num_ranks, self.dispatch_cfg.num_max_nvl_chunked_recv_tokens], dtype=torch.int32, - device='cuda', - allocator=self._allocator) - channel_topk_idx_buffers = tilelang.tensor([ - self.num_channels, self.num_ranks, self.dispatch_cfg.num_max_nvl_chunked_recv_tokens, - self.num_topk - ], - dtype=torch.int64, - device='cuda', - allocator=self._allocator) - channel_topk_weights_buffers = tilelang.tensor([ - self.num_channels, self.num_ranks, self.dispatch_cfg.num_max_nvl_chunked_recv_tokens, - self.num_topk - ], - dtype=torch.float32, - device='cuda', - allocator=self._allocator) - - self._symm_buffers = (barrier_signal, per_rank_buffer, per_expert_buffer, - channel_start_offset, channel_end_offset, channel_head_idx, - channel_tail_idx, channel_x_buffers, channel_src_idx_buffers, - channel_topk_idx_buffers, channel_topk_weights_buffers) + device="cuda", + allocator=self._allocator, + ) + channel_topk_idx_buffers = tilelang.tensor( + [self.num_channels, self.num_ranks, self.dispatch_cfg.num_max_nvl_chunked_recv_tokens, self.num_topk], + dtype=torch.int64, + device="cuda", + allocator=self._allocator, + ) + channel_topk_weights_buffers = tilelang.tensor( + [self.num_channels, self.num_ranks, self.dispatch_cfg.num_max_nvl_chunked_recv_tokens, self.num_topk], + dtype=torch.float32, + device="cuda", + allocator=self._allocator, + ) + + self._symm_buffers = ( + barrier_signal, + per_rank_buffer, + per_expert_buffer, + channel_start_offset, + channel_end_offset, + channel_head_idx, + channel_tail_idx, + channel_x_buffers, + channel_src_idx_buffers, + channel_topk_idx_buffers, + channel_topk_weights_buffers, + ) def _pre_alloc_symm_buffers_internode(self): raise NotImplementedError("internode is not supported yet") def _prepare_counters(self): - self._moe_recv_counter, self._moe_recv_counter_mapped = create_mapped_tensor([1], - torch.int32) - self._moe_recv_expert_counter, self._moe_recv_expert_counter_mapped = create_mapped_tensor( - [self.num_local_experts], torch.int32) + self._moe_recv_counter, self._moe_recv_counter_mapped = create_mapped_tensor([1], torch.int32) + self._moe_recv_expert_counter, self._moe_recv_expert_counter_mapped = create_mapped_tensor([self.num_local_experts], torch.int32) if self.num_ranks > 8: # internode - self._moe_recv_rdma_counter, self._moe_recv_rdma_counter_mapped = create_mapped_tensor( - [1], torch.int32) + self._moe_recv_rdma_counter, self._moe_recv_rdma_counter_mapped = create_mapped_tensor([1], torch.int32) @staticmethod def set_num_sms(num_sms: int): @@ -204,19 +196,20 @@ def get_dispatch_layout(self, topk_idx: torch.Tensor): num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert. is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank. """ - num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank = get_dispatch_layout( - topk_idx, self.num_experts, self.num_ranks) + num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank = get_dispatch_layout(topk_idx, self.num_experts, self.num_ranks) return num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank - def dispatch(self, - x: torch.Tensor, - handle: Optional[Tuple] = None, - num_tokens_per_rank: Optional[torch.Tensor] = None, - is_token_in_rank: Optional[torch.Tensor] = None, - num_tokens_per_expert: Optional[torch.Tensor] = None, - topk_idx: Optional[torch.Tensor] = None, - topk_weights: Optional[torch.Tensor] = None, - expert_alignment: int = 1): + def dispatch( + self, + x: torch.Tensor, + handle: Optional[Tuple] = None, + num_tokens_per_rank: Optional[torch.Tensor] = None, + is_token_in_rank: Optional[torch.Tensor] = None, + num_tokens_per_expert: Optional[torch.Tensor] = None, + topk_idx: Optional[torch.Tensor] = None, + topk_weights: Optional[torch.Tensor] = None, + expert_alignment: int = 1, + ): """ Dispatch tokens to different ranks, both intranode and internode settings are supported. Intranode kernels require all the ranks should be visible via NVLink. @@ -273,11 +266,24 @@ def dispatch(self, else: assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle = intranode_dispatch( - self.rank, self._allocator, self._symm_buffers, self._moe_recv_counter, - self._moe_recv_expert_counter, self._moe_recv_counter_mapped, - self._moe_recv_expert_counter_mapped, x, self.dispatch_cfg, handle, - num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, topk_idx, - topk_weights, expert_alignment, self.comm_stream) + self.rank, + self._allocator, + self._symm_buffers, + self._moe_recv_counter, + self._moe_recv_expert_counter, + self._moe_recv_counter_mapped, + self._moe_recv_expert_counter_mapped, + x, + self.dispatch_cfg, + handle, + num_tokens_per_rank, + is_token_in_rank, + num_tokens_per_expert, + topk_idx, + topk_weights, + expert_alignment, + self.comm_stream, + ) return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle def combine(self, x: torch.Tensor, handle: Tuple, topk_weights: torch.Tensor): @@ -298,7 +304,7 @@ def combine(self, x: torch.Tensor, handle: Tuple, topk_weights: torch.Tensor): recv_x: the reduced token from its dispatched ranks. recv_topk_weights: the reduced top-k weights from its dispatch ranks. """ - recv_x, recv_topk_weights = intranode_combine(self.rank, self._allocator, - self._symm_buffers, x, self.combine_cfg, - handle, topk_weights, self.comm_stream) + recv_x, recv_topk_weights = intranode_combine( + self.rank, self._allocator, self._symm_buffers, x, self.combine_cfg, handle, topk_weights, self.comm_stream + ) return recv_x, recv_topk_weights diff --git a/examples/distributed/deepseek_deepep/deepep.md b/examples/distributed/deepseek_deepep/deepep.md index d3cea90dc..620baf428 100644 --- a/examples/distributed/deepseek_deepep/deepep.md +++ b/examples/distributed/deepseek_deepep/deepep.md @@ -20,14 +20,12 @@ The table below shows a latency and bandwidth comparison for DeepEP and TileScal | DeepEP | 1.0045 | 328.97 | 1.1552 | 287.14 | | TileScale | 1.0720 | 308.25 | 1.0809 | 306.86 | - # Intra-node Introduction This example implements DeepEP’s intra‑node (NVLink) dispatch/combine using TileScale kernels. z The intra‑node path lives under `intranode/` and provides a minimal public API that mirrors DeepEP’s behavior for NVLink‑connected ranks. - ## Overview - Scope: intra‑node (NVLink) only; all ranks must be within one node and NVLink‑visible. @@ -35,7 +33,6 @@ The intra‑node path lives under `intranode/` and provides a minimal public API - Datatypes: inputs are `torch.bfloat16`; routing `topk_idx` is `torch.int64`; `topk_weights` is `torch.float32`. - Channels: each channel uses 2 SMs (send/recv). With default `num_sms=20`, there are `num_channels=10`. - ## Public API (intranode) - `intranode.get_dispatch_layout(topk_idx, num_experts, num_ranks)` @@ -63,7 +60,6 @@ Convenience wrapper used by examples/tests: - Exposes the interface for the functions above via methods: `get_dispatch_layout`, `dispatch`, `combine`. - Manages TileScale allocator, symmetric buffers, and recommended kernel configs. - ## Core Data Structures and Handle - `rank_prefix_matrix` (num_ranks × num_ranks): cumulative per‑rank token counts; used to compute global offsets for receiver writes. @@ -82,7 +78,6 @@ Dispatch returns the handle: `(rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head)` which can be reused for cached re‑dispatch and is required by the combine stage. - ## Kernel Responsibilities (high level) - Layout @@ -97,14 +92,12 @@ which can be reused for cached re‑dispatch and is required by the combine stag - `cached_notify_combine_kernel`: recalculates `send_head` expectations and zeros `channel_head_idx`/`channel_tail_idx` for the combine round. - `combine_kernel`: senders return expert outputs; receivers reduce by sum per token. `recv_topk_weights` is the sum of returned weights per token. Requires `hidden % 8 == 0` for vectorized access on the receiver side. - ## Configuration and Tuning - `utils.Config` provides recommended values for `num_max_nvl_chunked_send_tokens` and `num_max_nvl_chunked_recv_tokens` per `num_ranks`. These control per‑round trunk sizes and receiver buffer depth per channel. - `EPBuffer.num_sms` controls total SMs assigned to high‑throughput kernels. Channels = `num_sms // 2` (one send SM + one recv SM per channel). - `expert_alignment` pads per‑local‑expert MoE receive counters up to the specified multiple, which can be used to size per‑expert workspace. - ## Execution Flow (non‑cached) 1) Prepare group and buffers @@ -138,7 +131,6 @@ which can be reused for cached re‑dispatch and is required by the combine stag 6) Cached re‑dispatch (optional) - For repeated communication with the same layout, pass `handle` back into `EPBuffer.dispatch(x, handle, ...)` to skip layout/notify work and return only `recv_x`. - ## Usage Quick start (intra‑node test): @@ -174,7 +166,6 @@ recv_x, recv_topk_idx, recv_topk_weights, per_expert_counts, handle = buf.dispat reduced_x, reduced_weights = buf.combine(expert_out, handle, recv_topk_weights) ``` - ## Notes and Limits - Intra‑node only: ranks must be NVLink‑visible; current code asserts `num_ranks <= 8` and `num_experts % num_ranks == 0`. @@ -184,7 +175,6 @@ reduced_x, reduced_weights = buf.combine(expert_out, handle, recv_topk_weights) - Ensure `topk_idx` is contiguous, 2D, and `torch.int64`. - Set `TILELANG_USE_DISTRIBUTED=1` to enable TileScale’s distributed runtime. - ## Files - `intranode/__init__.py` — re‑exports `get_dispatch_layout`, `intranode_dispatch`, `intranode_combine`. @@ -194,7 +184,6 @@ reduced_x, reduced_weights = buf.combine(expert_out, handle, recv_topk_weights) - `buffer.py` — EPBuffer wrapper: allocator and symmetric buffers, public methods. - `utils.py` — recommended configs and MoE counter helpers. - ## Implementation Notes - Negative offset encoding: senders write channel start/end offsets as `-value-1` so that a zero token count is distinguishable from an uninitialized `0`. diff --git a/examples/distributed/deepseek_deepep/deepep_utils.py b/examples/distributed/deepseek_deepep/deepep_utils.py index 1294acb31..288640295 100644 --- a/examples/distributed/deepseek_deepep/deepep_utils.py +++ b/examples/distributed/deepseek_deepep/deepep_utils.py @@ -30,7 +30,7 @@ def __post_init__(self): # 1 sm for send, 1 sm for recv in each channel @staticmethod - def get_dispatch_config(num_ranks: int) -> 'Config': + def get_dispatch_config(num_ranks: int) -> "Config": """ Get a recommended dispatch config. @@ -56,11 +56,11 @@ def get_dispatch_config(num_ranks: int) -> 'Config': 144: Config(num_sms, 32, 720, 12, 128), 160: Config(num_sms, 28, 720, 12, 128), } - assert num_ranks in config_map, f'Unsupported number of EP ranks: {num_ranks}' + assert num_ranks in config_map, f"Unsupported number of EP ranks: {num_ranks}" return config_map[num_ranks] @staticmethod - def get_combine_config(num_ranks: int) -> 'Config': + def get_combine_config(num_ranks: int) -> "Config": """ Get a recommended combine config. @@ -86,33 +86,31 @@ def get_combine_config(num_ranks: int) -> 'Config': 144: Config(num_sms, 2, 720, 8, 128), 160: Config(num_sms, 2, 720, 8, 128), } - assert num_ranks in config_map, f'Unsupported number of EP ranks: {num_ranks}' + assert num_ranks in config_map, f"Unsupported number of EP ranks: {num_ranks}" return config_map[num_ranks] # Only necessary in inter-node cases -def set_rdma_env_args(num_qps_per_rank: int = 24, - allow_nvlink_for_low_latency_mode: bool = True, - allow_mnnvl: bool = False): - os.environ['NVSHMEM_DISABLE_P2P'] = '0' if allow_nvlink_for_low_latency_mode else '1' - os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1' - os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}' +def set_rdma_env_args(num_qps_per_rank: int = 24, allow_nvlink_for_low_latency_mode: bool = True, allow_mnnvl: bool = False): + os.environ["NVSHMEM_DISABLE_P2P"] = "0" if allow_nvlink_for_low_latency_mode else "1" + os.environ["NVSHMEM_IB_ENABLE_IBGDA"] = "1" + os.environ["NVSHMEM_IBGDA_NUM_RC_PER_PE"] = f"{num_qps_per_rank}" # Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check - nvshmem_qp_depth = int(os.environ.get('NVSHMEM_QP_DEPTH', '1024')) - os.environ['NVSHMEM_QP_DEPTH'] = str(nvshmem_qp_depth) + nvshmem_qp_depth = int(os.environ.get("NVSHMEM_QP_DEPTH", "1024")) + os.environ["NVSHMEM_QP_DEPTH"] = str(nvshmem_qp_depth) # Reduce gpu memory usage # 6 default teams + 1 extra team - os.environ['NVSHMEM_MAX_TEAMS'] = '7' + os.environ["NVSHMEM_MAX_TEAMS"] = "7" # Disable NVLink SHArP - os.environ['NVSHMEM_DISABLE_NVLS'] = '1' + os.environ["NVSHMEM_DISABLE_NVLS"] = "1" # NOTES: NVSHMEM initialization requires at least 256 MiB - os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}' + os.environ["NVSHMEM_CUMEM_GRANULARITY"] = f"{2**29}" if not allow_mnnvl: # Disable multi-node NVLink detection - os.environ['NVSHMEM_DISABLE_MNNVL'] = '1' + os.environ["NVSHMEM_DISABLE_MNNVL"] = "1" def unpack_bias(bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]): @@ -147,10 +145,10 @@ def gen_inputs(num_tokens: int, hidden: int, num_topk: int, num_experts: int, nu assert num_topk <= num_experts, "num_topk must be less than or equal to num_experts" assert num_experts % num_ranks == 0, "num_experts must be divisible by num_ranks" - x = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 + x = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") + scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs() + 1 topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1] - topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda') + topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device="cuda") rank_idx = topk_idx // (num_experts // num_ranks) rank_idx.masked_fill_(topk_idx == -1, -1) inplace_unique(rank_idx, num_ranks) @@ -192,7 +190,7 @@ def ep_bench(fn, warmup: int = 50, rep: int = 50, post_fn=None): # Flush L2 cache with 256 MB data torch.cuda.synchronize() - cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") # Warmup for _ in range(warmup): @@ -248,8 +246,5 @@ def ep_bench(fn, warmup: int = 50, rep: int = 50, post_fn=None): """ ep_ext = load_inline( - name="ep_ext", - cpp_sources=_src, - functions=["wait_for_counters_ready"], - extra_cflags=["-O3", "-march=native"], - verbose=False) + name="ep_ext", cpp_sources=_src, functions=["wait_for_counters_ready"], extra_cflags=["-O3", "-march=native"], verbose=False +) diff --git a/examples/distributed/deepseek_deepep/intranode/combine.py b/examples/distributed/deepseek_deepep/intranode/combine.py index 17c5f175c..aa95b9339 100644 --- a/examples/distributed/deepseek_deepep/intranode/combine.py +++ b/examples/distributed/deepseek_deepep/intranode/combine.py @@ -11,7 +11,7 @@ import tilelang.language as T tilelang.disable_cache() -os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log +os.environ["NCCL_DEBUG"] = "WARN" # silence NCCL log @tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) @@ -19,15 +19,15 @@ def cached_notify_combine_kernel(num_ranks, num_sms): num_channels = num_sms // 2 threads = max(128, 32 * num_ranks) - num_recv_tokens = T.dynamic('num_recv_tokens') + num_recv_tokens = T.dynamic("num_recv_tokens") @T.prim_func def cached_notify_combine_main( - send_head: T.Tensor([num_recv_tokens, num_ranks], "int32"), - ##### symm buffers ##### - channel_head_idx: T.Tensor([num_channels, num_ranks], "int32"), - channel_tail_idx: T.Tensor([num_channels, num_ranks], "int32"), - barrier_signal: T.Tensor((num_ranks,), 'int32'), + send_head: T.Tensor([num_recv_tokens, num_ranks], "int32"), + ##### symm buffers ##### + channel_head_idx: T.Tensor([num_channels, num_ranks], "int32"), + channel_tail_idx: T.Tensor([num_channels, num_ranks], "int32"), + barrier_signal: T.Tensor((num_ranks,), "int32"), ): with T.Kernel(num_channels + 1, threads=threads) as bx: tx = T.get_thread_binding() @@ -48,17 +48,15 @@ def cached_notify_combine_main( token_start_idx = T.min(tokens_per_channel * channel_id, num_recv_tokens) token_end_idx = T.min(token_start_idx + tokens_per_channel, num_recv_tokens) - last_head = T.alloc_var('int32', init=2**25) # a heuristic large number - # todo: tilelang doesn't support reverse loop, we simulate this - for i in T.serial(0, token_end_idx - token_start_idx, 32): - token_idx_tail = token_end_idx - i - 1 + last_head = T.alloc_var("int32", init=2**25) # a heuristic large number + for token_idx_tail in T.serial(token_end_idx - 1, token_start_idx - 1, -32): token_idx = token_idx_tail - lane_id - current_head = T.alloc_var('int32') + current_head = T.alloc_var("int32") if token_idx >= token_start_idx: T.ld(send_head[token_idx, rank_id], current_head, nc=True) else: current_head = -1 - expected_head = T.alloc_var('int32') + expected_head = T.alloc_var("int32") expected_head = 0 for j in T.serial(T.min(32, token_idx_tail - token_start_idx + 1)): head = T.tvm_warp_shuffle(-1, current_head, j, 32, 32) @@ -74,31 +72,27 @@ def cached_notify_combine_main( def cached_notify_combine( - num_ranks, - num_sms, - ##### symm buffers ##### - send_head: torch.Tensor, - channel_head_idx: torch.Tensor, - channel_tail_idx: torch.Tensor, - barrier_signal: torch.Tensor, - allocator, - comm_stream=None): + num_ranks, + num_sms, + ##### symm buffers ##### + send_head: torch.Tensor, + channel_head_idx: torch.Tensor, + channel_tail_idx: torch.Tensor, + barrier_signal: torch.Tensor, + allocator, +): kernel = cached_notify_combine_kernel(num_ranks, num_sms) - kernel.initialize(allocator=allocator, stream=comm_stream.cuda_stream) + kernel.initialize(allocator=allocator) - kernel( - send_head, - channel_head_idx, - channel_tail_idx, - barrier_signal, - stream=comm_stream.cuda_stream, - skip_tensor_validation=True) # reduce runtime overhead + kernel(send_head, channel_head_idx, channel_tail_idx, barrier_signal) # reduce runtime overhead -@tilelang.jit(pass_configs={ - "tl.disable_tma_lower": True, # use TMA later - "tl.disable_warp_specialized": True -}) +@tilelang.jit( + pass_configs={ + "tl.disable_tma_lower": True, # use TMA later + "tl.disable_warp_specialized": True, + } +) def combine_kernel( num_ranks, num_max_send_tokens, # config.num_max_nvl_chunked_send_tokens @@ -106,10 +100,10 @@ def combine_kernel( hidden, num_topk, num_sms, - dtype: str = 'bfloat16', + dtype: str = "bfloat16", ): - num_tokens = T.dynamic('num_tokens') - num_recv_tokens = T.dynamic('num_recv_tokens') + num_tokens = T.dynamic("num_tokens") + num_recv_tokens = T.dynamic("num_recv_tokens") num_channels = num_sms // 2 threads = 768 # 24 warps @@ -140,12 +134,9 @@ def combine_main( # symm buffers channel_head_idx: T.Tensor([num_channels, num_ranks], "int32"), # reuse, already zeroed channel_tail_idx: T.Tensor([num_channels, num_ranks], "int32"), # reuse, already zeroed - channel_x_buffers: T.Tensor([num_channels, num_ranks, num_recv_buffer_tokens, hidden], - dtype), - channel_src_idx_buffers: T.Tensor([num_channels, num_ranks, num_recv_buffer_tokens], - "int32"), - channel_topk_weights_buffers: T.Tensor( - [num_channels, num_ranks, num_recv_buffer_tokens, num_topk], "float32"), + channel_x_buffers: T.Tensor([num_channels, num_ranks, num_recv_buffer_tokens, hidden], dtype), + channel_src_idx_buffers: T.Tensor([num_channels, num_ranks, num_recv_buffer_tokens], "int32"), + channel_topk_weights_buffers: T.Tensor([num_channels, num_ranks, num_recv_buffer_tokens, num_topk], "float32"), ): with T.Kernel(num_sms, threads=threads) as bx: tx = T.get_thread_binding() @@ -158,85 +149,85 @@ def combine_main( send_warp_id_in_rank = warp_id // num_ranks # get tasks - rank_offset = T.if_then_else(send_rank_id > 0, rank_prefix_matrix[send_rank_id - 1, - rank], 0) + rank_offset = T.if_then_else(send_rank_id > 0, rank_prefix_matrix[send_rank_id - 1, rank], 0) num_rank_tokens = rank_prefix_matrix[send_rank_id, rank] - rank_offset channel_offset = channel_prefix_matrix[send_rank_id, responsible_channel] - num_channel_tokens = T.if_then_else( - responsible_channel == num_channels - 1, num_rank_tokens, - channel_prefix_matrix[send_rank_id, responsible_channel + 1]) - channel_offset + num_channel_tokens = ( + T.if_then_else( + responsible_channel == num_channels - 1, + num_rank_tokens, + channel_prefix_matrix[send_rank_id, responsible_channel + 1], + ) + - channel_offset + ) token_start_idx = rank_offset + channel_offset token_end_idx = token_start_idx + num_channel_tokens # Iterate over all tokens and send by trunk - current_channel_tail_idx = T.alloc_var('int32') + current_channel_tail_idx = T.alloc_var("int32") current_channel_tail_idx = 0 - token_idx = T.alloc_var('int32') + token_idx = T.alloc_var("int32") token_idx = token_start_idx - with T.While(token_idx < token_end_idx): + while token_idx < token_end_idx: # Check destination queue emptiness, or wait a buffer to be released (rare cases) num_round_tokens = T.min(num_max_send_tokens, token_end_idx - token_idx) - if T.elect_one_sync(): + if T.shuffle_elect(32): T.wait_ge( channel_head_idx[responsible_channel, rank], current_channel_tail_idx + num_round_tokens - num_recv_buffer_tokens, - peer=send_rank_id) + peer=send_rank_id, + ) T.sync_warp() # Send by trunk for i in T.serial(send_warp_id_in_rank, num_round_tokens, warps_per_rank): # Get an empty slot - dst_slot_idx = T.alloc_var('int32') + dst_slot_idx = T.alloc_var("int32") dst_slot_idx = (current_channel_tail_idx + i) % num_recv_buffer_tokens # 1. copy data T.put_warp( T.address_of(x[token_idx + i, 0]), - T.address_of(channel_x_buffers[responsible_channel, rank, dst_slot_idx, - 0]), + T.address_of(channel_x_buffers[responsible_channel, rank, dst_slot_idx, 0]), hidden, dst_pe=send_rank_id, unroll_factor=4, - enable_aggressive_vectorize=True) + enable_aggressive_vectorize=True, + ) # 2. send src idx - idx = T.alloc_var('int32') - if T.elect_one_sync(): + idx = T.alloc_var("int32") + if T.shuffle_elect(32): T.ld(src_idx[token_idx + i], idx, nc=True) - T.st( - channel_src_idx_buffers[responsible_channel, rank, dst_slot_idx], - idx, - dst_pe=send_rank_id) + T.st(channel_src_idx_buffers[responsible_channel, rank, dst_slot_idx], idx, dst_pe=send_rank_id) # 3. send topk_weights if num_topk > 0 and lane_id < num_topk: - weight = T.alloc_var('float32') + weight = T.alloc_var("float32") T.ld(topk_weights[token_idx + i, lane_id], weight, nc=True) T.st( - channel_topk_weights_buffers[responsible_channel, rank, - dst_slot_idx, lane_id], - weight, - dst_pe=send_rank_id) + channel_topk_weights_buffers[responsible_channel, rank, dst_slot_idx, lane_id], weight, dst_pe=send_rank_id + ) token_idx += num_round_tokens current_channel_tail_idx += num_round_tokens # move tail index T.sync_threads(send_rank_id, threads_per_rank) - if send_warp_id_in_rank == 0 and T.elect_one_sync(): + if T.shuffle_elect(96): T.st( channel_tail_idx[responsible_channel, rank], current_channel_tail_idx, - scope='sys', - sem='release', - dst_pe=send_rank_id) + scope="sys", + sem="release", + dst_pe=send_rank_id, + ) else: # receiver - #? Why we must need scope='shared', not 'shared.dynamic' here? - warp_channel_head_idx = T.alloc_shared([warps, num_ranks], 'int32', scope='shared') - shared_channel_tail_idx = T.alloc_shared( - [32], 'int32', scope='shared') #! workaround for illegal address - warp_retired = T.alloc_shared([warps], 'bool', scope='shared') + # ? Why we must need scope='shared', not 'shared.dynamic' here? + warp_channel_head_idx = T.alloc_shared([warps, num_ranks], "int32", scope="shared") + shared_channel_tail_idx = T.alloc_shared([32], "int32", scope="shared") #! workaround for illegal address + warp_retired = T.alloc_shared([warps], "bool", scope="shared") if tx < warps: warp_retired[tx] = False if lane_id < num_ranks: @@ -246,84 +237,66 @@ def combine_main( T.sync_threads() if tx < 32: # one warp for moving the queue head - last_head = T.alloc_var('int32') + last_head = T.alloc_var("int32") last_head = 0 - with T.While(lane_id < num_ranks): + while lane_id < num_ranks: # check retired - retired = T.alloc_var('bool') + retired = T.alloc_var("bool") retired = True for i in T.serial(1, warps): retired = retired and warp_retired[i] if retired: - T.loop_break() + break # Update queue tail - new_tail = T.alloc_var('int32') - T.ld( - channel_tail_idx[responsible_channel, lane_id], - new_tail, - sem="acquire", - scope="sys") + new_tail = T.alloc_var("int32") + T.ld(channel_tail_idx[responsible_channel, lane_id], new_tail, sem="acquire", scope="sys") # Use release semantics to ensure receiver warps see the update - T.st( - shared_channel_tail_idx[lane_id], new_tail, sem="release", - scope="cta") # todo: weaker sem pair + T.st(shared_channel_tail_idx[lane_id], new_tail, sem="release", scope="cta") # todo: weaker sem pair # Update minimum head - min_head = T.alloc_var('int32') + min_head = T.alloc_var("int32") min_head = 2**31 - 1 # int32 max for i in T.serial(1, warps): if not warp_retired[i]: min_head = T.min(min_head, warp_channel_head_idx[i, lane_id]) if min_head != 2**31 - 1 and min_head > last_head: last_head = min_head - T.st( - channel_head_idx[responsible_channel, lane_id], - min_head, - sem="relaxed", - scope="sys") + T.st(channel_head_idx[responsible_channel, lane_id], min_head, sem="relaxed", scope="sys") else: # other warps for reduction # All lanes will use data buffer, but only rank lane will use `head/tail/src_idx` # The same tokens as the dispatch process - num_tokens_per_channel = T.truncdiv(num_recv_tokens + num_channels - 1, - num_channels) + num_tokens_per_channel = T.truncdiv(num_recv_tokens + num_channels - 1, num_channels) # todo: this is a workaround, as TVM has a bug when calculating safe ceildiv for tir.Var - token_start_idx = T.min(num_tokens_per_channel * responsible_channel, - num_recv_tokens) + token_start_idx = T.min(num_tokens_per_channel * responsible_channel, num_recv_tokens) token_end_idx = T.min(token_start_idx + num_tokens_per_channel, num_recv_tokens) # Iterate over all tokens and combine - for token_idx in T.serial(token_start_idx + warp_id - 1, token_end_idx, - warps - 1): + for token_idx in T.serial(token_start_idx + warp_id - 1, token_end_idx, warps - 1): # Read expected head - expected_head = T.alloc_var('int32') + expected_head = T.alloc_var("int32") expected_head = -1 if lane_id < num_ranks: T.ld(send_head[token_idx, lane_id], expected_head, nc=True) - condvar = T.alloc_var('int32') + condvar = T.alloc_var("int32") T.ld(shared_channel_tail_idx[lane_id], condvar, sem="acquire", scope="cta") - with T.While(T.warp_any(condvar <= expected_head and expected_head >= 0)): - T.ld( - shared_channel_tail_idx[lane_id], - condvar, - sem="acquire", - scope="cta") - T.loop_continue() + while T.warp_any(condvar <= expected_head and expected_head >= 0): + T.ld(shared_channel_tail_idx[lane_id], condvar, sem="acquire", scope="cta") + continue # can we simplify this ? T.sync_warp() # Broadcast current heads - num_topk_ranks = T.alloc_var('int32') + num_topk_ranks = T.alloc_var("int32") num_topk_ranks = 0 - topk_ranks = T.alloc_local([num_ranks], 'int32') - slot_indices = T.alloc_local([num_ranks], 'int32') + topk_ranks = T.alloc_local([num_ranks], "int32") + slot_indices = T.alloc_local([num_ranks], "int32") for i in T.serial(num_ranks): expected_head_i = T.tvm_warp_shuffle(-1, expected_head, i, 32, 32) if expected_head_i >= 0: - slot_indices[ - num_topk_ranks] = expected_head_i % num_recv_buffer_tokens + slot_indices[num_topk_ranks] = expected_head_i % num_recv_buffer_tokens topk_ranks[num_topk_ranks] = i num_topk_ranks += 1 @@ -337,10 +310,10 @@ def combine_main( for j in T.serial(num_topk_ranks): for k in T.vectorized(8): T.ld( - channel_x_buffers[responsible_channel, topk_ranks[j], - slot_indices[j], i * 8 + k], + channel_x_buffers[responsible_channel, topk_ranks[j], slot_indices[j], i * 8 + k], recv_value[j, k], - nc=True) + nc=True, + ) # todo: support bias @@ -349,47 +322,52 @@ def combine_main( for k in T.vectorized(8): values[k] += recv_value[j, k] for j in T.vectorized(8): - recv_x[token_idx, - i * 8 + j] = values[j] # todo: further vectorize this + recv_x[token_idx, i * 8 + j] = values[j] # todo: further vectorize this # Reduce topk_weights if lane_id < num_topk: - weight_sum = T.alloc_var('float32') + weight_sum = T.alloc_var("float32") weight_sum = 0 for i in T.serial(num_topk_ranks): - weight = T.alloc_var('float32') + weight = T.alloc_var("float32") T.ld( - channel_topk_weights_buffers[responsible_channel, topk_ranks[i], - slot_indices[i], lane_id], + channel_topk_weights_buffers[responsible_channel, topk_ranks[i], slot_indices[i], lane_id], weight, - nc=True) + nc=True, + ) weight_sum += weight recv_topk_weights[token_idx, lane_id] = weight_sum # Update head if lane_id < num_ranks: warp_channel_head_idx[warp_id, lane_id] = T.if_then_else( - expected_head < 0, -expected_head - 1, expected_head + 1) + expected_head < 0, -expected_head - 1, expected_head + 1 + ) # Retired T.sync_warp() - if T.elect_one_sync(): + if T.shuffle_elect(32): warp_retired[warp_id] = True return combine_main -def intranode_combine(rank: int, - allocator, - symm_buffers, - x, - config, - handle, - topk_weights, - comm_stream=None): +def intranode_combine(rank: int, allocator, symm_buffers, x, config, handle, topk_weights, comm_stream=None): assert handle is not None rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, _, send_head = handle - barrier_signal, _, _, _, _, channel_head_idx, channel_tail_idx, channel_x_buffers, channel_src_idx_buffers, _, channel_topk_weights_buffers = symm_buffers + ( + barrier_signal, + _, + _, + _, + _, + channel_head_idx, + channel_tail_idx, + channel_x_buffers, + channel_src_idx_buffers, + _, + channel_topk_weights_buffers, + ) = symm_buffers # acquire_shapes _, hidden = x.shape @@ -398,19 +376,12 @@ def intranode_combine(rank: int, num_recv_tokens = send_head.shape[0] # notify combine - cached_notify_combine( - num_ranks, - config.num_sms, - send_head, - channel_head_idx, - channel_tail_idx, - barrier_signal, - allocator, - comm_stream=comm_stream) + with torch.cuda.stream(comm_stream): + cached_notify_combine(num_ranks, config.num_sms, send_head, channel_head_idx, channel_tail_idx, barrier_signal, allocator) # combine - recv_x = torch.empty((num_recv_tokens, hidden), dtype=x.dtype, device='cuda') - recv_topk_weights = torch.empty((num_recv_tokens, num_topk), dtype=torch.float32, device='cuda') + recv_x = torch.empty((num_recv_tokens, hidden), dtype=x.dtype, device="cuda") + recv_topk_weights = torch.empty((num_recv_tokens, num_topk), dtype=torch.float32, device="cuda") kernel = combine_kernel( num_ranks, @@ -419,25 +390,26 @@ def intranode_combine(rank: int, hidden, num_topk, config.num_sms, - dtype='bfloat16') - kernel.initialize(allocator=allocator, stream=comm_stream.cuda_stream) - kernel( - rank, - x, - topk_weights, - recv_src_idx, - recv_x, - recv_topk_weights, - rank_prefix_matrix, - recv_channel_prefix_matrix, - send_head, - channel_head_idx, - channel_tail_idx, - channel_x_buffers, - channel_src_idx_buffers, - channel_topk_weights_buffers, - stream=comm_stream.cuda_stream, - skip_tensor_validation=True) # reduce runtime overhead + dtype="bfloat16", + ) + with torch.cuda.stream(comm_stream): + kernel.initialize(allocator=allocator) + kernel( + rank, + x, + topk_weights, + recv_src_idx, + recv_x, + recv_topk_weights, + rank_prefix_matrix, + recv_channel_prefix_matrix, + send_head, + channel_head_idx, + channel_tail_idx, + channel_x_buffers, + channel_src_idx_buffers, + channel_topk_weights_buffers, + ) # reduce runtime overhead compute_stream = torch.cuda.current_stream() compute_stream.wait_stream(comm_stream) return recv_x, recv_topk_weights diff --git a/examples/distributed/deepseek_deepep/intranode/dispatch.py b/examples/distributed/deepseek_deepep/intranode/dispatch.py index 0811a4eb1..83912a089 100644 --- a/examples/distributed/deepseek_deepep/intranode/dispatch.py +++ b/examples/distributed/deepseek_deepep/intranode/dispatch.py @@ -11,9 +11,10 @@ import tilelang.language as T from typing import Optional, Tuple from deepep_utils import Config, ep_ext # noqa: F403 +import tvm_ffi # tilelang.disable_cache() -os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log +os.environ["NCCL_DEBUG"] = "WARN" # silence NCCL log # notify_dispatch is responsible for: @@ -30,26 +31,26 @@ def notify_dispatch_kernel( num_local_experts = num_experts // num_ranks num_warps = threads // 32 - num_tokens = T.dynamic('num_tokens') + num_tokens = T.dynamic("num_tokens") @T.prim_func def notify_dispatch_main( - rank: T.int32, - num_tokens_per_rank: T.Tensor((num_ranks,), 'int32'), - num_tokens_per_expert: T.Tensor((num_experts,), 'int32'), - is_token_in_rank: T.Tensor((num_tokens, num_ranks), 'bool'), - moe_recv_counter_mapped: T.Tensor((1,), 'int32'), - moe_recv_expert_counter_mapped: T.Tensor((num_local_experts,), 'int32'), - per_rank_buffer: T.Tensor((num_ranks, num_ranks), 'int32'), - per_expert_buffer: T.Tensor((num_ranks, num_local_experts), 'int32'), - barrier_signal: T.Tensor((num_ranks,), 'int32'), - rank_prefix_matrix: T.Tensor((num_ranks, num_ranks), 'int32'), - channel_prefix_matrix: T.Tensor((num_ranks, num_channels), 'int32'), - # 4 symm buffers to be zeroed - channel_start_offset: T.Tensor([num_channels, num_ranks], "int32"), - channel_end_offset: T.Tensor([num_channels, num_ranks], "int32"), - channel_head_idx: T.Tensor([num_channels, num_ranks], "int32"), - channel_tail_idx: T.Tensor([num_channels, num_ranks], "int32"), + rank: T.int32, + num_tokens_per_rank: T.Tensor((num_ranks,), "int32"), + num_tokens_per_expert: T.Tensor((num_experts,), "int32"), + is_token_in_rank: T.Tensor((num_tokens, num_ranks), "bool"), + moe_recv_counter_mapped: T.Tensor((1,), "int32"), + moe_recv_expert_counter_mapped: T.Tensor((num_local_experts,), "int32"), + per_rank_buffer: T.Tensor((num_ranks, num_ranks), "int32"), + per_expert_buffer: T.Tensor((num_ranks, num_local_experts), "int32"), + barrier_signal: T.Tensor((num_ranks,), "int32"), + rank_prefix_matrix: T.Tensor((num_ranks, num_ranks), "int32"), + channel_prefix_matrix: T.Tensor((num_ranks, num_channels), "int32"), + # 4 symm buffers to be zeroed + channel_start_offset: T.Tensor([num_channels, num_ranks], "int32"), + channel_end_offset: T.Tensor([num_channels, num_ranks], "int32"), + channel_head_idx: T.Tensor([num_channels, num_ranks], "int32"), + channel_tail_idx: T.Tensor([num_channels, num_ranks], "int32"), ): with T.Kernel(num_ranks + 1, threads=threads) as bx: tx = T.get_thread_binding() @@ -64,10 +65,7 @@ def notify_dispatch_main( if tx < num_ranks: T.st(per_rank_buffer[rank, tx], num_tokens_per_rank[tx], dst_pe=tx) for i in T.serial(num_local_experts): - T.st( - per_expert_buffer[rank, i], - num_tokens_per_expert[tx * num_local_experts + i], - dst_pe=tx) + T.st(per_expert_buffer[rank, i], num_tokens_per_expert[tx * num_local_experts + i], dst_pe=tx) T.barrier_blocks(barrier_signal) @@ -80,7 +78,7 @@ def notify_dispatch_main( # Sum per-expert cnts if tx < num_local_experts: - sum = T.alloc_local([1], 'int32') + sum = T.alloc_local([1], "int32") sum[0] = 0 for i in T.serial(0, num_ranks): sum[0] += per_expert_buffer[i, tx] @@ -106,12 +104,12 @@ def notify_dispatch_main( # todo: this is a workaround, as TVM has a bug when calculating safe ceildiv for tir.Var token_start_idx = T.min(num_tokens_per_channel * channel_id, num_tokens) token_end_idx = T.min(token_start_idx + num_tokens_per_channel, num_tokens) - cnt = T.alloc_var('int32') + cnt = T.alloc_var("int32") cnt = 0 for i in T.serial(token_start_idx + lane_id, token_end_idx, 32): cnt += is_token_in_rank[i, dst_rank] cnt = T.warp_reduce_sum(cnt) - if T.elect_one_sync(): + if T.shuffle_elect(32): channel_prefix_matrix[dst_rank, channel_id] = cnt T.sync_threads() @@ -149,7 +147,7 @@ def notify_dispatch( channel_tail_idx: torch.Tensor, # allocator allocator, - comm_stream=None, + comm_stream: torch.cuda.Stream = None, ): kernel = notify_dispatch_kernel( num_ranks, @@ -159,8 +157,8 @@ def notify_dispatch( ) kernel.initialize(allocator=allocator, stream=comm_stream.cuda_stream) - rank_prefix_matrix = torch.empty([num_ranks, num_ranks], dtype=torch.int32, device='cuda') - channel_prefix_matrix = torch.empty([num_ranks, num_channels], dtype=torch.int32, device='cuda') + rank_prefix_matrix = torch.empty([num_ranks, num_ranks], dtype=torch.int32, device="cuda") + channel_prefix_matrix = torch.empty([num_ranks, num_channels], dtype=torch.int32, device="cuda") # clear buffers and counters moe_recv_counter.fill_(-1) @@ -182,27 +180,22 @@ def notify_dispatch( channel_end_offset, channel_head_idx, channel_tail_idx, - stream=comm_stream.cuda_stream, - skip_tensor_validation=True # reduce runtime overhead ) - - num_recv_tokens, num_recv_tokens_per_expert_list = ep_ext.wait_for_counters_ready( - moe_recv_counter, moe_recv_expert_counter) + num_recv_tokens, num_recv_tokens_per_expert_list = ep_ext.wait_for_counters_ready(moe_recv_counter, moe_recv_expert_counter) return num_recv_tokens, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix # cached_notify_dispatch only needs to clear symm buffers @tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) def cached_notify_dispatch_kernel(num_ranks: int, num_channels: int): - @T.prim_func def cached_notify_dispatch_main( - barrier_signal: T.Tensor((num_ranks,), 'int32'), - # 4 symm buffers to be zeroed - channel_start_offset: T.Tensor([num_channels, num_ranks], "int32"), - channel_end_offset: T.Tensor([num_channels, num_ranks], "int32"), - channel_head_idx: T.Tensor([num_channels, num_ranks], "int32"), - channel_tail_idx: T.Tensor([num_channels, num_ranks], "int32"), + barrier_signal: T.Tensor((num_ranks,), "int32"), + # 4 symm buffers to be zeroed + channel_start_offset: T.Tensor([num_channels, num_ranks], "int32"), + channel_end_offset: T.Tensor([num_channels, num_ranks], "int32"), + channel_head_idx: T.Tensor([num_channels, num_ranks], "int32"), + channel_tail_idx: T.Tensor([num_channels, num_ranks], "int32"), ): with T.Kernel(1, threads=128): T.sync_blocks(barrier_signal) @@ -232,22 +225,23 @@ def cached_notify_dispatch( comm_stream=None, ): kernel = cached_notify_dispatch_kernel(num_ranks, num_channels) - kernel.initialize( - allocator=allocator, stream=comm_stream.cuda_stream) # we still comm on barrier_signal - kernel( - barrier_signal, - channel_start_offset, - channel_end_offset, - channel_head_idx, - channel_tail_idx, - stream=comm_stream.cuda_stream, - skip_tensor_validation=True) # reduce runtime overhead + kernel.initialize(allocator=allocator, stream=comm_stream.cuda_stream) + with torch.cuda.stream(comm_stream): + kernel( + barrier_signal, + channel_start_offset, + channel_end_offset, + channel_head_idx, + channel_tail_idx, + ) -@tilelang.jit(pass_configs={ - "tl.disable_tma_lower": True, # enable TMA later - "tl.disable_warp_specialized": True -}) +@tilelang.jit( + pass_configs={ + "tl.disable_tma_lower": True, # enable TMA later + "tl.disable_warp_specialized": True, + } +) def dispatch_kernel( num_ranks, num_max_send_tokens, # config.num_max_nvl_chunked_send_tokens @@ -256,7 +250,7 @@ def dispatch_kernel( num_topk, num_experts, num_sms, - dtype: str = 'bfloat16', + dtype: str = "bfloat16", ): threads = 768 # 24 warps TMABytesPerWarp = 8192 @@ -269,17 +263,17 @@ def dispatch_kernel( num_warps = threads // 32 # 24 num_warps_per_rank = num_warps // num_ranks # 3 - num_tokens = T.dynamic('num_tokens') - num_recv_tokens = T.dynamic('num_recv_tokens') + num_tokens = T.dynamic("num_tokens") + num_recv_tokens = T.dynamic("num_recv_tokens") @T.prim_func def dispatch_main( rank: T.int32, # output recv_x: T.Tensor((num_recv_tokens, hidden), dtype), - recv_src_idx: T.Tensor((num_recv_tokens,), 'int32'), - recv_topk_idx: T.Tensor((num_recv_tokens, num_topk), 'int64'), - recv_topk_weights: T.Tensor((num_recv_tokens, num_topk), 'float'), + recv_src_idx: T.Tensor((num_recv_tokens,), "int32"), + recv_topk_idx: T.Tensor((num_recv_tokens, num_topk), "int64"), + recv_topk_weights: T.Tensor((num_recv_tokens, num_topk), "float"), recv_channel_offset: T.Tensor([num_ranks, num_channels], "int32"), send_head: T.Tensor([num_tokens, num_ranks], "int32"), # input @@ -297,14 +291,10 @@ def dispatch_main( channel_head_idx: T.Tensor([num_channels, num_ranks], "int32"), channel_tail_idx: T.Tensor([num_channels, num_ranks], "int32"), # channel data buffers, stored on the receiver side - channel_x_buffers: T.Tensor([num_channels, num_ranks, num_recv_buffer_tokens, hidden], - dtype), - channel_src_idx_buffers: T.Tensor([num_channels, num_ranks, num_recv_buffer_tokens], - "int32"), - channel_topk_idx_buffers: T.Tensor( - [num_channels, num_ranks, num_recv_buffer_tokens, num_topk], "int64"), - channel_topk_weights_buffers: T.Tensor( - [num_channels, num_ranks, num_recv_buffer_tokens, num_topk], "float32"), + channel_x_buffers: T.Tensor([num_channels, num_ranks, num_recv_buffer_tokens, hidden], dtype), + channel_src_idx_buffers: T.Tensor([num_channels, num_ranks, num_recv_buffer_tokens], "int32"), + channel_topk_idx_buffers: T.Tensor([num_channels, num_ranks, num_recv_buffer_tokens, num_topk], "int64"), + channel_topk_weights_buffers: T.Tensor([num_channels, num_ranks, num_recv_buffer_tokens, num_topk], "float32"), # channel_x_scales_buffers: T.Tensor([num_channels, num_ranks, num_recv_buffer_tokens, num_scales], "float32"), ): with T.Kernel(num_sms, threads=threads) as bx: @@ -318,65 +308,53 @@ def dispatch_main( # send offset by `-value-1` e.g. 0->-1, 1->-2 # this is for distinguishing zero tokens - if send_warp_id_in_rank == 0 and T.elect_one_sync(): - value = T.alloc_var('int32') - value = T.if_then_else( - responsible_channel > 0, channel_prefix_matrix[responsible_rank, - responsible_channel - 1], 0) - T.st( - channel_start_offset[responsible_channel, rank], - -value - 1, - scope='sys', - sem='relaxed', - dst_pe=responsible_rank) + if send_warp_id_in_rank == 0 and T.shuffle_elect(32): + value = T.alloc_var("int32") + value = T.if_then_else(responsible_channel > 0, channel_prefix_matrix[responsible_rank, responsible_channel - 1], 0) + T.st(channel_start_offset[responsible_channel, rank], -value - 1, scope="sys", sem="relaxed", dst_pe=responsible_rank) value = channel_prefix_matrix[responsible_rank, responsible_channel] - T.st( - channel_end_offset[responsible_channel, rank], - -value - 1, - scope='sys', - sem='relaxed', - dst_pe=responsible_rank) + T.st(channel_end_offset[responsible_channel, rank], -value - 1, scope="sys", sem="relaxed", dst_pe=responsible_rank) T.sync_warp() # get task num_tokens_per_channel = T.truncdiv(num_tokens + num_channels - 1, num_channels) # todo: this is a workaround, as TVM has a bug when calculating safe ceildiv for tir.Var - token_start_idx = T.alloc_var('int32') + token_start_idx = T.alloc_var("int32") token_start_idx = T.min(num_tokens_per_channel * responsible_channel, num_tokens) - token_end_idx = T.alloc_var('int32') + token_end_idx = T.alloc_var("int32") token_end_idx = T.min(token_start_idx + num_tokens_per_channel, num_tokens) # sender mainloop: iterate over all tokens and send by trunk - cached_channel_tail_idx = T.alloc_var('int32') + cached_channel_tail_idx = T.alloc_var("int32") cached_channel_tail_idx = 0 - token_idx = T.alloc_var('int32') + token_idx = T.alloc_var("int32") token_idx = token_start_idx - with T.While(token_idx < token_end_idx): - if T.elect_one_sync(): + while token_idx < token_end_idx: + if T.shuffle_elect(32): T.wait_ge( channel_head_idx[responsible_channel, rank], num_max_send_tokens + cached_channel_tail_idx - num_recv_buffer_tokens, - responsible_rank) + responsible_rank, + ) T.sync_warp() - chunk_token_idx = T.alloc_var('int32') + chunk_token_idx = T.alloc_var("int32") chunk_token_idx = 0 while chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx: # for the same token, the warp assigned to save `send_head` may be different from the warp # assigned to send the following data - if token_idx % num_warps_per_rank == send_warp_id_in_rank and T.elect_one_sync( - ): + if token_idx % num_warps_per_rank == send_warp_id_in_rank and T.shuffle_elect(32): send_head[token_idx, responsible_rank] = T.if_then_else( - is_token_in_rank[token_idx, responsible_rank], - cached_channel_tail_idx, -1) + is_token_in_rank[token_idx, responsible_rank], cached_channel_tail_idx, -1 + ) # skip if not selected if not is_token_in_rank[token_idx, responsible_rank]: token_idx += 1 - T.loop_continue() + continue # selected, get an empty slot - dst_slot_idx = T.alloc_var('int32') + dst_slot_idx = T.alloc_var("int32") dst_slot_idx = cached_channel_tail_idx % num_recv_buffer_tokens cached_channel_tail_idx += 1 if cached_channel_tail_idx % num_warps_per_rank == send_warp_id_in_rank: @@ -384,20 +362,16 @@ def dispatch_main( # 1. copy data T.put_warp( T.address_of(x[token_idx, 0]), - T.address_of(channel_x_buffers[responsible_channel, rank, - dst_slot_idx, 0]), + T.address_of(channel_x_buffers[responsible_channel, rank, dst_slot_idx, 0]), hidden, dst_pe=responsible_rank, unroll_factor=4, - enable_aggressive_vectorize=True) + enable_aggressive_vectorize=True, + ) # 2. copy src idx - if T.elect_one_sync(): - T.st( - channel_src_idx_buffers[responsible_channel, rank, - dst_slot_idx], - token_idx, - dst_pe=responsible_rank) + if T.shuffle_elect(32): + T.st(channel_src_idx_buffers[responsible_channel, rank, dst_slot_idx], token_idx, dst_pe=responsible_rank) # 3. copy `topk_idx` and `topk_weights` with transformed index if lane_id < num_topk: @@ -405,26 +379,26 @@ def dispatch_main( recv_expert_begin = responsible_rank * num_local_experts recv_expert_end = recv_expert_begin + num_local_experts - idx_value = T.alloc_var('int64') + idx_value = T.alloc_var("int64") T.ld(topk_idx[token_idx, lane_id], idx_value, nc=True) idx_value = T.if_then_else( - recv_expert_begin <= T.cast(idx_value, 'int32') < - recv_expert_end, idx_value - recv_expert_begin, -1) + recv_expert_begin <= T.cast(idx_value, "int32") < recv_expert_end, idx_value - recv_expert_begin, -1 + ) T.st( - channel_topk_idx_buffers[responsible_channel, rank, - dst_slot_idx, lane_id], + channel_topk_idx_buffers[responsible_channel, rank, dst_slot_idx, lane_id], idx_value, - dst_pe=responsible_rank) + dst_pe=responsible_rank, + ) # topk_weights - weight_value = T.alloc_var('float32') + weight_value = T.alloc_var("float32") T.ld(topk_weights[token_idx, lane_id], weight_value, nc=True) weight_value = T.if_then_else(idx_value >= 0, weight_value, 0) T.st( - channel_topk_weights_buffers[responsible_channel, rank, - dst_slot_idx, lane_id], + channel_topk_weights_buffers[responsible_channel, rank, dst_slot_idx, lane_id], weight_value, - dst_pe=responsible_rank) + dst_pe=responsible_rank, + ) # 4. copy scale (support fp8 later) @@ -434,36 +408,30 @@ def dispatch_main( # move tail index # here all warps should share the same new tail T.sync_threads(responsible_rank, num_threads_per_rank) - if send_warp_id_in_rank == 0 and T.elect_one_sync(): + if send_warp_id_in_rank == 0 and T.shuffle_elect(32): T.st( channel_tail_idx[responsible_channel, rank], cached_channel_tail_idx, - scope='sys', - sem='release', - dst_pe=responsible_rank) + scope="sys", + sem="release", + dst_pe=responsible_rank, + ) else: # receiver recv_thread_id_in_rank = tx % num_threads_per_rank recv_warp_id_in_rank = recv_thread_id_in_rank // 32 # calculate offset first - rank_offset = T.if_then_else(responsible_rank > 0, - rank_prefix_matrix[responsible_rank - 1, rank], 0) + rank_offset = T.if_then_else(responsible_rank > 0, rank_prefix_matrix[responsible_rank - 1, rank], 0) # receive channel offset - total_offset = T.alloc_var('int32') - num_tokens_to_recv = T.alloc_var('int32') - if T.elect_one_sync(): + total_offset = T.alloc_var("int32") + num_tokens_to_recv = T.alloc_var("int32") + if T.shuffle_elect(32): T.wait_ne(channel_start_offset[responsible_channel, responsible_rank], 0) - T.ld( - channel_start_offset[responsible_channel, responsible_rank], - total_offset, - sem='volatile') + T.ld(channel_start_offset[responsible_channel, responsible_rank], total_offset, sem="volatile") T.wait_ne(channel_end_offset[responsible_channel, responsible_rank], 0) - T.ld( - channel_end_offset[responsible_channel, responsible_rank], - num_tokens_to_recv, - sem='volatile') + T.ld(channel_end_offset[responsible_channel, responsible_rank], num_tokens_to_recv, sem="volatile") total_offset = -total_offset - 1 num_tokens_to_recv = -num_tokens_to_recv - 1 if recv_warp_id_in_rank == 0: @@ -474,24 +442,20 @@ def dispatch_main( num_tokens_to_recv = T.tvm_warp_shuffle(-1, num_tokens_to_recv, 0, 32, 32) # Shared tail indices for different warps - shared_channel_tail_idx = T.alloc_shared([num_ranks], 'int32') + shared_channel_tail_idx = T.alloc_shared([num_ranks], "int32") - cached_channel_head_idx = T.alloc_var('int32') + cached_channel_head_idx = T.alloc_var("int32") cached_channel_head_idx = 0 - cached_channel_tail_idx = T.alloc_var('int32') + cached_channel_tail_idx = T.alloc_var("int32") cached_channel_tail_idx = 0 - with T.While(num_tokens_to_recv > 0): - with T.While(recv_thread_id_in_rank == 0): - T.ld( - channel_tail_idx[responsible_channel, responsible_rank], - cached_channel_tail_idx, - sem='acquire', - scope='sys') + while num_tokens_to_recv > 0: + while recv_thread_id_in_rank == 0: + T.ld(channel_tail_idx[responsible_channel, responsible_rank], cached_channel_tail_idx, sem="acquire", scope="sys") # read to copy if cached_channel_head_idx != cached_channel_tail_idx: shared_channel_tail_idx[responsible_rank] = cached_channel_tail_idx - T.loop_break() + break # sync queue tail T.sync_threads(responsible_rank, num_threads_per_rank) @@ -500,48 +464,42 @@ def dispatch_main( # copy data # 1. recv x num_cur_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx - for chunk_idx in T.serial(recv_warp_id_in_rank, num_cur_recv_tokens, - num_warps_per_rank): - token_idx_in_buffer = (cached_channel_head_idx + - chunk_idx) % num_recv_buffer_tokens + for chunk_idx in T.serial(recv_warp_id_in_rank, num_cur_recv_tokens, num_warps_per_rank): + token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens # T.copy(channel_x_buffers[responsible_channel, responsible_rank, token_idx_in_buffer, :], recv_x[total_offset+chunk_idx, :]) # todo: add ld_nc and st_na #! T.copy will cause layout inference error T.put_warp( - T.address_of(channel_x_buffers[responsible_channel, responsible_rank, - token_idx_in_buffer, 0]), + T.address_of(channel_x_buffers[responsible_channel, responsible_rank, token_idx_in_buffer, 0]), T.address_of(recv_x[total_offset + chunk_idx, 0]), hidden, -1, 5, - enable_aggressive_vectorize=True) + enable_aggressive_vectorize=True, + ) # 2. recv src_idx - for chunk_idx in T.serial(cached_channel_head_idx + recv_thread_id_in_rank, - cached_channel_tail_idx, num_threads_per_rank): - local_src_idx = T.alloc_var('int32') + for chunk_idx in T.serial( + cached_channel_head_idx + recv_thread_id_in_rank, cached_channel_tail_idx, num_threads_per_rank + ): + local_src_idx = T.alloc_var("int32") T.ld( - channel_src_idx_buffers[responsible_channel, responsible_rank, - chunk_idx % num_recv_buffer_tokens], + channel_src_idx_buffers[responsible_channel, responsible_rank, chunk_idx % num_recv_buffer_tokens], local_src_idx, - nc=True) - recv_src_idx[total_offset + chunk_idx - - cached_channel_head_idx] = local_src_idx + nc=True, + ) + recv_src_idx[total_offset + chunk_idx - cached_channel_head_idx] = local_src_idx # 3. recv topk_idx and topk_weights - for idx in T.serial(recv_thread_id_in_rank, num_cur_recv_tokens * num_topk, - num_threads_per_rank): + for idx in T.serial(recv_thread_id_in_rank, num_cur_recv_tokens * num_topk, num_threads_per_rank): chunk_idx = idx // num_topk token_topk_idx = idx % num_topk - token_idx_in_buffer = (cached_channel_head_idx + - chunk_idx) % num_recv_buffer_tokens - recv_topk_idx[total_offset + chunk_idx, - token_topk_idx] = channel_topk_idx_buffers[ - responsible_channel, responsible_rank, - token_idx_in_buffer, token_topk_idx] - recv_topk_weights[total_offset + chunk_idx, - token_topk_idx] = channel_topk_weights_buffers[ - responsible_channel, responsible_rank, - token_idx_in_buffer, token_topk_idx] + token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens + recv_topk_idx[total_offset + chunk_idx, token_topk_idx] = channel_topk_idx_buffers[ + responsible_channel, responsible_rank, token_idx_in_buffer, token_topk_idx + ] + recv_topk_weights[total_offset + chunk_idx, token_topk_idx] = channel_topk_weights_buffers[ + responsible_channel, responsible_rank, token_idx_in_buffer, token_topk_idx + ] # 4. recv scale (support fp8 later) @@ -549,12 +507,8 @@ def dispatch_main( cached_channel_head_idx += num_cur_recv_tokens total_offset += num_cur_recv_tokens T.sync_threads(responsible_rank, num_threads_per_rank) - if recv_warp_id_in_rank == num_warps_per_rank - 1 and T.elect_one_sync(): - T.st( - channel_head_idx[responsible_channel, responsible_rank], - cached_channel_head_idx, - scope='sys', - sem='relaxed') + if recv_warp_id_in_rank == num_warps_per_rank - 1 and T.shuffle_elect(32): + T.st(channel_head_idx[responsible_channel, responsible_rank], cached_channel_head_idx, scope="sys", sem="relaxed") # Exit num_tokens_to_recv -= num_cur_recv_tokens @@ -562,10 +516,12 @@ def dispatch_main( return dispatch_main -@tilelang.jit(pass_configs={ - "tl.disable_tma_lower": True, # enable TMA later - "tl.disable_warp_specialized": True -}) +@tilelang.jit( + pass_configs={ + "tl.disable_tma_lower": True, # enable TMA later + "tl.disable_warp_specialized": True, + } +) def cached_dispatch_kernel( num_ranks, num_tokens, @@ -573,7 +529,7 @@ def cached_dispatch_kernel( num_recv_buffer_tokens, # config.num_max_nvl_chunked_recv_tokens hidden, num_sms, - dtype: str = 'bfloat16', + dtype: str = "bfloat16", ): threads = 768 # 24 warps TMABytesPerWarp = 8192 @@ -585,14 +541,14 @@ def cached_dispatch_kernel( num_warps = threads // 32 # 24 num_warps_per_rank = num_warps // num_ranks # 3 - num_recv_tokens = T.dynamic('num_recv_tokens') + num_recv_tokens = T.dynamic("num_recv_tokens") @T.prim_func def cached_dispatch_main( rank: T.int32, # output recv_x: T.Tensor((num_recv_tokens, hidden), dtype), - recv_src_idx: T.Tensor((num_recv_tokens,), 'int32'), + recv_src_idx: T.Tensor((num_recv_tokens,), "int32"), recv_channel_offset: T.Tensor([num_ranks, num_channels], "int32"), send_head: T.Tensor([num_tokens, num_ranks], "int32"), # input @@ -608,10 +564,8 @@ def cached_dispatch_main( channel_head_idx: T.Tensor([num_channels, num_ranks], "int32"), channel_tail_idx: T.Tensor([num_channels, num_ranks], "int32"), # channel data buffers, stored on the receiver side - channel_x_buffers: T.Tensor([num_channels, num_ranks, num_recv_buffer_tokens, hidden], - dtype), - channel_src_idx_buffers: T.Tensor([num_channels, num_ranks, num_recv_buffer_tokens], - "int32"), + channel_x_buffers: T.Tensor([num_channels, num_ranks, num_recv_buffer_tokens, hidden], dtype), + channel_src_idx_buffers: T.Tensor([num_channels, num_ranks, num_recv_buffer_tokens], "int32"), # channel_x_scales_buffers: T.Tensor([num_channels, num_ranks, num_recv_buffer_tokens, num_scales], "float32"), ): with T.Kernel(num_sms, threads=threads) as bx: @@ -624,65 +578,52 @@ def cached_dispatch_main( # send offset by `-value-1` e.g. 0->-1, 1->-2 # this is for distinguishing zero tokens - if send_warp_id_in_rank == 0 and T.elect_one_sync(): - value = T.alloc_var('int32') - value = T.if_then_else( - responsible_channel > 0, channel_prefix_matrix[responsible_rank, - responsible_channel - 1], 0) - T.st( - channel_start_offset[responsible_channel, rank], - -value - 1, - scope='sys', - sem='relaxed', - dst_pe=responsible_rank) + if send_warp_id_in_rank == 0 and T.shuffle_elect(32): + value = T.alloc_var("int32") + value = T.if_then_else(responsible_channel > 0, channel_prefix_matrix[responsible_rank, responsible_channel - 1], 0) + T.st(channel_start_offset[responsible_channel, rank], -value - 1, scope="sys", sem="relaxed", dst_pe=responsible_rank) value = channel_prefix_matrix[responsible_rank, responsible_channel] - T.st( - channel_end_offset[responsible_channel, rank], - -value - 1, - scope='sys', - sem='relaxed', - dst_pe=responsible_rank) + T.st(channel_end_offset[responsible_channel, rank], -value - 1, scope="sys", sem="relaxed", dst_pe=responsible_rank) T.sync_warp() # get task - num_tokens_per_channel = T.alloc_var( - 'int32', init=T.ceildiv(num_tokens, num_channels)) - token_start_idx = T.alloc_var('int32') + num_tokens_per_channel = T.alloc_var("int32", init=T.ceildiv(num_tokens, num_channels)) + token_start_idx = T.alloc_var("int32") token_start_idx = T.min(num_tokens_per_channel * responsible_channel, num_tokens) - token_end_idx = T.alloc_var('int32') + token_end_idx = T.alloc_var("int32") token_end_idx = T.min(token_start_idx + num_tokens_per_channel, num_tokens) # sender mainloop: iterate over all tokens and send by trunk - cached_channel_tail_idx = T.alloc_var('int32') + cached_channel_tail_idx = T.alloc_var("int32") cached_channel_tail_idx = 0 - token_idx = T.alloc_var('int32') + token_idx = T.alloc_var("int32") token_idx = token_start_idx - with T.While(token_idx < token_end_idx): - if T.elect_one_sync(): + while token_idx < token_end_idx: + if T.shuffle_elect(32): T.wait_ge( channel_head_idx[responsible_channel, rank], num_max_send_tokens + cached_channel_tail_idx - num_recv_buffer_tokens, - responsible_rank) + responsible_rank, + ) T.sync_warp() - chunk_token_idx = T.alloc_var('int32') + chunk_token_idx = T.alloc_var("int32") chunk_token_idx = 0 while chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx: # for the same token, the warp assigned to save `send_head` may be different from the warp # assigned to send the following data - if token_idx % num_warps_per_rank == send_warp_id_in_rank and T.elect_one_sync( - ): + if token_idx % num_warps_per_rank == send_warp_id_in_rank and T.shuffle_elect(32): send_head[token_idx, responsible_rank] = T.if_then_else( - is_token_in_rank[token_idx, responsible_rank], - cached_channel_tail_idx, -1) + is_token_in_rank[token_idx, responsible_rank], cached_channel_tail_idx, -1 + ) # skip if not selected if not is_token_in_rank[token_idx, responsible_rank]: token_idx += 1 - T.loop_continue() + continue # selected, get an empty slot - dst_slot_idx = T.alloc_var('int32') + dst_slot_idx = T.alloc_var("int32") dst_slot_idx = cached_channel_tail_idx % num_recv_buffer_tokens cached_channel_tail_idx += 1 if cached_channel_tail_idx % num_warps_per_rank == send_warp_id_in_rank: @@ -690,20 +631,16 @@ def cached_dispatch_main( # 1. copy data T.put_warp( T.address_of(x[token_idx, 0]), - T.address_of(channel_x_buffers[responsible_channel, rank, - dst_slot_idx, 0]), + T.address_of(channel_x_buffers[responsible_channel, rank, dst_slot_idx, 0]), hidden, dst_pe=responsible_rank, unroll_factor=4, - enable_aggressive_vectorize=True) + enable_aggressive_vectorize=True, + ) # 2. copy src idx - if T.elect_one_sync(): - T.st( - channel_src_idx_buffers[responsible_channel, rank, - dst_slot_idx], - token_idx, - dst_pe=responsible_rank) + if T.shuffle_elect(32): + T.st(channel_src_idx_buffers[responsible_channel, rank, dst_slot_idx], token_idx, dst_pe=responsible_rank) # 4. copy scale (support fp8 later) @@ -713,36 +650,30 @@ def cached_dispatch_main( # move tail index # here all warps should share the same new tail T.sync_threads(responsible_rank, num_threads_per_rank) - if send_warp_id_in_rank == 0 and T.elect_one_sync(): + if T.shuffle_elect(96): T.st( channel_tail_idx[responsible_channel, rank], cached_channel_tail_idx, - scope='sys', - sem='release', - dst_pe=responsible_rank) + scope="sys", + sem="release", + dst_pe=responsible_rank, + ) else: # receiver recv_thread_id_in_rank = tx % num_threads_per_rank recv_warp_id_in_rank = recv_thread_id_in_rank // 32 # calculate offset first - rank_offset = T.if_then_else(responsible_rank > 0, - rank_prefix_matrix[responsible_rank - 1, rank], 0) + rank_offset = T.if_then_else(responsible_rank > 0, rank_prefix_matrix[responsible_rank - 1, rank], 0) # receive channel offset - total_offset = T.alloc_var('int32') - num_tokens_to_recv = T.alloc_var('int32') - if T.elect_one_sync(): + total_offset = T.alloc_var("int32") + num_tokens_to_recv = T.alloc_var("int32") + if T.shuffle_elect(32): T.wait_ne(channel_start_offset[responsible_channel, responsible_rank], 0) - T.ld( - channel_start_offset[responsible_channel, responsible_rank], - total_offset, - sem='volatile') + T.ld(channel_start_offset[responsible_channel, responsible_rank], total_offset, sem="volatile") T.wait_ne(channel_end_offset[responsible_channel, responsible_rank], 0) - T.ld( - channel_end_offset[responsible_channel, responsible_rank], - num_tokens_to_recv, - sem='volatile') + T.ld(channel_end_offset[responsible_channel, responsible_rank], num_tokens_to_recv, sem="volatile") total_offset = -total_offset - 1 num_tokens_to_recv = -num_tokens_to_recv - 1 if recv_warp_id_in_rank == 0: @@ -753,24 +684,20 @@ def cached_dispatch_main( num_tokens_to_recv = T.tvm_warp_shuffle(-1, num_tokens_to_recv, 0, 32, 32) # Shared tail indices for different warps - shared_channel_tail_idx = T.alloc_shared([num_ranks], 'int32') + shared_channel_tail_idx = T.alloc_shared([num_ranks], "int32") - cached_channel_head_idx = T.alloc_var('int32') + cached_channel_head_idx = T.alloc_var("int32") cached_channel_head_idx = 0 - cached_channel_tail_idx = T.alloc_var('int32') + cached_channel_tail_idx = T.alloc_var("int32") cached_channel_tail_idx = 0 - with T.While(num_tokens_to_recv > 0): - with T.While(recv_thread_id_in_rank == 0): - T.ld( - channel_tail_idx[responsible_channel, responsible_rank], - cached_channel_tail_idx, - sem='acquire', - scope='sys') + while num_tokens_to_recv > 0: + while recv_thread_id_in_rank == 0: + T.ld(channel_tail_idx[responsible_channel, responsible_rank], cached_channel_tail_idx, sem="acquire", scope="sys") # read to copy if cached_channel_head_idx != cached_channel_tail_idx: shared_channel_tail_idx[responsible_rank] = cached_channel_tail_idx - T.loop_break() + break # sync queue tail T.sync_threads(responsible_rank, num_threads_per_rank) @@ -779,31 +706,29 @@ def cached_dispatch_main( # copy data # 1. recv x num_cur_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx - for chunk_idx in T.serial(recv_warp_id_in_rank, num_cur_recv_tokens, - num_warps_per_rank): - token_idx_in_buffer = (cached_channel_head_idx + - chunk_idx) % num_recv_buffer_tokens + for chunk_idx in T.serial(recv_warp_id_in_rank, num_cur_recv_tokens, num_warps_per_rank): + token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens #! T.copy will cause layout inference error T.put_warp( - T.address_of(channel_x_buffers[responsible_channel, responsible_rank, - token_idx_in_buffer, 0]), + T.address_of(channel_x_buffers[responsible_channel, responsible_rank, token_idx_in_buffer, 0]), T.address_of(recv_x[total_offset + chunk_idx, 0]), hidden, -1, 5, - enable_aggressive_vectorize=True) + enable_aggressive_vectorize=True, + ) # 2. recv src_idx - for chunk_idx in T.serial(cached_channel_head_idx + recv_thread_id_in_rank, - cached_channel_tail_idx, num_threads_per_rank): - local_src_idx = T.alloc_var('int32') + for chunk_idx in T.serial( + cached_channel_head_idx + recv_thread_id_in_rank, cached_channel_tail_idx, num_threads_per_rank + ): + local_src_idx = T.alloc_var("int32") T.ld( - channel_src_idx_buffers[responsible_channel, responsible_rank, - chunk_idx % num_recv_buffer_tokens], + channel_src_idx_buffers[responsible_channel, responsible_rank, chunk_idx % num_recv_buffer_tokens], local_src_idx, - nc=True) - recv_src_idx[total_offset + chunk_idx - - cached_channel_head_idx] = local_src_idx + nc=True, + ) + recv_src_idx[total_offset + chunk_idx - cached_channel_head_idx] = local_src_idx # 4. recv scale (support fp8 later) @@ -811,12 +736,8 @@ def cached_dispatch_main( cached_channel_head_idx += num_cur_recv_tokens total_offset += num_cur_recv_tokens T.sync_threads(responsible_rank, num_threads_per_rank) - if recv_warp_id_in_rank == num_warps_per_rank - 1 and T.elect_one_sync(): - T.st( - channel_head_idx[responsible_channel, responsible_rank], - cached_channel_head_idx, - scope='sys', - sem='relaxed') + if T.shuffle_elect(96): + T.st(channel_head_idx[responsible_channel, responsible_rank], cached_channel_head_idx, scope="sys", sem="relaxed") # Exit num_tokens_to_recv -= num_cur_recv_tokens @@ -848,8 +769,9 @@ def intranode_dispatch( # todo: support async functionality ): if handle is None: - assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None, \ - "num_tokens_per_rank, is_token_in_rank, and num_tokens_per_expert must be provided in non-cached mode" + assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None, ( + "num_tokens_per_rank, is_token_in_rank, and num_tokens_per_expert must be provided in non-cached mode" + ) else: rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head = handle @@ -858,8 +780,19 @@ def intranode_dispatch( num_ranks = num_tokens_per_rank.shape[0] num_topk = topk_idx.shape[1] if handle is None else 0 - barrier_signal, per_rank_buffer, per_expert_buffer, channel_start_offset, channel_end_offset, channel_head_idx, channel_tail_idx, \ - channel_x_buffers, channel_src_idx_buffers, channel_topk_idx_buffers, channel_topk_weights_buffers = symm_buffers + ( + barrier_signal, + per_rank_buffer, + per_expert_buffer, + channel_start_offset, + channel_end_offset, + channel_head_idx, + channel_tail_idx, + channel_x_buffers, + channel_src_idx_buffers, + channel_topk_idx_buffers, + channel_topk_weights_buffers, + ) = symm_buffers if handle is None: num_recv_tokens, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix = notify_dispatch( @@ -895,76 +828,84 @@ def intranode_dispatch( channel_tail_idx, barrier_signal, allocator, - comm_stream=comm_stream) + comm_stream=comm_stream, + ) num_recv_tokens = recv_src_idx.size(0) - recv_x = torch.empty((num_recv_tokens, hidden), dtype=x.dtype, device='cuda') - recv_src_idx = torch.empty((num_recv_tokens,), dtype=torch.int32, device='cuda') + recv_x = torch.empty((num_recv_tokens, hidden), dtype=x.dtype, device="cuda") + recv_src_idx = torch.empty((num_recv_tokens,), dtype=torch.int32, device="cuda") if handle is None: - recv_topk_idx = torch.empty((num_recv_tokens, num_topk), dtype=torch.int64, device='cuda') - recv_topk_weights = torch.empty((num_recv_tokens, num_topk), - dtype=torch.float32, - device='cuda') - recv_channel_prefix_matrix = torch.empty((num_ranks, config.num_channels), - dtype=torch.int32, - device='cuda') - send_head = torch.empty((num_tokens, num_ranks), dtype=torch.int32, device='cuda') + recv_topk_idx = torch.empty((num_recv_tokens, num_topk), dtype=torch.int64, device="cuda") + recv_topk_weights = torch.empty((num_recv_tokens, num_topk), dtype=torch.float32, device="cuda") + recv_channel_prefix_matrix = torch.empty((num_ranks, config.num_channels), dtype=torch.int32, device="cuda") + send_head = torch.empty((num_tokens, num_ranks), dtype=torch.int32, device="cuda") # run dispatch if handle is None: - kernel = dispatch_kernel(num_ranks, config.num_max_nvl_chunked_send_tokens, - config.num_max_nvl_chunked_recv_tokens, hidden, num_topk, - num_experts, config.num_sms, 'bfloat16') - kernel.initialize(allocator=allocator) - kernel( - rank, - recv_x, - recv_src_idx, - recv_topk_idx, - recv_topk_weights, - recv_channel_prefix_matrix, - send_head, - x, - topk_idx, - topk_weights, - is_token_in_rank, - rank_prefix_matrix, - channel_prefix_matrix, - channel_start_offset, - channel_end_offset, - channel_head_idx, - channel_tail_idx, - channel_x_buffers, - channel_src_idx_buffers, - channel_topk_idx_buffers, - channel_topk_weights_buffers, - stream=comm_stream.cuda_stream, - skip_tensor_validation=True) # reduce runtime overhead - handle = (rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, - recv_src_idx, is_token_in_rank, send_head) + kernel = dispatch_kernel( + num_ranks, + config.num_max_nvl_chunked_send_tokens, + config.num_max_nvl_chunked_recv_tokens, + hidden, + num_topk, + num_experts, + config.num_sms, + "bfloat16", + ) + kernel.initialize(allocator=allocator, stream=comm_stream.cuda_stream) + with tvm_ffi.use_torch_stream(torch.cuda.stream(comm_stream)): + kernel( + rank, + recv_x, + recv_src_idx, + recv_topk_idx, + recv_topk_weights, + recv_channel_prefix_matrix, + send_head, + x, + topk_idx, + topk_weights, + is_token_in_rank, + rank_prefix_matrix, + channel_prefix_matrix, + channel_start_offset, + channel_end_offset, + channel_head_idx, + channel_tail_idx, + channel_x_buffers, + channel_src_idx_buffers, + channel_topk_idx_buffers, + channel_topk_weights_buffers, + ) + handle = (rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head) return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle else: - kernel = cached_dispatch_kernel(num_ranks, num_tokens, - config.num_max_nvl_chunked_send_tokens, - config.num_max_nvl_chunked_recv_tokens, hidden, - config.num_sms, 'bfloat16') + kernel = cached_dispatch_kernel( + num_ranks, + num_tokens, + config.num_max_nvl_chunked_send_tokens, + config.num_max_nvl_chunked_recv_tokens, + hidden, + config.num_sms, + "bfloat16", + ) kernel.initialize(allocator=allocator, stream=comm_stream.cuda_stream) - kernel( - rank, - recv_x, - recv_src_idx, - recv_channel_prefix_matrix, - send_head, - x, - is_token_in_rank, - rank_prefix_matrix, - channel_prefix_matrix, - channel_start_offset, - channel_end_offset, - channel_head_idx, - channel_tail_idx, - channel_x_buffers, - channel_src_idx_buffers, - stream=comm_stream.cuda_stream, - skip_tensor_validation=True) # reduce runtime overhead + with torch.cuda.stream(comm_stream): + kernel( + rank, + recv_x, + recv_src_idx, + recv_channel_prefix_matrix, + send_head, + x, + is_token_in_rank, + rank_prefix_matrix, + channel_prefix_matrix, + channel_start_offset, + channel_end_offset, + channel_head_idx, + channel_tail_idx, + channel_x_buffers, + channel_src_idx_buffers, + ) return recv_x diff --git a/examples/distributed/deepseek_deepep/intranode/example_intranode.py b/examples/distributed/deepseek_deepep/intranode/example_intranode.py index 8f555dfee..41ea25834 100644 --- a/examples/distributed/deepseek_deepep/intranode/example_intranode.py +++ b/examples/distributed/deepseek_deepep/intranode/example_intranode.py @@ -13,7 +13,7 @@ from deepep_utils import gen_inputs, ep_bench # tilelang.disable_cache() -os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log +os.environ["NCCL_DEBUG"] = "WARN" # silence NCCL log def test_intranode( @@ -37,170 +37,187 @@ def test_intranode( deepep_buffer = deep_ep.Buffer(group, num_nvl_bytes=2**30) # Generate inputs for testing - x, topk_idx, topk_weights, rank_idx = gen_inputs(num_tokens, hidden, num_topk, num_experts, - num_ranks) + x, topk_idx, topk_weights, rank_idx = gen_inputs(num_tokens, hidden, num_topk, num_experts, num_ranks) # 1. test get_dispatch_layout ref_num_tokens_per_rank, _, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = deepep_buffer.get_dispatch_layout( - topk_idx, num_experts) - num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank = ts_buffer.get_dispatch_layout( - topk_idx) + topk_idx, num_experts + ) + num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank = ts_buffer.get_dispatch_layout(topk_idx) - assert torch.equal(num_tokens_per_expert, ref_num_tokens_per_expert), \ + assert torch.equal(num_tokens_per_expert, ref_num_tokens_per_expert), ( f"[rank {rank}] num_tokens_per_expert mismatch, max err: {(num_tokens_per_expert - ref_num_tokens_per_expert).abs().max()}" - assert torch.equal(is_token_in_rank, ref_is_token_in_rank), \ - f"[rank {rank}] is_token_in_rank mismatch" - assert torch.equal(num_tokens_per_rank, ref_num_tokens_per_rank), \ + ) + assert torch.equal(is_token_in_rank, ref_is_token_in_rank), f"[rank {rank}] is_token_in_rank mismatch" + assert torch.equal(num_tokens_per_rank, ref_num_tokens_per_rank), ( f"[rank {rank}] num_tokens_per_rank mismatch, max err: {(num_tokens_per_rank - ref_num_tokens_per_rank).abs().max()}" + ) group.barrier() if rank == 0: - print('Check passed for get_dispatch_layout. ✅') + print("Check passed for get_dispatch_layout. ✅") # 2. test dispatch # ref - ref_recv_x, ref_recv_topk_idx, ref_recv_topk_weights, ref_num_recv_tokens_per_expert_list, ref_handle, event = \ - deepep_buffer.dispatch(x, None, ref_num_tokens_per_rank, None, ref_is_token_in_rank, ref_num_tokens_per_expert, topk_idx, topk_weights, expert_alignment) + ref_recv_x, ref_recv_topk_idx, ref_recv_topk_weights, ref_num_recv_tokens_per_expert_list, ref_handle, event = deepep_buffer.dispatch( + x, None, ref_num_tokens_per_rank, None, ref_is_token_in_rank, ref_num_tokens_per_expert, topk_idx, topk_weights, expert_alignment + ) # ours if cached_dispatch: - recv_x = ts_buffer.dispatch(x, ref_handle, num_tokens_per_rank, is_token_in_rank, - num_tokens_per_expert, None, None, expert_alignment) + recv_x = ts_buffer.dispatch( + x, ref_handle, num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, None, None, expert_alignment + ) else: recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle = ts_buffer.dispatch( - x, None, num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, topk_idx, - topk_weights, expert_alignment) + x, None, num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, topk_idx, topk_weights, expert_alignment + ) # check dispatch output - assert torch.equal( - recv_x, - ref_recv_x), f'[rank {rank}] recv_x mismatch, max err: {(recv_x - ref_recv_x).abs().max()}' + assert torch.equal(recv_x, ref_recv_x), f"[rank {rank}] recv_x mismatch, max err: {(recv_x - ref_recv_x).abs().max()}" if not cached_dispatch: - assert torch.equal( - recv_topk_idx, ref_recv_topk_idx - ), f'[rank {rank}] recv_topk_idx mismatch, max err: {(recv_topk_idx - ref_recv_topk_idx).abs().max()}' - assert torch.equal( - recv_topk_weights, ref_recv_topk_weights - ), f'[rank {rank}] recv_topk_weights mismatch, max err: {(recv_topk_weights - ref_recv_topk_weights).abs().max()}' - assert num_recv_tokens_per_expert_list == ref_num_recv_tokens_per_expert_list, f'[rank {rank}] num_recv_tokens_per_expert_list mismatch' + assert torch.equal(recv_topk_idx, ref_recv_topk_idx), ( + f"[rank {rank}] recv_topk_idx mismatch, max err: {(recv_topk_idx - ref_recv_topk_idx).abs().max()}" + ) + assert torch.equal(recv_topk_weights, ref_recv_topk_weights), ( + f"[rank {rank}] recv_topk_weights mismatch, max err: {(recv_topk_weights - ref_recv_topk_weights).abs().max()}" + ) + assert num_recv_tokens_per_expert_list == ref_num_recv_tokens_per_expert_list, ( + f"[rank {rank}] num_recv_tokens_per_expert_list mismatch" + ) # check handle rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head = handle - ref_rank_prefix_matrix, ref_channel_prefix_matrix, ref_recv_channel_prefix_matrix, ref_recv_src_idx, ref_is_token_in_rank, ref_send_head = ref_handle - assert torch.equal( - rank_prefix_matrix, ref_rank_prefix_matrix - ), f'[rank {rank}] rank_prefix_matrix mismatch, max err: {(rank_prefix_matrix - ref_rank_prefix_matrix).abs().max()}' - assert torch.equal( - channel_prefix_matrix, ref_channel_prefix_matrix - ), f'[rank {rank}] channel_prefix_matrix mismatch, max err: {(channel_prefix_matrix - ref_channel_prefix_matrix).abs().max()}' - assert torch.equal( - recv_channel_prefix_matrix, ref_recv_channel_prefix_matrix - ), f'[rank {rank}] recv_channel_prefix_matrix mismatch, max err: {(recv_channel_prefix_matrix - ref_recv_channel_prefix_matrix).abs().max()}' - assert torch.equal( - recv_src_idx, ref_recv_src_idx - ), f'[rank {rank}] recv_src_idx mismatch, max err: {(recv_src_idx - ref_recv_src_idx).abs().max()}' - assert torch.equal( - is_token_in_rank, ref_is_token_in_rank - ), f'[rank {rank}] is_token_in_rank mismatch, max err: {(is_token_in_rank - ref_is_token_in_rank).abs().max()}' - assert torch.equal( - send_head, ref_send_head - ), f'[rank {rank}] send_head mismatch, max err: {(send_head - ref_send_head).abs().max()}' + ( + ref_rank_prefix_matrix, + ref_channel_prefix_matrix, + ref_recv_channel_prefix_matrix, + ref_recv_src_idx, + ref_is_token_in_rank, + ref_send_head, + ) = ref_handle + assert torch.equal(rank_prefix_matrix, ref_rank_prefix_matrix), ( + f"[rank {rank}] rank_prefix_matrix mismatch, max err: {(rank_prefix_matrix - ref_rank_prefix_matrix).abs().max()}" + ) + assert torch.equal(channel_prefix_matrix, ref_channel_prefix_matrix), ( + f"[rank {rank}] channel_prefix_matrix mismatch, max err: {(channel_prefix_matrix - ref_channel_prefix_matrix).abs().max()}" + ) + assert torch.equal(recv_channel_prefix_matrix, ref_recv_channel_prefix_matrix), ( + f"[rank {rank}] recv_channel_prefix_matrix mismatch, max err: {(recv_channel_prefix_matrix - ref_recv_channel_prefix_matrix).abs().max()}" + ) + assert torch.equal(recv_src_idx, ref_recv_src_idx), ( + f"[rank {rank}] recv_src_idx mismatch, max err: {(recv_src_idx - ref_recv_src_idx).abs().max()}" + ) + assert torch.equal(is_token_in_rank, ref_is_token_in_rank), ( + f"[rank {rank}] is_token_in_rank mismatch, max err: {(is_token_in_rank - ref_is_token_in_rank).abs().max()}" + ) + assert torch.equal(send_head, ref_send_head), ( + f"[rank {rank}] send_head mismatch, max err: {(send_head - ref_send_head).abs().max()}" + ) group.barrier() if rank == 0: - print(f'Check passed for {"cached" if cached_dispatch else "non-cached"} dispatch. ✅') + print(f"Check passed for {'cached' if cached_dispatch else 'non-cached'} dispatch. ✅") # 3. test combine - ref_combined_x, ref_combined_topk_weights, _ = deepep_buffer.combine( - recv_x, ref_handle, ref_recv_topk_weights) + ref_combined_x, ref_combined_topk_weights, _ = deepep_buffer.combine(recv_x, ref_handle, ref_recv_topk_weights) if cached_dispatch: # acquire handle first recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle = ts_buffer.dispatch( - x, None, num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, topk_idx, - topk_weights, expert_alignment) + x, None, num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, topk_idx, topk_weights, expert_alignment + ) combined_x, combined_topk_weights = ts_buffer.combine(recv_x, handle, recv_topk_weights) - assert torch.equal( - combined_x, ref_combined_x - ), f'[rank {rank}] combined_x mismatch, max err: {(combined_x - ref_combined_x).abs().max()}' - assert torch.equal( - combined_topk_weights, ref_combined_topk_weights - ), f'[rank {rank}] combined_topk_weights mismatch, max err: {(combined_topk_weights - ref_combined_topk_weights).abs().max()}' + assert torch.equal(combined_x, ref_combined_x), ( + f"[rank {rank}] combined_x mismatch, max err: {(combined_x - ref_combined_x).abs().max()}" + ) + assert torch.equal(combined_topk_weights, ref_combined_topk_weights), ( + f"[rank {rank}] combined_topk_weights mismatch, max err: {(combined_topk_weights - ref_combined_topk_weights).abs().max()}" + ) group.barrier() if rank == 0: - print('Check passed for combine. ✅') + print("Check passed for combine. ✅") if rank == 0: - print('All checks passed for TileScale intranode DeepEP. ✅') + print("All checks passed for TileScale intranode DeepEP. ✅") # benchmark if rank == 0: - print( - f'========== Benchmarking {"cached" if cached_dispatch else "non-cached"} dispatch ==========' - ) + print(f"========== Benchmarking {'cached' if cached_dispatch else 'non-cached'} dispatch ==========") if not cached_dispatch: group.barrier() deepep_dispatch_time = ep_bench( - lambda: deepep_buffer. - dispatch(x, None, ref_num_tokens_per_rank, None, ref_is_token_in_rank, - ref_num_tokens_per_expert, topk_idx, topk_weights, expert_alignment), + lambda: deepep_buffer.dispatch( + x, + None, + ref_num_tokens_per_rank, + None, + ref_is_token_in_rank, + ref_num_tokens_per_expert, + topk_idx, + topk_weights, + expert_alignment, + ), warmup=50, - rep=50) - print(f'[rank {rank}] DeepEP dispatch time: {deepep_dispatch_time:.4f}ms') + rep=50, + ) + print(f"[rank {rank}] DeepEP dispatch time: {deepep_dispatch_time:.4f}ms") group.barrier() ts_dispatch_time = ep_bench( - lambda: ts_buffer. - dispatch(x, None, num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, - topk_idx, topk_weights, expert_alignment), + lambda: ts_buffer.dispatch( + x, None, num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, topk_idx, topk_weights, expert_alignment + ), warmup=50, - rep=50) - print(f'[rank {rank}] TileScale dispatch time: {ts_dispatch_time:.4f}ms') + rep=50, + ) + print(f"[rank {rank}] TileScale dispatch time: {ts_dispatch_time:.4f}ms") group.barrier() else: group.barrier() deepep_dispatch_time = ep_bench( - lambda: deepep_buffer. - dispatch(x, ref_handle, ref_num_tokens_per_rank, None, ref_is_token_in_rank, - ref_num_tokens_per_expert, None, None, expert_alignment), + lambda: deepep_buffer.dispatch( + x, ref_handle, ref_num_tokens_per_rank, None, ref_is_token_in_rank, ref_num_tokens_per_expert, None, None, expert_alignment + ), warmup=50, - rep=50) - print(f'[rank {rank}] DeepEP dispatch time: {deepep_dispatch_time:.4f}ms') + rep=50, + ) + print(f"[rank {rank}] DeepEP dispatch time: {deepep_dispatch_time:.4f}ms") group.barrier() ts_dispatch_time = ep_bench( - lambda: ts_buffer.dispatch(x, ref_handle, num_tokens_per_rank, is_token_in_rank, - num_tokens_per_expert, None, None, expert_alignment), + lambda: ts_buffer.dispatch( + x, ref_handle, num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, None, None, expert_alignment + ), warmup=50, - rep=50) - print(f'[rank {rank}] TileScale dispatch time: {ts_dispatch_time:.4f}ms') + rep=50, + ) + print(f"[rank {rank}] TileScale dispatch time: {ts_dispatch_time:.4f}ms") group.barrier() if rank == 0: - print('========== Benchmarking combine ==========') + print("========== Benchmarking combine ==========") group.barrier() - deepep_combine_time = ep_bench( - lambda: deepep_buffer.combine(recv_x, ref_handle, ref_recv_topk_weights), warmup=50, rep=50) - print(f'[rank {rank}] DeepEP combine time: {deepep_combine_time:.4f}ms') + deepep_combine_time = ep_bench(lambda: deepep_buffer.combine(recv_x, ref_handle, ref_recv_topk_weights), warmup=50, rep=50) + print(f"[rank {rank}] DeepEP combine time: {deepep_combine_time:.4f}ms") group.barrier() - ts_combine_time = ep_bench( - lambda: ts_buffer.combine(recv_x, handle, recv_topk_weights), warmup=50, rep=50) - print(f'[rank {rank}] TileScale combine time: {ts_combine_time:.4f}ms') + ts_combine_time = ep_bench(lambda: ts_buffer.combine(recv_x, handle, recv_topk_weights), warmup=50, rep=50) + print(f"[rank {rank}] TileScale combine time: {ts_combine_time:.4f}ms") group.barrier() if rank == 0: - print('========== Benchmarking report ==========') + print("========== Benchmarking report ==========") dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2 combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes if rank == 0: print( - f'DeepEP dispatch time: {deepep_dispatch_time:.4f}ms, bandwidth: {dispatch_bf16_nvl_recv_bytes / deepep_dispatch_time / 1e6:.2f} GB/s (NVL)' + f"DeepEP dispatch time: {deepep_dispatch_time:.4f}ms, bandwidth: {dispatch_bf16_nvl_recv_bytes / deepep_dispatch_time / 1e6:.2f} GB/s (NVL)" ) print( - f'TileScale dispatch time: {ts_dispatch_time:.4f}ms, bandwidth: {dispatch_bf16_nvl_recv_bytes / ts_dispatch_time / 1e6:.2f} GB/s (NVL)' + f"TileScale dispatch time: {ts_dispatch_time:.4f}ms, bandwidth: {dispatch_bf16_nvl_recv_bytes / ts_dispatch_time / 1e6:.2f} GB/s (NVL)" ) print( - f'DeepEP combine time: {deepep_combine_time:.4f}ms, bandwidth: {combine_bf16_nvl_send_bytes / deepep_combine_time / 1e6:.2f} GB/s (NVL)' + f"DeepEP combine time: {deepep_combine_time:.4f}ms, bandwidth: {combine_bf16_nvl_send_bytes / deepep_combine_time / 1e6:.2f} GB/s (NVL)" ) print( - f'TileScale combine time: {ts_combine_time:.4f}ms, bandwidth: {combine_bf16_nvl_send_bytes / ts_combine_time / 1e6:.2f} GB/s (NVL)' + f"TileScale combine time: {ts_combine_time:.4f}ms, bandwidth: {combine_bf16_nvl_send_bytes / ts_combine_time / 1e6:.2f} GB/s (NVL)" ) @@ -227,12 +244,10 @@ def parse_args(): parser.add_argument("--num_ranks", type=int, default=8, help="Number of ranks") parser.add_argument("--num_tokens", type=int, default=4096, help="Number of tokens") parser.add_argument("--hidden", type=int, default=7168, help="Hidden size") - parser.add_argument( - "--num_topk", type=int, default=8, help="Number of top-k experts to select for each token") + parser.add_argument("--num_topk", type=int, default=8, help="Number of top-k experts to select for each token") parser.add_argument("--num_experts", type=int, default=32, help="Number of experts") parser.add_argument("--expert_alignment", type=int, default=1, help="Expert alignment") - parser.add_argument( - "--cached", action="store_true", default=False, help="Whether to use cached dispatch") + parser.add_argument("--cached", action="store_true", default=False, help="Whether to use cached dispatch") return parser.parse_args() diff --git a/examples/distributed/deepseek_deepep/intranode/get_dispatch_layout.py b/examples/distributed/deepseek_deepep/intranode/get_dispatch_layout.py index 97b67d1a4..c696297e1 100644 --- a/examples/distributed/deepseek_deepep/intranode/get_dispatch_layout.py +++ b/examples/distributed/deepseek_deepep/intranode/get_dispatch_layout.py @@ -15,8 +15,8 @@ # TODO(wt): Add async functionality def get_dispatch_layout( - topk_idx: torch.Tensor, num_experts: int, - num_ranks: int) -> Tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor]: + topk_idx: torch.Tensor, num_experts: int, num_ranks: int +) -> Tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor]: """Calculate the layout required for later communication. Arguments: @@ -42,9 +42,9 @@ def get_dispatch_layout( # Allocate tensors # TODO(wt): Wait on previous events and allocate on comm stream when adding async functionality num_tokens, num_topk = topk_idx.shape - num_tokens_per_rank = torch.empty(num_ranks, dtype=torch.int32, device='cuda') - num_tokens_per_expert = torch.empty(num_experts, dtype=torch.int32, device='cuda') - is_token_in_rank = torch.empty((num_tokens, num_ranks), dtype=torch.bool, device='cuda') + num_tokens_per_rank = torch.empty(num_ranks, dtype=torch.int32, device="cuda") + num_tokens_per_expert = torch.empty(num_experts, dtype=torch.int32, device="cuda") + is_token_in_rank = torch.empty((num_tokens, num_ranks), dtype=torch.bool, device="cuda") # Launch the kernel kernel = get_dispatch_layout_kernel(num_topk, num_experts, num_ranks) @@ -72,14 +72,14 @@ def get_dispatch_layout_kernel( num_sms = T.ceildiv(num_experts, experts_per_sm) + T.ceildiv(num_ranks, ranks_per_sm) experts_per_rank = num_experts // num_ranks - num_tokens = T.dynamic('num_tokens') + num_tokens = T.dynamic("num_tokens") @T.prim_func def get_dispatch_layout_main( - topk_idx: T.Tensor([num_tokens, num_topk], "int64"), # type: ignore - num_tokens_per_rank: T.Tensor([num_ranks], "int32"), # type: ignore - num_tokens_per_expert: T.Tensor([num_experts], "int32"), # type: ignore - is_token_in_rank: T.Tensor([num_tokens, num_ranks], "bool"), # type: ignore + topk_idx: T.Tensor([num_tokens, num_topk], "int64"), # type: ignore + num_tokens_per_rank: T.Tensor([num_ranks], "int32"), # type: ignore + num_tokens_per_expert: T.Tensor([num_experts], "int32"), # type: ignore + is_token_in_rank: T.Tensor([num_tokens, num_ranks], "bool"), # type: ignore ): with T.Kernel(num_sms, threads=threads) as bx: tx = T.get_thread_binding() diff --git a/examples/distributed/deepseek_deepep/intranode/test_intranode.py b/examples/distributed/deepseek_deepep/intranode/test_intranode.py index 317721996..c6f8a55c6 100644 --- a/examples/distributed/deepseek_deepep/intranode/test_intranode.py +++ b/examples/distributed/deepseek_deepep/intranode/test_intranode.py @@ -3,6 +3,7 @@ import example_intranode +@tilelang.testing.requires_distributed @tilelang.testing.requires_cuda def test_intranode(monkeypatch): monkeypatch.setattr("sys.argv", ["example_intranode.py"]) # optionally add testing params here diff --git a/examples/distributed/example_all_to_all.py b/examples/distributed/example_all_to_all.py index 328ebc86b..dd0157c89 100644 --- a/examples/distributed/example_all_to_all.py +++ b/examples/distributed/example_all_to_all.py @@ -11,7 +11,6 @@ def all_to_all(PE_num, TOKEN_NUM, TOPK, HIDDEN, EXPERT_NUM, dtype="float16"): - EXPERTS_PER_RANK = EXPERT_NUM // PE_num @T.prim_func @@ -37,8 +36,8 @@ def main( m_end[0] = splits_cumsum[(peer + 1) * EXPERTS_PER_RANK] T.putmem_nbi_block( - T.address_of(data_dst[0, 0]), T.address_of(data_src[m_start[0], 0]), - (m_end[0] - m_start[0]) * HIDDEN * 2, peer) + T.address_of(data_dst[0, 0]), T.address_of(data_src[m_start[0], 0]), (m_end[0] - m_start[0]) * HIDDEN * 2, peer + ) T.fence() @@ -119,7 +118,7 @@ def splits_to_cumsum(splits: torch.Tensor): # print("split_cumsum:", split_cumsum) data_src = pynvshmem.nvshmem_create_tensor([args.M * args.topk, args.N], torch.float16) -data_src[:].copy_(ref_tensor[args.M * args.topk * RANK:args.M * args.topk * (RANK + 1), :]) +data_src[:].copy_(ref_tensor[args.M * args.topk * RANK : args.M * args.topk * (RANK + 1), :]) splits_cumsum = pynvshmem.nvshmem_create_tensor([args.G + 1], torch.int32) splits_cumsum[:].copy_(split_cumsum) diff --git a/examples/distributed/example_allgather.py b/examples/distributed/example_allgather.py index bc9cb3e1b..56e865391 100644 --- a/examples/distributed/example_allgather.py +++ b/examples/distributed/example_allgather.py @@ -13,8 +13,8 @@ def allgather(PE_num, M, N, dtype="float16", threads=128): @T.prim_func def a2a_split( - A: T.Tensor((M_per_rank, N), dtype), # type: ignore - B: T.Tensor((M, N), dtype), # type: ignore + A: T.Tensor((M_per_rank, N), dtype), # type: ignore + B: T.Tensor((M, N), dtype), # type: ignore ): # Each block is responsible for sending (block_M, N) to exact one rank. with T.Kernel(M_per_rank // block_M, PE_num - 1, threads=threads) as (bx, by): @@ -24,11 +24,9 @@ def a2a_split( A_shared = T.alloc_shared((block_M, N), dtype) local_base = bx * block_M global_base = M_per_rank * mype + local_base - T.copy(A[local_base:local_base + block_M, :], A_shared) + T.copy(A[local_base : local_base + block_M, :], A_shared) peer = (mype + by + 1) % npes - T.putmem_nbi_block( - T.address_of(B[global_base, 0]), T.address_of(A_shared[0, 0]), - block_M * N * dtype_map[dtype].itemsize, peer) + T.putmem_nbi_block(T.address_of(B[global_base, 0]), T.address_of(A_shared[0, 0]), block_M * N * dtype_map[dtype].itemsize, peer) return a2a_split @@ -37,8 +35,7 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--M", type=int, default=8192) parser.add_argument("--N", type=int, default=12288) - parser.add_argument( - "--dtype", type=str, default="float16", choices=["float16", "float32", "bfloat16"]) + parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "float32", "bfloat16"]) parser.add_argument("--threads", type=int, default=128, help="number of threads in a block") parser.add_argument("--print_source", action="store_true", help="print kernel source code") parser.add_argument("--warmup", type=int, default=1, help="number of warmup iterations") @@ -46,7 +43,7 @@ def parse_args(): return parser.parse_args() -if __name__ == '__main__': +if __name__ == "__main__": WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP = init_distributed(return_tp_group=True) assert WORLD_SIZE <= 8, "This benchmark is designed for intra-node communication" @@ -82,7 +79,7 @@ def tilelang_ag(): ag_buffer = pynvshmem.nvshmem_create_tensor([M_per_rank, N], torch_dtype) ag_buffer.copy_(local_data) out = pynvshmem.nvshmem_create_tensor([M, N], torch_dtype) - out[RANK * M_per_rank:(RANK + 1) * M_per_rank, :].copy_(local_data) + out[RANK * M_per_rank : (RANK + 1) * M_per_rank, :].copy_(local_data) kernel(ag_buffer, out) pynvshmem.nvshmem_barrier_all() # Ensure all ranks have completed return out diff --git a/examples/distributed/example_allgather_gemm.py b/examples/distributed/example_allgather_gemm.py index 96f95a797..702f1264a 100644 --- a/examples/distributed/example_allgather_gemm.py +++ b/examples/distributed/example_allgather_gemm.py @@ -8,16 +8,15 @@ def allgather_gemm(PE_num, M, N, K, block_M, block_N, block_K, dtype="float16"): - accum_dtype = "float" @T.prim_func def main( - A: T.Buffer((M, K), dtype), - A_ag: T.Buffer((M * PE_num, K), dtype), - B: T.Buffer((K, N), dtype), - signal: T.Buffer((PE_num,), "uint64"), - C: T.Buffer((M * PE_num, N), dtype), + A: T.Buffer((M, K), dtype), + A_ag: T.Buffer((M * PE_num, K), dtype), + B: T.Buffer((K, N), dtype), + signal: T.Buffer((PE_num,), "uint64"), + C: T.Buffer((M * PE_num, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -36,8 +35,14 @@ def main( for k in T.serial(PE_num - 1): peer[0] = (mype[0] + 1 + k) % npes[0] T.putmem_signal_nbi_block( - T.address_of(A_ag[mype[0] * M, 0]), T.address_of(A[0, 0]), - block_M * block_K * 2, T.address_of(signal[k]), k + 1, 9, peer[0]) + T.address_of(A_ag[mype[0] * M, 0]), + T.address_of(A[0, 0]), + block_M * block_K * 2, + T.address_of(signal[k]), + k + 1, + 9, + peer[0], + ) for k in T.serial(PE_num - 1): T.signal_wait_until(T.address_of(signal[k]), 0, k + 1) @@ -60,13 +65,7 @@ def main( WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP = init_distributed(return_tp_group=True) PE_num = WORLD_SIZE func = allgather_gemm(PE_num, M, N, K, block_M, block_N, block_K) -kernel = tilelang.compile( - func, - out_idx=-1, - pass_configs={ - "tl.disable_tma_lower": True, - "tl.disable_warp_specialized": True - }) +kernel = tilelang.compile(func, out_idx=-1, pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) # Get CUDA Source if RANK == 0: @@ -90,9 +89,9 @@ def ref_program(A, B): C_ref = ref_program(A_tensor, B_tensor) print("C_ref:", C_ref) -#profiler.init_distributed() +# profiler.init_distributed() A_local = pynvshmem.nvshmem_create_tensor([M, K], dtype) -A_local[:].copy_(A_tensor[M * RANK:M * (RANK + 1), :]) +A_local[:].copy_(A_tensor[M * RANK : M * (RANK + 1), :]) A_ag_local = pynvshmem.nvshmem_create_tensor([M * PE_num, K], dtype) A_ag_local.fill_(0) diff --git a/examples/distributed/example_allgather_gemm_overlapped.py b/examples/distributed/example_allgather_gemm_overlapped.py index cebf58ed1..309481967 100644 --- a/examples/distributed/example_allgather_gemm_overlapped.py +++ b/examples/distributed/example_allgather_gemm_overlapped.py @@ -12,6 +12,7 @@ cuda_python_version = importlib.metadata.version("cuda-python") from packaging import version + if version.parse(cuda_python_version) >= version.parse("12.8.0"): from cuda.bindings import driver as cuda else: @@ -19,14 +20,15 @@ from tilelang.distributed import perf_fn tilelang.disable_cache() -os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log +os.environ["NCCL_DEBUG"] = "WARN" # silence NCCL log @tilelang.jit(pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}) def set_signal_kernel(local_rank, num_local_ranks, threads): - @T.prim_func - def _set_signal_kernel(signal_buffer: T.Tensor((num_local_ranks), "uint32"),): + def _set_signal_kernel( + signal_buffer: T.Tensor((num_local_ranks), "uint32"), + ): with T.Kernel(1, threads=threads): tx = T.get_thread_binding(0) if tx < num_local_ranks: @@ -39,19 +41,9 @@ def _set_signal_kernel(signal_buffer: T.Tensor((num_local_ranks), "uint32"),): @tilelang.jit -def gemm_kernel(M, - N, - K, - local_rank, - num_local_rank, - block_M, - block_N, - block_K, - threads, - persistent=False, - dtype="float16", - accum_dtype="float"): - +def gemm_kernel( + M, N, K, local_rank, num_local_rank, block_M, block_N, block_K, threads, persistent=False, dtype="float16", accum_dtype="float" +): sm_num = driver.get_num_sms() m_blocks = T.ceildiv(M, block_M) n_blocks = T.ceildiv(N // num_local_rank, block_N) @@ -61,14 +53,12 @@ def gemm_kernel(M, @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N // num_local_rank), dtype), - signal_buffer: T.Tensor((num_local_rank), "uint32"), - C: T.Tensor((M, N // num_local_rank), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N // num_local_rank), dtype), + signal_buffer: T.Tensor((num_local_rank), "uint32"), + C: T.Tensor((M, N // num_local_rank), dtype), ): - with T.Kernel( - T.ceildiv(M, block_M) * T.ceildiv(N // num_local_rank, block_N), - threads=threads) as (bid): + with T.Kernel(T.ceildiv(M, block_M) * T.ceildiv(N // num_local_rank, block_N), threads=threads) as (bid): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype) C_shared = T.alloc_shared((block_M, block_N), dtype) @@ -103,10 +93,10 @@ def main( @T.prim_func def main_persistent( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N // num_local_rank), dtype), - signal_buffer: T.Tensor((num_local_rank), "uint32"), - C: T.Tensor((M, N // num_local_rank), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N // num_local_rank), dtype), + signal_buffer: T.Tensor((num_local_rank), "uint32"), + C: T.Tensor((M, N // num_local_rank), dtype), ): with T.Kernel(sm_num, threads=threads) as (bid): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -162,8 +152,8 @@ def cp_engine_producer_all_gather_full_mesh_pull( for src_rank in rank_orders: if src_rank == local_rank: continue - dst = ag_buffer[local_rank][src_rank * M_per_rank:(src_rank + 1) * M_per_rank, :] - src = ag_buffer[src_rank][src_rank * M_per_rank:(src_rank + 1) * M_per_rank, :] + dst = ag_buffer[local_rank][src_rank * M_per_rank : (src_rank + 1) * M_per_rank, :] + src = ag_buffer[src_rank][src_rank * M_per_rank : (src_rank + 1) * M_per_rank, :] dst.copy_(src) (err,) = cuda.cuStreamWriteValue32( @@ -175,21 +165,33 @@ def cp_engine_producer_all_gather_full_mesh_pull( CUDA_CHECK(err) -def ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, N, signal_target, local_rank, - local_world_size, set_signal_kernel, gemm_kernel, gemm_stream, ag_stream): - +def ag_gemm_op( + A, + B, + C, + ag_buffer, + signal_buffer, + M_per_rank, + N, + signal_target, + local_rank, + local_world_size, + set_signal_kernel, + gemm_kernel, + gemm_stream, + ag_stream, +): with torch.cuda.stream(gemm_stream): - set_signal_kernel(signal_buffer[local_rank], stream=gemm_stream.cuda_stream) + set_signal_kernel(signal_buffer[local_rank]) ag_stream.wait_stream(gemm_stream) - cp_engine_producer_all_gather_full_mesh_pull(ag_buffer, signal_buffer, M_per_rank, - signal_target, local_rank, local_world_size, - ag_stream) + cp_engine_producer_all_gather_full_mesh_pull( + ag_buffer, signal_buffer, M_per_rank, signal_target, local_rank, local_world_size, ag_stream + ) with torch.cuda.stream(gemm_stream): - gemm_kernel( - ag_buffer[local_rank], B, signal_buffer[local_rank], C, stream=gemm_stream.cuda_stream) + gemm_kernel(ag_buffer[local_rank], B, signal_buffer[local_rank], C) gemm_stream.wait_stream(ag_stream) current_stream = torch.cuda.current_stream() @@ -225,14 +227,9 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rank, num_ranks, group = init_dist(local_rank, num_local_ranks) assert rank == local_rank and num_ranks == num_local_ranks, "only support single node for now" allocator = tilelang.get_allocator( - size=2**30, - device="cuda", - is_distributed=True, - local_rank=local_rank, - num_local_ranks=num_local_ranks, - group=group) - gemm_func = gemm_kernel(M, N, K, local_rank, num_local_ranks, BLOCK_M, BLOCK_N, BLOCK_K, - threads, persistent) + size=2**30, device="cuda", is_distributed=True, local_rank=local_rank, num_local_ranks=num_local_ranks, group=group + ) + gemm_func = gemm_kernel(M, N, K, local_rank, num_local_ranks, BLOCK_M, BLOCK_N, BLOCK_K, threads, persistent) set_signal_func = set_signal_kernel( local_rank=local_rank, num_local_ranks=num_local_ranks, @@ -247,11 +244,8 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_() C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator) ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True) - A = ag_buffer[local_rank][M_per_rank * local_rank:M_per_rank * (local_rank + 1), :].normal_() - signal_buffer = tilelang.tensor((num_local_ranks,), - torch.uint32, - allocator=allocator, - return_peers=True) + A = ag_buffer[local_rank][M_per_rank * local_rank : M_per_rank * (local_rank + 1), :].normal_() + signal_buffer = tilelang.tensor((num_local_ranks,), torch.uint32, allocator=allocator, return_peers=True) gemm_stream = torch.cuda.Stream() ag_stream = torch.cuda.Stream(priority=-1) @@ -259,9 +253,22 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): dist.barrier() - tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, - local_rank, num_local_ranks, set_signal_func, gemm_func, gemm_stream, - ag_stream) + tilelang_C = ag_gemm_op( + A, + B, + C, + ag_buffer, + signal_buffer, + M_per_rank, + K, + signal_target, + local_rank, + num_local_ranks, + set_signal_func, + gemm_func, + gemm_stream, + ag_stream, + ) torch_ag_buffer = torch.empty([M, K], dtype=dtype, device="cuda") torch_C = torch_ag_gemm(group, A, B, torch_ag_buffer) @@ -273,27 +280,38 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): print(f"torch_C: {torch_C}, tilelang_C: {tilelang_C}") _, tl_t = perf_fn( - lambda: - ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, local_rank, - num_local_ranks, set_signal_func, gemm_func, gemm_stream, ag_stream), + lambda: ag_gemm_op( + A, + B, + C, + ag_buffer, + signal_buffer, + M_per_rank, + K, + signal_target, + local_rank, + num_local_ranks, + set_signal_func, + gemm_func, + gemm_stream, + ag_stream, + ), warmup=5, - rep=10) - - print( - f"rank {local_rank} tilelang ag_gemm time: {tl_t:.2f} ms, TFLOPS: {2*M*N*K/1e9/(tl_t)/num_local_ranks:.2f}" + rep=10, ) + print(f"rank {local_rank} tilelang ag_gemm time: {tl_t:.2f} ms, TFLOPS: {2 * M * N * K / 1e9 / (tl_t) / num_local_ranks:.2f}") + dist.destroy_process_group() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - '--num-processes', type=int, default=2, help='Number of processes to spawn (default: 2)') - parser.add_argument('--M', type=int, default=8192, help='M dimension') - parser.add_argument('--N', type=int, default=28672, help='N dimension') - parser.add_argument('--K', type=int, default=8192, help='K dimension') - parser.add_argument('--persistent', action='store_true', help='Use persistent kernel') + parser.add_argument("--num-processes", type=int, default=2, help="Number of processes to spawn (default: 2)") + parser.add_argument("--M", type=int, default=8192, help="M dimension") + parser.add_argument("--N", type=int, default=28672, help="N dimension") + parser.add_argument("--K", type=int, default=8192, help="K dimension") + parser.add_argument("--persistent", action="store_true", help="Use persistent kernel") args = parser.parse_args() num_processes = args.num_processes diff --git a/examples/distributed/example_cannon.py b/examples/distributed/example_cannon.py index 649be6c4b..ad25a41e7 100644 --- a/examples/distributed/example_cannon.py +++ b/examples/distributed/example_cannon.py @@ -11,7 +11,6 @@ def cannon(MESH, M, N, K, block_M, block_N, block_K, dtype="float16", specialize=False): - M_local = T.ceildiv(M, MESH) N_local = T.ceildiv(N, MESH) K_local = T.ceildiv(K, MESH) @@ -22,13 +21,13 @@ def cannon(MESH, M, N, K, block_M, block_N, block_K, dtype="float16", specialize @T.prim_func def main( - A: T.Tensor((2, M_local, K_local), dtype), - B: T.Tensor((2, N_local, K_local), dtype), - A_signal_to: T.Tensor((T.ceildiv(M, block_M),), "uint64"), - A_signal_from: T.Tensor((T.ceildiv(M, block_M),), "uint64"), - B_signal_to: T.Tensor((T.ceildiv(N, block_N),), "uint64"), - B_signal_from: T.Tensor((T.ceildiv(N, block_N),), "uint64"), - C: T.Tensor((M_local, N_local), dtype), + A: T.Tensor((2, M_local, K_local), dtype), + B: T.Tensor((2, N_local, K_local), dtype), + A_signal_to: T.Tensor((T.ceildiv(M, block_M),), "uint64"), + A_signal_from: T.Tensor((T.ceildiv(M, block_M),), "uint64"), + B_signal_to: T.Tensor((T.ceildiv(N, block_N),), "uint64"), + B_signal_from: T.Tensor((T.ceildiv(N, block_N),), "uint64"), + C: T.Tensor((M_local, N_local), dtype), ): grid_size = T.min(sm_num, total_tiles) A_rows_per_block = T.ceildiv(M_local, grid_size) @@ -72,16 +71,23 @@ def main( T.address_of(A[(ko + 1) % 2, A_rows_per_block * block_id, 0]), T.address_of(A[ko % 2, A_rows_per_block * block_id, 0]), A_rows_per_block * K_local * dtype_map[dtype].itemsize, - T.address_of(A_signal_to[0]), 1, T.Amo.SIGNAL_ADD, a_peer_to[0]) + T.address_of(A_signal_to[0]), + 1, + T.Amo.SIGNAL_ADD, + a_peer_to[0], + ) if block_id < T.ceildiv(N_local, B_cols_per_block): T.putmem_signal_nbi_block( T.address_of(B[(ko + 1) % 2, B_cols_per_block * block_id, 0]), T.address_of(B[ko % 2, B_cols_per_block * block_id, 0]), B_cols_per_block * K_local * dtype_map[dtype].itemsize, - T.address_of(B_signal_to[0]), 1, T.Amo.SIGNAL_ADD, b_peer_to[0]) + T.address_of(B_signal_to[0]), + 1, + T.Amo.SIGNAL_ADD, + b_peer_to[0], + ) for w in T.serial(waves): - bx = (grid_size * w + block_id) // T.ceildiv(N_local, block_N) by = (grid_size * w + block_id) % T.ceildiv(N_local, block_N) @@ -122,13 +128,13 @@ def main( # TODO: fix correctness @T.prim_func def main_specialize( - A: T.Tensor((2, M_local, K_local), dtype), - B: T.Tensor((2, N_local, K_local), dtype), - A_signal_to: T.Tensor((T.ceildiv(M, block_M),), "uint64"), - A_signal_from: T.Tensor((T.ceildiv(M, block_M),), "uint64"), - B_signal_to: T.Tensor((T.ceildiv(N, block_N),), "uint64"), - B_signal_from: T.Tensor((T.ceildiv(N, block_N),), "uint64"), - C: T.Tensor((M_local, N_local), dtype), + A: T.Tensor((2, M_local, K_local), dtype), + B: T.Tensor((2, N_local, K_local), dtype), + A_signal_to: T.Tensor((T.ceildiv(M, block_M),), "uint64"), + A_signal_from: T.Tensor((T.ceildiv(M, block_M),), "uint64"), + B_signal_to: T.Tensor((T.ceildiv(N, block_N),), "uint64"), + B_signal_from: T.Tensor((T.ceildiv(N, block_N),), "uint64"), + C: T.Tensor((M_local, N_local), dtype), ): # 0-compute blocks: compute # compute_blocks-grid_size: copy @@ -172,21 +178,26 @@ def main_specialize( total_tiles * ko, ) T.putmem_signal_nbi_block( - T.address_of(A[(ko + 1) % 2, A_rows_per_block * (block_id - compute_blocks), - 0]), + T.address_of(A[(ko + 1) % 2, A_rows_per_block * (block_id - compute_blocks), 0]), T.address_of(A[ko % 2, A_rows_per_block * (block_id - compute_blocks), 0]), A_rows_per_block * K_local * dtype_map[dtype].itemsize, - T.address_of(A_signal_to[0]), 1, T.Amo.SIGNAL_ADD, a_peer_to[0]) + T.address_of(A_signal_to[0]), + 1, + T.Amo.SIGNAL_ADD, + a_peer_to[0], + ) T.putmem_signal_nbi_block( - T.address_of(B[(ko + 1) % 2, B_cols_per_block * (block_id - compute_blocks), - 0]), + T.address_of(B[(ko + 1) % 2, B_cols_per_block * (block_id - compute_blocks), 0]), T.address_of(B[ko % 2, B_cols_per_block * (block_id - compute_blocks), 0]), B_cols_per_block * K_local * dtype_map[dtype].itemsize, - T.address_of(B_signal_to[0]), 1, T.Amo.SIGNAL_ADD, b_peer_to[0]) + T.address_of(B_signal_to[0]), + 1, + T.Amo.SIGNAL_ADD, + b_peer_to[0], + ) if block_id < compute_blocks: for w in T.serial(waves): - bx = (compute_blocks * w + block_id) // T.ceildiv(N_local, block_N) by = (compute_blocks * w + block_id) % T.ceildiv(N_local, block_N) @@ -256,11 +267,7 @@ def parse_args(): K_local = math.ceil(K / MESH) func = cannon(MESH, M, N, K, block_M, block_N, block_K, args.dtype, specialize) - kernel = tilelang.compile( - func, pass_configs={ - "tl.disable_tma_lower": True, - "tl.disable_warp_specialized": True - }) + kernel = tilelang.compile(func, pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) # Get CUDA Source if RANK == 0: @@ -281,11 +288,9 @@ def parse_args(): b_scatter_list = [] for r in range(WORLD_SIZE): rr, cc = divmod(r, MESH) - c_tile = C[M_local * rr:M_local * (rr + 1), N_local * cc:N_local * (cc + 1)] - a_tile = A[M_local * rr:M_local * (rr + 1), - K_local * ((cc + rr) % MESH):K_local * ((cc + rr) % MESH + 1)] - b_tile = B[N_local * cc:N_local * (cc + 1), - K_local * ((cc + rr) % MESH):K_local * ((cc + rr) % MESH + 1)] + c_tile = C[M_local * rr : M_local * (rr + 1), N_local * cc : N_local * (cc + 1)] + a_tile = A[M_local * rr : M_local * (rr + 1), K_local * ((cc + rr) % MESH) : K_local * ((cc + rr) % MESH + 1)] + b_tile = B[N_local * cc : N_local * (cc + 1), K_local * ((cc + rr) % MESH) : K_local * ((cc + rr) % MESH + 1)] c_scatter_list.append(c_tile.contiguous()) a_scatter_list.append(a_tile.contiguous()) @@ -320,7 +325,7 @@ def parse_args(): dist.barrier() if r == RANK: if torch.allclose(C_tilelang, ref, rtol=1e-2, atol=1e-2): - print('-' * 100) + print("-" * 100) print(f"[Rank {RANK}] ✅ Tilelang and Torch match") else: abs_error = torch.abs(C_tilelang - ref) @@ -330,7 +335,7 @@ def parse_args(): max_rel_error = rel_error.max().item() mismatch_ratio = (abs_error > (1e-2 + 1e-2 * torch.abs(ref))).float().mean().item() - print('-' * 100) + print("-" * 100) print(f"[Rank {RANK}] ❌ Tilelang and Torch mismatch") print(f"[Rank {RANK}] ref:\n{ref}") print(f"[Rank {RANK}] tilelang:\n{C_tilelang}") @@ -381,8 +386,7 @@ def reduce_local_time(local_time): total_flops = 2 * M * N * K -avg_time = reduce_local_time( - bench(kernel, A, B, A_signal_to, A_signal_from, B_signal_to, B_signal_from, C_tilelang)) +avg_time = reduce_local_time(bench(kernel, A, B, A_signal_to, A_signal_from, B_signal_to, B_signal_from, C_tilelang)) if RANK == 0: print(f"avg time of RANK {RANK}: {avg_time} ms") diff --git a/examples/distributed/example_gemm_rs_overlapped.py b/examples/distributed/example_gemm_rs_overlapped.py index 4fb1c6d43..27c2278bd 100644 --- a/examples/distributed/example_gemm_rs_overlapped.py +++ b/examples/distributed/example_gemm_rs_overlapped.py @@ -14,19 +14,9 @@ @tilelang.jit -def gemm_kernel(M, - N, - K, - local_rank, - num_local_rank, - block_M, - block_N, - block_K, - threads, - persistent=False, - dtype="float16", - accum_dtype="float"): - +def gemm_kernel( + M, N, K, local_rank, num_local_rank, block_M, block_N, block_K, threads, persistent=False, dtype="float16", accum_dtype="float" +): M_per_rank = T.ceildiv(M, num_local_rank) GROUP_SIZE_M = 8 @@ -41,11 +31,11 @@ def swizzle_2d(tile_id, num_pid_m, num_pid_n): @T.prim_func def main( - A: T.Tensor((M, K // num_local_rank), dtype), - B: T.Tensor((K // num_local_rank, N), dtype), - scatter_signal_buf: T.Tensor((num_local_rank), "uint32"), - counter_signal_buf: T.Tensor((num_local_rank), "uint32"), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K // num_local_rank), dtype), + B: T.Tensor((K // num_local_rank, N), dtype), + scatter_signal_buf: T.Tensor((num_local_rank), "uint32"), + counter_signal_buf: T.Tensor((num_local_rank), "uint32"), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(M, block_M) * T.ceildiv(N, block_N), threads=threads) as (bid): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -87,27 +77,12 @@ def main( return main -def gemm_rs_op(A, - B, - C, - output, - ctx, - gemm_kernel, - gemm_stream, - rs_stream, - local_rank, - print_source=False): - +def gemm_rs_op(A, B, C, output, ctx, gemm_kernel, gemm_stream, rs_stream, local_rank, print_source=False): current_stream = torch.cuda.current_stream() rs_stream.wait_stream(gemm_stream) - gemm_kernel( - A, - B, - ctx.scatter_signal_bufs[local_rank], - ctx.counter_bufs[local_rank], - C, - stream=gemm_stream.cuda_stream) + with torch.cuda.stream(gemm_stream): + gemm_kernel(A, B, ctx.scatter_signal_bufs[local_rank], ctx.counter_bufs[local_rank], C) if print_source and local_rank == 1: print(gemm_kernel.get_kernel_source()) @@ -155,14 +130,9 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rank, num_ranks, group = init_dist(local_rank, num_local_ranks) assert rank == local_rank and num_ranks == num_local_ranks, "only support single node for now" allocator = tilelang.get_allocator( - size=2**30, - device="cuda", - is_distributed=True, - local_rank=local_rank, - num_local_ranks=num_local_ranks, - group=group) - gemm_func = gemm_kernel(M, N, K, local_rank, num_local_ranks, BLOCK_M, BLOCK_N, BLOCK_K, - threads, persistent) + size=2**30, device="cuda", is_distributed=True, local_rank=local_rank, num_local_ranks=num_local_ranks, group=group + ) + gemm_func = gemm_kernel(M, N, K, local_rank, num_local_ranks, BLOCK_M, BLOCK_N, BLOCK_K, threads, persistent) gemm_func.initialize(allocator=allocator) A = tilelang.tensor((M, K_per_rank), dtype, allocator=allocator).normal_() / 10 @@ -172,20 +142,12 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): gemm_stream = torch.cuda.Stream() rs_stream = torch.cuda.Stream(priority=-1) ctx = create_reduce_scater_2d_ctx( - M, - N, - local_rank, - num_local_ranks, - num_local_ranks, - dtype, - allocator, - overlap_with_gemm=True, - num_reduction_sms=15) + M, N, local_rank, num_local_ranks, num_local_ranks, dtype, allocator, overlap_with_gemm=True, num_reduction_sms=15 + ) dist.barrier() - tilelang_out = gemm_rs_op( - A, B, C, output, ctx, gemm_func, gemm_stream, rs_stream, local_rank, print_source=True) + tilelang_out = gemm_rs_op(A, B, C, output, ctx, gemm_func, gemm_stream, rs_stream, local_rank, print_source=True) torch_out = torch_gemm_rs(group, A, B, None, num_local_ranks) atol = 1e-2 @@ -196,26 +158,20 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): print(f"rank {local_rank} check failed.❌") print(f"torch_out: {torch_out}, tilelang_out: {tilelang_out}") - _, tl_t = perf_fn( - lambda: gemm_rs_op(A, B, C, output, ctx, gemm_func, gemm_stream, rs_stream, local_rank), - warmup=5, - rep=5) + _, tl_t = perf_fn(lambda: gemm_rs_op(A, B, C, output, ctx, gemm_func, gemm_stream, rs_stream, local_rank), warmup=5, rep=5) - print( - f"rank {local_rank} tilelang gemm_rs time: {tl_t:.2f} ms, TFLOPS: {2*M*N*K/1e9/(tl_t)/num_local_ranks:.2f}" - ) + print(f"rank {local_rank} tilelang gemm_rs time: {tl_t:.2f} ms, TFLOPS: {2 * M * N * K / 1e9 / (tl_t) / num_local_ranks:.2f}") dist.destroy_process_group() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - '--num-processes', type=int, default=2, help='Number of processes to spawn (default: 2)') - parser.add_argument('--M', type=int, default=8192, help='M dimension') - parser.add_argument('--N', type=int, default=8192, help='N dimension') - parser.add_argument('--K', type=int, default=29568, help='K dimension') - parser.add_argument('--persistent', action='store_true', help='Use persistent kernel') + parser.add_argument("--num-processes", type=int, default=2, help="Number of processes to spawn (default: 2)") + parser.add_argument("--M", type=int, default=8192, help="M dimension") + parser.add_argument("--N", type=int, default=8192, help="N dimension") + parser.add_argument("--K", type=int, default=29568, help="K dimension") + parser.add_argument("--persistent", action="store_true", help="Use persistent kernel") args = parser.parse_args() num_processes = args.num_processes diff --git a/examples/distributed/example_nvshmem.py b/examples/distributed/example_nvshmem.py index 6499a4648..8f8de69ed 100644 --- a/examples/distributed/example_nvshmem.py +++ b/examples/distributed/example_nvshmem.py @@ -29,11 +29,10 @@ def tilelang_callback_cuda_postproc(code, _): def dist_test(M, N, block_M, block_N, dtype="int16"): - @T.prim_func def main( - A: T.Buffer((M, N), dtype), - B: T.Buffer((M, N), dtype), + A: T.Buffer((M, N), dtype), + B: T.Buffer((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_N), dtype) diff --git a/examples/distributed/example_overlapping_allgather.py b/examples/distributed/example_overlapping_allgather.py index 13c3e6dac..281e07dee 100644 --- a/examples/distributed/example_overlapping_allgather.py +++ b/examples/distributed/example_overlapping_allgather.py @@ -19,28 +19,24 @@ def internode_gather(M, local_world_size, block_M, threads): - @T.prim_func def main( - dst: T.Tensor((M), "float32"), - src: T.Tensor((M), "float32"), + dst: T.Tensor((M), "float32"), + src: T.Tensor((M), "float32"), ): with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx): rank = T.alloc_local([1], "uint64") rank[0] = (T.get_pe() + local_world_size) % (2 * local_world_size) # 2 nodes - T.putmem_nbi_block( - T.address_of(dst[bx * block_M]), T.address_of(src[bx * block_M]), block_M * 4, - rank[0]) + T.putmem_nbi_block(T.address_of(dst[bx * block_M]), T.address_of(src[bx * block_M]), block_M * 4, rank[0]) return main def intranode_gather(M, world_size, block_M, threads): - @T.prim_func def main( - dst: T.Tensor((M * world_size), "float32"), - src: T.Tensor((M * 2), "float32"), + dst: T.Tensor((M * world_size), "float32"), + src: T.Tensor((M * 2), "float32"), ): with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx): rank = T.alloc_local([1], "uint64") @@ -68,24 +64,19 @@ def main( return main -if __name__ == '__main__': +if __name__ == "__main__": tilelang.disable_cache() M = 2 K = 12288 - #for 2 node(16 GPUs), world_size=16,rank is 0-15,local rank is 0-7 - WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP, LC_GROUP = init_distributed( - return_tp_group=True, return_lc_group=True) - local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1)) + # for 2 node(16 GPUs), world_size=16,rank is 0-15,local rank is 0-7 + WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP, LC_GROUP = init_distributed(return_tp_group=True, return_lc_group=True) + local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0)) allocator = tilelang.get_allocator( - size=2**25, - device="cuda", - is_distributed=True, - local_rank=LOCAL_RANK, - num_local_ranks=local_world_size, - group=LC_GROUP) + size=2**25, device="cuda", is_distributed=True, local_rank=LOCAL_RANK, num_local_ranks=local_world_size, group=LC_GROUP + ) print(local_world_size, LOCAL_RANK) # Each rank sends the local_tensor to ranks of other nodes with the same local_rank @@ -99,7 +90,7 @@ def main( print(interkernel.get_kernel_source()) src = pynvshmem.nvshmem_create_tensor([M], torch.float32) dst = pynvshmem.nvshmem_create_tensor([M], torch.float32) - input_data = torch.ones([M], dtype=torch.float32, device='cuda') * RANK + input_data = torch.ones([M], dtype=torch.float32, device="cuda") * RANK src.copy_(input_data) pynvshmem.nvshmem_barrier_all() @@ -119,20 +110,14 @@ def main( src_intra = tilelang.tensor((M * 2), torch.float32, allocator=allocator).normal_() dst_intra = tilelang.tensor((M * WORLD_SIZE), torch.float32, allocator=allocator) if RANK < WORLD_SIZE / 2: - cudart.cudaMemcpy(src_intra.data_ptr(), src.data_ptr(), M * 4, - cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice) - cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, dst.data_ptr(), M * 4, - cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice) + cudart.cudaMemcpy(src_intra.data_ptr(), src.data_ptr(), M * 4, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice) + cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, dst.data_ptr(), M * 4, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice) else: - cudart.cudaMemcpy(src_intra.data_ptr(), dst.data_ptr(), M * 4, - cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice) - cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, src.data_ptr(), M * 4, - cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice) + cudart.cudaMemcpy(src_intra.data_ptr(), dst.data_ptr(), M * 4, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice) + cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, src.data_ptr(), M * 4, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice) env.USE_NVSHMEM = False - intrakernel = tilelang.compile( - intranode_gather(M, WORLD_SIZE, M, 128), - pass_configs={tilelang.PassConfigKey.TL_DISABLE_RDC: True}) + intrakernel = tilelang.compile(intranode_gather(M, WORLD_SIZE, M, 128), pass_configs={tilelang.PassConfigKey.TL_DISABLE_RDC: True}) intrakernel.initialize(allocator=allocator) if LOCAL_RANK == 0: print(intrakernel.get_kernel_source()) diff --git a/examples/distributed/example_post_attn_all2all_transpose.py b/examples/distributed/example_post_attn_all2all_transpose.py index e17c55ad9..de2c43671 100644 --- a/examples/distributed/example_post_attn_all2all_transpose.py +++ b/examples/distributed/example_post_attn_all2all_transpose.py @@ -2,6 +2,7 @@ import torch.distributed as dist import pynvshmem import tilelang +import tilelang.testing import tilelang.language as T from tilelang.distributed import init_distributed, dtype_map import argparse @@ -43,21 +44,14 @@ def torch_reverse_all_to_all_transpose_reference(data_src, group): # Step 2: Prepare output list for all_to_all output_list = [] for _ in range(world_size): - recv_data = torch.empty( - batch_size, - heads_per_pe, - seq_per_pe, - head_dim, - dtype=data_src.dtype, - device=data_src.device) + recv_data = torch.empty(batch_size, heads_per_pe, seq_per_pe, head_dim, dtype=data_src.dtype, device=data_src.device) output_list.append(recv_data) # Step 3: Execute all_to_all dist.all_to_all(output_list, input_list, group=group) # Step 4: Reorganize received data - result = torch.empty( - batch_size, seq_per_pe, num_heads, head_dim, dtype=data_src.dtype, device=data_src.device) + result = torch.empty(batch_size, seq_per_pe, num_heads, head_dim, dtype=data_src.dtype, device=data_src.device) for pe_idx in range(world_size): head_start = pe_idx * heads_per_pe @@ -69,12 +63,7 @@ def torch_reverse_all_to_all_transpose_reference(data_src, group): return result -def sequence_parallel_reverse_all_to_all_transpose(PE_num, - BATCH_SIZE, - NUM_HEADS, - SEQ_LEN, - HEAD_DIM, - dtype="float16"): +def sequence_parallel_reverse_all_to_all_transpose(PE_num, BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM, dtype="float16"): """ Reverse All-to-All: Convert from head parallel to sequence parallel Input: [BATCH_SIZE, HEADS_PER_PE, SEQ_LEN, HEAD_DIM] @@ -88,9 +77,9 @@ def sequence_parallel_reverse_all_to_all_transpose(PE_num, @T.prim_func def main( - data_src: T.Tensor((BATCH_SIZE, HEADS_PER_PE, SEQ_LEN, HEAD_DIM), dtype), - data_dst: T.Tensor((BATCH_SIZE, SEQ_PER_PE, NUM_HEADS, HEAD_DIM), dtype), - signal: T.Tensor((PE_num,), "uint64"), + data_src: T.Tensor((BATCH_SIZE, HEADS_PER_PE, SEQ_LEN, HEAD_DIM), dtype), + data_dst: T.Tensor((BATCH_SIZE, SEQ_PER_PE, NUM_HEADS, HEAD_DIM), dtype), + signal: T.Tensor((PE_num,), "uint64"), ): with T.Kernel(NUM_BLOCKS_X, PE_num, threads=128) as (bx, target_pe): tx = T.thread_binding(128, thread="threadIdx.x") @@ -118,8 +107,10 @@ def main( T.putmem_nbi_block( T.address_of(data_dst[batch_idx, seq_idx, dst_head_idx, 0]), - T.address_of(data_src[batch_idx, head_idx, src_seq_idx, 0]), transfer_size, - target_pe) + T.address_of(data_src[batch_idx, head_idx, src_seq_idx, 0]), + transfer_size, + target_pe, + ) T.fence() @@ -129,7 +120,8 @@ def main( T.address_of(signal[mype[0]]), 1, # Signal the number of head chunks processed T.Amo.SIGNAL_ADD, - target_pe) + target_pe, + ) T.fence() # Wait for all blocks to complete all head transfers T.signal_wait_until(T.address_of(signal[target_pe]), T.CmpType.EQ, NUM_BLOCKS_X) @@ -177,6 +169,7 @@ def parse_args(): return parser.parse_args() +@tilelang.testing.requires_distributed def test_reverse_transpose_all_to_all_with_golden_reference(): args = parse_args() @@ -203,13 +196,8 @@ def test_reverse_transpose_all_to_all_with_golden_reference(): print("Converting from HEAD_PARALLEL to SEQUENCE_PARALLEL") # Compile TileLang kernel - func = sequence_parallel_reverse_all_to_all_transpose(PE_num, args.batch_size, args.num_heads, - args.seq_len, args.head_dim, args.dtype) - kernel = tilelang.compile( - func, pass_configs={ - "tl.disable_tma_lower": True, - "tl.disable_warp_specialized": True - }) + func = sequence_parallel_reverse_all_to_all_transpose(PE_num, args.batch_size, args.num_heads, args.seq_len, args.head_dim, args.dtype) + kernel = tilelang.compile(func, pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) if RANK == 0: print("\nTileLang Kernel Source:") @@ -219,9 +207,7 @@ def test_reverse_transpose_all_to_all_with_golden_reference(): dtype_torch = dtype_map[args.dtype] # Create input data: [BATCH_SIZE, HEADS_PER_PE, SEQ_LEN, HEAD_DIM] - head parallel format - input_data = torch.rand([args.batch_size, HEADS_PER_PE, args.seq_len, args.head_dim], - dtype=dtype_torch, - device='cuda') + input_data = torch.rand([args.batch_size, HEADS_PER_PE, args.seq_len, args.head_dim], dtype=dtype_torch, device="cuda") print(f"PE {RANK} Input shape: {input_data.shape}") print(f"PE {RANK} Input sample: {input_data[0, 0, 0, :3]}") @@ -235,10 +221,8 @@ def test_reverse_transpose_all_to_all_with_golden_reference(): # === Test 2: TileLang NVSHMEM Implementation === def tilelang_reverse_all_to_all(): # Create NVSHMEM tensors - data_src = pynvshmem.nvshmem_create_tensor( - [args.batch_size, HEADS_PER_PE, args.seq_len, args.head_dim], dtype_torch) - data_dst = pynvshmem.nvshmem_create_tensor( - [args.batch_size, SEQ_PER_PE, args.num_heads, args.head_dim], dtype_torch) + data_src = pynvshmem.nvshmem_create_tensor([args.batch_size, HEADS_PER_PE, args.seq_len, args.head_dim], dtype_torch) + data_dst = pynvshmem.nvshmem_create_tensor([args.batch_size, SEQ_PER_PE, args.num_heads, args.head_dim], dtype_torch) signal = pynvshmem.nvshmem_create_tensor([PE_num], torch.uint64) # Initialize data @@ -268,6 +252,7 @@ def tilelang_reverse_all_to_all(): dist.destroy_process_group() +@tilelang.testing.requires_distributed def test_roundtrip_consistency(): """Test that forward + reverse all-to-all gives back original data""" args = parse_args() @@ -285,9 +270,7 @@ def test_roundtrip_consistency(): SEQ_PER_PE = args.seq_len // WORLD_SIZE # Create original data in sequence parallel format - original_data = torch.rand([args.batch_size, SEQ_PER_PE, args.num_heads, args.head_dim], - dtype=dtype_torch, - device='cuda') + original_data = torch.rand([args.batch_size, SEQ_PER_PE, args.num_heads, args.head_dim], dtype=dtype_torch, device="cuda") # Forward: sequence parallel -> head parallel head_parallel_data = torch_sequence_all_to_all_transpose_reference(original_data, TP_GROUP) diff --git a/examples/distributed/example_pre_attn_all2all.py b/examples/distributed/example_pre_attn_all2all.py index 53884f5b0..cb85a9389 100644 --- a/examples/distributed/example_pre_attn_all2all.py +++ b/examples/distributed/example_pre_attn_all2all.py @@ -2,6 +2,7 @@ import torch.distributed as dist import pynvshmem import tilelang +import tilelang.testing import tilelang.language as T from tilelang.distributed import init_distributed, dtype_map import argparse @@ -44,13 +45,7 @@ def torch_sequence_all_to_all_reference(data_src, group): output_list = [] for _ in range(world_size): # Receive [BATCH_SIZE, HEADS_PER_PE, SEQ_PER_PE, HEAD_DIM] from each PE - recv_data = torch.empty( - batch_size, - heads_per_pe, - seq_per_pe, - head_dim, - dtype=data_src.dtype, - device=data_src.device) + recv_data = torch.empty(batch_size, heads_per_pe, seq_per_pe, head_dim, dtype=data_src.dtype, device=data_src.device) output_list.append(recv_data) # Step 3: Execute all_to_all @@ -59,8 +54,7 @@ def torch_sequence_all_to_all_reference(data_src, group): # Step 4: Reorganize received data # output_list[pe_idx] contains data from PE pe_idx # Need to arrange by sequence dimension - result = torch.empty( - batch_size, heads_per_pe, seq_len, head_dim, dtype=data_src.dtype, device=data_src.device) + result = torch.empty(batch_size, heads_per_pe, seq_len, head_dim, dtype=data_src.dtype, device=data_src.device) for pe_idx in range(world_size): seq_start = pe_idx * seq_per_pe @@ -86,12 +80,12 @@ def sequence_parallel_all_to_all(PE_num, BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DI @T.prim_func def main( - # Input: [BATCH_SIZE, NUM_HEADS, SEQ_PER_PE, HEAD_DIM] - data_src: T.Tensor((BATCH_SIZE, NUM_HEADS, SEQ_PER_PE, HEAD_DIM), dtype), - # Output: [BATCH_SIZE, HEADS_PER_PE, SEQ_LEN, HEAD_DIM] - data_dst: T.Tensor((BATCH_SIZE, HEADS_PER_PE, SEQ_LEN, HEAD_DIM), dtype), - # Sync signals - signal: T.Tensor((PE_num,), "uint64"), + # Input: [BATCH_SIZE, NUM_HEADS, SEQ_PER_PE, HEAD_DIM] + data_src: T.Tensor((BATCH_SIZE, NUM_HEADS, SEQ_PER_PE, HEAD_DIM), dtype), + # Output: [BATCH_SIZE, HEADS_PER_PE, SEQ_LEN, HEAD_DIM] + data_dst: T.Tensor((BATCH_SIZE, HEADS_PER_PE, SEQ_LEN, HEAD_DIM), dtype), + # Sync signals + signal: T.Tensor((PE_num,), "uint64"), ): # Grid: (batch*head, target_pe) with T.Kernel(NUM_BLOCKS_X, PE_num, threads=128) as (bx, target_pe): @@ -116,7 +110,10 @@ def main( # Single block transfer for entire [SEQ_PER_PE, HEAD_DIM] data T.putmem_nbi_block( T.address_of(data_dst[batch_idx, head_idx, dst_seq_start, 0]), - T.address_of(data_src[batch_idx, src_head_idx, 0, 0]), transfer_size, target_pe) + T.address_of(data_src[batch_idx, src_head_idx, 0, 0]), + transfer_size, + target_pe, + ) # Memory fence T.fence() @@ -127,7 +124,8 @@ def main( T.address_of(signal[mype[0]]), 1, 10, # NVSHMEM_SIGNAL_ADD - target_pe) + target_pe, + ) T.fence() for k in T.serial(PE_num): T.signal_wait_until(T.address_of(signal[k]), 0, NUM_BLOCKS_X) @@ -165,8 +163,7 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--batch_size", type=int, default=2, help="Batch size") parser.add_argument("--seq_len", type=int, default=256, help="Sequence length") - parser.add_argument( - "--num_heads", type=int, default=16, help="Number of attention heads,combine QKV") + parser.add_argument("--num_heads", type=int, default=16, help="Number of attention heads,combine QKV") parser.add_argument("--head_dim", type=int, default=64, help="Head dimension") parser.add_argument("--dtype", default="float16", help="Data type") parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations") @@ -175,6 +172,7 @@ def parse_args(): return parser.parse_args() +@tilelang.testing.requires_distributed def test_all_to_all_with_golden_reference(): args = parse_args() @@ -200,13 +198,8 @@ def test_all_to_all_with_golden_reference(): print(f"Heads per PE: {HEADS_PER_PE}") # Compile TileLang kernel - func = sequence_parallel_all_to_all(PE_num, args.batch_size, args.num_heads, args.seq_len, - args.head_dim, args.dtype) - kernel = tilelang.compile( - func, pass_configs={ - "tl.disable_tma_lower": True, - "tl.disable_warp_specialized": True - }) + func = sequence_parallel_all_to_all(PE_num, args.batch_size, args.num_heads, args.seq_len, args.head_dim, args.dtype) + kernel = tilelang.compile(func, pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) if RANK == 0: print("\nTileLang Kernel Source:") @@ -216,9 +209,7 @@ def test_all_to_all_with_golden_reference(): dtype_torch = dtype_map[args.dtype] # Create input data (same for both implementations) - input_data = torch.rand([args.batch_size, args.num_heads, SEQ_PER_PE, args.head_dim], - dtype=dtype_torch, - device='cuda') + input_data = torch.rand([args.batch_size, args.num_heads, SEQ_PER_PE, args.head_dim], dtype=dtype_torch, device="cuda") print(f"PE {RANK} Input shape: {input_data.shape}") print(f"PE {RANK} Input sample: {input_data[0, 0, 0, :3]}") @@ -233,10 +224,8 @@ def test_all_to_all_with_golden_reference(): # === Test 2: TileLang NVSHMEM Implementation === def tilelang_all_to_all(): # Create NVSHMEM tensors - data_src = pynvshmem.nvshmem_create_tensor( - [args.batch_size, args.num_heads, SEQ_PER_PE, args.head_dim], dtype_torch) - data_dst = pynvshmem.nvshmem_create_tensor( - [args.batch_size, HEADS_PER_PE, args.seq_len, args.head_dim], dtype_torch) + data_src = pynvshmem.nvshmem_create_tensor([args.batch_size, args.num_heads, SEQ_PER_PE, args.head_dim], dtype_torch) + data_dst = pynvshmem.nvshmem_create_tensor([args.batch_size, HEADS_PER_PE, args.seq_len, args.head_dim], dtype_torch) signal = pynvshmem.nvshmem_create_tensor([PE_num], torch.uint64) # Initialize data @@ -246,7 +235,7 @@ def tilelang_all_to_all(): # Execute kernel kernel(data_src, data_dst, signal) - #pynvshmem.nvshmem_barrier_all() + # pynvshmem.nvshmem_barrier_all() return data_dst diff --git a/examples/distributed/example_pre_attn_all2all_transpose.py b/examples/distributed/example_pre_attn_all2all_transpose.py index f5c4b9fc3..80f6ef6b7 100644 --- a/examples/distributed/example_pre_attn_all2all_transpose.py +++ b/examples/distributed/example_pre_attn_all2all_transpose.py @@ -2,6 +2,7 @@ import torch.distributed as dist import pynvshmem import tilelang +import tilelang.testing import tilelang.language as T from tilelang.distributed import init_distributed, dtype_map import argparse @@ -41,21 +42,14 @@ def torch_sequence_all_to_all_transpose_reference(data_src, group): # Step 2: Prepare output list for all_to_all output_list = [] for _ in range(world_size): - recv_data = torch.empty( - batch_size, - seq_per_pe, - heads_per_pe, - head_dim, - dtype=data_src.dtype, - device=data_src.device) + recv_data = torch.empty(batch_size, seq_per_pe, heads_per_pe, head_dim, dtype=data_src.dtype, device=data_src.device) output_list.append(recv_data) # Step 3: Execute all_to_all dist.all_to_all(output_list, input_list, group=group) # Step 4: Reorganize received data with transpose - result = torch.empty( - batch_size, heads_per_pe, seq_len, head_dim, dtype=data_src.dtype, device=data_src.device) + result = torch.empty(batch_size, heads_per_pe, seq_len, head_dim, dtype=data_src.dtype, device=data_src.device) for pe_idx in range(world_size): seq_start = pe_idx * seq_per_pe @@ -67,12 +61,7 @@ def torch_sequence_all_to_all_transpose_reference(data_src, group): return result -def sequence_parallel_all_to_all_transpose(PE_num, - BATCH_SIZE, - NUM_HEADS, - SEQ_LEN, - HEAD_DIM, - dtype="float16"): +def sequence_parallel_all_to_all_transpose(PE_num, BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM, dtype="float16"): """ Coarse-grained version with proper transpose handling Each block handles one (batch, head) combination and processes all sequence positions @@ -85,9 +74,9 @@ def sequence_parallel_all_to_all_transpose(PE_num, @T.prim_func def main( - data_src: T.Tensor((BATCH_SIZE, SEQ_PER_PE, NUM_HEADS, HEAD_DIM), dtype), - data_dst: T.Tensor((BATCH_SIZE, HEADS_PER_PE, SEQ_LEN, HEAD_DIM), dtype), - signal: T.Tensor((PE_num,), "uint64"), + data_src: T.Tensor((BATCH_SIZE, SEQ_PER_PE, NUM_HEADS, HEAD_DIM), dtype), + data_dst: T.Tensor((BATCH_SIZE, HEADS_PER_PE, SEQ_LEN, HEAD_DIM), dtype), + signal: T.Tensor((PE_num,), "uint64"), ): with T.Kernel(NUM_BLOCKS_X, PE_num, threads=128) as (bx, target_pe): tx = T.thread_binding(128, thread="threadIdx.x") @@ -115,8 +104,10 @@ def main( T.putmem_nbi_block( T.address_of(data_dst[batch_idx, head_idx, dst_seq_idx, 0]), - T.address_of(data_src[batch_idx, seq_idx, src_head_idx, 0]), transfer_size, - target_pe) + T.address_of(data_src[batch_idx, seq_idx, src_head_idx, 0]), + transfer_size, + target_pe, + ) T.fence() @@ -126,7 +117,8 @@ def main( T.address_of(signal[mype[0]]), 1, # Signal the number of sequence positions processed T.Amo.SIGNAL_ADD, - target_pe) + target_pe, + ) T.fence() # Wait for all blocks to complete all sequence positions T.signal_wait_until(T.address_of(signal[target_pe]), T.CmpType.EQ, NUM_BLOCKS_X) @@ -173,6 +165,7 @@ def parse_args(): return parser.parse_args() +@tilelang.testing.requires_distributed def test_transpose_all_to_all_with_golden_reference(): args = parse_args() @@ -198,13 +191,8 @@ def test_transpose_all_to_all_with_golden_reference(): print(f"Heads per PE: {HEADS_PER_PE}") # Compile TileLang kernel - func = sequence_parallel_all_to_all_transpose(PE_num, args.batch_size, args.num_heads, - args.seq_len, args.head_dim, args.dtype) - kernel = tilelang.compile( - func, pass_configs={ - "tl.disable_tma_lower": True, - "tl.disable_warp_specialized": True - }) + func = sequence_parallel_all_to_all_transpose(PE_num, args.batch_size, args.num_heads, args.seq_len, args.head_dim, args.dtype) + kernel = tilelang.compile(func, pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) if RANK == 0: print("\nTileLang Kernel Source:") @@ -214,9 +202,7 @@ def test_transpose_all_to_all_with_golden_reference(): dtype_torch = dtype_map[args.dtype] # Create input data: [BATCH_SIZE, SEQ_PER_PE, NUM_HEADS, HEAD_DIM] - random like example - input_data = torch.rand([args.batch_size, SEQ_PER_PE, args.num_heads, args.head_dim], - dtype=dtype_torch, - device='cuda') + input_data = torch.rand([args.batch_size, SEQ_PER_PE, args.num_heads, args.head_dim], dtype=dtype_torch, device="cuda") print(f"PE {RANK} Input shape: {input_data.shape}") print(f"PE {RANK} Input sample: {input_data[0, 0, 0, :3]}") @@ -230,10 +216,8 @@ def test_transpose_all_to_all_with_golden_reference(): # === Test 2: TileLang NVSHMEM Implementation === def tilelang_all_to_all(): # Create NVSHMEM tensors - data_src = pynvshmem.nvshmem_create_tensor( - [args.batch_size, SEQ_PER_PE, args.num_heads, args.head_dim], dtype_torch) - data_dst = pynvshmem.nvshmem_create_tensor( - [args.batch_size, HEADS_PER_PE, args.seq_len, args.head_dim], dtype_torch) + data_src = pynvshmem.nvshmem_create_tensor([args.batch_size, SEQ_PER_PE, args.num_heads, args.head_dim], dtype_torch) + data_dst = pynvshmem.nvshmem_create_tensor([args.batch_size, HEADS_PER_PE, args.seq_len, args.head_dim], dtype_torch) signal = pynvshmem.nvshmem_create_tensor([PE_num], torch.uint64) # Initialize data diff --git a/examples/distributed/example_simple_shift.py b/examples/distributed/example_simple_shift.py index a837c4b8d..b1e69d960 100644 --- a/examples/distributed/example_simple_shift.py +++ b/examples/distributed/example_simple_shift.py @@ -5,11 +5,10 @@ def simple_shift(M, N, block_M, block_N, dtype="float16"): - @T.prim_func def main( - A: T.Buffer((M, N), dtype), - B: T.Buffer((M, N), dtype), + A: T.Buffer((M, N), dtype), + B: T.Buffer((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): mype = T.alloc_local([1], "int32") @@ -19,8 +18,7 @@ def main( npes[0] = T.get_pe_num() peer[0] = (mype[0] + 1) % npes[0] - T.putmem_nbi_block( - T.address_of(B[0, 0]), T.address_of(A[0, 0]), block_M * block_N * 2, peer[0]) + T.putmem_nbi_block(T.address_of(B[0, 0]), T.address_of(A[0, 0]), block_M * block_N * 2, peer[0]) return main @@ -28,6 +26,7 @@ def main( WORLD_SIZE, RANK, LOCAL_RANK = init_distributed() func = simple_shift(128, 128, 128, 128) +# Auto-selects cython backend when TILELANG_USE_DISTRIBUTED=1 is set kernel = tilelang.compile(func, out_idx=-1) # Get CUDA Source diff --git a/examples/distributed/example_sp_ag_attention_intra_node.py b/examples/distributed/example_sp_ag_attention_intra_node.py index c4d120fea..5b893e4f2 100644 --- a/examples/distributed/example_sp_ag_attention_intra_node.py +++ b/examples/distributed/example_sp_ag_attention_intra_node.py @@ -17,7 +17,6 @@ class FusedSequenceParallelAttn(torch.nn.Module): - def __init__( self, pg: torch.distributed.ProcessGroup, @@ -47,8 +46,9 @@ def __init__( self.max_seqlen_k = max_seqlen_k self.head_dim = head_dim - assert (max_seqlen_q % self.world_size == 0 and max_seqlen_q % self.world_size - == 0), f"sequence length should be multiple of world_size({self.world_size})" + assert max_seqlen_q % self.world_size == 0 and max_seqlen_q % self.world_size == 0, ( + f"sequence length should be multiple of world_size({self.world_size})" + ) self.max_q_shard_len = self.max_seqlen_q // self.world_size self.input_dtype = input_dtype @@ -101,7 +101,6 @@ def forward(self, q_shard, k_shards, v_shards, cu_seqlens_q, cu_seqlens_k, print class TorchSequenceParallelAttn(torch.nn.Module): - def __init__( self, pg: torch.distributed.ProcessGroup, @@ -138,8 +137,9 @@ def __init__( self.max_q_shard_len = max_seqlen_q // self.world_size self.max_kv_shard_ken = max_seqlen_q // self.world_size - assert (max_seqlen_q % self.world_size == 0 and max_seqlen_q % self.world_size - == 0), f"sequence length should be multiple of world_size({self.world_size})" + assert max_seqlen_q % self.world_size == 0 and max_seqlen_q % self.world_size == 0, ( + f"sequence length should be multiple of world_size({self.world_size})" + ) self.ag_k_buffer: torch.Tensor = torch.empty( self.batch_size * self.max_seqlen_k, @@ -161,9 +161,9 @@ def forward(self, q_shard, k_shard, v_shard, cu_seqlens_q, cu_seqlens_k): def _gen_mask(offset, q_shard_len, kv_len): if self.is_causal: mask = torch.zeros((q_shard_len, kv_len), dtype=torch.bool, device=self.device) - mask[:, :offset + q_shard_len] = True + mask[:, : offset + q_shard_len] = True if offset < kv_len: - mask[:, offset:offset + q_shard_len].tril_() + mask[:, offset : offset + q_shard_len].tril_() return mask return None @@ -186,37 +186,27 @@ def _gen_mask(offset, q_shard_len, kv_len): half_q_shard_len = q_shard_len // 2 half_kv_shard_len = kv_shard_len // 2 - q0_shard = q_shard[cu_seqlens_q_start:cu_seqlens_q_start + - half_q_shard_len, :, :].clone() - q1_shard = q_shard[cu_seqlens_q_start + - half_q_shard_len:cu_seqlens_q_end, :, :].clone() - - q0_shard_permute = torch.permute( - q0_shard.reshape(1, half_q_shard_len, q_head, head_dim), - (0, 2, 1, 3)).contiguous() - q1_shard_permute = torch.permute( - q1_shard.reshape(1, half_q_shard_len, q_head, head_dim), - (0, 2, 1, 3)).contiguous() - - k0_shard = k_shard[cu_seqlens_k_start:cu_seqlens_k_start + - half_kv_shard_len, :, :].clone() - k1_shard = k_shard[cu_seqlens_k_start + - half_kv_shard_len:cu_seqlens_k_end, :, :].clone() - v0_shard = v_shard[cu_seqlens_k_start:cu_seqlens_k_start + - half_kv_shard_len, :, :].clone() - v1_shard = v_shard[cu_seqlens_k_start + - half_kv_shard_len:cu_seqlens_k_end, :, :].clone() - - buffer_size = (half_kv_shard_len * kv_head * head_dim * self.world_size) - - ag_k0 = self.ag_k_buffer.reshape(-1)[:buffer_size].reshape( - half_kv_shard_len * self.world_size, kv_head, head_dim) - ag_k1 = self.ag_k_buffer.reshape(-1)[buffer_size:2 * buffer_size].reshape( - half_kv_shard_len * self.world_size, kv_head, head_dim) - ag_v0 = self.ag_v_buffer.reshape(-1)[:buffer_size].reshape( - half_kv_shard_len * self.world_size, kv_head, head_dim) - ag_v1 = self.ag_v_buffer.reshape(-1)[buffer_size:2 * buffer_size].reshape( - half_kv_shard_len * self.world_size, kv_head, head_dim) + q0_shard = q_shard[cu_seqlens_q_start : cu_seqlens_q_start + half_q_shard_len, :, :].clone() + q1_shard = q_shard[cu_seqlens_q_start + half_q_shard_len : cu_seqlens_q_end, :, :].clone() + + q0_shard_permute = torch.permute(q0_shard.reshape(1, half_q_shard_len, q_head, head_dim), (0, 2, 1, 3)).contiguous() + q1_shard_permute = torch.permute(q1_shard.reshape(1, half_q_shard_len, q_head, head_dim), (0, 2, 1, 3)).contiguous() + + k0_shard = k_shard[cu_seqlens_k_start : cu_seqlens_k_start + half_kv_shard_len, :, :].clone() + k1_shard = k_shard[cu_seqlens_k_start + half_kv_shard_len : cu_seqlens_k_end, :, :].clone() + v0_shard = v_shard[cu_seqlens_k_start : cu_seqlens_k_start + half_kv_shard_len, :, :].clone() + v1_shard = v_shard[cu_seqlens_k_start + half_kv_shard_len : cu_seqlens_k_end, :, :].clone() + + buffer_size = half_kv_shard_len * kv_head * head_dim * self.world_size + + ag_k0 = self.ag_k_buffer.reshape(-1)[:buffer_size].reshape(half_kv_shard_len * self.world_size, kv_head, head_dim) + ag_k1 = self.ag_k_buffer.reshape(-1)[buffer_size : 2 * buffer_size].reshape( + half_kv_shard_len * self.world_size, kv_head, head_dim + ) + ag_v0 = self.ag_v_buffer.reshape(-1)[:buffer_size].reshape(half_kv_shard_len * self.world_size, kv_head, head_dim) + ag_v1 = self.ag_v_buffer.reshape(-1)[buffer_size : 2 * buffer_size].reshape( + half_kv_shard_len * self.world_size, kv_head, head_dim + ) torch.distributed.all_gather_into_tensor( ag_k0, k0_shard, @@ -238,19 +228,15 @@ def _gen_mask(offset, q_shard_len, kv_len): group=self.pg, ) ag_k1 = ag_k1.reshape(self.world_size, half_kv_shard_len, kv_head, head_dim) - ag_k1 = torch.flip(ag_k1, [0]).reshape(self.world_size * half_kv_shard_len, kv_head, - head_dim) + ag_k1 = torch.flip(ag_k1, [0]).reshape(self.world_size * half_kv_shard_len, kv_head, head_dim) ag_k = torch.cat((ag_k0, ag_k1), dim=0) - ag_k = torch.permute(ag_k.reshape(1, kv_len, kv_head, head_dim), - (0, 2, 1, 3)).contiguous() + ag_k = torch.permute(ag_k.reshape(1, kv_len, kv_head, head_dim), (0, 2, 1, 3)).contiguous() ag_k = ag_k.repeat_interleave(q_head // kv_head, -3) ag_v1 = ag_v1.reshape(self.world_size, half_kv_shard_len, kv_head, head_dim) - ag_v1 = torch.flip(ag_v1, [0]).reshape(self.world_size * half_kv_shard_len, kv_head, - head_dim) + ag_v1 = torch.flip(ag_v1, [0]).reshape(self.world_size * half_kv_shard_len, kv_head, head_dim) ag_v = torch.cat((ag_v0, ag_v1), dim=0) - ag_v = torch.permute(ag_v.reshape(1, kv_len, kv_head, head_dim), - (0, 2, 1, 3)).contiguous() + ag_v = torch.permute(ag_v.reshape(1, kv_len, kv_head, head_dim), (0, 2, 1, 3)).contiguous() ag_v = ag_v.repeat_interleave(q_head // kv_head, -3) offset_q0 = half_q_shard_len * self.rank @@ -258,16 +244,12 @@ def _gen_mask(offset, q_shard_len, kv_len): prefix = kv_len - q_len mask0 = _gen_mask(prefix + offset_q0, half_q_shard_len, kv_len) mask1 = _gen_mask(prefix + offset_q1, half_q_shard_len, kv_len) - out0 = torch.nn.functional.scaled_dot_product_attention( - q0_shard_permute, ag_k, ag_v, attn_mask=mask0) - out1 = torch.nn.functional.scaled_dot_product_attention( - q1_shard_permute, ag_k, ag_v, attn_mask=mask1) + out0 = torch.nn.functional.scaled_dot_product_attention(q0_shard_permute, ag_k, ag_v, attn_mask=mask0) + out1 = torch.nn.functional.scaled_dot_product_attention(q1_shard_permute, ag_k, ag_v, attn_mask=mask1) out = torch.cat((out0, out1), dim=2) # [1, q_head, q_shard_len, head_dim] else: cu_q_shard = q_shard[cu_seqlens_q_start:cu_seqlens_q_end, :, :].clone() - cu_q_shard_permute = torch.permute( - cu_q_shard.reshape(1, q_shard_len, q_head, head_dim), - (0, 2, 1, 3)).contiguous() + cu_q_shard_permute = torch.permute(cu_q_shard.reshape(1, q_shard_len, q_head, head_dim), (0, 2, 1, 3)).contiguous() total_size = kv_len * kv_head * head_dim ag_k = self.ag_k_buffer.reshape(-1)[:total_size].reshape(kv_len, kv_head, head_dim) @@ -284,19 +266,17 @@ def _gen_mask(offset, q_shard_len, kv_len): cu_v_shard, group=self.pg, ) - ag_k = torch.permute(ag_k.reshape(1, kv_len, kv_head, head_dim), - (0, 2, 1, 3)).contiguous() + ag_k = torch.permute(ag_k.reshape(1, kv_len, kv_head, head_dim), (0, 2, 1, 3)).contiguous() ag_k = ag_k.repeat_interleave(q_head // kv_head, -3) - ag_v = torch.permute(ag_v.reshape(1, kv_len, kv_head, head_dim), - (0, 2, 1, 3)).contiguous() + ag_v = torch.permute(ag_v.reshape(1, kv_len, kv_head, head_dim), (0, 2, 1, 3)).contiguous() ag_v = ag_v.repeat_interleave(q_head // kv_head, -3) offset = self.rank * q_shard_len prefix = kv_len - q_len mask = _gen_mask(prefix + offset, q_shard_len, kv_len) out = torch.nn.functional.scaled_dot_product_attention( - cu_q_shard_permute, ag_k, ag_v, - attn_mask=mask) # [1, q_head, q_shard_len, head_dim] + cu_q_shard_permute, ag_k, ag_v, attn_mask=mask + ) # [1, q_head, q_shard_len, head_dim] out = torch.permute(out.reshape(q_head, q_shard_len, head_dim), (1, 0, 2)).contiguous() out_list.append(out) @@ -327,29 +307,20 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rank, num_ranks, group = init_dist(local_rank, num_local_ranks) assert rank == local_rank and num_ranks == num_local_ranks, "only support single node for now" allocator = tilelang.get_allocator( - size=2**30, - device=device, - is_distributed=True, - local_rank=local_rank, - num_local_ranks=num_local_ranks, - group=group) + size=2**30, device=device, is_distributed=True, local_rank=local_rank, num_local_ranks=num_local_ranks, group=group + ) cu_seqlens_q = torch.tensor(cu_seqlens_q_list, dtype=torch.int32, device=device) cu_seqlens_q = cu_seqlens_q // num_local_ranks cu_seqlens_k = torch.tensor(cu_seqlens_k_list, dtype=torch.int32, device=device) - q_shard = tilelang.tensor((cu_seqlens_q[-1], q_head, head_dim), - dtype=dtype, - allocator=allocator).normal_( - mean=0.0, std=0.5) - k_shards = tilelang.tensor((cu_seqlens_k[-1] // num_local_ranks, kv_head, head_dim), - dtype=dtype, - allocator=allocator, - return_peers=True) - v_shards = tilelang.tensor((cu_seqlens_k[-1] // num_local_ranks, kv_head, head_dim), - dtype=dtype, - allocator=allocator, - return_peers=True) + q_shard = tilelang.tensor((cu_seqlens_q[-1], q_head, head_dim), dtype=dtype, allocator=allocator).normal_(mean=0.0, std=0.5) + k_shards = tilelang.tensor( + (cu_seqlens_k[-1] // num_local_ranks, kv_head, head_dim), dtype=dtype, allocator=allocator, return_peers=True + ) + v_shards = tilelang.tensor( + (cu_seqlens_k[-1] // num_local_ranks, kv_head, head_dim), dtype=dtype, allocator=allocator, return_peers=True + ) k_shards[local_rank].normal_(mean=0.0, std=0.5) v_shards[local_rank].normal_(mean=0.0, std=0.5) @@ -386,12 +357,10 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): enable_zig_zag, ) - tilescale_out = tilescale_module( - q_shard, k_shards, v_shards, cu_seqlens_q, cu_seqlens_k, print_source=True) + tilescale_out = tilescale_module(q_shard, k_shards, v_shards, cu_seqlens_q, cu_seqlens_k, print_source=True) print(f"tilescale_out: {tilescale_out.shape}") - torch_out = torch_module(q_shard, k_shards[local_rank], v_shards[local_rank], cu_seqlens_q, - cu_seqlens_k) + torch_out = torch_module(q_shard, k_shards[local_rank], v_shards[local_rank], cu_seqlens_q, cu_seqlens_k) print(f"torch_out: {torch_out.shape}") atol = 1e-2 @@ -402,10 +371,7 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): print(f"rank {local_rank} check failed.❌") print(f"torch_out: {torch_out}, tilelang_out: {tilescale_out}") - _, tl_t = perf_fn( - lambda: tilescale_module(q_shard, k_shards, v_shards, cu_seqlens_q, cu_seqlens_k), - warmup=5, - rep=5) + _, tl_t = perf_fn(lambda: tilescale_module(q_shard, k_shards, v_shards, cu_seqlens_q, cu_seqlens_k), warmup=5, rep=5) print(f"rank {local_rank} tilescale time: {tl_t:.2f} ms") @@ -414,20 +380,16 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - '--num-processes', type=int, default=1, help='Number of processes to spawn (default: 2)') + parser.add_argument("--num-processes", type=int, default=2, help="Number of processes to spawn (default: 2)") parser.add_argument("--batch_size", type=int, default=2, help="batch size") parser.add_argument("--q_head", type=int, default=32, help="num q heads") parser.add_argument("--kv_head", type=int, default=8, help="num kv heads") parser.add_argument("--max_seqlen_q", type=int, default=8192, help="max sequence length of q") - parser.add_argument( - "--max_seqlen_k", type=int, default=12288, help="max sequence length of k/v") + parser.add_argument("--max_seqlen_k", type=int, default=12288, help="max sequence length of k/v") parser.add_argument("--head_dim", type=int, default=128, help="head dim") - parser.add_argument( - "--seqlens_q", type=int, nargs='+', default=[4096, 8192], help="sequence lengths of q") - parser.add_argument( - "--seqlens_k", type=int, nargs='+', default=[6144, 12288], help="sequence lengths of k/v") - parser.add_argument('--is_causal', action='store_true', help='causal') + parser.add_argument("--seqlens_q", type=int, nargs="+", default=[4096, 8192], help="sequence lengths of q") + parser.add_argument("--seqlens_k", type=int, nargs="+", default=[6144, 12288], help="sequence lengths of k/v") + parser.add_argument("--is_causal", action="store_true", help="causal") parser.add_argument( "--zig-zag", "--no-zig-zag", diff --git a/examples/distributed/example_summa.py b/examples/distributed/example_summa.py index 168517c09..640a31de6 100644 --- a/examples/distributed/example_summa.py +++ b/examples/distributed/example_summa.py @@ -11,7 +11,6 @@ def summa(MESH, M, N, K, block_M, block_N, block_K, dtype="float16"): - M_local = T.ceildiv(M, MESH) N_local = T.ceildiv(N, MESH) K_local = T.ceildiv(K, MESH) @@ -22,13 +21,13 @@ def summa(MESH, M, N, K, block_M, block_N, block_K, dtype="float16"): @T.prim_func def main( - A: T.Tensor((2, M_local, K_local), dtype), - B: T.Tensor((2, N_local, K_local), dtype), - A_signal_to: T.Tensor((T.ceildiv(M, block_M),), "uint64"), - A_signal_from: T.Tensor((T.ceildiv(M, block_M),), "uint64"), - B_signal_to: T.Tensor((T.ceildiv(N, block_N),), "uint64"), - B_signal_from: T.Tensor((T.ceildiv(N, block_N),), "uint64"), - C: T.Tensor((M_local, N_local), dtype), + A: T.Tensor((2, M_local, K_local), dtype), + B: T.Tensor((2, N_local, K_local), dtype), + A_signal_to: T.Tensor((T.ceildiv(M, block_M),), "uint64"), + A_signal_from: T.Tensor((T.ceildiv(M, block_M),), "uint64"), + B_signal_to: T.Tensor((T.ceildiv(N, block_N),), "uint64"), + B_signal_from: T.Tensor((T.ceildiv(N, block_N),), "uint64"), + C: T.Tensor((M_local, N_local), dtype), ): grid_size = T.min(sm_num, total_tiles) A_rows_per_block = T.ceildiv(M_local, grid_size) @@ -63,8 +62,11 @@ def main( T.address_of(A[(ko + 1) % 2, A_rows_per_block * block_id, 0]), T.address_of(A[ko % 2, A_rows_per_block * block_id, 0]), A_rows_per_block * K_local * dtype_map[dtype].itemsize, - T.address_of(A_signal_to[0]), 1, T.Amo.SIGNAL_ADD, - pe_mn * MESH + peer_k) + T.address_of(A_signal_to[0]), + 1, + T.Amo.SIGNAL_ADD, + pe_mn * MESH + peer_k, + ) # broadcast B if pe_k == ko: @@ -80,8 +82,11 @@ def main( T.address_of(B[(ko + 1) % 2, B_cols_per_block * block_id, 0]), T.address_of(B[ko % 2, B_cols_per_block * block_id, 0]), B_cols_per_block * K_local * dtype_map[dtype].itemsize, - T.address_of(B_signal_to[0]), 1, T.Amo.SIGNAL_ADD, - pe_mn * MESH + peer_k) + T.address_of(B_signal_to[0]), + 1, + T.Amo.SIGNAL_ADD, + pe_mn * MESH + peer_k, + ) # TODO: check if __syncthreads() is needed T.signal_wait_until( @@ -96,7 +101,6 @@ def main( ) for w in T.serial(waves): - bx = (grid_size * w + block_id) // T.ceildiv(N_local, block_N) by = (grid_size * w + block_id) % T.ceildiv(N_local, block_N) @@ -158,11 +162,7 @@ def parse_args(): K_local = math.ceil(K / MESH) func = summa(MESH, M, N, K, block_M, block_N, block_K, args.dtype) - kernel = tilelang.compile( - func, pass_configs={ - "tl.disable_tma_lower": True, - "tl.disable_warp_specialized": True - }) + kernel = tilelang.compile(func, pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) # Get CUDA Source if RANK == 0: @@ -183,9 +183,9 @@ def parse_args(): b_scatter_list = [] for r in range(WORLD_SIZE): rr, cc = divmod(r, MESH) - c_tile = C[M_local * rr:M_local * (rr + 1), N_local * cc:N_local * (cc + 1)] - a_tile = A[M_local * rr:M_local * (rr + 1), K_local * cc:K_local * (cc + 1)] - b_tile = B[N_local * cc:N_local * (cc + 1), K_local * rr:K_local * (rr + 1)] + c_tile = C[M_local * rr : M_local * (rr + 1), N_local * cc : N_local * (cc + 1)] + a_tile = A[M_local * rr : M_local * (rr + 1), K_local * cc : K_local * (cc + 1)] + b_tile = B[N_local * cc : N_local * (cc + 1), K_local * rr : K_local * (rr + 1)] c_scatter_list.append(c_tile.contiguous()) a_scatter_list.append(a_tile.contiguous()) @@ -220,7 +220,7 @@ def parse_args(): dist.barrier() if r == RANK: if torch.allclose(C_tilelang, ref, rtol=1e-2, atol=1e-2): - print('-' * 100) + print("-" * 100) print(f"[Rank {RANK}] ✅ Tilelang and Torch match") else: abs_error = torch.abs(C_tilelang - ref) @@ -230,7 +230,7 @@ def parse_args(): max_rel_error = rel_error.max().item() mismatch_ratio = (abs_error > (1e-2 + 1e-2 * torch.abs(ref))).float().mean().item() - print('-' * 100) + print("-" * 100) print(f"[Rank {RANK}] ❌ Tilelang and Torch mismatch") print(f"[Rank {RANK}] ref:\n{ref}") print(f"[Rank {RANK}] tilelang:\n{C_tilelang}") @@ -281,8 +281,7 @@ def reduce_local_time(local_time): total_flops = 2 * M * N * K -avg_time = reduce_local_time( - bench(kernel, A, B, A_signal_to, A_signal_from, B_signal_to, B_signal_from, C_tilelang)) +avg_time = reduce_local_time(bench(kernel, A, B, A_signal_to, A_signal_from, B_signal_to, B_signal_from, C_tilelang)) if RANK == 0: print(f"avg time of RANK {RANK}: {avg_time} ms") diff --git a/examples/distributed/gemm_rs_utils.py b/examples/distributed/gemm_rs_utils.py index 2d8141467..0a6634c39 100644 --- a/examples/distributed/gemm_rs_utils.py +++ b/examples/distributed/gemm_rs_utils.py @@ -79,16 +79,13 @@ def __post_init__(self): for buf in self.signal_bufs: assert buf.shape[0] >= 2 * self.world_size - self.scatter_signal_bufs = [buf[:self.world_size] for buf in self.signal_bufs] - self.rs_per_node_signal_bufs = [ - buf[self.world_size:self.world_size * 2] for buf in self.signal_bufs - ] + self.scatter_signal_bufs = [buf[: self.world_size] for buf in self.signal_bufs] + self.rs_per_node_signal_bufs = [buf[self.world_size : self.world_size * 2] for buf in self.signal_bufs] for node_id in range(self.nnodes): self.scatter_signal_buf_list_for_each_node.append( - self.scatter_signal_bufs[self.local_rank][node_id * - self.local_world_size:(node_id + 1) * - self.local_world_size]) + self.scatter_signal_bufs[self.local_rank][node_id * self.local_world_size : (node_id + 1) * self.local_world_size] + ) def reset_barriers(self) -> int: # self.scatter_signal_bufs[self.local_rank].fill_(0) @@ -101,9 +98,7 @@ def get_scatter_bufs_and_signal_for_each_node(self, input, node_id): M_per_node = M_per_rank * self.local_world_size M_start = node_id * M_per_node M_end = M_start + M_per_node - scatter_bufs_intra_node = [ - self.scatter_bufs[i][M_start:M_end] for i in range(self.local_world_size) - ] + scatter_bufs_intra_node = [self.scatter_bufs[i][M_start:M_end] for i in range(self.local_world_size)] return scatter_bufs_intra_node, self.scatter_signal_buf_list_for_each_node[node_id] @property @@ -131,36 +126,32 @@ def scatter_signal_buf(self) -> torch.Tensor: return self.scatter_signal_bufs[self.local_rank] -def create_reduce_scater_2d_ctx(max_M, - N, - rank, - world_size, - local_world_size, - dtype, - overlap_with_gemm=True, - num_reduction_sms=15) -> ReduceScatter2DContext: +def create_reduce_scater_2d_ctx( + max_M, N, rank, world_size, local_world_size, dtype, overlap_with_gemm=True, num_reduction_sms=15 +) -> ReduceScatter2DContext: """ - for num_reduction_sms: tunable param, 16 are enough for H800 - For H800, we overlap local reduce and inter-node p2p with intra-node scatter. - The reduction kernel bandwidth is not a bottleneck if it exceeds 450GB, so only a few SMs are needed. - For machines with higher intra_node bandwidth(e.g. H100), we may need to increase the number of SMs or redesign overlapping. + for num_reduction_sms: tunable param, 16 are enough for H800 + For H800, we overlap local reduce and inter-node p2p with intra-node scatter. + The reduction kernel bandwidth is not a bottleneck if it exceeds 450GB, so only a few SMs are needed. + For machines with higher intra_node bandwidth(e.g. H100), we may need to increase the number of SMs or redesign overlapping. """ assert world_size % local_world_size == 0 assert max_M % world_size == 0 scatter_bufs = pynvshmem.nvshmem_create_tensor_list_intra_node([max_M, N], dtype) - rs_per_node_bufs = pynvshmem.nvshmem_create_tensor_list_intra_node( - [max_M // local_world_size, N], dtype) + rs_per_node_bufs = pynvshmem.nvshmem_create_tensor_list_intra_node([max_M // local_world_size, N], dtype) - p2p_bufs = pynvshmem.nvshmem_create_tensor_list_intra_node([max_M // local_world_size, N], - dtype) + p2p_bufs = pynvshmem.nvshmem_create_tensor_list_intra_node([max_M // local_world_size, N], dtype) # signal_buf: scatter_signal | rs_per_node_signal num_signal_bufs = 2 - signal_bufs = pynvshmem.nvshmem_create_tensor_list_intra_node([ - world_size * num_signal_bufs, - ], SIGNAL_DTYPE) + signal_bufs = pynvshmem.nvshmem_create_tensor_list_intra_node( + [ + world_size * num_signal_bufs, + ], + SIGNAL_DTYPE, + ) # TODO: implement barrier_all_on_stream # barrier_all_on_stream(None, torch.cuda.current_stream()) @@ -187,7 +178,8 @@ def create_reduce_scater_2d_ctx(max_M, p2p_stream=p2p_stream, num_sync_sms=num_sync_sms, num_p2p_sms=num_p2p_sms, - num_reduction_sms=num_reduction_sms) + num_reduction_sms=num_reduction_sms, + ) return ctx @@ -211,14 +203,7 @@ class GEMMReduceScatterTensorParallelContext: GROUP_M: int = 8 stages: int = 3 - def update(self, - rs_stream, - output_dtype=None, - BLOCK_M=128, - BLOCK_N=256, - BLOCK_K=64, - GROUP_M=8, - stages=3): + def update(self, rs_stream, output_dtype=None, BLOCK_M=128, BLOCK_N=256, BLOCK_K=64, GROUP_M=8, stages=3): self.rs_stream = rs_stream self.output_dtype = output_dtype self.BLOCK_M = BLOCK_M @@ -233,20 +218,10 @@ def get_gemm_out_buf(self, input): return self.gemm_out_bufs[local_rank][:M] -def create_gemm_rs_context(max_M, - N, - rank, - world_size, - local_world_size, - output_dtype, - rs_stream, - BLOCK_M=128, - BLOCK_N=256, - BLOCK_K=64, - GROUP_M=8, - stages=3) -> GEMMReduceScatterTensorParallelContext: - rs_ctx = create_reduce_scater_2d_ctx( - max_M, N, rank, world_size, local_world_size, output_dtype, overlap_with_gemm=True) +def create_gemm_rs_context( + max_M, N, rank, world_size, local_world_size, output_dtype, rs_stream, BLOCK_M=128, BLOCK_N=256, BLOCK_K=64, GROUP_M=8, stages=3 +) -> GEMMReduceScatterTensorParallelContext: + rs_ctx = create_reduce_scater_2d_ctx(max_M, N, rank, world_size, local_world_size, output_dtype, overlap_with_gemm=True) NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count num_gemm_sms = NUM_SMS - rs_ctx.num_rs_sms gemm_out_bufs = pynvshmem.nvshmem_create_tensor_list_intra_node([max_M, N], output_dtype) @@ -260,5 +235,6 @@ def create_gemm_rs_context(max_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_M=GROUP_M, - stages=stages) + stages=stages, + ) return ctx diff --git a/examples/distributed/primitives/example_get_block.py b/examples/distributed/primitives/example_get_block.py index 9039fbf6c..369e81032 100644 --- a/examples/distributed/primitives/example_get_block.py +++ b/examples/distributed/primitives/example_get_block.py @@ -8,15 +8,14 @@ from tilelang.distributed import init_dist tilelang.disable_cache() -os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log +os.environ["NCCL_DEBUG"] = "WARN" # silence NCCL log def get_kernel(M, num_rank, block_M, threads): - @T.prim_func def main( - dst: T.Tensor((M), "float32"), - src: T.Tensor((M), "float32"), + dst: T.Tensor((M), "float32"), + src: T.Tensor((M), "float32"), ): with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx): rank = T.alloc_local([1], "uint64") @@ -42,12 +41,8 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rank, num_ranks, group = init_dist(local_rank, num_local_ranks) allocator = tilelang.get_allocator( - size=2**25, - device="cuda", - is_distributed=True, - local_rank=local_rank, - num_local_ranks=num_local_ranks, - group=group) + size=2**25, device="cuda", is_distributed=True, local_rank=local_rank, num_local_ranks=num_local_ranks, group=group + ) kernel = tilelang.compile(get_kernel(M, num_ranks, BLOCK_M, threads)) kernel.initialize(allocator=allocator) if local_rank == 0: @@ -78,9 +73,8 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - '--num-processes', type=int, default=2, help='Number of processes to spawn (default: 2)') - parser.add_argument('--M', type=int, default=65536, help='M dimension') + parser.add_argument("--num-processes", type=int, default=2, help="Number of processes to spawn (default: 2)") + parser.add_argument("--M", type=int, default=65536, help="M dimension") args = parser.parse_args() num_processes = args.num_processes diff --git a/examples/distributed/primitives/example_get_warp.py b/examples/distributed/primitives/example_get_warp.py index 49b1fc02a..80d34d2ce 100644 --- a/examples/distributed/primitives/example_get_warp.py +++ b/examples/distributed/primitives/example_get_warp.py @@ -8,15 +8,14 @@ from tilelang.distributed import init_dist tilelang.disable_cache() -os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log +os.environ["NCCL_DEBUG"] = "WARN" # silence NCCL log def get_kernel(M, num_rank, block_M, threads): - @T.prim_func def main( - dst: T.Tensor((M), "float32"), - src: T.Tensor((M), "float32"), + dst: T.Tensor((M), "float32"), + src: T.Tensor((M), "float32"), ): with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx): rank = T.alloc_local([1], "uint64") @@ -31,7 +30,8 @@ def main( dst=T.address_of(dst[warp_start]), size=warp_copy_size, src_pe=rank[0] ^ 1, - unroll_factor=4) + unroll_factor=4, + ) T.fence_sys() return main @@ -45,12 +45,8 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rank, num_ranks, group = init_dist(local_rank, num_local_ranks) allocator = tilelang.get_allocator( - size=2**25, - device="cuda", - is_distributed=True, - local_rank=local_rank, - num_local_ranks=num_local_ranks, - group=group) + size=2**25, device="cuda", is_distributed=True, local_rank=local_rank, num_local_ranks=num_local_ranks, group=group + ) kernel = tilelang.compile(get_kernel(M, num_ranks, BLOCK_M, threads)) kernel.initialize(allocator=allocator) if local_rank == 0: @@ -81,9 +77,8 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - '--num-processes', type=int, default=2, help='Number of processes to spawn (default: 2)') - parser.add_argument('--M', type=int, default=65536, help='M dimension') + parser.add_argument("--num-processes", type=int, default=2, help="Number of processes to spawn (default: 2)") + parser.add_argument("--M", type=int, default=65536, help="M dimension") args = parser.parse_args() num_processes = args.num_processes diff --git a/examples/distributed/primitives/example_put_block.py b/examples/distributed/primitives/example_put_block.py index 19e22b1ce..3b59c6c56 100644 --- a/examples/distributed/primitives/example_put_block.py +++ b/examples/distributed/primitives/example_put_block.py @@ -8,15 +8,14 @@ from tilelang.distributed import init_dist tilelang.disable_cache() -os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log +os.environ["NCCL_DEBUG"] = "WARN" # silence NCCL log def kernel_(M, num_rank, block_M, threads): - @T.prim_func def main( - dst: T.Tensor((M), "float32"), - src: T.Tensor((M), "float32"), + dst: T.Tensor((M), "float32"), + src: T.Tensor((M), "float32"), ): with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx): rank = T.alloc_local([1], "uint64") @@ -41,12 +40,8 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rank, num_ranks, group = init_dist(local_rank, num_local_ranks) allocator = tilelang.get_allocator( - size=2**25, - device="cuda", - is_distributed=True, - local_rank=local_rank, - num_local_ranks=num_local_ranks, - group=group) + size=2**25, device="cuda", is_distributed=True, local_rank=local_rank, num_local_ranks=num_local_ranks, group=group + ) kernel = tilelang.compile(kernel_(M, num_ranks, BLOCK_M, threads)) kernel.initialize(allocator=allocator) if local_rank == 0: @@ -77,9 +72,8 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - '--num-processes', type=int, default=2, help='Number of processes to spawn (default: 2)') - parser.add_argument('--M', type=int, default=65536, help='M dimension') + parser.add_argument("--num-processes", type=int, default=2, help="Number of processes to spawn (default: 2)") + parser.add_argument("--M", type=int, default=65536, help="M dimension") args = parser.parse_args() num_processes = args.num_processes diff --git a/examples/distributed/primitives/example_put_warp.py b/examples/distributed/primitives/example_put_warp.py index a0351f6bf..4d397bc9d 100644 --- a/examples/distributed/primitives/example_put_warp.py +++ b/examples/distributed/primitives/example_put_warp.py @@ -8,15 +8,14 @@ from tilelang.distributed import init_dist tilelang.disable_cache() -os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log +os.environ["NCCL_DEBUG"] = "WARN" # silence NCCL log def kernel_(M, num_rank, block_M, threads): - @T.prim_func def main( - dst: T.Tensor((M), "bfloat16"), - src: T.Tensor((M), "bfloat16"), + dst: T.Tensor((M), "bfloat16"), + src: T.Tensor((M), "bfloat16"), ): with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx): rank = T.alloc_local([1], "uint64") @@ -31,7 +30,8 @@ def main( dst=T.address_of(dst[warp_start]), size=warp_copy_size, dst_pe=rank[0] ^ 1, - unroll_factor=4) + unroll_factor=4, + ) return main @@ -44,12 +44,8 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rank, num_ranks, group = init_dist(local_rank, num_local_ranks) allocator = tilelang.get_allocator( - size=2**25, - device="cuda", - is_distributed=True, - local_rank=local_rank, - num_local_ranks=num_local_ranks, - group=group) + size=2**25, device="cuda", is_distributed=True, local_rank=local_rank, num_local_ranks=num_local_ranks, group=group + ) kernel = tilelang.compile(kernel_(M, num_ranks, BLOCK_M, threads)) kernel.initialize(allocator=allocator) if local_rank == 0: @@ -80,9 +76,8 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - '--num-processes', type=int, default=2, help='Number of processes to spawn (default: 2)') - parser.add_argument('--M', type=int, default=65536, help='M dimension') + parser.add_argument("--num-processes", type=int, default=2, help="Number of processes to spawn (default: 2)") + parser.add_argument("--M", type=int, default=65536, help="M dimension") args = parser.parse_args() num_processes = args.num_processes diff --git a/examples/distributed/primitives/example_remote_st.py b/examples/distributed/primitives/example_remote_st.py index 251e5e08b..05f95f50d 100644 --- a/examples/distributed/primitives/example_remote_st.py +++ b/examples/distributed/primitives/example_remote_st.py @@ -8,15 +8,14 @@ from tilelang.distributed import init_dist tilelang.disable_cache() -os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log +os.environ["NCCL_DEBUG"] = "WARN" # silence NCCL log def kernel_(M, num_rank, block_M, threads): - @T.prim_func def main( - dst: T.Tensor((M), "float32"), - src: T.Tensor((M), "float32"), + dst: T.Tensor((M), "float32"), + src: T.Tensor((M), "float32"), ): with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx): rank = T.alloc_local([1], "uint64") @@ -36,12 +35,8 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rank, num_ranks, group = init_dist(local_rank, num_local_ranks) allocator = tilelang.get_allocator( - size=2**25, - device="cuda", - is_distributed=True, - local_rank=local_rank, - num_local_ranks=num_local_ranks, - group=group) + size=2**25, device="cuda", is_distributed=True, local_rank=local_rank, num_local_ranks=num_local_ranks, group=group + ) kernel = tilelang.compile(kernel_(M, num_ranks, BLOCK_M, threads)) kernel.initialize(allocator=allocator) if local_rank == 0: @@ -72,9 +67,8 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - '--num-processes', type=int, default=2, help='Number of processes to spawn (default: 2)') - parser.add_argument('--M', type=int, default=1024, help='M dimension') + parser.add_argument("--num-processes", type=int, default=2, help="Number of processes to spawn (default: 2)") + parser.add_argument("--M", type=int, default=1024, help="M dimension") args = parser.parse_args() num_processes = args.num_processes diff --git a/examples/distributed/primitives/example_sync.py b/examples/distributed/primitives/example_sync.py index fa5949a3f..eba17c442 100644 --- a/examples/distributed/primitives/example_sync.py +++ b/examples/distributed/primitives/example_sync.py @@ -7,7 +7,7 @@ from tilelang.distributed import init_dist tilelang.disable_cache() -os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log +os.environ["NCCL_DEBUG"] = "WARN" # silence NCCL log def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): @@ -16,12 +16,8 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rank, num_ranks, group = init_dist(local_rank, num_local_ranks) allocator = tilelang.get_allocator( - size=2**25, - device="cuda", - is_distributed=True, - local_rank=local_rank, - num_local_ranks=num_local_ranks, - group=group) + size=2**25, device="cuda", is_distributed=True, local_rank=local_rank, num_local_ranks=num_local_ranks, group=group + ) dst = tilelang.tensor((M), torch.float32, allocator=allocator) srcs = tilelang.tensor((M), torch.float32, allocator=allocator, return_peers=True) @@ -39,9 +35,8 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - '--num-processes', type=int, default=2, help='Number of processes to spawn (default: 2)') - parser.add_argument('--M', type=int, default=65536, help='M dimension') + parser.add_argument("--num-processes", type=int, default=2, help="Number of processes to spawn (default: 2)") + parser.add_argument("--M", type=int, default=65536, help="M dimension") args = parser.parse_args() num_processes = args.num_processes diff --git a/examples/distributed/primitives/test_get_block.py b/examples/distributed/primitives/test_get_block.py index 6675965b0..63c52435a 100644 --- a/examples/distributed/primitives/test_get_block.py +++ b/examples/distributed/primitives/test_get_block.py @@ -5,6 +5,7 @@ import example_get_block +@tilelang.testing.requires_distributed @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_get_block(): diff --git a/examples/distributed/primitives/test_get_warp.py b/examples/distributed/primitives/test_get_warp.py index c482fa394..a542361fa 100644 --- a/examples/distributed/primitives/test_get_warp.py +++ b/examples/distributed/primitives/test_get_warp.py @@ -5,6 +5,7 @@ import example_get_warp +@tilelang.testing.requires_distributed @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_get_warp(): diff --git a/examples/distributed/primitives/test_put_block.py b/examples/distributed/primitives/test_put_block.py index 83ef08fb2..2e31de627 100644 --- a/examples/distributed/primitives/test_put_block.py +++ b/examples/distributed/primitives/test_put_block.py @@ -5,6 +5,7 @@ import example_put_block +@tilelang.testing.requires_distributed @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_put_block(): diff --git a/examples/distributed/primitives/test_put_warp.py b/examples/distributed/primitives/test_put_warp.py index de4cc1476..3b289cd27 100644 --- a/examples/distributed/primitives/test_put_warp.py +++ b/examples/distributed/primitives/test_put_warp.py @@ -5,6 +5,7 @@ import example_put_warp +@tilelang.testing.requires_distributed @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_put_warp(): diff --git a/examples/distributed/reduce_scatter.py b/examples/distributed/reduce_scatter.py index fcb8e997f..6ddc5707e 100644 --- a/examples/distributed/reduce_scatter.py +++ b/examples/distributed/reduce_scatter.py @@ -72,16 +72,13 @@ def __post_init__(self): for buf in self.signal_bufs: assert buf.shape[0] >= 2 * self.world_size - self.scatter_signal_bufs = [buf[:self.world_size] for buf in self.signal_bufs] - self.rs_per_node_signal_bufs = [ - buf[self.world_size:self.world_size * 2] for buf in self.signal_bufs - ] + self.scatter_signal_bufs = [buf[: self.world_size] for buf in self.signal_bufs] + self.rs_per_node_signal_bufs = [buf[self.world_size : self.world_size * 2] for buf in self.signal_bufs] for node_id in range(self.nnodes): self.scatter_signal_buf_list_for_each_node.append( - self.scatter_signal_bufs[self.local_rank][node_id * - self.local_world_size:(node_id + 1) * - self.local_world_size]) + self.scatter_signal_bufs[self.local_rank][node_id * self.local_world_size : (node_id + 1) * self.local_world_size] + ) def reset_barriers(self): self.signal_bufs[self.local_rank].fill_(0) @@ -93,9 +90,7 @@ def get_scatter_bufs_and_signal_for_each_node(self, input, node_id): M_per_node = M_per_rank * self.local_world_size M_start = node_id * M_per_node M_end = M_start + M_per_node - scatter_bufs_intra_node = [ - self.scatter_bufs[i][M_start:M_end] for i in range(self.local_world_size) - ] + scatter_bufs_intra_node = [self.scatter_bufs[i][M_start:M_end] for i in range(self.local_world_size)] return scatter_bufs_intra_node, self.scatter_signal_buf_list_for_each_node[node_id] @property @@ -123,50 +118,29 @@ def scatter_signal_buf(self) -> torch.Tensor: return self.scatter_signal_bufs[self.local_rank] -def create_reduce_scater_2d_ctx(max_M, - N, - rank, - world_size, - local_world_size, - dtype, - allocator, - overlap_with_gemm=True, - num_reduction_sms=15) -> ReduceScatter2DContext: +def create_reduce_scater_2d_ctx( + max_M, N, rank, world_size, local_world_size, dtype, allocator, overlap_with_gemm=True, num_reduction_sms=15 +) -> ReduceScatter2DContext: """ - for num_reduction_sms: tunable param, 16 are enough for H800 - For H800, we overlap local reduce and inter-node p2p with intra-node scatter. - The reduction kernel bandwidth is not a bottleneck if it exceeds 450GB, so only a few SMs are needed. - For machines with higher intra_node bandwidth(e.g. H100), we may need to increase the number of SMs or redesign overlapping. + for num_reduction_sms: tunable param, 16 are enough for H800 + For H800, we overlap local reduce and inter-node p2p with intra-node scatter. + The reduction kernel bandwidth is not a bottleneck if it exceeds 450GB, so only a few SMs are needed. + For machines with higher intra_node bandwidth(e.g. H100), we may need to increase the number of SMs or redesign overlapping. """ assert world_size % local_world_size == 0 assert max_M % world_size == 0 scatter_bufs = tilelang.tensor((max_M, N), dtype, allocator=allocator, return_peers=True) - rs_per_node_bufs = tilelang.tensor((max_M // local_world_size, N), - dtype, - allocator=allocator, - return_peers=True) - p2p_bufs = tilelang.tensor((max_M // local_world_size, N), - dtype, - allocator=allocator, - return_peers=True) + rs_per_node_bufs = tilelang.tensor((max_M // local_world_size, N), dtype, allocator=allocator, return_peers=True) + p2p_bufs = tilelang.tensor((max_M // local_world_size, N), dtype, allocator=allocator, return_peers=True) # signal_buf: scatter_signal | rs_per_node_signal num_signal_bufs = 2 - signal_bufs = tilelang.tensor((world_size * num_signal_bufs), - dtype=torch.uint32, - allocator=allocator, - return_peers=True) - symm_barriers = tilelang.tensor((local_world_size,), - torch.int32, - allocator=allocator, - return_peers=True) + signal_bufs = tilelang.tensor((world_size * num_signal_bufs), dtype=torch.uint32, allocator=allocator, return_peers=True) + symm_barriers = tilelang.tensor((local_world_size,), torch.int32, allocator=allocator, return_peers=True) symm_barriers[rank] = 0 - counter_signal_buf = tilelang.tensor((local_world_size), - dtype=torch.uint32, - allocator=allocator, - return_peers=True) + counter_signal_buf = tilelang.tensor((local_world_size), dtype=torch.uint32, allocator=allocator, return_peers=True) dist.barrier() @@ -191,29 +165,21 @@ def create_reduce_scater_2d_ctx(max_M, reduction_stream=reduction_stream, num_sync_sms=num_sync_sms, num_p2p_sms=num_p2p_sms, - num_reduction_sms=num_reduction_sms) + num_reduction_sms=num_reduction_sms, + ) return ctx @tilelang.jit -def kernel_ring_reduce_tma(M_per_rank, - N, - block_M, - block_N, - begin_idx, - num_splits, - threads, - persistent=False, - dtype="float16", - accum_dtype="float"): - +def kernel_ring_reduce_tma( + M_per_rank, N, block_M, block_N, begin_idx, num_splits, threads, persistent=False, dtype="float16", accum_dtype="float" +): @T.prim_func def _kernel_ring_reduce_tma( - C: T.Tensor((M_per_rank * num_splits, N), dtype), - output: T.Tensor((M_per_rank, N), dtype), + C: T.Tensor((M_per_rank * num_splits, N), dtype), + output: T.Tensor((M_per_rank, N), dtype), ): - with T.Kernel( - T.ceildiv(M_per_rank, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(M_per_rank, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by): data_shared = T.alloc_shared((block_M, block_N), dtype) init_shared = T.alloc_shared((block_M, block_N), dtype) data_local = T.alloc_fragment((block_M, block_N), dtype) @@ -233,10 +199,7 @@ def _kernel_ring_reduce_tma( return _kernel_ring_reduce_tma -def _wait_eq_cuda(signal_tensor: torch.Tensor, - signal: int, - stream: Optional[torch.cuda.Stream] = None, - require_i64=False): +def _wait_eq_cuda(signal_tensor: torch.Tensor, signal: int, stream: Optional[torch.cuda.Stream] = None, require_i64=False): stream = stream or torch.cuda.current_stream() if signal_tensor.dtype in (torch.int32, torch.uint32): (err,) = cuda.cuStreamWaitValue32( @@ -258,11 +221,13 @@ def _wait_eq_cuda(signal_tensor: torch.Tensor, raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}") -def intra_node_scatter(input_intra_node, - scatter_bufs_intra_node: List[torch.Tensor], - scatter_signal_buf_intra_node: torch.Tensor, - local_rank, - overlap_with_gemm=True): +def intra_node_scatter( + input_intra_node, + scatter_bufs_intra_node: List[torch.Tensor], + scatter_signal_buf_intra_node: torch.Tensor, + local_rank, + overlap_with_gemm=True, +): M, N = input_intra_node.shape local_world_size = len(scatter_bufs_intra_node) M_per_rank = M // local_world_size @@ -275,10 +240,8 @@ def intra_node_scatter(input_intra_node, # print(f"scatter_signal_buf_intra_node[remote_local_rank]: {scatter_signal_buf_intra_node[remote_local_rank]}") if overlap_with_gemm: _wait_eq_cuda(scatter_signal_buf_intra_node[remote_local_rank], 1, stream) - src = input_intra_node[remote_local_rank * M_per_rank:(remote_local_rank + 1) * - M_per_rank, :] - dst = scatter_bufs_intra_node[remote_local_rank][local_rank * M_per_rank:(local_rank + 1) * - M_per_rank, :] + src = input_intra_node[remote_local_rank * M_per_rank : (remote_local_rank + 1) * M_per_rank, :] + dst = scatter_bufs_intra_node[remote_local_rank][local_rank * M_per_rank : (local_rank + 1) * M_per_rank, :] with torch.cuda.stream(stream): dst.copy_(src) @@ -292,21 +255,15 @@ def ring_reduce_tma( ): total_M, N = input.shape M_per_split = total_M // num_splits - assert output.shape[ - 0] == M_per_split and total_M % num_splits == 0, f"{output.shape}, {total_M}, {num_splits}" + assert output.shape[0] == M_per_split and total_M % num_splits == 0, f"{output.shape}, {total_M}, {num_splits}" def alloc_fn(size, alignment, stream): return torch.empty(size, device="cuda", dtype=torch.int8) if num_sms == -1: ring_reduce_tma_func = kernel_ring_reduce_tma( - M_per_split, - N, - block_M=64, - block_N=64, - begin_idx=begin_idx, - num_splits=num_splits, - threads=128) + M_per_split, N, block_M=64, block_N=64, begin_idx=begin_idx, num_splits=num_splits, threads=128 + ) # if begin_idx == 0: # print(ring_reduce_tma_func.get_kernel_source()) ring_reduce_tma_func(input, output, stream=torch.cuda.current_stream().cuda_stream) @@ -345,9 +302,7 @@ def ring_reduce( raise NotImplementedError("Only Hopper ring reduce is implemented now.") -def reduce_scatter_for_each_node(input: torch.Tensor, - ctx: ReduceScatter2DContext, - output: Optional[torch.Tensor] = None): +def reduce_scatter_for_each_node(input: torch.Tensor, ctx: ReduceScatter2DContext, output: Optional[torch.Tensor] = None): world_size = ctx.world_size local_world_size = ctx.local_world_size local_rank = ctx.local_rank @@ -364,18 +319,14 @@ def reduce_scatter_for_each_node(input: torch.Tensor, stream = torch.cuda.current_stream() for n in range(0, nnodes): cur_node_id = (node_id + n + 1) % nnodes - input_intra_node = input[cur_node_id * M_per_node:(cur_node_id + 1) * M_per_node] - scatter_bufs_intra_node, scatter_signal_buf_intra_node = ctx.get_scatter_bufs_and_signal_for_each_node( - input, cur_node_id) + input_intra_node = input[cur_node_id * M_per_node : (cur_node_id + 1) * M_per_node] + scatter_bufs_intra_node, scatter_signal_buf_intra_node = ctx.get_scatter_bufs_and_signal_for_each_node(input, cur_node_id) intra_node_scatter( - input_intra_node, - scatter_bufs_intra_node, - scatter_signal_buf_intra_node, - local_rank, - overlap_with_gemm=ctx.overlap_with_gemm) + input_intra_node, scatter_bufs_intra_node, scatter_signal_buf_intra_node, local_rank, overlap_with_gemm=ctx.overlap_with_gemm + ) # ring reduce intra node - rs_buf_cur_node = rs_per_node_buf[M_per_rank * cur_node_id:(cur_node_id + 1) * M_per_rank] + rs_buf_cur_node = rs_per_node_buf[M_per_rank * cur_node_id : (cur_node_id + 1) * M_per_rank] # nvshmem_barrier_all_on_stream(stream) reduction_stream.wait_stream(stream) with torch.cuda.stream(reduction_stream): @@ -385,7 +336,8 @@ def reduce_scatter_for_each_node(input: torch.Tensor, reduce_out_buf, local_rank, local_world_size, - num_sms=-1 if n == nnodes - 1 else num_reduction_sms) + num_sms=-1 if n == nnodes - 1 else num_reduction_sms, + ) # inter node p2p if nnodes > 1: @@ -408,12 +360,10 @@ def reduce_scatter_for_each_node(input: torch.Tensor, stream.wait_stream(reduction_stream) if nnodes == 1: return output - return p2p_buf[:M_per_rank * nnodes] + return p2p_buf[: M_per_rank * nnodes] -def reduce_scatter_multi_node(input: torch.Tensor, - ctx: ReduceScatter2DContext, - output: Optional[torch.Tensor] = None): +def reduce_scatter_multi_node(input: torch.Tensor, ctx: ReduceScatter2DContext, output: Optional[torch.Tensor] = None): """ A hierarchical reduce-scatter implementation that overlaps the intra-node scatter with the local reduce and the inter-node p2p(after reduce). It also provides a rank-wise @@ -443,9 +393,7 @@ def reduce_scatter_multi_node(input: torch.Tensor, return output -def reduce_scatter_2d_op(input: torch.Tensor, - ctx: ReduceScatter2DContext, - output: Optional[torch.Tensor] = None): +def reduce_scatter_2d_op(input: torch.Tensor, ctx: ReduceScatter2DContext, output: Optional[torch.Tensor] = None): M, N = input.shape assert input.dtype == ctx.dtype assert ctx.max_M >= M and ctx.N == N diff --git a/examples/distributed/sp_ag_attention_intra_node.py b/examples/distributed/sp_ag_attention_intra_node.py index 421f13393..b66684ae6 100644 --- a/examples/distributed/sp_ag_attention_intra_node.py +++ b/examples/distributed/sp_ag_attention_intra_node.py @@ -10,10 +10,13 @@ @tilelang.jit -def barrier_all_blocks_sys_kernel(num_local_rank,): - +def barrier_all_blocks_sys_kernel( + num_local_rank, +): @T.prim_func - def main(barrier: T.Tensor((num_local_rank), "int32"),): + def main( + barrier: T.Tensor((num_local_rank), "int32"), + ): with T.Kernel(1, threads=32): T.barrier_blocks(barrier) @@ -25,28 +28,36 @@ def main(barrier: T.Tensor((num_local_rank), "int32"),): tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }, compile_flags=[ - "-O3", "-Wno-deprecated-declarations", "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", - "--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG" + "-O3", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--ptxas-options=-v,--register-usage-level=10", + "-DNDEBUG", ], ) -def flashattn(batch_size, - groups, - UQ, - UKV, - heads, - dim, - is_causal, - enable_zig_zag, - enable_specialized, - rank, - num_ranks, - block_M=64, - block_N=64, - num_stages=1, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) +def flashattn( + batch_size, + groups, + UQ, + UKV, + heads, + dim, + is_causal, + enable_zig_zag, + enable_specialized, + rank, + num_ranks, + block_M=64, + block_N=64, + num_stages=1, + threads=128, +): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [UQ, heads, dim] kv_shape = [UKV, head_kv, dim] @@ -83,8 +94,7 @@ def inner( global_offset_q: T.int32, kv_len_per_sp_block: T.int32, ): - T.copy(Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, head_idx, :], - Q_shared) + T.copy(Q_unpad[q_start_idx + bx * block_M : q_start_idx + (bx + 1) * block_M, head_idx, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -92,30 +102,30 @@ def inner( prefix_len = k_current_seqlen - q_current_seqlen * num_ranks loop_range = ( - T.ceildiv(prefix_len + global_offset_q + (bx + 1) * block_M, block_N) - if is_causal else T.ceildiv(k_current_seqlen, block_N)) + T.ceildiv(prefix_len + global_offset_q + (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(k_current_seqlen, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): sp_block_idx = (k * block_N) // kv_len_per_sp_block - wait_rank = ( - sp_block_idx if sp_block_idx < num_ranks else 2 * num_ranks - sp_block_idx - 1) - kv_load_offset = ((k * block_N) % kv_len_per_sp_block + - sp_block_idx // num_ranks * kv_len_per_sp_block + wait_rank * - (k_current_seqlen // num_ranks)) - T.copy( - K_unpad[k_start_idx + kv_load_offset:k_start_idx + kv_load_offset + block_N, - kv_head_idx, :], K_shared) + wait_rank = sp_block_idx if sp_block_idx < num_ranks else 2 * num_ranks - sp_block_idx - 1 + kv_load_offset = ( + (k * block_N) % kv_len_per_sp_block + + sp_block_idx // num_ranks * kv_len_per_sp_block + + wait_rank * (k_current_seqlen // num_ranks) + ) + T.copy(K_unpad[k_start_idx + kv_load_offset : k_start_idx + kv_load_offset + block_N, kv_head_idx, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.if_then_else( - (prefix_len + global_offset_q + bx * block_M + i < k * block_N + j) or - (bx * block_M + i >= q_current_seqlen or - k * block_N + j >= k_current_seqlen), -1e9, 0) + (prefix_len + global_offset_q + bx * block_M + i < k * block_N + j) + or (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), + -1e9, + 0, + ) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or - k * block_N + j >= k_current_seqlen), -1e9, 0) + acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), -1e9, 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -138,9 +148,7 @@ def inner( for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] - T.copy( - V_unpad[v_start_idx + kv_load_offset:v_start_idx + kv_load_offset + block_N, - kv_head_idx, :], V_shared) + T.copy(V_unpad[v_start_idx + kv_load_offset : v_start_idx + kv_load_offset + block_N, kv_head_idx, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @@ -154,17 +162,15 @@ def inner( @T.prim_func def main( - Q_unpad: T.Tensor(q_shape, dtype), - K_unpad: T.Tensor(kv_shape, dtype), - V_unpad: T.Tensor(kv_shape, dtype), - cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), - cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), - max_seqlen_q: T.int32, - Output_unpad: T.Tensor(o_shape, dtype), + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(kv_shape, dtype), + V_unpad: T.Tensor(kv_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), + cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), ): - with T.Kernel( - T.ceildiv(max_seqlen_q, block_M), heads, batch_size, - threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -194,24 +200,46 @@ def main( global_offset_q = q_current_seqlen * rank kv_len_per_sp_block = k_current_seqlen // num_ranks - inner(Q_unpad, K_unpad, V_unpad, Output_unpad, Q_shared, K_shared, V_shared, O_shared, - acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum, q_start_idx, k_start_idx, v_start_idx, q_current_seqlen, k_current_seqlen, - bx, head_idx, kv_head_idx, global_offset_q, kv_len_per_sp_block) + inner( + Q_unpad, + K_unpad, + V_unpad, + Output_unpad, + Q_shared, + K_shared, + V_shared, + O_shared, + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + scores_scale, + scores_sum, + logsum, + q_start_idx, + k_start_idx, + v_start_idx, + q_current_seqlen, + k_current_seqlen, + bx, + head_idx, + kv_head_idx, + global_offset_q, + kv_len_per_sp_block, + ) @T.prim_func def main_zigzag( - Q_unpad: T.Tensor(q_shape, dtype), - K_unpad: T.Tensor(kv_shape, dtype), - V_unpad: T.Tensor(kv_shape, dtype), - cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), - cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), - max_seqlen_q: T.int32, - Output_unpad: T.Tensor(o_shape, dtype), + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(kv_shape, dtype), + V_unpad: T.Tensor(kv_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), + cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), ): - with T.Kernel( - T.ceildiv(max_seqlen_q, block_M), heads, batch_size, - threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -239,27 +267,51 @@ def main_zigzag( k_current_seqlen = k_end_idx - k_start_idx half_q_shard_len = q_current_seqlen // 2 - global_offset_q = rank * half_q_shard_len if bx * block_M < half_q_shard_len else \ - q_current_seqlen * num_ranks - (rank + 2) * half_q_shard_len + global_offset_q = ( + rank * half_q_shard_len if bx * block_M < half_q_shard_len else q_current_seqlen * num_ranks - (rank + 2) * half_q_shard_len + ) kv_len_per_sp_block = k_current_seqlen // (2 * num_ranks) - inner(Q_unpad, K_unpad, V_unpad, Output_unpad, Q_shared, K_shared, V_shared, O_shared, - acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum, q_start_idx, k_start_idx, v_start_idx, q_current_seqlen, k_current_seqlen, - bx, head_idx, kv_head_idx, global_offset_q, kv_len_per_sp_block) + inner( + Q_unpad, + K_unpad, + V_unpad, + Output_unpad, + Q_shared, + K_shared, + V_shared, + O_shared, + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + scores_scale, + scores_sum, + logsum, + q_start_idx, + k_start_idx, + v_start_idx, + q_current_seqlen, + k_current_seqlen, + bx, + head_idx, + kv_head_idx, + global_offset_q, + kv_len_per_sp_block, + ) @T.prim_func def main_specialized( - Q_unpad: T.Tensor(q_shape, dtype), - K_unpad: T.Tensor(kv_shape, dtype), - V_unpad: T.Tensor(kv_shape, dtype), - cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), - cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), - max_seqlen_q: T.int32, - Output_unpad: T.Tensor(o_shape, dtype), + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(kv_shape, dtype), + V_unpad: T.Tensor(kv_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), + cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), ): - with T.Kernel( - T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=384) as (bx_, by, bz): + with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=384) as (bx_, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -279,10 +331,12 @@ def main_specialized( bar_k_release = T.alloc_barrier(arrive_count=256) bar_v_release = T.alloc_barrier(arrive_count=256) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - }) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + } + ) batch_idx = bz head_idx = by @@ -311,7 +365,9 @@ def main_specialized( prefix_len = k_current_seqlen - q_current_seqlen * num_ranks loop_range = ( T.ceildiv(prefix_len + global_offset_q + (bx + 1) * block_M, block_N) - if is_causal else T.ceildiv(k_current_seqlen, block_N)) + if is_causal + else T.ceildiv(k_current_seqlen, block_N) + ) T.barrier_wait(bar_q_ready, 0) for k in T.serial(loop_range): @@ -319,21 +375,18 @@ def main_specialized( for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.if_then_else( (prefix_len + global_offset_q + bx * block_M + i < k * block_N + j) - or (bx * block_M + i >= q_current_seqlen or - k * block_N + j >= k_current_seqlen), -1e9, 0) + or (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), + -1e9, + 0, + ) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or - k * block_N + j >= k_current_seqlen), - -1e9, 0) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), -1e9, 0 + ) T.barrier_wait(bar_k_ready, k % 2) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.barrier_arrive(bar_k_release) T.copy(scores_max, scores_max_prev) @@ -371,35 +424,30 @@ def main_specialized( prefix_len = k_current_seqlen - q_current_seqlen * num_ranks loop_range = ( T.ceildiv(prefix_len + global_offset_q + (bx + 1) * block_M, block_N) - if is_causal else T.ceildiv(k_current_seqlen, block_N)) - T.copy( - Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, - head_idx, :], Q_shared) + if is_causal + else T.ceildiv(k_current_seqlen, block_N) + ) + T.copy(Q_unpad[q_start_idx + bx * block_M : q_start_idx + (bx + 1) * block_M, head_idx, :], Q_shared) T.barrier_arrive(bar_q_ready) for k in T.serial(loop_range): T.barrier_wait(bar_k_release, (k + 1) % 2) - T.copy( - K_unpad[k_start_idx + (k * block_N):k_start_idx + (k * block_N) + block_N, - kv_head_idx, :], K_shared) + T.copy(K_unpad[k_start_idx + (k * block_N) : k_start_idx + (k * block_N) + block_N, kv_head_idx, :], K_shared) T.barrier_arrive(bar_k_ready) T.barrier_wait(bar_v_release, (k + 1) % 2) - T.copy( - V_unpad[v_start_idx + (k * block_N):v_start_idx + (k * block_N) + block_N, - kv_head_idx, :], V_shared) + T.copy(V_unpad[v_start_idx + (k * block_N) : v_start_idx + (k * block_N) + block_N, kv_head_idx, :], V_shared) T.barrier_arrive(bar_v_ready) @T.prim_func def main_specialized_zigzag( - Q_unpad: T.Tensor(q_shape, dtype), - K_unpad: T.Tensor(kv_shape, dtype), - V_unpad: T.Tensor(kv_shape, dtype), - cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), - cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), - max_seqlen_q: T.int32, - Output_unpad: T.Tensor(o_shape, dtype), + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(kv_shape, dtype), + V_unpad: T.Tensor(kv_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), + cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), ): - with T.Kernel( - T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=384) as (bx_, by, bz): + with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=384) as (bx_, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -420,10 +468,12 @@ def main_specialized_zigzag( bar_k_release = T.alloc_barrier(arrive_count=256) bar_v_release = T.alloc_barrier(arrive_count=256) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - }) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + } + ) batch_idx = bz head_idx = by @@ -441,8 +491,9 @@ def main_specialized_zigzag( bx = T.ceildiv(max_seqlen_q, block_M) - bx_ - 1 half_q_shard_len = q_current_seqlen // 2 - global_offset_q = rank * half_q_shard_len if bx * block_M < half_q_shard_len else \ - q_current_seqlen * num_ranks - (rank + 2) * half_q_shard_len + global_offset_q = ( + rank * half_q_shard_len if bx * block_M < half_q_shard_len else q_current_seqlen * num_ranks - (rank + 2) * half_q_shard_len + ) kv_len_per_sp_block = k_current_seqlen // (2 * num_ranks) tid = T.get_thread_binding(0) @@ -455,7 +506,9 @@ def main_specialized_zigzag( prefix_len = k_current_seqlen - q_current_seqlen * num_ranks loop_range = ( T.ceildiv(prefix_len + global_offset_q + (bx + 1) * block_M, block_N) - if is_causal else T.ceildiv(k_current_seqlen, block_N)) + if is_causal + else T.ceildiv(k_current_seqlen, block_N) + ) T.barrier_wait(bar_q_ready, 0) for k in T.serial(loop_range): @@ -463,21 +516,18 @@ def main_specialized_zigzag( for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.if_then_else( (prefix_len + global_offset_q + bx * block_M + i < k * block_N + j) - or (bx * block_M + i >= q_current_seqlen or - k * block_N + j >= k_current_seqlen), -1e9, 0) + or (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), + -1e9, + 0, + ) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or - k * block_N + j >= k_current_seqlen), - -1e9, 0) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), -1e9, 0 + ) T.barrier_wait(bar_k_ready, k % 2) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.barrier_arrive(bar_k_release) T.copy(scores_max, scores_max_prev) @@ -515,28 +565,24 @@ def main_specialized_zigzag( prefix_len = k_current_seqlen - q_current_seqlen * num_ranks loop_range = ( T.ceildiv(prefix_len + global_offset_q + (bx + 1) * block_M, block_N) - if is_causal else T.ceildiv(k_current_seqlen, block_N)) - T.copy( - Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, - head_idx, :], Q_shared) + if is_causal + else T.ceildiv(k_current_seqlen, block_N) + ) + T.copy(Q_unpad[q_start_idx + bx * block_M : q_start_idx + (bx + 1) * block_M, head_idx, :], Q_shared) T.barrier_arrive(bar_q_ready) for k in T.serial(loop_range): sp_block_idx = (k * block_N) // kv_len_per_sp_block - wait_rank = ( - sp_block_idx if sp_block_idx < num_ranks else 2 * num_ranks - sp_block_idx - - 1) - kv_load_offset = ((k * block_N) % kv_len_per_sp_block + - sp_block_idx // num_ranks * kv_len_per_sp_block + wait_rank * - (k_current_seqlen // num_ranks)) + wait_rank = sp_block_idx if sp_block_idx < num_ranks else 2 * num_ranks - sp_block_idx - 1 + kv_load_offset = ( + (k * block_N) % kv_len_per_sp_block + + sp_block_idx // num_ranks * kv_len_per_sp_block + + wait_rank * (k_current_seqlen // num_ranks) + ) T.barrier_wait(bar_k_release, (k + 1) % 2) - T.copy( - K_unpad[k_start_idx + kv_load_offset:k_start_idx + kv_load_offset + block_N, - kv_head_idx, :], K_shared) + T.copy(K_unpad[k_start_idx + kv_load_offset : k_start_idx + kv_load_offset + block_N, kv_head_idx, :], K_shared) T.barrier_arrive(bar_k_ready) T.barrier_wait(bar_v_release, (k + 1) % 2) - T.copy( - V_unpad[v_start_idx + kv_load_offset:v_start_idx + kv_load_offset + block_N, - kv_head_idx, :], V_shared) + T.copy(V_unpad[v_start_idx + kv_load_offset : v_start_idx + kv_load_offset + block_N, kv_head_idx, :], V_shared) T.barrier_arrive(bar_v_ready) if enable_specialized: @@ -571,16 +617,14 @@ def create_sp_ag_attention_context_intra_node( device, allocator, ): - ag_k_buffers = tilelang.tensor((batch_size * max_seqlen_k, kv_head, head_dim), - dtype=input_dtype, - allocator=allocator, - return_peers=True) + ag_k_buffers = tilelang.tensor( + (batch_size * max_seqlen_k, kv_head, head_dim), dtype=input_dtype, allocator=allocator, return_peers=True + ) ag_k_buffer = ag_k_buffers[rank] - ag_v_buffers = tilelang.tensor((batch_size * max_seqlen_k, kv_head, head_dim), - dtype=input_dtype, - allocator=allocator, - return_peers=True) + ag_v_buffers = tilelang.tensor( + (batch_size * max_seqlen_k, kv_head, head_dim), dtype=input_dtype, allocator=allocator, return_peers=True + ) ag_v_buffer = ag_v_buffers[rank] attn_output_buffer = torch.empty( @@ -603,14 +647,16 @@ def create_sp_ag_attention_context_intra_node( ag_v_buffer=ag_v_buffer, attn_output_buffer=attn_output_buffer, ag_stream=ag_stream, - barrier=barrier) + barrier=barrier, + ) return ctx def barrier_all_on_stream(barrier: torch.Tensor, stream: torch.cuda.Stream, world_size: int): barrier_all_blocks_sys_func = barrier_all_blocks_sys_kernel(world_size) - barrier_all_blocks_sys_func(barrier, stream=stream.cuda_stream) + with torch.cuda.stream(stream): + barrier_all_blocks_sys_func(barrier) def cp_engine_producer_kv_all_gather( @@ -681,12 +727,12 @@ def _cp_engine_copy_data(dst_ptr, src_ptr, cp_size, stream): for offset in range(1, world_size): src_rank = (rank + offset) % world_size - k_src_ptr = (k_shards[src_rank].data_ptr() + byte_start // world_size) - k_dst_ptr = (k_buffers[rank].data_ptr() + byte_start + src_rank * byte_per_rank) + k_src_ptr = k_shards[src_rank].data_ptr() + byte_start // world_size + k_dst_ptr = k_buffers[rank].data_ptr() + byte_start + src_rank * byte_per_rank _cp_engine_copy_data(k_dst_ptr, k_src_ptr, cp_size, ag_stream) - v_src_ptr = (v_shards[src_rank].data_ptr() + byte_start // world_size) - v_dst_ptr = (v_buffers[rank].data_ptr() + byte_start + src_rank * byte_per_rank) + v_src_ptr = v_shards[src_rank].data_ptr() + byte_start // world_size + v_dst_ptr = v_buffers[rank].data_ptr() + byte_start + src_rank * byte_per_rank _cp_engine_copy_data(v_dst_ptr, v_src_ptr, cp_size, ag_stream) barrier_all_on_stream(barrier, ag_stream, world_size) @@ -710,7 +756,6 @@ def fused_sp_ag_attn_intra_node( enable_specialized: bool = False, print_source: bool = False, ): - BLOCK_M = 128 BLOCK_N = 128 num_stages = 2 @@ -764,20 +809,14 @@ def fused_sp_ag_attn_intra_node( block_M=BLOCK_M, block_N=BLOCK_N, num_stages=num_stages, - threads=threads) + threads=threads, + ) if rank == 0 and print_source: print(kernel.get_kernel_source()) - kernel( - q_shard, - ag_k, - ag_v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - output, - stream=compute_stream.cuda_stream) + with torch.cuda.stream(compute_stream): + kernel(q_shard, ag_k, ag_v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, output) compute_stream.wait_stream(ctx.ag_stream) barrier_all_on_stream(ctx.barrier, compute_stream, world_size) diff --git a/examples/distributed/triton_sp.py b/examples/distributed/triton_sp.py index d8236259b..1b99a5fac 100644 --- a/examples/distributed/triton_sp.py +++ b/examples/distributed/triton_sp.py @@ -97,8 +97,7 @@ def store_v4_b32_cond(ptr, val0, val1, val2, val3, mask, _semantic=None): } """, constraints=("=r,l,r,r,r,r,r"), # no use output - args=[ptr, val0, val1, val2, val3, - mask.to(tl.int32, _semantic=_semantic)], + args=[ptr, val0, val1, val2, val3, mask.to(tl.int32, _semantic=_semantic)], dtype=tl.int32, is_pure=False, pack=1, @@ -125,7 +124,7 @@ def _matmul_launch_metadata(grid, kernel, args): bytes_per_elem = args["c_ptr"].element_size() else: bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 - ret[f"flops{bytes_per_elem * 8}"] = 2. * M * N * K + ret[f"flops{bytes_per_elem * 8}"] = 2.0 * M * N * K ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N) return ret @@ -138,13 +137,12 @@ def _kernel_consumer_gemm_persistent_repr(proxy): c_dtype = proxy.signature["c_ptr"].lstrip("*") BM, BN, BK = constexprs["BLOCK_SIZE_M"], constexprs["BLOCK_SIZE_N"], constexprs["BLOCK_SIZE_K"] - return f"cutlass_triton3x_sm{cap_major}{cap_minor}_a2a_consumer_gemm_persistent_tensorop_{a_dtype}_{b_dtype}_{c_dtype}_{BM}x{BN}x{BK}_ntn" + return ( + f"cutlass_triton3x_sm{cap_major}{cap_minor}_a2a_consumer_gemm_persistent_tensorop_{a_dtype}_{b_dtype}_{c_dtype}_{BM}x{BN}x{BK}_ntn" + ) -@triton.jit( - do_not_specialize=["sp_rank"], - launch_metadata=_matmul_launch_metadata, - repr=_kernel_consumer_gemm_persistent_repr) +@triton.jit(do_not_specialize=["sp_rank"], launch_metadata=_matmul_launch_metadata, repr=_kernel_consumer_gemm_persistent_repr) def matmul_kernel_descriptor_persistent( a_ptr, b_ptr, @@ -176,13 +174,10 @@ def matmul_kernel_descriptor_persistent( tl.static_assert(K % sp_size == 0, f"K {K} must be divisible by sp_size {sp_size}") K_per_sp_rank: tl.constexpr = K // sp_size - tl.static_assert( - K_per_sp_rank % BLOCK_SIZE_K == 0, - f"K_per_sp_rank {K_per_sp_rank} must be divisible by BLOCK_SIZE_K {BLOCK_SIZE_K}") + tl.static_assert(K_per_sp_rank % BLOCK_SIZE_K == 0, f"K_per_sp_rank {K_per_sp_rank} must be divisible by BLOCK_SIZE_K {BLOCK_SIZE_K}") k_tiles: tl.constexpr = K // BLOCK_SIZE_K - tl.static_assert(A2A_TILE_N % BLOCK_SIZE_K == 0, - f"A2A_TILE_N {A2A_TILE_N} must be divisible by BLOCK_SIZE_N {BLOCK_SIZE_K}") + tl.static_assert(A2A_TILE_N % BLOCK_SIZE_K == 0, f"A2A_TILE_N {A2A_TILE_N} must be divisible by BLOCK_SIZE_N {BLOCK_SIZE_K}") NUM_K_PER_TILE: tl.constexpr = A2A_TILE_N // BLOCK_SIZE_K # This is used for k-swizzle # k_tiles_per_rank: tl.constexpr = K_per_sp_rank // BLOCK_SIZE_K @@ -212,10 +207,8 @@ def matmul_kernel_descriptor_persistent( tile_id_c = start_pid - NUM_GEMM_SMS num_pid_in_group = GROUP_SIZE_M * num_pid_n - for tile_id in tl.range( - start_pid, num_tiles, NUM_GEMM_SMS, flatten=False, warp_specialize=WARP_SPECIALIZE): - pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, - NUM_GEMM_SMS) + for tile_id in tl.range(start_pid, num_tiles, NUM_GEMM_SMS, flatten=False, warp_specialize=WARP_SPECIALIZE): + pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_GEMM_SMS) offs_am = pid_m * BLOCK_SIZE_M offs_bn = pid_n * BLOCK_SIZE_N @@ -235,12 +228,12 @@ def matmul_kernel_descriptor_persistent( if ki % NUM_K_PER_TILE == 0: for chunk_id in range(chunk_beg, chunk_end + 1): token = dl.wait( - gemm_barrier_ptr + chunk_id * (k_tiles // NUM_K_PER_TILE) + - ki // NUM_K_PER_TILE, + gemm_barrier_ptr + chunk_id * (k_tiles // NUM_K_PER_TILE) + ki // NUM_K_PER_TILE, 1, scope="gpu", semantic="acquire", - waitValue=1) + waitValue=1, + ) a_desc = dl.consume_token(a_desc, token) offs_k = ki * BLOCK_SIZE_K a = a_desc.load([offs_am, offs_k]) @@ -248,15 +241,13 @@ def matmul_kernel_descriptor_persistent( accumulator = tl.dot(a, b.T, accumulator) tile_id_c += NUM_GEMM_SMS - pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, - NUM_GEMM_SMS) + pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_GEMM_SMS) offs_cm = pid_m * BLOCK_SIZE_M offs_cn = pid_n * BLOCK_SIZE_N if HAS_BIAS: offs_bias_n = tl.arange(0, BLOCK_SIZE_N) - bias_data = tl.load( - bias_ptr + offs_cn + offs_bias_n, mask=(offs_cn + offs_bias_n < N)).to(tl.float32) + bias_data = tl.load(bias_ptr + offs_cn + offs_bias_n, mask=(offs_cn + offs_bias_n < N)).to(tl.float32) accumulator = accumulator + bias_data[None, :] if EPILOGUE_SUBTILE: @@ -272,15 +263,7 @@ def matmul_kernel_descriptor_persistent( c_desc.store([offs_cm, offs_cn], c) -def matmul_descriptor_persistent(sp_rank, - sp_size, - a, - b, - bias, - c, - gemm_barrier, - gemm_config: triton.Config, - warp_specialize: bool = False): +def matmul_descriptor_persistent(sp_rank, sp_size, a, b, bias, c, gemm_barrier, gemm_config: triton.Config, warp_specialize: bool = False): # Check constraints. assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed assert a.dtype == b.dtype, "Incompatible dtypes" @@ -295,8 +278,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): triton.set_allocator(alloc_fn) def grid(META): - return (min(META["NUM_GEMM_SMS"], - triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])),) + return (min(META["NUM_GEMM_SMS"], triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])),) matmul_kernel_descriptor_persistent[grid]( a, @@ -350,8 +332,7 @@ def kernel_all2all_push_intra_node_nvl( if FUSE_SYNC: tl.static_assert(SUPPORT_ATOMIC, "FUSE_SYNC requires SUPPORT_ATOMIC to be True") - barrier_all_intra_node_atomic_cas_block(sp_rank, rank, sp_size, - intra_node_sync_buf_ptr + pid * sp_size) + barrier_all_intra_node_atomic_cas_block(sp_rank, rank, sp_size, intra_node_sync_buf_ptr + pid * sp_size) for i in tl.static_range(sp_size + 1): tl.store(cum_seqlen_gpu_ptr + i, cum_seqlen_cpu_tuple[i]) @@ -363,13 +344,11 @@ def kernel_all2all_push_intra_node_nvl( offs_n = tl.arange(0, BLOCK_N // VEC) if sp_size <= NUM_COMM_SM: - tl.static_assert(NUM_COMM_SM % sp_size == 0, - f"NUM_COMM_SM {NUM_COMM_SM} must be divisible by sp_size {sp_size}") + tl.static_assert(NUM_COMM_SM % sp_size == 0, f"NUM_COMM_SM {NUM_COMM_SM} must be divisible by sp_size {sp_size}") NUM_SM_PER_SP: tl.constexpr = NUM_COMM_SM // sp_size NUM_SP_PER_SM: tl.constexpr = 1 else: - tl.static_assert(sp_size % NUM_COMM_SM == 0, - f"sp_size {sp_size} must be divisible by NUM_COMM_SM {NUM_COMM_SM}") + tl.static_assert(sp_size % NUM_COMM_SM == 0, f"sp_size {sp_size} must be divisible by NUM_COMM_SM {NUM_COMM_SM}") NUM_SM_PER_SP: tl.constexpr = 1 NUM_SP_PER_SM: tl.constexpr = sp_size // NUM_COMM_SM @@ -384,8 +363,8 @@ def kernel_all2all_push_intra_node_nvl( remote_seq_len = seq_end - seq_beg num_tile_m = tl.cdiv(remote_seq_len, BLOCK_M) tl.static_assert( - local_head * head_dim % BLOCK_N == 0, - f"local_head * head_dim {local_head * head_dim} must be divisible by BLOCK_N {BLOCK_N}") + local_head * head_dim % BLOCK_N == 0, f"local_head * head_dim {local_head * head_dim} must be divisible by BLOCK_N {BLOCK_N}" + ) num_tile_n = local_head * head_dim // BLOCK_N for tile_id_m_outer_n_tail in range(0, tl.cdiv(num_tile_m, GROUP_SIZE_M) * num_tile_n): @@ -398,32 +377,32 @@ def kernel_all2all_push_intra_node_nvl( attn_mask_m = attn_offs_m < seq_end attn_offs_n = tile_id_n_tail * BLOCK_N + offs_n * VEC data0, data1, data2, data3 = load_v4_b32_cond( - attn_out_ptr + attn_offs_m[:, None] * local_head * head_dim + - attn_offs_n[None, :], - mask=attn_mask_m[:, None]) + attn_out_ptr + attn_offs_m[:, None] * local_head * head_dim + attn_offs_n[None, :], mask=attn_mask_m[:, None] + ) out_offs_m = tile_id_m_tail * BLOCK_M + offs_m out_mask_m = out_offs_m < remote_seq_len out_offs_n = sp_rank * local_head * head_dim + tile_id_n_tail * BLOCK_N + offs_n * VEC store_v4_b32_cond( - remote_a2a_out_ptr + out_offs_m[:, None] * global_head * head_dim + - out_offs_n[None, :], + remote_a2a_out_ptr + out_offs_m[:, None] * global_head * head_dim + out_offs_n[None, :], data0, data1, data2, data3, - mask=out_mask_m[:, None]) + mask=out_mask_m[:, None], + ) if not SKIP_BARRIER: __syncthreads() - notify_barrier_ptr = remote_barrier_ptr + tile_id_m_tail * num_tile_n * sp_size + sp_rank * num_tile_n + tile_id_n_tail + notify_barrier_ptr = ( + remote_barrier_ptr + tile_id_m_tail * num_tile_n * sp_size + sp_rank * num_tile_n + tile_id_n_tail + ) thread_idx = tid(0) if thread_idx == 0: st(notify_barrier_ptr, 1, scope="sys", semantic="release") class SpUlysessOAll2AllGemmKernel: - def __init__( self, world_group: torch.distributed.ProcessGroup, @@ -492,14 +471,13 @@ def finalize(self): def init_symm_buffer(self): max_local_seq = self.max_seqlen // self.sp_size self._comm_output_buffer = nvshmem_create_tensor( - [self.max_num_comm_buf, self.max_batch, max_local_seq, self.num_head * self.head_dim], - self.input_dtype) + [self.max_num_comm_buf, self.max_batch, max_local_seq, self.num_head * self.head_dim], self.input_dtype + ) self._barrier_buffer = nvshmem_create_tensor( - [triton.cdiv(self.max_batch * self.max_seqlen, self.BLOCK_SIZE_M) * self.num_head], - torch.int32) + [triton.cdiv(self.max_batch * self.max_seqlen, self.BLOCK_SIZE_M) * self.num_head], torch.int32 + ) self._barrier_buffer.zero_() - self._intra_node_sync_buffer = nvshmem_create_tensor([self.sp_size * self.max_sms], - torch.int32) + self._intra_node_sync_buffer = nvshmem_create_tensor([self.sp_size * self.max_sms], torch.int32) self._intra_node_sync_buffer.zero_() self._sp_group_sync_buffer = nvshmem_create_tensor([self.world_size], torch.int32) self._sp_group_sync_buffer.zero_() @@ -525,30 +503,31 @@ def sp_group_barrier_all_intra_node(self, stream=None): stream = torch.cuda.current_stream() if stream is None else stream sp_local_rank = self.local_rank % self.sp_size with torch.cuda.stream(stream): - barrier_all_intra_node_atomic_cas_block[(1,)](sp_local_rank, self.rank, self.sp_size, - self._sp_group_sync_buffer) + barrier_all_intra_node_atomic_cas_block[(1,)](sp_local_rank, self.rank, self.sp_size, self._sp_group_sync_buffer) def reset_cusum_seq_lens(self, local_seqlen, seq_lens_cpu=None): if seq_lens_cpu is None: seq_lens_cpu = [local_seqlen] * self.sp_size else: seq_lens_cpu = seq_lens_cpu.tolist() - assert local_seqlen == seq_lens_cpu[ - self.local_rank % self. - sp_size], f"local_seqlen {local_seqlen} != seq_lens_cpu[{self.local_rank % self.sp_size}]={seq_lens_cpu[self.local_rank % self.sp_size]}" + assert local_seqlen == seq_lens_cpu[self.local_rank % self.sp_size], ( + f"local_seqlen {local_seqlen} != seq_lens_cpu[{self.local_rank % self.sp_size}]={seq_lens_cpu[self.local_rank % self.sp_size]}" + ) cum_seqlen_cpu = [0] + list(itertools.accumulate(seq_lens_cpu)) self._cum_seq_len_cpu_tuple = tuple(cum_seqlen_cpu) - def forward(self, - inputs: torch.Tensor, - weight: torch.Tensor, - seq_lens_cpu: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - output: Optional[torch.Tensor] = None, - a2a_output: Optional[torch.Tensor] = None, - transpose_weight: bool = False, - num_comm_sms: int = -1, - sm_margin: int = 0): + def forward( + self, + inputs: torch.Tensor, + weight: torch.Tensor, + seq_lens_cpu: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + output: Optional[torch.Tensor] = None, + a2a_output: Optional[torch.Tensor] = None, + transpose_weight: bool = False, + num_comm_sms: int = -1, + sm_margin: int = 0, + ): if num_comm_sms == -1: num_comm_sms = self.world_size assert num_comm_sms >= 0, "num_comm_sms must be non-negative" @@ -582,7 +561,7 @@ def forward(self, self.reset_cusum_seq_lens(local_seqlen=local_seq_len, seq_lens_cpu=seq_lens_cpu) - gemm_input_a = self._comm_output_buffer.view(-1)[:M * K].view([M, K]) + gemm_input_a = self._comm_output_buffer.view(-1)[: M * K].view([M, K]) cur_stream = torch.cuda.current_stream() @@ -618,46 +597,42 @@ def forward(self, ) if output is None: - output = torch.empty([bs, local_seq_len, N], - device=inputs.device, - dtype=self.output_dtype) + output = torch.empty([bs, local_seq_len, N], device=inputs.device, dtype=self.output_dtype) assert len(output.shape) == 3, f"output must be 4D tensor, got {len(output)}D" - assert output.shape[ - 0] == bs, f"output batch size {output.shape[0]} must be equal to input batch size {bs}" - assert output.shape[ - 1] == local_seq_len, f"output seq_len {output.shape[1]} must be equal to local_seq_len {local_seq_len}" - assert output.shape[ - 2] == N, f"output head {output.shape[2]} must be equal to output size {N}" + assert output.shape[0] == bs, f"output batch size {output.shape[0]} must be equal to input batch size {bs}" + assert output.shape[1] == local_seq_len, f"output seq_len {output.shape[1]} must be equal to local_seq_len {local_seq_len}" + assert output.shape[2] == N, f"output head {output.shape[2]} must be equal to output size {N}" assert output.is_contiguous(), f"output must be contiguous, got {output.shape}" - assert self.max_gemm_sms - num_comm_sms - sm_margin > 0, f"max_gemm_sms {self.max_gemm_sms} - num_comm_sms {num_comm_sms} - sm_margin {sm_margin} must be greater than 0" + assert self.max_gemm_sms - num_comm_sms - sm_margin > 0, ( + f"max_gemm_sms {self.max_gemm_sms} - num_comm_sms {num_comm_sms} - sm_margin {sm_margin} must be greater than 0" + ) gemm_config = triton.Config( { - 'BLOCK_SIZE_M': self.BLOCK_SIZE_M, - 'BLOCK_SIZE_N': self.BLOCK_SIZE_N, - 'BLOCK_SIZE_K': self.BLOCK_SIZE_K, - 'GROUP_SIZE_M': self.GROUP_SIZE_M, - 'A2A_TILE_M': self.A2A_TILE_M, - 'A2A_TILE_N': self.A2A_TILE_N, - 'NUM_GEMM_SMS': self.max_gemm_sms - num_comm_sms - sm_margin + "BLOCK_SIZE_M": self.BLOCK_SIZE_M, + "BLOCK_SIZE_N": self.BLOCK_SIZE_N, + "BLOCK_SIZE_K": self.BLOCK_SIZE_K, + "GROUP_SIZE_M": self.GROUP_SIZE_M, + "A2A_TILE_M": self.A2A_TILE_M, + "A2A_TILE_N": self.A2A_TILE_N, + "NUM_GEMM_SMS": self.max_gemm_sms - num_comm_sms - sm_margin, }, num_stages=self.num_stages, - num_warps=self.num_warps) + num_warps=self.num_warps, + ) with torch.cuda.stream(self.compute_stream): - matmul_descriptor_persistent(self.sp_rank, self.sp_size, gemm_input_a, weight, bias, - output, self._barrier_buffer, gemm_config, - self.warp_specialize) + matmul_descriptor_persistent( + self.sp_rank, self.sp_size, gemm_input_a, weight, bias, output, self._barrier_buffer, gemm_config, self.warp_specialize + ) if a2a_output is not None: - assert a2a_output.shape == ( - bs, local_seq_len, local_head * self.sp_size, head_dim - ), f"a2a_output shape {a2a_output.shape} must be equal to (bs, local_seq_len, local_head * self.sp_size, head_dim) ({bs}, {local_seq_len}, {local_head * self.sp_size}, {head_dim})" - assert a2a_output.is_contiguous( - ), f"a2a_output must be contiguous, got {a2a_output.shape}" - a2a_output.copy_( - gemm_input_a.view(bs, local_seq_len, local_head * self.sp_size * head_dim)) + assert a2a_output.shape == (bs, local_seq_len, local_head * self.sp_size, head_dim), ( + f"a2a_output shape {a2a_output.shape} must be equal to (bs, local_seq_len, local_head * self.sp_size, head_dim) ({bs}, {local_seq_len}, {local_head * self.sp_size}, {head_dim})" + ) + assert a2a_output.is_contiguous(), f"a2a_output must be contiguous, got {a2a_output.shape}" + a2a_output.copy_(gemm_input_a.view(bs, local_seq_len, local_head * self.sp_size * head_dim)) ret = (output, a2a_output) else: ret = (output,) @@ -701,7 +676,7 @@ def post_attn_a2a( self.reset_cusum_seq_lens(local_seqlen=local_seq_len, seq_lens_cpu=seq_lens_cpu) assert comm_buf_idx < self.max_num_comm_buf, f"comm_buf_idx {comm_buf_idx} must be less than num_comm_buf {self.max_num_comm_buf}" - gemm_input_a = self._comm_output_buffer[comm_buf_idx].view(-1)[:M * K].view([M, K]) + gemm_input_a = self._comm_output_buffer[comm_buf_idx].view(-1)[: M * K].view([M, K]) cur_stream = torch.cuda.current_stream() diff --git a/examples/dsa_sparse_finetune/dsa.py b/examples/dsa_sparse_finetune/dsa.py new file mode 100644 index 000000000..9fae8e5e3 --- /dev/null +++ b/examples/dsa_sparse_finetune/dsa.py @@ -0,0 +1,223 @@ +from typing import Optional +import torch +import torch.nn.functional as F +from indexer_topk_reducesum import indexer_topk_reducesum_interface +from indexer_bwd import indexer_bwd_interface +from sparse_mla_fwd import sparse_mla_fwd_interface +from sparse_mla_bwd import sparse_mla_bwd +from sparse_mla_topk_reducesum import sparse_mla_topk_reducesum_interface +from einops import einsum, repeat +from utils import get_abs_err, get_err_ratio + + +class RegsiterLossFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, loss): + ctx.save_for_backward(loss) + return x + + @staticmethod + def backward(ctx, grad): + loss = ctx.saved_tensors + return grad, torch.ones(1, dtype=loss[0].dtype, device=loss[0].device) + + +register_loss = RegsiterLossFunction.apply + + +def ref_deepseek_sparse_attention_innner( + q: torch.Tensor, + kv: torch.Tensor, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + topk: int, + dim_v: int, + sm_scale: Optional[float] = None, + index_sm_scale: Optional[float] = None, +): + dtype = q.dtype + q, kv, index_q, index_k, weights = map(lambda x: x.to(torch.float32), (q, kv, index_q, index_k, weights)) + + index_sm_scale = index_q.shape[-1] ** -0.5 + b, s = index_q.shape[:2] + + # tl_topk_indices = tl_topk_indices.to(torch.int64) + # tl_topk_indices[tl_topk_indices == -1] = s + + casual_mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) + index_logits = einsum(index_q, index_k, "b s1 h k, b s2 k -> b s1 h s2") + index_logits = F.relu(index_logits) + index_logits = (index_logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * index_sm_scale + index_logits = torch.where(casual_mask, index_logits, float("-inf")) + topk_indices = torch.topk(index_logits, k=topk, dim=-1).indices + topk_logits = torch.gather(F.pad(index_logits, (0, 1), value=float("-inf")), dim=-1, index=topk_indices) + topk_score = F.log_softmax(topk_logits, dim=-1, dtype=torch.float32) + index_topk_score = topk_score + + if sm_scale is None: + sm_scale = kv.shape[-1] ** -0.5 + + h = q.shape[-2] + index_mask = torch.zeros((b, s, s + 1), dtype=torch.bool, device="cuda").scatter_( + dim=-1, index=topk_indices, src=torch.ones_like(topk_indices, dtype=torch.bool) + )[:, :, :-1] + mask = repeat(casual_mask & index_mask, "b s1 s2 -> b s1 h s2", h=h) + k, v = kv, kv[..., :dim_v] + logits = einsum(q, k, "b s1 h d, b s2 d -> b s1 h s2") * sm_scale + logits = torch.where(mask, logits, float("-inf")) + attn_score = F.softmax(logits, dim=-1, dtype=torch.float32) + o = einsum(attn_score, v, "b s1 h s2, b s2 d -> b s1 h d") + + attn_score = attn_score.sum(dim=-2) # [b, s1, s2] + attn_topk_score = torch.gather(F.pad(attn_score, (0, 1)), dim=-1, index=topk_indices) + attn_topk_score = attn_topk_score / attn_topk_score.sum(dim=-1, keepdim=True) + + loss = F.kl_div(index_topk_score.clip(-100, 0), attn_topk_score.detach().log().clip(-100, 0), log_target=True, reduction="sum") + o = register_loss(o, loss) + + return o.to(dtype), topk_indices + + +def ref_deepseek_sparse_attention( + q: torch.Tensor, + kv: torch.Tensor, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + offsets: torch.Tensor, + topk: int, + dim_v: int, + sm_scale: Optional[float] = None, + index_sm_scale: Optional[float] = None, +): + all_o, all_topk_indices = [], [] + for i in range(offsets.shape[0] - 1): + o, topk_indices = ref_deepseek_sparse_attention_innner( + q[None, offsets[i] : offsets[i + 1]], + kv[None, offsets[i] : offsets[i + 1]], + index_q[None, offsets[i] : offsets[i + 1]], + index_k[None, offsets[i] : offsets[i + 1]], + weights[None, offsets[i] : offsets[i + 1]], + topk, + dim_v, + sm_scale, + index_sm_scale, + ) + all_o.append(o.squeeze(0)) + all_topk_indices.append(topk_indices.squeeze(0)) + o = torch.cat(all_o, dim=0) + topk_indices = torch.cat(all_topk_indices, dim=0) + return o, topk_indices + + +class DSAFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + kv: torch.Tensor, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + offsets: torch.Tensor, + topk: int, + dim_v: int, + sm_scale: Optional[float] = None, + ): + # topk_indices, index_score = ref_index_score(index_q, weights, index_k, topk) + topk_indices, index_score = indexer_topk_reducesum_interface(index_q, weights, index_k, topk, offsets) + o, lse = sparse_mla_fwd_interface(q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), offsets, sm_scale=sm_scale, d_v=dim_v) + ctx.save_for_backward(q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets) + ctx.topk = topk + ctx.dim_v = dim_v + ctx.sm_scale = sm_scale + return o, topk_indices + + @staticmethod + def backward( + ctx, + do: torch.Tensor, + _1: torch.Tensor, + ): + q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets = ctx.saved_tensors + attn_score = sparse_mla_topk_reducesum_interface( + q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), lse, offsets, dim_v=ctx.dim_v + ).squeeze(-2) + dq, dkv = sparse_mla_bwd(q, kv.unsqueeze(-2), o, do, topk_indices.unsqueeze(-2), lse, offsets, sm_scale=ctx.sm_scale) + dindex_q, dweights, dindex_k = indexer_bwd_interface(index_q, weights, index_k, attn_score, index_score, topk_indices, offsets) + return dq, dkv.squeeze(-2), dindex_q, dindex_k, dweights, None, None, None, None + + +def deepseek_sparse_attention( + q: torch.Tensor, + kv: torch.Tensor, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + offsets: torch.Tensor, + topk: int, + dim_v: int, + sm_scale: Optional[float] = None, +): + return DSAFunction.apply(q, kv, index_q, index_k, weights, offsets, topk, dim_v, sm_scale) + + +def test_kernel( + B=1, + S=2048, + H=16, + D=512, + tail_D=64, + index_D=128, + topk=64, +): + torch.manual_seed(42) + q = torch.randn((S, H, D + tail_D)).cuda().bfloat16().requires_grad_() + kv = torch.randn((S, D + tail_D)).cuda().bfloat16().requires_grad_() + index_q = torch.randn((S, H, index_D)).cuda().bfloat16().requires_grad_() + weights = torch.randn((S, H)).cuda().bfloat16().requires_grad_() + index_k = torch.randn((S, index_D)).cuda().bfloat16().requires_grad_() + do = torch.randn((S, H, D)).cuda().bfloat16().requires_grad_() + offsets = torch.tensor([0, S // 2, S], dtype=torch.int32).cuda() + + o, topk_indices = deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D) + o.backward(do) + q_grad, q.grad = q.grad, None + kv_grad, kv.grad = kv.grad, None + index_q_grad, index_q.grad = index_q.grad, None + index_k_grad, index_k.grad = index_k.grad, None + weights_grad, weights.grad = weights.grad, None + + ref_o, ref_topk_indices = ref_deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D) + ref_o.backward(do) + ref_q_grad, q.grad = q.grad, None + ref_kv_grad, kv.grad = kv.grad, None + ref_index_q_grad, index_q.grad = index_q.grad, None + ref_index_k_grad, index_k.grad = index_k.grad, None + ref_weights_grad, weights.grad = weights.grad, None + + print(f"o err: {get_abs_err(o, ref_o):.6f} ratio: {get_err_ratio(o, ref_o):.6f}") + print(f"q.grad err: {get_abs_err(q_grad, ref_q_grad):.6f} ratio: {get_err_ratio(q_grad, ref_q_grad):.6f}") + print(f"kv.grad err: {get_abs_err(kv_grad, ref_kv_grad):.6f} ratio: {get_err_ratio(kv_grad, ref_kv_grad):.6f}") + print( + f"index_q.grad err: {get_abs_err(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f} ratio: {get_err_ratio(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f}" + ) + print(f"index_k.grad err: {get_abs_err(index_k_grad, ref_index_k_grad):.6f} ratio: {get_err_ratio(index_k_grad, ref_index_k_grad):.6f}") + print(f"weights.grad err: {get_abs_err(weights_grad, ref_weights_grad):.6f} ratio: {get_err_ratio(weights_grad, ref_weights_grad):.6f}") + + intersections = [] + for j in range(S): + ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy() + trt_np = topk_indices[j].cpu().to(torch.int32).numpy() + + mask = trt_np != -1 + + set_ref = set(ref_np[mask]) + set_trt = set(trt_np[mask]) + intersection = set_ref & set_trt + intersections.append(len(intersection) / len(set_ref)) + print("average intersections: {:.4f}".format(sum(intersections) / len(intersections))) + + +test_kernel() diff --git a/examples/dsa_sparse_finetune/index.py b/examples/dsa_sparse_finetune/index.py new file mode 100644 index 000000000..5e4800411 --- /dev/null +++ b/examples/dsa_sparse_finetune/index.py @@ -0,0 +1,82 @@ +# Modified from: https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/index.py +import torch +import torch.nn.functional as F +import functools +from typing import Callable, Any + + +def tensor_cache( + fn: Callable[..., torch.Tensor], +) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent result of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + If the function is called again with the same input tensors, it will return the cached result. + + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + last_args: tuple | None = None + last_kwargs: dict | None = None + last_result: Any = None + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal last_args, last_kwargs, last_result + + if ( + (last_args is not None and last_kwargs is not None) + and (len(args) == len(last_args) and len(kwargs) == len(last_kwargs)) + and all(a is b for a, b in zip(args, last_args, strict=False)) + and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()) + ): + return last_result + + result = fn(*args, **kwargs) + last_args, last_kwargs, last_result = args, kwargs, result + return result + + return wrapper + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return torch.diff(cu_seqlens) + + +@tensor_cache +def prepare_cu_seqlens_from_lens( + lens: torch.LongTensor, + dtype: torch.dtype | None = torch.int32, +) -> torch.LongTensor: + return F.pad(lens.cumsum(dim=0, dtype=dtype), (1, 0)) + + +@tensor_cache +def prepare_lens_from_cu_seqlens( + cu_seqlens: torch.LongTensor, +) -> torch.LongTensor: + return torch.diff(cu_seqlens) + + +@tensor_cache +def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return torch.cat([torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) for n in prepare_lens(cu_seqlens).unbind()]) + + +@tensor_cache +def prepare_sequence_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return prepare_position_ids(cu_seqlens).eq(0).cumsum(0) - 1 + + +@tensor_cache +def prepare_token_indices(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + position_ids = prepare_position_ids(cu_seqlens) + return torch.stack([prepare_sequence_ids(cu_seqlens), position_ids], 1).to(cu_seqlens) diff --git a/examples/dsa_sparse_finetune/indexer_bwd.py b/examples/dsa_sparse_finetune/indexer_bwd.py new file mode 100644 index 000000000..68508ad4e --- /dev/null +++ b/examples/dsa_sparse_finetune/indexer_bwd.py @@ -0,0 +1,254 @@ +import torch +import torch.nn.functional as F +from einops import einsum, repeat + +import tilelang as tl +import tilelang.language as T +from typing import Optional +from index import prepare_token_indices + +from utils import get_abs_err, get_err_ratio + +BF16 = T.bfloat16 +FP32 = T.float32 +INT32 = T.int32 + +pass_configs = { + tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, +} + + +@tl.jit(pass_configs=pass_configs) +def tl_indexer_bwd_impl( + heads: int, + dim: int, + topk: int, + sm_scale: Optional[float] = None, + block_I: int = 32, + num_stages: int = 0, + num_threads: int = 128, +): + assert num_stages == 0 + assert topk == tl.math.next_power_of_2(topk) + assert topk % block_I == 0 + assert heads <= 64 and heads % 8 == 0 + batch_plus_one = T.symbolic("batch_plus_one") + seq_len = T.symbolic("seq_len") + dtype: str = BF16 + accum_dtype: str = FP32 + index_q_shape = [seq_len, heads, dim] + weights_shape = [seq_len, heads] + index_k_shape = [seq_len, dim] + shape_p = [seq_len, topk] + topk_indices_shape = [seq_len, topk] + offsets_shape = [batch_plus_one] + token_indices_shape = [seq_len, 2] + if sm_scale is None: + sm_scale = dim**-0.5 + + @T.prim_func + def tl_indexer_bwd_kernel( + IndexQ: T.Tensor(index_q_shape, dtype), + Weights: T.Tensor(weights_shape, dtype), + IndexK: T.Tensor(index_k_shape, dtype), + dIndexQ: T.Tensor(index_q_shape, dtype), + dWeights: T.Tensor(weights_shape, dtype), + dIndexK: T.Tensor(index_k_shape, dtype), + AttnScore: T.Tensor(shape_p, FP32), + IndexScore: T.Tensor(shape_p, FP32), + TopkIndices: T.Tensor(topk_indices_shape, INT32), + Offsets: T.Tensor(offsets_shape, INT32), + TokenIndices: T.Tensor(token_indices_shape, INT32), + ): + with T.Kernel(seq_len, threads=num_threads) as (bx): + i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1] + bos = Offsets[i_b] + num_blocks = T.ceildiv(topk, block_I) + + index_q_shared = T.alloc_shared([heads, dim], dtype=dtype) + weights_shared = T.alloc_shared([heads], dtype=dtype) + + d_index_q_frag = T.alloc_fragment([heads, dim], dtype=accum_dtype) + d_weights_frag = T.alloc_fragment([heads], dtype=accum_dtype) + + T.copy(IndexQ[bos + i_t, :, :], index_q_shared) + T.copy(Weights[bos + i_t, :], weights_shared) + T.fill(d_index_q_frag, 0) + T.fill(d_weights_frag, 0) + + for i, j in T.Parallel(heads, dim): + index_q_shared[i, j] = index_q_shared[i, j] * sm_scale + + for bi_i in T.Pipelined(num_blocks, num_stages=num_stages): + i_st = bi_i * block_I + i_ed = (bi_i + 1) * block_I + + indices_shared = T.alloc_shared([block_I], dtype=INT32) + T.copy(TopkIndices[bos + i_t, i_st:i_ed], indices_shared) + + index_k_shared = T.alloc_shared([block_I, dim], dtype=dtype) + for i, j in T.Parallel(block_I, dim): + pos = indices_shared[i] + index_k_shared[i, j] = T.if_then_else((pos > -1) & (pos <= i_t), IndexK[bos + pos, j], 0) + + attn_score_shared = T.alloc_shared([block_I], dtype=accum_dtype) + index_score_shared = T.alloc_shared([block_I], dtype=accum_dtype) + for i in T.Parallel(block_I): + attn_score_shared[i] = AttnScore[bos + i_t, i_st + i] + index_score_shared[i] = IndexScore[bos + i_t, i_st + i] + + logits = T.alloc_fragment((block_I, heads), accum_dtype) + T.gemm( + index_k_shared, + index_q_shared, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=True, + ) + for i, j in T.Parallel(block_I, heads): + logits[i, j] = T.max(logits[i, j], 0) + + # dw + d_weights_i = T.alloc_fragment((block_I, heads), accum_dtype) + for i, j in T.Parallel(block_I, heads): + d_weights_i[i, j] = (index_score_shared[i] - attn_score_shared[i]) * logits[i, j] + T.reduce_sum(d_weights_i, d_weights_frag, dim=0, clear=False) + + d_logits_qk = T.alloc_shared((block_I, heads), accum_dtype) + d_logits_qk_cast1 = T.alloc_fragment((block_I, heads), dtype) + d_logits_qk_cast2 = T.alloc_fragment((block_I, heads), dtype) + + for i, j in T.Parallel(block_I, heads): + d_relu = T.alloc_var(accum_dtype) + if logits[i, j] > 0: + d_relu = 1.0 + else: + d_relu = 0.0 + d_logits_qk[i, j] = (index_score_shared[i] - attn_score_shared[i]) * d_relu * weights_shared[j] + + # dq + T.copy(d_logits_qk, d_logits_qk_cast1) + T.gemm( + d_logits_qk_cast1, # [BS, HQ] + index_k_shared, # [BS, K] + d_index_q_frag, # [HQ, K] + transpose_A=True, + transpose_B=False, + clear_accum=False, + ) + + # dk + T.copy(d_logits_qk, d_logits_qk_cast2) + d_index_k_frag = T.alloc_fragment([block_I, dim], dtype=accum_dtype) + T.gemm( + d_logits_qk_cast2, # [BS, HQ] + index_q_shared, # [HQ, K] + d_index_k_frag, # [BS, K] + transpose_A=False, + transpose_B=False, + clear_accum=True, + ) + + for i, j in T.Parallel(block_I, dim): + pos = indices_shared[i] + if (pos > -1) & (pos <= i_t): + T.atomic_add(dIndexK[bos + pos, j], d_index_k_frag[i, j]) + + for i, j in T.Parallel(heads, dim): + d_index_q_frag[i, j] = d_index_q_frag[i, j] * sm_scale + + T.copy(d_index_q_frag, dIndexQ[bos + i_t, :, :]) + T.copy(d_weights_frag, dWeights[bos + i_t, :]) + + return tl_indexer_bwd_kernel + + +def indexer_bwd_interface( + q: torch.Tensor, + weights: torch.Tensor, + k: torch.Tensor, + attn_score: torch.Tensor, + index_score: torch.Tensor, + topk_indices: torch.Tensor, + offsets: torch.Tensor, +): + _, heads, dim, topk = *q.shape, topk_indices.shape[-1] + token_indices = prepare_token_indices(offsets) + dq = torch.zeros_like(q) + dweights = torch.zeros_like(weights) + dk = torch.zeros_like(k) + kernel = tl_indexer_bwd_impl(heads, dim, topk) + kernel(q, weights, k, dq, dweights, dk, attn_score, index_score, topk_indices, offsets, token_indices) + return dq, dweights, dk + + +def ref_indexer_bwd( + Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, AttnScore: torch.Tensor, offsets: torch.Tensor +) -> torch.Tensor: + Q.requires_grad_(True) + Weights.requires_grad_(True) + K.requires_grad_(True) + softmax_scale = Q.shape[-1] ** -0.5 + all_loss = [] + all_log_topk_prob = [] + for i in range(offsets.shape[0] - 1): + assert (offsets[i + 1] - offsets[i]).item() >= TopkIndices.shape[-1] + q = Q[offsets[i] : offsets[i + 1]] + weights = Weights[offsets[i] : offsets[i + 1]] + k = K[offsets[i] : offsets[i + 1]] + topk_indices = TopkIndices[offsets[i] : offsets[i + 1]] + attn_score = AttnScore[offsets[i] : offsets[i + 1]] + s = q.shape[0] + mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) + logits = einsum(q, k, "s1 h k, s2 k -> s1 h s2") * softmax_scale + logits = F.relu(logits) + score = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) + score = torch.where(mask, score, float("-inf")) + topk_value = torch.gather(score, dim=-1, index=topk_indices.to(torch.int64)) + log_topk_prob = F.log_softmax(topk_value, dim=-1, dtype=torch.float32) + loss = F.kl_div(log_topk_prob.clip(-100, 0), attn_score.log().clip(-100, 0), log_target=True, reduction="sum") + all_loss.append(loss) + all_log_topk_prob.append(log_topk_prob) + loss = torch.stack(all_loss).sum() + loss.backward() + log_topk_prob = torch.cat(all_log_topk_prob, dim=0) + return log_topk_prob.exp(), Q.grad, Weights.grad, K.grad + + +def test_kernel( + B=1, + S=2048, + H=16, + D=128, + topk=64, +): + torch.manual_seed(42) + q = torch.randn((S, H, D)).cuda().bfloat16() + w = torch.randn((S, H)).cuda().bfloat16() + k = torch.randn((S, D)).cuda().bfloat16() + offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda() + + all_attn_score = [] + for i in range(offsets.shape[0] - 1): + seq_len = (offsets[i + 1] - offsets[i]).item() + mask = (torch.arange(seq_len)[:, None] >= torch.arange(topk)[None, :]).to(q.device) + logits = torch.ones(seq_len, topk).cuda() + logits = torch.where(mask, logits, float("-inf")) + attn_score = F.softmax(logits, dim=-1, dtype=torch.float32) + all_attn_score.append(attn_score) + attn_score = torch.cat(all_attn_score, dim=0) + + topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous() + index_score, ref_dq, ref_dw, ref_dk = ref_indexer_bwd(q, w, k, topk_indices, attn_score, offsets) + + dq, dw, dk = indexer_bwd_interface(q, w, k, attn_score, index_score, topk_indices, offsets) + + print(f"dq err: {get_abs_err(dq, ref_dq):.6f} ratio: {get_err_ratio(dq, ref_dq):.6f}") + print(f"dq err: {get_abs_err(dw, ref_dw):.6f} ratio: {get_err_ratio(dw, ref_dw):.6f}") + print(f"dq err: {get_abs_err(dk, ref_dk):.6f} ratio: {get_err_ratio(dk, ref_dk):.6f}") + + +if __name__ == "__main__": + test_kernel() diff --git a/examples/dsa_sparse_finetune/indexer_topk_reducesum.py b/examples/dsa_sparse_finetune/indexer_topk_reducesum.py new file mode 100644 index 000000000..d76eb0272 --- /dev/null +++ b/examples/dsa_sparse_finetune/indexer_topk_reducesum.py @@ -0,0 +1,273 @@ +import math +import torch +import torch.nn.functional as F +from einops import einsum + +import tilelang as tl +import tilelang.language as T +from typing import Optional +from index import prepare_token_indices + +from utils import get_abs_err, get_err_ratio + +BF16 = T.bfloat16 +FP32 = T.float32 +INT32 = T.int32 + +pass_configs = { + tl.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, + tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, +} + + +@tl.jit(pass_configs=pass_configs) +def tl_indexer_topk_reducesum_impl( + heads: int, + dim: int, + topk: int, + sm_scale: Optional[float] = None, + block_K: int = 32, + dtype: str = FP32, + num_stages: int = 0, + num_threads: int = 128, +): + assert topk == tl.math.next_power_of_2(topk) + assert topk % block_K == 0 + assert heads <= 64 and heads % 8 == 0 + assert num_stages == 0 + batch_plus_one = T.symbolic("batch_plus_one") + seq_len = T.symbolic("seq_len") + + index_q_shape = [seq_len, heads, dim] + weights_shape = [seq_len, heads] + index_k_shape = [seq_len, dim] + topk_indices_shape = [seq_len, topk] + offsets_shape = [batch_plus_one] + token_indices_shape = [seq_len, 2] + + N = 2 * topk + num_iters = int(round(math.log2(N))) + if sm_scale is None: + sm_scale = dim**-0.5 + + @T.macro + def bitonic_sort( + topk_index_shared: T.SharedBuffer([N], dtype=INT32), + topk_value_shared: T.SharedBuffer([N], dtype=FP32), + ): + T.sync_threads() + for i1 in T.serial(num_iters): + for i2 in T.serial(i1 + 1): + for i in T.Parallel(N): + ascending = (i & (1 << (i1 + 1))) != 0 + j = i ^ (1 << (i1 - i2)) + if i < j and ( + (ascending and topk_value_shared[i] > topk_value_shared[j]) + or (not ascending and topk_value_shared[i] < topk_value_shared[j]) + ): + val = topk_value_shared[i] + topk_value_shared[i] = topk_value_shared[j] + topk_value_shared[j] = val + idx = topk_index_shared[i] + topk_index_shared[i] = topk_index_shared[j] + topk_index_shared[j] = idx + T.sync_threads() + + @T.prim_func + def tl_indexer_topk_reducesum_kernel( + IndexQ: T.Tensor(index_q_shape, dtype), + Weights: T.Tensor(weights_shape, dtype), + IndexK: T.Tensor(index_k_shape, dtype), + TopkIndices: T.Tensor(topk_indices_shape, INT32), + ReduceSum: T.Tensor(topk_indices_shape, FP32), + Offsets: T.Tensor(offsets_shape, INT32), + TokenIndices: T.Tensor(token_indices_shape, INT32), + ): + with T.Kernel(seq_len, threads=num_threads) as (bx): + i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1] + bos, eos = Offsets[i_b], Offsets[i_b + 1] + num_blocks = T.ceildiv(i_t + 1, block_K) + + topk_index_shared = T.alloc_shared([N], dtype=INT32) + topk_value_shared = T.alloc_shared([N], dtype=FP32) + + T.fill(topk_index_shared, -1) + T.fill(topk_value_shared, float("-inf")) + T.sync_threads() + + index_q_shared = T.alloc_shared([heads, dim], dtype=dtype) + T.copy(IndexQ[bos + i_t, :, :], index_q_shared) + T.sync_threads() + + weights_frag = T.alloc_shared([heads], dtype=dtype) + T.copy(Weights[bos + i_t, :], weights_frag) + T.sync_threads() + + for i, j in T.Parallel(heads, dim): + index_q_shared[i, j] = index_q_shared[i, j] * sm_scale + T.sync_threads() + + for bk_i in T.Pipelined(num_blocks, num_stages=num_stages): + k_st = bk_i * block_K + k_ed = T.min((bk_i + 1) * block_K, eos - bos) + + index_k_shared = T.alloc_shared([block_K, dim], dtype=dtype) + for i, j in T.Parallel(block_K, dim): + index_k_shared[i, j] = T.if_then_else(k_st + i < k_ed, IndexK[bos + k_st + i, j], 0) + T.sync_threads() + + logits = T.alloc_fragment((block_K, heads), FP32) + T.gemm( + index_k_shared, + index_q_shared, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=True, + ) + T.sync_threads() + + for i, j in T.Parallel(block_K, heads): + logits[i, j] = T.max(logits[i, j], 0) * weights_frag[j] + T.sync_threads() + + logits_sum = T.alloc_fragment(block_K, FP32) + T.reduce_sum(logits, logits_sum, dim=1) + T.sync_threads() + + offset = T.alloc_var(INT32) + if k_st >= topk: + offset = topk + (k_st % topk) + else: + offset = k_st + T.sync_threads() + for i in T.Parallel(block_K): + if k_st + i > i_t: + logits_sum[i] = float("-inf") + j = offset + i + topk_index_shared[j] = k_st + i + topk_value_shared[j] = logits_sum[i] + T.sync_threads() + + if k_ed > topk and k_ed % topk == 0: + bitonic_sort(topk_index_shared, topk_value_shared) + + bitonic_sort(topk_index_shared, topk_value_shared) + + logits_max_frag = T.alloc_fragment([1], dtype=FP32) + logits_frag = T.alloc_fragment([topk], dtype=FP32) + reducesum_shared = T.alloc_shared([topk], dtype=FP32) + + T.copy(topk_value_shared[:topk], logits_frag) + T.sync_threads() + + T.reduce_max(logits_frag, logits_max_frag, dim=-1) + T.sync_threads() + + for i in T.Parallel(topk): + logits_frag[i] = T.exp(logits_frag[i] - logits_max_frag[0]) + T.sync_threads() + + lse_frag = T.alloc_fragment([1], dtype=FP32) + T.reduce_sum(logits_frag, lse_frag) + T.sync_threads() + + for i in T.Parallel(topk): + reducesum_shared[i] = logits_frag[i] / lse_frag[0] + T.sync_threads() + + # for i in T.Parallel(topk): + # reducesum_shared[i] = logits_frag[i] + # T.sync_threads() + + for i in T.Parallel(topk): + if topk_index_shared[i] > i_t: + topk_index_shared[i] = -1 + T.sync_threads() + + T.copy(topk_index_shared[:topk], TopkIndices[bos + i_t, :]) + T.copy(reducesum_shared[:topk], ReduceSum[bos + i_t, :]) + + return tl_indexer_topk_reducesum_kernel + + +def indexer_topk_reducesum_interface( + q: torch.Tensor, + weights: torch.Tensor, + k: torch.Tensor, + topk: int, + offsets: torch.Tensor, + dtype: str = BF16, +): + seq_len, heads, dim = q.shape + kernel = tl_indexer_topk_reducesum_impl(heads=heads, dim=dim, topk=topk, dtype=dtype) + token_indices = prepare_token_indices(offsets) + topk_indices = torch.zeros((seq_len, topk), device=q.device, dtype=torch.int32) + topk_score = torch.zeros((seq_len, topk), device=q.device, dtype=torch.float32) + kernel(q, weights, k, topk_indices, topk_score, offsets, token_indices) + return topk_indices, topk_score + + +def ref_index_score(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, topk: int, offsets: torch.Tensor) -> torch.Tensor: + all_topk_indices = [] + all_topk_score = [] + for i in range(offsets.shape[0] - 1): + assert (offsets[i + 1] - offsets[i]).item() >= topk + q = Q[offsets[i] : offsets[i + 1]] + weights = Weights[offsets[i] : offsets[i + 1]] + k = K[offsets[i] : offsets[i + 1]] + softmax_scale = q.shape[-1] ** -0.5 + s = q.shape[0] + mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) + logits = einsum(q, k, "s1 h k, s2 k -> s1 h s2") + logits = F.relu(logits) + logits = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * softmax_scale + logits = torch.where(mask, logits, float("-inf")) + topk_logits, topk_indices = torch.topk(logits, k=topk, dim=-1) + topk_score = F.softmax(topk_logits, dim=-1, dtype=torch.float32) + all_topk_indices.append(topk_indices) + all_topk_score.append(topk_score) + topk_indices = torch.cat(all_topk_indices, dim=0) + topk_score = torch.cat(all_topk_score, dim=0) + return topk_indices, topk_score + + +def test_kernel( + B=1, + S=2048, + H=64, + D=128, + topk=64, +): + torch.manual_seed(42) + + q = torch.randn((S, H, D)).cuda().bfloat16() + weights = torch.randn((S, H)).cuda().bfloat16() + k = torch.randn((S, D)).cuda().bfloat16() + offsets = torch.tensor([0, S], dtype=torch.int32).cuda() + + ref_topk_indices, ref_topk_score = ref_index_score(q, weights, k, topk, offsets) + + topk_indices, topk_score = indexer_topk_reducesum_interface(q, weights, k, topk, offsets) + + for j in range(S): + ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy() + trt_np = topk_indices[j].cpu().to(torch.int32).numpy() + + ref_np_val = ref_topk_score[j] + trt_np_val = topk_score[j] + + mask = (ref_np_val > 0).cpu().numpy() + + set_ref = set(ref_np[mask]) + set_trt = set(trt_np[mask]) + intersection = set_ref & set_trt + + print("idx:", j, "selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref)) + + print(f"err: {get_abs_err(ref_np_val, trt_np_val):.6f} ratio: {get_err_ratio(ref_np_val, trt_np_val):.6f}") + + +if __name__ == "__main__": + test_kernel() diff --git a/examples/dsa_sparse_finetune/sparse_mla_bwd.py b/examples/dsa_sparse_finetune/sparse_mla_bwd.py new file mode 100644 index 000000000..53e5f8bfe --- /dev/null +++ b/examples/dsa_sparse_finetune/sparse_mla_bwd.py @@ -0,0 +1,347 @@ +# ruff: noqa +import tilelang +from tilelang import language as T +import torch +from index import prepare_token_indices + +from utils import assert_tensors_similar + + +@tilelang.jit(out_idx=[-1]) +def preprocess( + H, + D, + block_ND=32, + num_stages=5, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + + S = T.symbolic("S") + + shape = [S, H, D] + + @T.prim_func + def preprocess_kernel( + O: T.Tensor(shape, dtype), + dO: T.Tensor(shape, dtype), + Delta: T.Tensor([S, H], accum_dtype), + ): + with T.Kernel(H, T.ceildiv(S, block_ND)) as (bx, by): + o = T.alloc_fragment([block_ND, block_ND], accum_dtype) + do = T.alloc_fragment([block_ND, block_ND], accum_dtype) + delta = T.alloc_fragment([block_ND], accum_dtype) + acc = T.alloc_fragment([block_ND, block_ND], accum_dtype) + T.clear(acc) + for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages): + T.copy(O[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], o) + T.copy(dO[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], do) + for i, j in T.Parallel(block_ND, block_ND): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[by * block_ND : (by + 1) * block_ND, bx]) + + return preprocess_kernel + + +@tilelang.jit(out_idx=[-1]) +def postprocess( + D, + D_tail, + kv_group=1, + block_N=64, + threads=128, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + S_kv = T.symbolic("S_kv") + + dkv_shape = [S_kv, kv_group, D + D_tail] + + @T.prim_func + def postprocess_kernel( + dKV: T.Tensor(dkv_shape, accum_dtype), + dKV_out: T.Tensor(dkv_shape, dtype), + ): + with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, threads=threads) as (bx, by): + T.copy( + dKV[bx * block_N : (bx + 1) * block_N, by, :], + dKV_out[bx * block_N : (bx + 1) * block_N, by, :], + ) + + return postprocess_kernel + + +@tilelang.jit( + out_idx=[-2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def bwd( + H, + D, + D_tail, + topk, + kv_group=1, + sm_scale=None, + is_causal=True, + block_size=32, + num_stages=0, + threads=128, + indices_dtype=T.int32, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + assert is_causal == True, "non-casual is not supported now" + assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + assert indices_dtype == T.int32 + + if sm_scale is None: + sm_scale = (D + D_tail) ** (-0.5) + + B_plus_one = T.symbolic("B_plus_one") + S = T.symbolic("S") + + H_kv = H // kv_group + q_shape = [S, H, D + D_tail] + k_shape = [S, kv_group, D + D_tail] + o_shape = [S, H, D] + indices_shape = [S, kv_group, topk] + delta_shape = [S, H] + lse_shape = [S, H] + offsets_shape = [B_plus_one] + token_indices_shape = [S, 2] + assert indices_dtype == T.int32 + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + + H = H_kv + padded_H = max(tilelang.math.next_power_of_2(H_kv), 16) + BS = block_size + NS = tilelang.cdiv(topk, block_size) + + split_store = 2 + + @T.prim_func + def sparse_mla_bwd_kernel( + Q: T.Tensor(q_shape, dtype), + KV: T.Tensor(k_shape, dtype), + dO: T.Tensor(o_shape, dtype), + Indices: T.Tensor(indices_shape, indices_dtype), + Lse: T.Tensor(lse_shape, accum_dtype), + Delta: T.Tensor(delta_shape, accum_dtype), + Offsets: T.Tensor(offsets_shape, indices_dtype), + TokenIndices: T.Tensor(token_indices_shape, indices_dtype), + dQ: T.Tensor(q_shape, dtype), + dKV: T.Tensor(k_shape, accum_dtype), + ): + with T.Kernel(S, kv_group, threads=threads) as (b_s_i, bz): + Q_shared = T.alloc_shared([padded_H, D], dtype) + Q_tail_shared = T.alloc_shared([padded_H, D_tail], dtype) + KV_shared = T.alloc_shared([BS, D], dtype) + KV_tail_shared = T.alloc_shared([BS, D_tail], dtype) + dO_shared = T.alloc_shared([padded_H, D], dtype) + mask = T.alloc_fragment([BS], "bool") + + P_shared_cast = T.alloc_shared([padded_H, BS], dtype) + dP_shared_cast = T.alloc_shared([padded_H, BS], dtype) + dQ_shared = T.alloc_shared([padded_H, D], dtype) + dQ_tail_shared = T.alloc_shared([padded_H, D_tail], dtype) + + acc_p = T.alloc_fragment([padded_H, BS], accum_dtype) + acc_dp = T.alloc_fragment([padded_H, BS], accum_dtype) + acc_dq = T.alloc_fragment([padded_H, D], accum_dtype) + acc_dq_tail = T.alloc_fragment([padded_H, D_tail], accum_dtype) + acc_dkv = T.alloc_fragment([BS, D], accum_dtype) + acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype) + acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype) + acc_dkv_tail_shared = T.view(KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype) + + b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1] + bos, eos = Offsets[b_i], Offsets[b_i + 1] + + max_kv_i = s_i + + T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], Q_shared) + T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:], Q_tail_shared) + T.copy(dO[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], dO_shared) + + T.clear(acc_dq) + T.clear(acc_dq_tail) + + # Process each block of indices + for i_i in T.Pipelined(NS, num_stages=num_stages): + # Check which indices are valid + for bi_i in T.Parallel(BS): + mask[bi_i] = (Indices[bos + s_i, bz, i_i * BS + bi_i] <= max_kv_i) & (Indices[bos + s_i, bz, i_i * BS + bi_i] != -1) + + # Compute attention scores + for h_i, bi_i in T.Parallel(padded_H, BS): + acc_p[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_p.dtype)) + + # Load KV, V for this block of indices + for bi_i, d_i in T.Parallel(BS, D): + KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, d_i] + + T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + + for bi_i, d_i in T.Parallel(BS, D_tail): + KV_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, D + d_i] + T.gemm(Q_tail_shared, KV_tail_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + + for h_i, bi_i in T.Parallel(padded_H, BS): + acc_p[h_i, bi_i] = T.exp(acc_p[h_i, bi_i] * sm_scale - Lse[bos + s_i, bz * padded_H + h_i]) + + T.copy(acc_p, P_shared_cast) + + T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + + for h_i, bi_i in T.Parallel(padded_H, BS): + acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[bos + s_i, bz * padded_H + h_i]) * sm_scale + + T.copy(acc_dp, dP_shared_cast) + T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol) + T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol) + + T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) + + T.clear(acc_dkv_tail) + T.gemm(dP_shared_cast, Q_tail_shared, acc_dkv_tail, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) + + for s in range(split_store): + for bi_i, d_i in T.Parallel(BS, D): + if bi_i < BS // split_store: + acc_dkv_shared[bi_i, d_i] = acc_dkv[bi_i + s * (BS // split_store), d_i] + + for bi_i, d_i in T.Parallel(BS, D_tail): + if bi_i < BS // split_store: + acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i] + + for bi_i, d_i in T.Parallel(BS // split_store, D // 4): + T.atomic_addx4( + dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, d_i * 4], + acc_dkv_shared[bi_i, d_i * 4], + ) + + # Atomically update dKV, dKV_tail tensors + for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4): + T.atomic_addx4( + dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, D + d_i * 4], + acc_dkv_tail_shared[bi_i, d_i * 4], + ) + + # Store the accumulated dQ + T.copy(acc_dq, dQ_shared) + T.copy(acc_dq_tail, dQ_tail_shared) + + T.copy(dQ_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D]) + T.copy(dQ_tail_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:]) + + return sparse_mla_bwd_kernel + + +def sparse_mla_bwd(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True, return_kernel=False, delta=None): + assert q.is_contiguous() + assert kv.is_contiguous() + assert indices.is_contiguous() + assert lse.is_contiguous() + S, H, dim_plus_tail_dim = q.shape + S_kv, kv_group, _ = kv.shape + assert kv.shape[-1] == dim_plus_tail_dim + assert S == S_kv + # dim should be assigned + D = 512 + + D_tail = dim_plus_tail_dim - D + topk = indices.shape[-1] + assert indices.shape == (S, kv_group, topk) + assert lse.shape == (S, H) + + token_indices = prepare_token_indices(offsets) + + # Get kernels + preprocess_kernel = preprocess(H, D) + bwd_kernel = bwd(H, D, D_tail, topk, kv_group, sm_scale, is_casual) + postprocess_kernel = postprocess(D, D_tail, kv_group) + + if delta is None: + delta = preprocess_kernel(o, do) + dkv = torch.zeros_like(kv, dtype=torch.float32) + dq = bwd_kernel(q, kv, do, indices, lse, delta, offsets, token_indices, dkv) + dkv = postprocess_kernel(dkv) + + return dq, dkv + + +def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True): + from sparse_mla_fwd import ref_sparse_mla_fwd_interface + + q = q.detach().clone() + kv = kv.detach().clone() + q.requires_grad = True + kv.requires_grad = True + o = ref_sparse_mla_fwd_interface(q, kv, indices, offsets, sm_scale, is_casual) + o.backward(do) + return q.grad, kv.grad + + +def test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True): + # Prepare data + q = torch.randn((S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((S, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((S, H, DV), dtype=dtype, device="cuda") + offsets = torch.tensor([0, S], dtype=torch.int32, device="cuda") + + indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device="cuda") + for i in range(offsets.shape[0] - 1): + seq_len = (offsets[i + 1] - offsets[i]).item() + assert seq_len >= topk + for t in range(seq_len): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[offsets[i] + t, h, : len(i_i)] = i_i + + # Forward + from sparse_mla_fwd import sparse_mla_fwd_interface + + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets) + + tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets) + ref_dq, ref_dkv = ref_sparse_mla_bwd_interface(q, kv, None, do, indices, None, offsets) + + if check_correctness: + assert_tensors_similar(tl_dq, ref_dq, eps=1e-4, name="dq") + assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv") + print("assert_tensors_similar passed") + + per_token_flop = 2 * sum( + [ + H * DV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DV * topk, + ] + ) + from tilelang.profiler import do_bench + + def fn(): + return sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets) + + ms = do_bench(fn, rep=100, warmup=250) + print(f"Average time: {ms:.3f} ms") + print(f"bwd io bandwidth = ", (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12) + print(f"bwd tflops = ", per_token_flop * S / (ms * 1e-3) / 1e12) + + +if __name__ == "__main__": + test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True) diff --git a/examples/dsa_sparse_finetune/sparse_mla_fwd.py b/examples/dsa_sparse_finetune/sparse_mla_fwd.py new file mode 100644 index 000000000..d87523695 --- /dev/null +++ b/examples/dsa_sparse_finetune/sparse_mla_fwd.py @@ -0,0 +1,310 @@ +# ruff: noqa +import torch +import tilelang +from tilelang import language as T +from index import prepare_token_indices + +from utils import assert_tensors_similar + + +@tilelang.jit( + out_idx=[-2, -1], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def sparse_mla_fwd( + heads, + dim, + tail_dim, + topk, + kv_group=1, + sm_scale=None, + is_causal=True, + CP0=True, + block_I=32, + num_stages=2, + threads=128, +): + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert is_causal == True, "non-casual is not supported" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 + else: + sm_scale = sm_scale + + batch_plus_one = T.symbolic("batch_plus_one") + seq_len = T.symbolic("seq_len") + + head_kv = heads // kv_group + q_shape = [seq_len, heads, dim + tail_dim] + kv_shape = [seq_len, kv_group, dim + tail_dim] + o_shape = [seq_len, heads, dim] + indices_shape = [seq_len, kv_group, topk] + lse_shape = [seq_len, heads] + offsets_shape = [batch_plus_one] + token_indices_shape = [seq_len, 2] + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 + + G = kv_group + H = head_kv + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != H: + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) + BI = block_I + NI = tilelang.cdiv(topk, block_I) + D = dim + D_tail = tail_dim + + if head_kv > 64: + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" + REPLICATE_H = head_kv // 64 + else: + REPLICATE_H = 1 + + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore + TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + ): + with T.Kernel(seq_len * REPLICATE_H, kv_group, threads=threads) as ( + bx, + by, + ): + Q_shared = T.alloc_shared([H_per_block, D], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + KV_shared = T.alloc_shared([BI, D], dtype) + K_tail_shared = T.alloc_shared([BI, D_tail], dtype) + mask = T.alloc_fragment([BI], "bool") + + acc_o = T.alloc_fragment([H_per_block, D], accum_dtype) + acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) + S_shared = T.alloc_shared([H_per_block, BI], dtype) + sumexp = T.alloc_fragment([H_per_block], accum_dtype) + sumexp_i = T.alloc_fragment([H_per_block], accum_dtype) + alpha = T.alloc_fragment([H_per_block], accum_dtype) + m_i = T.alloc_fragment([H_per_block], accum_dtype) + m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) + + T.fill(acc_o, 0) + T.fill(sumexp, 0) + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan + + b_s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H) + b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1] + bos, eos = Offsets[b_i], Offsets[b_i + 1] + g_i = by + q_i = s_i + max_kv_i = q_i + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + T.copy(Q[bos + s_i, H0:H1, :D], Q_shared) + T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared) + + for i_i in T.Pipelined(NI, num_stages=num_stages): + for bi_i in T.Parallel(BI): + mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1) + + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i] + for bi_i, d_i in T.Parallel(BI, D_tail): + K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) + T.gemm( + Q_shared, + KV_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.gemm( + Q_tail_shared, + K_tail_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + alpha[h_i] = T.exp((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i] + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + # Rescale + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] /= sumexp[h_i] + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = T.log(sumexp[h_i]) + m_i[h_i] * sm_scale + + T.copy(acc_o, Output[bos + s_i, H0:H1, :]) + T.copy(sumexp, Lse[bos + s_i, H0:H1]) + + return main + + +def sparse_mla_fwd_interface( + q, kv, indices, offsets, sm_scale=None, return_p_sum: bool = False, d_v=512, block_I=32, num_stages=2, threads=128 +): + is_casual = True + assert return_p_sum == False, "This kernel file is for fwd only" + assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() + seq_len, heads, dim_plus_tail_dim = q.shape + seq_len_kv, kv_group, _ = kv.shape + assert seq_len == seq_len_kv + + assert dim_plus_tail_dim == 576, "you should assign dim otherwise" + dim = d_v + + assert kv.shape[-1] == dim_plus_tail_dim + tail_dim = dim_plus_tail_dim - dim + _, _, topk = indices.shape + assert indices.shape == (seq_len, kv_group, topk) + + token_indices = prepare_token_indices(offsets) + + kernel = sparse_mla_fwd( + heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual, block_I=block_I, num_stages=num_stages, threads=threads + ) + out, lse = kernel(q, kv, indices, offsets, token_indices) + return out, lse + + +def ref_sparse_mla_fwd_interface(Q, KV, Indices, offsets, sm_scale=None, is_casual=True): + Q = Q.float() + KV = KV.float() + all_o = [] + for i in range(offsets.shape[0] - 1): + q = Q[None, offsets[i] : offsets[i + 1]] + kv = KV[None, offsets[i] : offsets[i + 1]] + indices = Indices[None, offsets[i] : offsets[i + 1]].clone() + + indices = indices.transpose(1, 2) + b, sq, h, dim_q = q.shape + b, sk, g, _ = kv.shape + + assert kv.shape[-1] == 576, "you should assign dim otherwise" + dim = 512 + k = kv + v = kv[..., :dim] + + b, _, _, dim_v = v.shape + g_index = g + h_index = h // g + compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( + 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda" + ).view(1, -1) + + indices[indices > sk] = sk + mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) + mask = mask[..., :-1] + mask = mask & compressed_casual_mask.view(1, 1, sq, sk) + mask[:, :, : 1 - 1, 0] = True + mask = mask.view(b, g_index, 1, sq, sk) + + q = q.view(b, sq, g, -1, dim_q) + score = torch.einsum("bmghd,bngd->bghmn", q, k) + sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale + score = score.masked_fill(~mask, float("-inf")).mul(sm_scale) + p = score.softmax(dim=-1) + p = p.view(b, g_index, h_index, -1, sq, sk) + p = p.view(b, g, -1, sq, sk) + o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v) + o = o.reshape(b, sq, h, dim_v) + all_o.append(o.squeeze(0)) + o = torch.cat(all_o, dim=0) + return o.to(torch.bfloat16) + + +def test_sparse_mla_fwd( + B=1, + S=4096, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + check_correctness=True, + block_I=64, + num_stages=2, + threads=256, +): + torch.random.manual_seed(0) + q = torch.randn((S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((S, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) + offsets = torch.tensor([0, S // 2 - 1, S], dtype=torch.int32, device="cuda") + + indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device="cuda") + for i in range(offsets.shape[0] - 1): + seq_len = (offsets[i + 1] - offsets[i]).item() + assert seq_len >= topk + for t in range(seq_len): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[offsets[i] + t, h, : len(i_i)] = i_i + + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads) + + if check_correctness: + # otherwise may cause out of memory + ref_out = ref_sparse_mla_fwd_interface(q, kv, indices, offsets) + assert_tensors_similar(tl_out, ref_out, eps=1e-2, name="out") + print("assert_tensors_similar passed") + + def fn(): + return sparse_mla_fwd_interface(q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads) + + from tilelang.profiler import do_bench + + ms = do_bench( + fn, + rep=100, + warmup=250, + ) + print(f"Average time: {ms:.3f} ms") + print("fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) + print("fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) + + +if __name__ == "__main__": + test_sparse_mla_fwd( + B=1, + S=4096, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=1024, + dtype=torch.bfloat16, + check_correctness=True, + block_I=64, + num_stages=2, + threads=256, + ) diff --git a/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py b/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py new file mode 100644 index 000000000..a03bc74f5 --- /dev/null +++ b/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py @@ -0,0 +1,226 @@ +# ruff: noqa +import torch +import torch.nn as nn +import torch.nn.functional as F +import tilelang +from tilelang import language as T +from einops import repeat, rearrange, einsum +from index import prepare_token_indices +from utils import get_abs_err, get_err_ratio + +BF16 = T.bfloat16 +FP32 = T.float32 +INT32 = T.int32 + +pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, +} + + +@tilelang.jit(pass_configs=pass_configs) +def tl_sparse_mla_topk_reducesum_impl( + heads, + dim, + tail_dim, + topk, + kv_group=1, + sm_scale=None, + block_I=32, + num_stages=2, + threads=128, +): + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 + + batch_plus_one = T.symbolic("batch_plus_one") + seq_len = T.symbolic("seq_len") + seq_len_kv = T.symbolic("seq_len_kv") + + head_kv = heads // kv_group + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 + + G = kv_group + H = head_kv + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != H: + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) + BI = block_I + NI = tilelang.cdiv(topk, block_I) + D = dim + D_tail = tail_dim + + if head_kv > 64: + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" + REPLICATE_H = head_kv // 64 + else: + REPLICATE_H = 1 + + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + q_shape = [seq_len, heads, dim + tail_dim] + kv_shape = [seq_len_kv, kv_group, dim + tail_dim] + indices_shape = [seq_len, kv_group, topk] + lse_shape = [seq_len, heads] + reducesum_shape = [seq_len, kv_group, REPLICATE_H, topk] + offsets_shape = [batch_plus_one] + token_indices_shape = [seq_len, 2] + + @T.prim_func + def tl_sparse_mla_topk_reducesum_kernel( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore + TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore + ReduceSum: T.Tensor(reducesum_shape, accum_dtype), # type: ignore + ): + with T.Kernel(seq_len * REPLICATE_H, kv_group, threads=threads) as ( + bx, + by, + ): + Q_shared = T.alloc_shared([H_per_block, D], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + KV_shared = T.alloc_shared([BI, D], dtype) + K_tail_shared = T.alloc_shared([BI, D_tail], dtype) + mask = T.alloc_fragment([BI], "bool") + + acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) + reducesum = T.alloc_fragment([BI], accum_dtype) + lse = T.alloc_fragment([H_per_block], accum_dtype) + + T.fill(lse, 0) + + b_s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H) + b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1] + bos, eos = Offsets[b_i], Offsets[b_i + 1] + r_i = bx % REPLICATE_H + g_i = by + q_i = s_i + max_kv_i = q_i + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + T.copy(Q[bos + s_i, H0:H1, :D], Q_shared) + T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared) + T.copy(Lse[bos + s_i, H0:H1], lse) + + for i_i in T.Pipelined(NI, num_stages=num_stages): + for bi_i in T.Parallel(BI): + mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1) + + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i] + for bi_i, d_i in T.Parallel(BI, D_tail): + K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) + T.gemm( + Q_shared, + KV_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.gemm( + Q_tail_shared, + K_tail_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - lse[h_i]) + T.reduce_sum(acc_s, reducesum, dim=0) + T.copy(reducesum, ReduceSum[bos + s_i, g_i, r_i, i_i * BI : i_i * BI + BI]) + + return tl_sparse_mla_topk_reducesum_kernel + + +def sparse_mla_topk_reducesum_interface( + q: torch.Tensor, + kv: torch.Tensor, + topk_indices: torch.Tensor, + lse: torch.Tensor, + offsets: torch.Tensor, + dim_v: int, +): + assert kv.shape[-2] == 1 + seq_len, heads, dim_plus_tail_dim, topk = *q.shape, topk_indices.shape[-1] + REPLICATE_H = max(heads // 64, 1) + tail_dim = dim_plus_tail_dim - dim_v + token_indices = prepare_token_indices(offsets) + + reducesum = torch.zeros([seq_len, 1, REPLICATE_H, topk], dtype=torch.float32, device=q.device) + kernel = tl_sparse_mla_topk_reducesum_impl(heads=heads, dim=dim_v, tail_dim=tail_dim, topk=topk) + kernel(q, kv, topk_indices, lse, offsets, token_indices, reducesum) + reducesum = reducesum.sum(dim=-2) # [batch, seq_len, 1, RH, topk] -> [batch, seq_len, 1, topk] + attn_score = reducesum / reducesum.sum(dim=-1, keepdim=True) + + return attn_score + + +def ref_mla_topk_softmax(Q: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, offsets: torch.Tensor): + # q: [batch, seq_len, heads, dim] + # k: [batch, seq_len, dim] + sm_scale = Q.shape[-1] ** -0.5 + all_lse = [] + all_topk_score = [] + for i in range(offsets.shape[0] - 1): + q = Q[offsets[i] : offsets[i + 1]] + k = K[offsets[i] : offsets[i + 1]] + topk_indices = TopkIndices[offsets[i] : offsets[i + 1]] + seq_len = q.shape[0] + mask = (torch.arange(seq_len)[:, None] >= torch.arange(seq_len)[None, :]).unsqueeze(-2).cuda() + logits = einsum(q, k, "s1 h d, s2 d -> s1 h s2") * sm_scale + logits = torch.where(mask, logits, float("-inf")) + score = F.softmax(logits, dim=-1, dtype=torch.float32) + score_sum = score.sum(dim=-2) + topk_score = torch.gather(score_sum, dim=-1, index=topk_indices.to(torch.int64)) + topk_score = topk_score / topk_score.sum(dim=-1, keepdim=True) + max_logits = logits.amax(dim=-1).to(torch.float32) + lse = torch.log((logits - max_logits.unsqueeze(-1).to(torch.float32)).exp().sum(dim=-1)) + max_logits + all_lse.append(lse) + all_topk_score.append(topk_score) + lse = torch.cat(all_lse, dim=0) + topk_score = torch.cat(all_topk_score, dim=0) + return lse, topk_score + + +def test_kernel( + B=1, + S=2048, + H=16, + D=512, + tail_D=64, + topk=128, +): + torch.manual_seed(42) + + q = torch.randn((S, H, D + tail_D)).cuda().bfloat16() + kv = torch.randn((S, D + tail_D)).cuda().bfloat16() + offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda() + + topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous() + + lse, ref_attn_score = ref_mla_topk_softmax(q, kv, topk_indices, offsets) + + kv = kv.unsqueeze(-2) + topk_indices = topk_indices.unsqueeze(-2) + + attn_score = sparse_mla_topk_reducesum_interface(q, kv, topk_indices, lse, offsets, dim_v=D).squeeze(-2) + print(f"attn_score err: {get_abs_err(attn_score, ref_attn_score):.6f} ratio: {get_err_ratio(attn_score, ref_attn_score):.6f}") + + +if __name__ == "__main__": + test_kernel() diff --git a/examples/dsa_sparse_finetune/utils.py b/examples/dsa_sparse_finetune/utils.py new file mode 100644 index 000000000..96afd064d --- /dev/null +++ b/examples/dsa_sparse_finetune/utils.py @@ -0,0 +1,73 @@ +import torch + + +def get_abs_err(y, x): + x = x.to(torch.float32) + y = y.to(torch.float32) + return (x - y).flatten().abs().max().item() + + +def get_err_ratio(y, x): + x = x.to(torch.float32) + y = y.to(torch.float32) + err = (x - y).flatten().square().mean().sqrt().item() + base = (x).flatten().square().mean().sqrt().item() + return err / base + + +def calculate_tensor_similarity(x, y, name="tensor"): + """ + Calculate similarity between two tensors using a normalized dot product metric. + + Unlike torch.testing.assert_close which uses absolute/relative tolerance based on + element-wise differences, this function computes a global similarity score: + sim = 2 * / (||x||^2 + ||y||^2) + + This metric is scale-invariant and measures the cosine-like similarity normalized + by the magnitude of both tensors. It returns 1 for identical tensors and values + closer to 0 for dissimilar ones. This is particularly useful for comparing tensors + with varying magnitudes where relative errors matter more than absolute differences. + + Args: + x: First tensor to compare + y: Second tensor to compare + name: Name of the tensor for logging purposes + + Returns: + Similarity score in range [0, 1] where 1 means identical + """ + x, y = x.data.double(), y.data.double() + denominator = (x * x + y * y).sum() + if denominator == 0: + print(f"\033[33mWARNING: {name} all zero\033[0m") + return 1 + sim = 2 * (x * y).sum() / denominator + return sim + + +def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): + """ + Assert that two tensors are similar using a global similarity metric. + + Key differences from torch.testing.assert_close: + - torch.testing.assert_close: Uses element-wise comparison with rtol/atol, checking + that |x - y| <= atol + rtol * |y| for each element. It's sensitive to outliers + and requires all elements to satisfy the tolerance. + - assert_tensors_similar: Uses a single global similarity score (1 - sim) where sim is the + normalized dot product. It's more robust to outliers and focuses on overall + tensor similarity rather than element-wise precision. This is better suited for + comparing large tensors where a few outlier elements shouldn't fail the test. + + Args: + x: First tensor to compare + y: Second tensor to compare + eps: Maximum allowed difference (1 - similarity), default 1e-8 + name: Name of the tensor for error messages + raise_assert: Whether to raise assertion error on failure + """ + sim = calculate_tensor_similarity(x, y, name) + diff = 1.0 - sim + if not (0 <= diff <= eps): + print(f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m") + if raise_assert: + assert False # noqa: B011 diff --git a/examples/dynamic_shape/example_dynamic.py b/examples/dynamic_shape/example_dynamic.py index be018c8b7..e338d76ca 100644 --- a/examples/dynamic_shape/example_dynamic.py +++ b/examples/dynamic_shape/example_dynamic.py @@ -1,10 +1,9 @@ import tilelang import tilelang.language as T import tilelang.testing -from tilelang import tvm as tvm -@tilelang.jit(pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8}) +@tilelang.jit def matmul_dynamic_mnk( block_M, block_N, @@ -17,9 +16,9 @@ def matmul_dynamic_mnk( num_stages, threads, ): - M = tvm.te.var("m") - N = tvm.te.var("n") - K = tvm.te.var("k") + M = T.dynamic("m") + N = T.dynamic("n") + K = T.dynamic("k") A_shape = (K, M) if trans_A else (M, K) B_shape = (N, K) if trans_B else (K, N) @@ -29,9 +28,9 @@ def matmul_dynamic_mnk( @T.prim_func def dynamic_matmul( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -53,15 +52,14 @@ def dynamic_matmul( return dynamic_matmul -def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, - accum_dtype, num_stages, threads): +def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads): print( f"M: {M}, N: {N}, K: {K}, block_M: {block_M}, block_N: {block_N}, block_K: {block_K}, trans_A: {trans_A}, trans_B: {trans_B}, in_dtype: {in_dtype}, out_dtype: {out_dtype}, accum_dtype: {accum_dtype}, num_stages: {num_stages}, threads: {threads}" ) - kernel = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, - accum_dtype, num_stages, threads) + kernel = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads) import torch + if trans_A: A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype)) else: @@ -103,8 +101,30 @@ def main(M=16384, N=16384, K=16384): accum_dtype = "float32" num_stages = 3 threads = 128 - matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, - accum_dtype, num_stages, threads) + matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads) + + +def run_regression_perf(M=4096, N=4096, K=4096): + block_M, block_N, block_K = 128, 128, 32 + trans_A, trans_B = False, False + in_dtype, out_dtype = "float16", "float16" + accum_dtype = "float32" + num_stages = 3 + threads = 128 + kernel = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads) + import torch + + if trans_A: + A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype)) + else: + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + if trans_B: + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + else: + B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(input_tensors=[A, B, C], backend="cupti") if __name__ == "__main__": diff --git a/examples/dynamic_shape/regression_example_dynamic.py b/examples/dynamic_shape/regression_example_dynamic.py new file mode 100644 index 000000000..958695990 --- /dev/null +++ b/examples/dynamic_shape/regression_example_dynamic.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_dynamic + + +def regression_example_dynamic(): + tilelang.testing.process_func(example_dynamic.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/elementwise/example_elementwise_add.py b/examples/elementwise/example_elementwise_add.py index bc9bb4df5..32da94015 100644 --- a/examples/elementwise/example_elementwise_add.py +++ b/examples/elementwise/example_elementwise_add.py @@ -3,19 +3,25 @@ import torch import tilelang import tilelang.language as T -from tilelang.autotuner import AutoTuner def ref_program(x, y): return x + y +def get_configs(): + block_M = [64, 128, 256] + block_N = [64, 128, 256] + threads = [64, 128, 256] + configs = list(itertools.product(block_M, block_N, threads)) + return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs] + + +@tilelang.autotune(configs=get_configs()) @tilelang.jit(out_idx=[-1]) def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): - @T.prim_func - def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor( - (M, N), out_dtype)): + def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared((block_M, block_N), in_dtype) B_shared = T.alloc_shared((block_M, block_N), in_dtype) @@ -24,7 +30,7 @@ def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T. T.copy(A[by * block_M, bx * block_N], A_shared) T.copy(B[by * block_M, bx * block_N], B_shared) - for (local_y, local_x) in T.Parallel(block_M, block_N): + for local_y, local_x in T.Parallel(block_M, block_N): C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x] T.copy(C_local, C_shared) T.copy(C_shared, C[by * block_M, bx * block_N]) @@ -32,53 +38,40 @@ def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T. return elem_add -def get_configs(M, N): - block_M = [64, 128, 256] - block_N = [64, 128, 256] - threads = [64, 128, 256] - configs = list(itertools.product(block_M, block_N, threads)) - return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs] - - -def get_best_config(M, N): +def main(M=1024, N=1024, use_autotune=False): + a = torch.randn(M, N, dtype=torch.float32, device="cuda") + b = torch.randn(M, N, dtype=torch.float32, device="cuda") - def kernel(block_M=None, block_N=None, threads=None): - return elementwise_add(M, N, block_M, block_N, "float32", "float32", threads) + if use_autotune: + kernel = elementwise_add(M, N, in_dtype=T.float32, out_dtype=T.float32) + else: + # Default config + config = {"block_M": 32, "block_N": 32, "threads": 128} + kernel = elementwise_add(M, N, **config, in_dtype=T.float32, out_dtype=T.float32) - autotuner = AutoTuner.from_kernel( - kernel=kernel, configs=get_configs(M, N)).set_compile_args( - out_idx=[-1], - target="cuda", - ).set_profile_args( - supply_type=tilelang.TensorSupplyType.Auto, - ref_prog=ref_program, - skip_check=False, - ) - return autotuner.run(warmup=3, rep=20) + out = kernel(a, b) + torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) -def main(): +def run_regression_perf(): parser = argparse.ArgumentParser() - parser.add_argument("--m", type=int, default=1024) - parser.add_argument("--n", type=int, default=1024) - parser.add_argument("--use_autotune", action="store_true", default=False) + parser.add_argument("--m", type=int, default=4096) + parser.add_argument("--n", type=int, default=4096) args, _ = parser.parse_known_args() M, N = args.m, args.n - a = torch.randn(M, N, dtype=torch.float32, device="cuda") b = torch.randn(M, N, dtype=torch.float32, device="cuda") + config = {"block_M": 32, "block_N": 32, "threads": 128} + kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") + from tilelang.profiler import do_bench - if args.use_autotune: - result = get_best_config(M, N) - kernel = result.kernel - else: - # Default config - config = {"block_M": 32, "block_N": 32, "threads": 128} - kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") - - out = kernel(a, b) - torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) + return do_bench(lambda: kernel(a, b), backend="cupti") if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + parser.add_argument("--m", type=int, default=1024) + parser.add_argument("--n", type=int, default=1024) + parser.add_argument("--use_autotune", action="store_true", default=False) + args, _ = parser.parse_known_args() + main(args.m, args.n, args.use_autotune) diff --git a/examples/elementwise/example_elementwise_add_tma_1d.py b/examples/elementwise/example_elementwise_add_tma_1d.py index 0467eba88..501e1f00d 100644 --- a/examples/elementwise/example_elementwise_add_tma_1d.py +++ b/examples/elementwise/example_elementwise_add_tma_1d.py @@ -10,10 +10,8 @@ def ref_program(x, y): @tilelang.jit(out_idx=[-1]) def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): - @T.prim_func - def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor( - (M, N), out_dtype)): + def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared((block_M, block_N), in_dtype) B_shared = T.alloc_shared((block_M, block_N), in_dtype) @@ -22,7 +20,7 @@ def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T. T.copy(A[by * block_M, bx * block_N], A_shared) T.copy(B[by * block_M, bx * block_N], B_shared) - for (local_y, local_x) in T.Parallel(block_M, block_N): + for local_y, local_x in T.Parallel(block_M, block_N): C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x] T.copy(C_local, C_shared) T.copy(C_shared, C[by * block_M, bx * block_N]) diff --git a/examples/elementwise/regression_example_elementwise.py b/examples/elementwise/regression_example_elementwise.py new file mode 100644 index 000000000..261202a56 --- /dev/null +++ b/examples/elementwise/regression_example_elementwise.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_elementwise_add + + +def regression_example_elementwise_add(): + tilelang.testing.process_func(example_elementwise_add.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/elementwise/test_example_elementwise.py b/examples/elementwise/test_example_elementwise.py index ff0b45a0a..24f675cd6 100644 --- a/examples/elementwise/test_example_elementwise.py +++ b/examples/elementwise/test_example_elementwise.py @@ -1,14 +1,13 @@ import tilelang.testing import example_elementwise_add -import example_elementwise_add_tma_1d def test_example_elementwise_add(): example_elementwise_add.main() -def test_example_elementwise_add_tma_1d(): - example_elementwise_add_tma_1d.main() +def test_example_elementwise_add_autotune(): + example_elementwise_add.main(use_autotune=True) if __name__ == "__main__": diff --git a/examples/flash_attention/README.md b/examples/flash_attention/README.md index be11a8dc6..355ed7325 100644 --- a/examples/flash_attention/README.md +++ b/examples/flash_attention/README.md @@ -34,8 +34,6 @@ def flash_attention( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - # Annotate layout for Q_shared, e.g., use a swizzled layout to optimize memory access - T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) # Copy a block of Q from global memory to Q_shared T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) @@ -77,6 +75,8 @@ def flash_attention( # Compute the maximum value per row on dimension 1 (block_N) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # Compute the factor by which we need to rescale previous partial sums for i in T.Parallel(block_M): @@ -106,4 +106,4 @@ def flash_attention( # Write back the final output block from acc_o to the Output buffer T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) -``` \ No newline at end of file +``` diff --git a/examples/flash_attention/bert_padding.py b/examples/flash_attention/bert_padding.py index 7058fd773..15c4097ce 100644 --- a/examples/flash_attention/bert_padding.py +++ b/examples/flash_attention/bert_padding.py @@ -6,7 +6,6 @@ class IndexFirstAxis(torch.autograd.Function): - @staticmethod def forward(ctx, input, indices): ctx.save_for_backward(indices) @@ -15,9 +14,7 @@ def forward(ctx, input, indices): second_dim = other_shape.numel() # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. # return input[indices] - return torch.gather( - rearrange(input, "b ... -> b (...)"), 0, - repeat(indices, "z -> z d", d=second_dim)).reshape(-1, *other_shape) + return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)).reshape(-1, *other_shape) @staticmethod def backward(ctx, grad_output): @@ -40,14 +37,12 @@ def backward(ctx, grad_output): class IndexPutFirstAxis(torch.autograd.Function): - @staticmethod def forward(ctx, values, indices, first_axis_dim): ctx.save_for_backward(indices) assert indices.ndim == 1 assert values.ndim >= 2 - output = torch.zeros( - first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) + output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. output[indices] = values # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) @@ -66,7 +61,6 @@ def backward(ctx, grad_output): class IndexFirstAxisResidual(torch.autograd.Function): - @staticmethod def forward(ctx, input, indices): ctx.save_for_backward(indices) @@ -128,7 +122,7 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng """ Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model). The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286). - + For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: ``` [ @@ -177,9 +171,7 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng """ length = attention_mask_in_length.sum(dim=-1) seqlen = attention_mask_in_length.size(-1) - attention_mask_2d = torch.arange( - seqlen, device=length.device, dtype=length.dtype).expand(len(length), - seqlen) < length.unsqueeze(1) + attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1) real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten() seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() diff --git a/examples/flash_attention/example_gqa_bwd.py b/examples/flash_attention/example_gqa_bwd.py index 907a121d2..801927faf 100644 --- a/examples/flash_attention/example_gqa_bwd.py +++ b/examples/flash_attention/example_gqa_bwd.py @@ -6,25 +6,27 @@ @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_fwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -39,26 +41,25 @@ def flash_fwd( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim_v): @@ -72,29 +73,31 @@ def flash_fwd( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim_v): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim_v] blk = 32 @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -103,81 +106,74 @@ def flash_bwd_prep( delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim_v, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment - return T.Layout(dQ.shape, - lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + return T.Layout(dQ.shape, lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) @tilelang.jit( - out_idx=[1], pass_configs={ + out_idx=[1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim_qk] blk = 64 @T.prim_func def flash_bwd_post( - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(shape, dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) T.copy( - dQ[bz, bx * blk:(bx + 1) * blk, by, :], - dQ_out[bz, bx * blk:(bx + 1) * blk, by, :], + dQ[bz, bx * blk : (bx + 1) * blk, by, :], + dQ_out[bz, bx * blk : (bx + 1) * blk, by, :], ) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_atomic_add(batch, - heads, - seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -197,35 +193,35 @@ def flash_bwd( dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype) dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - }) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -237,49 +233,41 @@ def flash_bwd( for i, j in T.Parallel(block_N, dim_qk): T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) - T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared) + T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared) T.copy(dk, dk_shared) - T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared) + T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared) return flash_bwd -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_split(batch, - heads, - seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_split(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(dk_shape, dtype), # type: ignore - dV: T.Tensor(dv_shape, dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(dk_shape, dtype), # type: ignore + dV: T.Tensor(dv_shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -299,37 +287,35 @@ def flash_bwd( dv_shared = T.alloc_shared([block_M, dim_v], dtype) dk_shared = T.alloc_shared([block_M, dim_qk], dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -342,16 +328,15 @@ def flash_bwd( T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) - T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) + T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) T.copy(dk, dk_shared) - T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) + T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) return flash_bwd @torch.compile class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): BATCH, N_CTX, H, D_HEAD_QK = q.shape @@ -369,7 +354,10 @@ def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): def backward(ctx, do): q, k, v, o, lse = ctx.saved_tensors BATCH, N_CTX, H, D_HEAD_QK = q.shape - HEAD_KV, D_HEAD_V, = v.shape[-2], v.shape[-1] + ( + HEAD_KV, + D_HEAD_V, + ) = v.shape[-2], v.shape[-1] groups = H // HEAD_KV def maybe_contiguous(x): @@ -386,17 +374,8 @@ def maybe_contiguous(x): if ctx.use_atomic: kernel = flashattn_bwd_atomic_add( - BATCH, - H, - N_CTX, - D_HEAD_QK, - D_HEAD_V, - ctx.causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=groups) + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] @@ -409,17 +388,8 @@ def maybe_contiguous(x): dv = dv.to(torch.float16) else: kernel = flashattn_bwd_split( - BATCH, - H, - N_CTX, - D_HEAD_QK, - D_HEAD_V, - ctx.causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=groups) + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel @@ -441,53 +411,45 @@ def ref_program(Q, K, V, is_causal, groups=1): # K: [B, T, HK, D_QK] # V: [B, T, HV, D_V] # HQ = HKV * groups - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim_qk = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output -def main(BATCH: int = 1, - H: int = 32, - N_CTX: int = 256, - D_HEAD_QK: int = 192, - D_HEAD_V: int = 128, - groups: int = 16, - causal: bool = False, - use_atomic: bool = True): +def main( + BATCH: int = 1, + H: int = 32, + N_CTX: int = 256, + D_HEAD_QK: int = 192, + D_HEAD_V: int = 128, + groups: int = 16, + causal: bool = False, + use_atomic: bool = True, +): flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V total_flops = 3 * flops_per_qk + 2 * flops_per_v if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() head_kv = H // groups - K = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - V = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - dO = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() O = attention(Q, K, V, causal, groups, use_atomic) O.backward(dO, retain_graph=True) dQ, Q.grad = Q.grad.clone(), None @@ -504,7 +466,7 @@ def main(BATCH: int = 1, torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) - print('All checks passed.✅') + print("All checks passed.✅") def run(): O_ref.backward(dO, retain_graph=True) @@ -522,19 +484,61 @@ def run1(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(): + BATCH = 1 + H = 32 + N_CTX = 256 + D_HEAD_QK = 192 + D_HEAD_V = 128 + groups = 16 + causal = False + device = "cuda" + torch.manual_seed(42) + head_kv = H // groups + Q = torch.randn(BATCH, N_CTX, H, D_HEAD_QK, device=device, dtype=torch.half) + K = torch.randn(BATCH, N_CTX, head_kv, D_HEAD_QK, device=device, dtype=torch.half) + V = torch.randn(BATCH, N_CTX, head_kv, D_HEAD_V, device=device, dtype=torch.half) + O = torch.randn(BATCH, N_CTX, H, D_HEAD_V, device=device, dtype=torch.half) + dO = torch.randn(BATCH, N_CTX, H, D_HEAD_V, device=device, dtype=torch.half) + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) + kernel = flashattn_bwd_split( + BATCH, + H, + N_CTX, + D_HEAD_QK, + D_HEAD_V, + causal, + block_M=128, + block_N=32, + threads=256, + num_stages=2, + groups=groups, + ) + dQ = torch.zeros_like(Q, dtype=torch.float32) + dK = torch.zeros(groups, BATCH, N_CTX, head_kv, D_HEAD_QK, device=device, dtype=torch.float16) + dV = torch.zeros(groups, BATCH, N_CTX, head_kv, D_HEAD_V, device=device, dtype=torch.float16) + Delta = mod_prep(O, dO) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, dQ, dK, dV) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') - parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') - parser.add_argument('--causal', action='store_true', help='Causal flag') - parser.add_argument('--groups', type=int, default=16, help='groups') - parser.add_argument( - '--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') - parser.add_argument( - '--use_split', action='store_true', default=False, help='Use split for dK/dV') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV") + parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV") args = parser.parse_args() # Handle backward compatibility and logic @@ -546,5 +550,4 @@ def run1(): # Default: use atomic use_atomic = True - main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, - use_atomic) + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce.py b/examples/flash_attention/example_gqa_bwd_tma_reduce.py index b0732eb5a..fea547b6e 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce.py @@ -5,27 +5,31 @@ from tilelang.contrib import nvcc import argparse +tilelang.disable_cache() + @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_fwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -40,26 +44,27 @@ def flash_fwd( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + # Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops + # We should set it to negative large number instead + T.fill(scores_max, T.Cast(accum_dtype, -1e30)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, T.Cast(accum_dtype, -1e30)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim_v): @@ -73,29 +78,31 @@ def flash_fwd( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim_v): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim_v] blk = 32 @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -104,12 +111,12 @@ def flash_bwd_prep( delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim_v, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep @@ -120,12 +127,14 @@ def make_dq_layout(dQ): @tilelang.jit( - out_idx=[3, 4, 5], pass_configs={ + out_idx=[3, 4, 5], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] @@ -133,64 +142,55 @@ def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v): @T.prim_func def flash_bwd_post( - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(q_shape, dtype), # type: ignore - dK_out: T.Tensor(k_shape, dtype), # type: ignore - dV_out: T.Tensor(v_shape, dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(q_shape, dtype), # type: ignore + dK_out: T.Tensor(k_shape, dtype), # type: ignore + dV_out: T.Tensor(v_shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) - T.copy(dQ[bz, bx * blk:(bx + 1) * blk, by, :], dQ_out[bz, bx * blk:(bx + 1) * blk, - by, :]) + T.copy(dQ[bz, bx * blk : (bx + 1) * blk, by, :], dQ_out[bz, bx * blk : (bx + 1) * blk, by, :]) with T.Kernel(T.ceildiv(seq_len, blk), head_kv, batch, threads=128) as (bx, by, bz): - T.annotate_layout({ - dK: make_dq_layout(dK), - dV: make_dq_layout(dV), - }) - T.copy(dK[bz, bx * blk:(bx + 1) * blk, by, :], dK_out[bz, bx * blk:(bx + 1) * blk, - by, :]) - T.copy(dV[bz, bx * blk:(bx + 1) * blk, by, :], dV_out[bz, bx * blk:(bx + 1) * blk, - by, :]) + T.annotate_layout( + { + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + } + ) + T.copy(dK[bz, bx * blk : (bx + 1) * blk, by, :], dK_out[bz, bx * blk : (bx + 1) * blk, by, :]) + T.copy(dV[bz, bx * blk : (bx + 1) * blk, by, :], dV_out[bz, bx * blk : (bx + 1) * blk, by, :]) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_atomic_add(batch, - heads, - seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -211,37 +211,37 @@ def flash_bwd( dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - dK: make_dq_layout(dK), - dV: make_dq_layout(dV), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - }) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + } + ) - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -251,53 +251,43 @@ def flash_bwd( T.clear(dq) T.gemm(dsT_shared, K_shared, dq, transpose_A=True) T.copy(dq, dq_shared) - T.atomic_add(dQ[bz, k * block_N:(k + 1) * block_N, bx, :], dq_shared, use_tma=True) + T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared, use_tma=True) T.copy(dv, dv_shared) - T.atomic_add( - dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared, use_tma=True) + T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared, use_tma=True) T.copy(dk, dk_shared) - T.atomic_add( - dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared, use_tma=True) + T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared, use_tma=True) return flash_bwd -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_split(batch, - heads, - seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_split_novarlen(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(dk_shape, dtype), # type: ignore - dV: T.Tensor(dv_shape, dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(dk_shape, dtype), # type: ignore + dV: T.Tensor(dv_shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -317,37 +307,35 @@ def flash_bwd( dv_shared = T.alloc_shared([block_M, dim_v], dtype) dk_shared = T.alloc_shared([block_M, dim_qk], dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -360,16 +348,15 @@ def flash_bwd( T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) - T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) + T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) T.copy(dk, dk_shared) - T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) + T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) return flash_bwd @torch.compile class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): BATCH, N_CTX, H, D_HEAD_QK = q.shape @@ -387,7 +374,10 @@ def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): def backward(ctx, do): q, k, v, o, lse = ctx.saved_tensors BATCH, N_CTX, H, D_HEAD_QK = q.shape - HEAD_KV, D_HEAD_V, = v.shape[-2], v.shape[-1] + ( + HEAD_KV, + D_HEAD_V, + ) = v.shape[-2], v.shape[-1] groups = H // HEAD_KV def maybe_contiguous(x): @@ -404,17 +394,8 @@ def maybe_contiguous(x): if ctx.use_atomic: kernel = flashattn_bwd_atomic_add( - BATCH, - H, - N_CTX, - D_HEAD_QK, - D_HEAD_V, - ctx.causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=groups) + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] @@ -424,18 +405,9 @@ def maybe_contiguous(x): kernel(q, k, v, do, lse, delta, dq, dk, dv) dq, dk, dv = mod_post(dq, dk, dv) else: - kernel = flashattn_bwd_split( - BATCH, - H, - N_CTX, - D_HEAD_QK, - D_HEAD_V, - ctx.causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=groups) + kernel = flashattn_bwd_split_novarlen( + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel @@ -443,8 +415,7 @@ def maybe_contiguous(x): dk = torch.empty(shape_k, dtype=torch.float16, device=q.device) dv = torch.empty(shape_v, dtype=torch.float16, device=q.device) kernel(q, k, v, do, lse, delta, dq, dk, dv) - dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), - torch.zeros_like(v, dtype=torch.float32)) + dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32)) dk, dv = dk.sum(0), dv.sum(0) return dq, dk, dv, None, None, None @@ -458,53 +429,45 @@ def ref_program(Q, K, V, is_causal, groups=1): # K: [B, T, HK, D_QK] # V: [B, T, HV, D_V] # HQ = HKV * groups - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim_qk = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output -def main(BATCH: int = 1, - H: int = 32, - N_CTX: int = 256, - D_HEAD_QK: int = 192, - D_HEAD_V: int = 128, - groups: int = 16, - causal: bool = False, - use_atomic: bool = True): +def main( + BATCH: int = 1, + H: int = 32, + N_CTX: int = 256, + D_HEAD_QK: int = 192, + D_HEAD_V: int = 128, + groups: int = 16, + causal: bool = False, + use_atomic: bool = True, +): flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V total_flops = 3 * flops_per_qk + 2 * flops_per_v if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() head_kv = H // groups - K = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - V = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - dO = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() O = attention(Q, K, V, causal, groups, use_atomic) O.backward(dO, retain_graph=True) dQ, Q.grad = Q.grad.clone(), None @@ -521,7 +484,7 @@ def main(BATCH: int = 1, torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) - print('All checks passed.✅') + print("All checks passed.✅") def run(): O_ref.backward(dO, retain_graph=True) @@ -544,17 +507,15 @@ def run1(): print(f"Detected GPU compute capability: {arch}") assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0" parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') - parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') - parser.add_argument('--causal', action='store_true', help='Causal flag') - parser.add_argument('--groups', type=int, default=16, help='groups') - parser.add_argument( - '--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') - parser.add_argument( - '--use_split', action='store_true', default=False, help='Use split for dK/dV') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV") + parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV") args = parser.parse_args() # Handle backward compatibility and logic @@ -566,5 +527,4 @@ def run1(): # Default: use atomic use_atomic = True - main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, - use_atomic) + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py index 82d363768..a9f45e077 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py @@ -7,56 +7,44 @@ from einops import rearrange, repeat from bert_padding import pad_input, unpad_input -torch.manual_seed(1) - def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): assert mode in ["full", "random", "third"] if mode == "full": lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) elif mode == "random": - lengths = torch.randint( - max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) + lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) elif mode == "third": lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) - padding_mask = ( - repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths) + padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths return padding_mask @tilelang.jit( - out_idx=[5, 6], pass_configs={ + out_idx=[5, 6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn_fwd(batch, - total_q, - total_kv, - heads, - max_seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - groups=1): - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn_fwd(batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [total_q, heads, dim_qk] k_shape = [total_kv, head_kv, dim_qk] v_shape = [total_kv, head_kv, dim_v] o_shape = [total_q, heads, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_fwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore - cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore - Output: T.Tensor(o_shape, dtype), # type: ignore - lse: T.Tensor([total_q, heads], accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore + cu_seqlens_k: T.Tensor([batch + 1], T.int32), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(max_seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -78,8 +66,6 @@ def flash_fwd( q_current_seqlen = q_end_idx - q_start_idx k_current_seqlen = k_end_idx - k_start_idx - T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - for i, d in T.Parallel(block_M, dim_qk): if bx * block_M + i < q_current_seqlen: Q_shared[i, d] = Q[q_start_idx + bx * block_M + i, by, d] @@ -88,7 +74,9 @@ def flash_fwd( T.fill(acc_o, 0.0) T.fill(logsum, 0.0) - T.fill(scores_max, -T.infinity(accum_dtype)) + # Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops + # We should set it to negative large number instead + T.fill(scores_max, T.Cast(accum_dtype, -1e30)) loop_range = T.ceildiv(k_current_seqlen, block_N) for k in T.Pipelined(loop_range, num_stages=1): for i, d in T.Parallel(block_N, dim_qk): @@ -99,15 +87,17 @@ def flash_fwd( if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and - (bx * block_M + i < q_current_seqlen and - k * block_N + j < k_current_seqlen), 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= k * block_N + j) + and (bx * block_M + i < q_current_seqlen and k * block_N + j < k_current_seqlen), + 0, + T.Cast(accum_dtype, -1e30), + ) else: for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.if_then_else( - bx * block_M + i < q_current_seqlen and - k * block_N + j < k_current_seqlen, 0, -T.infinity(acc_s.dtype)) + bx * block_M + i < q_current_seqlen and k * block_N + j < k_current_seqlen, 0, T.Cast(accum_dtype, -1e30) + ) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, d in T.Parallel(block_N, dim_v): if k * block_N + i < k_current_seqlen: @@ -116,6 +106,8 @@ def flash_fwd( V_shared[i, d] = 0.0 T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim_v): @@ -137,27 +129,29 @@ def flash_fwd( for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale if bx * block_M + i < q_current_seqlen: - lse[q_start_idx + bx * block_M + i, by] = logsum[i] + lse[bz, by, bx * block_M + i] = logsum[i] return flash_fwd @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn_bwd_preprocess(batch, heads, total_q, max_seq_len, dim_v): - dtype = "float16" - accum_dtype = "float" + }, +) +def flashattn_bwd_preprocess(batch, heads, total_q, N_CTX, max_seq_len, dim_v): + dtype = T.float16 + accum_dtype = T.float32 shape = [total_q, heads, dim_v] blk = 32 @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore - Delta: T.Tensor([total_q, heads], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore + Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(max_seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -185,23 +179,25 @@ def flash_bwd_prep( for i in T.Parallel(blk): if by * blk + i < q_current_seqlen: - Delta[q_start_idx + by * blk + i, bx] = delta[i] + Delta[bz, bx, by * blk + i] = delta[i] return flash_bwd_prep def make_dq_layout(dQ): - # bshd -> bhld to use tma reduction instruction - return T.Layout(dQ.shape, lambda b, l, h, d: [b, h, l, d]) + # bshd -> bhsd to use tma reduction instruction + return T.Layout(dQ.shape, lambda l, h, d: [h, l, d]) @tilelang.jit( - out_idx=[3, 4, 5], pass_configs={ + out_idx=[3, 4, 5], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 q_shape = [total_q, heads, dim_qk] k_shape = [total_kv, head_kv, dim_qk] v_shape = [total_kv, head_kv, dim_v] @@ -209,69 +205,62 @@ def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v): @T.prim_func def flash_bwd_post( - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(q_shape, dtype), # type: ignore - dK_out: T.Tensor(k_shape, dtype), # type: ignore - dV_out: T.Tensor(v_shape, dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(q_shape, dtype), # type: ignore + dK_out: T.Tensor(k_shape, dtype), # type: ignore + dV_out: T.Tensor(v_shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(total_q, blk), heads, threads=128) as (bx, by): - # T.annotate_layout({dQ: make_dq_layout(dQ)}) - T.copy(dQ[bx * blk:(bx + 1) * blk, by, :], dQ_out[bx * blk:(bx + 1) * blk, by, :]) + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy(dQ[bx * blk : (bx + 1) * blk, by, :], dQ_out[bx * blk : (bx + 1) * blk, by, :]) with T.Kernel(T.ceildiv(total_kv, blk), head_kv, threads=128) as (bx, by): - # T.annotate_layout({ - # dK: make_dq_layout(dK), - # dV: make_dq_layout(dV), - # }) - T.copy(dK[bx * blk:(bx + 1) * blk, by, :], dK_out[bx * blk:(bx + 1) * blk, by, :]) - T.copy(dV[bx * blk:(bx + 1) * blk, by, :], dV_out[bx * blk:(bx + 1) * blk, by, :]) + T.annotate_layout( + { + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + } + ) + T.copy(dK[bx * blk : (bx + 1) * blk, by, :], dK_out[bx * blk : (bx + 1) * blk, by, :]) + T.copy(dV[bx * blk : (bx + 1) * blk, by, :], dV_out[bx * blk : (bx + 1) * blk, by, :]) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_atomic_add(batch, - total_q, - total_kv, - heads, - max_seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_atomic_add( + batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1 +): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [total_q, heads, dim_qk] k_shape = [total_kv, head_kv, dim_qk] v_shape = [total_kv, head_kv, dim_v] do_shape = [total_q, heads, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor(do_shape, dtype), # type: ignore - lse: T.Tensor([total_q, heads], accum_dtype), # type: ignore - Delta: T.Tensor([total_q, heads], accum_dtype), # type: ignore - cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore - cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor(do_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore + cu_seqlens_k: T.Tensor([batch + 1], T.int32), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore ): - with T.Kernel( - heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz): + with T.Kernel(heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype) q = T.alloc_shared([block_N, dim_qk], dtype) @@ -286,6 +275,9 @@ def flash_bwd( dv = T.alloc_fragment([block_M, dim_v], accum_dtype) dk = T.alloc_fragment([block_M, dim_qk], accum_dtype) dq = T.alloc_fragment([block_N, dim_qk], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) + dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype) + dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype) q_start_idx = cu_seqlens_q[bz] k_start_idx = cu_seqlens_k[bz] @@ -294,71 +286,53 @@ def flash_bwd( q_current_seqlen = q_end_idx - q_start_idx k_current_seqlen = k_end_idx - k_start_idx - T.annotate_layout({ - # dQ: make_dq_layout(dQ), - # dK: make_dq_layout(dK), - # dV: make_dq_layout(dV), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - }) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + } + ) - for i, d in T.Parallel(block_M, dim_qk): - if by * block_M + i < k_current_seqlen: - K_shared[i, d] = K[k_start_idx + by * block_M + i, bx // groups, d] - V_shared[i, d] = V[k_start_idx + by * block_M + i, bx // groups, d] - else: - K_shared[i, d] = 0.0 - V_shared[i, d] = 0.0 + T.copy(K[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) - loop_st = (T.floordiv(by * block_M, block_N) if is_causal else 0) + loop_st = T.min(T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, block_N)) if is_causal else 0 loop_ed = T.ceildiv(q_current_seqlen, block_N) for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - for i, d in T.Parallel(block_N, dim_qk): - if k_base * block_N + i < q_current_seqlen: - q[i, d] = Q[q_start_idx + k_base * block_N + i, bx, d] - else: - q[i, d] = 0.0 + T.copy(Q[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - for i in T.Parallel(block_N): - if k_base * block_N + i < q_current_seqlen: - lse_shared[i] = lse[q_start_idx + k_base * block_N + i, bx] - else: - lse_shared[i] = 0.0 + T.copy(lse[bz, bx, k_base * block_N : (k_base + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else((by * block_M + i <= k_base * block_N + j) and - (by * block_M + i < k_current_seqlen and - k_base * block_N + j < q_current_seqlen), - qkT[i, j], 0) + qkT[i, j] = T.if_then_else( + (by * block_M + i <= k_base * block_N + j) + and (by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen), + qkT[i, j], + 0, + ) else: for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.if_then_else( - by * block_M + i < k_current_seqlen and - k_base * block_N + j < q_current_seqlen, qkT[i, j], 0) + by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen, qkT[i, j], 0 + ) - for i, d in T.Parallel(block_N, dim_v): - if k_base * block_N + i < q_current_seqlen: - do[i, d] = dO[q_start_idx + k_base * block_N + i, bx, d] - else: - do[i, d] = 0.0 + T.copy(dO[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], do) T.clear(dsT) # dsT: (block_kv, block_q) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - for i in T.Parallel(block_N): - if k_base * block_N + i < q_current_seqlen: - delta[i] = Delta[q_start_idx + k_base * block_N + i, bx] - else: - delta[i] = 0.0 + T.copy(Delta[bz, bx, k_base * block_N : (k_base + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) @@ -366,44 +340,42 @@ def flash_bwd( T.copy(dsT_cast, dsT_shared) T.clear(dq) T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + T.copy(dq, dq_shared) T.atomic_add( - dQ[q_start_idx + k_base * block_N:q_start_idx + k_base * block_N + block_N, - bx, :], - dq, - memory_order="release") + dQ[q_start_idx + k_base * block_N : q_start_idx + k_base * block_N + block_N, bx, :], + dq_shared, + memory_order="relaxed", + use_tma=True, + ) + T.copy(dv, dv_shared) T.atomic_add( - dV[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M, - bx // groups, :], - dv, - memory_order="release") + dV[k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :], + dv_shared, + memory_order="relaxed", + use_tma=True, + ) + T.copy(dk, dk_shared) T.atomic_add( - dK[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M, - bx // groups, :], - dk, - memory_order="release") + dK[k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :], + dk_shared, + memory_order="relaxed", + use_tma=True, + ) return flash_bwd -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_split(batch, - total_q, - total_kv, - heads, - max_seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_split( + batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1 +): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [total_q, heads, dim_qk] k_shape = [total_kv, head_kv, dim_qk] @@ -411,25 +383,24 @@ def flashattn_bwd_split(batch, do_shape = [total_q, heads, dim_v] dk_shape = [groups, total_kv, head_kv, dim_qk] # sum after kernel dv_shape = [groups, total_kv, head_kv, dim_v] # sum after kernel - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor(do_shape, dtype), # type: ignore - lse: T.Tensor([total_q, heads], accum_dtype), # type: ignore - Delta: T.Tensor([total_q, heads], accum_dtype), # type: ignore - cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore - cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(dk_shape, dtype), # type: ignore - dV: T.Tensor(dv_shape, dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor(do_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore + cu_seqlens_k: T.Tensor([batch + 1], T.int32), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(dk_shape, dtype), # type: ignore + dV: T.Tensor(dv_shape, dtype), # type: ignore ): - with T.Kernel( - heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz): + with T.Kernel(heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype) q = T.alloc_shared([block_N, dim_qk], dtype) @@ -454,67 +425,52 @@ def flash_bwd( q_current_seqlen = q_end_idx - q_start_idx k_current_seqlen = k_end_idx - k_start_idx - T.annotate_layout({ - # dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) - for i, d in T.Parallel(block_M, dim_qk): - if by * block_M + i < k_current_seqlen: - K_shared[i, d] = K[k_start_idx + by * block_M + i, bx // groups, d] - V_shared[i, d] = V[k_start_idx + by * block_M + i, bx // groups, d] - else: - K_shared[i, d] = 0.0 - V_shared[i, d] = 0.0 + T.copy(K[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) - loop_st = (T.floordiv(by * block_M, block_N) if is_causal else 0) + loop_st = T.min(T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, block_N)) if is_causal else 0 loop_ed = T.ceildiv(q_current_seqlen, block_N) for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - for i, d in T.Parallel(block_N, dim_qk): - if k_base * block_N + i < q_current_seqlen: - q[i, d] = Q[q_start_idx + k_base * block_N + i, bx, d] - else: - q[i, d] = 0.0 + # Note: The padding zero of varlen should be considered in T.copy + T.copy(Q[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - for i, d in T.Parallel(block_N, dim_v): - if k_base * block_N + i < q_current_seqlen: - do[i, d] = dO[q_start_idx + k_base * block_N + i, bx, d] - else: - do[i, d] = 0.0 + + T.copy(dO[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], do) + T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - for i in T.Parallel(block_N): - if k_base * block_N + i < q_current_seqlen: - lse_shared[i] = lse[q_start_idx + k_base * block_N + i, bx] - else: - lse_shared[i] = 0.0 + + T.copy(lse[bz, bx, k_base * block_N : (k_base + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else((by * block_M + i <= k_base * block_N + j) and - (by * block_M + i < k_current_seqlen and - k_base * block_N + j < q_current_seqlen), - qkT[i, j], 0) + qkT[i, j] = T.if_then_else( + (by * block_M + i <= k_base * block_N + j) + and (by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen), + qkT[i, j], + 0, + ) else: for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.if_then_else( - by * block_M + i < k_current_seqlen and - k_base * block_N + j < q_current_seqlen, qkT[i, j], 0) + by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen, qkT[i, j], 0 + ) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - for i in T.Parallel(block_N): - if k_base * block_N + i < q_current_seqlen: - delta[i] = Delta[q_start_idx + k_base * block_N + i, bx] - else: - delta[i] = 0.0 + + T.copy(Delta[bz, bx, k_base * block_N : (k_base + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -525,57 +481,38 @@ def flash_bwd( T.gemm(dsT_shared, K_shared, dq, transpose_A=True) for i, j in T.Parallel(block_N, dim_qk): if k_base * block_N + i < q_current_seqlen: - T.atomic_add(dQ[q_start_idx + k_base * block_N + i, bx, j], dq[i, j]) + T.atomic_add(dQ[q_start_idx + k_base * block_N + i, bx, j], dq[i, j], memory_order="relaxed") T.copy(dv, dv_shared) - for i, d in T.Parallel(block_M, dim_v): - if by * block_M + i < k_current_seqlen: - dV[bx % groups, k_start_idx + by * block_M + i, bx // groups, d] = dv[i, d] + T.copy(dv_shared, dV[bx % groups, k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :]) T.copy(dk, dk_shared) - for i, d in T.Parallel(block_M, dim_qk): - if by * block_M + i < k_current_seqlen: - dK[bx % groups, k_start_idx + by * block_M + i, bx // groups, d] = dk[i, d] + T.copy(dk_shared, dK[bx % groups, k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :]) return flash_bwd @torch.compile class _attention(torch.autograd.Function): - @staticmethod - def forward(ctx, - q, - k, - v, - seqlens_q, - seqlens_k, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - causal, - groups=1, - use_atomic=True): + def forward( + ctx, q, k, v, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, groups=1, use_atomic=True + ): BATCH, N_CTX, H, D_HEAD_QK = q.shape D_HEAD_V = v.shape[-1] block_M = 128 block_N = 64 - q_unpad, indices_q, _, _ = unpad_input( - q, (torch.arange(N_CTX, device=q.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) - k_unpad, indices_k, _, _ = unpad_input( - k, (torch.arange(N_CTX, device=k.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) - v_unpad, _, _, _ = unpad_input( - v, (torch.arange(N_CTX, device=v.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) + q_unpad, indices_q, _, _ = unpad_input(q, (torch.arange(N_CTX, device=q.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) + k_unpad, indices_k, _, _ = unpad_input(k, (torch.arange(N_CTX, device=k.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) + v_unpad, _, _, _ = unpad_input(v, (torch.arange(N_CTX, device=v.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) total_q = q_unpad.shape[0] total_kv = k_unpad.shape[0] - mod = flashattn_fwd(BATCH, total_q, total_kv, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, causal, - block_M, block_N, groups) + mod = flashattn_fwd(BATCH, total_q, total_kv, N_CTX, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups) o_unpad, lse = mod(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k) o = pad_input(o_unpad, indices_q, BATCH, N_CTX) - ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k, - cu_seqlens_q, cu_seqlens_k) + ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k) + ctx.batch = BATCH ctx.causal = causal ctx.use_atomic = use_atomic ctx.max_seqlen_q = max_seqlen_q @@ -587,9 +524,9 @@ def forward(ctx, @staticmethod def backward(ctx, do): N_CTX = do.shape[1] - q, k, v, o, lse, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors - do_unpad, _, _, _ = unpad_input( - do, (torch.arange(N_CTX, device=do.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) + q, k, v, o, lse_clone, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + # lse_clone = lse.clone() + do_unpad, _, _, _ = unpad_input(do, (torch.arange(N_CTX, device=do.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) total_q, H, D_HEAD_QK = q.shape total_kv, HEAD_KV, D_HEAD_V = v.shape groups = H // HEAD_KV @@ -603,7 +540,7 @@ def maybe_contiguous(x): do, q, k, v, o = [maybe_contiguous(x) for x in (do_unpad, q, k, v, o)] block_M = 128 block_N = 32 - mod_prep = flashattn_bwd_preprocess(BATCH, H, total_q, ctx.max_seqlen_q, D_HEAD_V) + mod_prep = flashattn_bwd_preprocess(BATCH, H, total_q, N_CTX, ctx.max_seqlen_q, D_HEAD_V) mod_post = flashattn_bwd_postprocess(total_q, total_kv, H, HEAD_KV, D_HEAD_QK, D_HEAD_V) delta = mod_prep(o, do, cu_seqlens_q) @@ -612,6 +549,7 @@ def maybe_contiguous(x): BATCH, total_q, total_kv, + N_CTX, H, ctx.max_seqlen_q, D_HEAD_QK, @@ -621,17 +559,19 @@ def maybe_contiguous(x): block_N, threads=256, num_stages=2, - groups=groups) + groups=groups, + ) dq = torch.zeros_like(q, dtype=torch.float32) dk = torch.zeros_like(k, dtype=torch.float32) dv = torch.zeros_like(v, dtype=torch.float32) - kernel(q, k, v, do, lse, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv) + kernel(q, k, v, do, lse_clone, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv) dq, dk, dv = mod_post(dq, dk, dv) else: kernel = flashattn_bwd_split( BATCH, total_q, total_kv, + N_CTX, H, ctx.max_seqlen_q, D_HEAD_QK, @@ -641,13 +581,13 @@ def maybe_contiguous(x): block_N, threads=256, num_stages=2, - groups=groups) + groups=groups, + ) dq = torch.zeros_like(q, dtype=torch.float32) dk = torch.empty(groups, *k.shape, dtype=torch.float16, device=q.device) dv = torch.empty(groups, *v.shape, dtype=torch.float16, device=q.device) - kernel(q, k, v, do, lse, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv) - dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), - torch.zeros_like(v, dtype=torch.float32)) + kernel(q, k, v, do, lse_clone, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv) + dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32)) dk, dv = dk.sum(0), dv.sum(0) dq = pad_input(dq, ctx.indices_q, BATCH, N_CTX) @@ -666,15 +606,13 @@ def ref_program(Q, K, V, padding_mask, is_causal, groups=1): # HQ = HKV * groups # To handle precision issue Q, K, V = Q.float(), K.float(), V.float() - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim_qk = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) if padding_mask is not None: scores.masked_fill_(rearrange(~padding_mask, "b s -> b 1 1 s"), float("-inf")) @@ -682,41 +620,35 @@ def ref_program(Q, K, V, padding_mask, is_causal, groups=1): seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) if padding_mask is not None: output.masked_fill_(rearrange(~padding_mask, "b s -> b s 1 1"), 0.0) return output -def main(BATCH: int = 1, - H: int = 32, - N_CTX: int = 256, - D_HEAD_QK: int = 192, - D_HEAD_V: int = 128, - groups: int = 16, - causal: bool = False, - use_atomic: bool = True): +def main( + BATCH: int = 1, + H: int = 32, + N_CTX: int = 256, + D_HEAD_QK: int = 192, + D_HEAD_V: int = 128, + groups: int = 16, + causal: bool = False, + use_atomic: bool = True, +): flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V total_flops = 3 * flops_per_qk + 2 * flops_per_v if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() head_kv = H // groups - K = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - V = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - dO = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() padding_mask = generate_random_padding_mask(N_CTX, BATCH, "cuda", mode="random") seqlens_q = padding_mask.sum(dim=-1, dtype=torch.int32) cu_seqlens_q = F.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0)) @@ -725,8 +657,7 @@ def main(BATCH: int = 1, # In training backward pass, seqlens_k should be the same as seqlens_q seqlens_k, cu_seqlens_k, max_seqlen_k = seqlens_q, cu_seqlens_q, max_seqlen_q - O = attention(Q, K, V, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, - max_seqlen_k, causal, groups, use_atomic) + O = attention(Q, K, V, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, groups, use_atomic) O.backward(dO, retain_graph=True) dQ, Q.grad = Q.grad.clone(), None dK, K.grad = K.grad.clone(), None @@ -738,12 +669,6 @@ def main(BATCH: int = 1, dK_ref, K.grad = K.grad.clone(), None dV_ref, V.grad = V.grad.clone(), None - torch.testing.assert_close(O, O_ref.half(), rtol=1e-2, atol=1e-2) - torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) - torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) - torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) - print('All checks passed.✅') - def run(): O_ref.backward(dO, retain_graph=True) @@ -759,24 +684,85 @@ def run1(): print("tilelang: {:.2f} ms".format(latency)) print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + torch.testing.assert_close(O, O_ref.half(), rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) + print("All checks passed.✅") + print( + "Note: this varlen kernel performance is as good as the non-varlen kernel shown in Nsight-Compute. As you may observe that the TFLOPS is a bit lower, that's because the unpad operation is included in the above benchmark." + ) + + +def run_regression_perf(): + BATCH = 1 + H = 32 + N_CTX = 256 + D_HEAD_QK = 192 + D_HEAD_V = 128 + groups = 16 + causal = False + device = "cuda" + torch.manual_seed(42) + total_q = BATCH * N_CTX + total_kv = BATCH * N_CTX + head_kv = H // groups + Q = torch.randn(total_q, H, D_HEAD_QK, device=device, dtype=torch.half) + K = torch.randn(total_kv, head_kv, D_HEAD_QK, device=device, dtype=torch.half) + V = torch.randn(total_kv, head_kv, D_HEAD_V, device=device, dtype=torch.half) + O = torch.randn(total_q, H, D_HEAD_V, device=device, dtype=torch.half) + dO = torch.randn(total_q, H, D_HEAD_V, device=device, dtype=torch.half) + cu_seqlens_q = torch.arange(0, (BATCH + 1) * N_CTX, N_CTX, device=device, dtype=torch.int32) + cu_seqlens_k = cu_seqlens_q + max_seqlen_q = N_CTX + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, total_q, N_CTX, max_seqlen_q, D_HEAD_V) + kernel = flashattn_bwd_split( + BATCH, + total_q, + total_kv, + N_CTX, + H, + max_seqlen_q, + D_HEAD_QK, + D_HEAD_V, + causal, + block_M=128, + block_N=32, + threads=256, + num_stages=2, + groups=groups, + ) + dQ = torch.zeros_like(Q, dtype=torch.float32) + dK = torch.zeros(groups, total_kv, head_kv, D_HEAD_QK, device=device, dtype=torch.float16) + dV = torch.zeros(groups, total_kv, head_kv, D_HEAD_V, device=device, dtype=torch.float16) + Delta = mod_prep(O, dO, cu_seqlens_q) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, cu_seqlens_q, cu_seqlens_k, dQ, dK, dV) + + return do_bench(run_kernel_only, backend="cupti") + if __name__ == "__main__": arch = nvcc.get_target_compute_version() print(f"Detected GPU compute capability: {arch}") assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0" parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') - parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') - parser.add_argument('--causal', action='store_true', help='Causal flag') - parser.add_argument('--groups', type=int, default=16, help='groups') - parser.add_argument( - '--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') - parser.add_argument( - '--use_split', action='store_true', default=False, help='Use split for dK/dV') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV") + parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV") args = parser.parse_args() + # Can be set to True/False for testing + args.causal = True # Handle backward compatibility and logic if args.use_split: @@ -787,5 +773,4 @@ def run1(): # Default: use atomic use_atomic = True - main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, - use_atomic) + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) diff --git a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py index ed07e7d9d..2da64472c 100644 --- a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py @@ -6,25 +6,27 @@ @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_fwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -39,26 +41,25 @@ def flash_fwd( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim_v): @@ -72,29 +73,31 @@ def flash_fwd( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim_v): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim_v] blk = 32 @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -103,50 +106,42 @@ def flash_bwd_prep( delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim_v, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd(batch, - heads, - seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -167,45 +162,30 @@ def flash_bwd( dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - }) - - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) - T.gemm( - K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) - T.gemm( - V_shared, - do, - dsT, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow, - wg_wait=-1) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) T.wait_wgmma(1) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -217,18 +197,17 @@ def flash_bwd( T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) T.wait_wgmma(0) T.copy(dq, dq_shared) - T.atomic_add(dQ[bz, k * block_N:(k + 1) * block_N, bx, :], dq_shared) + T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared) T.copy(dv, dv_shared) - T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared) + T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared) T.copy(dk, dk_shared) - T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared) + T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared) return flash_bwd @torch.compile class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): BATCH, N_CTX, H, D_HEAD_QK = q.shape @@ -246,7 +225,10 @@ def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): def backward(ctx, do): q, k, v, o, lse = ctx.saved_tensors BATCH, N_CTX, H, D_HEAD_QK = q.shape - HEAD_KV, D_HEAD_V, = v.shape[-2], v.shape[-1] + ( + HEAD_KV, + D_HEAD_V, + ) = v.shape[-2], v.shape[-1] groups = H // HEAD_KV def maybe_contiguous(x): @@ -260,18 +242,7 @@ def maybe_contiguous(x): mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) delta = mod_prep(o, do) - kernel = flashattn_bwd( - BATCH, - H, - N_CTX, - D_HEAD_QK, - D_HEAD_V, - ctx.causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=groups) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] @@ -294,52 +265,36 @@ def ref_program(Q, K, V, is_causal, groups=1): # K: [B, T, HK, D_QK] # V: [B, T, HV, D_V] # HQ = HKV * groups - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim_qk = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output -def main(BATCH: int = 1, - H: int = 32, - N_CTX: int = 256, - D_HEAD_QK: int = 192, - D_HEAD_V: int = 128, - groups: int = 16, - causal: bool = False): +def main(BATCH: int = 1, H: int = 32, N_CTX: int = 256, D_HEAD_QK: int = 192, D_HEAD_V: int = 128, groups: int = 16, causal: bool = False): flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V total_flops = 3 * flops_per_qk + 2 * flops_per_v if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() head_kv = H // groups - K = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - V = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - dO = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() O = attention(Q, K, V, causal, groups) O.backward(dO, retain_graph=True) dQ, Q.grad = Q.grad.clone(), None @@ -356,7 +311,7 @@ def main(BATCH: int = 1, torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) - print('All checks passed.✅') + print("All checks passed.✅") def run(): O_ref.backward(dO, retain_graph=True) @@ -374,15 +329,34 @@ def run1(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf( + BATCH: int = 1, H: int = 32, N_CTX: int = 256, D_HEAD_QK: int = 192, D_HEAD_V: int = 128, groups: int = 16, causal: bool = False +): + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + + head_kv = H // groups + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + O = attention(Q, K, V, causal, groups) + + def run1(): + O.backward(dO, retain_graph=True) + + from tilelang.profiler import do_bench + + return do_bench(run1, warmup=500, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') - parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') - parser.add_argument('--causal', action='store_true', help='Causal flag') - parser.add_argument('--groups', type=int, default=16, help='groups') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal) diff --git a/examples/flash_attention/example_gqa_fwd_bshd.py b/examples/flash_attention/example_gqa_fwd_bshd.py index 4d9d06a4f..e884a8158 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd.py +++ b/examples/flash_attention/example_gqa_fwd_bshd.py @@ -9,7 +9,6 @@ class FlashAttentionTuneSpace: - def __init__( self, block_sizes=(64, 128, 256), @@ -40,7 +39,7 @@ def get_configs(user_config=None): warp_M = block_M // warp_count warp_N = block_N // warp_count - if (warp_M % config.warp_alignment != 0 or warp_N % config.warp_alignment != 0): + if warp_M % config.warp_alignment != 0 or warp_N % config.warp_alignment != 0: continue shared_mem = 2 * config.dtype_bytes * config.dim * (block_M + block_N) @@ -48,114 +47,38 @@ def get_configs(user_config=None): continue for num_stages in config.num_stages_range: - valid_configs.append({ - "block_M": block_M, - "block_N": block_N, - "num_stages": num_stages, - "threads": threads, - }) + valid_configs.append( + { + "block_M": block_M, + "block_N": block_N, + "num_stages": num_stages, + "threads": threads, + } + ) return valid_configs @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_len, - dim, - is_causal, - groups=1, - block_M=64, - block_N=64, - num_stages=0, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_len, dim, is_causal, groups=1, block_M=64, block_N=64, num_stages=0, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] - dtype = "float16" - accum_dtype = "float" - - @T.macro - def MMA0( - K: T.Tensor(kv_shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(kv_shape, dtype), - V_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -171,25 +94,49 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) return main @@ -199,50 +146,34 @@ def ref_program(Q, K, V, is_causal, groups=1): # K: [B, T, HK, D] # V: [B, T, HV, D] # HQ = HKV * groups - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output -def main(batch: int = 1, - heads: int = 64, - seq_len: int = 4096, - dim: int = 128, - is_causal: bool = False, - groups: int = 16, - tune: bool = False): +def main( + batch: int = 1, heads: int = 64, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 16, tune: bool = False +): flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim total_flops = 2 * flops_per_matmul if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_len, - dim, - is_causal, - groups=groups, - block_M=64, - block_N=64, - num_stages=2, - threads=128) + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=64, block_N=64, num_stages=2, threads=128) ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) @@ -264,14 +195,22 @@ def main(batch: int = 1, print(f"Ref latency: {ref_latency}") +def run_regression_perf( + batch: int = 1, heads: int = 64, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 16, tune: bool = False +): + kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=64, block_N=64, num_stages=2, threads=128) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='heads') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') - parser.add_argument('--groups', type=int, default=16, help='groups') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") + parser.add_argument("--groups", type=int, default=16, help="groups") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups, args.tune) diff --git a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py index 1c1fc12d2..73a725d9f 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py @@ -24,9 +24,11 @@ def get_configs(): rep=10, ) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn( batch, heads, @@ -39,90 +41,19 @@ def flashattn( num_stages=0, threads=128, ): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] - dtype = "float16" - accum_dtype = "float" - - @T.macro - def MMA0( - K: T.Tensor(kv_shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(kv_shape, dtype), - V_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -138,30 +69,55 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined( - loop_range, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) return main @@ -171,23 +127,21 @@ def ref_program(Q, K, V, is_causal, groups=1): # K: [B, T, HK, D] # V: [B, T, HV, D] # HQ = HKV * groups - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -205,18 +159,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_len, - dim, - is_causal, - groups=groups, - block_M=128, - block_N=128, - num_stages=2, - threads=256) + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=128, block_N=128, num_stages=2, threads=256) ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) @@ -238,14 +182,28 @@ def main( print(f"Ref latency: {ref_latency}") +def run_regression_perf( + batch: int = 1, + heads: int = 64, + seq_len: int = 4096, + dim: int = 128, + is_causal: bool = False, + groups: int = 16, +): + kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=128, block_N=128, num_stages=2, threads=256) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='heads') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') - parser.add_argument('--groups', type=int, default=16, help='groups') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") + parser.add_argument("--groups", type=int, default=16, help="groups") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups, args.tune) diff --git a/examples/flash_attention/example_gqa_fwd_varlen.py b/examples/flash_attention/example_gqa_fwd_varlen.py index 37e81ebb3..0e8e21c43 100644 --- a/examples/flash_attention/example_gqa_fwd_varlen.py +++ b/examples/flash_attention/example_gqa_fwd_varlen.py @@ -4,80 +4,36 @@ import tilelang import tilelang.language as T import tilelang.testing -from einops import rearrange, repeat from tilelang.profiler import do_bench from varlen_utils import generate_random_padding_mask, generate_qkv -def attention_ref( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - causal=False, - window_size=(-1, -1), - upcast=True, -): - if causal: - window_size = (window_size[0], 0) - dtype_og = q.dtype - if upcast: - q, k, v = q.float(), k.float(), v.float() - dim = q.shape[-1] - scale = (1.0 / dim)**0.5 - k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) - v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) - scores = torch.einsum("bthd,bshd->bhts", q, k) - if key_padding_mask is not None: - scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) - scores = scores * scale - attention = torch.softmax(scores, dim=-1).to(v.dtype) - - if query_padding_mask is not None: - attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) - output = torch.einsum("bhts,bshd->bthd", attention, v) - if query_padding_mask is not None: - output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) - return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) - - @tilelang.jit( - out_idx=[6], pass_configs={ + out_idx=[6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch_size, - groups, - UQ, - UKV, - heads, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=1, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch_size, groups, UQ, UKV, heads, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [UQ, heads, dim] kv_shape = [UKV, head_kv, dim] o_shape = [UQ, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def main( - Q_unpad: T.Tensor(q_shape, dtype), - K_unpad: T.Tensor(kv_shape, dtype), - V_unpad: T.Tensor(kv_shape, dtype), - cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), - cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), - max_seqlen_q: T.int32, - Output_unpad: T.Tensor(o_shape, dtype), + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(kv_shape, dtype), + V_unpad: T.Tensor(kv_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), + cu_seqlens_k: T.Tensor([batch_size + 1], T.int32), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), ): - with T.Kernel( - T.ceildiv(max_seqlen_q, block_M), heads, batch_size, - threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -96,54 +52,51 @@ def main( kv_head_idx = head_idx // groups q_start_idx = cu_seqlens_q[batch_idx] - k_start_idx = cu_seqlens_k[batch_idx] - v_start_idx = cu_seqlens_k[batch_idx] + kv_start_idx = cu_seqlens_k[batch_idx] q_end_idx = cu_seqlens_q[batch_idx + 1] k_end_idx = cu_seqlens_k[batch_idx + 1] - v_end_idx = cu_seqlens_k[batch_idx + 1] q_current_seqlen = q_end_idx - q_start_idx - k_current_seqlen = k_end_idx - k_start_idx - v_current_seqlen = v_end_idx - v_start_idx + kv_current_seqlen = k_end_idx - kv_start_idx - T.copy( - Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, head_idx, :], - Q_shared) - for i, d in T.Parallel(block_M, dim): - if bx * block_M + i >= q_current_seqlen: - Q_shared[i, d] = 0 + T.copy(Q_unpad[q_start_idx + bx * block_M : q_start_idx + (bx + 1) * block_M, head_idx, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = T.ceildiv(k_current_seqlen, block_N) + offset = kv_current_seqlen - q_current_seqlen # always align on the right + max_visible_k_idx = offset + (bx + 1) * block_M + loop_range = ( + T.min(T.ceildiv(max_visible_k_idx, block_N), T.ceildiv(kv_current_seqlen, block_N)) + if is_causal + else T.ceildiv(kv_current_seqlen, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): - T.copy( - K_unpad[k_start_idx + k * block_N:k_start_idx + (k + 1) * block_N, - kv_head_idx, :], K_shared) - for i, d in T.Parallel(block_N, dim): - if k * block_N + i >= k_current_seqlen: - K_shared[i, d] = 0 + T.copy(K_unpad[kv_start_idx + k * block_N : kv_start_idx + (k + 1) * block_N, kv_head_idx, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and - (bx * block_M + i >= q_current_seqlen or - k * block_N + j >= k_current_seqlen), - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i + offset < k * block_N + j) + or (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), + -1e9, + 0, + ) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or - k * block_N + j >= k_current_seqlen), - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), -1e9, 0 + ) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) @@ -157,19 +110,15 @@ def main( for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] - T.copy( - V_unpad[v_start_idx + k * block_N:v_start_idx + (k + 1) * block_N, - kv_head_idx, :], V_shared) - for i, d in T.Parallel(block_N, dim): - if k * block_N + i >= v_current_seqlen: - V_shared[i, d] = 0 + T.copy(V_unpad[kv_start_idx + k * block_N : kv_start_idx + (k + 1) * block_N, kv_head_idx, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_M, dim): - acc_o[i, j] /= logsum[i] - T.copy(acc_o, O_shared) + # When sq > skv, some tokens can see nothing + acc_o[i, j] = 0 if is_causal and bx * block_M + i + offset < 0 else acc_o[i, j] / logsum[i] + T.copy(acc_o, O_shared) for i, d in T.Parallel(block_M, dim): if bx * block_M + i < q_current_seqlen: Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d] @@ -177,13 +126,9 @@ def main( return main -def main(batch: int = 1, - heads: int = 64, - q_seqlen: int = 2048, - k_seqlen: int = 2048, - dim: int = 128, - groups: int = 16, - is_causal: bool = False): +def main( + batch: int = 1, heads: int = 64, q_seqlen: int = 2048, k_seqlen: int = 2048, dim: int = 128, groups: int = 16, is_causal: bool = False +): assert heads % groups == 0, "heads must be divisible by groups" flops_per_matmul = 2.0 * batch * heads * q_seqlen * k_seqlen * dim @@ -191,8 +136,7 @@ def main(batch: int = 1, tilelang.testing.set_random_seed(0) - causal = False - if causal: + if is_causal: total_flops *= 0.5 tilelang.testing.set_random_seed(0) @@ -201,9 +145,9 @@ def main(batch: int = 1, device = torch.device("cuda") head_kv = heads // groups - q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device, requires_grad=True) - k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True) - v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True) + q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device) + k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device) + v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device) query_padding_mask = generate_random_padding_mask(q_seqlen, batch, device, mode="random") key_padding_mask = generate_random_padding_mask(k_seqlen, batch, device, mode="random") @@ -222,53 +166,46 @@ def main(batch: int = 1, output_pad_fn, _, _, - ) = generate_qkv( - q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) UQ = q_unpad.shape[0] UKV = k_unpad.shape[0] - kernel = flashattn( - batch, - groups, - UQ, - UKV, - heads, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=1, - threads=128) + kernel = flashattn(batch, groups, UQ, UKV, heads, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) out = output_pad_fn(out_unpad) - out_ref, _ = attention_ref( - q, - k, - v, - query_padding_mask=query_padding_mask, - key_padding_mask=key_padding_mask, + import flash_attn + + fa_out_unpad = flash_attn.flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + 0.0, causal=is_causal, ) - torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) + fa_out = output_pad_fn(fa_out_unpad) + torch.testing.assert_close(out, fa_out, rtol=1e-2, atol=1e-2) + print("All checks passed.✅") - latency = do_bench( - lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)) + latency = do_bench(lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q), _n_warmup=5, _n_repeat=5) print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='query heads') - parser.add_argument('--groups', type=int, default=16, help='groups') - parser.add_argument('--q_seqlen', type=int, default=2048, help='query sequence length') - parser.add_argument('--k_seqlen', type=int, default=2048, help='key/value sequence length') - parser.add_argument('--dim', type=int, default=128, help='head dim') - parser.add_argument('--is_causal', action='store_true', help='causal attention') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="query heads") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--q_seqlen", type=int, default=2048, help="query sequence length") + parser.add_argument("--k_seqlen", type=int, default=2048, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=128, help="head dim") + parser.add_argument("--is_causal", action="store_true", help="causal attention") args = parser.parse_args() - main(args.batch, args.heads, args.q_seqlen, args.k_seqlen, args.dim, args.groups, - args.is_causal) + main(args.batch, args.heads, args.q_seqlen, args.k_seqlen, args.dim, args.groups, args.is_causal) diff --git a/examples/flash_attention/example_mha_bwd_bhsd.py b/examples/flash_attention/example_mha_bwd_bhsd.py index 1595ae764..34e8fefc5 100644 --- a/examples/flash_attention/example_mha_bwd_bhsd.py +++ b/examples/flash_attention/example_mha_bwd_bhsd.py @@ -7,22 +7,24 @@ @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_fwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - Output: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -38,29 +40,28 @@ def flash_fwd( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) # T.copy(Q_shared, Q_local) # for i, j in T.Parallel(block_M, dim): # Q_local[i, j] *= scale - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): @@ -74,29 +75,31 @@ def flash_fwd( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, heads, seq_len, dim] blk = 32 @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -105,68 +108,71 @@ def flash_bwd_prep( delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], do) + T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment - return T.Layout(dQ.shape, - lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) @tilelang.jit( - out_idx=[1], pass_configs={ + out_idx=[1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(batch, heads, seq_len, dim): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, heads, seq_len, dim] blk = 64 @T.prim_func def flash_bwd_post( - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(shape, dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) T.copy( - dQ[bz, by, bx * blk:(bx + 1) * blk, :], - dQ_out[bz, by, bx * blk:(bx + 1) * blk, :], + dQ[bz, by, bx * blk : (bx + 1) * blk, :], + dQ_out[bz, by, bx * blk : (bx + 1) * blk, :], ) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - sm_scale = (1.0 / dim)**0.5 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + sm_scale = (1.0 / dim) ** 0.5 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dK: T.Tensor(shape, dtype), # type: ignore - dV: T.Tensor(shape, dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) @@ -190,36 +196,36 @@ def flash_bwd( dv_shared = T.alloc_shared([block_M, dim], dtype) dk_shared = T.alloc_shared([block_M, dim], dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) - T.copy(K[bz, bx, by * block_M:(by + 1) * block_M, :], K_shared) - T.copy(V[bz, bx, by * block_M:(by + 1) * block_M, :], V_shared) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) + T.copy(K[bz, bx, by * block_M : (by + 1) * block_M, :], K_shared) + T.copy(V[bz, bx, by * block_M : (by + 1) * block_M, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=2): - T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q) + T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + # We don't need to handle OOB positions for non-causal cases, + # since OOB values won't affect other positions here. + T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -232,14 +238,13 @@ def flash_bwd( T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j]) T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, dV[bz, bx, by * block_M:(by + 1) * block_M, :]) - T.copy(dk_shared, dK[bz, bx, by * block_M:(by + 1) * block_M, :]) + T.copy(dv_shared, dV[bz, bx, by * block_M : (by + 1) * block_M, :]) + T.copy(dk_shared, dK[bz, bx, by * block_M : (by + 1) * block_M, :]) return flash_bwd class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal): BATCH, H, N_CTX, D_HEAD = q.shape @@ -281,15 +286,15 @@ def maybe_contiguous(x): def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(2) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) return output @@ -304,9 +309,7 @@ def main( total_flops = 5 * flops_per_matmul if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, H, N_CTX, D_HEAD, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, H, N_CTX, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_() K = torch.empty_like(Q).normal_().requires_grad_() V = torch.empty_like(Q).normal_().requires_grad_() dO = torch.randn_like(Q) @@ -345,12 +348,43 @@ def run1(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(): + BATCH = 1 + H = 16 + N_CTX = 512 + D_HEAD = 64 + causal = False + device = "cuda" + torch.manual_seed(0) + block_M = 64 + block_N = 64 if D_HEAD <= 64 else 32 + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, device=device, dtype=torch.half) + K = torch.randn_like(Q) + V = torch.randn_like(Q) + O = torch.randn_like(Q) + dO = torch.randn_like(Q) + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N) + dQ = torch.zeros(BATCH, H, N_CTX, D_HEAD, device=device, dtype=torch.float32) + dK = torch.zeros(BATCH, H, N_CTX, D_HEAD, device=device, dtype=torch.float16) + dV = torch.zeros(BATCH, H, N_CTX, D_HEAD, device=device, dtype=torch.float16) + Delta = mod_prep(O, dO) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, dQ, dK, dV) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head', type=int, default=64, help='Head dimension') - parser.add_argument('--causal', type=bool, default=False, help='Causal flag') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head", type=int, default=64, help="Head dimension") + parser.add_argument("--causal", type=bool, default=False, help="Causal flag") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) diff --git a/examples/flash_attention/example_mha_bwd.py b/examples/flash_attention/example_mha_bwd_bshd.py similarity index 65% rename from examples/flash_attention/example_mha_bwd.py rename to examples/flash_attention/example_mha_bwd_bshd.py index 543c2c0e7..fc8328fa4 100644 --- a/examples/flash_attention/example_mha_bwd.py +++ b/examples/flash_attention/example_mha_bwd_bshd.py @@ -7,22 +7,24 @@ @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_fwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - Output: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -38,25 +40,25 @@ def flash_fwd( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): @@ -70,29 +72,31 @@ def flash_fwd( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim] blk = 32 @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -101,68 +105,71 @@ def flash_bwd_prep( delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment - return T.Layout(dQ.shape, - lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + return T.Layout(dQ.shape, lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) @tilelang.jit( - out_idx=[1], pass_configs={ + out_idx=[1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(batch, heads, seq_len, dim): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim] blk = 64 @T.prim_func def flash_bwd_post( - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(shape, dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) T.copy( - dQ[bz, bx * blk:(bx + 1) * blk, by, :], - dQ_out[bz, bx * blk:(bx + 1) * blk, by, :], + dQ[bz, bx * blk : (bx + 1) * blk, by, :], + dQ_out[bz, bx * blk : (bx + 1) * blk, by, :], ) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - sm_scale = (1.0 / dim)**0.5 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + sm_scale = (1.0 / dim) ** 0.5 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dK: T.Tensor(shape, dtype), # type: ignore - dV: T.Tensor(shape, dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) @@ -186,33 +193,36 @@ def flash_bwd( dv_shared = T.alloc_shared([block_M, dim], dtype) dk_shared = T.alloc_shared([block_M, dim], dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - }) - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=2): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + # We don't need to handle OOB positions for non-causal cases, + # since OOB values won't affect other positions here. + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -225,14 +235,13 @@ def flash_bwd( T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :]) - T.copy(dk_shared, dK[bz, by * block_M:(by + 1) * block_M, bx, :]) + T.copy(dv_shared, dV[bz, by * block_M : (by + 1) * block_M, bx, :]) + T.copy(dk_shared, dK[bz, by * block_M : (by + 1) * block_M, bx, :]) return flash_bwd class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal): BATCH, N_CTX, H, D_HEAD = q.shape @@ -274,15 +283,15 @@ def maybe_contiguous(x): def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -297,9 +306,7 @@ def main( total_flops = 5 * flops_per_matmul if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_() K = torch.empty_like(Q).normal_().requires_grad_() V = torch.empty_like(Q).normal_().requires_grad_() dO = torch.randn_like(Q) @@ -336,12 +343,43 @@ def run1(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(): + BATCH = 1 + H = 16 + N_CTX = 512 + D_HEAD = 64 + causal = False + device = "cuda" + torch.manual_seed(42) + block_M = 64 + block_N = 64 if D_HEAD <= 64 else 32 + Q = torch.randn(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.half) + K = torch.randn_like(Q) + V = torch.randn_like(Q) + O = torch.randn_like(Q) + dO = torch.randn_like(Q) + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N) + dQ = torch.zeros(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.float32) + dK = torch.zeros(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.float16) + dV = torch.zeros(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.float16) + Delta = mod_prep(O, dO) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, dQ, dK, dV) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head', type=int, default=64, help='Head dimension') - parser.add_argument('--causal', type=bool, default=False, help='Causal flag') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head", type=int, default=64, help="Head dimension") + parser.add_argument("--causal", type=bool, default=False, help="Causal flag") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) diff --git a/examples/flash_attention/example_mha_bwd_wgmma_pipelined.py b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py similarity index 64% rename from examples/flash_attention/example_mha_bwd_wgmma_pipelined.py rename to examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py index 7ad417ef5..c0fe4e33d 100644 --- a/examples/flash_attention/example_mha_bwd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py @@ -7,22 +7,24 @@ @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_fwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - Output: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -37,27 +39,26 @@ def flash_fwd( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): @@ -71,29 +72,31 @@ def flash_fwd( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim] blk = 32 @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -102,37 +105,39 @@ def flash_bwd_prep( delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - sm_scale = (1.0 / dim)**0.5 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + sm_scale = (1.0 / dim) ** 0.5 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dK: T.Tensor(shape, dtype), # type: ignore - dV: T.Tensor(shape, dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) @@ -157,47 +162,34 @@ def flash_bwd( dk_shared = T.alloc_shared([block_M, dim], dtype) dq_shared = T.alloc_shared([block_N, dim], accum_dtype) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), - }) - - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=2): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) - T.gemm( - K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) - T.gemm( - V_shared, - do, - dsT, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow, - wg_wait=-1) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) T.wait_wgmma(1) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + # We don't need to handle OOB positions for non-causal cases, + # since OOB values won't affect other positions here. T.wait_wgmma(0) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -208,17 +200,16 @@ def flash_bwd( T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) T.wait_wgmma(0) T.copy(dq, dq_shared) - T.atomic_add(dQ[bz, k * block_N:(k + 1) * block_N, bx, :], dq_shared) + T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared) T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :]) - T.copy(dk_shared, dK[bz, by * block_M:(by + 1) * block_M, bx, :]) + T.copy(dv_shared, dV[bz, by * block_M : (by + 1) * block_M, bx, :]) + T.copy(dk_shared, dK[bz, by * block_M : (by + 1) * block_M, bx, :]) return flash_bwd class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal): BATCH, N_CTX, H, D_HEAD = q.shape @@ -260,15 +251,15 @@ def maybe_contiguous(x): def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -283,9 +274,7 @@ def main( total_flops = 5 * flops_per_matmul if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_() K = torch.empty_like(Q).normal_().requires_grad_() V = torch.empty_like(Q).normal_().requires_grad_() dO = torch.randn_like(Q) @@ -305,7 +294,7 @@ def main( assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) - print('All checks passed.✅') + print("All checks passed.✅") def run(): O_ref.backward(dO, retain_graph=True) @@ -321,12 +310,44 @@ def run1(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(): + BATCH = 1 + H = 32 + N_CTX = 256 + D_HEAD = 64 + causal = False + device = "cuda" + torch.manual_seed(0) + block_M = 128 + block_N = 128 if D_HEAD <= 64 else 32 + Q = torch.randn(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.half) + K = torch.randn_like(Q) + V = torch.randn_like(Q) + O = torch.randn_like(Q) + dO = torch.randn_like(Q) + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N) + dQ = torch.zeros(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.float32) + dK = torch.zeros_like(Q, dtype=torch.float16) + dV = torch.zeros_like(Q, dtype=torch.float16) + Delta = mod_prep(O, dO) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, dQ, dK, dV) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head', type=int, default=64, help='Head dimension') - parser.add_argument('--causal', type=bool, default=False, help='Causal flag') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head", type=int, default=64, help="Head dimension") + parser.add_argument("--causal", type=bool, default=False, help="Causal flag") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) diff --git a/examples/flash_attention/example_mha_fwd_bhsd.py b/examples/flash_attention/example_mha_fwd_bhsd.py index f07f7a618..400736541 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd.py +++ b/examples/flash_attention/example_mha_fwd_bhsd.py @@ -15,107 +15,27 @@ def get_configs(): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=1, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 past_len = seq_kv - seq_q assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" - @T.macro - def MMA0( - K: T.Tensor(kv_shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - q_idx = bx * block_M + i + past_len - k_idx = k * block_N + j - acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(kv_shape, dtype), - V_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -131,43 +51,69 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv( - (bx + 1) * block_M + - past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) + T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + if is_causal + else T.ceildiv(seq_kv, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_q = Q.size(2) seq_kv = K.size(2) mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) return output @@ -185,18 +131,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=1, - threads=128) + if not tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128) ref_program_processed = partial(ref_program, is_causal=is_causal) profiler = kernel.get_profiler() @@ -219,14 +155,28 @@ def main( print(f"Ref latency: {ref_latency}") +def run_regression_perf( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 64, + is_causal: bool = False, + tune: bool = False, +): + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) + profiler = kernel.get_profiler() + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=1, help='heads') - parser.add_argument('--seq_q', type=int, default=256, help='query sequence length') - parser.add_argument('--seq_kv', type=int, default=256, help='key/value sequence length') - parser.add_argument('--dim', type=int, default=64, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=1, help="heads") + parser.add_argument("--seq_q", type=int, default=256, help="query sequence length") + parser.add_argument("--seq_kv", type=int, default=256, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal", default=False) + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py index 26167b34b..90514f762 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py @@ -15,107 +15,27 @@ def get_configs(): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=2, - threads=256): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 past_len = seq_kv - seq_q assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" - @T.macro - def MMA0( - K: T.Tensor(kv_shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - q_idx = bx * block_M + i + past_len - k_idx = k * block_N + j - acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(kv_shape, dtype), - V_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -131,48 +51,75 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv( - (bx + 1) * block_M + - past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) + T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + if is_causal + else T.ceildiv(seq_kv, block_N) + ) for k in T.Pipelined( - loop_range, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_q = Q.size(2) seq_kv = K.size(2) mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) return output @@ -190,18 +137,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=2, - threads=256) + if not tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) ref_program_processed = partial(ref_program, is_causal=is_causal) profiler = kernel.get_profiler() @@ -224,14 +161,28 @@ def main( print(f"Ref latency: {ref_latency}") +def run_regression_perf( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + is_causal: bool = False, + tune: bool = False, +): + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) + profiler = kernel.get_profiler() + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_q', type=int, default=4096, help='query sequence length') - parser.add_argument('--seq_kv', type=int, default=4096, help='key/value sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="query sequence length") + parser.add_argument("--seq_kv", type=int, default=4096, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_bshd.py b/examples/flash_attention/example_mha_fwd_bshd.py index 6a1f707e5..e584971c0 100644 --- a/examples/flash_attention/example_mha_fwd_bshd.py +++ b/examples/flash_attention/example_mha_fwd_bshd.py @@ -15,100 +15,23 @@ def get_configs(): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_len, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=1, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_len, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] - dtype = "float16" - accum_dtype = "float" - - @T.macro - def MMA0( - K: T.Tensor(shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(shape, dtype), - V_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def main( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - Output: T.Tensor(shape, dtype), + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -124,40 +47,64 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) return main def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -174,17 +121,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_len, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=1, - threads=128) + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=1, threads=128) ref_program_processed = partial(ref_program, is_causal=is_causal) profiler = kernel.get_profiler() profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) @@ -206,13 +144,19 @@ def main( print(f"Ref latency: {ref_latency}") +def run_regression_perf(batch: int = 8, heads: int = 32, seq_len: int = 4096, dim: int = 128, is_causal: bool = False): + kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=1, threads=128) + profiler = kernel.get_profiler() + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py index 3928db4c3..d6e1490c9 100644 --- a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py @@ -15,100 +15,23 @@ def get_configs(): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_len, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=2, - threads=256): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] - dtype = "float16" - accum_dtype = "float" - - @T.macro - def MMA0( - K: T.Tensor(shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(shape, dtype), - V_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def main( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - Output: T.Tensor(shape, dtype), + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -124,45 +47,70 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined( - loop_range, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) return main def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -179,17 +127,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_len, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=2, - threads=256) + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) ref_program_processed = partial(ref_program, is_causal=is_causal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) @@ -211,13 +150,19 @@ def main( print(f"Ref latency: {ref_latency}") +def run_regression_perf(batch: int = 8, heads: int = 32, seq_len: int = 4096, dim: int = 128, is_causal: bool = False): + kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_varlen.py b/examples/flash_attention/example_mha_fwd_varlen.py index f381e900a..0f3610b11 100644 --- a/examples/flash_attention/example_mha_fwd_varlen.py +++ b/examples/flash_attention/example_mha_fwd_varlen.py @@ -4,109 +4,51 @@ import tilelang.language as T import tilelang.testing import argparse +from tilelang.profiler import do_bench +from tilelang.autotuner import set_autotune_inputs, autotune import torch -from einops import rearrange, repeat from varlen_utils import generate_random_padding_mask, generate_qkv +import itertools -def attention_ref( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - causal=False, - window_size=(-1, -1), # -1 means infinite window size - upcast=True, -): - """ - Arguments: - q: (batch_size, seqlen_q, nheads, head_dim) - k: (batch_size, seqlen_k, nheads_k, head_dim) - v: (batch_size, seqlen_k, nheads_k, head_dim) - query_padding_mask: (batch_size, seqlen_q) - key_padding_mask: (batch_size, seqlen_k) - attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) - dropout_p: float - dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) - causal: whether to apply causal masking - window_size: (int, int), left and right window size - upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast - output back to fp16/bf16. - reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) - without changing the math. This is to estimate the numerical error from operation - reordering. - Output: - output: (batch_size, seqlen_q, nheads, head_dim) - attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout - """ - if causal: - window_size = (window_size[0], 0) - dtype_og = q.dtype - if upcast: - q, k, v = q.float(), k.float(), v.float() - dim = q.shape[-1] - scale = (1.0 / dim)**0.5 # log2(e) - k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) - v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) - scores = torch.einsum("bthd,bshd->bhts", q, k) - if key_padding_mask is not None: - scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) - # scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0) - scores = scores * scale - attention = torch.softmax(scores, dim=-1).to(v.dtype) - - # We want to mask here so that the attention matrix doesn't have any NaNs - # Otherwise we'll get NaN in dV - if query_padding_mask is not None: - attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) - output = torch.einsum("bhts,bshd->bthd", attention, v) - if query_padding_mask is not None: - output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) - return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) +def get_configs(): + iter_params = dict(block_M=[64, 128], block_N=[64, 128], num_stages=[0, 1, 2, 3], threads=[128, 256]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] +@autotune(configs=get_configs()) @tilelang.jit( - out_idx=[6], pass_configs={ + out_idx=[6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch_size, - UQ, - UKV, - heads, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=0, - threads=32): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch_size, UQ, UKV, heads, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) q_shape = [UQ, heads, dim] k_shape = [UKV, heads, dim] v_shape = [UKV, heads, dim] o_shape = [UQ, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def main( - Q_unpad: T.Tensor(q_shape, dtype), - K_unpad: T.Tensor(k_shape, dtype), - V_unpad: T.Tensor(v_shape, dtype), - cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), - cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), - max_seqlen_q: T.int32, - Output_unpad: T.Tensor(o_shape, dtype), + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(k_shape, dtype), + V_unpad: T.Tensor(v_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), + cu_seqlens_k: T.Tensor([batch_size + 1], T.int32), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), ): - with T.Kernel( - T.ceildiv(max_seqlen_q, block_M), heads, batch_size, - threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_M, dim], dtype, "shared") - K_shared = T.alloc_shared([block_N, dim], dtype, "shared") - V_shared = T.alloc_shared([block_N, dim], dtype, "shared") - O_shared = T.alloc_shared([block_M, dim], dtype, "shared") + with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) acc_o = T.alloc_fragment([block_M, dim], accum_dtype) @@ -120,46 +62,46 @@ def main( head_idx = by q_start_idx = cu_seqlens_q[batch_idx] - k_start_idx = cu_seqlens_k[batch_idx] - v_start_idx = cu_seqlens_k[batch_idx] + kv_start_idx = cu_seqlens_k[batch_idx] q_end_idx = cu_seqlens_q[batch_idx + 1] - k_end_idx = cu_seqlens_k[batch_idx + 1] - v_end_idx = cu_seqlens_k[batch_idx + 1] + kv_end_idx = cu_seqlens_k[batch_idx + 1] q_current_seqlen = q_end_idx - q_start_idx - k_current_seqlen = k_end_idx - k_start_idx - v_current_seqlen = v_end_idx - v_start_idx + kv_current_seqlen = kv_end_idx - kv_start_idx - for i, d in T.Parallel(block_M, dim): - if bx * block_M + i < q_current_seqlen: - Q_shared[i, d] = Q_unpad[q_start_idx + bx * block_M + i, head_idx, d] - else: - Q_shared[i, d] = 0 + T.copy( + Q_unpad[q_start_idx + bx * block_M : q_start_idx + bx * block_M + block_M, head_idx, :], Q_shared + ) # OOB positions will be handled below T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = T.ceildiv(k_current_seqlen, block_N) + offset = kv_current_seqlen - q_current_seqlen # always align on the right + loop_range = ( + T.min(T.ceildiv(offset + (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N)) + if is_causal + else T.ceildiv(kv_current_seqlen, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): # Q * K - for i, d in T.Parallel(block_N, dim): - if k * block_N + i < k_current_seqlen: - K_shared[i, d] = K_unpad[k_start_idx + k * block_N + i, head_idx, d] - else: - K_shared[i, d] = 0 + T.copy( + K_unpad[kv_start_idx + k * block_N : kv_start_idx + k * block_N + block_N, head_idx, :], K_shared + ) # OOB positions will be handled below if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and - (bx * block_M + i >= q_current_seqlen or - k * block_N + j >= k_current_seqlen), - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i + offset < k * block_N + j) + or (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), + -1e9, + 0, + ) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or - k * block_N + j >= k_current_seqlen), - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), -1e9, 0 + ) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -167,6 +109,8 @@ def main( T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # in the first ceil_div(kBlockM, kBlockN) steps. @@ -189,18 +133,17 @@ def main( acc_o[i, j] *= scores_scale[i] # V * softmax(Q * K) - for i, d in T.grid(block_N, dim): - if k * block_N + i < v_current_seqlen: - V_shared[i, d] = V_unpad[v_start_idx + k * block_N + i, head_idx, d] - else: - V_shared[i, d] = 0 + T.copy( + V_unpad[kv_start_idx + k * block_N : kv_start_idx + k * block_N + block_N, head_idx, :], V_shared + ) # OOB positions' weights are 0 T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_M, dim): - acc_o[i, j] /= logsum[i] - T.copy(acc_o, O_shared) + # When sq > skv, some tokens can see nothing + acc_o[i, j] = 0 if is_causal and bx * block_M + i + offset < 0 else acc_o[i, j] / logsum[i] + T.copy(acc_o, O_shared) for i, d in T.Parallel(block_M, dim): if bx * block_M + i < q_current_seqlen: Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d] @@ -208,19 +151,17 @@ def main( return main -def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128): +def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128, causal: bool = False, tune: bool = False): flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim total_flops = 2 * flops_per_matmul tilelang.testing.set_random_seed(0) - causal = False if causal: total_flops *= 0.5 dtype = torch.float16 device = torch.device("cuda") - window_size = (-1, -1) q = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) k = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) @@ -240,30 +181,23 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128): k, v, output_pad_fn, - dq_pad_fn, - dk_pad_fn, - ) = generate_qkv( - q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + _, + _, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) UQ = q_unpad.shape[0] # unpadded query length - UK = k_unpad.shape[0] # unpadded key length UKV = k_unpad.shape[0] # unpadded query key length - kernel = flashattn(batch, UQ, UKV, heads, dim, causal) + if tune: + with set_autotune_inputs(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q): + kernel = flashattn(batch, UQ, UKV, heads, dim, causal) + else: + kernel = flashattn(batch, UQ, UKV, heads, dim, causal, block_M=64, block_N=64, num_stages=1, threads=128) + # NOTE: (128, 128, 2or3, 256) is recommended for Hopper out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) out = output_pad_fn(out_unpad) - out_ref, _ = attention_ref( - q, - k, - v, - query_padding_mask, - key_padding_mask, - causal=causal, - ) - torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) - import flash_attn fla_out_unpad = flash_attn.flash_attn_varlen_func( @@ -282,13 +216,67 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128): print("All checks passed.✅") + # benchmark + t = do_bench(lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)) + print(f"Tilelang time: {t} ms") + print(f"Tilelang: {total_flops / t * 1e-9} TFlops") + t = do_bench( + lambda: flash_attn.flash_attn_varlen_func( + q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, 0.0, causal=causal + ) + ) + print(f"FA2 time: {t} ms") + print(f"FA2: {total_flops / t * 1e-9} TFlops") + + +def run_regression_perf(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128, causal: bool = False): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + tilelang.testing.set_random_seed(0) + if causal: + total_flops *= 0.5 + dtype = torch.float16 + device = torch.device("cuda") + q = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) + k = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) + v = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) + query_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random") + key_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random") + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + UQ = q_unpad.shape[0] + UKV = k_unpad.shape[0] + kernel = flashattn(batch, UQ, UKV, heads, dim, causal, block_M=128, block_N=128, num_stages=2, threads=256) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) + + return do_bench(run_kernel_only, backend="cupti") + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='heads') - parser.add_argument('--seq_len', type=int, default=2048, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_len", type=int, default=2048, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", default=False, help="causal attention") + parser.add_argument("--tune", action="store_true", default=False, help="tune the kernel") args = parser.parse_args() - main(args.batch, args.heads, args.seq_len, args.dim) + main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/regression_example_flash_attention.py b/examples/flash_attention/regression_example_flash_attention.py new file mode 100644 index 000000000..8710bbb6e --- /dev/null +++ b/examples/flash_attention/regression_example_flash_attention.py @@ -0,0 +1,74 @@ +import tilelang.testing +import example_gqa_fwd_bshd +import example_gqa_fwd_bshd_wgmma_pipelined +import example_mha_fwd_bhsd +import example_mha_fwd_bhsd_wgmma_pipelined +import example_mha_fwd_bshd +import example_mha_fwd_bshd_wgmma_pipelined +import example_mha_fwd_varlen +import example_gqa_bwd_tma_reduce_varlen +import example_gqa_bwd +import example_gqa_bwd_wgmma_pipelined +import example_mha_bwd_bshd +import example_mha_bwd_bhsd +import example_mha_bwd_bshd_wgmma_pipelined + + +def regression_example_gqa_bwd_tma_reduce_varlen(): + tilelang.testing.process_func(example_gqa_bwd_tma_reduce_varlen.run_regression_perf) + + +def regression_example_gqa_bwd(): + tilelang.testing.process_func(example_gqa_bwd.run_regression_perf) + + +def regression_example_gqa_bwd_wgmma_pipelined(): + tilelang.testing.process_func(example_gqa_bwd_wgmma_pipelined.run_regression_perf) + + +def regression_example_mha_bwd_bshd(): + tilelang.testing.process_func(example_mha_bwd_bshd.run_regression_perf) + + +def regression_example_mha_bwd_bhsd(): + tilelang.testing.process_func(example_mha_bwd_bhsd.run_regression_perf) + + +def regression_example_mha_bwd_bshd_wgmma_pipelined(): + tilelang.testing.process_func(example_mha_bwd_bshd_wgmma_pipelined.run_regression_perf) + + +def regression_example_gqa_fwd_bshd_wgmma_pipelined(): + tilelang.testing.process_func( + example_gqa_fwd_bshd_wgmma_pipelined.run_regression_perf, batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16 + ) + + +def regression_example_gqa_fwd_bshd(): + tilelang.testing.process_func( + example_gqa_fwd_bshd.run_regression_perf, batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16 + ) + + +def regression_example_mha_fwd_bhsd_wgmma_pipelined(): + tilelang.testing.process_func(example_mha_fwd_bhsd_wgmma_pipelined.run_regression_perf) + + +def regression_example_mha_fwd_bhsd(): + tilelang.testing.process_func(example_mha_fwd_bhsd.run_regression_perf) + + +def regression_example_mha_fwd_bshd_wgmma_pipelined(): + tilelang.testing.process_func(example_mha_fwd_bshd_wgmma_pipelined.run_regression_perf, batch=1, heads=32, seq_len=256) + + +def regression_example_mha_fwd_bshd(): + tilelang.testing.process_func(example_mha_fwd_bshd.run_regression_perf, batch=1, seq_len=256) + + +def regression_example_mha_fwd_varlen(): + tilelang.testing.process_func(example_mha_fwd_varlen.run_regression_perf, batch=4, heads=16, seq_len=512, dim=64) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/flash_attention/test_example_flash_attention.py b/examples/flash_attention/test_example_flash_attention.py index 8a58f3b6a..a74bf071b 100644 --- a/examples/flash_attention/test_example_flash_attention.py +++ b/examples/flash_attention/test_example_flash_attention.py @@ -2,7 +2,7 @@ import example_gqa_bwd import example_gqa_bwd_wgmma_pipelined -import example_mha_bwd +import example_mha_bwd_bshd import example_mha_bwd_bhsd import example_mha_fwd_bhsd_wgmma_pipelined import example_gqa_fwd_bshd @@ -10,8 +10,15 @@ import example_gqa_fwd_bshd_wgmma_pipelined import example_mha_fwd_bshd_wgmma_pipelined import example_mha_fwd_varlen -import example_mha_bwd_wgmma_pipelined +import example_mha_bwd_bshd_wgmma_pipelined import example_mha_fwd_bhsd +import example_gqa_bwd_tma_reduce_varlen +import example_gqa_fwd_varlen + + +@tilelang.testing.requires_cuda +def test_example_gqa_bwd_tma_reduce_varlen(): + example_gqa_bwd_tma_reduce_varlen.main() @tilelang.testing.requires_cuda @@ -27,31 +34,41 @@ def test_example_gqa_bwd_wgmma_pipelined(): @tilelang.testing.requires_cuda def test_example_mha_bwd(): - example_mha_bwd.main(BATCH=1) + example_mha_bwd_bshd.main( + BATCH=1, + H=16, + N_CTX=512, + D_HEAD=64, + causal=False, + ) @tilelang.testing.requires_cuda def test_example_mha_bwd_bhsd(): - example_mha_bwd_bhsd.main(BATCH=1) + example_mha_bwd_bhsd.main( + BATCH=1, + H=16, + N_CTX=512, + D_HEAD=64, + causal=False, + ) @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_mha_bwd_wgmma_pipelined(): - example_mha_bwd_wgmma_pipelined.main(BATCH=1) + example_mha_bwd_bshd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False) @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_gqa_fwd_bshd_wgmma_pipelined(): - example_gqa_fwd_bshd_wgmma_pipelined.main( - batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) + example_gqa_fwd_bshd_wgmma_pipelined.main(batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) @tilelang.testing.requires_cuda def test_example_gqa_fwd_bshd(): - example_gqa_fwd_bshd.main( - batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) + example_gqa_fwd_bshd.main(batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) @tilelang.testing.requires_cuda @@ -78,7 +95,14 @@ def test_example_mha_fwd_bshd(): @tilelang.testing.requires_cuda def test_example_mha_fwd_varlen(): - example_mha_fwd_varlen.main() + example_mha_fwd_varlen.main(batch=4, heads=16, seq_len=512, dim=64, causal=False) + example_mha_fwd_varlen.main(batch=4, heads=16, seq_len=512, dim=64, causal=True) + + +@tilelang.testing.requires_cuda +def test_example_gqa_fwd_varlen(): + example_gqa_fwd_varlen.main(batch=4, heads=16, q_seqlen=512, k_seqlen=512, dim=64, is_causal=False) + example_gqa_fwd_varlen.main(batch=4, heads=16, q_seqlen=512, k_seqlen=512, dim=64, is_causal=True) if __name__ == "__main__": diff --git a/examples/flash_attention/varlen_utils.py b/examples/flash_attention/varlen_utils.py index 4301215d5..43e21cc3b 100644 --- a/examples/flash_attention/varlen_utils.py +++ b/examples/flash_attention/varlen_utils.py @@ -9,22 +9,14 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): if mode == "full": lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) elif mode == "random": - lengths = torch.randint( - max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) + lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) elif mode == "third": lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) - padding_mask = ( - repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths) + padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths return padding_mask -def generate_qkv(q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - kvpacked=False, - qkvpacked=False): +def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False): """ Arguments: q: (batch_size, seqlen_q, nheads, d) @@ -39,15 +31,12 @@ def generate_qkv(q, if query_padding_mask is not None: q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) - output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q - ) + output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q) else: q_unpad = rearrange(q, "b s h d -> (b s) h d") - cu_seqlens_q = torch.arange( - 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device) + cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device) max_seqlen_q = seqlen_q - output_pad_fn = lambda output_unpad: rearrange( - output_unpad, "(b s) h d -> b s h d", b=batch_size) + output_pad_fn = lambda output_unpad: rearrange(output_unpad, "(b s) h d -> b s h d", b=batch_size) if key_padding_mask is not None: k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) @@ -55,8 +44,7 @@ def generate_qkv(q, else: k_unpad = rearrange(k, "b s h d -> (b s) h d") v_unpad = rearrange(v, "b s h d -> (b s) h d") - cu_seqlens_k = torch.arange( - 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device) + cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device) max_seqlen_k = seqlen_k if qkvpacked: @@ -67,8 +55,7 @@ def generate_qkv(q, if query_padding_mask is not None: dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) else: - dqkv_pad_fn = lambda dqkv_unpad: rearrange( - dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) + dqkv_pad_fn = lambda dqkv_unpad: rearrange(dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) return ( qkv_unpad.detach().requires_grad_(), cu_seqlens_q, @@ -84,8 +71,7 @@ def generate_qkv(q, if key_padding_mask is not None: dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) else: - dkv_pad_fn = lambda dkv_unpad: rearrange( - dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) + dkv_pad_fn = lambda dkv_unpad: rearrange(dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) return ( q_unpad.detach().requires_grad_(), kv_unpad.detach().requires_grad_(), diff --git a/examples/flash_decoding/example_gqa_decode.py b/examples/flash_decoding/example_gqa_decode.py index 5f946d8b5..9e6f36017 100644 --- a/examples/flash_decoding/example_gqa_decode.py +++ b/examples/flash_decoding/example_gqa_decode.py @@ -15,18 +15,12 @@ def get_configs(): block_N = [64, 128] block_H = [64] - num_split = [2, 4, 8] + num_split = [1, 2, 4, 8] num_stages = [1, 2, 3] threads = [128] _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) - configs = [{ - 'block_N': c[0], - 'block_H': c[1], - 'num_split': c[2], - 'num_stages': c[3], - 'threads': c[4] - } for c in _configs] + configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs] return configs @@ -40,45 +34,44 @@ def get_heuristic_config() -> Tuple[Dict, int]: sm_version = sm_major * 10 + sm_minor print(f"CUDA device capability: {sm_version}") if sm_version == 89: - cfg = dict(block_N=128, block_H=64, num_split=16, num_stages=0, threads=128) + cfg = dict(block_N=128, block_H=64, num_split=1, num_stages=0, threads=128) else: - cfg = dict(block_N=128, block_H=64, num_split=16, num_stages=2, threads=128) + cfg = dict(block_N=128, block_H=64, num_split=8, num_stages=2, threads=128) return cfg, sm_version # TODO(lei): fix warp specialized and tma lower pass def get_pass_configs(): - return { - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - } + return {tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[6], pass_configs=get_pass_configs()) -def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages, - threads): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) +def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages, threads): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape_q = [batch, heads, dim] shape_k = [batch, seqlen_kv, groups, dim] shape_v = [batch, seqlen_kv, groups, dim] shape_o = [batch, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // groups part_shape = [batch, heads, num_split, dim] valid_block_H = min(block_H, kv_group_num) valid_block_N = min(block_N, seqlen_kv // num_split) - @T.macro - def flash_attn( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - Output: T.Tensor([batch, heads, dim], dtype), + @T.prim_func + def flashattn_gqa_decode_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_o, dtype), ): + # split with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) @@ -96,25 +89,43 @@ def flash_attn( bid = bx hid = by + sid = bz cur_kv_head = hid // (kv_group_num // valid_block_H) - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + for k in T.Pipelined(loop_range, num_stages=num_stages): - T.copy(K[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_shared) - T.copy(mask[bid, k * block_N:(k + 1) * block_N, cur_kv_head], mask_local) + T.copy( + K[ + bid, + (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, + cur_kv_head, + :, + ], + K_shared, + ) + T.copy( + mask[ + bid, + (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, + cur_kv_head, + ], + mask_local, + ) T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j], - -T.infinity(accum_dtype)) + acc_s[i, j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split), acc_s[i, j], -T.infinity(accum_dtype)) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): @@ -125,23 +136,66 @@ def flash_attn( T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(block_H, dim): acc_o[i, j] *= scores_scale[i] - T.copy(V[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], V_shared) + T.copy( + V[ + bid, + (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, + cur_kv_head, + :, + ], + V_shared, + ) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] T.copy(acc_o[:valid_block_H, :], O_shared) - T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :]) - - @T.macro - def flash_attn_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), + T.copy(O_shared, Output_partial[bid, hid * valid_block_H : (hid + 1) * valid_block_H, sid, :]) + + # combine + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim], dtype) + o_accum_local = T.alloc_fragment([dim], accum_dtype) + lse_local = T.alloc_fragment([num_split, 128], dtype) + lse_logsum_local = T.alloc_fragment([128], accum_dtype) + lse_max_local = T.alloc_fragment([128], accum_dtype) + scale_local = T.alloc_fragment([128], accum_dtype) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + for k, j in T.Parallel(num_split, 128): + lse_local[k, j] = glse[bz, by, k] + T.reduce_max(lse_local, lse_max_local, dim=0, clear=True) + for k in T.serial(num_split): + for j in T.Parallel(128): + lse_logsum_local[j] += T.exp2(lse_local[k, j] - lse_max_local[j]) + for j in T.Parallel(128): + lse_logsum_local[j] = T.log2(lse_logsum_local[j]) + lse_max_local[j] + for k in T.serial(num_split): + for i in T.Parallel(dim): + po_local[i] = Output_partial[bz, by, k, i] + for j in T.Parallel(128): + scale_local[j] = T.exp2(lse_local[k, j] - lse_logsum_local[j]) + # Note: Pay attention to dim and the number of threads in Parallel + for i in T.Parallel(dim): + o_accum_local[i] += po_local[i] * scale_local[i] + for i in T.Parallel(dim): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def flashattn_gqa_decode_no_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_o, dtype), ): with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -160,34 +214,26 @@ def flash_attn_split( bid = bx hid = by - sid = bz cur_kv_head = hid // (kv_group_num // valid_block_H) - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = T.ceildiv((seqlen_kv // num_split), block_N) - for k in T.Pipelined(loop_range, num_stages=num_stages): - T.copy( - K[bid, (seqlen_kv // num_split) * sid + - k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, - cur_kv_head, :], K_shared) - T.copy( - mask[bid, (seqlen_kv // num_split) * sid + - k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, - cur_kv_head], mask_local) + T.copy(K[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_shared) + T.copy(mask[bid, k * block_N : (k + 1) * block_N, cur_kv_head], mask_local) T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, block_N): - acc_s[i, - j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split), - acc_s[i, j], -T.infinity(accum_dtype)) + acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j], -T.infinity(accum_dtype)) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): @@ -198,88 +244,14 @@ def flash_attn_split( T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(block_H, dim): acc_o[i, j] *= scores_scale[i] - T.copy( - V[bid, (seqlen_kv // num_split) * sid + - k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, - cur_kv_head, :], V_shared) + T.copy(V[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - - for i in T.Parallel(block_H): - if i < valid_block_H: - glse[bid, hid * valid_block_H + i, sid] = logsum[i] T.copy(acc_o[:valid_block_H, :], O_shared) - T.copy(O_shared, Output_partial[bid, hid * valid_block_H:(hid + 1) * valid_block_H, - sid, :]) - - @T.macro - def combine( - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), - Output: T.Tensor(shape_o, dtype), - ): - with T.Kernel(heads, batch, threads=128) as (by, bz): - po_local = T.alloc_fragment([dim], dtype) - o_accum_local = T.alloc_fragment([dim], accum_dtype) - lse_local = T.alloc_fragment([num_split, 128], dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_fragment([128], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - - T.annotate_layout({ - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i), - # lse_local: (local_id, thread_id) - lse_local: T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)), - }) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - for k, j in T.Parallel(num_split, 128): - lse_local[k, j] = glse[bz, by, k] - T.reduce_max(lse_local, lse_max_local, dim=0, clear=True) - for k in T.Pipelined(num_split, num_stages=1): - lse_local_split[0] = glse[bz, by, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] - for k in T.serial(num_split): - for i in T.Parallel(dim): - po_local[i] = Output_partial[bz, by, k, i] - lse_local_split[0] = glse[bz, by, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) - for i in T.Parallel(dim): - o_accum_local[i] += po_local[i] * scale_local[0] - for i in T.Parallel(dim): - Output[bz, by, i] = o_accum_local[i] - - @T.prim_func - def flashattn_gqa_decode_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), - Output: T.Tensor(shape_o, dtype), - ): - flash_attn_split(Q, K, V, mask, glse, Output_partial) - combine(glse, Output_partial, Output) - - @T.prim_func - def flashattn_gqa_decode_no_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), - Output: T.Tensor(shape_o, dtype), - ): - flash_attn(Q, K, V, mask, Output) + T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) if num_split > 1: return flashattn_gqa_decode_split @@ -300,27 +272,21 @@ def ref_program(query, key, value, mask, glse, Output_partial): dim = query.shape[-1] num_head_groups = query.shape[1] // key.shape[2] scale = dim**0.5 - key = rearrange(key, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] - value = rearrange(value, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + key = rearrange(key, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] if mask is not None: - mask = rearrange(mask, 'b s h -> b h s') + mask = rearrange(mask, "b s h -> b h s") mask = mask.unsqueeze(1) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, value, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out @@ -334,16 +300,12 @@ def flash_split_ref(Q, K, V, mask): seqlen_kv = K.size(1) num_head_groups = nheads // groups - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) acc_s = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float) - acc_s_cast = torch.empty((batch, num_head_groups, groups, block_N), - device="cuda", - dtype=torch.float16) + acc_s_cast = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float16) acc_o = torch.empty((batch, num_head_groups, groups, dim), device="cuda", dtype=torch.float) scores_max = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) - scores_max_prev = torch.empty((batch, num_head_groups, groups), - device="cuda", - dtype=torch.float) + scores_max_prev = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) scores_scale = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) scores_sum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) logsum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) @@ -351,25 +313,25 @@ def flash_split_ref(Q, K, V, mask): glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float) Q_ = Q * scale - Q_ = rearrange(Q_, 'b (h g) d -> b g h d', g=num_head_groups) + Q_ = rearrange(Q_, "b (h g) d -> b g h d", g=num_head_groups) for ks in range(num_split): acc_o.fill_(0) logsum.fill_(0) - scores_max.fill_(float('-inf')) - scores_max_prev.fill_(float('-inf')) + scores_max.fill_(float("-inf")) + scores_max_prev.fill_(float("-inf")) for i in range(int((seqlen_kv // num_split) / block_N)): acc_s.fill_(0) - acc_s = torch.einsum('bghd,bkhd->bghk', Q_, - K[:, (seqlen_kv // num_split) * ks + - i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) # [batch, nheads, block_N] + acc_s = torch.einsum( + "bghd,bkhd->bghk", + Q_, + K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) # [batch, nheads, block_N] if mask is not None: - mask_local = mask[:, (seqlen_kv // num_split) * ks + - i * block_N:(seqlen_kv // num_split) * ks + (i + 1) * block_N, :] - mask_local = rearrange(mask_local, 'b s h -> b h s') + mask_local = mask[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :] + mask_local = rearrange(mask_local, "b s h -> b h s") mask_local = mask_local.unsqueeze(1) - acc_s = acc_s.masked_fill(mask_local == 0, float('-inf')) + acc_s = acc_s.masked_fill(mask_local == 0, float("-inf")) scores_max_prev = scores_max scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads] scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads] @@ -377,15 +339,16 @@ def flash_split_ref(Q, K, V, mask): acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N] acc_o += torch.einsum( - 'bghk,bkhd->bghd', acc_s_cast, - V[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) + "bghk,bkhd->bghd", + acc_s_cast, + V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) scores_sum = acc_s.sum(dim=-1, keepdim=False) logsum = logsum * scores_scale + scores_sum - acc_o_out = rearrange(acc_o, 'b g h d->b (h g) d') - logsum_out = rearrange(logsum, 'b g h->b (h g)') + acc_o_out = rearrange(acc_o, "b g h d->b (h g) d") + logsum_out = rearrange(logsum, "b g h->b (h g)") acc_o_out /= logsum_out[:, :, None] - logsum_out = torch.log2(logsum_out) + rearrange(scores_max, 'b g h->b (h g)') + logsum_out = torch.log2(logsum_out) + rearrange(scores_max, "b g h->b (h g)") gacc_o[ks, :, :, :] = acc_o_out glogsum[ks, :, :] = logsum_out @@ -421,7 +384,7 @@ def calc_sim(x, y, name="tensor"): x, y = x.data.double(), y.data.double() denominator = (x * x + y * y).sum() if denominator == 0: - print_red_warning(f'{name} all zero') + print_red_warning(f"{name} all zero") return 1 sim = 2 * (x * y).sum() / denominator return sim @@ -429,28 +392,23 @@ def calc_sim(x, y, name="tensor"): def assert_similar(x, y, eps=1e-2, name="tensor", assert_=False, print_=True): sim = calc_sim(x, y, name) - diff = 1. - sim + diff = 1.0 - sim if not (0 <= diff <= eps): - print_red_warning(f'{name} Error: {diff}') + print_red_warning(f"{name} Error: {diff}") if assert_: - raise AssertionError(f'{name} Error: {diff}') + raise AssertionError(f"{name} Error: {diff}") else: if print_: - print(f'passed: {name} diff={diff}') + print(f"passed: {name} diff={diff}") -def main(batch: int = 1, - heads: int = 32, - groups: int = 8, - kv_seqlen: int = 8192, - dim: int = 128, - tune: bool = False): +def main(batch: int = 1, heads: int = 32, groups: int = 8, kv_seqlen: int = 8192, dim: int = 128, tune: bool = False): batch, heads, groups, kv_seqlen, dim = batch, heads, groups, kv_seqlen, dim qk_flops = 2 * batch * heads * kv_seqlen * dim pv_flops = 2 * batch * heads * kv_seqlen * dim total_flops = qk_flops + pv_flops - if (not tune): + if not tune: config, sm_version = get_heuristic_config() kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) @@ -459,8 +417,9 @@ def main(batch: int = 1, k = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16) v = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16) mask = torch.randint(0, 2, (batch, kv_seqlen, groups), device="cuda", dtype=torch.uint8) - glse = torch.empty(batch, heads, 16, device="cuda", dtype=torch.float16) - Output_partial = torch.empty(batch, heads, 16, dim, device="cuda", dtype=torch.float16) + split = config["num_split"] + glse = torch.empty(batch, heads, split, device="cuda", dtype=torch.float16) + Output_partial = torch.empty(batch, heads, split, dim, device="cuda", dtype=torch.float16) o = kernel(q, k, v, mask, glse, Output_partial) o_ref = ref_program(q, k, v, mask, glse, Output_partial) o_ref_split = ref_split_program(q, k, v, mask, glse, Output_partial) @@ -469,7 +428,7 @@ def main(batch: int = 1, print(o_ref) assert_similar(o, o_ref, name="o_ref") - assert_similar(o_ref_split, o_ref, name="o_ref_split") + assert_similar(o, o_ref_split, name="o_ref_split") print("All checks pass.") latency = profiler.do_bench(ref_program, warmup=500) @@ -489,13 +448,21 @@ def main(batch: int = 1, print(f"Ref latency: {ref_latency}") +def run_regression_perf(batch: int = 1, heads: int = 32, groups: int = 8, kv_seqlen: int = 8192, dim: int = 128): + batch, heads, groups, kv_seqlen, dim = batch, heads, groups, kv_seqlen, dim + config, _ = get_heuristic_config() + kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--groups', type=int, default=8, help='groups') - parser.add_argument('--kv_seqlen', type=int, default=8192, help='kv sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--groups", type=int, default=8, help="groups") + parser.add_argument("--kv_seqlen", type=int, default=8192, help="kv sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() main(args.batch, args.heads, args.groups, args.kv_seqlen, args.dim, args.tune) diff --git a/examples/flash_decoding/example_gqa_decode_varlen_logits.py b/examples/flash_decoding/example_gqa_decode_varlen_logits.py new file mode 100644 index 000000000..30acd879e --- /dev/null +++ b/examples/flash_decoding/example_gqa_decode_varlen_logits.py @@ -0,0 +1,785 @@ +import torch +import triton +import triton.language as tl +import math +import argparse +import tilelang +import tilelang.language as T + +torch.manual_seed(0) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +@triton.jit +def _fwd_inner( + q, + k_ptrs, + v_ptrs, + s_ptrs, + m_i, + l_i, + acc, + offs_h, + mask_h, + offs_n, + seqlen, + softmax_scale, + lo, + hi, + stride_kt, + stride_vt, + stride_sh, + stride_sn, + BLOCK_N: tl.constexpr, +): + """Inner loop computation for attention""" + + for blk_idx in tl.range(lo, hi): + start_n = blk_idx * BLOCK_N + k = tl.load(k_ptrs + start_n * stride_kt, mask=offs_n[None, :] + start_n < seqlen) + v = tl.load(v_ptrs + start_n * stride_vt, mask=offs_n[:, None] + start_n < seqlen) + + qk = tl.dot(q, k) + qk *= softmax_scale + qk += tl.where(offs_n[None, :] + start_n < seqlen, 0, -1.0e9) + + row_max = tl.max(qk, 1) + tl.store(s_ptrs + offs_h * stride_sh + blk_idx * stride_sn, row_max, mask=mask_h) + + m_ij = tl.maximum(m_i, row_max) + qk -= m_ij[:, None] + p = tl.math.exp(qk) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + m_i = m_ij + acc *= alpha[:, None] + p = p.to(v.type.element_ty) + acc += tl.dot(p, v) + + return m_i, l_i, acc + + +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [4, 8] for num_stages in [2, 4]], + key=["gqa_group_size", "BLOCK_N", "BLOCK_D", "BLOCK_H"], +) +@triton.jit +def _fwd_kernel_varlen( + Q, # [token_q = b, h_q, dim] + K, # [token_k, h_kv, dim] + V, + O, + S, + s_aux, + softmax_scale, + cu_seqlens_k, + stride_qt, + stride_qh, + stride_qd, + stride_kt, + stride_kh, + stride_kd, + stride_vt, + stride_vh, + stride_vd, + stride_ot, + stride_oh, + stride_od, + stride_sb, + stride_sh, + stride_sn, # bmask shape [b, q_h, seq/BLOCK_N] + gqa_group_size: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + off_z = tl.program_id(0) + off_h_for_kv = tl.program_id(1) + off_h_q = off_h_for_kv * gqa_group_size + + cu_k_start = tl.load(cu_seqlens_k + off_z) + cu_k_end = tl.load(cu_seqlens_k + off_z + 1) + + seqlen_k = cu_k_end - cu_k_start + + offs_h = tl.arange(0, BLOCK_H) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + Q_ptrs = Q + off_z * stride_qt + off_h_q * stride_qh + K_ptrs = K + (cu_k_start) * stride_kt + off_h_for_kv * stride_kh + V_ptrs = V + (cu_k_start) * stride_vt + off_h_for_kv * stride_vh + O_ptrs = O + off_z * stride_ot + off_h_q * stride_oh + S_ptrs = S + off_z * stride_sb + off_h_q * stride_sh + + mask_h = offs_h < gqa_group_size + q = tl.load(Q_ptrs + offs_d[None, :] * stride_qd + offs_h[:, None] * stride_qh, mask=mask_h[:, None]) + + if s_aux is not None: + sink = tl.load(s_aux + off_h_q + offs_h, mask=mask_h).to(tl.float32) + l_i = tl.zeros([BLOCK_H], dtype=tl.float32) + m_i = tl.zeros([BLOCK_H], dtype=tl.float32) + sink + else: + l_i = tl.full([BLOCK_H], 1.0, dtype=tl.float32) + m_i = tl.full([BLOCK_H], float("-inf"), dtype=tl.float32) + + acc = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32) + + k_ptrs = K_ptrs + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd + v_ptrs = V_ptrs + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd + + lo, hi = 0, tl.cdiv(seqlen_k, BLOCK_N) + m_i, l_i, acc = _fwd_inner( + q, + k_ptrs, + v_ptrs, + S_ptrs, + m_i, + l_i, + acc, + offs_h, + mask_h, + offs_n, + seqlen_k, + softmax_scale, + lo, + hi, + stride_kt, + stride_vt, + stride_sh, + stride_sn, + BLOCK_N, + ) + + if s_aux is not None: + sink = tl.math.exp(sink - m_i) + l_i = l_i + sink + acc = acc / l_i[:, None] + + else: + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + + for blk_idx in tl.range(lo, hi): + s = tl.load(S_ptrs + offs_h * stride_sh + blk_idx * stride_sn, mask=mask_h) + s = tl.exp(s - m_i) / l_i + tl.store(S_ptrs + offs_h * stride_sh + blk_idx * stride_sn, s, mask=mask_h) + + acc = acc.to(O.dtype.element_ty) + + tl.store(O_ptrs + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od, acc, mask=mask_h[:, None]) + + +def get_configs(): + import itertools + + block_N = [64, 128] + block_H = [64] + num_split = [1] + num_stages = [1, 2, 3] + threads = [128] + _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) + + configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs] + return configs + + +@tilelang.jit(out_idx=[-2, -1]) +def flashattn( + batch, heads, k_heads, max_seqlen_kv, total_seqlen_k, dim, has_sink, block_N=128, block_H=64, num_split=1, num_stages=1, threads=128 +): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape_q = [batch, heads, dim] + shape_k = [total_seqlen_k, k_heads, dim] + shape_v = [total_seqlen_k, k_heads, dim] + shape_o = [batch, heads, dim] + shape_s = [batch, heads, math.ceil(max_seqlen_kv / block_N)] + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // k_heads + + valid_block_H = min(block_H, kv_group_num) + # TODO: check if max_seqlen_kv is correct for varlen case + + @T.prim_func + def flashattn_gqa_decode_no_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + cu_seqlens_k: T.Tensor([batch + 1], T.int32), + s_aux: T.Tensor([heads], T.float32), + Output: T.Tensor(shape_o, dtype), + S: T.Tensor(shape_s, dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([valid_block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype) + # S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype) + s_aux_shared = T.alloc_shared([block_H], T.float32) + + bid = bx + hid = by + cur_kv_head = hid // (kv_group_num // valid_block_H) + + cur_start_k = cu_seqlens_k[bid] + cur_end_k = cu_seqlens_k[bid + 1] + cur_seqlen_k = cur_end_k - cur_start_k + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + # loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N) + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy(K[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, block_N): + # acc_s[i, j] = T.if_then_else(mask_local[j] != 0 and k * block_N + j < cur_seqlen_k, acc_s[i, j], + # -T.infinity(accum_dtype)) + acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype)) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # scores_max_prev is m_i + # scores_max is row_max->m_ij in triton + T.copy(scores_max, S_shared[:, k]) + # scores_scale is alpha in triton + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + # scores_sum is l_ij in triton + # logsum is l_i in triton + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.copy(V[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + if has_sink: + T.copy(s_aux[hid * valid_block_H : hid * valid_block_H + block_H], s_aux_shared) + for i in T.Parallel(block_H): + logsum[i] += s_aux_shared[i] + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)): + S_shared[h, k] = T.exp2((S_shared[h, k] - scores_max[h]) * scale) / logsum[h] + # T.copy(S_shared, S_fragment) + # for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)): + # S_fragment[h, k] = T.exp2((S_fragment[h, k] - scores_max[h]) * scale) / logsum[h] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(acc_o[:valid_block_H, :], O_shared) + T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) + # T.copy(S_fragment, S_shared) + T.copy(S_shared[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) + + # TODO: split version + return flashattn_gqa_decode_no_split + + +def flash_attn_with_attn_pool_decode_tilelang( + Q: torch.Tensor, ## [tq = b, q_h, q_dim] + K: torch.Tensor, ## [tk, k_h, k_dim] + V: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_k: int, + real_max_k_seqlen: int, + num_split: int, + softmax_scale: float, + s_aux: torch.Tensor = None, + block_size: int = 64, + use_per_kv_head_sparse_index: bool = False, + tl_kernel=None, +): + num_tokens, q_h, head_size = Q.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = K.size(1) + + assert Q.dim() == K.dim() == 3 + assert Q.size(2) == K.size(2) + assert cu_seqlens_k.dim() == 1 + assert head_size in {64, 128, 256} + assert Q.is_contiguous() + # assert K.is_contiguous() + # assert V.is_contiguous() + + gqa_group_size = q_h // k_h + + O_tl = torch.zeros_like(Q) + S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device) + O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux) + + if use_per_kv_head_sparse_index: + S_tl = torch.max_pool2d(S_tl, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1)) + else: + S_tl = torch.max_pool2d(S_tl, kernel_size=(q_h, 1), stride=(q_h, 1)) + + return O_tl, S_tl + + +def flash_attn_with_attn_pool_decode( + Q: torch.Tensor, ## [tq = b, q_h, q_dim] + K: torch.Tensor, ## [tk, k_h, k_dim] + V: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_k: int, + real_max_k_seqlen: int, + num_split: int, + softmax_scale: float, + s_aux: torch.Tensor = None, + block_size: int = 64, + use_per_kv_head_sparse_index: bool = False, +): + num_tokens, q_h, head_size = Q.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = K.size(1) + + assert Q.dim() == K.dim() == 3 + assert Q.size(2) == K.size(2) + assert cu_seqlens_k.dim() == 1 + assert head_size in {64, 128, 256} + assert Q.is_contiguous() + # assert K.is_contiguous() + # assert V.is_contiguous() + + gqa_group_size = q_h // k_h + + BLOCK_D = head_size + BLOCK_N = block_size + BLOCK_H = 64 + + O = torch.zeros_like(Q) + S = torch.zeros((batch, q_h, math.ceil(max_seqlen_k / block_size)), dtype=Q.dtype, device=Q.device) + + def grid(META): + return (batch, k_h) + + with torch.cuda.device(Q.device.index): + _fwd_kernel_varlen[grid]( + Q, + K, + V, + O, + S, + s_aux, + softmax_scale, + cu_seqlens_k, + *Q.stride(), + *K.stride(), + *V.stride(), + *O.stride(), + *S.stride(), + gqa_group_size, + BLOCK_H=BLOCK_H, + BLOCK_N=BLOCK_N, + BLOCK_D=BLOCK_D, + ) + + if use_per_kv_head_sparse_index: + S = torch.max_pool2d(S, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1)) + else: + S = torch.max_pool2d(S, kernel_size=(q_h, 1), stride=(q_h, 1)) + + return O, S + + +def test_varlen_decode_main(args): + """Test decode kernel with variable sequence lengths""" + batch_size = args.batch_size + q_heads = args.q_heads + kv_heads = args.kv_heads + max_k_seqlen = args.k_seqlen # Use as max sequence length + real_max_k_seqlen = args.k_seqlen + head_size = args.head_size + block_size = args.block_size + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 + + print(f"Testing decode kernel with variable sequence lengths (max_k_seqlen={max_k_seqlen})") + + # Generate sink values if needed + sink = None + if args.test_sink: + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values + print(f"Using sink attention with sink values: {sink}") + + # Generate variable length k sequences + k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,)) + print(f"k_seqlens: {k_seqlens}") + + # Generate cumulative sequence lengths for k + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) + total_k_tokens = 0 + for i in range(batch_size): + cu_seqlens_k[i] = total_k_tokens + total_k_tokens += k_seqlens[i] + cu_seqlens_k[batch_size] = total_k_tokens + + print(f"cu_seqlens_k: {cu_seqlens_k}") + + # Generate tensors - Q is [batch_size, q_heads, head_size] for decode + q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + + softmax_scale = 1.0 / math.sqrt(head_size) + max_seqlen_k = int(k_seqlens.max()) + + print(f"Actual max_seqlen_k: {max_seqlen_k}") + print(f"q_decode shape: {q_decode.shape}") + print(f"k_varlen shape: {k_varlen.shape}") + print(f"v_varlen shape: {v_varlen.shape}") + + num_tokens, q_h, head_size = q_decode.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = k_varlen.size(1) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink) + + # Test our decode kernel + O_triton, S_triton = flash_attn_with_attn_pool_decode( + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + ) + O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + tl_kernel=tl_kernel, + ) + for i in range(batch_size): + S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 + + # Create torch reference - pad tensors for comparison + k_padded_list = [] + v_padded_list = [] + + for i in range(batch_size): + actual_k_len = k_seqlens[i] + + # Extract and pad k, v for this batch + k_start = cu_seqlens_k[i] + k_end = cu_seqlens_k[i + 1] + + # Pad to max_seqlen_k + k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) + v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) + + k_padded[:actual_k_len] = k_varlen[k_start:k_end] + v_padded[:actual_k_len] = v_varlen[k_start:k_end] + + k_padded_list.append(k_padded) + v_padded_list.append(v_padded) + + # Stack to create batched tensors [b, max_seqlen, kv_heads, head_size] + k_padded_batched = torch.stack(k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + v_padded_batched = torch.stack(v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + + # Expand q to match kv heads: [b, q_heads, 1, head_size] + q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size] + + print(f"q_expanded shape: {q_expanded.shape}") + print(f"k_padded_batched shape: {k_padded_batched.shape}") + print(f"v_padded_batched shape: {v_padded_batched.shape}") + + # Compute torch reference + k_repeat = repeat_kv(k_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + v_repeat = repeat_kv(v_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + + if sink is None: + # Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen] + attn_score = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + + # Apply sequence length masking + for i in range(batch_size): + actual_k_len = k_seqlens[i] + attn_score[i, :, :, actual_k_len:] = float("-inf") + + attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen] + + # Mask out invalid positions + for i in range(batch_size): + actual_k_len = k_seqlens[i] + attn_weights[i, :, :, actual_k_len:] = 0.0 + + # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] + O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size] + else: + # s_aux attention + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + + # Apply sequence length masking + for i in range(batch_size): + actual_k_len = k_seqlens[i] + logits[i, :, :, actual_k_len:] = float("-inf") + + sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(logits_max, sink_expanded) + sinks = torch.exp(sink_expanded - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + attn_weights = unnormalized_scores / normalizer + + # Mask out invalid positions + for i in range(batch_size): + actual_k_len = k_seqlens[i] + attn_weights[i, :, :, actual_k_len:] = 0.0 + + # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] + O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat) # [b, q_heads, 1, head_size] + + O_torch = O_torch.squeeze(2) # [b, q_heads, head_size] + + # Compute attention score pooling for S + attn_score_pooled = torch.max_pool2d( + attn_weights.squeeze(2), # [b, q_heads, max_seqlen] + kernel_size=(q_heads, block_size), + stride=(q_heads, block_size), + ceil_mode=True, + ).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] + + print(f"O_triton shape: {O_triton.shape}") + print(f"O_tilelang shape: {O_tilelang.shape}") + print(f"O_torch shape: {O_torch.shape}") + print(f"S_triton shape: {S_triton.shape}") + print(f"S_tilelang shape: {S_tilelang.shape}") + print(f"attn_score_pooled shape: {attn_score_pooled.shape}") + + # Compare results + max_diff_o = torch.max(torch.abs(O_triton - O_torch)) + max_diff_o_tl = torch.max(torch.abs(O_tilelang - O_torch)) + print(f"Max difference in O: {max_diff_o.item()}") + print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}") + + max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) + max_diff_s_tl = torch.max( + torch.abs( + S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled[:, :, : math.ceil(max_seqlen_k / block_size)] + ) + ) + print(f"Max difference in S: {max_diff_s.item()}") + print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}") + + assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" + assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" + assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" + assert torch.allclose( + S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], + attn_score_pooled[:, :, : math.ceil(max_seqlen_k / block_size)], + atol=1e-2, + rtol=1e-2, + ), f"Score mismatch: {max_diff_s_tl.item()}" + + print("✅ All tests passed!") + + +def do_bench(fn, *args, warmup=10, rep=10, **kwargs): + """ + Do benchmark for a function. + """ + start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] + end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] + for _ in range(warmup): + fn(*args, **kwargs) + + torch.cuda.synchronize() + for i in range(rep): + start_event[i].record() + fn(*args, **kwargs) + end_event[i].record() + torch.cuda.synchronize() + + # Record clocks + times = torch.tensor( + [s.elapsed_time(e) for s, e in zip(start_event, end_event)], + dtype=torch.float, + ) + + return times.mean().item() + + +def speed_benchmark_decode_comparison(args): + """Speed benchmark for decode kernel""" + batch_size = args.batch_size + q_heads = args.q_heads + kv_heads = args.kv_heads + max_k_seqlen = args.k_seqlen + head_size = args.head_size + block_size = args.block_size + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 + + print("\n=== Decode Speed Benchmark Comparison ===") + print("Configuration:") + print(f" Batch size: {batch_size}") + print(f" Q heads: {q_heads}, KV heads: {kv_heads}") + print(f" Max K sequence length: {max_k_seqlen}") + print(f" Head size: {head_size}") + print(f" Block size: {block_size}") + print(f" Data type: {dtype}") + print(f" Variable lengths: {args.test_varlen}") + print(f" s_aux attention: {args.test_sink}") + print() + + # Generate input data + if args.test_varlen: + k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,)) + else: + k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int) + + # Generate cumulative sequence lengths for k + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) + total_k_tokens = 0 + for i in range(batch_size): + cu_seqlens_k[i] = total_k_tokens + total_k_tokens += k_seqlens[i] + cu_seqlens_k[batch_size] = total_k_tokens + + # Generate tensors + q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + + softmax_scale = 1.0 / math.sqrt(head_size) + max_seqlen_k = int(k_seqlens.max()) + + # Generate sink values if needed + sink = None + if args.test_sink: + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values + print(" Using sink attention with sink values") + + print("Setup complete:") + print(f" Total K tokens: {total_k_tokens}") + print(f" Actual max K seq len: {max_seqlen_k}") + if args.test_varlen: + print(f" K sequence lengths: {k_seqlens.tolist()}") + + # Warmup + num_tokens, q_h, head_size = q_decode.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = k_varlen.size(1) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink) + + # Benchmark + print("⚡ Benchmarking Tilelang kernel (100 iterations)...") + tilelang_time = do_bench( + flash_attn_with_attn_pool_decode_tilelang, + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + args.k_seqlen, + 1, + softmax_scale, + sink, + block_size, + False, + tl_kernel, + ) + print(f"Average decode kernel time Tilelang: {tilelang_time:.3f} ms") + + # Benchmark + print("⚡ Benchmarking Triton kernel (100 iterations)...") + triton_time = do_bench( + flash_attn_with_attn_pool_decode, + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + args.k_seqlen, + 1, + softmax_scale, + sink, + block_size, + ) + print(f"Average decode kernel time Triton: {triton_time:.3f} ms") + + print(f"Speedup: {(triton_time / tilelang_time):.3f}") + + +def main(): + args = argparse.Namespace( + batch_size=1, + q_heads=32, + kv_heads=8, + k_seqlen=8192, + head_size=128, + block_size=128, + dtype=T.float16, + ) + args.test_sink = True + args.test_varlen = True + args.dtype = T.float16 + args.num_split = 1 + test_varlen_decode_main(args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--q_heads", type=int, default=32, help="Number of query heads") + parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads") + parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length") + parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension") + parser.add_argument("--block_size", type=int, default=128, help="Block size for computation") + parser.add_argument("--dtype", type=str, default=T.bfloat16, choices=[T.float16, T.bfloat16], help="Data type") + parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths") + parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism") + parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark") + parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits") + args = parser.parse_args() + args.test_sink = True + args.test_varlen = True + args.dtype = T.float16 + args.num_split = 1 + + if args.benchmark: + speed_benchmark_decode_comparison(args) + else: + test_varlen_decode_main(args) diff --git a/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py b/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py new file mode 100644 index 000000000..87748512d --- /dev/null +++ b/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py @@ -0,0 +1,550 @@ +import torch +import math +import argparse +import tilelang +import tilelang.language as T +from example_gqa_decode_varlen_logits import flash_attn_with_attn_pool_decode, repeat_kv, do_bench + +torch.manual_seed(0) + + +def get_configs(): + import itertools + + block_N = [64, 128] + block_H = [64] + num_split = [1] + num_stages = [1, 2, 3] + threads = [128] + _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) + + configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs] + return configs + + +# @autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit(out_idx=[-2, -1]) +def flashattn( + batch, + heads, + k_heads, + max_seqlen_kv, + total_seqlen_k, + dim, + has_sink, + page_block_size, + block_N=128, + block_H=64, + num_split=1, + num_stages=1, + threads=128, +): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape_q = [batch, heads, dim] + shape_k = [total_seqlen_k, k_heads, dim] + shape_v = [total_seqlen_k, k_heads, dim] + shape_o = [batch, heads, dim] + shape_s = [batch, heads, math.ceil(max_seqlen_kv / block_N)] + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // k_heads + assert page_block_size >= block_N and page_block_size % block_N == 0, ( + "page_block_size must be larger than block_N and a multiple of block_N" + ) + + valid_block_H = min(block_H, kv_group_num) + # TODO: check if max_seqlen_kv is correct for varlen case + + @T.prim_func + def flashattn_gqa_decode_no_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + cu_seqlens_k: T.Tensor([batch + 1], T.int32), + s_aux: T.Tensor([heads], T.float32), + BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / page_block_size)], T.int32), + Output: T.Tensor(shape_o, dtype), + S: T.Tensor(shape_s, dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([valid_block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype) + s_aux_shared = T.alloc_shared([block_H], T.float32) + + bid = bx + hid = by + cur_kv_head = hid // (kv_group_num // valid_block_H) + + cur_start_k = cu_seqlens_k[bid] + cur_end_k = cu_seqlens_k[bid + 1] + cur_seqlen_k = cur_end_k - cur_start_k + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + # loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N) + for k in T.Pipelined(loop_range, num_stages=num_stages): + k_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (k * block_N) % page_block_size + T.copy(K[cur_start_k + k_start : cur_start_k + k_start + block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype)) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # scores_max_prev is m_i + # scores_max is row_max->m_ij in triton + T.copy(scores_max, S_shared[:, k]) + # scores_scale is alpha in triton + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + # scores_sum is l_ij in triton + # logsum is l_i in triton + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + v_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (k * block_N) % page_block_size + T.copy(V[cur_start_k + v_start : cur_start_k + v_start + block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + if has_sink: + T.copy(s_aux[hid * valid_block_H : hid * valid_block_H + block_H], s_aux_shared) + for i in T.Parallel(block_H): + logsum[i] += s_aux_shared[i] + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)): + S_shared[h, k] = T.exp2((S_shared[h, k] - scores_max[h]) * scale) / logsum[h] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(acc_o[:valid_block_H, :], O_shared) + T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) + T.copy(S_shared[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) + + # TODO: split version + return flashattn_gqa_decode_no_split + + +def flash_attn_with_attn_pool_decode_tilelang( + Q: torch.Tensor, ## [tq = b, q_h, q_dim] + K: torch.Tensor, ## [tk, k_h, k_dim] + V: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_k: int, + real_max_k_seqlen: int, + num_split: int, + softmax_scale: float, + s_aux: torch.Tensor = None, + block_size: int = 64, + use_per_kv_head_sparse_index: bool = False, + tl_kernel=None, + block_table: torch.Tensor = None, +): + num_tokens, q_h, head_size = Q.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = K.size(1) + + assert Q.dim() == K.dim() == 3 + assert Q.size(2) == K.size(2) + assert cu_seqlens_k.dim() == 1 + assert head_size in {64, 128, 256} + assert Q.is_contiguous() + assert K.is_contiguous() + assert V.is_contiguous() + + gqa_group_size = q_h // k_h + + O_tl = torch.zeros_like(Q) + S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device) + O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux, block_table) + + if use_per_kv_head_sparse_index: + S_tl = torch.max_pool2d(S_tl, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1)) + else: + S_tl = torch.max_pool2d(S_tl, kernel_size=(q_h, 1), stride=(q_h, 1)) + + return O_tl, S_tl + + +def test_varlen_decode_main(args): + """Test decode kernel with variable sequence lengths""" + batch_size = args.batch_size + q_heads = args.q_heads + kv_heads = args.kv_heads + max_k_seqlen = args.k_seqlen # Use as max sequence length + real_max_k_seqlen = args.k_seqlen + head_size = args.head_size + block_size = args.block_size + page_block_size = args.page_block_size + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 + + print(f"Testing decode kernel with variable sequence lengths (max_k_seqlen={max_k_seqlen})") + + # Generate sink values if needed + sink = None + if args.test_sink: + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values + print(f"Using sink attention with sink values: {sink}") + + # Generate variable length k sequences + k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,)) + print(f"k_seqlens: {k_seqlens}") + + # Generate cumulative sequence lengths for k + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) + total_k_tokens = 0 + for i in range(batch_size): + cu_seqlens_k[i] = total_k_tokens + total_k_tokens += k_seqlens[i] + cu_seqlens_k[batch_size] = total_k_tokens + + print(f"cu_seqlens_k: {cu_seqlens_k}") + + # Generate tensors - Q is [batch_size, q_heads, head_size] for decode + q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + + softmax_scale = 1.0 / math.sqrt(head_size) + max_seqlen_k = int(k_seqlens.max()) + + print(f"Actual max_seqlen_k: {max_seqlen_k}") + print(f"q_decode shape: {q_decode.shape}") + print(f"k_varlen shape: {k_varlen.shape}") + print(f"v_varlen shape: {v_varlen.shape}") + + num_tokens, q_h, head_size = q_decode.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = k_varlen.size(1) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size) + + block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32) + block_cnt = 0 + for i in range(batch): + cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() + for j in range(math.ceil(cur_seqlen / page_block_size)): + block_table[i, j] = block_cnt + block_cnt += 1 + block_cnt = 0 + + # Test our decode kernel + O_triton, S_triton = flash_attn_with_attn_pool_decode( + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + ) + O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + tl_kernel=tl_kernel, + block_table=block_table, + ) + for i in range(batch_size): + S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 + + # Create torch reference - pad tensors for comparison + k_padded_list = [] + v_padded_list = [] + + for i in range(batch_size): + actual_k_len = k_seqlens[i] + + # Extract and pad k, v for this batch + k_start = cu_seqlens_k[i] + k_end = cu_seqlens_k[i + 1] + + # Pad to max_seqlen_k + k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) + v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) + + k_padded[:actual_k_len] = k_varlen[k_start:k_end] + v_padded[:actual_k_len] = v_varlen[k_start:k_end] + + k_padded_list.append(k_padded) + v_padded_list.append(v_padded) + + # Stack to create batched tensors [b, max_seqlen, kv_heads, head_size] + k_padded_batched = torch.stack(k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + v_padded_batched = torch.stack(v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + + # Expand q to match kv heads: [b, q_heads, 1, head_size] + q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size] + + print(f"q_expanded shape: {q_expanded.shape}") + print(f"k_padded_batched shape: {k_padded_batched.shape}") + print(f"v_padded_batched shape: {v_padded_batched.shape}") + + # Compute torch reference + k_repeat = repeat_kv(k_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + v_repeat = repeat_kv(v_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + + if sink is None: + # Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen] + attn_score = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + + # Apply sequence length masking + for i in range(batch_size): + actual_k_len = k_seqlens[i] + attn_score[i, :, :, actual_k_len:] = float("-inf") + + attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen] + + # Mask out invalid positions + for i in range(batch_size): + actual_k_len = k_seqlens[i] + attn_weights[i, :, :, actual_k_len:] = 0.0 + + # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] + O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size] + else: + # s_aux attention + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + + # Apply sequence length masking + for i in range(batch_size): + actual_k_len = k_seqlens[i] + logits[i, :, :, actual_k_len:] = float("-inf") + + sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(logits_max, sink_expanded) + sinks = torch.exp(sink_expanded - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + attn_weights = unnormalized_scores / normalizer + + # Mask out invalid positions + for i in range(batch_size): + actual_k_len = k_seqlens[i] + attn_weights[i, :, :, actual_k_len:] = 0.0 + + # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] + O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat) # [b, q_heads, 1, head_size] + + O_torch = O_torch.squeeze(2) # [b, q_heads, head_size] + + # Compute attention score pooling for S + attn_score_pooled = torch.max_pool2d( + attn_weights.squeeze(2), # [b, q_heads, max_seqlen] + kernel_size=(q_heads, block_size), + stride=(q_heads, block_size), + ceil_mode=True, + ).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] + + print(f"O_triton shape: {O_triton.shape}") + print(f"O_tilelang shape: {O_tilelang.shape}") + print(f"O_torch shape: {O_torch.shape}") + print(f"S_triton shape: {S_triton.shape}") + print(f"S_tilelang shape: {S_tilelang.shape}") + print(f"attn_score_pooled shape: {attn_score_pooled.shape}") + + # Compare results + max_diff_o = torch.max(torch.abs(O_triton - O_torch)) + max_diff_o_tl = torch.max(torch.abs(O_tilelang - O_torch)) + print(f"Max difference in O: {max_diff_o.item()}") + print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}") + + max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) + max_diff_s_tl = torch.max(torch.abs(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled)) + print(f"Max difference in S: {max_diff_s.item()}") + print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}") + + assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" + assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" + assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" + assert torch.allclose(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], attn_score_pooled, atol=1e-2, rtol=1e-2), ( + f"Score mismatch: {max_diff_s_tl.item()}" + ) + + print("✅ All tests passed!") + + +def speed_benchmark_decode_comparison(args): + """Speed benchmark for decode kernel""" + batch_size = args.batch_size + q_heads = args.q_heads + kv_heads = args.kv_heads + max_k_seqlen = args.k_seqlen + real_max_k_seqlen = args.k_seqlen + head_size = args.head_size + block_size = args.block_size + page_block_size = args.page_block_size + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 + + print("\n=== Decode Speed Benchmark Comparison ===") + print("Configuration:") + print(f" Batch size: {batch_size}") + print(f" Q heads: {q_heads}, KV heads: {kv_heads}") + print(f" Max K sequence length: {max_k_seqlen}") + print(f" Head size: {head_size}") + print(f" Block size: {block_size}") + print(f" Data type: {dtype}") + print(f" Variable lengths: {args.test_varlen}") + print(f" s_aux attention: {args.test_sink}") + print() + + # Generate input data + if args.test_varlen: + k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,)) + else: + k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int) + + # Generate cumulative sequence lengths for k + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) + total_k_tokens = 0 + for i in range(batch_size): + cu_seqlens_k[i] = total_k_tokens + total_k_tokens += k_seqlens[i] + cu_seqlens_k[batch_size] = total_k_tokens + + # Generate tensors + q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + + softmax_scale = 1.0 / math.sqrt(head_size) + max_seqlen_k = int(k_seqlens.max()) + + # Generate sink values if needed + sink = None + if args.test_sink: + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values + print(" Using sink attention with sink values") + + print("Setup complete:") + print(f" Total K tokens: {total_k_tokens}") + print(f" Actual max K seq len: {max_seqlen_k}") + if args.test_varlen: + print(f" K sequence lengths: {k_seqlens.tolist()}") + + # Warmup + num_tokens, q_h, head_size = q_decode.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = k_varlen.size(1) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size) + + block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32) + block_cnt = 0 + for i in range(batch): + cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() + for j in range(math.ceil(cur_seqlen / page_block_size)): + block_table[i, j] = block_cnt + block_cnt += 1 + block_cnt = 0 + + # Benchmark + print("⚡ Benchmarking Tilelang kernel (100 iterations)...") + tilelang_time = do_bench( + flash_attn_with_attn_pool_decode_tilelang, + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + args.k_seqlen, + 1, + softmax_scale, + sink, + block_size, + False, + tl_kernel, + block_table, + ) + print(f"Average decode kernel time Tilelang: {tilelang_time:.3f} ms") + + # Benchmark + print("⚡ Benchmarking Triton kernel (100 iterations)...") + triton_time = do_bench( + flash_attn_with_attn_pool_decode, + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + args.k_seqlen, + 1, + softmax_scale, + sink, + block_size, + ) + print(f"Average decode kernel time Triton: {triton_time:.3f} ms") + print(f"Speedup: {(triton_time / tilelang_time):.3f}") + + +def main(): + args = argparse.Namespace( + batch_size=1, + q_heads=32, + kv_heads=8, + k_seqlen=8192, + head_size=128, + block_size=128, + dtype=T.float16, + ) + args.test_sink = True + args.test_varlen = True + args.dtype = T.float16 + args.num_split = 1 + args.page_block_size = 128 + test_varlen_decode_main(args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--q_heads", type=int, default=32, help="Number of query heads") + parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads") + parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length") + parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension") + parser.add_argument("--block_size", type=int, default=128, help="Block size for computation") + parser.add_argument("--dtype", type=str, default=T.bfloat16, choices=[T.float16, T.bfloat16], help="Data type") + parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths") + parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism") + parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark") + parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits") + parser.add_argument("--page_block_size", type=int, default=128, help="Page block size") + args = parser.parse_args() + args.test_sink = True + args.test_varlen = True + args.dtype = T.float16 + args.num_split = 1 + + if args.benchmark: + speed_benchmark_decode_comparison(args) + else: + test_varlen_decode_main(args) diff --git a/examples/flash_decoding/example_mha_inference.py b/examples/flash_decoding/example_mha_inference.py index b4285a64f..24a90c57b 100644 --- a/examples/flash_decoding/example_mha_inference.py +++ b/examples/flash_decoding/example_mha_inference.py @@ -10,102 +10,24 @@ @tilelang.jit(out_idx=[5]) def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape_q = [batch, seqlen_q, heads, dim] shape_kv = [batch, seqlen_kv, heads, dim] part_shape = [batch, seqlen_q, heads, num_split, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 - @T.macro - def MMA0( + @T.prim_func + def flashattn_mha_inference( + Q: T.Tensor(shape_q, dtype), K: T.Tensor(shape_kv, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - mid: T.int32, - hid: T.int32, - bid: T.int32, - sid: T.int32, - ): - T.copy( - K[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid + - (k + 1) * block_N, hid, :], K_shared) - # TODO: Handle causal split case - if is_causal: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( V: T.Tensor(shape_kv, dtype), - V_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - hid: T.int32, - bid: T.int32, - sid: T.int32, - ): - T.copy( - V[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid + - (k + 1) * block_N, hid, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), + Output_partial: T.Tensor(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim] + Output: T.Tensor(shape_q, dtype), ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - - @T.macro - def flash_attn_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_kv, dtype), - V: T.Tensor(shape_kv, dtype), - glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), - Output_partial: T.Tensor(part_shape, dtype), - ): - with T.Kernel( - T.ceildiv(seqlen_q, block_M), heads * batch, num_split, - threads=128) as (bx, by, bz): + # split + with T.Kernel(T.ceildiv(seqlen_q, block_M), heads * batch, num_split, threads=128) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -126,43 +48,73 @@ def flash_attn_split( # NOTE(wt): tma barrier has some problems with padded dimensions (seq_q here) currently # disable relevant tma copy and use SIMT as fallback for now - T.copy(Q[bid, mid * block_M:(mid + 1) * block_M, hid, :], Q_shared, disable_tma=True) + T.copy(Q[bid, mid * block_M : (mid + 1) * block_M, hid, :], Q_shared, disable_tma=True) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) # TODO: Handle causal split case loop_range = ( - T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv( - (mid + 1) * block_M, block_N)) if is_causal else T.ceildiv( - (seqlen_kv // num_split), block_N)) + T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv((mid + 1) * block_M, block_N)) + if is_causal + else T.ceildiv((seqlen_kv // num_split), block_N) + ) for k in T.Pipelined(loop_range, num_stages=2): - MMA0(K, Q_shared, K_shared, acc_s, k, mid, hid, bid, sid) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, hid, bid, sid) + T.copy( + K[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], + K_shared, + ) + # TODO: Handle causal split case + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy( + V[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], + V_shared, + ) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, glse[bid, hid, sid, mid * block_M:(mid + 1) * block_M]) + T.copy(logsum, glse[bid, hid, sid, mid * block_M : (mid + 1) * block_M]) T.copy(acc_o, O_shared) - T.copy( - O_shared, - Output_partial[bid, mid * block_M:(mid + 1) * block_M, hid, sid, :], - disable_tma=True) - - @T.macro - def combine( - glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), - Output_partial: T.Tensor(part_shape, dtype), - Output: T.Tensor(shape_q, dtype), - ): + T.copy(O_shared, Output_partial[bid, mid * block_M : (mid + 1) * block_M, hid, sid, :], disable_tma=True) + + # combine with T.Kernel(T.ceildiv(seqlen_q, block_M), heads, batch, threads=128) as (bx, by, bz): po_local = T.alloc_fragment([block_M, dim], dtype) - po_shared = T.alloc_shared([block_M, dim], dtype) o_accum_local = T.alloc_fragment([block_M, dim], accum_dtype) o_shared = T.alloc_shared([block_M, dim], dtype) lse_local = T.alloc_fragment([num_split, block_M], dtype) @@ -171,20 +123,17 @@ def combine( lse_max_local = T.alloc_fragment([block_M], accum_dtype) scale_local = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({ - o_accum_local: T.Fragment(o_accum_local.shape, forward_thread_fn=lambda i, j: i), - o_shared: tilelang.layout.make_swizzled_layout(o_shared), - po_shared: tilelang.layout.make_swizzled_layout(po_shared), - }) - T.clear(lse_logsum_local) T.clear(o_accum_local) - T.copy(glse[ - bz, - by, - :, - bx * block_M:(bx + 1) * block_M, - ], lse_local) + T.copy( + glse[ + bz, + by, + :, + bx * block_M : (bx + 1) * block_M, + ], + lse_local, + ) T.reduce_max(lse_local, lse_max_local, dim=0, clear=False) for k in T.Pipelined(num_split): T.copy(lse_local[k, :], lse_local_split) @@ -193,11 +142,7 @@ def combine( for i in T.Parallel(block_M): lse_logsum_local[i] = T.log2(lse_logsum_local[i]) + lse_max_local[i] for k in T.Pipelined(num_split, num_stages=2): - T.copy( - Output_partial[bz, bx * block_M:(bx + 1) * block_M, by, k, :], - po_shared, - disable_tma=True) - T.copy(po_shared, po_local) + T.copy(Output_partial[bz, bx * block_M : (bx + 1) * block_M, by, k, :], po_local) for i in T.Parallel(block_M): lse_local_split[i] = lse_local[k, i] for i in T.Parallel(block_M): @@ -205,19 +150,7 @@ def combine( for i, j in T.Parallel(block_M, dim): o_accum_local[i, j] += po_local[i, j] * scale_local[i] T.copy(o_accum_local, o_shared) - T.copy(o_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :], disable_tma=True) - - @T.prim_func - def flashattn_mha_inference( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_kv, dtype), - V: T.Tensor(shape_kv, dtype), - glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), - Output_partial: T.Tensor(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim] - Output: T.Tensor(shape_q, dtype), - ): - flash_attn_split(Q, K, V, glse, Output_partial) - combine(glse, Output_partial, Output) + T.copy(o_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :], disable_tma=True) return flashattn_mha_inference @@ -225,10 +158,10 @@ def flashattn_mha_inference( def ref_program(Q, K, V, glse, Output_partial, causal): assert causal is False dim = Q.size(-1) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -256,7 +189,7 @@ def flash_split_ref(Q, K, V, causal): block_N = 128 seqlen_kv = K.size(1) - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) acc_s = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float) acc_s_cast = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float16) acc_o = torch.empty((batch, block_M, nheads, dim), device="cuda", dtype=torch.float) @@ -273,14 +206,15 @@ def flash_split_ref(Q, K, V, causal): for ks in range(num_split): acc_o.fill_(0) logsum.fill_(0) - scores_max.fill_(float('-inf')) - scores_max_prev.fill_(float('-inf')) + scores_max.fill_(float("-inf")) + scores_max_prev.fill_(float("-inf")) for i in range(int((seqlen_kv // num_split) / block_N)): acc_s.fill_(0) - acc_s = torch.einsum('bqhd,bkhd->bhqk', Q_, - K[:, (seqlen_kv // num_split) * ks + - i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) # [batch, seqlen, nheads, block_N] + acc_s = torch.einsum( + "bqhd,bkhd->bhqk", + Q_, + K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) # [batch, seqlen, nheads, block_N] scores_max_prev = scores_max scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM] scores_scale = torch.exp2(scores_max_prev - scores_max) @@ -288,9 +222,10 @@ def flash_split_ref(Q, K, V, causal): acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) acc_s_cast = acc_s.to(torch.float16) acc_o += torch.einsum( - 'bhqk,bkhd->bqhd', acc_s_cast, - V[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) + "bhqk,bkhd->bqhd", + acc_s_cast, + V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) scores_sum = acc_s.sum(dim=-1, keepdim=False) logsum = logsum * scores_scale + scores_sum acc_o /= logsum[:, :, :, None].transpose(1, 2) @@ -298,13 +233,10 @@ def flash_split_ref(Q, K, V, causal): gacc_o[ks, :, :, :, :] = acc_o glogsum[ks, :, :, :] = logsum - return glogsum.to(torch.float16).permute(1, 2, 0, - 3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4) + return glogsum.to(torch.float16).permute(1, 2, 0, 3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4) -def main(): - BATCH, H, Q_CTX, KV_CTX, D_HEAD = 1, 32, 128, 8192, 128 - causal = False +def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False): flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD total_flops = 2 * flops_per_matmul if causal: @@ -325,5 +257,13 @@ def main(): print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False): + BLOCK_M = 128 + BLOCK_N = 64 + kernel = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/flash_decoding/regression_example_flash_decoding.py b/examples/flash_decoding/regression_example_flash_decoding.py new file mode 100644 index 000000000..476bceb34 --- /dev/null +++ b/examples/flash_decoding/regression_example_flash_decoding.py @@ -0,0 +1,17 @@ +import tilelang.testing +import example_gqa_decode +import example_mha_inference + + +def regression_example_gqa_decode(): + tilelang.testing.process_func(example_gqa_decode.run_regression_perf) + + +def regression_example_mha_inference(): + tilelang.testing.process_func( + example_mha_inference.run_regression_perf, BATCH=1, H=32, Q_CTX=128, KV_CTX=2048, D_HEAD=128, causal=False + ) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/flash_decoding/test_example_flash_decoding.py b/examples/flash_decoding/test_example_flash_decoding.py index a6ec1c68e..a02a92097 100644 --- a/examples/flash_decoding/test_example_flash_decoding.py +++ b/examples/flash_decoding/test_example_flash_decoding.py @@ -2,6 +2,8 @@ import example_gqa_decode import example_mha_inference +import example_gqa_decode_varlen_logits +import example_gqa_decode_varlen_logits_paged # TODO(lei): fix the correctness of gqa decode on sm90 @@ -12,7 +14,15 @@ def test_example_example_gqa_decode(): def test_example_example_mha_inference(): - example_mha_inference.main() + example_mha_inference.main(BATCH=1, H=32, Q_CTX=128, KV_CTX=2048, D_HEAD=128, causal=False) + + +def test_example_example_gqa_decode_varlen_logits(): + example_gqa_decode_varlen_logits.main() + + +def test_example_example_gqa_decode_varlen_logits_paged(): + example_gqa_decode_varlen_logits_paged.main() if __name__ == "__main__": diff --git a/examples/fusedmoe/example_fusedmoe_tilelang.py b/examples/fusedmoe/example_fusedmoe_tilelang.py index a8d684965..5c236dd80 100644 --- a/examples/fusedmoe/example_fusedmoe_tilelang.py +++ b/examples/fusedmoe/example_fusedmoe_tilelang.py @@ -9,17 +9,18 @@ @tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) -def moe_forward_tilelang_shared(d_hidden, - d_expert, - n_shared_experts, - dtype, - num_tokens, - block_token=128, - block_dhidden=128, - block_dexpert=128, - threads=256, - num_stages=1): - +def moe_forward_tilelang_shared( + d_hidden, + d_expert, + n_shared_experts, + dtype, + num_tokens, + block_token=128, + block_dhidden=128, + block_dexpert=128, + threads=256, + num_stages=1, +): scale = 1.44269504 # log2(e) # Parameters @@ -32,21 +33,19 @@ def moe_forward_tilelang_shared(d_hidden, shared_W_up_shape = (dexpert, dhidden) shared_W_down_shape = (dhidden, dexpert) - accum_type = "float32" + accum_type = T.float32 @T.prim_func def kernel_shared( - input: T.Tensor(input_shape, dtype), # type: ignore - shared_W_gate: T.Tensor(shared_W_gate_shape, dtype), # type: ignore - shared_W_up: T.Tensor(shared_W_up_shape, dtype), # type: ignore - shared_W_down: T.Tensor(shared_W_down_shape, dtype), # type: ignore - up_logits: T.Tensor((num_tokens, dexpert), dtype), # type: ignore - output: T.Tensor(input_shape, dtype), # type: ignore + input: T.Tensor(input_shape, dtype), # type: ignore + shared_W_gate: T.Tensor(shared_W_gate_shape, dtype), # type: ignore + shared_W_up: T.Tensor(shared_W_up_shape, dtype), # type: ignore + shared_W_down: T.Tensor(shared_W_down_shape, dtype), # type: ignore + up_logits: T.Tensor((num_tokens, dexpert), dtype), # type: ignore + output: T.Tensor(input_shape, dtype), # type: ignore ): # Step 1: Compute gate and up logits - with T.Kernel( - T.ceildiv(num_tokens, block_token), T.ceildiv(dexpert, block_dexpert), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by): # Split the block to shared experts and routed experts input_shared = T.alloc_fragment((block_token, block_dhidden), dtype=dtype) W_gate_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype) @@ -70,16 +69,13 @@ def kernel_shared( # Fuse with SiLU and element-wise product for i, j in T.Parallel(block_token, block_dexpert): - gate_logits_local[i, j] = gate_logits_local[i, j] * ( - 1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale))) + gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale))) up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j] T.copy(up_logits_local, up_logits[bx * block_token, by * block_dexpert]) # Step 2: Compute down logits - with T.Kernel( - T.ceildiv(num_tokens, block_token), T.ceildiv(dhidden, block_dhidden), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dhidden, block_dhidden), threads=threads) as (bx, by): up_logits_shared = T.alloc_fragment((block_token, block_dexpert), dtype=dtype) W_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype) output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_type) @@ -98,20 +94,21 @@ def kernel_shared( @tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) -def moe_forward_tilelang_routed(d_hidden, - d_expert, - n_routed_experts, - dtype, - group_sum, - group_count, - block_token=128, - block_dhidden=128, - block_dexpert=128, - threads=256, - num_stages=1, - k_pack=1, - coalesced_width=None): - +def moe_forward_tilelang_routed( + d_hidden, + d_expert, + n_routed_experts, + dtype, + group_sum, + group_count, + block_token=128, + block_dhidden=128, + block_dexpert=128, + threads=256, + num_stages=1, + k_pack=1, + coalesced_width=None, +): scale = 1.44269504 # log2(e) # Parameters @@ -124,7 +121,7 @@ def moe_forward_tilelang_routed(d_hidden, # group_count = len(group_sizes_list) # M = sum([(group_size + block_token - 1) // block_token for group_size in group_sizes_list]) M = math.ceil(group_sum / block_token) + group_count - accum_dtype = "float32" + accum_dtype = T.float32 # Tensors: Note that input shape is reshape to (bs * seq_len * n_experts_per_token, dhidden) for grouped gemm input_shape = (group_sum, dhidden) @@ -132,22 +129,22 @@ def moe_forward_tilelang_routed(d_hidden, routed_expert_gate_shape = (n_routed_experts, dexpert, dhidden) routed_expert_up_shape = (n_routed_experts, dexpert, dhidden) routed_expert_down_shape = (n_routed_experts, dhidden, dexpert) - routed_expert_weights_shape = (group_sum) - group_sizes_shape = (n_routed_experts) + routed_expert_weights_shape = group_sum + group_sizes_shape = n_routed_experts @T.prim_func def kernel( - input: T.Tensor(input_shape, dtype), # type: ignore - routed_expert_gate: T.Tensor(routed_expert_gate_shape, dtype), # type: ignore - routed_expert_up: T.Tensor(routed_expert_up_shape, dtype), # type: ignore - routed_expert_down: T.Tensor(routed_expert_down_shape, dtype), # type: ignore - routed_expert_weights: T.Tensor(routed_expert_weights_shape, dtype), # type: ignore - group_sizes: T.Tensor(group_sizes_shape, "int32"), # type: ignore - group_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore - group_padded_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore - group_idx_for_bx: T.Tensor((M,), "int32"), # type: ignore - up_logits: T.Tensor(intermediate_shape, dtype), # type: ignore - output: T.Tensor(input_shape, dtype), # type: ignore + input: T.Tensor(input_shape, dtype), # type: ignore + routed_expert_gate: T.Tensor(routed_expert_gate_shape, dtype), # type: ignore + routed_expert_up: T.Tensor(routed_expert_up_shape, dtype), # type: ignore + routed_expert_down: T.Tensor(routed_expert_down_shape, dtype), # type: ignore + routed_expert_weights: T.Tensor(routed_expert_weights_shape, dtype), # type: ignore + group_sizes: T.Tensor(group_sizes_shape, T.int32), # type: ignore + group_offsets: T.Tensor(group_sizes_shape, T.int32), # type: ignore + group_padded_offsets: T.Tensor(group_sizes_shape, T.int32), # type: ignore + group_idx_for_bx: T.Tensor((M,), T.int32), # type: ignore + up_logits: T.Tensor(intermediate_shape, dtype), # type: ignore + output: T.Tensor(input_shape, dtype), # type: ignore ): # Step 1: Compute gate and up logits with T.Kernel(M, T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by): @@ -158,58 +155,44 @@ def kernel( gate_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype) up_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype) - cur_group_idx = T.alloc_local([1], "int32") - cur_group_size = T.alloc_local([1], "int32") - T.use_swizzle(10, enable=True) m_start_padded = bx * block_token - cur_group_idx[0] = group_idx_for_bx[bx] + cur_group_idx = group_idx_for_bx[bx] - cur_group_size[0] = group_sizes[cur_group_idx[0]] - m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[ - cur_group_idx[0]] - actual_rows = T.max( - 0, - T.min(block_token, cur_group_size[0] - - (m_start_padded - group_padded_offsets[cur_group_idx[0]]))) + cur_group_size = group_sizes[cur_group_idx] + m_start = m_start_padded - group_padded_offsets[cur_group_idx] + group_offsets[cur_group_idx] + actual_rows = T.max(0, T.min(block_token, cur_group_size - (m_start_padded - group_padded_offsets[cur_group_idx]))) T.clear(gate_logits_local) T.clear(up_logits_local) for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages): T.copy( - input[m_start:m_start + block_token, k * block_dhidden:(k + 1) * block_dhidden], + input[m_start : m_start + block_token, k * block_dhidden : (k + 1) * block_dhidden], input_shared, - coalesced_width=coalesced_width) + coalesced_width=coalesced_width, + ) T.copy( - routed_expert_gate[cur_group_idx[0], - by * block_dexpert:(by + 1) * block_dexpert, - k * block_dhidden:(k + 1) * block_dhidden], + routed_expert_gate[ + cur_group_idx, by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden + ], routed_expert_gate_shared, - coalesced_width=coalesced_width) - T.gemm( - input_shared, - routed_expert_gate_shared, - gate_logits_local, - k_pack=k_pack, - transpose_B=True) + coalesced_width=coalesced_width, + ) + T.gemm(input_shared, routed_expert_gate_shared, gate_logits_local, k_pack=k_pack, transpose_B=True) T.copy( - routed_expert_up[cur_group_idx[0], by * block_dexpert:(by + 1) * block_dexpert, - k * block_dhidden:(k + 1) * block_dhidden], - routed_expert_up_shared, - coalesced_width=coalesced_width) - T.gemm( - input_shared, + routed_expert_up[ + cur_group_idx, by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden + ], routed_expert_up_shared, - up_logits_local, - k_pack=k_pack, - transpose_B=True) + coalesced_width=coalesced_width, + ) + T.gemm(input_shared, routed_expert_up_shared, up_logits_local, k_pack=k_pack, transpose_B=True) for i, j in T.Parallel(block_token, block_dexpert): - gate_logits_local[i, j] = gate_logits_local[i, j] * ( - 1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale))) + gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale))) up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j] for i, j in T.Parallel(block_token, block_dexpert): @@ -222,60 +205,42 @@ def kernel( routed_expert_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype) output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_dtype) - cur_group_idx = T.alloc_local([1], "int32") - cur_group_size = T.alloc_local([1], "int32") - T.use_swizzle(10, enable=True) m_start_padded = bx * block_token - cur_group_idx[0] = group_idx_for_bx[bx] + cur_group_idx = group_idx_for_bx[bx] - cur_group_size[0] = group_sizes[cur_group_idx[0]] - m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[ - cur_group_idx[0]] - actual_rows = T.max( - 0, - T.min(block_token, cur_group_size[0] - - (m_start_padded - group_padded_offsets[cur_group_idx[0]]))) + cur_group_size = group_sizes[cur_group_idx] + m_start = m_start_padded - group_padded_offsets[cur_group_idx] + group_offsets[cur_group_idx] + actual_rows = T.max(0, T.min(block_token, cur_group_size - (m_start_padded - group_padded_offsets[cur_group_idx]))) T.clear(output_local) for k in T.Pipelined(T.ceildiv(dexpert, block_dexpert), num_stages=num_stages): T.copy( - up_logits[m_start:m_start + block_token, - k * block_dexpert:(k + 1) * block_dexpert], + up_logits[m_start : m_start + block_token, k * block_dexpert : (k + 1) * block_dexpert], up_logits_shared, - coalesced_width=coalesced_width) + coalesced_width=coalesced_width, + ) T.copy( - routed_expert_down[cur_group_idx[0], - by * block_dhidden:(by + 1) * block_dhidden, - k * block_dexpert:(k + 1) * block_dexpert], - routed_expert_down_shared, - coalesced_width=coalesced_width) - T.gemm( - up_logits_shared, + routed_expert_down[ + cur_group_idx, by * block_dhidden : (by + 1) * block_dhidden, k * block_dexpert : (k + 1) * block_dexpert + ], routed_expert_down_shared, - output_local, - k_pack=k_pack, - transpose_B=True) + coalesced_width=coalesced_width, + ) + T.gemm(up_logits_shared, routed_expert_down_shared, output_local, k_pack=k_pack, transpose_B=True) for i, j in T.Parallel(block_token, block_dhidden): if i < actual_rows: - output[m_start + i, by * block_dhidden + - j] = output_local[i, j] * routed_expert_weights[m_start + i] + output[m_start + i, by * block_dhidden + j] = output_local[i, j] * routed_expert_weights[m_start + i] return kernel class Expert(nn.Module): - - def __init__(self, - config: Dict, - gate: torch.Tensor, - up: torch.Tensor, - down: torch.Tensor, - d_expert: Optional[int] = None): + def __init__(self, config: Dict, gate: torch.Tensor, up: torch.Tensor, down: torch.Tensor, d_expert: Optional[int] = None): super().__init__() self.config = config self.act_fn = nn.SiLU() @@ -294,14 +259,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MoEGate(nn.Module): - def __init__(self, config: Dict, weights: Dict): super().__init__() self.top_k: int = config["n_experts_per_token"] self.num_experts: int = config["n_routed_experts"] self.d_hidden: int = config["d_hidden"] - self.W_g_weight = weights['router.weight'].t() + self.W_g_weight = weights["router.weight"].t() def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: logits = x @ self.W_g_weight @@ -312,76 +276,69 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: class MoE(nn.Module): - - def __init__(self, - config: Dict, - shared_kernel: tilelang.JITKernel, - routed_kernel: tilelang.JITKernel, - weights: Dict, - padding_M: int = 128): + def __init__( + self, config: Dict, shared_kernel: tilelang.JITKernel, routed_kernel: tilelang.JITKernel, weights: Dict, padding_M: int = 128 + ): super().__init__() self.config = config self.shared_kernel = shared_kernel self.routed_kernel = routed_kernel self.padding_M = padding_M - self.experts = nn.ModuleList([ - Expert( - config, - gate=weights[f'experts.{i}.0.weight'], - up=weights[f'experts.{i}.1.weight'], - down=weights[f'experts.{i}.2.weight']) for i in range(config["n_routed_experts"]) - ]) + self.experts = nn.ModuleList( + [ + Expert( + config, + gate=weights[f"experts.{i}.0.weight"], + up=weights[f"experts.{i}.1.weight"], + down=weights[f"experts.{i}.2.weight"], + ) + for i in range(config["n_routed_experts"]) + ] + ) self.device = torch.device("cuda") self.gating_network = MoEGate(config, weights).to(self.device) shared_expert_dim = config["d_expert"] * config["n_shared_experts"] self.shared_expert = Expert( config=config, - gate=weights['shared_experts.0.weight'], - up=weights['shared_experts.1.weight'], - down=weights['shared_experts.2.weight'], - d_expert=shared_expert_dim).to(self.device) + gate=weights["shared_experts.0.weight"], + up=weights["shared_experts.1.weight"], + down=weights["shared_experts.2.weight"], + d_expert=shared_expert_dim, + ).to(self.device) self.expert_cache = torch.zeros( - (config["batch_size"] * config["seq_len"], config["d_hidden"]), - dtype=torch.float16, - device=self.device) - self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts], - dim=0) - self.stacked_expert_w_up = torch.stack([expert.W_up_weight for expert in self.experts], - dim=0) - self.stacked_expert_w_down = torch.stack([expert.W_down_weight for expert in self.experts], - dim=0) + (config["batch_size"] * config["seq_len"], config["d_hidden"]), dtype=torch.float16, device=self.device + ) + self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts], dim=0) + self.stacked_expert_w_up = torch.stack([expert.W_up_weight for expert in self.experts], dim=0) + self.stacked_expert_w_down = torch.stack([expert.W_down_weight for expert in self.experts], dim=0) self.stacked_expert_tokens = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], - self.config["d_hidden"]), + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]), dtype=torch.float16, - device=self.device) + device=self.device, + ) self.stacked_expert_weights = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), - dtype=torch.float16, - device=self.device) + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.float16, device=self.device + ) self.stacked_expert_tokens_idxs = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), - dtype=torch.int64, - device=self.device) + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.int64, device=self.device + ) self.up_logits_shared = torch.empty( - (config["batch_size"] * config["seq_len"], self.config["d_expert"]), - dtype=torch.float16, - device=self.device) + (config["batch_size"] * config["seq_len"], self.config["d_expert"]), dtype=torch.float16, device=self.device + ) self.expert_output_shared = torch.empty( - (config["batch_size"] * config["seq_len"], self.config["d_hidden"]), - dtype=torch.float16, - device=self.device) + (config["batch_size"] * config["seq_len"], self.config["d_hidden"]), dtype=torch.float16, device=self.device + ) self.up_logits_routed = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], - self.config["d_expert"]), + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_expert"]), dtype=torch.float16, - device=self.device) + device=self.device, + ) self.expert_output_routed = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], - self.config["d_hidden"]), + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]), dtype=torch.float16, - device=self.device) + device=self.device, + ) @torch.no_grad() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -413,22 +370,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens self.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs - self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[ - idxs[start_idx:end_idx]] + self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[idxs[start_idx:end_idx]] group_sizes = torch.tensor(counts, dtype=torch.int32, device=self.device) - group_offset = torch.tensor( - tokens_per_expert - counts, dtype=torch.int32, device=self.device) + group_offset = torch.tensor(tokens_per_expert - counts, dtype=torch.int32, device=self.device) group_padded_offsets = [0 for _ in range(len(group_sizes))] for i in range(1, len(group_sizes)): - group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil( - (counts[i - 1] + 1) / self.padding_M) * self.padding_M + group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil((counts[i - 1] + 1) / self.padding_M) * self.padding_M block_token = 128 - M = math.ceil( - self.config["batch_size"] * self.config["seq_len"] * - self.config["n_experts_per_token"] / block_token) + self.config["n_routed_experts"] + M = ( + math.ceil(self.config["batch_size"] * self.config["seq_len"] * self.config["n_experts_per_token"] / block_token) + + self.config["n_routed_experts"] + ) group_idx_for_bx = [0 for _ in range(M)] for bx in range(M): @@ -437,8 +392,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if m_start_padded >= group_padded_offsets[i]: group_idx_for_bx[bx] = i - group_padded_offsets = torch.tensor( - group_padded_offsets, dtype=torch.int32, device=self.device) + group_padded_offsets = torch.tensor(group_padded_offsets, dtype=torch.int32, device=self.device) group_idx_for_bx = torch.tensor(group_idx_for_bx, dtype=torch.int32, device=self.device) # Multi-stream execution @@ -448,11 +402,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: with torch.cuda.stream(routed_stream): # Tilelang version: Grouped GEMM - self.routed_kernel(self.stacked_expert_tokens, self.stacked_expert_w_gate, - self.stacked_expert_w_up, self.stacked_expert_w_down, - self.stacked_expert_weights, group_sizes, group_offset, - group_padded_offsets, group_idx_for_bx, self.up_logits_routed, - self.expert_output_routed) + self.routed_kernel( + self.stacked_expert_tokens, + self.stacked_expert_w_gate, + self.stacked_expert_w_up, + self.stacked_expert_w_down, + self.stacked_expert_weights, + group_sizes, + group_offset, + group_padded_offsets, + group_idx_for_bx, + self.up_logits_routed, + self.expert_output_routed, + ) # Scatter reduce self.expert_cache = torch.scatter_reduce( @@ -460,14 +422,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: 0, self.stacked_expert_tokens_idxs.view(-1, 1).repeat(1, x_flat.shape[-1]), self.expert_output_routed, - reduce='sum') + reduce="sum", + ) routed_output = self.expert_cache.view(*orig_shape) with torch.cuda.stream(shared_stream): - - self.shared_kernel(x_flat, self.shared_expert.W_gate_weight, - self.shared_expert.W_up_weight, self.shared_expert.W_down_weight, - self.up_logits_shared, self.expert_output_shared) + self.shared_kernel( + x_flat, + self.shared_expert.W_gate_weight, + self.shared_expert.W_up_weight, + self.shared_expert.W_down_weight, + self.up_logits_shared, + self.expert_output_shared, + ) shared_output = self.expert_output_shared.view(*orig_shape) torch.cuda.synchronize() @@ -491,14 +458,15 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: """ input_tensor, weights, config = data - dtype_str = "float16" + dtype_str = T.float16 shared_kernel = moe_forward_tilelang_shared( config["d_hidden"], config["d_expert"], config["n_shared_experts"], dtype=dtype_str, - num_tokens=config["batch_size"] * config["seq_len"]) + num_tokens=config["batch_size"] * config["seq_len"], + ) routed_kernel = moe_forward_tilelang_routed( config["d_hidden"], config["d_expert"], @@ -512,7 +480,8 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: threads=256, num_stages=1, k_pack=1, - coalesced_width=2) + coalesced_width=2, + ) moe = MoE(config, shared_kernel, routed_kernel, weights, padding_M=128) @@ -521,13 +490,7 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: return output -def main(d_hidden=7168, - d_expert=2048, - n_routed_experts=8, - n_shared_experts=1, - n_experts_per_token=4, - batch_size=1, - seq_len=8192): +def main(d_hidden=7168, d_expert=2048, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=8192): config = { "dhidden": d_hidden, "dexpert": d_expert, @@ -536,7 +499,7 @@ def main(d_hidden=7168, "nexpertspertoken": n_experts_per_token, "bs": batch_size, "seqlen": seq_len, - "seed": 81394 + "seed": 81394, } data = generate_input(**config) @@ -551,5 +514,121 @@ def main(d_hidden=7168, print("✅ Tilelang and Torch match") +def run_regression_perf( + d_hidden=7168, d_expert=2048, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=8192 +): + config = { + "dhidden": d_hidden, + "dexpert": d_expert, + "nroutedexperts": n_routed_experts, + "nsharedexperts": n_shared_experts, + "nexpertspertoken": n_experts_per_token, + "bs": batch_size, + "seqlen": seq_len, + "seed": 81394, + } + from tilelang.profiler import do_bench + + data = generate_input(**config) + + x, weights, config = data + + dtype_str = "float16" + + shared_kernel = moe_forward_tilelang_shared( + config["d_hidden"], + config["d_expert"], + config["n_shared_experts"], + dtype=dtype_str, + num_tokens=config["batch_size"] * config["seq_len"], + ) + routed_kernel = moe_forward_tilelang_routed( + config["d_hidden"], + config["d_expert"], + config["n_routed_experts"], + dtype=dtype_str, + group_sum=config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], + group_count=config["n_routed_experts"], + block_token=128, + block_dhidden=128, + block_dexpert=128, + threads=256, + num_stages=1, + k_pack=1, + coalesced_width=2, + ) + + moe = MoE(config, shared_kernel, routed_kernel, weights, padding_M=128) + batch_size, seq_len, hidden_dim = x.shape + expert_indices, expert_scores = moe.gating_network(x) + flat_expert_indices = expert_indices.view(-1) + flat_expert_weights = expert_scores.view(-1) + x_flat = x.view(-1, hidden_dim) + idxs = flat_expert_indices.argsort() + counts = flat_expert_indices.bincount().cpu().numpy() + tokens_per_expert = counts.cumsum() + num_per_tok = moe.config["n_experts_per_token"] + token_idxs = idxs // num_per_tok + for expert_id, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1] + if start_idx == end_idx: + continue + exp_token_idxs = token_idxs[start_idx:end_idx] + expert_tokens = x_flat[exp_token_idxs] + moe.stacked_expert_tokens[start_idx:end_idx] = expert_tokens + moe.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs + moe.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[idxs[start_idx:end_idx]] + group_sizes = torch.tensor(counts, dtype=torch.int32, device=moe.device) + group_offset = torch.tensor(tokens_per_expert - counts, dtype=torch.int32, device=moe.device) + group_padded_offsets = [0 for _ in range(len(group_sizes))] + for i in range(1, len(group_sizes)): + group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil((counts[i - 1] + 1) / moe.padding_M) * moe.padding_M + block_token = 128 + M = ( + math.ceil(moe.config["batch_size"] * moe.config["seq_len"] * moe.config["n_experts_per_token"] / block_token) + + moe.config["n_routed_experts"] + ) + group_idx_for_bx = [0 for _ in range(M)] + for bx in range(M): + m_start_padded = bx * block_token + for i in range(moe.config["n_routed_experts"]): + if m_start_padded >= group_padded_offsets[i]: + group_idx_for_bx[bx] = i + group_padded_offsets = torch.tensor(group_padded_offsets, dtype=torch.int32, device=moe.device) + group_idx_for_bx = torch.tensor(group_idx_for_bx, dtype=torch.int32, device=moe.device) + + def run_shared_kernel_only(): + moe.routed_kernel( + moe.stacked_expert_tokens, + moe.stacked_expert_w_gate, + moe.stacked_expert_w_up, + moe.stacked_expert_w_down, + moe.stacked_expert_weights, + group_sizes, + group_offset, + group_padded_offsets, + group_idx_for_bx, + moe.up_logits_routed, + moe.expert_output_routed, + ) + + def run_routed_kernel_only(): + moe.routed_kernel( + moe.stacked_expert_tokens, + moe.stacked_expert_w_gate, + moe.stacked_expert_w_up, + moe.stacked_expert_w_down, + moe.stacked_expert_weights, + group_sizes, + group_offset, + group_padded_offsets, + group_idx_for_bx, + moe.up_logits_routed, + moe.expert_output_routed, + ) + + return do_bench(run_routed_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/fusedmoe/example_fusedmoe_torch.py b/examples/fusedmoe/example_fusedmoe_torch.py index 00219c6e9..6b6322aff 100644 --- a/examples/fusedmoe/example_fusedmoe_torch.py +++ b/examples/fusedmoe/example_fusedmoe_torch.py @@ -6,7 +6,6 @@ # Reference code in PyTorch class ExpertTorch(nn.Module): - def __init__(self, config: Dict, d_expert: Optional[int] = None): super().__init__() self.config = config @@ -25,7 +24,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MoEGateTorch(nn.Module): - def __init__(self, config: Dict): super().__init__() self.top_k: int = config["n_experts_per_token"] @@ -43,12 +41,10 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: class MoETorch(nn.Module): - def __init__(self, config: Dict): super().__init__() self.config = config - self.experts = nn.ModuleList( - [ExpertTorch(config) for _ in range(config["n_routed_experts"])]) + self.experts = nn.ModuleList([ExpertTorch(config) for _ in range(config["n_routed_experts"])]) self.gating_network = MoEGateTorch(config) shared_expert_dim = config["d_expert"] * config["n_shared_experts"] self.shared_expert = ExpertTorch(config=config, d_expert=shared_expert_dim) @@ -67,8 +63,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return routed_output + shared_output @torch.no_grad() - def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor, - flat_expert_weights: torch.Tensor) -> torch.Tensor: + def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor, flat_expert_weights: torch.Tensor) -> torch.Tensor: expert_cache = torch.zeros_like(x) # test_expert_cache = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"])) # test_expert_tokens = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"])) @@ -91,8 +86,7 @@ def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor, expert_out = expert(expert_tokens) expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) - expert_cache.scatter_reduce_( - 0, exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum') + expert_cache.scatter_reduce_(0, exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce="sum") return expert_cache @@ -116,21 +110,21 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: moe = MoETorch(config) # Fill in the given weights of the model - moe.gating_network.W_g.weight = nn.Parameter(weights['router.weight']) + moe.gating_network.W_g.weight = nn.Parameter(weights["router.weight"]) for i in range(num_experts): - gate_proj_weight = weights[f'experts.{i}.0.weight'] - up_proj_weight = weights[f'experts.{i}.1.weight'] - down_proj_weight = weights[f'experts.{i}.2.weight'] + gate_proj_weight = weights[f"experts.{i}.0.weight"] + up_proj_weight = weights[f"experts.{i}.1.weight"] + down_proj_weight = weights[f"experts.{i}.2.weight"] # Transpose weights to match expected shape for nn.Linear moe.experts[i].W_gate.weight = nn.Parameter(gate_proj_weight.t()) moe.experts[i].W_up.weight = nn.Parameter(up_proj_weight.t()) moe.experts[i].W_down.weight = nn.Parameter(down_proj_weight.t()) - moe.shared_expert.W_gate.weight = nn.Parameter(weights['shared_experts.0.weight'].t()) - moe.shared_expert.W_up.weight = nn.Parameter(weights['shared_experts.1.weight'].t()) - moe.shared_expert.W_down.weight = nn.Parameter(weights['shared_experts.2.weight'].t()) + moe.shared_expert.W_gate.weight = nn.Parameter(weights["shared_experts.0.weight"].t()) + moe.shared_expert.W_up.weight = nn.Parameter(weights["shared_experts.1.weight"].t()) + moe.shared_expert.W_down.weight = nn.Parameter(weights["shared_experts.2.weight"].t()) output = moe(input_tensor) @@ -140,10 +134,9 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: # Input generation for the reference code -def generate_input(dhidden: int, dexpert: int, nroutedexperts: int, nsharedexperts: int, - nexpertspertoken: int, bs: int, seqlen: int, - seed: int) -> Tuple[torch.Tensor, Dict, Dict]: - +def generate_input( + dhidden: int, dexpert: int, nroutedexperts: int, nsharedexperts: int, nexpertspertoken: int, bs: int, seqlen: int, seed: int +) -> Tuple[torch.Tensor, Dict, Dict]: # Really dumb but for now _ isn't parsing correctly. d_hidden = dhidden d_expert = dexpert @@ -163,50 +156,40 @@ def generate_input(dhidden: int, dexpert: int, nroutedexperts: int, nsharedexper "seq_len": seq_len, } - gen = torch.Generator(device='cuda') + gen = torch.Generator(device="cuda") gen.manual_seed(seed) num_experts = n_routed_experts expert_dim = d_expert weights = {} - input_tensor = torch.randn((batch_size, seq_len, d_hidden), - device='cuda', - dtype=torch.float16, - generator=gen).contiguous() + input_tensor = torch.randn((batch_size, seq_len, d_hidden), device="cuda", dtype=torch.float16, generator=gen).contiguous() # Initialize router weights - weights['router.weight'] = torch.randn( - (num_experts, d_hidden), device="cuda", dtype=torch.float16, - generator=gen) / math.sqrt(d_hidden) + weights["router.weight"] = torch.randn((num_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen) / math.sqrt(d_hidden) for i in range(num_experts): - weights[f'experts.{i}.0.weight'] = torch.randn( - (d_hidden, expert_dim), device='cuda', dtype=torch.float16, - generator=gen) / math.sqrt(expert_dim) - - weights[f'experts.{i}.1.weight'] = torch.randn( - (d_hidden, expert_dim), device='cuda', dtype=torch.float16, - generator=gen) / math.sqrt(expert_dim) - - weights[f'experts.{i}.2.weight'] = torch.randn( - (expert_dim, d_hidden), device='cuda', dtype=torch.float16, - generator=gen) / math.sqrt(d_hidden) - - weights['shared_experts.0.weight'] = torch.randn( - (d_hidden, expert_dim * n_shared_experts), - device='cuda', - dtype=torch.float16, - generator=gen) / math.sqrt(expert_dim * n_shared_experts) - weights['shared_experts.1.weight'] = torch.randn( - (d_hidden, expert_dim * n_shared_experts), - device='cuda', - dtype=torch.float16, - generator=gen) / math.sqrt(expert_dim * n_shared_experts) - weights['shared_experts.2.weight'] = torch.randn((expert_dim * n_shared_experts, d_hidden), - device='cuda', - dtype=torch.float16, - generator=gen) / math.sqrt(d_hidden) + weights[f"experts.{i}.0.weight"] = torch.randn( + (d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim) + + weights[f"experts.{i}.1.weight"] = torch.randn( + (d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim) + + weights[f"experts.{i}.2.weight"] = torch.randn( + (expert_dim, d_hidden), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(d_hidden) + + weights["shared_experts.0.weight"] = torch.randn( + (d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim * n_shared_experts) + weights["shared_experts.1.weight"] = torch.randn( + (d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim * n_shared_experts) + weights["shared_experts.2.weight"] = torch.randn( + (expert_dim * n_shared_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(d_hidden) return (input_tensor, weights, config) diff --git a/examples/fusedmoe/regression_example_fusedmoe.py b/examples/fusedmoe/regression_example_fusedmoe.py new file mode 100644 index 000000000..ac0f18aae --- /dev/null +++ b/examples/fusedmoe/regression_example_fusedmoe.py @@ -0,0 +1,19 @@ +import tilelang.testing +import example_fusedmoe_tilelang + + +def regression_example_fusedmoe_tilelang(): + tilelang.testing.process_func( + example_fusedmoe_tilelang.run_regression_perf, + d_hidden=1024, + d_expert=256, + n_routed_experts=8, + n_shared_experts=1, + n_experts_per_token=4, + batch_size=1, + seq_len=1024, + ) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/fusedmoe/test_example_fusedmoe.py b/examples/fusedmoe/test_example_fusedmoe.py index 806aff49e..ba8415895 100644 --- a/examples/fusedmoe/test_example_fusedmoe.py +++ b/examples/fusedmoe/test_example_fusedmoe.py @@ -4,13 +4,8 @@ def test_example_fusedmoe_tilelang(): example_fusedmoe_tilelang.main( - d_hidden=1024, - d_expert=256, - n_routed_experts=8, - n_shared_experts=1, - n_experts_per_token=4, - batch_size=1, - seq_len=1024) + d_hidden=1024, d_expert=256, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=1024 + ) if __name__ == "__main__": diff --git a/examples/gdn/example_chunk_delta_bwd.py b/examples/gdn/example_chunk_delta_bwd.py index 518b0ee21..4230df525 100644 --- a/examples/gdn/example_chunk_delta_bwd.py +++ b/examples/gdn/example_chunk_delta_bwd.py @@ -12,6 +12,7 @@ # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__, flush=True) from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu except ImportError: @@ -24,7 +25,7 @@ torch.random.manual_seed(0) # torch.set_printoptions(profile="full") -from utils import * +from test_utils import assert_similar def prepare_input( @@ -49,6 +50,7 @@ def prepare_input( G = F.logsigmoid(G) try: from fla.ops.utils.cumsum import chunk_local_cumsum + G = chunk_local_cumsum(G, chunk_size) except ImportError: print("fla not found, skip cumsum") @@ -125,8 +127,11 @@ def torch_chunk_gated_delta_rule_bwd_dhu( DV = dv.shape[-1] block_S = 64 BS = S // block_S - dh, dh0, dv2 = torch.empty((B, BS, H, DK, DV), dtype=output_dtype), torch.empty( - (B, H, DK, DV), dtype=state_dtype), torch.empty((B, S, H, DV), dtype=output_dtype) + dh, dh0, dv2 = ( + torch.empty((B, BS, H, DK, DV), dtype=output_dtype), + torch.empty((B, H, DK, DV), dtype=state_dtype), + torch.empty((B, S, H, DV), dtype=output_dtype), + ) dh_tmp = torch.empty((B, H, DK, DV), dtype=accum_dtype) dv_tmp = torch.empty((B, S, H, DV), dtype=accum_dtype) Q_tmp = torch.empty((B, S, H, DK), dtype=accum_dtype) @@ -138,34 +143,30 @@ def torch_chunk_gated_delta_rule_bwd_dhu( for i_s in range(BS - 1, -1, -1): dh[:, i_s, :, :, :] = dh_tmp - dv_tmp = torch.matmul(K[:, i_s * block_S:(i_s + 1) * block_S, :, :].permute(0, 2, 1, 3), - dh_tmp.to(K.dtype)).permute(0, 2, 1, 3) + dv_tmp = torch.matmul(K[:, i_s * block_S : (i_s + 1) * block_S, :, :].permute(0, 2, 1, 3), dh_tmp.to(K.dtype)).permute(0, 2, 1, 3) if use_g: for i_bh in range(B * H): i_b, i_h = i_bh // H, i_bh % H for i_s2 in range(block_S): - if G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, - i_h] <= 0: - dv_tmp[i_b, i_s2, - i_h, :] *= torch.exp(G[i_b, i_s * block_S + block_S - 1, i_h] - - G[i_b, i_s * block_S + i_s2, i_h]) + if G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h] <= 0: + dv_tmp[i_b, i_s2, i_h, :] *= torch.exp(G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h]) else: dv_tmp[i_b, i_s2, i_h, :] = 0 - dv_tmp += dv[:, i_s * block_S:(i_s + 1) * block_S, :, :] - dv2[:, i_s * block_S:(i_s + 1) * block_S, :, :] = dv_tmp + dv_tmp += dv[:, i_s * block_S : (i_s + 1) * block_S, :, :] + dv2[:, i_s * block_S : (i_s + 1) * block_S, :, :] = dv_tmp if use_g: G_last = G[:, i_s * block_S + block_S - 1, :] for i_bh in range(B * H): i_b, i_h = i_bh // H, i_bh % H dh_tmp[i_b, i_h, :, :] *= torch.exp(G_last[i_b, i_h]) - Q_tmp = Q[:, i_s * block_S:(i_s + 1) * block_S, :, :] + Q_tmp = Q[:, i_s * block_S : (i_s + 1) * block_S, :, :] for i_s2 in range(block_S): for i_k in range(DK): Q_tmp[:, i_s2, :, i_k] *= torch.exp(G[:, i_s * block_S + i_s2, :]) Q_tmp *= scale - W_tmp = W[:, i_s * block_S:(i_s + 1) * block_S, :, :] - dO_tmp = dO[:, i_s * block_S:(i_s + 1) * block_S, :, :] + W_tmp = W[:, i_s * block_S : (i_s + 1) * block_S, :, :] + dO_tmp = dO[:, i_s * block_S : (i_s + 1) * block_S, :, :] torch.backends.cuda.matmul.allow_tf32 = True dh_tmp += torch.matmul(Q_tmp.permute(0, 2, 3, 1), dO_tmp.permute(0, 2, 1, 3)) @@ -223,25 +224,24 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( @T.prim_func def kernel( - # Input - Q: T.Tensor(Q_shape, dtype=input_dtype), - K: T.Tensor(K_shape, dtype=input_dtype), - W: T.Tensor(W_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - h0: T.Tensor(h0_shape, dtype=input_dtype), - dht: T.Tensor(dht_shape, dtype=input_dtype), - dO: T.Tensor(dO_shape, dtype=input_dtype), - dv: T.Tensor(dv_shape, dtype=input_dtype), - # Output - dh: T.Tensor(dh_shape, dtype=output_dtype), - dh0: T.Tensor(dh0_shape, dtype=state_dtype), - dv2: T.Tensor(dv2_shape, dtype=output_dtype), + # Input + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + h0: T.Tensor(h0_shape, dtype=input_dtype), + dht: T.Tensor(dht_shape, dtype=input_dtype), + dO: T.Tensor(dO_shape, dtype=input_dtype), + dv: T.Tensor(dv_shape, dtype=input_dtype), + # Output + dh: T.Tensor(dh_shape, dtype=output_dtype), + dh0: T.Tensor(dh0_shape, dtype=state_dtype), + dv2: T.Tensor(dv2_shape, dtype=output_dtype), ): with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): bb, bh = bbh // H, bbh % H b_dh_shared = T.alloc_shared((DK, block_DV), dtype=output_dtype) - b_dh_shared_fp32 = T.alloc_shared((DK, block_DV), dtype=state_dtype) b_dh_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) b_dh_fragment_1 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) b_dh_fragment_2 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) @@ -249,17 +249,14 @@ def kernel( dv_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) dv_fragment_2 = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) dO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) - dO_shared_t = T.alloc_shared((block_DV, block_S), dtype="float32") - dO_fragment = T.alloc_fragment((block_S, block_DV), dtype="float32") - dO_fragment_t = T.alloc_fragment((block_DV, block_S), dtype="float32") + dO_shared_t = T.alloc_shared((block_DV, block_S), dtype=T.float32) + dO_fragment = T.alloc_fragment((block_S, block_DV), dtype=T.float32) + dO_fragment_t = T.alloc_fragment((block_DV, block_S), dtype=T.float32) K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) Q_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) - Q_shared_fp32 = T.alloc_shared((block_S, DK), dtype="float32") W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) - G_last_local = T.alloc_local((1), dtype=gate_dtype) - G_last_local_exp = T.alloc_local((1), dtype=gate_dtype) G_shared = T.alloc_shared((block_S), dtype=gate_dtype, scope="shared") G_fragment = T.alloc_fragment((block_S), dtype=gate_dtype) G_fragment_post = T.alloc_fragment((block_S), dtype=gate_dtype) @@ -269,20 +266,15 @@ def kernel( T.use_swizzle(10) - T.annotate_layout({ - b_dh_shared: tilelang.layout.make_swizzled_layout(b_dh_shared), - b_dh_shared_fp32: tilelang.layout.make_swizzled_layout(b_dh_shared_fp32), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), - dO_shared_t: tilelang.layout.make_swizzled_layout(dO_shared_t), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - Q_shared_fp32: tilelang.layout.make_swizzled_layout(Q_shared_fp32), - W_shared: tilelang.layout.make_swizzled_layout(W_shared), - }) + T.annotate_layout( + { + dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + } + ) if use_final_state_gradient: - T.copy(dht[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV], b_dh_shared) + T.copy(dht[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_dh_shared) T.copy(b_dh_shared, b_dh_fragment) else: T.clear(b_dh_fragment) @@ -293,57 +285,45 @@ def kernel( # Store the updated dh T.copy(b_dh_fragment, b_dh_shared) - T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) + T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) # Update dv - T.copy(K[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], K_shared) + T.copy(K[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], K_shared) T.gemm(K_shared, b_dh_shared, dv_fragment, clear_accum=True) if use_g: - T.copy( - G[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh], - G_shared, - disable_tma=True) + T.copy(G[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh], G_shared, disable_tma=True) T.copy(G_shared, G_fragment) - G_last_local[0] = G_shared[block_S - 1] - G_last_local_exp[0] = T.exp(G_last_local[0]) + G_last_local = G_shared[block_S - 1] + G_last_local_exp = T.exp(G_last_local) for i_s2 in T.Parallel(block_S): - G_fragment_post[i_s2] = T.exp(G_last_local[0] - G_fragment[i_s2]) + G_fragment_post[i_s2] = T.exp(G_last_local - G_fragment[i_s2]) for i_s2, i_v in T.Parallel(block_S, block_DV): - # with T.If(G_last_local[0] - G_shared[i_s2] <= 0): - with T.If(G_last_local[0] - G_fragment[i_s2] <= 0): - with T.Then(): - dv_fragment[i_s2, - i_v] = dv_fragment[i_s2, i_v] * G_fragment_post[i_s2] - with T.Else(): - dv_fragment[i_s2, i_v] = 0 - - T.copy( - dv[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, - bv * block_DV:(bv + 1) * block_DV], dv_shared) + dv_fragment[i_s2, i_v] = ( + dv_fragment[i_s2, i_v] * G_fragment_post[i_s2] if G_last_local - G_fragment[i_s2] <= 0 else 0 + ) + + T.copy(dv[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dv_shared) T.copy(dv_shared, dv_fragment_2) for i_s2, i_v in T.Parallel(block_S, block_DV): dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] + dv_fragment_2[i_s2, i_v] # Store the updated dv T.copy(dv_fragment, dv_shared) - T.copy( - dv_shared, dv2[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, - bv * block_DV:(bv + 1) * block_DV]) + T.copy(dv_shared, dv2[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) # Update dh - T.copy(Q[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], Q_shared) - T.copy(W[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], W_shared) + T.copy(Q[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], Q_shared) + T.copy(W[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], W_shared) T.clear(Q_fragment) if use_g: for i_k, i_v in T.Parallel(DK, block_DV): - b_dh_fragment[i_k, i_v] *= G_last_local_exp[0] + b_dh_fragment[i_k, i_v] *= G_last_local_exp T.copy(Q_shared, Q_fragment) for i_s2 in T.Parallel(block_S): G_fragment_exp[i_s2] = T.exp(G_shared[i_s2]) for i_s2, i_k in T.Parallel(block_S, DK): - # Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * T.exp(G_shared[i_s2]) * scale Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * G_fragment_exp[i_s2] * scale else: T.copy(Q_shared, Q_fragment) @@ -353,9 +333,7 @@ def kernel( for i_s2, i_k in T.Parallel(block_S, DK): Q_fragment_t[i_k, i_s2] = Q_fragment[i_s2, i_k] - T.copy( - dO[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, - bv * block_DV:(bv + 1) * block_DV], dO_shared) + T.copy(dO[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dO_shared) T.copy(dO_shared, dO_fragment) for i_s2, i_v in T.Parallel(block_S, block_DV): dO_fragment_t[i_v, i_s2] = dO_fragment[i_s2, i_v] @@ -369,7 +347,7 @@ def kernel( b_dh_fragment[i_k, i_v] += b_dh_fragment_1[i_k, i_v] - b_dh_fragment_2[i_k, i_v] if use_initial_state: - T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) + T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) return kernel @@ -444,44 +422,61 @@ def run_test( num_stages=0, use_torch=False, ): - Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - dh_ref, dh0_ref, dv2_ref = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) + Q, K, W, G, h0, dht, dO, dv = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dh_ref, dh0_ref, dv2_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) + dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) # fla ref print("fla running...", flush=True) if use_g: - dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, - scale) + dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale) else: G = G.fill_(0) - dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, - scale) + dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale) # tilelang print("tilelang running...", flush=True) - kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, - chunk_size, scale, use_g, use_initial_state, - use_final_state_gradient, block_DV, threads, - num_stages) + kernel = tilelang_chunk_gated_delta_rule_bwd_dhu( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_g, + use_initial_state, + use_final_state_gradient, + block_DV, + threads, + num_stages, + ) # kernel = tilelang.compile(program) print(kernel.get_kernel_source()) dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) - fla_time = do_bench( - chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size) + fla_time = do_bench(chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size) tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv) print(f"fla time: {fla_time} ms") @@ -496,19 +491,47 @@ def run_test( print("torch running...", flush=True) if use_g: dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu( - Q, K, W, G, h0, dht, dO, dv, scale, use_g, use_initial_state, - use_final_state_gradient, getattr(torch, input_dtype), getattr(torch, output_dtype), - getattr(torch, accum_dtype), getattr(torch, - gate_dtype), getattr(torch, state_dtype)) + Q, + K, + W, + G, + h0, + dht, + dO, + dv, + scale, + use_g, + use_initial_state, + use_final_state_gradient, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) dh_ref_torch = dh_ref_torch.cuda() dh0_ref_torch = dh0_ref_torch.cuda() dv2_ref_torch = dv2_ref_torch.cuda() else: dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu( - Q, K, W, None, h0, dht, dO, dv, scale, use_g, use_initial_state, - use_final_state_gradient, getattr(torch, input_dtype), getattr(torch, output_dtype), - getattr(torch, accum_dtype), getattr(torch, - gate_dtype), getattr(torch, state_dtype)) + Q, + K, + W, + None, + h0, + dht, + dO, + dv, + scale, + use_g, + use_initial_state, + use_final_state_gradient, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) dh_ref_torch = dh_ref_torch.cuda() dh0_ref_torch = dh0_ref_torch.cuda() dv2_ref_torch = dv2_ref_torch.cuda() @@ -554,11 +577,11 @@ def main(): H=8, DK=DK, DV=128, - input_dtype="bfloat16", - output_dtype="bfloat16", - accum_dtype="float32", - gate_dtype="float32", - state_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, chunk_size=64, scale=DK**-0.5, use_g=True, diff --git a/examples/gdn/example_chunk_delta_h.py b/examples/gdn/example_chunk_delta_h.py index 4d6b657ff..2ee84e7bf 100644 --- a/examples/gdn/example_chunk_delta_h.py +++ b/examples/gdn/example_chunk_delta_h.py @@ -3,12 +3,14 @@ import sys # noqa: F401 import tilelang import tilelang.language as T +from tilelang.autotuner import autotune # Add your fla repository path to sys.path # Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h except ImportError: @@ -19,7 +21,7 @@ import torch.nn.functional as F from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401 -from utils import * +from test_utils import assert_similar # (zhengju) We can slightly modify the generated cuda code from tilelang lowering # in the debug folder to make the performance better. To enable this callback, @@ -55,6 +57,7 @@ def prepare_input( G = F.logsigmoid(G) try: from fla.ops.utils.cumsum import chunk_local_cumsum + G = chunk_local_cumsum(G, chunk_size) except ImportError: print("fla not found, skip cumsum") @@ -80,7 +83,21 @@ def prepare_output( return h, final_state, V_new -@tilelang.jit(out_idx=[-3, -2, -1]) +def get_configs(): + import itertools + + block_DK = [32, 64, 128] + block_DV = [32, 64, 128] + threads = [128, 256] + num_stages = [1, 2, 3] + _configs = list(itertools.product(block_DK, block_DV, threads, num_stages)) + + configs = [{"block_DK": c[0], "block_DV": c[1], "threads": c[2], "num_stages": c[3]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=3, rep=5) +@tilelang.jit(out_idx=[-3, -2, -1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) def tilelang_chunk_gated_delta_rule_fwd_h( # task config B, @@ -94,15 +111,15 @@ def tilelang_chunk_gated_delta_rule_fwd_h( gate_dtype, state_dtype, chunk_size, - use_g=True, - use_initial_state=True, - store_final_state=True, - save_new_value=True, + use_g, + use_initial_state, + store_final_state, + save_new_value, # kernel config block_DK=64, - block_DV=64, - threads=256, - num_stages=0, + block_DV=32, + threads=128, + num_stages=1, ): block_S = chunk_size BS = S // block_S @@ -118,14 +135,14 @@ def tilelang_chunk_gated_delta_rule_fwd_h( @T.prim_func def kernel( - K: T.Tensor(K_shape, dtype=input_dtype), - W: T.Tensor(W_shape, dtype=input_dtype), - U: T.Tensor(U_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - initial_state: T.Tensor(initial_state_shape, dtype=input_dtype), - h: T.Tensor(h_shape, dtype=output_dtype), - final_state: T.Tensor(final_state_shape, dtype=state_dtype), - V_new: T.Tensor(V_shape, dtype=output_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + U: T.Tensor(U_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + initial_state: T.Tensor(initial_state_shape, dtype=input_dtype), + h: T.Tensor(h_shape, dtype=output_dtype), + final_state: T.Tensor(final_state_shape, dtype=state_dtype), + V_new: T.Tensor(V_shape, dtype=output_dtype), ): with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): bb, bh = bbh // H, bbh % H @@ -139,39 +156,35 @@ def kernel( V_new_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) V_new_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype) K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) - G_last_local = T.alloc_local((1), dtype=gate_dtype) + G_last_local = T.alloc_var(T.float32) G_shared = T.alloc_shared((block_S, block_DV), dtype=gate_dtype) G_fragment = T.alloc_fragment((block_S, block_DV), dtype=gate_dtype) - T.annotate_layout({ - b_h_shared: tilelang.layout.make_swizzled_layout(b_h_shared), - U_shared: tilelang.layout.make_swizzled_layout(U_shared), - W_shared: tilelang.layout.make_swizzled_layout(W_shared), - V_new_shared: tilelang.layout.make_swizzled_layout(V_new_shared), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - G_shared: tilelang.layout.make_swizzled_layout(G_shared), - }) + T.annotate_layout( + { + U_shared: tilelang.layout.make_swizzled_layout(U_shared), + G_shared: tilelang.layout.make_swizzled_layout(G_shared), + } + ) T.use_swizzle(10) if use_initial_state: - T.copy(initial_state[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV], b_h_shared) + T.copy(initial_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_h_shared) T.copy(b_h_shared, b_h_fragment) else: T.clear(b_h_fragment) for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages): # Store previous result to the hidden tensor, like the epilogue - T.copy(b_h_shared, h[bb, i_s, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) + T.copy(b_h_shared, h[bb, i_s, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) # Recurrence - T.copy(W[bb, i_s * block_S:(i_s + 1) * block_S, bh, 0:DK], W_shared) + T.copy(W[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], W_shared) T.gemm(W_shared, b_h_shared, V_new_fragment, clear_accum=True) # U - W * S - T.copy( - U[bb, i_s * block_S:(i_s + 1) * block_S, bh, bv * block_DV:(bv + 1) * block_DV], - U_shared) + T.copy(U[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], U_shared) T.copy(U_shared, U_fragment) for i_s2, i_v in T.Parallel(block_S, block_DV): V_new_fragment[i_s2, i_v] = -V_new_fragment[i_s2, i_v] + U_fragment[i_s2, i_v] @@ -179,27 +192,24 @@ def kernel( # Save V_new if save_new_value: T.copy(V_new_fragment, dst=V_new_shared) - T.copy( - V_new_shared, V_new[bb, i_s * block_S:(i_s + 1) * block_S, bh, - bv * block_DV:(bv + 1) * block_DV]) + T.copy(V_new_shared, V_new[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) - T.copy(K[bb, i_s * block_S:(i_s + 1) * block_S, bh, 0:DK], K_shared) + T.copy(K[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], K_shared) # use_g if use_g: - G_last_local[0] = G[bb, (i_s + 1) * block_S - 1, bh] + G_last_local = G[bb, (i_s + 1) * block_S - 1, bh] for i_s2, i_v in T.Parallel(block_S, block_DV): G_shared[i_s2, i_v] = G[bb, i_s * block_S + i_s2, bh] T.copy(G_shared, G_fragment) for i_s2, i_v in T.Parallel(block_S, block_DV): - with T.If(G_last_local[0] - G_fragment[i_s2, i_v] <= 0): - with T.Then(): - V_new_fragment[i_s2, i_v] = V_new_fragment[i_s2, i_v] * T.exp( - G_last_local[0] - G_fragment[i_s2, i_v]) - with T.Else(): - V_new_fragment[i_s2, i_v] = 0 - G_last_local[0] = T.exp(G_last_local[0]) + V_new_fragment[i_s2, i_v] = ( + V_new_fragment[i_s2, i_v] * T.exp2((G_last_local - G_fragment[i_s2, i_v]) * 1.442695) + if G_last_local - G_fragment[i_s2, i_v] <= 0 + else 0 + ) + G_last_local = T.exp2(G_last_local * 1.442695) for i_k, i_v in T.Parallel(DK, block_DV): - b_h_fragment[i_k, i_v] *= G_last_local[0] + b_h_fragment[i_k, i_v] *= G_last_local # Update intermediate results T.copy(V_new_fragment, V_new_shared) @@ -209,7 +219,7 @@ def kernel( # Save final state if store_final_state: - T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) + T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) return kernel @@ -260,47 +270,77 @@ def run_test( threads=128, num_stages=0, ): - K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype)) - h_ref, final_state_ref, V_new_ref = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, state_dtype)) - h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, state_dtype)) + K, W, U, G, initial_state = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + ) + h_ref, final_state_ref, V_new_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype) + ) + h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype) + ) # fla ref - h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h(K, W, U, G, initial_state, - store_final_state, chunk_size, - save_new_value) + h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h( + k=K, + w=W, + u=U, + g=G, + initial_state=initial_state, + output_final_state=store_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value, + ) # tilelang - kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, chunk_size, - use_g, use_initial_state, store_final_state, - save_new_value, block_DK, block_DV, threads, - num_stages) + kernel = tilelang_chunk_gated_delta_rule_fwd_h( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_g, + use_initial_state, + store_final_state, + save_new_value, + ) h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) # (zhengju) If you want to print the generated cuda code, you can uncomment the following line # print("CUDA Code:\n", kernel.get_kernel_source()) - fla_time = do_bench(chunk_gated_delta_rule_fwd_h, K, W, U, G, initial_state, store_final_state, - chunk_size, save_new_value) + fla_time = do_bench( + chunk_gated_delta_rule_fwd_h, + k=K, + w=W, + u=U, + g=G, + initial_state=initial_state, + output_final_state=store_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value, + ) tilelang_time = do_bench(kernel, K, W, U, G, initial_state) # check correctness try: h_ref_fp32 = h_ref.to(torch.float32) h_tilelang_fp32 = h_tilelang.to(torch.float32) - assert_similar( - h_ref_fp32, - h_tilelang_fp32, - eps=1e-5, - name="tilelang chunk gated delta rule fwd h", - raise_assert=False) + assert_similar(h_ref_fp32, h_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd h", raise_assert=False) print("tilelang chunk gated delta rule fwd h passed √") except Exception as e: print("tilelang chunk gated delta rule fwd h failed ✗") @@ -314,7 +354,8 @@ def run_test( final_state_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd final_state", - raise_assert=False) + raise_assert=False, + ) print("tilelang chunk gated delta rule fwd final_state passed √") except Exception as e: print("tilelang chunk gated delta rule fwd final_state failed ✗") @@ -323,12 +364,7 @@ def run_test( try: V_new_ref_fp32 = V_new_ref.to(torch.float32) V_new_tilelang_fp32 = V_new_tilelang.to(torch.float32) - assert_similar( - V_new_ref_fp32, - V_new_tilelang_fp32, - eps=1e-5, - name="tilelang chunk gated delta rule fwd V_new", - raise_assert=False) + assert_similar(V_new_ref_fp32, V_new_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd V_new", raise_assert=False) print("tilelang chunk gated delta rule fwd V_new passed √") except Exception as e: print("tilelang chunk gated delta rule fwd V_new failed ✗") @@ -345,20 +381,20 @@ def main(): H=32, DK=128, DV=128, - input_dtype="bfloat16", - output_dtype="bfloat16", - accum_dtype="float32", - gate_dtype="float32", - state_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, chunk_size=64, use_g=True, - use_initial_state=True, + use_initial_state=False, store_final_state=True, save_new_value=True, - block_DK=64, + block_DK=32, block_DV=32, threads=128, - num_stages=1, + num_stages=2, ) diff --git a/examples/gdn/example_chunk_o.py b/examples/gdn/example_chunk_o.py index 1c084be70..a4d7281f5 100644 --- a/examples/gdn/example_chunk_o.py +++ b/examples/gdn/example_chunk_o.py @@ -9,6 +9,7 @@ # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.common.chunk_o import chunk_fwd_o except ImportError: @@ -87,16 +88,14 @@ def tilelang_chunk_fwd_o( @T.prim_func def kernel( - Q: T.Tensor(Q_shape, dtype=input_dtype), - K: T.Tensor(K_shape, dtype=input_dtype), - V: T.Tensor(V_shape, dtype=input_dtype), - HIDDEN: T.Tensor(H_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - O: T.Tensor(O_shape, dtype=output_dtype), + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + HIDDEN: T.Tensor(H_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + O: T.Tensor(O_shape, dtype=output_dtype), ): - with T.Kernel( - T.ceildiv(DV, block_DV), T.ceildiv(S, block_S), B * H, - threads=threads) as (bv, bs, bbh): + with T.Kernel(T.ceildiv(DV, block_DV), T.ceildiv(S, block_S), B * H, threads=threads) as (bv, bs, bbh): bb, bh = bbh // H, bbh % H Q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) @@ -109,28 +108,13 @@ def kernel( G_shared = T.alloc_shared((block_S,), dtype=gate_dtype, scope="shared") G_diff_local = T.alloc_fragment((block_S, block_S), dtype=gate_dtype) - T.annotate_layout({ - Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - V_shared: tilelang.layout.make_swizzled_layout(V_shared), - H_shared: tilelang.layout.make_swizzled_layout(H_shared), - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - }) - T.clear(A_fragment) T.clear(O_fragment) T.disable_warp_group_reg_alloc() for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): - T.copy( - Q[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - Q_shared) - T.copy( - K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - K_shared) - T.copy( - HIDDEN[bb, bs, bh, i_k * block_DK:(i_k + 1) * block_DK, - bv * block_DV:(bv + 1) * block_DV], H_shared) + T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], Q_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) + T.copy(HIDDEN[bb, bs, bh, i_k * block_DK : (i_k + 1) * block_DK, bv * block_DV : (bv + 1) * block_DV], H_shared) T.gemm(Q_shared, H_shared, O_fragment) T.gemm(Q_shared, K_shared, A_fragment, transpose_B=True) @@ -145,8 +129,7 @@ def kernel( for i_s1, i_s2 in T.Parallel(block_S, block_S): with T.If(G_diff_local[i_s1, i_s2] <= 0): with T.Then(): - A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp( - G_diff_local[i_s1, i_s2]) + A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]) with T.Else(): A_fragment[i_s1, i_s2] = 0 @@ -155,8 +138,7 @@ def kernel( with T.Then(): A_fragment[i_s1, i_s2] = 0 - T.copy(V[bb, bs * block_S:(bs + 1) * block_S, bh, bv * block_DV:(bv + 1) * block_DV], - V_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], V_shared) T.copy(A_fragment, A_shared) T.gemm(A_shared, V_shared, O_fragment) @@ -164,8 +146,7 @@ def kernel( O_fragment[i_s, i_v] = O_fragment[i_s, i_v] * scale T.copy(O_fragment, O_shared) - T.copy(O_shared, O[bb, bs * block_S:(bs + 1) * block_S, bh, - bv * block_DV:(bv + 1) * block_DV]) + T.copy(O_shared, O[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) return kernel @@ -191,8 +172,9 @@ def run_test( output_dtype_torch = getattr(torch, output_dtype) accum_dtype_torch = getattr(torch, accum_dtype) gate_dtype_torch = getattr(torch, gate_dtype) - Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, input_dtype_torch, - output_dtype_torch, accum_dtype_torch, gate_dtype_torch) + Q, K, V, HIDDEN, G = prepare_input( + B, S, H, DK, DV, chunk_size, input_dtype_torch, output_dtype_torch, accum_dtype_torch, gate_dtype_torch + ) scale = 1.0 / DK**0.5 O_ref = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch) @@ -200,9 +182,25 @@ def run_test( block_S = chunk_size O_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch) - kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV, - threads, num_stages) + kernel = tilelang_chunk_fwd_o( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + chunk_size, + scale, + use_g, + block_S, + block_DK, + block_DV, + threads, + num_stages, + ) O_tilelang = kernel(Q, K, V, HIDDEN, G) try: @@ -221,10 +219,10 @@ def main(): DK=128, DV=128, chunk_size=64, - input_dtype="bfloat16", - output_dtype="bfloat16", - accum_dtype="float32", - gate_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, use_g=True, block_DK=128, block_DV=128, diff --git a/examples/gdn/example_chunk_o_bwd.py b/examples/gdn/example_chunk_o_bwd.py index 76b4792df..e589818f4 100644 --- a/examples/gdn/example_chunk_o_bwd.py +++ b/examples/gdn/example_chunk_o_bwd.py @@ -7,13 +7,12 @@ import tilelang.language as T from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401 -print(tilelang.__file__) - # Add your fla repository path to sys.path # Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.common.chunk_o import chunk_bwd_dqkwg except ImportError: @@ -21,7 +20,7 @@ fla = None import torch -from utils import * +from test_utils import assert_similar torch.random.manual_seed(0) # torch.set_printoptions(profile="full") @@ -110,10 +109,8 @@ def prepare_output( @tilelang.jit( out_idx=[-4, -3, -2, -1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - }) + pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, +) def tilelang_chunk_o_bwd_dqkwg( # task config B, @@ -157,25 +154,23 @@ def tilelang_chunk_o_bwd_dqkwg( @T.prim_func def kernel( - # input - Q: T.Tensor(Q_shape, dtype=input_dtype), - K: T.Tensor(K_shape, dtype=input_dtype), - V: T.Tensor(V_shape, dtype=input_dtype), - h: T.Tensor(h_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - dO: T.Tensor(dO_shape, dtype=input_dtype), - dh: T.Tensor(dh_shape, dtype=input_dtype), - dv: T.Tensor(dv_shape, dtype=input_dtype), - W: T.Tensor(W_shape, dtype=input_dtype), - # output - dq: T.Tensor(dq_shape, dtype=output_dtype), - dk: T.Tensor(dk_shape, dtype=output_dtype), - dw: T.Tensor(dw_shape, dtype=output_dtype), - dg: T.Tensor(dg_shape, dtype=gate_dtype), + # input + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + h: T.Tensor(h_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + dO: T.Tensor(dO_shape, dtype=input_dtype), + dh: T.Tensor(dh_shape, dtype=input_dtype), + dv: T.Tensor(dv_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + # output + dq: T.Tensor(dq_shape, dtype=output_dtype), + dk: T.Tensor(dk_shape, dtype=output_dtype), + dw: T.Tensor(dw_shape, dtype=output_dtype), + dg: T.Tensor(dg_shape, dtype=gate_dtype), ): - with T.Kernel( - T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H, - threads=threads) as (bk, bs, bbh): + with T.Kernel(T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H, threads=threads) as (bk, bs, bbh): bb, bh = bbh // H, bbh % H V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) @@ -204,27 +199,27 @@ def kernel( dg_fragment = T.alloc_fragment((block_S,), dtype=gate_dtype) dg_fragment_2 = T.alloc_fragment((block_S,), dtype=gate_dtype) dg_fragment_final = T.alloc_fragment((block_S,), dtype=gate_dtype) - dg_last_local = T.alloc_local((2,), dtype=gate_dtype) + dg_last_local_0 = T.alloc_var(dtype=gate_dtype) + dg_last_local_1 = T.alloc_var(dtype=gate_dtype) + G_last_local = T.alloc_var(dtype=gate_dtype) + dg_last_fragment = T.alloc_fragment((block_DV * block_DK), dtype=gate_dtype) dg_last_fragment_scalar = T.alloc_fragment((1,), dtype=gate_dtype) dg_last_fragment_2 = T.alloc_fragment((block_S * block_DK), dtype=gate_dtype) dg_last_fragment_scalar_2 = T.alloc_fragment((1,), dtype=gate_dtype) - G_shared = T.alloc_shared((block_S, block_DK), dtype=gate_dtype, scope="shared") - G_last_local = T.alloc_local((1,), dtype=gate_dtype) + G_shared = T.alloc_shared((block_S, block_DK), dtype=gate_dtype) T.use_swizzle(10) - T.annotate_layout({ - V_shared: tilelang.layout.make_swizzled_layout(V_shared), - dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), - h_shared: tilelang.layout.make_swizzled_layout(h_shared), - dh_shared: tilelang.layout.make_swizzled_layout(dh_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - q_shared: tilelang.layout.make_swizzled_layout(q_shared), - k_shared: tilelang.layout.make_swizzled_layout(k_shared), - }) - - T.clear(dg_last_local) + T.annotate_layout( + { + q_shared: tilelang.layout.make_swizzled_layout(q_shared), + k_shared: tilelang.layout.make_swizzled_layout(k_shared), + } + ) + + T.clear(dg_last_local_0) + T.clear(dg_last_local_1) T.clear(G_last_local) T.clear(G_shared) T.clear(q_fragment) @@ -237,18 +232,10 @@ def kernel( T.clear(dw_fragment) for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): - T.copy( - V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV], - V_shared) - T.copy( - dO[bb, bs * block_S:(bs + 1) * block_S, bh, - i_v * block_DV:(i_v + 1) * block_DV], dO_shared) - T.copy( - h[bb, bs, bh, bk * block_DK:(bk + 1) * block_DK, - i_v * block_DV:(i_v + 1) * block_DV], h_shared) - T.copy( - dh[bb, bs, bh, bk * block_DK:(bk + 1) * block_DK, - i_v * block_DV:(i_v + 1) * block_DV], dh_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) + T.copy(dO[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], dO_shared) + T.copy(h[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], h_shared) + T.copy(dh[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], dh_shared) if use_g: T.clear(dg_last_fragment_scalar) @@ -256,32 +243,25 @@ def kernel( # for i_kv in T.Parallel(block_DK * block_DV): # dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] for i_kv in T.Parallel(block_DK * block_DV): - i_k, i_v = i_kv // block_DV, i_kv % block_DV - dg_last_fragment[i_kv] = h_shared[i_k, i_v] * dh_shared[i_k, i_v] + dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False) - dg_last_local[0] += dg_last_fragment_scalar[0] + dg_last_local_0 = dg_last_local_0 + dg_last_fragment_scalar[0] T.gemm(dO_shared, V_shared, ds_fragment, transpose_B=True) T.gemm(dO_shared, h_shared, dq_fragment, transpose_B=True) T.gemm(V_shared, dh_shared, dk_fragment, transpose_B=True) if use_dw: - T.copy( - dv[bb, bs * block_S:(bs + 1) * block_S, bh, - i_v * block_DV:(i_v + 1) * block_DV], dv_shared) + T.copy(dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], dv_shared) T.gemm(dv_shared, h_shared, dw_fragment, transpose_B=True) if use_dw: for i_s, i_k in T.Parallel(block_S, block_DK): dw_fragment[i_s, i_k] = -dw_fragment[i_s, i_k] - T.copy( - dw_fragment, dw[bb, bs * block_S:(bs + 1) * block_S, bh, - bk * block_DK:(bk + 1) * block_DK]) - - T.copy(Q[bb, bs * block_S:(bs + 1) * block_S, bh, bk * block_DK:(bk + 1) * block_DK], - q_shared) - T.copy(K[bb, bs * block_S:(bs + 1) * block_S, bh, bk * block_DK:(bk + 1) * block_DK], - k_shared) + T.copy(dw_fragment, dw[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + + T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], q_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], k_shared) T.copy(q_shared, q_fragment) T.copy(k_shared, k_fragment) @@ -290,13 +270,12 @@ def kernel( T.clear(dg_fragment_2) for i_s, i_k in T.Parallel(block_S, block_DK): G_shared[i_s, i_k] = G[bb, bs * block_S + i_s, bh] - G_last_local[0] = G[bb, bs * block_S + block_S - 1, bh] + dg_last_local_0 = G[bb, bs * block_S + block_S - 1, bh] # Use gmem directly instead of local register - dg_last_local[0] = dg_last_local[0] * T.exp(G[bb, bs * block_S + block_S - 1, bh]) + dg_last_local_0 = dg_last_local_0 * T.exp(G[bb, bs * block_S + block_S - 1, bh]) for i_s, i_k in T.Parallel(block_S, block_DK): - dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * T.exp(G[bb, bs * block_S + i_s, - bh]) * scale + dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * T.exp(G[bb, bs * block_S + i_s, bh]) * scale T.clear(dg_fragment_reduce_tmp) for i_s, i_k in T.Parallel(block_S, block_DK): dg_fragment_reduce_tmp[i_s, i_k] = dq_fragment[i_s, i_k] * q_shared[i_s, i_k] @@ -304,12 +283,11 @@ def kernel( T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=-1, clear=False) for i_s, i_k in T.Parallel(block_S, block_DK): - with T.If(G_last_local[0] - G[bb, bs * block_S + i_s, bh] <= 0): - with T.Then(): - dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] * T.exp( - G_last_local[0] - G[bb, bs * block_S + i_s, bh]) - with T.Else(): - dk_fragment[i_s, i_k] = 0 + dk_fragment[i_s, i_k] = ( + dk_fragment[i_s, i_k] * T.exp(G_last_local - G[bb, bs * block_S + i_s, bh]) + if G_last_local - G[bb, bs * block_S + i_s, bh] <= 0 + else 0 + ) T.clear(dg_fragment_reduce_tmp) for i_s, i_k in T.Parallel(block_S, block_DK): dg_fragment_reduce_tmp[i_s, i_k] = dk_fragment[i_s, i_k] * (-k_shared[i_s, i_k]) @@ -323,24 +301,20 @@ def kernel( i_s, i_k = i_sk // block_DK, i_sk % block_DK dg_last_fragment_2[i_sk] = dk_shared[i_s, i_k] * k_shared[i_s, i_k] T.reduce_sum(dg_last_fragment_2, dg_last_fragment_scalar_2, dim=-1, clear=False) - dg_last_local[1] = dg_last_fragment_scalar_2[0] + dg_last_local_1 = dg_last_fragment_scalar_2[0] for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(i_s1 >= i_s2 and - G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0): - with T.Then(): - ds_fragment[i_s1, i_s2] = ds_fragment[ - i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - - G[bb, bs * block_S + i_s2, bh]) * scale - with T.Else(): - ds_fragment[i_s1, i_s2] = 0 + ds_fragment[i_s1, i_s2] = ( + (ds_fragment[i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]) * scale) + if G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0 + else 0 + ) T.clear(ds_fragment_positive) T.clear(ds_fragment_positive_transpose) T.gemm(q_shared, k_shared, ds_fragment_positive, transpose_B=True) for i_s1, i_s2 in T.Parallel(block_S, block_S): - ds_fragment_positive[ - i_s1, i_s2] = ds_fragment[i_s1, i_s2] * ds_fragment_positive[i_s1, i_s2] + ds_fragment_positive[i_s1, i_s2] = ds_fragment[i_s1, i_s2] * ds_fragment_positive[i_s1, i_s2] # FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass T.reduce_sum(ds_fragment_positive, dg_fragment, dim=1, clear=False) @@ -362,25 +336,16 @@ def kernel( T.gemm(ds_shared, q_shared, dk_fragment, transpose_A=True) for i_s in T.Parallel(block_S): - with T.If(i_s >= block_S - 1): # noqa: SIM117 - with T.Then(): - dg_fragment_final[ - i_s] = dg_fragment_final[i_s] + dg_last_local[0] + dg_last_local[1] - - T.copy( - dq_fragment, dq[bb, bs * block_S:(bs + 1) * block_S, bh, - bk * block_DK:(bk + 1) * block_DK]) - T.copy( - dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh, - bk * block_DK:(bk + 1) * block_DK]) + dg_fragment_final[i_s] = dg_fragment_final[i_s] + dg_last_local_0 + dg_last_local_1 + + T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) for i_s in T.Parallel(block_S): dg[bk, bb, bs * block_S + i_s, bh] = dg_fragment_final[i_s] else: for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(i_s1 < i_s2): # noqa: SIM117 - with T.Then(): - ds_fragment[i_s1, i_s2] = 0 + ds_fragment[i_s1, i_s2] = 0 if i_s1 < i_s2 else ds_fragment[i_s1, i_s2] T.clear(dk_fragment_2) T.copy(ds_fragment, ds_shared) T.gemm(ds_shared, k_shared, dq_fragment) @@ -388,12 +353,8 @@ def kernel( for i_s, i_k in T.Parallel(block_S, block_DK): dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * scale dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] + dk_fragment_2[i_s, i_k] * scale - T.copy( - dq_fragment, dq[bb, bs * block_S:(bs + 1) * block_S, bh, - bk * block_DK:(bk + 1) * block_DK]) - T.copy( - dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh, - bk * block_DK:(bk + 1) * block_DK]) + T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) return kernel @@ -443,33 +404,53 @@ def run_test( threads=256, num_stages=0, ): - Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - dq_ref, dk_ref, dw_ref, dg_ref = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype), block_DK) + Q, K, V, h, G, dO, dh, dv, W = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dq_ref, dk_ref, dw_ref, dg_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype), block_DK + ) dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = prepare_output( - B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), - getattr(torch, state_dtype), block_DK) + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype), block_DK + ) # ref if use_g: - dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg( - Q, K, V, G, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale) + dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(Q, K, V, G, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale) else: - dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg( - Q, K, V, None, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale) + dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(Q, K, V, None, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale) # tilelang - kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, state_dtype, chunk_size, scale, use_g, use_dw, - block_DK, block_DV, threads, num_stages) - print(kernel.get_kernel_source()) + kernel = tilelang_chunk_o_bwd_dqkwg( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_g, + use_dw, + block_DK, + block_DV, + threads, + num_stages, + ) dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W) if use_g: @@ -516,11 +497,11 @@ def main(): H=8, DK=DK, DV=DV, - input_dtype="bfloat16", - output_dtype="bfloat16", - accum_dtype="float32", - gate_dtype="float32", - state_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, chunk_size=64, scale=DK**-0.5, # scale=1, diff --git a/examples/gdn/example_chunk_scaled_dot_kkt.py b/examples/gdn/example_chunk_scaled_dot_kkt.py index d07a4776a..8c7a4d573 100644 --- a/examples/gdn/example_chunk_scaled_dot_kkt.py +++ b/examples/gdn/example_chunk_scaled_dot_kkt.py @@ -9,6 +9,7 @@ # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd except ImportError: @@ -56,9 +57,9 @@ def tilelang_chunk_scaled_dot_kkt_fwd( H, DK, chunk_size=64, - input_dtype="bfloat16", - output_dtype="bfloat16", - accum_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, use_g=True, # kernel config block_S=64, @@ -75,10 +76,10 @@ def tilelang_chunk_scaled_dot_kkt_fwd( @T.prim_func def kernel( - K: T.Tensor(K_shape, dtype=input_dtype), - Beta: T.Tensor(Beta_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=accum_dtype), - A: T.Tensor(output_shape, dtype=output_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=accum_dtype), + A: T.Tensor(output_shape, dtype=output_dtype), ): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): bb, bh = bbh // H, bbh % H @@ -93,20 +94,13 @@ def kernel( G_shared = T.alloc_shared((block_S,), dtype=accum_dtype, scope="shared") G_diff_local = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - }) - T.fill(A_fragment, 0) T.disable_warp_group_reg_alloc() for i_s in T.Parallel(block_S): Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): - T.copy( - K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - K_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) for i_s, i_k2 in T.Parallel(block_S, block_DK): Beta_K_fragment[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] T.gemm(Beta_K_fragment, K_shared, A_fragment, transpose_B=True) @@ -119,8 +113,7 @@ def kernel( for i_s1, i_s2 in T.Parallel(block_S, block_S): with T.If(G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2): with T.Then(): - A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp( - G_diff_local[i_s1, i_s2]) + A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]) with T.Else(): A_fragment[i_s1, i_s2] = 0 else: @@ -130,7 +123,7 @@ def kernel( A_fragment[i_s1, i_s2] = 0 T.copy(A_fragment, A_shared) - T.copy(A_shared, A[bb, bs * block_S:(bs + 1) * block_S, bh, :]) + T.copy(A_shared, A[bb, bs * block_S : (bs + 1) * block_S, bh, :]) return kernel @@ -149,24 +142,21 @@ def run_test( threads, num_stages, ): - K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), - getattr(torch, output_dtype), getattr(torch, accum_dtype)) + K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype)) A_ref = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype)) A_tilelang = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype)) # reference if use_g: - A_ref = chunk_scaled_dot_kkt_fwd( - K, Beta, G, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype)) + A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, G, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype)) else: - A_ref = chunk_scaled_dot_kkt_fwd( - K, Beta, None, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype)) + A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, None, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype)) # tilelang block_S = chunk_size - kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype, - accum_dtype, use_g, block_S, block_DK, threads, - num_stages) + kernel = tilelang_chunk_scaled_dot_kkt_fwd( + B, S, H, DK, chunk_size, input_dtype, output_dtype, accum_dtype, use_g, block_S, block_DK, threads, num_stages + ) A_tilelang = kernel(K, Beta, G) try: @@ -186,13 +176,14 @@ def main(): H=32, DK=128, chunk_size=64, - input_dtype="bfloat16", - output_dtype="bfloat16", - accum_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, use_g=True, block_DK=64, threads=128, - num_stages=2) + num_stages=2, + ) if __name__ == "__main__": diff --git a/examples/gdn/example_cumsum.py b/examples/gdn/example_cumsum.py index 9896c7ecf..0760b4964 100644 --- a/examples/gdn/example_cumsum.py +++ b/examples/gdn/example_cumsum.py @@ -10,6 +10,7 @@ # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.utils.cumsum import chunk_local_cumsum_scalar except ImportError: @@ -20,11 +21,8 @@ @tilelang.jit( - out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - }) + out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} +) def tilelang_chunk_local_cumsum_scalar( # task config B, @@ -34,43 +32,43 @@ def tilelang_chunk_local_cumsum_scalar( is_varlen=False, head_first=False, reverse=False, - input_dtype="float16", - output_dtype="float32", + input_dtype=T.float16, + output_dtype=T.float32, # kernel config block_S=64, threads=256, use_fragment=False, ): G_shape = (B, H, S) if head_first else (B, S, H) - assert chunk_size == 2**(chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" assert chunk_size == block_S, "chunk_size must be equal to block_S" @T.prim_func def kernel( - G: T.Tensor(G_shape, dtype=input_dtype), - G_new: T.Tensor(G_shape, dtype=output_dtype), + G: T.Tensor(G_shape, dtype=input_dtype), + G_new: T.Tensor(G_shape, dtype=output_dtype), ): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): bb, bh = bbh // H, bbh % H G_shared = T.alloc_shared((1, block_S), dtype=output_dtype, scope="shared") if head_first: - T.copy(G[bb, bh, bs * block_S:(bs + 1) * block_S], G_shared) + T.copy(G[bb, bh, bs * block_S : (bs + 1) * block_S], G_shared) else: - T.copy(G[bb, bs * block_S:(bs + 1) * block_S, bh], G_shared) + T.copy(G[bb, bs * block_S : (bs + 1) * block_S, bh], G_shared) if use_fragment: G_fragment = T.alloc_fragment((1, block_S), dtype=output_dtype, scope="shared") T.copy(G_shared, G_fragment) T.cumsum(G_fragment, dim=1, reverse=reverse) if head_first: - T.copy(G_fragment, G_new[bb, bh, bs * block_S:(bs + 1) * block_S]) + T.copy(G_fragment, G_new[bb, bh, bs * block_S : (bs + 1) * block_S]) else: - T.copy(G_fragment, G_new[bb, bs * block_S:(bs + 1) * block_S, bh]) + T.copy(G_fragment, G_new[bb, bs * block_S : (bs + 1) * block_S, bh]) else: T.cumsum(G_shared, dim=1, reverse=reverse) if head_first: - T.copy(G_shared, G_new[bb, bh, bs * block_S:(bs + 1) * block_S]) + T.copy(G_shared, G_new[bb, bh, bs * block_S : (bs + 1) * block_S]) else: - T.copy(G_shared, G_new[bb, bs * block_S:(bs + 1) * block_S, bh]) + T.copy(G_shared, G_new[bb, bs * block_S : (bs + 1) * block_S, bh]) return kernel @@ -113,11 +111,8 @@ def run_test( # reference cumsum G_new_ref = chunk_local_cumsum_scalar( - g=G, - chunk_size=chunk_size, - reverse=reverse, - head_first=head_first, - output_dtype=getattr(torch, output_dtype)) + g=G, chunk_size=chunk_size, reverse=reverse, head_first=head_first, output_dtype=getattr(torch, output_dtype) + ) # tilelang cumsum block_S = chunk_size @@ -159,10 +154,11 @@ def main(): chunk_size=64, reverse=True, head_first=False, - input_dtype="float32", - output_dtype="float32", + input_dtype=T.float32, + output_dtype=T.float32, threads=256, - use_fragment=False) + use_fragment=False, + ) if __name__ == "__main__": diff --git a/examples/gdn/example_wy_fast.py b/examples/gdn/example_wy_fast.py index 0a0983a82..d36dcf9b7 100644 --- a/examples/gdn/example_wy_fast.py +++ b/examples/gdn/example_wy_fast.py @@ -9,6 +9,7 @@ # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd except ImportError: @@ -73,13 +74,13 @@ def tilelang_recompute_w_u_fwd( @T.prim_func def kernel( - K: T.Tensor(K_shape, dtype=input_dtype), - V: T.Tensor(V_shape, dtype=input_dtype), - Beta: T.Tensor(Beta_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - A: T.Tensor(A_shape, dtype=output_dtype), - W: T.Tensor(K_shape, dtype=output_dtype), - U: T.Tensor(V_shape, dtype=output_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + A: T.Tensor(A_shape, dtype=output_dtype), + W: T.Tensor(K_shape, dtype=output_dtype), + U: T.Tensor(V_shape, dtype=output_dtype), ): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): bb, bh = bbh // H, bbh % H @@ -95,49 +96,37 @@ def kernel( W_Beta_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) U_Beta_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - V_shared: tilelang.layout.make_swizzled_layout(V_shared), - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - W_shared: tilelang.layout.make_swizzled_layout(W_shared), - U_shared: tilelang.layout.make_swizzled_layout(U_shared), - W_Beta_shared: tilelang.layout.make_swizzled_layout(W_Beta_shared), - U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared), - }) + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + V_shared: tilelang.layout.make_swizzled_layout(V_shared), + } + ) T.disable_warp_group_reg_alloc() for i_s in T.Parallel(block_S): Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] G_shared[i_s] = T.exp(G[bb, bs * block_S + i_s, bh]) - T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared) + T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared) for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): - T.copy( - V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV], - V_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) for i_s, i_v2 in T.Parallel(block_S, block_DV): U_Beta_shared[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s] T.gemm(A_shared, U_Beta_shared, U_fragment, clear_accum=True) # First copy to smem, then copy to gmem to reduce U2RU instructions T.copy(U_fragment, U_shared) - T.copy( - U_shared, U[bb, bs * block_S:(bs + 1) * block_S, bh, - i_v * block_DV:(i_v + 1) * block_DV]) + T.copy(U_shared, U[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV]) for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): - T.copy( - K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - K_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) for i_s, i_k2 in T.Parallel(block_S, block_DK): - W_Beta_shared[i_s, - i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared[i_s] + W_Beta_shared[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared[i_s] T.gemm(A_shared, W_Beta_shared, W_fragment, clear_accum=True) # First copy to smem, then copy to gmem to reduce U2RU instructions T.copy(W_fragment, W_shared) - T.copy( - W_shared, W[bb, bs * block_S:(bs + 1) * block_S, bh, - i_k * block_DK:(i_k + 1) * block_DK]) + T.copy(W_shared, W[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK]) return kernel @@ -159,15 +148,8 @@ def run_test( num_stages, ): K, V, Beta, G, A = prepare_input( - B, - S, - H, - DK, - DV, - chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - gate_dtype=getattr(torch, gate_dtype)) + B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, output_dtype), gate_dtype=getattr(torch, gate_dtype) + ) W_ref, U_ref = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype)) W_tilelang, U_tilelang = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype)) @@ -191,7 +173,8 @@ def run_test( block_DK=block_DK, block_DV=block_DV, threads=threads, - num_stages=num_stages) + num_stages=num_stages, + ) print(kernel.get_kernel_source()) W_tilelang, U_tilelang = kernel(K, V, Beta, G, A) @@ -217,14 +200,15 @@ def main(): DK=128, DV=128, chunk_size=64, - input_dtype="bfloat16", - output_dtype="bfloat16", - gate_dtype="float32", - accum_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + gate_dtype=T.float32, + accum_dtype=T.float32, block_DK=64, block_DV=32, threads=128, - num_stages=3) + num_stages=3, + ) if __name__ == "__main__": diff --git a/examples/gdn/example_wy_fast_bwd_split.py b/examples/gdn/example_wy_fast_bwd_split.py index 618a82b4c..de8afc2b7 100644 --- a/examples/gdn/example_wy_fast_bwd_split.py +++ b/examples/gdn/example_wy_fast_bwd_split.py @@ -10,6 +10,7 @@ # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr except ImportError: @@ -93,10 +94,8 @@ def prepare_output( @tilelang.jit( out_idx=[-5, -4, -3, -2, -1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - }) + pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, +) def tilelang_wy_fast_bwd( # task config B, @@ -135,20 +134,20 @@ def tilelang_wy_fast_bwd( @T.prim_func def kernel( - # input - K: T.Tensor(K_shape, dtype=input_dtype), - V: T.Tensor(V_shape, dtype=input_dtype), - Beta: T.Tensor(Beta_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - A: T.Tensor(A_shape, dtype=input_dtype), - dw: T.Tensor(dw_shape, dtype=input_dtype), - du: T.Tensor(du_shape, dtype=input_dtype), - # output - dA: T.Tensor(dA_shape, dtype=input_dtype), - dk: T.Tensor(dk_shape, dtype=output_dtype), - dv: T.Tensor(dv_shape, dtype=output_dtype), - dbeta: T.Tensor(dbeta_shape, dtype=output_dtype), - dg: T.Tensor(dg_shape, dtype=gate_dtype), + # input + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + A: T.Tensor(A_shape, dtype=input_dtype), + dw: T.Tensor(dw_shape, dtype=input_dtype), + du: T.Tensor(du_shape, dtype=input_dtype), + # output + dA: T.Tensor(dA_shape, dtype=input_dtype), + dk: T.Tensor(dk_shape, dtype=output_dtype), + dv: T.Tensor(dv_shape, dtype=output_dtype), + dbeta: T.Tensor(dbeta_shape, dtype=output_dtype), + dg: T.Tensor(dg_shape, dtype=gate_dtype), ): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): bb, bh = bbh // H, bbh % H @@ -187,7 +186,7 @@ def kernel( T.clear(dbeta_fragment_v) T.clear(dg_fragment) - T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared) + T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared) for i_s in T.Parallel(block_S): Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] G_shared[i_s] = G[bb, bs * block_S + i_s, bh] @@ -195,51 +194,37 @@ def kernel( # Update dk for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): - T.copy( - K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - K_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) for i_s, i_k2 in T.Parallel(block_S, block_DK): - K_shared_beta_g[i_s, - i_k2] = K_shared[i_s, - i_k2] * Beta_shared[i_s] * G_shared_exp[i_s] - T.copy( - dw[bb, bs * block_S:(bs + 1) * block_S, bh, - i_k * block_DK:(i_k + 1) * block_DK], dw_shared) + K_shared_beta_g[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s] + T.copy(dw[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], dw_shared) T.gemm(dw_shared, K_shared_beta_g, dA_fragment, transpose_B=True) T.gemm(A_shared, dw_shared, dk_fragment_beta_g, clear_accum=True, transpose_A=True) for i_s, i_k2 in T.Parallel(block_S, block_DK): - dk_fragment[ - i_s, - i_k2] = dk_fragment_beta_g[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s] + dk_fragment[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s] # for i_s, i_k2 in T.Parallel(block_S, block_DK): # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] for i_s, i_k2 in T.Parallel(block_S, block_DK): - dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta_g[ - i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] + dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False) # for i_s, i_k2 in T.Parallel(block_S, block_DK): # dg_fragment[i_s] = dg_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s] for i_s, i_k2 in T.Parallel(block_S, block_DK): - dg_fragment_reduce_tmp[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * K_shared[ - i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s] + dg_fragment_reduce_tmp[i_s, i_k2] = ( + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s] + ) T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=1, clear=False) # correct dk - T.copy( - dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh, - i_k * block_DK:(i_k + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK]) # Update dv for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): - T.copy( - V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV], - V_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) for i_s, i_v2 in T.Parallel(block_S, block_DV): V_shared_beta[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s] - T.copy( - du[bb, bs * block_S:(bs + 1) * block_S, bh, - i_v * block_DV:(i_v + 1) * block_DV], du_shared) + T.copy(du[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], du_shared) T.gemm(du_shared, V_shared_beta, dA_fragment, transpose_B=True) T.gemm(A_shared, du_shared, dv_fragment_beta, clear_accum=True, transpose_A=True) for i_s, i_v2 in T.Parallel(block_S, block_DV): @@ -247,30 +232,22 @@ def kernel( # for i_s, i_v2 in T.Parallel(block_S, block_DV): # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2] for i_s, i_v2 in T.Parallel(block_S, block_DV): - dbeta_fragment_reduce_tmpv[i_s, - i_v2] = dv_fragment_beta[i_s, i_v2] * V_shared[i_s, - i_v2] + dbeta_fragment_reduce_tmpv[i_s, i_v2] = dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2] T.reduce_sum(dbeta_fragment_reduce_tmpv, dbeta_fragment_v, dim=1, clear=False) - T.copy( - dv_fragment, dv[bb, bs * block_S:(bs + 1) * block_S, bh, - i_v * block_DV:(i_v + 1) * block_DV]) + T.copy(dv_fragment, dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV]) # Temporary store dbeta, dg and dA for i_s in T.Parallel(block_S): dbeta[bb, bs * block_S + i_s, bh] = dbeta_fragment_k[i_s] + dbeta_fragment_v[i_s] dg[bb, bs * block_S + i_s, bh] = dg_fragment[i_s] # correct dA - T.copy(dA_fragment, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :]) + T.copy(dA_fragment, dA[bb, bs * block_S : (bs + 1) * block_S, bh, :]) return kernel -@tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - }) +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}) def tilelang_wy_fast_bwd_split( # task config B, @@ -308,20 +285,20 @@ def tilelang_wy_fast_bwd_split( @T.prim_func def kernel( - # input - K: T.Tensor(K_shape, dtype=input_dtype), - V: T.Tensor(V_shape, dtype=input_dtype), - Beta: T.Tensor(Beta_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - A: T.Tensor(A_shape, dtype=input_dtype), - dw: T.Tensor(dw_shape, dtype=input_dtype), - du: T.Tensor(du_shape, dtype=input_dtype), - dA: T.Tensor(dA_shape, dtype=input_dtype), - dk: T.Tensor(dk_shape, dtype=output_dtype), - dv: T.Tensor(dv_shape, dtype=output_dtype), - dbeta_k: T.Tensor(dbeta_shape, dtype=output_dtype), - dg_A_positive: T.Tensor(dA_shape, dtype=gate_dtype), - dg_A_negative: T.Tensor(dA_shape, dtype=gate_dtype), + # input + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + A: T.Tensor(A_shape, dtype=input_dtype), + dw: T.Tensor(dw_shape, dtype=input_dtype), + du: T.Tensor(du_shape, dtype=input_dtype), + dA: T.Tensor(dA_shape, dtype=input_dtype), + dk: T.Tensor(dk_shape, dtype=output_dtype), + dv: T.Tensor(dv_shape, dtype=output_dtype), + dbeta_k: T.Tensor(dbeta_shape, dtype=output_dtype), + dg_A_positive: T.Tensor(dA_shape, dtype=gate_dtype), + dg_A_negative: T.Tensor(dA_shape, dtype=gate_dtype), ): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): bb, bh = bbh // H, bbh % H @@ -350,7 +327,7 @@ def kernel( T.clear(dA_A_fragment_1) T.clear(dA_A_fragment_2) - T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared) + T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared) for i_s in T.Parallel(block_S): Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] G_shared[i_s] = G[bb, bs * block_S + i_s, bh] @@ -361,7 +338,7 @@ def kernel( # for i_s in T.Parallel(block_S): # dbeta_fragment[i_s] = dbeta[bb, bs * block_S + i_s, bh] # dg_fragment[i_s] = dg[bb, bs * block_S + i_s, bh] - T.copy(dA[bb, bs * block_S:(bs + 1) * block_S, bh, :], dA_shared) + T.copy(dA[bb, bs * block_S : (bs + 1) * block_S, bh, :], dA_shared) # T.copy(dA_shared, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :]) # Update dA @@ -385,8 +362,7 @@ def kernel( for i_s1, i_s2 in T.Parallel(block_S, block_S): with T.If(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0): with T.Then(): - dA_fragment[i_s1, i_s2] *= T.exp(G[bb, bs * block_S + i_s1, bh] - - G[bb, bs * block_S + i_s2, bh]) + dA_fragment[i_s1, i_s2] *= T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]) with T.Else(): dA_fragment[i_s1, i_s2] = 0 T.copy(dA_fragment, dA_shared) @@ -397,12 +373,8 @@ def kernel( # Update dk using previous dk T.clear(A_fragment) for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): - T.copy( - K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - K_shared) - T.copy( - dk[bb, bs * block_S:(bs + 1) * block_S, bh, - i_k * block_DK:(i_k + 1) * block_DK], dk_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) + T.copy(dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], dk_shared) T.copy(dk_shared, dk_fragment) for i_s, i_k2 in T.Parallel(block_S, block_DK): K_shared_beta[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] @@ -411,18 +383,14 @@ def kernel( # for i_s, i_k2 in T.Parallel(block_S, block_DK): # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2] for i_s, i_k2 in T.Parallel(block_S, block_DK): - dbeta_fragment_reduce_tmpk[i_s, - i_k2] = dk_fragment_beta[i_s, i_k2] * K_shared[i_s, - i_k2] + dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2] T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False) T.gemm(dA_shared, K_shared_beta, dk_fragment, transpose_A=True) for i_s, i_k2 in T.Parallel(block_S, block_DK): dk_shared_beta[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * Beta_shared[i_s] for i_s, i_k2 in T.Parallel(block_S, block_DK): dk_fragment[i_s, i_k2] = dk_fragment[i_s, i_k2] + dk_shared_beta[i_s, i_k2] - T.copy( - dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh, - i_k * block_DK:(i_k + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK]) # Update dg and dbeta T.copy(A_fragment, A_shared) @@ -460,19 +428,25 @@ def run_test( threads=128, num_stages=0, ): - K, V, Beta, G, A, dw, du = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, - accum_dtype), getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - dk_ref, dv_ref, dbeta_ref, dg_ref = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) + K, V, Beta, G, A, dw, du = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dk_ref, dv_ref, dbeta_ref, dg_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output( - B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), - getattr(torch, state_dtype)) + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) BS = chunk_size dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda() dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda() @@ -480,28 +454,55 @@ def run_test( dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda() # ref - dk_ref, dv_ref, dbeta_ref, dg_ref = bwd_prepare_wy_repr( - K, V, G, Beta, A, dw, du, cu_seqlens=None) + dk_ref, dv_ref, dbeta_ref, dg_ref = bwd_prepare_wy_repr(K, V, G, Beta, A, dw, du, cu_seqlens=None) # tilelang - kernel = tilelang_wy_fast_bwd(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, state_dtype, chunk_size, block_DK, block_DV, threads, - num_stages) - dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel( - K, V, Beta, G, A, dw, du) + kernel = tilelang_wy_fast_bwd( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(K, V, Beta, G, A, dw, du) torch.cuda.synchronize() - kernel_split = tilelang_wy_fast_bwd_split(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, chunk_size, - block_DK, block_DV, threads, num_stages) - kernel_split(K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, - dg_tilelang_A_positive, dg_tilelang_A_negative) + kernel_split = tilelang_wy_fast_bwd_split( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + kernel_split( + K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, dg_tilelang_A_positive, dg_tilelang_A_negative + ) torch.cuda.synchronize() dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang - dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum( - dim=-1) + dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(dim=-1) + + from test_utils import assert_similar - from utils import assert_similar assert_similar(dk_ref, dk_tilelang, eps=1e-5, name="dk", raise_assert=False) assert_similar(dv_ref, dv_tilelang, eps=1e-5, name="dv", raise_assert=False) assert_similar(dbeta_ref, dbeta_tilelang, eps=1e-5, name="dbeta", raise_assert=False) @@ -517,11 +518,11 @@ def main(): H=8, DK=DK, DV=DV, - input_dtype="bfloat16", - output_dtype="bfloat16", - accum_dtype="float32", - gate_dtype="float32", - state_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, chunk_size=64, block_DK=32, block_DV=32, diff --git a/examples/gdn/test_example_gdn_compilation.py b/examples/gdn/test_example_gdn_compilation.py index e184dbcac..6f9fa5d2f 100644 --- a/examples/gdn/test_example_gdn_compilation.py +++ b/examples/gdn/test_example_gdn_compilation.py @@ -1,16 +1,16 @@ -import tilelang.testing import torch +from tilelang import language as T B = 1 S = 1024 # small but for test only. H = 32 DK = 128 DV = 128 -input_dtype = "bfloat16" -output_dtype = "bfloat16" -accum_dtype = "float32" -gate_dtype = "float32" -state_dtype = "float32" +input_dtype = T.bfloat16 +output_dtype = T.bfloat16 +accum_dtype = T.float32 +gate_dtype = T.float32 +state_dtype = T.float32 chunk_size = 64 use_g = True use_initial_state = True @@ -20,21 +20,15 @@ block_DK = 64 block_DV = 32 threads = 128 -num_stages = 1 +num_stages = 0 def test_example_wy_fast_compilation(): from example_wy_fast import tilelang_recompute_w_u_fwd, prepare_input + K, V, Beta, G, A = prepare_input( - B, - S, - H, - DK, - DV, - chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - gate_dtype=getattr(torch, gate_dtype)) + B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, output_dtype), gate_dtype=getattr(torch, gate_dtype) + ) # tilelang block_S = chunk_size kernel = tilelang_recompute_w_u_fwd( @@ -52,22 +46,31 @@ def test_example_wy_fast_compilation(): block_DK=block_DK, block_DV=block_DV, threads=threads, - num_stages=num_stages) + num_stages=num_stages, + ) print(kernel.get_kernel_source()) W_tilelang, U_tilelang = kernel(K, V, Beta, G, A) def test_example_wy_fast_bwd_split_compilation(): from example_wy_fast_bwd_split import tilelang_wy_fast_bwd, tilelang_wy_fast_bwd_split, prepare_input, prepare_output - K, V, Beta, G, A, dw, du = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, - accum_dtype), getattr(torch, gate_dtype), - getattr(torch, state_dtype)) + + K, V, Beta, G, A, dw, du = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output( - B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), - getattr(torch, state_dtype)) + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) BS = chunk_size dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda() dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda() @@ -75,67 +78,146 @@ def test_example_wy_fast_bwd_split_compilation(): dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda() # tilelang - kernel = tilelang_wy_fast_bwd(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, state_dtype, chunk_size, block_DK, block_DV, threads, - num_stages) - dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel( - K, V, Beta, G, A, dw, du) + kernel = tilelang_wy_fast_bwd( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(K, V, Beta, G, A, dw, du) torch.cuda.synchronize() - kernel_split = tilelang_wy_fast_bwd_split(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, chunk_size, - block_DK, block_DV, threads, num_stages) - kernel_split(K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, - dg_tilelang_A_positive, dg_tilelang_A_negative) + kernel_split = tilelang_wy_fast_bwd_split( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + kernel_split( + K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, dg_tilelang_A_positive, dg_tilelang_A_negative + ) torch.cuda.synchronize() dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang - dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum( - dim=-1) + dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(dim=-1) def test_example_chunk_o_compilation(): from example_chunk_o import tilelang_chunk_fwd_o, prepare_input - Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), - getattr(torch, output_dtype), getattr(torch, accum_dtype), - getattr(torch, gate_dtype)) + + Q, K, V, HIDDEN, G = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + ) scale = 1.0 / DK**0.5 block_S = chunk_size - kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV, - threads, num_stages) + kernel = tilelang_chunk_fwd_o( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + chunk_size, + scale, + use_g, + block_S, + block_DK, + block_DV, + threads, + num_stages, + ) O_tilelang = kernel(Q, K, V, HIDDEN, G) # noqa: F841 def test_example_chunk_o_bwd_compilation(): from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg, prepare_input - Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, state_dtype, chunk_size, 1.0, use_g, True, - block_DK, block_DV, threads, num_stages) - dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, - W) # noqa: F841 + + Q, K, V, h, G, dO, dh, dv, W = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + kernel = tilelang_chunk_o_bwd_dqkwg( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + 1.0, + use_g, + True, + block_DK, + block_DV, + threads, + num_stages, + ) + + dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W) # noqa: F841 if use_g: dg_tilelang = dg_tilelang.sum(dim=0) def test_example_chunk_scaled_dot_kkt_compilation(): from example_chunk_scaled_dot_kkt import tilelang_chunk_scaled_dot_kkt_fwd, prepare_input - K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), - getattr(torch, output_dtype), getattr(torch, accum_dtype)) + + K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype)) block_S = chunk_size - kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype, - accum_dtype, use_g, block_S, block_DK, threads, - num_stages) + kernel = tilelang_chunk_scaled_dot_kkt_fwd( + B, S, H, DK, chunk_size, input_dtype, output_dtype, accum_dtype, use_g, block_S, block_DK, threads, num_stages + ) A_tilelang = kernel(K, Beta, G) # noqa: F841 def test_example_cumsum_compilation(): from example_cumsum import tilelang_chunk_local_cumsum_scalar, prepare_cumsum_input, prepare_cumsum_output + G = prepare_cumsum_input(B, S, H, getattr(torch, gate_dtype)) G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, gate_dtype)) block_S = chunk_size @@ -157,35 +239,82 @@ def test_example_cumsum_compilation(): def test_example_chunk_delta_h_compilation(): from example_chunk_delta_h import tilelang_chunk_gated_delta_rule_fwd_h, prepare_input - K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype)) - kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, chunk_size, - use_g, use_initial_state, store_final_state, - save_new_value, block_DK, block_DV, threads, - num_stages) - h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, - initial_state) # noqa: F841 + + K, W, U, G, initial_state = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + ) + kernel = tilelang_chunk_gated_delta_rule_fwd_h( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_g, + use_initial_state, + store_final_state, + save_new_value, + block_DK, + block_DV, + threads, + num_stages, + ) + h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) # noqa: F841 def test_example_chunk_delta_bwd_compilation(): from example_chunk_delta_bwd import tilelang_chunk_gated_delta_rule_bwd_dhu, prepare_input - Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, - chunk_size, 1.0, use_g, use_initial_state, - use_final_state_gradient, block_DV, threads, - num_stages) + + Q, K, W, G, h0, dht, dO, dv = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + kernel = tilelang_chunk_gated_delta_rule_bwd_dhu( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + 1.0, + use_g, + use_initial_state, + use_final_state_gradient, + block_DV, + threads, + num_stages, + ) dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) # noqa: F841 if __name__ == "__main__": - tilelang.testing.main() + # tilelang.testing.main() + test_example_chunk_delta_bwd_compilation() diff --git a/examples/gdn/test_utils.py b/examples/gdn/test_utils.py new file mode 100644 index 000000000..3588551ce --- /dev/null +++ b/examples/gdn/test_utils.py @@ -0,0 +1,38 @@ +import torch + + +def print_red_warning(message): + print(f"\033[31mWARNING: {message}\033[0m") + + +def calc_sim(x, y, name="tensor"): + x, y = x.data.double(), y.data.double() + denominator = (x * x + y * y).sum() + if denominator == 0: + print_red_warning(f"{name} all zero") + return 1 + sim = 2 * (x * y).sum() / denominator + return sim + + +def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True): + x_mask = torch.isfinite(x) + y_mask = torch.isfinite(y) + if not torch.all(x_mask == y_mask): + print_red_warning(f"{name} Error: isfinite mask mismatch") + if raise_assert: + raise AssertionError + if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all(): + print_red_warning(f"{name} Error: nonfinite value mismatch") + if raise_assert: + raise AssertionError + x = x.masked_fill(~x_mask, 0) + y = y.masked_fill(~y_mask, 0) + sim = calc_sim(x, y, name) + diff = 1.0 - sim + if not (0 <= diff <= eps): + print_red_warning(f"{name} Error: {diff}") + if raise_assert: + raise AssertionError + else: + print(f"{name} {data} passed") diff --git a/examples/gdn/utils.py b/examples/gdn/utils.py index 37f8d8e69..3588551ce 100644 --- a/examples/gdn/utils.py +++ b/examples/gdn/utils.py @@ -9,7 +9,7 @@ def calc_sim(x, y, name="tensor"): x, y = x.data.double(), y.data.double() denominator = (x * x + y * y).sum() if denominator == 0: - print_red_warning(f'{name} all zero') + print_red_warning(f"{name} all zero") return 1 sim = 2 * (x * y).sum() / denominator return sim @@ -19,21 +19,19 @@ def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True): x_mask = torch.isfinite(x) y_mask = torch.isfinite(y) if not torch.all(x_mask == y_mask): - print_red_warning(f'{name} Error: isfinite mask mismatch') + print_red_warning(f"{name} Error: isfinite mask mismatch") if raise_assert: raise AssertionError - if not torch.isclose( - x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, - equal_nan=True).all(): - print_red_warning(f'{name} Error: nonfinite value mismatch') + if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all(): + print_red_warning(f"{name} Error: nonfinite value mismatch") if raise_assert: raise AssertionError x = x.masked_fill(~x_mask, 0) y = y.masked_fill(~y_mask, 0) sim = calc_sim(x, y, name) - diff = 1. - sim + diff = 1.0 - sim if not (0 <= diff <= eps): - print_red_warning(f'{name} Error: {diff}') + print_red_warning(f"{name} Error: {diff}") if raise_assert: raise AssertionError else: diff --git a/examples/gemm/README.md b/examples/gemm/README.md index 059d08c84..9ab7fb661 100644 --- a/examples/gemm/README.md +++ b/examples/gemm/README.md @@ -4,20 +4,23 @@ TileLang is a domain-specific language designed to simplify the process of writi ## Table of Contents -1. [Getting Started](#getting-started) -2. [Simple GEMM Example](#simple-gemm-example) - - [Code Walkthrough](#code-walkthrough) - - [Compiling and Profiling](#compiling-and-profiling) -3. [Advanced GEMM Features](#advanced-gemm-features) - - [Custom Memory Layout / Swizzling](#custom-memory-layout--swizzling) - - [Parallel Copy and Auto-Pipelining](#parallel-copy-and-auto-pipelining) - - [Rasterization for L2 Cache Locality](#rasterization-for-l2-cache-locality) -4. [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations) -5. [Verifying Correctness](#verifying-correctness) -6. [Fine-grained MMA Computations](#fine-grained-mma-computations) - - [Example Workflow](#example-workflow) - - [Summary](#summary) -7. [References](#references) +- [Table of Contents](#table-of-contents) +- [Getting Started](#getting-started) + - [Prerequisites](#prerequisites) + - [Installation](#installation) +- [Simple GEMM Example](#simple-gemm-example) + - [Code Walkthrough](#code-walkthrough) + - [Compiling and Profiling](#compiling-and-profiling) +- [Advanced GEMM Features](#advanced-gemm-features) + - [Custom Memory Layout / Swizzling](#custom-memory-layout--swizzling) + - [Parallel Copy and Auto-Pipelining](#parallel-copy-and-auto-pipelining) + - [Rasterization for L2 Cache Locality](#rasterization-for-l2-cache-locality) +- [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations) +- [Verifying Correctness](#verifying-correctness) +- [Fine-grained MMA Computations](#fine-grained-mma-computations) + - [Example Workflow](#example-workflow) + - [Summary](#summary) +- [References](#references) --- @@ -25,10 +28,10 @@ TileLang is a domain-specific language designed to simplify the process of writi ### Prerequisites -- **Python 3.8+** -- **NVIDIA GPU** with a recent CUDA toolkit installed +- **Python 3.8+** +- **NVIDIA GPU** with a recent CUDA toolkit installed - **PyTorch** (optional, for easy correctness verification) -- **tilelang** +- **tilelang** - **bitblas** (optional; used for swizzle layout utilities in the advanced examples) ### Installation @@ -50,7 +53,7 @@ import tilelang from tilelang import Profiler import tilelang.language as T -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float): @T.prim_func def main( A: T.Tensor((M, K), dtype), @@ -87,26 +90,26 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ### Code Walkthrough -1. **Define the Kernel Launch Configuration:** +1. **Define the Kernel Launch Configuration:** ```python with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): ``` This creates a grid of blocks (ceildiv(N, block_N) in x-dimension, ceildiv(M, block_M) in y-dimension), each with 128 threads. -2. **Shared Memory Allocation:** +2. **Shared Memory Allocation:** ```python A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype) ``` Tiles of \(A\) and \(B\) are loaded into these shared memory buffers for faster access. -3. **Local Fragment Accumulation:** +3. **Local Fragment Accumulation:** ```python C_local = T.alloc_fragment((block_M, block_N), accum_dtype) ``` Partial results are stored in registers (or local memory) to reduce writes to global memory. -4. **Pipelined Loading and GEMM:** +4. **Pipelined Loading and GEMM:** ```python for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): T.copy(...) @@ -114,7 +117,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ``` Loads blocks of \(A\) and \(B\) in a pipelined fashion (up to 3 stages). This exploits overlap of data transfer and computation. -5. **Copy Out the Results:** +5. **Copy Out the Results:** ```python T.copy(C_local, C[by * block_M, bx * block_N]) ``` @@ -173,7 +176,7 @@ import tilelang.language as T # that helps align data for MMA (Matrix Multiply-Accumulate) operations. from tilelang.intrinsics import make_mma_swizzle_layout as make_swizzle_layout -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float): @T.prim_func def main( A: T.Tensor((M, K), dtype), @@ -216,10 +219,10 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo return main ``` -**Key Differences vs. Basic Example** -1. **`T.annotate_layout(...)`**: Annotates how data should be organized in shared memory (swizzling). -2. **`T.use_swizzle(...)`**: Enables swizzle-based rasterization. -3. **Parallel Copy Loop** with `T.Parallel(...)`: Distributes global-to-shared copy across all threads, potentially vectorizing load/store instructions. +**Key Differences vs. Basic Example** +1. **`T.annotate_layout(...)`**: Annotates how data should be organized in shared memory (swizzling). +2. **`T.use_swizzle(...)`**: Enables swizzle-based rasterization. +3. **Parallel Copy Loop** with `T.Parallel(...)`: Distributes global-to-shared copy across all threads, potentially vectorizing load/store instructions. --- @@ -247,7 +250,7 @@ print("Results match!") ## Fine-grained MMA Computations -For advanced users who require full control over warp-level matrix multiplication operations, TileLang allows you to specify fine-grained MMA (Matrix Multiply-Accumulate) computations in a manner similar to writing raw CUDA. While higher-level abstractions like `T.gemm(...)` or automatic MMA emitters are sufficient for many use cases, specialized workloads (for example, dequantize gemm may require fine-grained layout transformation on shared to register stage) may benefit from explicitly controlling each MMA instruction, the data layout, and the synchronization points. +For advanced users who require full control over warp-level matrix multiplication operations, TileLang allows you to specify fine-grained MMA (Matrix Multiply-Accumulate) computations in a manner similar to writing raw CUDA. While higher-level abstractions like `T.gemm(...)` or automatic MMA emitters are sufficient for many use cases, specialized workloads (for example, dequantize gemm may require fine-grained layout transformation on shared to register stage) may benefit from explicitly controlling each MMA instruction, the data layout, and the synchronization points. ### Example Workflow @@ -262,18 +265,18 @@ def tl_matmul( accum_dtype, ): assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 - if out_dtype == "int32": + if out_dtype == T.int32: micro_size_k = 32 # This is a debug config @@ -394,10 +397,10 @@ def tl_matmul( ] ``` -1. **Set Up Tile Sizes and Thread Bindings** +1. **Set Up Tile Sizes and Thread Bindings** Just like in CUDA, you will typically start by defining how many warps or threads per block you want and how your matrix is subdivided. In TileLang, this is done via `T.Kernel(...)` and `T.thread_binding(...),` which ensure that the correct number of threads are active, and each thread is bound to a specific role (e.g., warp ID or lane ID). -2. **Allocate Warp-local Fragments** +2. **Allocate Warp-local Fragments** Instead of using a single shared buffer for partial sums, you allocate local buffers (register fragments) to hold sub-blocks of matrices \(A\) and \(B\). In TileLang, this is done with something like: ```python A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) @@ -406,7 +409,7 @@ def tl_matmul( ``` Each of these `local` allocations represents a region of per-thread storage, which collectively forms the warp’s register tiles. -3. **Load Data via `ldmatrix`** +3. **Load Data via `ldmatrix`** Fine-grained loading instructions allow you to specify exactly how data moves from shared memory to the warp-level fragments. In the example below, `mma_emitter.ldmatrix_a()` and `.ldmatrix_b()` are higher-level wrappers around warp-synchronous intrinsics. You can write your own load logic as well: ```python for ki in T.serial(0, (block_K // micro_size_k)): @@ -418,7 +421,7 @@ def tl_matmul( ``` Internally, these calls orchestrate how each thread in the warp issues the correct load instructions, performs address calculations, and stores the data into registers. -4. **Perform the MMA Instruction** +4. **Perform the MMA Instruction** After loading sub-tiles (fragments), the warp executes the `mma` instruction. This operation is essentially: \[ C_{\text{local}} \;+=\; A_{\text{local}} \;\times\; B_{\text{local}} @@ -429,7 +432,7 @@ def tl_matmul( ``` Under the hood, this translates into Tensor Core instructions (e.g., `wmma.mma.sync` in PTX), which process multiple data elements per warp in parallel. -5. **Store Results via `stmatrix`** +5. **Store Results via `stmatrix`** Finally, you write the results from the warp-level fragments back to shared memory or global memory. This step might happen multiple times in a loop or just once at the end. The code snippet: ```python mma_emitter.stmatrix(C_local, C_shared) @@ -444,6 +447,6 @@ By combining warp-synchronous intrinsics (`ldmatrix`, `mma`, `stmatrix`) with ma ## References -- [NVIDIA CUTLASS Library](https://github.com/NVIDIA/cutlass): A collection of high-performance CUDA C++ template abstractions for GEMM. -- [NVIDIA CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html): Official documentation for CUDA. +- [NVIDIA CUTLASS Library](https://github.com/NVIDIA/cutlass): A collection of high-performance CUDA C++ template abstractions for GEMM. +- [NVIDIA CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html): Official documentation for CUDA. - [PyTorch Documentation](https://pytorch.org/docs): For verifying correctness via CPU or GPU-based matmul. diff --git a/examples/gemm/example_gemm.py b/examples/gemm/example_gemm.py index f18cd388a..dfa431121 100644 --- a/examples/gemm/example_gemm.py +++ b/examples/gemm/example_gemm.py @@ -3,13 +3,12 @@ @tilelang.jit(out_idx=[-1]) -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def gemm( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -58,5 +57,11 @@ def main(): print(f"tilelang Latency: {latency}ms") +def run_regression_perf(): + kernel = matmul(1024, 1024, 1024, 128, 128, 32) + profiler = kernel.get_profiler() + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/gemm/example_gemm_autotune.py b/examples/gemm/example_gemm_autotune.py index 661ef1276..016d448a4 100644 --- a/examples/gemm/example_gemm_autotune.py +++ b/examples/gemm/example_gemm_autotune.py @@ -51,9 +51,9 @@ def get_configs(M, N, K, with_roller=False, topk=20): M=M, N=N, K=K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float", + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float32, ).with_arch(arch) func = carve_template.equivalent_function() @@ -90,7 +90,8 @@ def get_configs(M, N, K, with_roller=False, topk=20): num_stages, thread_num, enable_rasterization, - )) + ) + ) configs = [ { @@ -100,13 +101,13 @@ def get_configs(M, N, K, with_roller=False, topk=20): "num_stages": c[3], "thread_num": c[4], "enable_rasteration": c[5], # keep param name for backward-compat - } for c in _configs + } + for c in _configs ] return configs def get_best_config(M, N, K, with_roller=False): - def kernel( block_M=None, block_N=None, @@ -115,17 +116,16 @@ def kernel( thread_num=None, enable_rasteration=None, ): - dtype = "bfloat16" - accum_dtype = "float" + dtype = T.bfloat16 + accum_dtype = T.float32 @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -146,15 +146,18 @@ def main( return main - autotuner = AutoTuner.from_kernel( - kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args( + autotuner = ( + AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N, K, with_roller)) + .set_compile_args( out_idx=[-1], target="auto", - ).set_profile_args( + ) + .set_profile_args( supply_type=tl.TensorSupplyType.Integer, ref_prog=ref_program, skip_check=False, ) + ) return autotuner.run(warmup=3, rep=20) @@ -167,52 +170,20 @@ def get_heuristic_config() -> dict: sm_version = sm_major * 10 + sm_minor print(f"CUDA device capability: {sm_version}") if sm_version in {80}: - return { - "block_M": 128, - "block_N": 256, - "block_K": 32, - "num_stages": 2, - "thread_num": 128, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True} elif sm_version in {90}: - return { - "block_M": 128, - "block_N": 256, - "block_K": 64, - "num_stages": 3, - "thread_num": 256, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True} else: - return { - "block_M": 128, - "block_N": 256, - "block_K": 32, - "num_stages": 0, - "thread_num": 128, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True} @tl.jit(out_idx=[-1]) -def matmul(M, - N, - K, - block_M, - block_N, - block_K, - num_stages, - thread_num, - enable_rasteration, - dtype="float16", - accum_dtype="float"): - +def matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def gemm_autotune( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -236,11 +207,7 @@ def gemm_autotune( return gemm_autotune -def main(M: int = 4096, - N: int = 4096, - K: int = 4096, - use_autotune: bool = False, - with_roller: bool = False): +def main(M: int = 4096, N: int = 4096, K: int = 4096, use_autotune: bool = False, with_roller: bool = False): use_autotune = True if use_autotune: result = get_best_config(M, N, K, with_roller) @@ -261,20 +228,19 @@ def main(M: int = 4096, print(f"Ref TFlops: {2 * M * N * K / ref_latency * 1e-9}") +def run_regression_perf(M: int = 4096, N: int = 4096, K: int = 4096): + config = get_heuristic_config() + kernel = matmul(M, N, K, **config) + profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") parser.add_argument("--m", type=int, default=4096, help="Matrix dimension M") parser.add_argument("--n", type=int, default=4096, help="Matrix dimension N") parser.add_argument("--k", type=int, default=4096, help="Matrix dimension K") - parser.add_argument( - "--use_autotune", - action="store_true", - default=False, - help="Whether to use autotune for matmul configs") - parser.add_argument( - "--with_roller", - action="store_true", - default=False, - help="Whether to enable BitBLAS roller for search space") + parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs") + parser.add_argument("--with_roller", action="store_true", default=False, help="Whether to enable BitBLAS roller for search space") args = parser.parse_args() main(args.m, args.n, args.k, args.use_autotune, args.with_roller) diff --git a/examples/gemm/example_gemm_intrinsics.py b/examples/gemm/example_gemm_intrinsics.py index 5c014ce3a..d4bc9480f 100644 --- a/examples/gemm/example_gemm_intrinsics.py +++ b/examples/gemm/example_gemm_intrinsics.py @@ -4,7 +4,8 @@ import tilelang.language as T from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func @@ -34,18 +35,18 @@ def tl_matmul( accum_dtype, ): assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 - if out_dtype == "int32": + if out_dtype == T.int32: micro_size_k = 32 # This is a debug config @@ -53,7 +54,7 @@ def tl_matmul( block_col_warps = 2 warp_row_tiles = 64 warp_col_tiles = 64 - # chunk = 32 if in_dtype == "float16" else 64 + # chunk = 32 if in_dtype == T.float16 else 64 chunk = 32 shared_scope = "shared.dyn" @@ -99,12 +100,11 @@ def tl_matmul( @T.prim_func def gemm_intrinsics( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -112,10 +112,12 @@ def gemm_intrinsics( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -123,7 +125,6 @@ def gemm_intrinsics( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -133,7 +134,6 @@ def gemm_intrinsics( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a(A_local, A_shared, ki) @@ -163,7 +163,7 @@ def ref_program(A, B): def main(M=4096, N=4096, K=4096): - in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32" + in_dtype, out_dtype, accum_dtype = T.float16, T.float16, T.float32 kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) src_code = kernel.get_kernel_source() # src_code is the generated cuda source @@ -181,5 +181,12 @@ def main(M=4096, N=4096, K=4096): profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) +def run_regression_perf(M=4096, N=4096, K=4096): + in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32" + kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + profiler = kernel.get_profiler() + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main(M=4096, N=4096, K=4096) diff --git a/examples/gemm/example_gemm_persistent.py b/examples/gemm/example_gemm_persistent.py index a2a7122d3..ad3d556ed 100644 --- a/examples/gemm/example_gemm_persistent.py +++ b/examples/gemm/example_gemm_persistent.py @@ -5,22 +5,12 @@ @tilelang.jit(out_idx=[-1]) -def matmul_non_persistent(M, - N, - K, - block_M, - block_N, - block_K, - threads, - num_stages, - dtype="float16", - accum_dtype="float"): - +def matmul_non_persistent(M, N, K, block_M, block_N, block_K, threads, num_stages, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -43,18 +33,9 @@ def main( @tilelang.jit(out_idx=[-1]) -def matmul_persistent(M, - N, - K, - block_M, - block_N, - block_K, - threads, - num_stages, - dtype="float16", - accum_dtype="float", - use_persistent_primitive=True): - +def matmul_persistent( + M, N, K, block_M, block_N, block_K, threads, num_stages, dtype=T.float16, accum_dtype=T.float32, use_persistent_primitive=True +): sm_num = driver.get_num_sms() m_blocks = T.ceildiv(M, block_M) n_blocks = T.ceildiv(N, block_N) @@ -63,9 +44,9 @@ def matmul_persistent(M, @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(sm_num, threads=threads) as (block_id): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -90,9 +71,9 @@ def main( @T.prim_func def main_persistent_primitive( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(sm_num, threads=threads) as (block_id): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -100,8 +81,7 @@ def main_persistent_primitive( C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), dtype) - for bx, by in T.Persistent( - [T.ceildiv(M, block_M), T.ceildiv(N, block_N)], sm_num, block_id): + for bx, by in T.Persistent([T.ceildiv(M, block_M), T.ceildiv(N, block_N)], sm_num, block_id): T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[bx * block_M, k * block_K], A_shared) @@ -128,18 +108,15 @@ def main(M=4096, N=4096, K=4096): num_stages = 3 persistent_kernel = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages) - persistent_profiler = persistent_kernel.get_profiler( - tensor_supply_type=tilelang.TensorSupplyType.Randn) + persistent_profiler = persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("Persistent GEMM: All check passed.") persistent_latency = persistent_profiler.do_bench(warmup=500) print(f"Persistent GEMM Latency: {persistent_latency} ms") print(f"Persistent GEMM TFlops: {total_flops / persistent_latency * 1e-9} TFlops") - non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, - num_stages) - non_persistent_profiler = non_persistent_kernel.get_profiler( - tensor_supply_type=tilelang.TensorSupplyType.Randn) + non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages) + non_persistent_profiler = non_persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) non_persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("Non-Persistent GEMM: All check passed.") non_persistent_latency = non_persistent_profiler.do_bench(warmup=500) @@ -149,11 +126,22 @@ def main(M=4096, N=4096, K=4096): print(f"Persistent GEMM Speedup: {non_persistent_latency / persistent_latency}") +def run_regression_perf(M=4096, N=4096, K=4096): + BLOCK_M = 128 + BLOCK_N = 256 + BLOCK_K = 64 + threads = 256 + num_stages = 3 + persistent_kernel = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages) + persistent_profiler = persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + return persistent_profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--M', type=int, default=8192, help='M dimension') - parser.add_argument('--N', type=int, default=8192, help='N dimension') - parser.add_argument('--K', type=int, default=8192, help='K dimension') + parser.add_argument("--M", type=int, default=8192, help="M dimension") + parser.add_argument("--N", type=int, default=8192, help="N dimension") + parser.add_argument("--K", type=int, default=8192, help="K dimension") args = parser.parse_args() M, N, K = args.M, args.N, args.K main(M, N, K) diff --git a/examples/gemm/example_gemm_schedule.py b/examples/gemm/example_gemm_schedule.py index f4727412b..17dbcc568 100644 --- a/examples/gemm/example_gemm_schedule.py +++ b/examples/gemm/example_gemm_schedule.py @@ -3,13 +3,12 @@ @tilelang.jit(out_idx=[-1]) -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def gemm_schedule( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -65,5 +64,19 @@ def main(): print(kernel.get_kernel_source()) +def run_regression_perf(): + kernel = matmul(1024, 1024, 1024, 128, 128, 32) + import torch + + a = torch.randn(1024, 1024).cuda().half() + b = torch.randn(1024, 1024).cuda().half() + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(a, b) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/gemm/regression_example_gemm.py b/examples/gemm/regression_example_gemm.py new file mode 100644 index 000000000..3583cf16a --- /dev/null +++ b/examples/gemm/regression_example_gemm.py @@ -0,0 +1,25 @@ +import tilelang.testing +import example_gemm +import example_gemm_autotune +import example_gemm_intrinsics +import example_gemm_schedule + + +def regression_example_gemm_autotune(): + tilelang.testing.process_func(example_gemm_autotune.run_regression_perf, M=1024, N=1024, K=1024) + + +def regression_example_gemm_intrinsics(): + tilelang.testing.process_func(example_gemm_intrinsics.run_regression_perf, M=1024, N=1024, K=1024) + + +def regression_example_gemm_schedule(): + tilelang.testing.process_func(example_gemm_schedule.run_regression_perf) + + +def regression_example_gemm(): + tilelang.testing.process_func(example_gemm.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/gemm_fp8/README.md b/examples/gemm_fp8/README.md index 9d7011a06..2b3dc9560 100644 --- a/examples/gemm_fp8/README.md +++ b/examples/gemm_fp8/README.md @@ -1 +1 @@ -**Notes**: Now we only support fp8 with mma instructions instead of `T.gemm`, because the cutlass version of tilelang is too old, we should update the cutlass version in future. \ No newline at end of file +**Notes**: Now we only support fp8 with mma instructions instead of `T.gemm`, because the cutlass version of tilelang is too old, we should update the cutlass version in future. diff --git a/examples/gemm_fp8/example_tilelang_gemm_amd.py b/examples/gemm_fp8/example_tilelang_gemm_amd.py index 0e6ace757..93f8c4980 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_amd.py +++ b/examples/gemm_fp8/example_tilelang_gemm_amd.py @@ -17,10 +17,8 @@ def supply_prog(args): a_param, b_param = args M, K = a_param.shape N, _ = b_param.shape - a = (torch.randn(M, K, dtype=torch.float16, device='cuda') * - 0.01).to(dtype=torch.float8_e4m3fnuz) - b = (torch.randn(N, K, dtype=torch.float16, device='cuda') * - 0.01).to(dtype=torch.float8_e4m3fnuz) + a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) + b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) return [a, b] @@ -35,40 +33,36 @@ def get_configs(): valid_configs = [] - for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks, - num_stages, num_threads, k_packs, - gemm_types): - valid_configs.append({ - "block_M": m, - "block_N": n, - "block_K": k, - "num_stages": stages, - "num_threads": t, - "k_pack": kp, - "gemm_type": gemm_type, - }) + for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks, num_stages, num_threads, k_packs, gemm_types): + valid_configs.append( + { + "block_M": m, + "block_N": n, + "block_K": k, + "num_stages": stages, + "num_threads": t, + "k_pack": kp, + "gemm_type": gemm_type, + } + ) return valid_configs @tilelang.autotune( - configs=get_configs(), - cache_input_tensors=True, - ref_prog=ref_program, - manual_check_prog=manual_check_prog, - supply_prog=supply_prog) + configs=get_configs(), cache_input_tensors=True, ref_prog=ref_program, manual_check_prog=manual_check_prog, supply_prog=supply_prog +) @tilelang.jit(out_idx=[-1]) def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type): - dtype = "float8_e4m3fnuz" - accum_dtype = "float" + dtype = T.float8_e4m3fnuz + accum_dtype = T.float32 @T.prim_func def gemm_fp8_rs( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), accum_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), accum_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): A_local = T.alloc_fragment((block_M, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -77,24 +71,17 @@ def gemm_fp8_rs( for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_local) T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm( - A_local, - B_shared, - C_local, - transpose_B=True, - k_pack=k_pack, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(A_local, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow) T.copy(C_local, C[by * block_M, bx * block_N]) @T.prim_func def gemm_fp8_ss( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), accum_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), accum_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -103,13 +90,7 @@ def gemm_fp8_ss( for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm( - A_shared, - B_shared, - C_local, - transpose_B=True, - k_pack=k_pack, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(A_shared, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow) T.copy(C_local, C[by * block_M, bx * block_N]) @@ -123,10 +104,8 @@ def gemm_fp8_ss( def test_gemm_fp8(M, N, K): kernel = fp8_matmul(M, N, K) - a = (torch.randn(M, K, dtype=torch.float16, device='cuda') * - 0.01).to(dtype=torch.float8_e4m3fnuz) - b = (torch.randn(N, K, dtype=torch.float16, device='cuda') * - 0.01).to(dtype=torch.float8_e4m3fnuz) + a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) + b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) c = kernel(a, b) ref_c = ref_program(a, b) torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2) diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8.py b/examples/gemm_fp8/example_tilelang_gemm_fp8.py index a403ed068..086997975 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8.py @@ -1,7 +1,6 @@ import torch import tilelang import tilelang.language as T -from tilelang.utils.tensor import map_torch_type def calc_diff(x, y): @@ -12,13 +11,12 @@ def calc_diff(x, y): @tilelang.jit(out_idx=[-1]) -def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): - +def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype=T.float32): @T.prim_func def gemm_fp8( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -37,12 +35,12 @@ def gemm_fp8( def test_gemm_fp8(M, N, K, dtype): - torch_dtype = map_torch_type(dtype) + torch_dtype = T.dtype(dtype).as_torch() kernel = matmul(M, N, K, 128, 128, 64, dtype) - a = torch.randn(M, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype) - b = torch.randn(N, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype) + a = torch.randn(M, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype) + b = torch.randn(N, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype) c = kernel(a, b) @@ -57,8 +55,21 @@ def test_gemm_fp8(M, N, K, dtype): def main(): - test_gemm_fp8(1024, 1024, 1024, 'float8_e4m3') - test_gemm_fp8(1024, 1024, 1024, 'float8_e5m2') + test_gemm_fp8(1024, 1024, 1024, T.float8_e4m3fn) + test_gemm_fp8(1024, 1024, 1024, T.float8_e5m2) + + +def run_regression_perf(): + M, N, K = 4096, 4096, 4096 + dtype = "float8_e4m3" + kernel_e4m3 = matmul(M, N, K, 128, 128, 64, dtype) + profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e4m3 = profiler_e4m3.do_bench(backend="cupti") + dtype = "float8_e5m2" + kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype) + profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e5m2 = profiler_e5m2.do_bench(backend="cupti") + return (latency_e4m3 + latency_e5m2) / 2 if __name__ == "__main__": diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py index 1d9207aff..a702e8ae0 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py @@ -1,11 +1,10 @@ import torch import tilelang import tilelang.language as T -from tilelang.utils.tensor import map_torch_type @tilelang.jit(out_idx=[-1]) -def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype=T.float32): # for fp8 gemm, do one promote after 4 wgmma inst, i.e. block_K = 128. # if block_K < 128, promote after 128/block_K iters. # if block_K > 128, promote after every iter. @@ -13,9 +12,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): @T.prim_func def gemm_fp8_2xAcc( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), accum_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), accum_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -55,18 +54,18 @@ def calc_diff(x, y): def test_gemm_fp8(M, N, K, dtype): - torch_dtype = map_torch_type(dtype) + torch_dtype = T.dtype(dtype).as_torch() kernel = matmul(M, N, K, 128, 128, 64, dtype) - a = torch.rand(M, K, dtype=torch.float16, device='cuda') + a = torch.rand(M, K, dtype=torch.float16, device="cuda") a = (100 * (2 * a - 1)).to(dtype=torch_dtype) - b = torch.rand(N, K, dtype=torch.float16, device='cuda') + b = torch.rand(N, K, dtype=torch.float16, device="cuda") b = (100 * (2 * b - 1)).to(dtype=torch_dtype) c = kernel(a, b) - ref_c = (a.float() @ b.float().T) + ref_c = a.float() @ b.float().T diff = calc_diff(c, ref_c) print(f"diff: {diff}") @@ -74,8 +73,21 @@ def test_gemm_fp8(M, N, K, dtype): def main(): - test_gemm_fp8(1024, 1024, 8192, 'float8_e4m3') - test_gemm_fp8(1024, 1024, 8192, 'float8_e5m2') + test_gemm_fp8(1024, 1024, 8192, T.float8_e4m3fn) + test_gemm_fp8(1024, 1024, 8192, T.float8_e5m2) + + +def run_regression_perf(): + M, N, K = 1024, 1024, 8192 + dtype = "float8_e4m3" + kernel_e4m3 = matmul(M, N, K, 128, 128, 64, dtype) + profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e4m3 = profiler_e4m3.do_bench(backend="cupti") + dtype = "float8_e5m2" + kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype) + profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e5m2 = profiler_e5m2.do_bench(backend="cupti") + return (latency_e4m3 + latency_e5m2) / 2 if __name__ == "__main__": diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py index ed44aab69..762885ec3 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py @@ -5,7 +5,8 @@ import tilelang.language as T from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func from tilelang.utils.tensor import map_torch_type @@ -38,21 +39,26 @@ def tl_matmul( accum_dtype, ): assert in_dtype in [ - "float16", - "float8_e4m3", - "float8_e5m2", - "int8", + T.float16, + T.float8_e4m3fn, + T.float8_e5m2, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 - is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"] - if out_dtype == "int32" or is_float8: + is_float8 = in_dtype in [ + T.float8_e4m3fn, + T.float8_e5m2, + T.float8_e4m3fn, + T.float8_e5m2fnuz, + ] + if out_dtype == T.int32 or is_float8: micro_size_k = 32 # This is a debug config @@ -60,7 +66,7 @@ def tl_matmul( block_col_warps = 2 warp_row_tiles = 32 warp_col_tiles = 32 - chunk = 32 if in_dtype == "float16" else 64 + chunk = 32 if in_dtype == T.float16 else 64 shared_scope = "shared.dyn" # Pipeline Stage @@ -105,12 +111,11 @@ def tl_matmul( @T.prim_func def gemm_fp8_intrinsic( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -118,10 +123,12 @@ def gemm_fp8_intrinsic( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -129,7 +136,6 @@ def gemm_fp8_intrinsic( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -139,7 +145,6 @@ def gemm_fp8_intrinsic( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, @@ -215,8 +220,22 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): def main(): - assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3", "float32", "float32") - assert_tl_matmul_correctness(128, 128, 128, "float8_e5m2", "float32", "float32") + assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32) + assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32) + + +def run_regression_perf(): + M, N, K = 4096, 4096, 4096 + out_dtype, accum_dtype = "float32", "float32" + in_dtype = T.float8_e4m3fn + kernel_e4m3 = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e4m3 = profiler_e4m3.do_bench(backend="cupti") + in_dtype = T.float8_e5m2 + kernel_e5m2 = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e5m2 = profiler_e5m2.do_bench(backend="cupti") + return (latency_e4m3 + latency_e5m2) / 2 if __name__ == "__main__": diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py new file mode 100644 index 000000000..aa7e8b360 --- /dev/null +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py @@ -0,0 +1,124 @@ +import torch +import tilelang +import tilelang.language as T +from tilelang.utils.tensor import map_torch_type + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + mbar = T.alloc_barrier(1) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm_v2( + A_shared, + B_shared, + C_tmem, + trans_A, + trans_B, + mbar=mbar, + wg_wait=-1, + clear_accum=(k == 0), + ) + T.mbarrier_wait_parity(mbar, k % 2) + + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +M, N, K = 4096, 4096, 8192 +block_M, block_N, block_K = 64, 256, 32 +trans_A, trans_B = False, True +num_stages = 2 +threads = 256 +for tvm_fp8_dtype in [T.float8_e4m3fn, T.float8_e5m2]: + for tvm_acc_dtype in [T.float16, T.float32]: # , torch.float16]: + torch_fp8_dtype = map_torch_type(tvm_fp8_dtype) + torch_acc_dtype = map_torch_type(tvm_acc_dtype) + print(f"running {tvm_fp8_dtype} -> {tvm_acc_dtype}") + in_dtype, out_dtype, accum_dtype = tvm_fp8_dtype, tvm_acc_dtype, tvm_acc_dtype + + func = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, + ) + jit_kernel = tilelang.compile( + func, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT: True, + }, + ) + # jit_kernel.export_ptx("./dump.ptx") + # jit_kernel.export_sources("./dump.cu") + + a = torch.randn(M, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype) + b = torch.randn(N, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype) + + c = jit_kernel(a, b) + ref_c = (a.to(torch.half) @ b.T.to(torch.half)).float() + c = c.float() + diff = calc_diff(c, ref_c) + # assert diff < 1e-3, f"{diff}" + print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] diff = {diff}") + + profiler = jit_kernel.get_profiler() + latency = profiler.do_bench() + print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Latency: {latency} ms") + print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS") diff --git a/examples/gemm_fp8/regression_example_gemm_fp8.py b/examples/gemm_fp8/regression_example_gemm_fp8.py new file mode 100644 index 000000000..3ba2f4f27 --- /dev/null +++ b/examples/gemm_fp8/regression_example_gemm_fp8.py @@ -0,0 +1,20 @@ +import tilelang.testing +import example_tilelang_gemm_fp8 +import example_tilelang_gemm_fp8_2xAcc +import example_tilelang_gemm_fp8_intrinsic + + +def regression_example_tilelang_gemm_fp8_2xAcc(): + tilelang.testing.process_func(example_tilelang_gemm_fp8_2xAcc.run_regression_perf) + + +def regression_example_tilelang_gemm_fp8_intrinsic(): + tilelang.testing.process_func(example_tilelang_gemm_fp8_intrinsic.run_regression_perf) + + +def regression_example_tilelang_gemm_fp8(): + tilelang.testing.process_func(example_tilelang_gemm_fp8.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/gemm_fp8/test_example_gemm_fp8.py b/examples/gemm_fp8/test_example_gemm_fp8.py index 19a9ee00a..8a60d0e02 100644 --- a/examples/gemm_fp8/test_example_gemm_fp8.py +++ b/examples/gemm_fp8/test_example_gemm_fp8.py @@ -1,17 +1,30 @@ +import pytest +import torch import tilelang.testing import example_tilelang_gemm_fp8_2xAcc import example_tilelang_gemm_fp8_intrinsic import example_tilelang_gemm_fp8 +def requires_sm89(): + """FP8 tensor core MMA requires SM89 (Ada Lovelace) or higher.""" + major, minor = torch.cuda.get_device_capability() + return pytest.mark.skipif( + major < 9 and not (major == 8 and minor >= 9), reason="FP8 tensor core MMA requires SM89 or higher (Ada Lovelace/Hopper)" + ) + + +@requires_sm89() def test_example_tilelang_gemm_fp8_2xAcc(): example_tilelang_gemm_fp8_2xAcc.main() +@requires_sm89() def test_example_tilelang_gemm_fp8_intrinsic(): example_tilelang_gemm_fp8_intrinsic.main() +@requires_sm89() def test_example_tilelang_gemm_fp8(): example_tilelang_gemm_fp8.main() diff --git a/examples/gemm_sm100/README.md b/examples/gemm_sm100/README.md index 73dd76c30..d630d2d0d 100644 --- a/examples/gemm_sm100/README.md +++ b/examples/gemm_sm100/README.md @@ -40,19 +40,19 @@ import tilelang.language as T @T.prim_func def main( - A: T.Tensor((M, K), "bfloat16"), - B: T.Tensor((N, K), "bfloat16"), - C: T.Tensor((M, N), "bfloat16"), + A: T.Tensor((M, K), T.bfloat16), + B: T.Tensor((N, K), T.bfloat16), + C: T.Tensor((M, N), T.bfloat16), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): # 1. Allocate memory buffers - A_shared = T.alloc_shared((block_M, block_K), "bfloat16") # A matrix shared memory - B_shared = T.alloc_shared((block_N, block_K), "bfloat16") # B matrix shared memory - C_tmem = T.alloc_tmem([block_M, block_N], "float") # TCGEN5MMA output to Tensor Memory + A_shared = T.alloc_shared((block_M, block_K), T.bfloat16) # A matrix shared memory + B_shared = T.alloc_shared((block_N, block_K), T.bfloat16) # B matrix shared memory + C_tmem = T.alloc_tmem([block_M, block_N], T.float) # TCGEN5MMA output to Tensor Memory mbar = T.alloc_barrier(1) # mbarrier synchronization primitive - C_local = T.alloc_fragment((block_M, block_N), "float") # Register storage - C_shared = T.alloc_shared((block_M, block_N), "bfloat16") # Output shared memory + C_local = T.alloc_fragment((block_M, block_N), T.float) # Register storage + C_shared = T.alloc_shared((block_M, block_N), T.bfloat16) # Output shared memory # 2. Main computation loop for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): @@ -103,4 +103,3 @@ latency = profiler.do_bench() print(f"Latency: {latency} ms") print(f"Performance: {2 * M * N * K / (latency/1e3) / 1e12:.2f} TFLOPS") ``` - diff --git a/examples/gemm_sm100/gemm_mma.py b/examples/gemm_sm100/gemm_mma.py index a58e5a7c0..226e33c01 100644 --- a/examples/gemm_sm100/gemm_mma.py +++ b/examples/gemm_sm100/gemm_mma.py @@ -4,13 +4,12 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): @@ -62,7 +61,8 @@ def main( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, +) print(jit_kernel.get_kernel_source()) # 3. Test the kernel in Python with PyTorch data import torch diff --git a/examples/gemm_sm100/gemm_tcgen5mma.py b/examples/gemm_sm100/gemm_tcgen5mma.py index 9008c7ef5..523a94fea 100644 --- a/examples/gemm_sm100/gemm_tcgen5mma.py +++ b/examples/gemm_sm100/gemm_tcgen5mma.py @@ -25,9 +25,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -40,15 +40,7 @@ def main( for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm( - A_shared, - B_shared, - C_tmem, - trans_A, - trans_B, - mbar=mbar, - wg_wait=-1, - clear_accum=k == 0) + T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k == 0) T.mbarrier_wait_parity(mbar, k % 2) T.copy(C_tmem, C_local) @@ -62,12 +54,11 @@ def main( M, N, K = 4096, 4096, 8192 block_M, block_N, block_K = 128, 256, 128 trans_A, trans_B = False, True -in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float" +in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float num_stages = 2 threads = 256 -func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, - accum_dtype, num_stages, threads) +func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads) jit_kernel = tilelang.compile( func, out_idx=[2], @@ -75,7 +66,8 @@ def main( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, +) print(jit_kernel.get_kernel_source()) @@ -88,4 +80,4 @@ def main( profiler = jit_kernel.get_profiler() latency = profiler.do_bench() print(f"Latency: {latency} ms") -print(f"Flops: {2 * M * N * K / (latency/1e3) / 1e12} TFLOPS") +print(f"Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS") diff --git a/examples/gemm_sp/example_custom_compress.py b/examples/gemm_sp/example_custom_compress.py new file mode 100644 index 000000000..0544b8255 --- /dev/null +++ b/examples/gemm_sp/example_custom_compress.py @@ -0,0 +1,337 @@ +import argparse + +import tilelang +import tilelang.language as T + +from tilelang.layout import make_cutlass_metadata_layout +from tilelang.utils.sparse import randn_semi_sparse +from tilelang.utils.tensor import torch_assert_close + +from triton.testing import do_bench + +import torch + +torch.manual_seed(42) + +DEFAULT_CONFIG = { # take best config from autotune script + "4090": { + T.float: { + "block_M": 128, + "block_N": 64, + "block_K": 64, + "num_stages": 1, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + T.float16: { + "block_M": 256, + "block_N": 128, + "block_K": 64, + "num_stages": 2, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + }, + "h20": { + T.float: { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + T.float16: { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + }, +} + +ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} + + +@tilelang.jit(out_idx=[-1]) +def matmul_sp_fp16_custom_compress( + M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization, use_cutlass_layout +): + e_factor, e_dtype = (16, T.int16) + + @T.prim_func + def gemm_sp_fp16_custom_compress( + A_sparse: T.Tensor((M, K // 2), T.float16), + E: T.Tensor((M, K // e_factor), e_dtype), + B: T.Tensor((K, N), T.float16), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K // 2), T.float16) + E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) + B_shared = T.alloc_shared((block_K, block_N), T.float16) + C_shared = T.alloc_shared((block_M, block_N), accum_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + if use_cutlass_layout: + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="8.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="8.0", block_k=block_K), + } + ) + T.clear(C_local) + T.disable_warp_group_reg_alloc() + T.use_swizzle(panel_size=10, enable=enable_rasterization) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + T.copy(E[by * block_M, k * block_K // e_factor], E_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp_v2(A_shared, E_shared, B_shared, C_local, False, False, policy=policy) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return gemm_sp_fp16_custom_compress + + +def torch_compress(dense): + """ + A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout. + """ + if dense.dim() != 2: + raise RuntimeError(f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor") + + m, k = dense.shape + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError("Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 16") + else: + if m % 32 != 0: + raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 32") + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError(f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}") + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, _m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, _m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = meta_n[:, :, 0] | (meta_n[:, :, 1] << 4) | (meta_n[:, :, 2] << 8) | (meta_n[:, :, 3] << 12) + elif quadbits_per_meta_elem == 8: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) + + return (sparse, meta) + + +def decode_metadata(meta: torch.Tensor) -> torch.Tensor: + assert meta.dtype is torch.int16 + groups_per_meta = 16 // 4 # 4 groups per uint16 + out = [] + for g in range(groups_per_meta): + group_bits = (meta >> (g * 4)) & 0xF + idx0 = group_bits & 0x3 + idx1 = (group_bits >> 2) & 0x3 + out.append(torch.stack([idx0, idx1], dim=-1)) + return torch.concat(out, dim=-1).view(meta.shape[0], -1) + + +@tilelang.jit( + out_idx=[1, 2], + pass_configs={ + tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True, + }, +) +def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): + e_factor, e_dtype = ARCH_INFO["8.0"] + e_K = K // e_factor + elem, group = 2, 4 + + assert M % block_M == 0, "M must be divisible by block_M" + assert K % block_K == 0, "K must be divisible by block_K" + assert K % e_factor == 0, "K must be divisible by e_factor" + assert block_K % e_factor == 0, "block_K must be divisible by e_factor" + + @T.prim_func + def kernel( + A: T.Tensor((M, K), dtype), + A_sp: T.Tensor((M, K // 2), dtype), + E: T.Tensor((M, e_K), e_dtype), + ): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype) + E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) + if use_cutlass_layout: + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="8.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="8.0", block_k=block_K), + } + ) + T.clear(A_sp_shared) + T.clear(E_shared) + # TODO: alloc_var seems buggy here + non_zero_cnt = T.alloc_local((1,), dtype=T.uint8) + non_zero_elt_log_idx = T.alloc_local((elem,), dtype=T.uint8) + T.copy(A[bx * block_M, by * block_K], A_shared) + for tm in T.Parallel(block_M): + for g_i in range(0, block_K // group): + a_k = g_i * group + non_zero_cnt[0] = 0 + for i in range(elem): + non_zero_elt_log_idx[i] = 0 + for i in range(group): + val = A_shared[tm, a_k + i] + if val != 0.0: + non_zero_elt_log_idx[non_zero_cnt[0]] = i + A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val + non_zero_cnt[0] += 1 + # TODO: use T.device_assert(non_zero_cnt <= 2) after rebasing main + if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3: + non_zero_elt_log_idx[0] = 0 + non_zero_elt_log_idx[1] = 3 + A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2] + A_sp_shared[tm, a_k // 2] = 0.0 + elif non_zero_cnt[0] == 1: + A_sp_shared[tm, a_k // 2 + 1] = 0 + non_zero_elt_log_idx[1] = 3 + for i in T.serial(elem): + val = non_zero_elt_log_idx[i] + E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i) + T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2]) + T.copy(E_shared, E[bx * block_M, by * block_K // e_factor]) + + return kernel + + +def main(m=16384, n=16384, k=16384, use_cutlass_layout=False, use_torch_compressor=False, accum_dtype=None, cfg="4090"): + if accum_dtype is None: + accum_dtype = T.float + kernel = matmul_sp_fp16_custom_compress(m, n, k, accum_dtype, **DEFAULT_CONFIG[cfg][accum_dtype], use_cutlass_layout=use_cutlass_layout) + + a = randn_semi_sparse(m, k, device="cuda", dtype=torch.half) + b = torch.randn(k, n, device="cuda", dtype=torch.half) + + if use_torch_compressor: + assert not use_cutlass_layout, "torch sparse must be used with naive layout" + a_sparse, e = torch_compress(a) + else: + a_sparse, e = compress_kernel(m, k, 32, 32, T.float16, use_cutlass_layout=use_cutlass_layout)(a) + + c = kernel(a_sparse, e, b) + + ref_c = a @ b + + assert not c.isnan().any(), "Reference result contains NaNs, please report an issue" + torch_assert_close(c, ref_c.to(c.dtype), rtol=1e-3, atol=1e-3) + print(f"Precision check passed. Max diff: {(c - ref_c).abs().max()}, Mean diff: {(c - ref_c).abs().mean()}") + + latency = do_bench(lambda: kernel(a_sparse, e, b)) + ref_latency = do_bench(lambda: a @ b) + + total_flops = 2 * m * n * k + tflops = total_flops / latency / 1e9 + ref_tflops = total_flops / ref_latency / 1e9 + print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s") + print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") + parser.add_argument("--use_cutlass_layout", action="store_true", help="Use cutlass layout for E tensor") + parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference") + parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype") + parser.add_argument("--cfg", type=str, choices=["4090"], default="4090") + args = parser.parse_args() + accum_dtype = T.float if args.accum_dtype == "float" else T.float16 + main(args.m, args.n, args.k, args.use_cutlass_layout, args.use_torch_compressor, accum_dtype, args.cfg) diff --git a/examples/gemm_sp/example_gemm_sp.py b/examples/gemm_sp/example_gemm_sp.py index 505f2b883..8163c84cc 100644 --- a/examples/gemm_sp/example_gemm_sp.py +++ b/examples/gemm_sp/example_gemm_sp.py @@ -1,11 +1,9 @@ -# Copyright (c) Tile-AI Corporation. -# Licensed under the MIT License. import argparse import tilelang import tilelang.language as T -from tilelang.layout import make_metadata_layout +from tilelang.layout import make_cutlass_metadata_layout from tilelang.utils.sparse import compress, randn_semi_sparse from tilelang.contrib import nvcc from triton.testing import do_bench @@ -14,86 +12,79 @@ arch = nvcc.get_target_compute_version() -ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} - -default_config = { # take best config from autotune script +DEFAULT_CONFIG = { # take best config from autotune script "4090": { - 'float': { - 'block_M': 128, - 'block_N': 64, - 'block_K': 64, - 'num_stages': 1, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True + T.float: { + "block_M": 128, + "block_N": 64, + "block_K": 64, + "num_stages": 1, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + T.float16: { + "block_M": 256, + "block_N": 128, + "block_K": 64, + "num_stages": 2, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, }, - 'float16': { - 'block_M': 256, - 'block_N': 128, - 'block_K': 64, - 'num_stages': 2, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True - } }, "h20": { - 'float': { - 'block_M': 128, - 'block_N': 64, - 'block_K': 128, - 'num_stages': 3, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True + T.float: { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, }, - 'float16': { - 'block_M': 128, - 'block_N': 64, - 'block_K': 128, - 'num_stages': 3, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True - } - } + T.float16: { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + }, } +ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} + @tilelang.jit(out_idx=[-1]) -def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, - enable_rasterization): +def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization): e_factor, e_dtype = ARCH_INFO[arch] @T.prim_func def gemm_sp_fp16( - A_sparse: T.Tensor((M, K // 2), 'float16'), - E: T.Tensor((M, K // e_factor), e_dtype), - B: T.Tensor((K, N), 'float16'), - C: T.Tensor((M, N), accum_dtype), + A_sparse: T.Tensor((M, K // 2), T.float16), + E: T.Tensor((M, K // e_factor), e_dtype), + B: T.Tensor((K, N), T.float16), + C: T.Tensor((M, N), accum_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K // 2), 'float16') + A_shared = T.alloc_shared((block_M, block_K // 2), T.float16) E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) - B_shared = T.alloc_shared((block_K, block_N), 'float16') + B_shared = T.alloc_shared((block_K, block_N), T.float16) C_shared = T.alloc_shared((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) T.disable_warp_group_reg_alloc() T.use_swizzle(panel_size=10, enable=enable_rasterization) - T.annotate_layout({ - E: - make_metadata_layout( - E, mma_dtype="float16", backend="cutlass", block_k=block_K, arch=arch), - E_shared: - make_metadata_layout( - E_shared, - mma_dtype="float16", - backend="cutlass", - block_k=block_K, - arch=arch), - }) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, block_k=block_K, arch=arch), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, block_k=block_K, arch=arch), + } + ) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) T.copy(E[by * block_M, k * block_K // e_factor], E_shared) @@ -106,30 +97,15 @@ def gemm_sp_fp16( return gemm_sp_fp16 -def main(): - parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") - parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") - parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") - parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") - parser.add_argument( - "--accum_dtype", - type=str, - default="float", - choices=["float", "float16"], - help="Accumulation datatype") - parser.add_argument("--cfg", type=str, choices=["4090", "h20"], required=True) - args = parser.parse_args() - kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, - **default_config[args.cfg][args.accum_dtype]) +def main(m=16384, n=16384, k=16384, accum_dtype=None, cfg="4090"): + if accum_dtype is None: + accum_dtype = T.float + kernel = matmul_sp_fp16(m, n, k, accum_dtype, **DEFAULT_CONFIG[cfg][accum_dtype]) - a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half) - b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half) + a = randn_semi_sparse(m, k, device="cuda", dtype=torch.half) + b = torch.randn(k, n, device="cuda", dtype=torch.half) - a_sparse, e = compress( - a, - transposed=False, - block_k=default_config[args.cfg][args.accum_dtype]['block_K'], - arch=arch) + a_sparse, e = compress(a, transposed=False, block_k=DEFAULT_CONFIG[cfg][accum_dtype]["block_K"], arch=arch) c = kernel(a_sparse, e, b) ref_c = a @ b @@ -141,12 +117,20 @@ def main(): latency = do_bench(lambda: kernel(a_sparse, e, b)) ref_latency = do_bench(lambda: a @ b) - total_flops = 2 * args.m * args.n * args.k + total_flops = 2 * m * n * k tflops = total_flops / latency / 1e9 ref_tflops = total_flops / ref_latency / 1e9 - print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s") - print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s") + print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s") + print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s") if __name__ == "__main__": - main() + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") + parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype") + parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090") + args = parser.parse_args() + accum_dtype = T.float if args.accum_dtype == "float" else T.float16 + main(args.m, args.n, args.k, accum_dtype, args.cfg) diff --git a/examples/gemm_sp/test_example_gemm_sp.py b/examples/gemm_sp/test_example_gemm_sp.py new file mode 100644 index 000000000..fe26df144 --- /dev/null +++ b/examples/gemm_sp/test_example_gemm_sp.py @@ -0,0 +1,16 @@ +import tilelang.testing + +import example_custom_compress +import example_gemm_sp + + +def test_example_custom_compress(): + example_custom_compress.main() + + +def test_example_gemm_sp(): + example_gemm_sp.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/gemm_splitk/example_tilelang_gemm_splitk.py b/examples/gemm_splitk/example_tilelang_gemm_splitk.py index c96669711..64ffade8e 100644 --- a/examples/gemm_splitk/example_tilelang_gemm_splitk.py +++ b/examples/gemm_splitk/example_tilelang_gemm_splitk.py @@ -3,27 +3,16 @@ @tilelang.jit -def matmul(M, - N, - K, - block_M, - block_N, - block_K, - split_k, - dtype="float16", - accum_dtype="float", - out_dtype="float32"): - +def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype=T.float16, accum_dtype=T.float32, out_dtype=T.float32): splitK = K // split_k @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), out_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype) @@ -67,5 +56,28 @@ def main(): torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2) +def run_regression_perf(): + M = 4096 + N = 4096 + K = 4096 + block_M = 128 + block_N = 128 + block_K = 32 + split_k = 4 + kernel = matmul(M, N, K, block_M, block_N, block_K, split_k) + import torch + + torch.random.manual_seed(42) + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + c = torch.zeros(M, N).cuda().float() + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(a, b, c) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py b/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py index 145d622ed..3d33478cf 100644 --- a/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py +++ b/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py @@ -3,27 +3,16 @@ @tilelang.jit -def matmul(M, - N, - K, - block_M, - block_N, - block_K, - split_k, - dtype="float16", - accum_dtype="float", - out_dtype="float32"): - +def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype=T.float16, accum_dtype=T.float32, out_dtype=T.float32): splitK = K // split_k @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), out_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype) @@ -66,5 +55,29 @@ def main(): torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2) +def run_regression_perf(): + M = 4096 + N = 4096 + K = 4096 + block_M = 128 + block_N = 128 + block_K = 32 + split_k = 4 + + kernel = matmul(M, N, K, block_M, block_N, block_K, split_k) + import torch + + torch.random.manual_seed(42) + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + c = torch.zeros(M, N).cuda().float() + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(a, b, c) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/gemm_splitk/regression_example_gemm_splitk.py b/examples/gemm_splitk/regression_example_gemm_splitk.py new file mode 100644 index 000000000..c76b7e55c --- /dev/null +++ b/examples/gemm_splitk/regression_example_gemm_splitk.py @@ -0,0 +1,15 @@ +import tilelang.testing +import example_tilelang_gemm_splitk +import example_tilelang_gemm_splitk_vectorize_atomicadd + + +def regression_example_tilelang_gemm_splitk(): + tilelang.testing.process_func(example_tilelang_gemm_splitk.run_regression_perf) + + +def regression_example_tilelang_gemm_splitk_vectorize_atomicadd(): + tilelang.testing.process_func(example_tilelang_gemm_splitk_vectorize_atomicadd.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/gemm_streamk/example_tilelang_gemm_streamk.py b/examples/gemm_streamk/example_tilelang_gemm_streamk.py index 31cf40647..b2e8e9369 100644 --- a/examples/gemm_streamk/example_tilelang_gemm_streamk.py +++ b/examples/gemm_streamk/example_tilelang_gemm_streamk.py @@ -39,7 +39,7 @@ def cdiv(a, b): # Two-tile SK + DP streamk_tiles = total_tiles % streamk_programs -if (total_tiles - streamk_tiles > streamk_programs): # (total_tiles // total_programs > 1) +if total_tiles - streamk_tiles > streamk_programs: # (total_tiles // total_programs > 1) streamk_tiles += streamk_programs blocking_tiles = total_tiles - streamk_tiles @@ -77,95 +77,71 @@ def tl_matmul_streamk( A_shared_shape = (block_M, block_K) if not trans_A else (block_K, block_M) B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K) - @T.macro - def compute_first_wave( - pid: T.int32, - A_buf: T.Tensor, - A_buf_shared: T.SharedBuffer, - B_buf: T.Tensor, - B_buf_shared: T.SharedBuffer, - C: T.Tensor, - C_local: T.LocalBuffer, - ): - start_iter = T.alloc_fragment((1,), "int32", "local") - end_iter = T.alloc_fragment((1,), "int32", "local") - - start_iter[0] = pid * streamk_full_tiles + T.min(pid, streamk_partial_tiles) - last_iter = (pid + 1) * streamk_full_tiles + T.min(pid + 1, streamk_partial_tiles) - - while start_iter[0] < last_iter: - end_iter[0] = T.min( - start_iter[0] + (iters_per_tile - (start_iter[0] % iters_per_tile)), - last_iter, - ) - - tile_id = start_iter[0] // iters_per_tile - remain_iters = start_iter[0] % iters_per_tile - pid_m = tile_id // T.ceildiv(N, block_N) - pid_n = tile_id % T.ceildiv(N, block_N) - - T.clear(C_local) - for k in T.Pipelined(end_iter[0] - start_iter[0], num_stages=num_stages): - T.copy( - A_buf[pid_m * block_M, (k + (start_iter[0] % iters_per_tile)) * block_K], - A_buf_shared, - ) - T.copy( - B_buf[pid_n * block_N, (k + (start_iter[0] % iters_per_tile)) * block_K], - B_buf_shared, - ) - T.gemm(A_buf_shared, B_buf_shared, C_local, transpose_B=trans_B) - - # last iteration of the tile always happens before its start on another SM - if remain_iters == 0 and (end_iter[0] % iters_per_tile == 0): - T.copy(C_local, C[pid_m * block_M, pid_n * block_N]) - else: - for i, j in T.Parallel(block_M, block_N): - T.atomic_add(C[pid_m * block_M + i, pid_n * block_N + j], C_local[i, j]) - - start_iter[0] = end_iter[0] - - @T.macro - def compute_full_tiles( - pid: T.int32, - A_buf: T.Tensor, - A_shared: T.SharedBuffer, - B_buf: T.Tensor, - B_shared: T.SharedBuffer, - C: T.Tensor, - C_local: T.LocalBuffer, - ): - - for p in T.serial(sm_patition_factor): - tile_id = pid + streamk_tiles + p * total_sm - pid_m = tile_id // T.ceildiv(N, block_N) - pid_n = tile_id % T.ceildiv(N, block_N) - T.clear(C_local) - - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): - T.copy(A_buf[pid_m * block_M, k * block_K], A_shared) - T.copy(B_buf[pid_n * block_N, k * block_K], B_shared) - T.gemm(A_shared, B_shared, C_local, transpose_B=trans_B) - T.copy(C_local, C[pid_m * block_M, pid_n * block_N]) - @T.prim_func def main( - A: T.Tensor(A_shape, dtypeAB), - B: T.Tensor(B_shape, dtypeAB), - C: T.Tensor((M, N), dtypeC), + A: T.Tensor(A_shape, dtypeAB), + B: T.Tensor(B_shape, dtypeAB), + C: T.Tensor((M, N), dtypeC), ): with T.Kernel(streamk_programs, threads=threads) as pid: - A_shared = T.alloc_shared(A_shared_shape, dtypeAB) B_shared = T.alloc_shared(B_shared_shape, dtypeAB) A_shared_full_tiles = T.alloc_shared(A_shared_shape, dtypeAB) B_shared_full_tiles = T.alloc_shared(B_shared_shape, dtypeAB) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - compute_first_wave(pid, A, A_shared, B, B_shared, C, C_local) + # compute first wave + start_iter = T.alloc_fragment((1,), T.int32, "local") + end_iter = T.alloc_fragment((1,), T.int32, "local") + + start_iter[0] = pid * streamk_full_tiles + T.min(pid, streamk_partial_tiles) + last_iter = (pid + 1) * streamk_full_tiles + T.min(pid + 1, streamk_partial_tiles) + while start_iter[0] < last_iter: + end_iter[0] = T.min( + start_iter[0] + (iters_per_tile - (start_iter[0] % iters_per_tile)), + last_iter, + ) + + tile_id = start_iter[0] // iters_per_tile + remain_iters = start_iter[0] % iters_per_tile + pid_m = tile_id // T.ceildiv(N, block_N) + pid_n = tile_id % T.ceildiv(N, block_N) + + T.clear(C_local) + for k in T.Pipelined(end_iter[0] - start_iter[0], num_stages=num_stages): + T.copy( + A[pid_m * block_M, (k + (start_iter[0] % iters_per_tile)) * block_K], + A_shared, + ) + T.copy( + B[pid_n * block_N, (k + (start_iter[0] % iters_per_tile)) * block_K], + B_shared, + ) + T.gemm(A_shared, B_shared, C_local, transpose_B=trans_B) + + # last iteration of the tile always happens before its start on another SM + if remain_iters == 0 and (end_iter[0] % iters_per_tile == 0): + T.copy(C_local, C[pid_m * block_M, pid_n * block_N]) + else: + for i, j in T.Parallel(block_M, block_N): + T.atomic_add(C[pid_m * block_M + i, pid_n * block_N + j], C_local[i, j]) + + start_iter[0] = end_iter[0] + + # compute full tiles if sm_patition_factor > 0: - compute_full_tiles(pid, A, A_shared_full_tiles, B, B_shared_full_tiles, C, C_local) + for p in T.serial(sm_patition_factor): + tile_id = pid + streamk_tiles + p * total_sm + pid_m = tile_id // T.ceildiv(N, block_N) + pid_n = tile_id % T.ceildiv(N, block_N) + T.clear(C_local) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): + T.copy(A[pid_m * block_M, k * block_K], A_shared_full_tiles) + T.copy(B[pid_n * block_N, k * block_K], B_shared_full_tiles) + T.gemm(A_shared_full_tiles, B_shared_full_tiles, C_local, transpose_B=trans_B) + T.copy(C_local, C[pid_m * block_M, pid_n * block_N]) return main @@ -181,9 +157,9 @@ def main(): BLOCK_SIZE_K, False, True, - "float16", - "float16", - "float32", + T.float16, + T.float16, + T.float32, 2, 64, ) @@ -201,5 +177,30 @@ def main(): torch.testing.assert_close(C, b_c, rtol=1e-2, atol=1e-2) +def run_regression_perf(): + kernel = tl_matmul_streamk( + m, + n, + k, + streamk_tiles, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + False, + True, + "float16", + "float16", + "float32", + 2, + 64, + ) + b_c = torch.zeros((m, n), device="cuda", dtype=torch.float16) + torch.cuda.synchronize() + + from tilelang.profiler import do_bench + + return do_bench(lambda: kernel(A, B, b_c), backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/gemm_streamk/test_example_tilelang_gemm_splitk.py b/examples/gemm_streamk/test_example_tilelang_gemm_streamk.py similarity index 100% rename from examples/gemm_streamk/test_example_tilelang_gemm_splitk.py rename to examples/gemm_streamk/test_example_tilelang_gemm_streamk.py diff --git a/examples/gemv/example_gemv.py b/examples/gemv/example_gemv.py index 4e43dcd9a..8ca77a2e8 100644 --- a/examples/gemv/example_gemv.py +++ b/examples/gemv/example_gemv.py @@ -17,15 +17,14 @@ def naive_gemv( K: int, BLOCK_N: int, BLOCK_K: int, - dtype: str = "float16", - accum_dtype: str = "float", + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, ): - @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N)) as bn: tn = T.get_thread_binding(0) # tn = threadIdx.x @@ -38,8 +37,7 @@ def main( A_shared[tk] = A[bk * BLOCK_K + tk] B_shared[tn, tk] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk] for tk in T.serial(BLOCK_K): - C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn, - tk].astype(accum_dtype) + C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn, tk].astype(accum_dtype) C[bn * BLOCK_N + tn] = C_reg[0] return main @@ -51,15 +49,14 @@ def naive_splitk_gemv( K: int, BLOCK_N: int, BLOCK_K: int, - dtype: str = "float16", - accum_dtype: str = "float", + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, ): - @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, BLOCK_K)) as bn: tn = T.get_thread_binding(0) @@ -88,16 +85,16 @@ def splitk_gemv( BLOCK_N: int, BLOCK_K: int, reduce_threads: int, - dtype: str = "float16", - accum_dtype: str = "float", + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, ): TILE_K = T.ceildiv(BLOCK_K, reduce_threads) @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: tn = T.get_thread_binding(0) @@ -127,8 +124,8 @@ def splitk_gemv_vectorized( K: int, BLOCK_N: int, reduce_threads: int, - dtype: str = "float16", - accum_dtype: str = "float", + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, ): MAX_TRANSACTION_SIZE_IN_BITS = 128 TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits @@ -136,9 +133,9 @@ def splitk_gemv_vectorized( @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: tn = T.get_thread_binding(0) @@ -168,8 +165,8 @@ def splitk_gemv_vectorized_tvm( K: int, BLOCK_N: int, reduce_threads: int, - dtype: str = "float16", - accum_dtype: str = "float", + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, ): MAX_TRANSACTION_SIZE_IN_BITS = 128 TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits @@ -177,9 +174,9 @@ def splitk_gemv_vectorized_tvm( @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: tn = T.get_thread_binding(0) @@ -197,9 +194,9 @@ def main( C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) C_reduced = T.alloc_local((1,), accum_dtype) with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), - "reduce_scope", - T.reinterpret(T.uint64(0), dtype="handle"), + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), ): T.evaluate( T.tvm_thread_allreduce( @@ -209,7 +206,8 @@ def main( C_reduced[0], tk, dtype="handle", - )) + ) + ) C[bn * BLOCK_N + tn] = C_reduced[0] @@ -218,10 +216,8 @@ def main( def get_block_template_configs(): iter_params = dict( - block_M=[2, 4, 8, 32, 64, 128], - block_N=[2, 4, 8, 32, 64, 128], - num_stages=[0, 1, 2, 3, 4], - threads=[32, 64, 128, 256]) + block_M=[2, 4, 8, 32, 64, 128], block_N=[2, 4, 8, 32, 64, 128], num_stages=[0, 1, 2, 3, 4], threads=[32, 64, 128, 256] + ) return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] @@ -237,18 +233,11 @@ def get_block_template_configs(): }, out_idx=[2], ) -def gemv_alloc_reducer(M, - N, - block_M=128, - block_N=128, - num_stages=2, - threads=256, - dtype: str = "float16", - accum_dtype: str = "float"): - +def gemv_alloc_reducer( + M, N, block_M=128, block_N=128, num_stages=2, threads=256, dtype: T.dtype = T.float16, accum_dtype: T.dtype = T.float +): @T.prim_func - def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M, - dtype)): # type: ignore + def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M, dtype)): # type: ignore with T.Kernel(T.ceildiv(M, block_M), threads=threads) as i0_m: o_reducer = T.alloc_reducer(block_M, accum_dtype, replication="all") T.clear(o_reducer) @@ -287,17 +276,17 @@ def get_autotuned_kernel( BLOCK_N=None, reduce_threads=None, ): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 MAX_TRANSACTION_SIZE_IN_BITS = 128 TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits BLOCK_K = reduce_threads * TILE_K @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: tn = T.get_thread_binding(0) @@ -315,9 +304,9 @@ def main( C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) C_reduced = T.alloc_local((1,), accum_dtype) with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), - "reduce_scope", - T.reinterpret(T.uint64(0), dtype="handle"), + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), ): T.evaluate( T.tvm_thread_allreduce( @@ -327,21 +316,22 @@ def main( C_reduced[0], tk, dtype="handle", - )) + ) + ) C[bn * BLOCK_N + tn] = C_reduced[0] return main -def check_correctness_and_bench(kernel, N, K, bench_ref=True): +def check_correctness_and_bench(kernel, N, K, do_bench=True): profiler = kernel.get_profiler() profiler.assert_allclose(lambda x, y: x @ y.T, atol=1e-2, rtol=1e-2) - if bench_ref: + if do_bench: latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=50) print(f"Torch Latency: {latency} ms") - latency = profiler.do_bench(kernel, warmup=50) - print(f"TileLang Latency: {latency} ms\n") + latency = profiler.do_bench(kernel, warmup=50) + print(f"TileLang Latency: {latency} ms\n") def main(do_bench: bool = True): @@ -350,16 +340,16 @@ def main(do_bench: bool = True): parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K") args, _ = parser.parse_known_args() N, K = args.n, args.k - check_correctness_and_bench(naive_gemv(N, K, 128, 128), N, K) - check_correctness_and_bench(naive_splitk_gemv(N, K, 32, 32), N, K) - check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K) - check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K) - check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K) - check_correctness_and_bench(gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K) + check_correctness_and_bench(naive_gemv(N, K, 128, 128), N, K, do_bench=do_bench) + check_correctness_and_bench(naive_splitk_gemv(N, K, 32, 32), N, K, do_bench=do_bench) + check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K, do_bench=do_bench) + check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K, do_bench=do_bench) + check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K, do_bench=do_bench) + check_correctness_and_bench(gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K, do_bench=do_bench) print("Test passed!") - if not do_bench: + if do_bench: best_result = get_autotuned_kernel(N, K) best_config = best_result.config kernel = splitk_gemv_vectorized_tvm(N, K, **best_config) @@ -374,5 +364,23 @@ def main(do_bench: bool = True): print(f"TileLang BlockReduce Latency: {tilelang_tile_latency} ms\n") +def run_regression_perf(): + N, K = 4096, 4096 + latency = 0.0 + kernel_list = [ + naive_gemv(N, K, 128, 128), + naive_splitk_gemv(N, K, 32, 32), + splitk_gemv(N, K, 32, 32, 32), + splitk_gemv_vectorized(N, K, 2, 32), + splitk_gemv_vectorized_tvm(N, K, 2, 32), + gemv_alloc_reducer(N, K, block_M=128, block_N=128), + ] + for kernel in kernel_list: + profiler = kernel.get_profiler() + # Benchmark the TileLang kernel itself, not the PyTorch reference. + latency += profiler.do_bench(backend="cupti") + return latency / len(kernel_list) + + if __name__ == "__main__": main() diff --git a/examples/gemv/regression_example_gemv.py b/examples/gemv/regression_example_gemv.py new file mode 100644 index 000000000..dd6f1d39f --- /dev/null +++ b/examples/gemv/regression_example_gemv.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_gemv + + +def regression_example_gemv(): + tilelang.testing.process_func(example_gemv.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/gemv/test_example_gemv.py b/examples/gemv/test_example_gemv.py index 3881ca769..323337a7a 100644 --- a/examples/gemv/test_example_gemv.py +++ b/examples/gemv/test_example_gemv.py @@ -1,5 +1,3 @@ -import tilelang.testing - import example_gemv @@ -8,4 +6,4 @@ def test_example_gemv(): if __name__ == "__main__": - tilelang.testing.main() + test_example_gemv() diff --git a/examples/grouped_gemm/example_grouped_gemm_bwd.py b/examples/grouped_gemm/example_grouped_gemm_bwd.py index ac8da7e2c..49cce0d1d 100644 --- a/examples/grouped_gemm/example_grouped_gemm_bwd.py +++ b/examples/grouped_gemm/example_grouped_gemm_bwd.py @@ -5,78 +5,55 @@ import tilelang.language as T -@tilelang.jit( - out_idx=[2], pass_configs={ - "tl.disable_tma_lower": True, - "tl.disable_warp_specialized": True - }) -def grouped_gemm_fwd(batch_sum, - batch_count, - K, - N, - block_M, - block_N, - block_K, - num_stages=2, - threads=128, - dtype="float16"): +@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +def grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16): """ args: a (torch.Tensor): Input tensor of shape (M, K). b (torch.Tensor): Input tensor of shape (G, K, N). """ - accum_dtype = "float32" + accum_dtype = T.float32 @T.prim_func def kernel( - A: T.Tensor([batch_sum, K], dtype), # type: ignore - B: T.Tensor([batch_count, K, N], dtype), # type: ignore - C: T.Tensor([batch_sum, N], dtype), # type: ignore - batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore - batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore - batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore + A: T.Tensor([batch_sum, K], dtype), # type: ignore + B: T.Tensor([batch_count, K, N], dtype), # type: ignore + C: T.Tensor([batch_sum, N], dtype), # type: ignore + batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore + batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore + batch_padded_offsets: T.Tensor([batch_count], T.int32), # type: ignore ): - - with T.Kernel( - T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N), threads=threads) as (bx, by): A_shared = T.alloc_shared([block_M, block_K], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype) C_local = T.alloc_fragment([block_M, block_N], accum_dtype) - cur_batch_idx = T.alloc_local([1], "int32") - cur_batch_size = T.alloc_local([1], "int32") + cur_batch_idx = T.alloc_var(dtype=T.int32) + cur_batch_size = T.alloc_var(dtype=T.int32) m_start_padded = bx * block_M for i in range(batch_count): - in_cur_batch_idx = (m_start_padded >= batch_padded_offsets[i]) - cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0]) + in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i] + cur_batch_idx = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx) - cur_batch_size[0] = batch_sizes[cur_batch_idx[0]] - m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[ - cur_batch_idx[0]] - actual_rows = T.max( - 0, - T.min(block_M, - cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded)) + cur_batch_size = batch_sizes[cur_batch_idx] + m_start = m_start_padded - batch_padded_offsets[cur_batch_idx] + batch_offsets[cur_batch_idx] + actual_rows = T.max(0, T.min(block_M, cur_batch_size + batch_padded_offsets[cur_batch_idx] - m_start_padded)) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[m_start:m_start + block_M, k * block_K:(k + 1) * block_K], A_shared) - T.copy( - B[cur_batch_idx[0], k * block_K:(k + 1) * block_K, - by * block_N:(by + 1) * block_N], B_shared) + T.copy(A[m_start : m_start + block_M, k * block_K : (k + 1) * block_K], A_shared) + T.copy(B[cur_batch_idx, k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared) T.gemm(A_shared, B_shared, C_local) for i, j in T.Parallel(block_M, block_N): - with T.If(i < actual_rows), T.Then(): + if i < actual_rows: C[m_start + i, by * block_N + j] = C_local[i, j] return kernel class _GroupedGEMM(torch.autograd.Function): - @staticmethod def forward(ctx, a, b, batch_sizes): block_M = 64 @@ -99,15 +76,11 @@ def forward(ctx, a, b, batch_sizes): for i in range(batch_count - 1): batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes[i]) for i in range(batch_count - 1): - batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + - math.ceil((batch_sizes[i] + 1) / padding_M) * - padding_M) + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes[i] + 1) / padding_M) * padding_M) batch_offsets = torch.tensor(batch_offsets_list, device=a.device, dtype=torch.int32) - batch_padded_offsets = torch.tensor( - batch_padded_offsets_list, device=a.device, dtype=torch.int32) + batch_padded_offsets = torch.tensor(batch_padded_offsets_list, device=a.device, dtype=torch.int32) - kernel = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, - num_stages, threads) + kernel = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages, threads) o = kernel(a, b, batch_sizes, batch_offsets, batch_padded_offsets) ctx.save_for_backward(a, b, batch_sizes, batch_offsets) @@ -135,8 +108,7 @@ def maybe_contiguous(x): return x A, B, batch_sizes = [maybe_contiguous(x) for x in (A, B, batch_sizes)] - kernel = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K, - num_stages, threads) + kernel = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K, num_stages, threads) dB = kernel(A, grad_output, batch_sizes, batch_offsets) return None, dB, None @@ -172,9 +144,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): for i in range(batch_count - 1): batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i]) for i in range(batch_count - 1): - batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + - math.ceil((batch_sizes_list[i] + 1) / padding_M) * - padding_M) + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes_list[i] + 1) / padding_M) * padding_M) A = torch.randn(batch_sum, K, device=device, dtype=dtype) B = torch.randn(batch_count, K, M, device=device, dtype=dtype) C = torch.empty(batch_sum, M, device=device, dtype=dtype) @@ -187,40 +157,24 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets -@tilelang.jit( - out_idx=[2], pass_configs={ - "tl.disable_tma_lower": True, - "tl.disable_warp_specialized": True - }) -def grouped_gemm_bwd(batch_sum, - batch_count, - M, - N, - block_M, - block_N, - block_K, - num_stages=2, - threads=128, - dtype="float16"): +@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +def grouped_gemm_bwd(batch_sum, batch_count, M, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16): """ args: a (torch.Tensor): Input tensor of shape (M, K). b (torch.Tensor): Input tensor of shape (G, K, N). """ - accum_dtype = "float32" + accum_dtype = T.float32 @T.prim_func def kernel( - A: T.Tensor([batch_sum, M], dtype), # type: ignore - B: T.Tensor([batch_sum, N], dtype), # type: ignore - C: T.Tensor([batch_count, M, N], dtype), # type: ignore - batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore - batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore + A: T.Tensor([batch_sum, M], dtype), # type: ignore + B: T.Tensor([batch_sum, N], dtype), # type: ignore + C: T.Tensor([batch_count, M, N], dtype), # type: ignore + batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore + batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore ): - - with T.Kernel( - T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count, - threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count, threads=threads) as (bx, by, bz): A_shared = T.alloc_shared([block_K, block_M], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype) C_local = T.alloc_fragment([block_M, block_N], accum_dtype) @@ -228,13 +182,9 @@ def kernel( T.clear(C_local) for k in T.Pipelined(T.ceildiv(batch_sizes[bz], block_K), num_stages=num_stages): for i, j in T.Parallel(block_K, block_M): - A_shared[i, j] = T.if_then_else( - i < batch_sizes[bz], A[batch_offsets[bz] + k * block_K + i, - bx * block_M + j], 0) + A_shared[i, j] = T.if_then_else(i < batch_sizes[bz], A[batch_offsets[bz] + k * block_K + i, bx * block_M + j], 0) for i, j in T.Parallel(block_K, block_N): - B_shared[i, j] = T.if_then_else( - i < batch_sizes[bz], B[batch_offsets[bz] + k * block_K + i, - by * block_N + j], 0) + B_shared[i, j] = T.if_then_else(i < batch_sizes[bz], B[batch_offsets[bz] + k * block_K + i, by * block_N + j], 0) T.gemm(A_shared, B_shared, C_local, transpose_A=True) T.copy(C_local, C[bz, bx * block_M, by * block_N]) @@ -242,23 +192,12 @@ def kernel( return kernel -def run_tilelang_grouped_gemm(batch_sizes_list, - K, - M, - block_M, - block_N, - block_K, - trans_b, - num_stages=2, - threads=128, - profile=False): - +def run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages=2, threads=128, profile=False): padding_M = block_M device = torch.device("cuda") dtype = torch.float16 - A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs( - batch_sizes_list, K, M, False, padding_M, device, dtype) + A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(batch_sizes_list, K, M, False, padding_M, device, dtype) A.requires_grad_(False) B.requires_grad_(True) @@ -273,10 +212,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list, O.backward(dO, retain_graph=True) dB, B.grad = B.grad.clone(), None - if ( - torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) and \ - torch.allclose(dB, dB_ref, rtol=1e-2, atol=1e-2) - ): + if torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) and torch.allclose(dB, dB_ref, rtol=1e-2, atol=1e-2): print("✅ Tilelang and Torch match") else: print("❌ Tilelang and Torch mismatch") @@ -284,12 +220,11 @@ def run_tilelang_grouped_gemm(batch_sizes_list, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - '--batch_sizes', type=str, default="64, 128", help='comma-separated batch sizes') - parser.add_argument('--K', type=int, default=8192, help='reduce dim') - parser.add_argument('--M', type=int, default=8192, help='output dim') - parser.add_argument('--trans_b', action="store_true", help="transpose B") - parser.add_argument('--profile', action="store_true", help="profile") + parser.add_argument("--batch_sizes", type=str, default="64, 128", help="comma-separated batch sizes") + parser.add_argument("--K", type=int, default=8192, help="reduce dim") + parser.add_argument("--M", type=int, default=8192, help="output dim") + parser.add_argument("--trans_b", action="store_true", help="transpose B") + parser.add_argument("--profile", action="store_true", help="profile") args = parser.parse_args() batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")] @@ -301,14 +236,4 @@ def run_tilelang_grouped_gemm(batch_sizes_list, num_stages = 2 threads = 256 - run_tilelang_grouped_gemm( - batch_sizes_list, - K, - M, - block_M, - block_N, - block_K, - trans_b, - num_stages, - threads, - profile=args.profile) + run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages, threads, profile=args.profile) diff --git a/examples/grouped_gemm/example_grouped_gemm_fwd.py b/examples/grouped_gemm/example_grouped_gemm_fwd.py index 9b58e3a21..b71472741 100644 --- a/examples/grouped_gemm/example_grouped_gemm_fwd.py +++ b/examples/grouped_gemm/example_grouped_gemm_fwd.py @@ -18,8 +18,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False): torch.Tensor: Resulting tensor after grouped matrix multiplication. """ assert a.shape[0] == sum(batch_sizes), "Sum of batch_sizes must equal the first dimension of a" - assert b.shape[0] == len( - batch_sizes), "The first dimension of b must match the length of batch_sizes" + assert b.shape[0] == len(batch_sizes), "The first dimension of b must match the length of batch_sizes" # Initialize output tensor output = torch.empty((sum(batch_sizes), b.shape[2]), device=a.device, dtype=a.dtype) @@ -38,15 +37,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False): @tilelang.jit(out_idx=[2]) -def grouped_gemm(batch_sizes_list, - K, - N, - block_M, - block_N, - block_K, - num_stages=2, - threads=128, - dtype="float16"): +def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16): """ args: a (torch.Tensor): Input tensor of shape (M, K). @@ -54,50 +45,43 @@ def grouped_gemm(batch_sizes_list, """ batch_sum = sum(batch_sizes_list) batch_count = len(batch_sizes_list) - accum_dtype = "float32" + accum_dtype = T.float32 total_m_blocks = sum((size + block_M - 1) // block_M for size in batch_sizes_list) @T.prim_func def kernel( - A: T.Tensor([batch_sum, K], dtype), # type: ignore - B: T.Tensor([batch_count, K, N], dtype), # type: ignore - C: T.Tensor([batch_sum, N], dtype), # type: ignore - batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore - batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore - batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore + A: T.Tensor([batch_sum, K], dtype), # type: ignore + B: T.Tensor([batch_count, K, N], dtype), # type: ignore + C: T.Tensor([batch_sum, N], dtype), # type: ignore + batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore + batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore + batch_padded_offsets: T.Tensor([batch_count], T.int32), # type: ignore ): - with T.Kernel(total_m_blocks, T.ceildiv(N, block_N), threads=threads) as (bx, by): A_shared = T.alloc_shared([block_M, block_K], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype) C_local = T.alloc_fragment([block_M, block_N], accum_dtype) - cur_batch_idx = T.alloc_local([1], "int32") - cur_batch_size = T.alloc_local([1], "int32") + cur_batch_idx = T.alloc_var(dtype=T.int32) + cur_batch_size = T.alloc_var(dtype=T.int32) m_start_padded = bx * block_M for i in range(batch_count): - in_cur_batch_idx = (m_start_padded >= batch_padded_offsets[i]) - cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0]) + in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i] + cur_batch_idx = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx) - cur_batch_size[0] = batch_sizes[cur_batch_idx[0]] - m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[ - cur_batch_idx[0]] - actual_rows = T.max( - 0, - T.min(block_M, - cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded)) + cur_batch_size = batch_sizes[cur_batch_idx] + m_start = m_start_padded - batch_padded_offsets[cur_batch_idx] + batch_offsets[cur_batch_idx] + actual_rows = T.max(0, T.min(block_M, cur_batch_size + batch_padded_offsets[cur_batch_idx] - m_start_padded)) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[m_start:m_start + block_M, k * block_K:(k + 1) * block_K], A_shared) - T.copy( - B[cur_batch_idx[0], k * block_K:(k + 1) * block_K, - by * block_N:(by + 1) * block_N], B_shared) + T.copy(A[m_start : m_start + block_M, k * block_K : (k + 1) * block_K], A_shared) + T.copy(B[cur_batch_idx, k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared) T.gemm(A_shared, B_shared, C_local) for i, j in T.Parallel(block_M, block_N): - with T.If(i < actual_rows), T.Then(): + if i < actual_rows: C[m_start + i, by * block_N + j] = C_local[i, j] return kernel @@ -111,8 +95,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): for i in range(batch_count - 1): batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i]) for i in range(batch_count - 1): - batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + - math.ceil((batch_sizes_list[i]) / padding_M) * padding_M) + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes_list[i]) / padding_M) * padding_M) A = torch.randn(batch_sum, K, device=device, dtype=dtype) B = torch.randn(batch_count, K, M, device=device, dtype=dtype) C = torch.empty(batch_sum, M, device=device, dtype=dtype) @@ -125,27 +108,16 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets -def run_tilelang_grouped_gemm(batch_sizes_list, - K, - M, - block_M, - block_N, - block_K, - trans_b, - num_stages=2, - threads=128, - profile=False): +def run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages=2, threads=128, profile=False): padding_M = block_M batch_sum = sum(batch_sizes_list) - kernel = grouped_gemm( - tuple(batch_sizes_list), K, M, block_M, block_N, block_K, num_stages, threads) + kernel = grouped_gemm(tuple(batch_sizes_list), K, M, block_M, block_N, block_K, num_stages, threads) # print(kernel.get_kernel_source()) device = torch.device("cuda") dtype = torch.float16 - A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs( - batch_sizes_list, K, M, trans_b, padding_M, device, dtype) + A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype) out = kernel(A, B, batch_sizes, batch_offsets, batch_padded_offsets) ref_output = torch_gmm(A, B, batch_sizes, batch_offsets, trans_b) # print(out) @@ -157,8 +129,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list, if profile: profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) - latency = profiler.do_bench( - warmup=500, input_tensors=[A, B, batch_sizes, batch_offsets, batch_padded_offsets]) + latency = profiler.do_bench(warmup=500, input_tensors=[A, B, batch_sizes, batch_offsets, batch_padded_offsets]) print(f"Latency: {latency} ms") print(f"TFlops: {batch_sum * K * M * 2 / latency * 1e-9} TFlops") @@ -173,12 +144,11 @@ def test_grouped_gemm(): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - '--batch_sizes', type=str, default="64, 128", help='comma-separated batch sizes') - parser.add_argument('--K', type=int, default=8192, help='reduce dim') - parser.add_argument('--M', type=int, default=8192, help='output dim') - parser.add_argument('--trans_b', action="store_true", help="transpose B") - parser.add_argument('--profile', action="store_true", help="profile") + parser.add_argument("--batch_sizes", type=str, default="64, 128", help="comma-separated batch sizes") + parser.add_argument("--K", type=int, default=8192, help="reduce dim") + parser.add_argument("--M", type=int, default=8192, help="output dim") + parser.add_argument("--trans_b", action="store_true", help="transpose B") + parser.add_argument("--profile", action="store_true", help="profile") args = parser.parse_args() batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")] @@ -190,14 +160,4 @@ def test_grouped_gemm(): num_stages = 2 threads = 256 - run_tilelang_grouped_gemm( - batch_sizes_list, - K, - M, - block_M, - block_N, - block_K, - trans_b, - num_stages, - threads, - profile=args.profile) + run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages, threads, profile=args.profile) diff --git a/examples/hadamard_transform/example_hadamard.py b/examples/hadamard_transform/example_hadamard.py index 531d46891..65f463b71 100644 --- a/examples/hadamard_transform/example_hadamard.py +++ b/examples/hadamard_transform/example_hadamard.py @@ -17,7 +17,7 @@ def is_pow_of_2(n): def hadamard(b, n, dtype): assert is_pow_of_2(n), "n must be a power of 2" assert 2 <= n <= 32768, "n must be in [2, 32768]" - elem_size = {'float32': 4, 'float16': 2, 'bfloat16': 2}[dtype] + elem_size = {T.float32: 4, T.float16: 2, T.bfloat16: 2}[dtype] logN = int(math.log2(n)) threads = [0, 1, 1, 1, 2, 4, 8, 16, 32, 32, 128, 256, 256, 256, 256, 256][logN] @@ -40,23 +40,21 @@ def hadamard(b, n, dtype): # print(f'{exchange_round=}') @T.macro - def warp_shfl(local: T.Tensor((thread_elem,), dtype), buf: T.Tensor((thread_elem,), dtype), - round: int): + def warp_shfl(local: T.Tensor((thread_elem,), dtype), buf: T.Tensor((thread_elem,), dtype), round: int): tx = T.get_thread_binding(0) for i in T.serial(round): tx_stride = 1 << i another_tx = tx ^ tx_stride - sign = ( - tx >> i - ) & 1 # get i-th lowest bit of tx, which determines the operation type for shared[tx, :] + sign = (tx >> i) & 1 # get i-th lowest bit of tx, which determines the operation type for shared[tx, :] for j in T.Pipelined(thread_elem, num_stages=1): buf[j] = T.tvm_warp_shuffle( - 0xffffffff, # mask of all threads + 0xFFFFFFFF, # mask of all threads local[j], another_tx % warp_size, warp_size, - warp_size) + warp_size, + ) local[j] = T.if_then_else(sign == 0, local[j] + buf[j], buf[j] - local[j]) @T.prim_func @@ -78,10 +76,8 @@ def main(A: T.Tensor((b, n), dtype), B: T.Tensor((b, n), dtype)): for j in T.serial(chunknum): chunkbase = j * chunksize for k in T.serial(chunksize // 2): - local[chunkbase + - k] = local[chunkbase + k] + local[chunkbase + k + chunksize // 2] - local[chunkbase + k + chunksize // - 2] = local[chunkbase + k] - 2 * local[chunkbase + k + chunksize // 2] + local[chunkbase + k] = local[chunkbase + k] + local[chunkbase + k + chunksize // 2] + local[chunkbase + k + chunksize // 2] = local[chunkbase + k] - 2 * local[chunkbase + k + chunksize // 2] # 3. Hadamard inside warp, n<=512 # In warp level, we rely on warp shuffle to exchange data inside each warp, without using shared memory @@ -131,28 +127,27 @@ def ref_program(x: torch.Tensor): assert x.ndim == 2 dim = x.shape[-1] assert is_pow_of_2(dim) - return F.linear( - x, torch.tensor(scipy.linalg.hadamard(dim, dtype=float), dtype=x.dtype, device=x.device)) + return F.linear(x, torch.tensor(scipy.linalg.hadamard(dim, dtype=float), dtype=x.dtype, device=x.device)) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=64, help='Batch size') - parser.add_argument('--dim', type=int, default=32768, help='Dimension') + parser.add_argument("--batch", type=int, default=64, help="Batch size") + parser.add_argument("--dim", type=int, default=32768, help="Dimension") args = parser.parse_args() B, D = args.batch, args.dim - x = torch.randn((B, D), device='cuda') - kernel = hadamard(B, D, 'float32') + x = torch.randn((B, D), device="cuda") + kernel = hadamard(B, D, T.float32) y = kernel(x) y_ref = ref_program(x) torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2) - print('All tests passed.') + print("All tests passed.") profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) latency = profiler.do_bench(warmup=100) print("Tile-lang: {:.2f} ms".format(latency)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/lazy_jit/lazyjit.en.ipynb b/examples/lazy_jit/lazyjit.en.ipynb new file mode 100644 index 000000000..5b5df8e6a --- /dev/null +++ b/examples/lazy_jit/lazyjit.en.ipynb @@ -0,0 +1,977 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5e0deecc", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n", + "import tilelang\n", + "import torch\n", + "import tilelang.language as T" + ] + }, + { + "cell_type": "markdown", + "id": "1ca2c56d", + "metadata": {}, + "source": [ + "# Tilelang Lazy JIT" + ] + }, + { + "cell_type": "markdown", + "id": "156e7370", + "metadata": {}, + "source": [ + "## Tensor Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "b070c109", + "metadata": {}, + "source": [ + "Tilelang Lazy JIT merges JIT kernel generation and invocation into a single workflow.\n", + "\n", + "The function signature looks similar to Triton, but we add many enhancements; the most important one is allowing rich Tensor annotations:\n", + "\n", + "* If a Tensor has complex shape constraints, we can move its annotation into the function body.\n", + "* Use `T.const` or `T.dynamic` to create shape variables, then annotate complex Tensors with `T.Tensor`.\n", + "* Use `T.empty` to declare return tensors." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "60bf8954", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm(\n", + " A,\n", + " B,\n", + " out_dtype: T.dtype = T.float32,\n", + " block_M: int = 128,\n", + " block_N: int = 128,\n", + " block_K: int = 32,\n", + "):\n", + " M, N, K = T.const(\"M, N, K\")\n", + "\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + "\n", + " C = T.empty((M, N), out_dtype)\n", + "\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", + " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), out_dtype)\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", + " T.copy(A[bx * block_M, k * block_K], A_shared)\n", + " T.copy(B[k * block_K, by * block_N], B_shared)\n", + " T.gemm(A_shared, B_shared, C_local)\n", + " T.copy(C_local, C[bx * block_M, by * block_N])\n", + " return C" + ] + }, + { + "cell_type": "markdown", + "id": "28f868fe", + "metadata": {}, + "source": [ + "Calling the function with Tensors directly triggers the full JIT compile-and-run pipeline:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ee13394a", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm(A, B)\n", + "\n", + "# check output is correct\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "c6705091", + "metadata": {}, + "source": [ + "Changing the call arguments may trigger a recompilation when compilation parameters change:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d8aab5b7", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 1024, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm(A, B, block_M=64, block_N=64)" + ] + }, + { + "cell_type": "markdown", + "id": "ce6b7391", + "metadata": {}, + "source": [ + "You can also explicitly call the `compile` method to build the kernel.\n", + "\n", + "1. `ker.compile` compiles the kernel\n", + "2. `ker.get_tir` retrieves the TIR\n", + "3. `ker.par_compile` compiles in parallel" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f3cf3a2d", + "metadata": {}, + "outputs": [], + "source": [ + "kernel = gemm.compile(A, B, block_M=64, block_N=64)\n", + "C = kernel(A, B)" + ] + }, + { + "cell_type": "markdown", + "id": "921761b5", + "metadata": {}, + "source": [ + "## More Tensor Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "4539e54e", + "metadata": {}, + "source": [ + "### Use macros to separate implementation" + ] + }, + { + "cell_type": "markdown", + "id": "ad96ba65", + "metadata": {}, + "source": [ + "Next, we implement a simple GEMM in several different ways. For convenience, we first write a macro that contains the core GEMM logic:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "171d4fe6", + "metadata": {}, + "outputs": [], + "source": [ + "@T.macro\n", + "def gemm_impl(A, B, C, M, N, K, block_M, block_N, block_K):\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", + " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), C.dtype)\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", + " T.copy(A[bx * block_M, k * block_K], A_shared)\n", + " T.copy(B[k * block_K, by * block_N], B_shared)\n", + " T.gemm(A_shared, B_shared, C_local)\n", + " T.copy(C_local, C[bx * block_M, by * block_N])" + ] + }, + { + "cell_type": "markdown", + "id": "446a1acd", + "metadata": {}, + "source": [ + "### Use `T.dynamic` to mark dynamic shapes\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6a38aa95", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm_dyn_K(A, B):\n", + " M, N, K = T.dynamic(\"M, N, K\")\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, 128, 128, 32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "fe6cfdc8", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_dyn_K(A, B)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "2ee97bf7", + "metadata": {}, + "source": [ + "### Use `T.StridedTensor` to annotate tensors with strides\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "9dde1dae", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def as_contingious(A):\n", + " M, N, dM, dN = T.dynamic(\"M, N, dM, dN\")\n", + " A: T.StridedTensor[[M, N], [dM, dN], T.float32]\n", + " B = T.empty((M, N), A.dtype)\n", + " block_M = 128\n", + " block_N = 128\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " T.copy(\n", + " A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " )\n", + " return B" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "dec2c0a7", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 1024, device=\"cuda\")\n", + "B = as_contingious(A.T)\n", + "B_ref = A.T.contiguous()\n", + "torch.testing.assert_close(B, B_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "f5fb20d6", + "metadata": {}, + "source": [ + "## More Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "890df0a2", + "metadata": {}, + "source": [ + "### Use parameters directly as annotations" + ] + }, + { + "cell_type": "markdown", + "id": "e9a47d42", + "metadata": {}, + "source": [ + "You can directly use function parameters in the annotations." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "0fc17af6", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm_ptr(\n", + " A,\n", + " B,\n", + " M,\n", + " N,\n", + " K,\n", + "):\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "8e52a554", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_ptr(A, B, 1024, 256, 512)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "6b19ef90", + "metadata": {}, + "source": [ + "### Annotations for runtime variables" + ] + }, + { + "cell_type": "markdown", + "id": "bba5f27f", + "metadata": {}, + "source": [ + "Runtime variables work the same; if the function annotation becomes too long, you can move it into the function body." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "c1e7598a", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm_ptr_dyn(A, B, M, N, K):\n", + " M: T.int32\n", + " N: T.int32\n", + " K: T.int32\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "9e9a4c88", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "81427765", + "metadata": {}, + "source": [ + "### Constraints for constants" + ] + }, + { + "cell_type": "markdown", + "id": "4d6b084b", + "metadata": {}, + "source": [ + "A constant annotation created by `T.const` must be used directly at least once, otherwise an error is raised." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "c90dd24f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Constexpr variable `M` is not used in any buffer shape or stride.\n", + "At least one **DIRECT** usage is required. Please check:\n", + "(1) the variable is not used\n", + "(2) all uses are indirect, e.g. M * 2, M * 3. (you can replace them with separate constexpr variables)\n", + "Buffer shapes: {A: [M * 2, M * 3]}\n", + "Buffer strides: {A: [M * 3, 1]}\n" + ] + } + ], + "source": [ + "@tilelang.lazy_jit\n", + "def example_wrong_kernel(A):\n", + " M = T.const(\"M\")\n", + " A: T.Tensor[[M * 2, M * 3], T.float32]\n", + " with T.Kernel(1) as _:\n", + " A[0, 0]\n", + "\n", + "\n", + "try:\n", + " A = torch.randn(64, 96, dtype=torch.float32, device=\"cuda\")\n", + " example_wrong_kernel(A)\n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "id": "e07e762b", + "metadata": {}, + "source": [ + "### Dynamic dimensions" + ] + }, + { + "cell_type": "markdown", + "id": "f48e5d7a", + "metadata": {}, + "source": [ + "If you want certain parameters in a Tensor annotation to change, it is recommended to switch to the `T.ptr` + `T.match_buffer` style." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "1d050321", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@tilelang.lazy_jit\n", + "def dyn_annot(\n", + " A: T.ptr, # 1. T.ptr type annotation\n", + " is_2d=False,\n", + "):\n", + " if is_2d:\n", + " M, N = T.const(\"M, N\")\n", + " # 2. dynamic shape annotation inside function body\n", + " A = T.match_buffer(A, [M, N], T.float32)\n", + " with T.Kernel(1) as _:\n", + " A[0, 0]\n", + " else:\n", + " L = T.const(\"L\")\n", + " A = T.match_buffer(A, [L], T.float32)\n", + " with T.Kernel(1) as _:\n", + " A[0]\n", + "\n", + "\n", + "A = torch.randn(64, 96, dtype=torch.float32, device=\"cuda\")\n", + "dyn_annot(A, is_2d=True)" + ] + }, + { + "cell_type": "markdown", + "id": "2e9f1bb3", + "metadata": {}, + "source": [ + "### Default arguments" + ] + }, + { + "cell_type": "markdown", + "id": "f7fc9917", + "metadata": {}, + "source": [ + "Scalar annotations like `T.float32` can carry default values." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "42ec86a1", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def add_one(X, data: T.float32 = 1):\n", + " M, N = T.const(\"M, N\")\n", + " X: T.Tensor[[M, N], T.float32]\n", + " Y = T.empty((M, N), T.float32)\n", + " with T.Kernel(T.ceildiv(M, 128), threads=128) as bx:\n", + " for i, j in T.Parallel(128, N):\n", + " Y[bx * 128 + i, j] = X[bx * 128 + i, j] + data\n", + " return Y" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "d49e1120", + "metadata": {}, + "outputs": [], + "source": [ + "X = torch.randn(1024, 1024, dtype=torch.float32, device=\"cuda\")\n", + "Y = add_one(X)\n", + "torch.testing.assert_close(Y, X + 1)" + ] + }, + { + "cell_type": "markdown", + "id": "a02baedc", + "metadata": {}, + "source": [ + "## Overhead of argument matching" + ] + }, + { + "cell_type": "markdown", + "id": "860a2972", + "metadata": {}, + "source": [ + "LazyJIT has very small overhead; each additional constant annotation costs about 200 ns.\n", + "* 200 ns is roughly the cost of an FFI call that reads parameters from a `torch.Tensor`'s shape/stride." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc676e33", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Kernel call : 7.68 us\n", + "Parse cache key: 0.41 us\n" + ] + } + ], + "source": [ + "import time\n", + "\n", + "A = torch.randn(128, 128, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(128, 128, dtype=torch.float16, device=\"cuda\")\n", + "\n", + "\n", + "@tilelang.lazy_jit\n", + "def dummy_kernel(A, B):\n", + " M, N = T.const(\"M, N\")\n", + " A: T.Tensor[[M, N], T.float16]\n", + " B: T.Tensor[[M, N], T.float16]\n", + " with T.Kernel(1) as _:\n", + " pass\n", + "\n", + "\n", + "# compile it first\n", + "dummy_kernel(A, B)\n", + "\n", + "\n", + "def eval_overhead(f):\n", + " start = time.perf_counter_ns()\n", + " for _ in range(10000):\n", + " f()\n", + " stop = time.perf_counter_ns()\n", + " return (stop - start) / 10000 / 1000\n", + "\n", + "\n", + "kernel_call_overhead = eval_overhead(lambda: dummy_kernel(A, B))\n", + "parse_cache_key_overhead = eval_overhead(lambda: dummy_kernel.parse_cache_key(A, B))\n", + "\n", + "print(f\"Kernel call : {kernel_call_overhead:.2f} us\")\n", + "print(f\"Parse cache key: {parse_cache_key_overhead:.2f} us\")" + ] + }, + { + "cell_type": "markdown", + "id": "39166cb4", + "metadata": {}, + "source": [ + "## Compilation and parallel compilation" + ] + }, + { + "cell_type": "markdown", + "id": "8c6fbe08", + "metadata": {}, + "source": [ + "Both `lazyjit` and the original `jit` support parallel compilation.\n", + "\n", + "To avoid wasting memory on temporary `torch.Tensor` objects, you can use `T.Tensor` to create placeholders." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "7222e57b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8a4e4eb3cd4445bda6e8693da31ef3b8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Elaborating: 0%| | 0/8 [00:00,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from itertools import product\n", + "\n", + "\n", + "def get_configs():\n", + " return [\n", + " {\n", + " \"A\": T.Tensor((1024, 1024), T.float32),\n", + " \"B\": T.Tensor((1024, 1024), T.float32),\n", + " \"block_M\": block_M,\n", + " \"block_N\": block_N,\n", + " \"block_K\": block_K,\n", + " }\n", + " for block_M, block_N, block_K in product([32, 64], repeat=3)\n", + " ]\n", + "\n", + "\n", + "gemm.par_compile(get_configs())" + ] + }, + { + "cell_type": "markdown", + "id": "5160d2cc", + "metadata": {}, + "source": [ + "## More convenient macros" + ] + }, + { + "cell_type": "markdown", + "id": "be44afc4", + "metadata": {}, + "source": [ + "tilelang's macros have been improved:\n", + "\n", + "1. Allow using `T.Ref` as an annotation, similar to C++ references.\n", + "2. Allow returning multiple values.\n", + "3. Allow nesting and recursion." + ] + }, + { + "cell_type": "markdown", + "id": "79575972", + "metadata": {}, + "source": [ + "### Passing references with `T.Ref`\n", + "\n", + "A `T.Ref` reference can point to a scalar variable or to an element of a buffer." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "90eaa6e5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# from tvm.script import tir as T\n", + "\n", + "@T.prim_func\n", + "def foo(x_handle: T.handle):\n", + " x = T.match_buffer(x_handle, (2,), strides=(1,))\n", + " # with T.block(\"root\"):\n", + " bx = T.launch_thread(\"blockIdx.x\", 1)\n", + " tx = T.launch_thread(\"threadIdx.x\", 128)\n", + " ty = T.launch_thread(\"threadIdx.y\", 1)\n", + " tz = T.launch_thread(\"threadIdx.z\", 1)\n", + " with T.block(\"tilelang_root\"):\n", + " T.reads()\n", + " idx = T.Buffer((1,), \"int32\", scope=\"local.var\")\n", + " T.writes(x[T.min(1, idx[0]):T.min(1, idx[0]) + (T.max(1, idx[0]) + 1 - T.min(1, idx[0]))])\n", + " T.block_attr({\"tl.local_var_init\": {idx.data: 0}})\n", + " idx = T.alloc_buffer((1,), \"int32\", data=idx.data, scope=\"local.var\")\n", + " x[1] = T.float32(1.0)\n", + " _tmp: T.int32 = idx[0]\n", + " x[_tmp] = T.float32(1.0)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@T.macro\n", + "def macro_with_ref(x: T.Ref):\n", + " x = 1 # noqa: F841\n", + "\n", + "\n", + "@T.prim_func\n", + "def foo(x: T.Tensor((2,))):\n", + " with T.Kernel(1) as _:\n", + " # Supports constant indices\n", + " macro_with_ref(x[1])\n", + "\n", + " # Also supports variable indices\n", + " idx = T.alloc_var(T.int32, 0)\n", + " macro_with_ref(x[idx])\n", + "\n", + "\n", + "foo" + ] + }, + { + "cell_type": "markdown", + "id": "7bb447a2", + "metadata": {}, + "source": [ + "### Pass macros as arguments\n", + "\n", + "You can pass a macro as a function argument." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "dc7bb779", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def element_wise(A, fn):\n", + " N = T.dynamic(\"N\")\n", + " A: T.Tensor[[N], T.float32]\n", + " B = T.empty((N,), dtype=A.dtype)\n", + " block_N = 128\n", + " with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n", + " for i in T.Parallel(block_N):\n", + " idx = bx * block_N + i\n", + " B[idx] = fn(A[idx])\n", + " return B\n", + "\n", + "\n", + "@T.macro\n", + "def add_one(x):\n", + " return x + 1" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "a89fdb44", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, device=\"cuda\")\n", + "B = element_wise(A, add_one)\n", + "B_ref = A + 1\n", + "torch.testing.assert_close(B, B_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "ef6e403a", + "metadata": {}, + "source": [ + "### Recursive macros\n", + "\n", + "You may not need this often, but macros can be recursive as long as the termination condition is known at compile time." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "7703cab5", + "metadata": {}, + "outputs": [], + "source": [ + "@T.macro\n", + "def n31(x, var: T.Ref):\n", + " if x == 1:\n", + " pass\n", + " elif x % 2 == 0:\n", + " var = var // 2\n", + " n31(x // 2, var)\n", + " else:\n", + " var = var * 3 + 1\n", + " n31(x * 3 + 1, var)\n", + "\n", + "\n", + "@tilelang.lazy_jit\n", + "def foo(A: T.Tensor[[1], T.int32], n: int):\n", + " with T.Kernel(1) as _:\n", + " n31(n, A[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "542ddd4e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([18], device='cuda:0', dtype=torch.int32)" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "A = torch.tensor([100], dtype=torch.int32, device=\"cuda\")\n", + "foo(A, 5)\n", + "A" + ] + }, + { + "cell_type": "markdown", + "id": "dc30c2d2", + "metadata": {}, + "source": [ + "### Macros returning multiple values" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "d5a2388f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# from tvm.script import tir as T\n", + "\n", + "@T.prim_func\n", + "def foo():\n", + " # with T.block(\"root\"):\n", + " x = T.launch_thread(\"blockIdx.x\", 32)\n", + " tx = T.launch_thread(\"threadIdx.x\", 128)\n", + " ty = T.launch_thread(\"threadIdx.y\", 1)\n", + " tz = T.launch_thread(\"threadIdx.z\", 1)\n", + " with T.block(\"tilelang_root\"):\n", + " T.reads()\n", + " T.writes()\n", + " s: T.int32 = T.sin(x)\n", + " c: T.int32 = T.cos(x)\n", + " a: T.int32 = s + c\n", + " b: T.int32 = s - c\n", + " T.evaluate(0)" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@T.macro\n", + "def sincos(x):\n", + " return T.sin(x), T.cos(x)\n", + "\n", + "\n", + "@T.prim_func\n", + "def foo():\n", + " with T.Kernel(32) as x:\n", + " s, c = sincos(x)\n", + " a = s + c # noqa: F841\n", + " b = s - c # noqa: F841\n", + "\n", + "\n", + "foo" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd83fea7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tilelang-dev_0", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/lazy_jit/lazyjit.zh.ipynb b/examples/lazy_jit/lazyjit.zh.ipynb new file mode 100644 index 000000000..387aff461 --- /dev/null +++ b/examples/lazy_jit/lazyjit.zh.ipynb @@ -0,0 +1,977 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5e0deecc", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n", + "import tilelang\n", + "import torch\n", + "import tilelang.language as T" + ] + }, + { + "cell_type": "markdown", + "id": "1ca2c56d", + "metadata": {}, + "source": [ + "# Tilelang Lazy JIT" + ] + }, + { + "cell_type": "markdown", + "id": "156e7370", + "metadata": {}, + "source": [ + "## Tensor Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "b070c109", + "metadata": {}, + "source": [ + "Tilelang Lazy JIT 将 jit 生成和调用的逻辑合并到一起\n", + "\n", + "函数签名的写法与 triton 相似,但做了大量增强,最主要的增强是允许对 Tensor 的标注:\n", + "\n", + "* 如果一个 Tensor 有复杂的 shape 约束,我们可以把它的标注移动到函数内部\n", + "* 通过 `T.const` 或 `T.dynamic` 来建立一些 shape 变量,然后用 `T.Tensor` 标注复杂的 Tensor\n", + "* 用 `T.empty` 来声明返回值" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "60bf8954", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm(\n", + " A,\n", + " B,\n", + " out_dtype: T.dtype = T.float32,\n", + " block_M: int = 128,\n", + " block_N: int = 128,\n", + " block_K: int = 32,\n", + "):\n", + " M, N, K = T.const(\"M, N, K\")\n", + "\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + "\n", + " C = T.empty((M, N), out_dtype)\n", + "\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", + " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), out_dtype)\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", + " T.copy(A[bx * block_M, k * block_K], A_shared)\n", + " T.copy(B[k * block_K, by * block_N], B_shared)\n", + " T.gemm(A_shared, B_shared, C_local)\n", + " T.copy(C_local, C[bx * block_M, by * block_N])\n", + " return C" + ] + }, + { + "cell_type": "markdown", + "id": "28f868fe", + "metadata": {}, + "source": [ + "直接将 Tensor 作为参数调用,即可触发完整的 jit 编译运行流程:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ee13394a", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm(A, B)\n", + "\n", + "# check output is correct\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "c6705091", + "metadata": {}, + "source": [ + "更改调用的参数,如果编译器参数发生了变化,会触发重新编译:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d8aab5b7", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 1024, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm(A, B, block_M=64, block_N=64)" + ] + }, + { + "cell_type": "markdown", + "id": "ce6b7391", + "metadata": {}, + "source": [ + "你也可以手动调用 compile 函数编译 kernel\n", + "\n", + "1. `ker.compile` 编译 kernel\n", + "2. `ker.get_tir` 获取 tir\n", + "3. `ker.par_compile` 并行编译" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f3cf3a2d", + "metadata": {}, + "outputs": [], + "source": [ + "kernel = gemm.compile(A, B, block_M=64, block_N=64)\n", + "C = kernel(A, B)" + ] + }, + { + "cell_type": "markdown", + "id": "921761b5", + "metadata": {}, + "source": [ + "## More Tensor Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "4539e54e", + "metadata": {}, + "source": [ + "### 用 macro 来分离实现" + ] + }, + { + "cell_type": "markdown", + "id": "ad96ba65", + "metadata": {}, + "source": [ + "接下来,我们会用各种方式来实现一个简单的 gemm,为了方便,我们先写一个 macro 把 gemm 的主要逻辑写出来:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "171d4fe6", + "metadata": {}, + "outputs": [], + "source": [ + "@T.macro\n", + "def gemm_impl(A, B, C, M, N, K, block_M, block_N, block_K):\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", + " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), C.dtype)\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", + " T.copy(A[bx * block_M, k * block_K], A_shared)\n", + " T.copy(B[k * block_K, by * block_N], B_shared)\n", + " T.gemm(A_shared, B_shared, C_local)\n", + " T.copy(C_local, C[bx * block_M, by * block_N])" + ] + }, + { + "cell_type": "markdown", + "id": "446a1acd", + "metadata": {}, + "source": [ + "### 用 T.dynamic 标记动态 Shape\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6a38aa95", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm_dyn_K(A, B):\n", + " M, N, K = T.dynamic(\"M, N, K\")\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, 128, 128, 32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "fe6cfdc8", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_dyn_K(A, B)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "2ee97bf7", + "metadata": {}, + "source": [ + "### 用 T.StridedTensor 标记带 stride 的 Tensor\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "9dde1dae", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def as_contingious(A):\n", + " M, N, dM, dN = T.dynamic(\"M, N, dM, dN\")\n", + " A: T.StridedTensor[[M, N], [dM, dN], T.float32]\n", + " B = T.empty((M, N), A.dtype)\n", + " block_M = 128\n", + " block_N = 128\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " T.copy(\n", + " A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " )\n", + " return B" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "dec2c0a7", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 1024, device=\"cuda\")\n", + "B = as_contingious(A.T)\n", + "B_ref = A.T.contiguous()\n", + "torch.testing.assert_close(B, B_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "f5fb20d6", + "metadata": {}, + "source": [ + "## More Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "890df0a2", + "metadata": {}, + "source": [ + "### 直接用参数当 annotation" + ] + }, + { + "cell_type": "markdown", + "id": "e9a47d42", + "metadata": {}, + "source": [ + "可以直接把函数参数写到 annotation 里面" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "0fc17af6", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm_ptr(\n", + " A,\n", + " B,\n", + " M,\n", + " N,\n", + " K,\n", + "):\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "8e52a554", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_ptr(A, B, 1024, 256, 512)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "6b19ef90", + "metadata": {}, + "source": [ + "### 对运行时变量的 annotation" + ] + }, + { + "cell_type": "markdown", + "id": "bba5f27f", + "metadata": {}, + "source": [ + "运行时变量也是一样,如果嫌函数 annotation 太长,可以放到函数体里面" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "c1e7598a", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm_ptr_dyn(A, B, M, N, K):\n", + " M: T.int32\n", + " N: T.int32\n", + " K: T.int32\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "9e9a4c88", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "81427765", + "metadata": {}, + "source": [ + "### 常量的约束" + ] + }, + { + "cell_type": "markdown", + "id": "4d6b084b", + "metadata": {}, + "source": [ + "`T.const` 创建的常量 annotation 只要要被直接使用一次,否则会报错" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "c90dd24f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Constexpr variable `M` is not used in any buffer shape or stride.\n", + "At least one **DIRECT** usage is required. Please check:\n", + "(1) the variable is not used\n", + "(2) all uses are indirect, e.g. M * 2, M * 3. (you can replace them with separate constexpr variables)\n", + "Buffer shapes: {A: [M * 2, M * 3]}\n", + "Buffer strides: {A: [M * 3, 1]}\n" + ] + } + ], + "source": [ + "@tilelang.lazy_jit\n", + "def example_wrong_kernel(A):\n", + " M = T.const(\"M\")\n", + " A: T.Tensor[[M * 2, M * 3], T.float32]\n", + " with T.Kernel(1) as _:\n", + " A[0, 0]\n", + "\n", + "\n", + "try:\n", + " A = torch.randn(64, 96, dtype=torch.float32, device=\"cuda\")\n", + " example_wrong_kernel(A)\n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "id": "e07e762b", + "metadata": {}, + "source": [ + "### 动态维度的" + ] + }, + { + "cell_type": "markdown", + "id": "f48e5d7a", + "metadata": {}, + "source": [ + "如果想要 Tensor 的 annotation 类型某个参数变化,建议改成 T.ptr + T.match_buffer 格式。" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "1d050321", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@tilelang.lazy_jit\n", + "def dyn_annot(\n", + " A: T.ptr, # 1. T.ptr type annotation\n", + " is_2d=False,\n", + "):\n", + " if is_2d:\n", + " M, N = T.const(\"M, N\")\n", + " # 2. dynamic shape annotation inside function body\n", + " A = T.match_buffer(A, [M, N], T.float32)\n", + " with T.Kernel(1) as _:\n", + " A[0, 0]\n", + " else:\n", + " L = T.const(\"L\")\n", + " A = T.match_buffer(A, [L], T.float32)\n", + " with T.Kernel(1) as _:\n", + " A[0]\n", + "\n", + "\n", + "A = torch.randn(64, 96, dtype=torch.float32, device=\"cuda\")\n", + "dyn_annot(A, is_2d=True)" + ] + }, + { + "cell_type": "markdown", + "id": "2e9f1bb3", + "metadata": {}, + "source": [ + "### 带默认参数的" + ] + }, + { + "cell_type": "markdown", + "id": "f7fc9917", + "metadata": {}, + "source": [ + "类似 `T.float32` 标注的标量可以带默认参数" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "42ec86a1", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def add_one(X, data: T.float32 = 1):\n", + " M, N = T.const(\"M, N\")\n", + " X: T.Tensor[[M, N], T.float32]\n", + " Y = T.empty((M, N), T.float32)\n", + " with T.Kernel(T.ceildiv(M, 128), threads=128) as bx:\n", + " for i, j in T.Parallel(128, N):\n", + " Y[bx * 128 + i, j] = X[bx * 128 + i, j] + data\n", + " return Y" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "d49e1120", + "metadata": {}, + "outputs": [], + "source": [ + "X = torch.randn(1024, 1024, dtype=torch.float32, device=\"cuda\")\n", + "Y = add_one(X)\n", + "torch.testing.assert_close(Y, X + 1)" + ] + }, + { + "cell_type": "markdown", + "id": "a02baedc", + "metadata": {}, + "source": [ + "## 参数匹配的 Overhead" + ] + }, + { + "cell_type": "markdown", + "id": "860a2972", + "metadata": {}, + "source": [ + "LazyJIT overhead 很小,每个 constant 添加约 200ns 的 overhead\n", + "* 200ns 大约是从 torch.Tensor 的 shape/stride 中拿参数的 ffi call 的代价" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc676e33", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Kernel call : 7.68 us\n", + "Parse cache key: 0.41 us\n" + ] + } + ], + "source": [ + "import time\n", + "\n", + "A = torch.randn(128, 128, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(128, 128, dtype=torch.float16, device=\"cuda\")\n", + "\n", + "\n", + "@tilelang.lazy_jit\n", + "def dummy_kernel(A, B):\n", + " M, N = T.const(\"M, N\")\n", + " A: T.Tensor[[M, N], T.float16]\n", + " B: T.Tensor[[M, N], T.float16]\n", + " with T.Kernel(1) as _:\n", + " pass\n", + "\n", + "\n", + "# compile it first\n", + "dummy_kernel(A, B)\n", + "\n", + "\n", + "def eval_overhead(f):\n", + " start = time.perf_counter_ns()\n", + " for _ in range(10000):\n", + " f()\n", + " stop = time.perf_counter_ns()\n", + " return (stop - start) / 10000 / 1000\n", + "\n", + "\n", + "kernel_call_overhead = eval_overhead(lambda: dummy_kernel(A, B))\n", + "parse_cache_key_overhead = eval_overhead(lambda: dummy_kernel.parse_cache_key(A, B))\n", + "\n", + "print(f\"Kernel call : {kernel_call_overhead:.2f} us\")\n", + "print(f\"Parse cache key: {parse_cache_key_overhead:.2f} us\")" + ] + }, + { + "cell_type": "markdown", + "id": "39166cb4", + "metadata": {}, + "source": [ + "## 编译与并行编译" + ] + }, + { + "cell_type": "markdown", + "id": "8c6fbe08", + "metadata": {}, + "source": [ + "lazyjit 和原来的 jit 都支持并行编译\n", + "\n", + "为了防止 torch.tensor 白白浪费内存,可以使用 T.Tensor 来创建 placeholder" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "7222e57b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8a4e4eb3cd4445bda6e8693da31ef3b8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Elaborating: 0%| | 0/8 [00:00,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from itertools import product\n", + "\n", + "\n", + "def get_configs():\n", + " return [\n", + " {\n", + " \"A\": T.Tensor((1024, 1024), T.float32),\n", + " \"B\": T.Tensor((1024, 1024), T.float32),\n", + " \"block_M\": block_M,\n", + " \"block_N\": block_N,\n", + " \"block_K\": block_K,\n", + " }\n", + " for block_M, block_N, block_K in product([32, 64], repeat=3)\n", + " ]\n", + "\n", + "\n", + "gemm.par_compile(get_configs())" + ] + }, + { + "cell_type": "markdown", + "id": "5160d2cc", + "metadata": {}, + "source": [ + "## 更便利的 Macro" + ] + }, + { + "cell_type": "markdown", + "id": "be44afc4", + "metadata": {}, + "source": [ + "tilelang 的 macro 现在已经升级:\n", + "\n", + "1. 允许用 `T.Ref` 作为 annotation,这类似与 C++ 的引用传递\n", + "2. 允许返回多个值\n", + "3. 允许嵌套,递归" + ] + }, + { + "cell_type": "markdown", + "id": "79575972", + "metadata": {}, + "source": [ + "### T.Ref 传递引用\n", + "\n", + "T.Ref 传递的引用可以 var 也可以是 Buffer 的索引" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "90eaa6e5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# from tvm.script import tir as T\n", + "\n", + "@T.prim_func\n", + "def foo(x_handle: T.handle):\n", + " x = T.match_buffer(x_handle, (2,), strides=(1,))\n", + " # with T.block(\"root\"):\n", + " bx = T.launch_thread(\"blockIdx.x\", 1)\n", + " tx = T.launch_thread(\"threadIdx.x\", 128)\n", + " ty = T.launch_thread(\"threadIdx.y\", 1)\n", + " tz = T.launch_thread(\"threadIdx.z\", 1)\n", + " with T.block(\"tilelang_root\"):\n", + " T.reads()\n", + " idx = T.Buffer((1,), \"int32\", scope=\"local.var\")\n", + " T.writes(x[T.min(1, idx[0]):T.min(1, idx[0]) + (T.max(1, idx[0]) + 1 - T.min(1, idx[0]))])\n", + " T.block_attr({\"tl.local_var_init\": {idx.data: 0}})\n", + " idx = T.alloc_buffer((1,), \"int32\", data=idx.data, scope=\"local.var\")\n", + " x[1] = T.float32(1.0)\n", + " _tmp: T.int32 = idx[0]\n", + " x[_tmp] = T.float32(1.0)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@T.macro\n", + "def macro_with_ref(x: T.Ref):\n", + " x = 1 # noqa: F841\n", + "\n", + "\n", + "@T.prim_func\n", + "def foo(x: T.Tensor((2,))):\n", + " with T.Kernel(1) as _:\n", + " # 支持常量 index\n", + " macro_with_ref(x[1])\n", + "\n", + " # 也支持变量 index\n", + " idx = T.alloc_var(T.int32, 0)\n", + " macro_with_ref(x[idx])\n", + "\n", + "\n", + "foo" + ] + }, + { + "cell_type": "markdown", + "id": "7bb447a2", + "metadata": {}, + "source": [ + "### 当作参数传递\n", + "\n", + "你可以把 macro 当做参数传递" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "dc7bb779", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def element_wise(A, fn):\n", + " N = T.dynamic(\"N\")\n", + " A: T.Tensor[[N], T.float32]\n", + " B = T.empty((N,), dtype=A.dtype)\n", + " block_N = 128\n", + " with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n", + " for i in T.Parallel(block_N):\n", + " idx = bx * block_N + i\n", + " B[idx] = fn(A[idx])\n", + " return B\n", + "\n", + "\n", + "@T.macro\n", + "def add_one(x):\n", + " return x + 1" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "a89fdb44", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, device=\"cuda\")\n", + "B = element_wise(A, add_one)\n", + "B_ref = A + 1\n", + "torch.testing.assert_close(B, B_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "ef6e403a", + "metadata": {}, + "source": [ + "### Macro 递归\n", + "\n", + "虽然不知道有没有这种需求,但 macro 是可以递归的,终止条件要求编译期间确定" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "7703cab5", + "metadata": {}, + "outputs": [], + "source": [ + "@T.macro\n", + "def n31(x, var: T.Ref):\n", + " if x == 1:\n", + " pass\n", + " elif x % 2 == 0:\n", + " var = var // 2\n", + " n31(x // 2, var)\n", + " else:\n", + " var = var * 3 + 1\n", + " n31(x * 3 + 1, var)\n", + "\n", + "\n", + "@tilelang.lazy_jit\n", + "def foo(A: T.Tensor[[1], T.int32], n: int):\n", + " with T.Kernel(1) as _:\n", + " n31(n, A[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "542ddd4e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([18], device='cuda:0', dtype=torch.int32)" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "A = torch.tensor([100], dtype=torch.int32, device=\"cuda\")\n", + "foo(A, 5)\n", + "A" + ] + }, + { + "cell_type": "markdown", + "id": "dc30c2d2", + "metadata": {}, + "source": [ + "### Macro 返回多个值" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "d5a2388f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# from tvm.script import tir as T\n", + "\n", + "@T.prim_func\n", + "def foo():\n", + " # with T.block(\"root\"):\n", + " x = T.launch_thread(\"blockIdx.x\", 32)\n", + " tx = T.launch_thread(\"threadIdx.x\", 128)\n", + " ty = T.launch_thread(\"threadIdx.y\", 1)\n", + " tz = T.launch_thread(\"threadIdx.z\", 1)\n", + " with T.block(\"tilelang_root\"):\n", + " T.reads()\n", + " T.writes()\n", + " s: T.int32 = T.sin(x)\n", + " c: T.int32 = T.cos(x)\n", + " a: T.int32 = s + c\n", + " b: T.int32 = s - c\n", + " T.evaluate(0)" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@T.macro\n", + "def sincos(x):\n", + " return T.sin(x), T.cos(x)\n", + "\n", + "\n", + "@T.prim_func\n", + "def foo():\n", + " with T.Kernel(32) as x:\n", + " s, c = sincos(x)\n", + " a = s + c # noqa: F841\n", + " b = s - c # noqa: F841\n", + "\n", + "\n", + "foo" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd83fea7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tilelang-dev_0", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/linear_attention/example_linear_attn_bwd.py b/examples/linear_attention/example_linear_attn_bwd.py index 568bcc55f..82ae1d982 100644 --- a/examples/linear_attention/example_linear_attn_bwd.py +++ b/examples/linear_attention/example_linear_attn_bwd.py @@ -13,20 +13,20 @@ pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + } +) def tl_fused_chunk_bwd_kernel( B, S, H, DK, DV, - dtype: str = 'float16', + dtype: T.dtype = T.float16, scale: float = None, ) -> torch.Tensor: - if scale is None: scale = DK**-0.5 - accum_dtype = 'float' + accum_dtype = T.float32 chunk_size = 64 BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA @@ -37,13 +37,13 @@ def tl_fused_chunk_bwd_kernel( @T.prim_func def fused_chunk_linear_attn_bwd( - Q: T.Tensor([B, S, H, DK], dtype), # type: ignore - K: T.Tensor([B, S, H, DK], dtype), # type: ignore - V: T.Tensor([B, S, H, DV], dtype), # type: ignore - dO: T.Tensor([B, S, H, DV], dtype), # type: ignore - dQ: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore - dK: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore - dV: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + dO: T.Tensor([B, S, H, DV], dtype), # type: ignore + dQ: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore + dK: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore + dV: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore ): with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): i_b = i_bh // H @@ -66,11 +66,6 @@ def fused_chunk_linear_attn_bwd( dh = T.alloc_fragment([BK, BV], accum_dtype) dh_shared = T.alloc_shared([BK, BV], dtype) - T.annotate_layout({ - dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared) - }) T.use_swizzle(10) T.clear(h) @@ -78,10 +73,9 @@ def fused_chunk_linear_attn_bwd( # Calculate dQ for i in T.Pipelined(0, NT): - T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) - T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v) - T.copy(dO[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], - do) + T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k) + T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v) + T.copy(dO[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], do) T.gemm(do, v, ds, transpose_B=True, clear_accum=True) for row, col in T.Parallel(chunk_size, chunk_size): @@ -94,29 +88,19 @@ def fused_chunk_linear_attn_bwd( for row, col in T.Parallel(chunk_size, BK): dq[row, col] *= scale T.copy(dq, dq_shared) - T.atomic_add( - dQ[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], - dq_shared) + T.atomic_add(dQ[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], dq_shared) # Calculate dK, dV (reversely) for i in T.Pipelined(1, NT + 1): start = NT - i for row, col in T.Parallel(chunk_size, BK): q[row, col] = Q[i_b, start * chunk_size + row, i_h, i_k * BK + col] * scale - T.copy( - K[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, - i_k * BK:(i_k + 1) * BK], k) - T.copy( - V[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, - i_v * BV:(i_v + 1) * BV], v) - T.copy( - dO[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, - i_v * BV:(i_v + 1) * BV], do) + T.copy(K[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k) + T.copy(V[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v) + T.copy(dO[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], do) # Calculate dk - T.gemm( - v, do, ds, transpose_B=True, clear_accum=True - ) # ds here actually means `s`, but we simply reuse the buffer `ds` + T.gemm(v, do, ds, transpose_B=True, clear_accum=True) # ds here actually means `s`, but we simply reuse the buffer `ds` for row, col in T.Parallel(chunk_size, chunk_size): ds_shared[row, col] = T.if_then_else(row <= col, ds[row, col], 0) T.gemm(ds_shared, q, dk, clear_accum=True) @@ -134,13 +118,9 @@ def fused_chunk_linear_attn_bwd( T.gemm(q, do, dh, transpose_A=True) T.copy(dk, dk_shared) - T.atomic_add( - dK[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, - i_k * BK:(i_k + 1) * BK], dk_shared) + T.atomic_add(dK[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], dk_shared) T.copy(dv, dv_shared) - T.atomic_add( - dV[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, - i_v * BV:(i_v + 1) * BV], dv_shared) + T.atomic_add(dV[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], dv_shared) return fused_chunk_linear_attn_bwd @@ -155,34 +135,31 @@ def tl_fused_chunk_bwd(Q, K, V, dO): return dQ.to(torch.float16), dK.to(torch.float16), dV.to(torch.float16) -def ref_program(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]: +def ref_program(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]: q, k, v = q.float(), k.float(), v.float() if scale is None: - scale = q.shape[-1]**-0.5 + scale = q.shape[-1] ** -0.5 chunk_size = 64 - q = rearrange(q, 'b (n c) h d -> b h n c d', c=chunk_size) * scale - k = rearrange(k, 'b (n c) h d -> b h n c d', c=chunk_size) - v = rearrange(v, 'b (n c) h d -> b h n c d', c=chunk_size) + q = rearrange(q, "b (n c) h d -> b h n c d", c=chunk_size) * scale + k = rearrange(k, "b (n c) h d -> b h n c d", c=chunk_size) + v = rearrange(v, "b (n c) h d -> b h n c d", c=chunk_size) kv = k.transpose(-1, -2) @ v kv = kv.cumsum(2) h = kv[:, :, -1, :, :] kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) inter = q @ kv - intra = ((q @ k.transpose(-1, -2)).masked_fill_( - torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), - 0)) @ v + intra = ( + (q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0) + ) @ v o = inter + intra - return rearrange(o, 'b h n c d -> b (n c) h d'), h + return rearrange(o, "b h n c d -> b (n c) h d"), h def main(B=1, S=1024, H=16, D=128): - q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) - k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) - v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) - do = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) + q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True) + k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True) + v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True) + do = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) # qk norm is necessary for linear attn q = l2norm_fwd(q)[0].requires_grad_(True) @@ -193,30 +170,42 @@ def main(B=1, S=1024, H=16, D=128): o_ref, _ = ref_program(q, k, v) o_ref.backward(do, retain_graph=True) - assert torch.allclose( - dq, q.grad, atol=1e-2, rtol=1e-2), f'dq max err: {(dq - q.grad).abs().max()}' - assert torch.allclose( - dk, k.grad, atol=1e-2, rtol=1e-2), f'dk max err: {(dk - k.grad).abs().max()}' - assert torch.allclose( - dv, v.grad, atol=1e-2, rtol=1e-2), f'dv max err: {(dv - v.grad).abs().max()}' - print('Passed all tests!✅') + assert torch.allclose(dq, q.grad, atol=1e-2, rtol=1e-2), f"dq max err: {(dq - q.grad).abs().max()}" + assert torch.allclose(dk, k.grad, atol=1e-2, rtol=1e-2), f"dk max err: {(dk - k.grad).abs().max()}" + assert torch.allclose(dv, v.grad, atol=1e-2, rtol=1e-2), f"dv max err: {(dv - v.grad).abs().max()}" + print("Passed all tests!✅") # Benchmark q.grad = k.grad = v.grad = None o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) - t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), backend='cupti') - t2 = do_bench(lambda: tl_fused_chunk_bwd(q, k, v, do), backend='cupti') - print(f'Triton latency: {t1:.3f} ms') - print(f'TileLang latency: {t2:.3f} ms') - print(f'Speedup: {t1/t2:.3f}x') + t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), backend="cupti") + t2 = do_bench(lambda: tl_fused_chunk_bwd(q, k, v, do), backend="cupti") + print(f"Triton latency: {t1:.3f} ms") + print(f"TileLang latency: {t2:.3f} ms") + print(f"Speedup: {t1 / t2:.3f}x") + + +def run_regression_perf(B=1, S=1024, H=16, D=128): + q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True) + k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True) + v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True) + do = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + q = l2norm_fwd(q)[0].requires_grad_(True) + k = l2norm_fwd(k)[0].requires_grad_(True) + kernel = tl_fused_chunk_bwd_kernel(B, S, H, D, D) + dQ = torch.zeros_like(q, dtype=torch.float32) + dK = torch.zeros_like(k, dtype=torch.float32) + dV = torch.zeros_like(v, dtype=torch.float32) + kernel(q, k, v, do, dQ, dK, dV) + return do_bench(lambda: kernel(q, k, v, do, dQ, dK, dV), backend="cupti") -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--B', type=int, default=8, help='Batch size') - parser.add_argument('--S', type=int, default=1024, help='Seq len') - parser.add_argument('--H', type=int, default=32, help='Num heads') - parser.add_argument('--D', type=int, default=128, help='Head dim') + parser.add_argument("--B", type=int, default=8, help="Batch size") + parser.add_argument("--S", type=int, default=1024, help="Seq len") + parser.add_argument("--H", type=int, default=32, help="Num heads") + parser.add_argument("--D", type=int, default=128, help="Head dim") args = parser.parse_args() main(args.B, args.S, args.H, args.D) diff --git a/examples/linear_attention/example_linear_attn_fwd.py b/examples/linear_attention/example_linear_attn_fwd.py index cbf352bbc..cdfd5cb72 100644 --- a/examples/linear_attention/example_linear_attn_fwd.py +++ b/examples/linear_attention/example_linear_attn_fwd.py @@ -14,20 +14,20 @@ pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, +) def tl_fused_chunk_fwd_kernel( B, S, H, DK, DV, - dtype: str = 'float16', + dtype: T.dtype = T.float16, scale: float = None, ) -> torch.Tensor: - if scale is None: scale = DK**-0.5 - accum_dtype = 'float' + accum_dtype = T.float32 chunk_size = 64 BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA @@ -38,11 +38,12 @@ def tl_fused_chunk_fwd_kernel( @T.prim_func def fused_chunk_linear_attn_fwd( - Q: T.Tensor([B, S, H, DK], dtype), # type: ignore - K: T.Tensor([B, S, H, DK], dtype), # type: ignore - V: T.Tensor([B, S, H, DV], dtype), # type: ignore - O: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore - final_state: T.Tensor([B, H, DK, DV], accum_dtype)): # type: ignore + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + O: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore + final_state: T.Tensor([B, H, DK, DV], accum_dtype), + ): # type: ignore with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): i_b = i_bh // H i_h = i_bh % H @@ -57,7 +58,6 @@ def fused_chunk_linear_attn_fwd( o = T.alloc_fragment([chunk_size, BV], accum_dtype) o_shared = T.alloc_shared([chunk_size, BV], accum_dtype) - T.annotate_layout({o_shared: tilelang.layout.make_swizzled_layout(o_shared)}) T.use_swizzle(10) T.clear(h) @@ -65,8 +65,8 @@ def fused_chunk_linear_attn_fwd( for i in T.Pipelined(0, NT): for row, col in T.Parallel(chunk_size, BK): q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale - T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) - T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v) + T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k) + T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v) T.gemm(q, k, s, clear_accum=True, transpose_B=True) for row, col in T.Parallel(chunk_size, chunk_size): @@ -77,13 +77,10 @@ def fused_chunk_linear_attn_fwd( T.gemm(k, v, h, transpose_A=True) T.gemm(q, h_shared, o) T.copy(o, o_shared) - T.atomic_add( - O[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], - o_shared) - #TODO: consider using vectorized atomic add or tma reduce for sm90 + T.atomic_add(O[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], o_shared) # Output final state - T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV]) + T.copy(h, final_state[i_b, i_h, i_k * BK : (i_k + 1) * BK, i_v * BV : (i_v + 1) * BV]) return fused_chunk_linear_attn_fwd @@ -91,38 +88,36 @@ def fused_chunk_linear_attn_fwd( def tl_fused_chunk_fwd(q, k, v): B, S, H, D = q.shape kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D) - o = torch.zeros((B, S, H, D), device='cuda', dtype=torch.float32) + print(kernel.get_kernel_source()) + o = torch.zeros((B, S, H, D), device="cuda", dtype=torch.float32) h = kernel(q, k, v, o) return o, h -def ref_program(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]: +def ref_program(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]: q, k, v = q.float(), k.float(), v.float() if scale is None: - scale = q.shape[-1]**-0.5 + scale = q.shape[-1] ** -0.5 chunk_size = 64 - q = rearrange(q, 'b (n c) h d -> b h n c d', c=chunk_size) * scale - k = rearrange(k, 'b (n c) h d -> b h n c d', c=chunk_size) - v = rearrange(v, 'b (n c) h d -> b h n c d', c=chunk_size) + q = rearrange(q, "b (n c) h d -> b h n c d", c=chunk_size) * scale + k = rearrange(k, "b (n c) h d -> b h n c d", c=chunk_size) + v = rearrange(v, "b (n c) h d -> b h n c d", c=chunk_size) kv = k.transpose(-1, -2) @ v kv = kv.cumsum(2) h = kv[:, :, -1, :, :] kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) inter = q @ kv - intra = ((q @ k.transpose(-1, -2)).masked_fill_( - torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), - 0)) @ v + intra = ( + (q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0) + ) @ v o = inter + intra - return rearrange(o, 'b h n c d -> b (n c) h d'), h + return rearrange(o, "b h n c d -> b (n c) h d"), h def main(B=1, S=512, H=16, D=128): - q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) - k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) - v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) + q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) # qk norm is necessary for linear attn q, _ = l2norm_fwd(q) @@ -131,25 +126,35 @@ def main(B=1, S=512, H=16, D=128): o, h = tl_fused_chunk_fwd(q, k, v) o_ref, h_ref = ref_program(q, k, v) - assert torch.allclose(o, o_ref, atol=1e-2, rtol=1e-2), f'o max err: {(o - o_ref).abs().max()}' - assert torch.allclose(h, h_ref, atol=1e-2, rtol=1e-2), f'h max err: {(h - h_ref).abs().max()}' - print('Passed all tests!✅') + assert torch.allclose(o, o_ref, atol=1e-2, rtol=1e-2), f"o max err: {(o - o_ref).abs().max()}" + assert torch.allclose(h, h_ref, atol=1e-2, rtol=1e-2), f"h max err: {(h - h_ref).abs().max()}" + print("Passed all tests!✅") + + t1 = do_bench(lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False), backend="cupti") + t2 = do_bench(lambda: tl_fused_chunk_fwd(q, k, v), backend="cupti") + print(f"Triton latency: {t1:.3f} ms") + print(f"TileLang latency: {t2:.3f} ms") + print(f"Speedup: {t1 / t2:.3f}x") - t1 = do_bench( - lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False), - backend='cupti') - t2 = do_bench(lambda: tl_fused_chunk_fwd(q, k, v), backend='cupti') - print(f'Triton latency: {t1:.3f} ms') - print(f'TileLang latency: {t2:.3f} ms') - print(f'Speedup: {t1/t2:.3f}x') + +def run_regression_perf(B=1, S=512, H=16, D=128): + q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + q, _ = l2norm_fwd(q) + k, _ = l2norm_fwd(k) + B, S, H, D = q.shape + kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D) + o = torch.zeros((B, S, H, D), device="cuda", dtype=torch.float32) + return do_bench(lambda: kernel(q, k, v, o), backend="cupti") -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--B', type=int, default=8, help='Batch size') - parser.add_argument('--S', type=int, default=1024, help='Seq len') - parser.add_argument('--H', type=int, default=32, help='Num heads') - parser.add_argument('--D', type=int, default=128, help='Head dim') + parser.add_argument("--B", type=int, default=8, help="Batch size") + parser.add_argument("--S", type=int, default=1024, help="Seq len") + parser.add_argument("--H", type=int, default=32, help="Num heads") + parser.add_argument("--D", type=int, default=128, help="Head dim") args = parser.parse_args() main(args.B, args.S, args.H, args.D) diff --git a/examples/linear_attention/example_mamba_chunk_scan.py b/examples/linear_attention/example_mamba_chunk_scan.py index add49052d..88a9b75bc 100644 --- a/examples/linear_attention/example_mamba_chunk_scan.py +++ b/examples/linear_attention/example_mamba_chunk_scan.py @@ -9,6 +9,7 @@ def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D): from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd + out, _ = _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D) return out @@ -43,14 +44,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] decay = torch.exp(dt_segment_sum) scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s") - causal_mask = torch.tril( - torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) + causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) scores_decay = scores_decay.masked_fill(~causal_mask, 0) - out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype), - rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) + out = torch.einsum( + "bchls,bhcs,bcshp->bclhp", scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks) + ) state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) - out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange( - C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out + out_prev = ( + torch.einsum("bclhn,bchpn->bclhp", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out + ) out = out + out_prev out = rearrange(out, "b c l h p -> b (c l) h p") if D is not None: @@ -61,12 +63,7 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): def get_configs(): - iter_params = dict( - block_M=[64, 128, 256], - block_N=[32, 64], - block_K=[64, 128, 256], - block_Dstate=[128], - num_stages=[1, 2, 3, 4, 5]) + iter_params = dict(block_M=[64, 128, 256], block_N=[32, 64], block_K=[64, 128, 256], block_Dstate=[128], num_stages=[1, 2, 3, 4, 5]) return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] @@ -77,56 +74,58 @@ def get_configs(): tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }, ) -def chunk_scan_fwd(batch, - seqlen, - chunk_size, - ngroups, - nheads, - headdim, - dstate, - block_M=64, - block_N=64, - block_K=64, - block_Dstate=128, - num_stages=2, - threads=128): - dtype = "float16" - accum_dtype = "float" +def chunk_scan_fwd( + batch, + seqlen, + chunk_size, + ngroups, + nheads, + headdim, + dstate, + block_M=64, + block_N=64, + block_K=64, + block_Dstate=128, + num_stages=2, + threads=128, +): + dtype = T.float16 + accum_dtype = T.float32 nchunks = T.ceildiv(seqlen, chunk_size) p = 1.44269504 @T.prim_func def main( - cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore - x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore - dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore - dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore - C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore - prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore - D: T.Tensor((nheads), dtype), # type: ignore - Output: T.Tensor((batch, seqlen, nheads, headdim), dtype) # type: ignore + cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore + x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore + dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore + dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore + C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore + prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore + D: T.Tensor((nheads), dtype), # type: ignore + Output: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore ): - with T.Kernel( - nheads, - T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), - batch * nchunks, - threads=threads) as (bz, bx, by): + with T.Kernel(nheads, T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), batch * nchunks, threads=threads) as ( + bz, + bx, + by, + ): acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) acc_o_shared = T.alloc_shared((block_M, block_N), dtype) - cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn") + cb_shared = T.alloc_shared((block_M, block_K), dtype) cb_local = T.alloc_fragment((block_M, block_K), dtype) - dA_cs_k_shared = T.alloc_shared((block_K), dtype, scope="shared") + dA_cs_k_shared = T.alloc_shared((block_K), dtype) dA_cs_k_local = T.alloc_fragment((block_K), accum_dtype) dA_cs_m_local = T.alloc_fragment((block_M), accum_dtype) - dt_shared = T.alloc_shared((block_K), dtype, scope="shared") + dt_shared = T.alloc_shared((block_K), dtype) dt_local = T.alloc_fragment((block_K), accum_dtype) x_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.dyn") - dA_cs_m_shared = T.alloc_shared((block_M), dtype, scope="shared") + dA_cs_m_shared = T.alloc_shared((block_M), dtype) scale_m_local = T.alloc_fragment((block_M), accum_dtype) C_shared = T.alloc_shared((block_M, block_Dstate), dtype) prev_state_shared = T.alloc_shared((block_N, block_Dstate), dtype) D_local = T.alloc_fragment((1), accum_dtype) - x_residual_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.dyn") + x_residual_shared = T.alloc_shared((block_M, block_N), dtype) x_residual_local = T.alloc_fragment((block_M, block_N), accum_dtype) batch_idx = by % batch @@ -136,27 +135,31 @@ def main( m_idx = bx // T.ceildiv(headdim, block_N) n_idx = bx % T.ceildiv(headdim, block_N) - T.annotate_layout({ - acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared), - cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), - x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared) - }) + T.annotate_layout( + { + cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), + x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared), + } + ) T.no_set_max_nreg() - T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M], - dA_cs_m_shared) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M : (m_idx + 1) * block_M], dA_cs_m_shared) T.copy(dA_cs_m_shared, dA_cs_m_local) T.clear(acc_o) for i in T.Parallel(block_M): scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p) T.copy( - C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared) - T.copy( - prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N, - 0:block_Dstate], prev_state_shared) + C[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz // (nheads // ngroups), + 0:block_Dstate, + ], + C_shared, + ) + T.copy(prev_states[batch_idx, chunk_idx, bz, n_idx * block_N : (n_idx + 1) * block_N, 0:block_Dstate], prev_state_shared) T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True) for i, j in T.Parallel(block_M, block_N): acc_o[i, j] *= scale_m_local[i] @@ -165,34 +168,47 @@ def main( for k in T.Pipelined(loop_range, num_stages=num_stages): T.copy( - cb[batch_idx, chunk_idx, bz // (nheads // ngroups), - m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K], - cb_shared) + cb[ + batch_idx, + chunk_idx, + bz // (nheads // ngroups), + m_idx * block_M : (m_idx + 1) * block_M, + k * block_K : (k + 1) * block_K, + ], + cb_shared, + ) T.copy(cb_shared, cb_local) - T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], - dA_cs_k_shared) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cs_k_shared) T.copy(dA_cs_k_shared, dA_cs_k_local) for i, j in T.Parallel(block_M, block_K): - cb_local[i, - j] = cb_local[i, - j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) - T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared) + cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) + T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared) T.copy(dt_shared, dt_local) for i, j in T.Parallel(block_M, block_K): cb_local[i, j] *= dt_local[j] for i, j in T.Parallel(block_M, block_K): - cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, - cb_local[i, j], 0) + cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, cb_local[i, j], 0) T.copy( - x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + - (k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared) + x[ + batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + x_shared, + ) T.gemm(cb_local, x_shared, acc_o) D_local[0] = D[bz] T.copy( - x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N], - x_residual_shared) + x[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + x_residual_shared, + ) T.copy(x_residual_shared, x_residual_local) for i, j in T.Parallel(block_M, block_N): acc_o[i, j] += x_residual_local[i, j] * D_local[0] @@ -200,27 +216,40 @@ def main( T.copy(acc_o, acc_o_shared) T.copy( acc_o_shared, - Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N]) + Output[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + ) return main if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=80, help='heads') - parser.add_argument('--groups', type=int, default=1, help='groups') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--chunk_size', type=int, default=256, help='chunk size') - parser.add_argument('--dim', type=int, default=64, help='dim') - parser.add_argument('--dstate', type=int, default=128, help='dstate') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=80, help="heads") + parser.add_argument("--groups", type=int, default=1, help="groups") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--chunk_size", type=int, default=256, help="chunk size") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--dstate", type=int, default=128, help="dstate") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate + batch, heads, groups, seq_len, chunk_size, dim, dstate = ( + args.batch, + args.heads, + args.groups, + args.seq_len, + args.chunk_size, + args.dim, + args.dstate, + ) total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate - if (not args.tune): + if not args.tune: kernel = chunk_scan_fwd( batch, seq_len, @@ -234,7 +263,8 @@ def main( block_K=64, block_Dstate=128, num_stages=2, - threads=128) + threads=128, + ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("All checks pass.") diff --git a/examples/linear_attention/example_mamba_chunk_state.py b/examples/linear_attention/example_mamba_chunk_state.py index ad3df0df8..96126889b 100644 --- a/examples/linear_attention/example_mamba_chunk_state.py +++ b/examples/linear_attention/example_mamba_chunk_state.py @@ -10,6 +10,7 @@ def chunk_state_triton(B, x, dt, dA_cumsum): from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd + return _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=False) @@ -41,46 +42,33 @@ def ref_program(B, x, dt, dA_cumsum): x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size) B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size) decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)) - return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), - dt.to(x.dtype), x) + return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), dt.to(x.dtype), x) def get_configs(): - iter_params = dict( - block_M=[64, 128], block_N=[32, 64, 128], block_K=[32, 64], num_stages=[1, 2, 3, 4, 5]) + iter_params = dict(block_M=[64, 128], block_N=[32, 64, 128], block_K=[32, 64], num_stages=[1, 2, 3, 4, 5]) return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[4]) -def chunk_state_fwd(batch, - seqlen, - chunk_size, - ngroups, - nheads, - headdim, - dstate, - block_M=64, - block_N=64, - block_K=64, - num_stages=2, - threads=128): - dtype = "float16" - accum_dtype = "float" +def chunk_state_fwd( + batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, block_M=64, block_N=64, block_K=64, num_stages=2, threads=128 +): + dtype = T.float16 + accum_dtype = T.float32 nchunks = T.ceildiv(seqlen, chunk_size) p = 1.44269504 @T.prim_func - def main(B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), x: T.Tensor( - (batch, seqlen, nheads, headdim), dtype), dt: T.Tensor( - (batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor( - (batch, nheads, nchunks, chunk_size), dtype), Output: T.Tensor( - (batch, nchunks, nheads, headdim, dstate), dtype)): - with T.Kernel( - nheads, - T.ceildiv(headdim, block_M) * T.ceildiv(dstate, block_N), - batch * nchunks, - threads=threads) as (bz, bx, by): + def main( + B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), + x: T.Tensor((batch, seqlen, nheads, headdim), dtype), + dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), + dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), + Output: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), + ): + with T.Kernel(nheads, T.ceildiv(headdim, block_M) * T.ceildiv(dstate, block_N), batch * nchunks, threads=threads) as (bz, bx, by): x_shared = T.alloc_shared((block_K, block_M), dtype) x_local = T.alloc_fragment((block_K, block_M), dtype) xt_local = T.alloc_fragment((block_M, block_K), dtype) @@ -101,20 +89,22 @@ def main(B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), x: T.Tensor( m_idx = bx // T.ceildiv(dstate, block_N) n_idx = bx % T.ceildiv(dstate, block_N) - T.annotate_layout({ - x_shared: tilelang.layout.make_swizzled_layout(x_shared), - acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared) - }) + T.annotate_layout({x_shared: tilelang.layout.make_swizzled_layout(x_shared)}) dA_cs_last[0] = dA_cumsum[batch_idx, bz, chunk_idx, chunk_size - 1] T.clear(acc_o) for k in T.Pipelined(loop_range, num_stages=num_stages): T.copy( - x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + - (k + 1) * block_K, bz, m_idx * block_M:(m_idx + 1) * block_M], x_shared) - T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], - dA_cumsum_shared) - T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared) + x[ + batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz, + m_idx * block_M : (m_idx + 1) * block_M, + ], + x_shared, + ) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cumsum_shared) + T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared) T.copy(dA_cumsum_shared, dA_cumsum_local) T.copy(dt_shared, dt_local) for i in T.Parallel(block_K): @@ -123,47 +113,50 @@ def main(B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), x: T.Tensor( for i, j in T.Parallel(block_M, block_K): xt_local[i, j] = x_local[j, i] * scale[j] T.copy( - B[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + - (k + 1) * block_K, bz // (nheads // ngroups), - n_idx * block_N:(n_idx + 1) * block_N], B_shared) + B[ + batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz // (nheads // ngroups), + n_idx * block_N : (n_idx + 1) * block_N, + ], + B_shared, + ) T.gemm(xt_local, B_shared, acc_o) T.copy(acc_o, acc_o_shared) T.copy( acc_o_shared, - Output[batch_idx, chunk_idx, bz, m_idx * block_M:(m_idx + 1) * block_M, - n_idx * block_N:(n_idx + 1) * block_N]) + Output[batch_idx, chunk_idx, bz, m_idx * block_M : (m_idx + 1) * block_M, n_idx * block_N : (n_idx + 1) * block_N], + ) return main if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=80, help='heads') - parser.add_argument('--groups', type=int, default=1, help='groups') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--chunk_size', type=int, default=256, help='chunk size') - parser.add_argument('--dim', type=int, default=64, help='dim') - parser.add_argument('--dstate', type=int, default=128, help='dstate') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=80, help="heads") + parser.add_argument("--groups", type=int, default=1, help="groups") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--chunk_size", type=int, default=256, help="chunk size") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--dstate", type=int, default=128, help="dstate") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate + batch, heads, groups, seq_len, chunk_size, dim, dstate = ( + args.batch, + args.heads, + args.groups, + args.seq_len, + args.chunk_size, + args.dim, + args.dstate, + ) total_flops = 2 * batch * seq_len * heads * dim * dstate - if (not args.tune): + if not args.tune: kernel = chunk_state_fwd( - batch, - seq_len, - chunk_size, - groups, - heads, - dim, - dstate, - block_M=64, - block_N=128, - block_K=64, - num_stages=4, - threads=128) + batch, seq_len, chunk_size, groups, heads, dim, dstate, block_M=64, block_N=128, block_K=64, num_stages=4, threads=128 + ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("All checks pass.") diff --git a/examples/linear_attention/example_retention_fwd.py b/examples/linear_attention/example_retention_fwd.py index 66012e0c1..f45e38388 100644 --- a/examples/linear_attention/example_retention_fwd.py +++ b/examples/linear_attention/example_retention_fwd.py @@ -13,13 +13,12 @@ def chunk_retention_fwd_kernel( H, DK, DV, - dtype: str = 'float16', + dtype: T.dtype = T.float16, scale: float = None, ) -> torch.Tensor: - if scale is None: scale = DK**-0.5 - accum_dtype = 'float' + accum_dtype = T.float32 chunk_size = 64 BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA @@ -30,16 +29,16 @@ def chunk_retention_fwd_kernel( @T.prim_func def chunk_retention_fwd( - Q: T.Tensor([B, S, H, DK], dtype), # type: ignore - K: T.Tensor([B, S, H, DK], dtype), # type: ignore - V: T.Tensor([B, S, H, DV], dtype), # type: ignore - O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore ): with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): i_b = i_bh // H i_h = i_bh % H - log_decay = T.alloc_var('float32') - log_decay = T.log2(1 - T.exp2(-5. - 1. * i_h)) # Head-specific log decay + log_decay = T.alloc_var(T.float32) + log_decay = T.log2(1 - T.exp2(-5.0 - 1.0 * i_h)) # Head-specific log decay q = T.alloc_shared([chunk_size, BK], dtype) k = T.alloc_shared([chunk_size, BK], dtype) @@ -51,26 +50,17 @@ def chunk_retention_fwd( o = T.alloc_fragment([chunk_size, BV], accum_dtype) T.clear(h) - T.annotate_layout({ - q: tl.layout.make_swizzled_layout(q), - k: tl.layout.make_swizzled_layout(k), - v: tl.layout.make_swizzled_layout(v), - h_shared: tl.layout.make_swizzled_layout(h_shared), - s_shared: tl.layout.make_swizzled_layout(s_shared), - }) T.use_swizzle(10) for i in T.Pipelined(0, NT): for row, col in T.Parallel(chunk_size, BK): q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale - T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) - T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v) + T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k) + T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v) T.gemm(q, k, s, clear_accum=True, transpose_B=True) for row, col in T.Parallel(chunk_size, chunk_size): - s_shared[row, - col] = T.if_then_else(row >= col, s[row, col] * T.exp2( - (row - col) * log_decay), 0) + s_shared[row, col] = T.if_then_else(row >= col, s[row, col] * T.exp2((row - col) * log_decay), 0) T.copy(h, h_shared) T.gemm(q, h_shared, o, clear_accum=True) @@ -82,9 +72,7 @@ def chunk_retention_fwd( v[row, col] = v[row, col] * T.exp2((chunk_size - row - 1) * log_decay) for row, col in T.Parallel(BK, BV): h[row, col] = T.exp2(chunk_size * log_decay) * h[row, col] - T.copy( - o, O[i_k, i_b, i * chunk_size:(i + 1) * chunk_size, i_h, - i_v * BV:(i_v + 1) * BV]) + T.copy(o, O[i_k, i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV]) T.gemm(k, v, h, transpose_A=True) return chunk_retention_fwd @@ -96,24 +84,24 @@ def postprocess(o): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--B', type=int, default=8, help='Batch size') - parser.add_argument('--S', type=int, default=4096, help='Seq len') - parser.add_argument('--H', type=int, default=32, help='Num heads') - parser.add_argument('--D', type=int, default=128, help='Head dim') + parser.add_argument("--B", type=int, default=8, help="Batch size") + parser.add_argument("--S", type=int, default=4096, help="Seq len") + parser.add_argument("--H", type=int, default=32, help="Num heads") + parser.add_argument("--D", type=int, default=128, help="Head dim") args = parser.parse_args() B, S, H, D = args.B, args.S, args.H, args.D total_flops = 2.0 * B * S * S * H * D # causal - q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) - k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) - v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) + q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) kernel = chunk_retention_fwd_kernel(B, S, H, D, D) t = do_bench(lambda: postprocess(kernel(q, k, v)), warmup=25, rep=100) - print(f'Tilelang latency: {t:.3f} ms') - print(f'Tilelang TFLOPs: {total_flops/t * 1e-9}') + print(f"Tilelang latency: {t:.3f} ms") + print(f"Tilelang TFLOPs: {total_flops / t * 1e-9}") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/linear_attention/regression_linear_attn.py b/examples/linear_attention/regression_linear_attn.py new file mode 100644 index 000000000..ced854087 --- /dev/null +++ b/examples/linear_attention/regression_linear_attn.py @@ -0,0 +1,15 @@ +import tilelang.testing +import example_linear_attn_bwd +import example_linear_attn_fwd + + +def regression_example_linear_attn_fwd(): + tilelang.testing.process_func(example_linear_attn_fwd.run_regression_perf) + + +def regression_example_linear_attn_bwd(): + tilelang.testing.process_func(example_linear_attn_bwd.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/minference/example_vertical_slash_sparse_attn.py b/examples/minference/example_vertical_slash_sparse_attn.py index ebf8513a1..91af8b454 100644 --- a/examples/minference/example_vertical_slash_sparse_attn.py +++ b/examples/minference/example_vertical_slash_sparse_attn.py @@ -15,12 +15,11 @@ @tilelang.jit(out_idx=[3]) def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_size): - block_M = 64 block_N = 64 num_stages = 2 threads = 128 - scale = (1.0 / dim)**0.5 * 1.44269504 + scale = (1.0 / dim) ** 0.5 * 1.44269504 shape = [batch, heads, seq_len, dim] seq_blocks = (seq_len + block_M - 1) // block_M @@ -30,15 +29,13 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz offset_shape = count_shape + [slash_size] index_shape = count_shape + [vertical_size] - vertical_size_round, slash_size_round = tilelang.next_power_of_2( - vertical_size), tilelang.next_power_of_2(slash_size) + vertical_size_round, slash_size_round = tilelang.next_power_of_2(vertical_size), tilelang.next_power_of_2(slash_size) - dtype = "float16" - accum_dtype = "float" - int_dtype = "int32" + dtype = T.float16 + accum_dtype = T.float32 + int_dtype = T.int32 def kernel_func(block_M, block_N, num_stages, threads): - @T.macro def Prefetch( K: T.Tensor(shape, dtype), @@ -53,32 +50,30 @@ def Prefetch( ): with T.attr("default", "async_scope", 1): for i, j in T.Parallel(block_N, dim): - K_shared[i, j] = T.if_then_else(k + i < column_count, - K[bz, by, column_index[k + i], j], 0) + K_shared[i, j] = T.if_then_else(k + i < column_count, K[bz, by, column_index[k + i], j], 0) with T.attr("default", "async_scope", 1): for i, j in T.Parallel(block_N, dim): - V_shared[i, j] = T.if_then_else(k + i < column_count, - V[bz, by, column_index[k + i], j], 0) + V_shared[i, j] = T.if_then_else(k + i < column_count, V[bz, by, column_index[k + i], j], 0) T.ptx_commit_group() @T.macro def Compute( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - k: T.int32, - column_count: T.int32, - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - V_shared: T.SharedBuffer([block_N, dim], dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - count: T.int32, + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + k: T.int32, + column_count: T.int32, + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + count: T.int32, ): T.ptx_wait_group(count) for i, j in T.Parallel(block_M, block_N): @@ -87,6 +82,8 @@ def Compute( T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) @@ -106,17 +103,16 @@ def Compute( @T.prim_func def vs_sparse_flashattn_ws( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - Output: T.Tensor(shape, dtype), - BlockCount: T.Tensor(count_shape, int_dtype), - BlockOffset: T.Tensor(offset_shape, int_dtype), - ColumnCount: T.Tensor(count_shape, int_dtype), - ColumnIndex: T.Tensor(index_shape, int_dtype), + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), + BlockCount: T.Tensor(count_shape, int_dtype), + BlockOffset: T.Tensor(offset_shape, int_dtype), + ColumnCount: T.Tensor(count_shape, int_dtype), + ColumnIndex: T.Tensor(index_shape, int_dtype), ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bc, by, bz): - bx = T.ceildiv(seq_len, block_M) - 1 - bc Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([2, block_N, dim], dtype) @@ -134,19 +130,15 @@ def vs_sparse_flashattn_ws( scores_scale = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - block_count = T.alloc_local([1], int_dtype) + block_count = T.alloc_var(dtype=int_dtype) block_offset = T.alloc_shared([slash_size_round], int_dtype, scope="shared") - column_count = T.alloc_local([1], int_dtype) + column_count = T.alloc_var(dtype=int_dtype) column_index = T.alloc_shared([vertical_size_round], int_dtype, scope="shared") T.create_list_of_mbarrier([128] * 9) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - }) - - block_count[0] = BlockCount[bz, by, bx] - column_count[0] = ColumnCount[bz, by, bx] + block_count = BlockCount[bz, by, bx] + column_count = ColumnCount[bz, by, bx] for vi in T.Parallel(slash_size_round): if vi < slash_size: @@ -160,15 +152,15 @@ def vs_sparse_flashattn_ws( if tid >= 128: T.annotate_producer_reg_dealloc() - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.mbarrier_arrive(mbarrier=8) - for bi in T.serial(block_count[0]): + for bi in T.serial(block_count): k = block_offset[bi] T.mbarrier_wait_parity(mbarrier=bi % 2 + 4, parity=(((bi & 3) >> 1) ^ 1)) - T.copy(K[bz, by, k:k + block_N, :], K_shared[bi % 2, :, :]) + T.copy(K[bz, by, k : k + block_N, :], K_shared[bi % 2, :, :]) T.mbarrier_arrive(mbarrier=bi % 2) T.mbarrier_wait_parity(mbarrier=bi % 2 + 6, parity=(((bi & 3) >> 1) ^ 1)) - T.copy(V[bz, by, k:k + block_N, :], V_shared[bi % 2, :, :]) + T.copy(V[bz, by, k : k + block_N, :], V_shared[bi % 2, :, :]) T.mbarrier_arrive(mbarrier=bi % 2 + 2) else: T.annotate_consumer_reg_alloc() @@ -176,40 +168,31 @@ def vs_sparse_flashattn_ws( T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) T.mbarrier_wait_parity(mbarrier=8, parity=0) - for bi in T.serial(block_count[0]): + for bi in T.serial(block_count): k = block_offset[bi] for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k + j, 0, -T.infinity(acc_s.dtype)) T.mbarrier_wait_parity(mbarrier=bi % 2, parity=((bi & 3) >> 1)) - T.gemm( - Q_shared, - K_shared[bi % 2, :, :], - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared[bi % 2, :, :], acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.mbarrier_arrive(mbarrier=bi % 2 + 4) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - - scores_max[i] * scale) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): acc_o[i, j] = acc_o[i, j] * scores_scale[i] T.copy(acc_s, acc_s_cast) - T.mbarrier_wait_parity(mbarrier=bi % 2 + 2, parity=(((bi & 3) >> 1))) - T.gemm( - acc_s_cast, - V_shared[bi % 2, :, :], - acc_o, - policy=T.GemmWarpPolicy.FullRow) + T.mbarrier_wait_parity(mbarrier=bi % 2 + 2, parity=((bi & 3) >> 1)) + T.gemm(acc_s_cast, V_shared[bi % 2, :, :], acc_o, policy=T.GemmWarpPolicy.FullRow) T.mbarrier_arrive(mbarrier=bi % 2 + 6) @@ -218,39 +201,86 @@ def vs_sparse_flashattn_ws( for i in T.Parallel(block_M): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - if column_count[0] != 0: - Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], 0, bz, - by) - for bi in T.serial(T.ceildiv(column_count[0], block_N) - 1): + if column_count != 0: + Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count, 0, bz, by) + for bi in T.serial(T.ceildiv(column_count, block_N) - 1): k = bi * block_N if bi % 2 == 0: - Prefetch(K, V, K_shared_2, V_shared_2, column_index, - column_count[0], k + block_N, bz, by) - - Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, k, - column_count[0], Q_shared, K_shared_1, V_shared_1, - scores_scale, scores_sum, logsum, 1) + Prefetch(K, V, K_shared_2, V_shared_2, column_index, column_count, k + block_N, bz, by) + + Compute( + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + k, + column_count, + Q_shared, + K_shared_1, + V_shared_1, + scores_scale, + scores_sum, + logsum, + 1, + ) else: - Prefetch(K, V, K_shared_1, V_shared_1, column_index, - column_count[0], k + block_N, bz, by) - - Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, k, - column_count[0], Q_shared, K_shared_2, V_shared_2, - scores_scale, scores_sum, logsum, 1) - if T.ceildiv(column_count[0], block_N) % 2 == 0: - Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, - T.ceildiv(column_count[0], block_N) * block_N - block_N, - column_count[0], Q_shared, K_shared_2, V_shared_2, scores_scale, - scores_sum, logsum, 0) + Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count, k + block_N, bz, by) + + Compute( + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + k, + column_count, + Q_shared, + K_shared_2, + V_shared_2, + scores_scale, + scores_sum, + logsum, + 1, + ) + if T.ceildiv(column_count, block_N) % 2 == 0: + Compute( + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + T.ceildiv(column_count, block_N) * block_N - block_N, + column_count, + Q_shared, + K_shared_2, + V_shared_2, + scores_scale, + scores_sum, + logsum, + 0, + ) else: - Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, - T.ceildiv(column_count[0], block_N) * block_N - block_N, - column_count[0], Q_shared, K_shared_1, V_shared_1, scores_scale, - scores_sum, logsum, 0) + Compute( + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + T.ceildiv(column_count, block_N) * block_N - block_N, + column_count, + Q_shared, + K_shared_1, + V_shared_1, + scores_scale, + scores_sum, + logsum, + 0, + ) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return vs_sparse_flashattn_ws @@ -466,11 +496,8 @@ def vertical_slash_sparse_attention( import os current_dir = os.path.dirname(os.path.abspath(__file__)) - sources = [ - os.path.join(current_dir, 'ops', 'kernels.cpp'), - os.path.join(current_dir, 'ops', 'vertical_slash_index.cu') - ] - ops = load(name='convert', sources=sources, verbose=False) + sources = [os.path.join(current_dir, "ops", "kernels.cpp"), os.path.join(current_dir, "ops", "vertical_slash_index.cu")] + ops = load(name="convert", sources=sources, verbose=False) convert_vertical_slash_indexes = ops.convert_vertical_slash_indexes batch_size, num_heads, context_size, head_dim = query.shape pad = (block_size_M - context_size) & (block_size_M - 1) @@ -481,15 +508,13 @@ def vertical_slash_sparse_attention( value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0]) if head_dim not in [16, 32, 64, 128, 256, 512]: - target_dim = 2**math.ceil(math.log2(head_dim)) - head_dim + target_dim = 2 ** math.ceil(math.log2(head_dim)) - head_dim query = torch.nn.functional.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0]) key = torch.nn.functional.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0]) value = torch.nn.functional.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0]) - v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort( - dim=-1, descending=False)[0] - s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort( - dim=-1, descending=True)[0] + v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0] + s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0] seqlens = torch.tensor([context_size] * query.shape[0], dtype=torch.int32, device=query.device) sm_scale = head_dim**-0.5 @@ -502,8 +527,7 @@ def vertical_slash_sparse_attention( block_size_N, ) - tl_kernel = _tl_vs_sparse_flashattn(batch_size, num_heads, context_size, head_dim, - v_idx.shape[2], s_idx.shape[2]) + tl_kernel = _tl_vs_sparse_flashattn(batch_size, num_heads, context_size, head_dim, v_idx.shape[2], s_idx.shape[2]) def run(is_triton: bool = True): if is_triton: @@ -521,8 +545,7 @@ def run(is_triton: bool = True): block_size_N, ) else: - out = tl_kernel(query, key, value, block_count, block_offset, column_count, - column_index) + out = tl_kernel(query, key, value, block_count, block_offset, column_count, column_index) return out[..., :context_size, :head_dim] return run @@ -532,8 +555,7 @@ def sum_all_diagonal_matrix(mat: torch.tensor): b, h, n, m = mat.shape zero_mat = torch.zeros((b, h, n, n)).to(mat.device) # Zero matrix used for padding mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) # pads the matrix on left and right - mat_strided = mat_padded.as_strided( - (1, 1, n, n + m), (1, n * (2 * n + m), 2 * n + m + 1, 1)) # Change the strides + mat_strided = mat_padded.as_strided((1, 1, n, n + m), (1, n * (2 * n + m), 2 * n + m + 1, 1)) # Change the strides sum_diags = torch.sum(mat_strided, 2) # Sums the resulting matrix's columns return sum_diags[:, :, 1:] @@ -555,24 +577,23 @@ def main(argv=None): vertical_size, slash_size = args.vertical_size, args.slash_size torch.manual_seed(0) - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) q_len = SEQ_LEN vertical_size, slash_size = min(q_len, vertical_size), min(q_len, slash_size) last_q = 64 - qk = torch.einsum('bhmk, bhnk -> bhmn', q[:, :, -last_q:, :], k) + qk = torch.einsum("bhmk, bhnk -> bhmn", q[:, :, -last_q:, :], k) arange = torch.arange(last_q, device="cuda") - qk[:, :, :, -last_q:] = torch.where(arange[None, None, :, None] >= arange[None, None, None, :], - qk[:, :, :, -last_q:], -torch.inf) + qk[:, :, :, -last_q:] = torch.where(arange[None, None, :, None] >= arange[None, None, None, :], qk[:, :, :, -last_q:], -torch.inf) qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32) vertical = qk.sum(-2, keepdim=True) vertical[..., :30] = torch.inf vertical_topk = torch.topk(vertical, vertical_size, -1).indices - slash = sum_all_diagonal_matrix(qk)[..., :-last_q + 1] + slash = sum_all_diagonal_matrix(qk)[..., : -last_q + 1] slash[..., -30:] = torch.inf slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices @@ -592,5 +613,78 @@ def main(argv=None): print(f"speedup: {triton_time / tilelang_time:.2f}x") +def run_regression_perf(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1) + parser.add_argument("--heads", type=int, default=1) + parser.add_argument("--seq_len", type=int, default=16384) + parser.add_argument("--head_dim", type=int, default=64) + parser.add_argument("--vertical_size", type=int, default=1000) + parser.add_argument("--slash_size", type=int, default=200) + args = parser.parse_args(argv) + BATCH, N_HEADS, SEQ_LEN, D_HEAD = args.batch, args.heads, args.seq_len, args.head_dim + vertical_size, slash_size = args.vertical_size, args.slash_size + torch.manual_seed(0) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + q_len = SEQ_LEN + vertical_size, slash_size = min(q_len, vertical_size), min(q_len, slash_size) + last_q = 64 + qk = torch.einsum("bhmk, bhnk -> bhmn", q[:, :, -last_q:, :], k) + arange = torch.arange(last_q, device="cuda") + qk[:, :, :, -last_q:] = torch.where(arange[None, None, :, None] >= arange[None, None, None, :], qk[:, :, :, -last_q:], -torch.inf) + qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32) + vertical = qk.sum(-2, keepdim=True) + vertical[..., :30] = torch.inf + vertical_topk = torch.topk(vertical, vertical_size, -1).indices + slash = sum_all_diagonal_matrix(qk)[..., : -last_q + 1] + slash[..., -30:] = torch.inf + slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices + block_size_M = 64 + block_size_N = 64 + query, key, value = q, k, v + v_idx, s_idx = vertical_topk, slash + batch_size, num_heads, context_size, head_dim = query.shape + v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0] + s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0] + from torch.utils.cpp_extension import load + import os + + current_dir = os.path.dirname(os.path.abspath(__file__)) + sources = [os.path.join(current_dir, "ops", "kernels.cpp"), os.path.join(current_dir, "ops", "vertical_slash_index.cu")] + ops = load(name="convert", sources=sources, verbose=False) + convert_vertical_slash_indexes = ops.convert_vertical_slash_indexes + batch_size, num_heads, context_size, head_dim = query.shape + pad = (block_size_M - context_size) & (block_size_M - 1) + if pad == block_size_M: + pad = 0 + query = torch.nn.functional.pad(query, [0, 0, 0, pad, 0, 0, 0, 0]) + key = torch.nn.functional.pad(key, [0, 0, 0, pad, 0, 0, 0, 0]) + value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0]) + if head_dim not in [16, 32, 64, 128, 256, 512]: + target_dim = 2 ** math.ceil(math.log2(head_dim)) - head_dim + query = torch.nn.functional.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0]) + key = torch.nn.functional.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0]) + value = torch.nn.functional.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0]) + v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0] + s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0] + seqlens = torch.tensor([context_size] * query.shape[0], dtype=torch.int32, device=query.device) + block_count, block_offset, column_count, column_index = convert_vertical_slash_indexes( + seqlens, + v_idx, + s_idx, + context_size, + block_size_M, + block_size_N, + ) + tl_kernel = _tl_vs_sparse_flashattn(batch_size, num_heads, context_size, head_dim, vertical_topk.shape[-1], slash.shape[-1]) + + def run_kernel_only(): + tl_kernel(query, key, value, block_count, block_offset, column_count, column_index) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/minference/regression_vs_sparse_attn.py b/examples/minference/regression_vs_sparse_attn.py new file mode 100644 index 000000000..32fdfa9e8 --- /dev/null +++ b/examples/minference/regression_vs_sparse_attn.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_vertical_slash_sparse_attn + + +def regression_example_vertical_slash_sparse_attn(): + tilelang.testing.process_func(example_vertical_slash_sparse_attn.run_regression_perf, argv=[]) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/norm/rms_norm.py b/examples/norm/rms_norm.py index 25bac50fc..57bccc1a0 100644 --- a/examples/norm/rms_norm.py +++ b/examples/norm/rms_norm.py @@ -4,7 +4,7 @@ def rms_norm_splitk(M, N, blk_m, blk_k): - dtype = "float" + dtype = T.float @T.prim_func def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): @@ -21,7 +21,7 @@ def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): A_local[i, j] += A_shared[i, j] * A_shared[i, j] T.reduce_sum(A_local, A_powsum, dim=1) for i in T.Parallel(blk_m): - A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12 + A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) for k in range(num_k_step): # reverse, better cache hit rate @@ -35,7 +35,7 @@ def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): @tilelang.jit(out_idx=[-1], pass_configs={"tl.disable_tma_lower": True}) def rms_norm(M, N, blk_m): - dtype = "float" + dtype = T.float @T.prim_func def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): @@ -45,16 +45,16 @@ def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): A_local = T.alloc_fragment((blk_m, N), dtype) A_powsum = T.alloc_fragment((blk_m,), dtype) - T.copy(A[bx * blk_m:(bx + 1) * blk_m, :], A_shared) + T.copy(A[bx * blk_m : (bx + 1) * blk_m, :], A_shared) T.copy(A_shared, A_local) for i, j in T.Parallel(blk_m, N): A_pow_local[i, j] = A_local[i, j] * A_local[i, j] T.reduce_sum(A_pow_local, A_powsum, dim=1) for i in T.Parallel(blk_m): - A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12 + A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) for i, j in T.Parallel(blk_m, N): A_local[i, j] *= A_powsum[i] - T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :]) + T.copy(A_local, B[bx * blk_m : (bx + 1) * blk_m, :]) return main diff --git a/examples/norm/test_rms_norm.py b/examples/norm/test_rms_norm.py index 8cc413531..53db03d98 100644 --- a/examples/norm/test_rms_norm.py +++ b/examples/norm/test_rms_norm.py @@ -5,7 +5,7 @@ def rms_norm_splitk(M, N, blk_m, blk_k): - dtype = "float" + dtype = T.float @T.prim_func def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): @@ -22,7 +22,7 @@ def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): A_local[i, j] += A_shared[i, j] * A_shared[i, j] T.reduce_sum(A_local, A_powsum, dim=1) for i in T.Parallel(blk_m): - A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12 + A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) for k in range(num_k_step): # reverse, better cache hit rate @@ -35,7 +35,7 @@ def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): def rms_norm(M, N, blk_m): - dtype = "float" + dtype = T.float @T.prim_func def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): @@ -45,16 +45,16 @@ def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): A_local = T.alloc_fragment((blk_m, N), dtype) A_powsum = T.alloc_fragment((blk_m,), dtype) - T.copy(A[bx * blk_m:(bx + 1) * blk_m, :], A_shared) + T.copy(A[bx * blk_m : (bx + 1) * blk_m, :], A_shared) T.copy(A_shared, A_local) for i, j in T.Parallel(blk_m, N): A_pow_local[i, j] = A_local[i, j] * A_local[i, j] T.reduce_sum(A_pow_local, A_powsum, dim=1) for i in T.Parallel(blk_m): - A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12 + A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) for i, j in T.Parallel(blk_m, N): A_local[i, j] *= A_powsum[i] - T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :]) + T.copy(A_local, B[bx * blk_m : (bx + 1) * blk_m, :]) return main diff --git a/examples/online_softmax/online_softmax.py b/examples/online_softmax/online_softmax.py index 432482d06..811870e44 100644 --- a/examples/online_softmax/online_softmax.py +++ b/examples/online_softmax/online_softmax.py @@ -9,19 +9,19 @@ def softmax_kernel( M, N, - dtype: str = "float16", + dtype: T.dtype = T.float16, ) -> "Callable": BN = min(tl.next_power_of_2(N), 8192) NN = tl.cdiv(N, BN) - accum_dtype = "float" + accum_dtype = T.float32 scale = 1.44269504 # log2(e) @T.prim_func def main( - X: T.Tensor([M, N], dtype), - Y: T.Tensor([M, N], dtype), + X: T.Tensor([M, N], dtype), + Y: T.Tensor([M, N], dtype), ): with T.Kernel(M, threads=128) as (i_m): x = T.alloc_fragment([BN], dtype) @@ -33,7 +33,7 @@ def main( T.fill(lse, -T.infinity(accum_dtype)) for i_n in T.Pipelined(0, NN): - T.copy(X[i_m, i_n * BN:(i_n + 1) * BN], x) + T.copy(X[i_m, i_n * BN : (i_n + 1) * BN], x) T.reduce_max(x, max_x, dim=0, clear=True) @@ -45,12 +45,12 @@ def main( lse[0] = max_x[0] * scale + T.log2(T.exp2(lse[0] - max_x[0] * scale) + sum_exp_x[0]) for i_n in T.Pipelined(0, NN): - T.copy(X[i_m, i_n * BN:(i_n + 1) * BN], x) + T.copy(X[i_m, i_n * BN : (i_n + 1) * BN], x) for j in T.Parallel(BN): y[j] = T.exp2(x[j] * scale - lse[0]) - T.copy(y, Y[i_m, i_n * BN:(i_n + 1) * BN]) + T.copy(y, Y[i_m, i_n * BN : (i_n + 1) * BN]) return main @@ -69,4 +69,4 @@ def main( t2 = do_bench(lambda: kernel(X), warmup=25, rep=100) print(f"torch latency: {t1:.3f} ms") print(f"TileLang latency: {t2:.3f} ms") -print(f"Speedup: {t1/t2:.3f}x") +print(f"Speedup: {t1 / t2:.3f}x") diff --git a/examples/plot_layout/README.md b/examples/plot_layout/README.md index a65d771c2..8204e93d8 100644 --- a/examples/plot_layout/README.md +++ b/examples/plot_layout/README.md @@ -10,7 +10,7 @@ from typing import Literal, Callable from tilelang.intrinsics.utils import get_mma_micro_size from tilelang.tools import plot_layout -def make_mma_load_base_layout(dtype: str = "float16", +def make_mma_load_base_layout(dtype: str = T.float16, matrix: Literal["A", "B"] = "A", transposed: bool = False) -> T.Fragment: """ @@ -69,7 +69,7 @@ def make_mma_load_base_layout(dtype: str = "float16", micro_size_s, _, micro_size_r = get_mma_micro_size(dtype) transform_func = transform_func - inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) def forward_thread(i: int, j: int) -> int: """ @@ -94,7 +94,7 @@ def make_mma_load_base_layout(dtype: str = "float16", # Create a 16×16 matrix layout for ldmatrix operations -base_layout = make_mma_load_base_layout(dtype="float16", matrix="A", transposed=False) +base_layout = make_mma_load_base_layout(dtype=T.float16, matrix="A", transposed=False) # Print the layout structure (optional for debugging) print(base_layout) diff --git a/examples/plot_layout/fragment_mfma_load_a.py b/examples/plot_layout/fragment_mfma_load_a.py new file mode 100644 index 000000000..d45cc227b --- /dev/null +++ b/examples/plot_layout/fragment_mfma_load_a.py @@ -0,0 +1,127 @@ +import tilelang.language as T +from typing import Literal, Callable +from tvm.tir import IndexMap +from tilelang.intrinsics.utils import get_mma_micro_size + +from tilelang.intrinsics.mfma_layout import ( + shared_16x4_to_local_64x1_layout_A, + shared_16x16_to_local_64x4_layout_A, + shared_16x32_to_local_64x8_layout_A, + shared_16x64_to_local_64x16_layout_A, +) + + +def make_mfma_load_base_layout( + dtype: T.dtype = T.float16, matrix: Literal["A", "B"] = "A", k_dim: int = 16, transposed: bool = False +) -> T.Fragment: + """ + Create a layout function for storing MFMA results into a fragment buffer. + This layout is used in conjunction with `inverse_mfma_store_layout` to + map fragment indices to threads and local indices. + + Parameters + ---------- + dtype : str + The data type of the matrix. + matrix : Literal["A", "B"] + The mfma operand to be loaded. + k_dim : int + The k dimension of the mfma. + transposed : bool + Whether the matrix is transposed, by default False. + + Returns + ------- + T.Fragment + Describes how threads and indices in fragment are laid out. + + """ + + assert matrix in ["A", "B"], "matrix should be either A or B" + # s represents spatial axis + # r represents reduction axis + # sr represents the two dims are spatial + reduction + # rs represents the two dims are reduction + spatial + transform_func_sr_a: Callable = None + transform_func_sr_b: Callable = None + + if k_dim == 4: + transform_func_sr_a = shared_16x4_to_local_64x1_layout_A + transform_func_sr_b = shared_16x4_to_local_64x1_layout_A + elif k_dim == 16: + transform_func_sr_a = shared_16x16_to_local_64x4_layout_A + transform_func_sr_b = shared_16x16_to_local_64x4_layout_A + elif k_dim == 32: + transform_func_sr_a = shared_16x32_to_local_64x8_layout_A + transform_func_sr_b = shared_16x32_to_local_64x8_layout_A + elif k_dim == 64: + transform_func_sr_a = shared_16x64_to_local_64x16_layout_A + transform_func_sr_b = shared_16x64_to_local_64x16_layout_A + else: + raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently") + + is_sr_conditions = [False] + is_sr_conditions.append(matrix == "A" and not transposed) + is_sr_conditions.append(matrix == "B" and transposed) + is_sr_axis_order = any(is_sr_conditions) + + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtype) + + # the layout of mma.sync is row.col. + # so the b matrix expected a transposed basic layout + transform_func: Callable = None + if matrix == "A": + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) + micro_size_s, micro_size_r = micro_size_x, micro_size_k + elif matrix == "B": + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i) + micro_size_s, micro_size_r = micro_size_k, micro_size_y + else: + raise ValueError(f"Unsupported matrix {matrix}") + + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + lane_id, _ = inverse_mma_load_layout.map_indices([i, j]) + return lane_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + _, local_id = inverse_mma_load_layout.map_indices([i, j]) + return local_id + + base_fragment = T.Fragment( + [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + return base_fragment + + +block_rows = 2 +block_cols = 2 +warp_rows = 2 +warp_cols = 2 +chunk = 2 + +from tilelang.tools import plot_layout + +# ldmatrix layout 16x16 +base_layout = make_mfma_load_base_layout(dtype=T.float16, matrix="A", transposed=False) +print(base_layout) +plot_layout(base_layout, name="base_layout") + +# warp layout 32x32 +warp_layout = base_layout.repeat([warp_rows, warp_cols], repeat_on_thread=False, lower_dim_first=False) +print(warp_layout) +plot_layout(warp_layout, name="warp_layout") + +# block layout 64x32 +block_layout = warp_layout.repeat([block_rows, 1], repeat_on_thread=True, lower_dim_first=True).replicate(block_cols) +print(block_layout) +plot_layout(block_layout, name="block_layout") diff --git a/examples/plot_layout/fragment_mma_load_a.py b/examples/plot_layout/fragment_mma_load_a.py index 988899448..df4a0b887 100644 --- a/examples/plot_layout/fragment_mma_load_a.py +++ b/examples/plot_layout/fragment_mma_load_a.py @@ -5,9 +5,7 @@ from tilelang.intrinsics.utils import get_mma_micro_size -def make_mma_load_base_layout(dtype: str = "float16", - matrix: Literal["A", "B"] = "A", - transposed: bool = False) -> T.Fragment: +def make_mma_load_base_layout(dtype: T.dtype = T.float16, matrix: Literal["A", "B"] = "A", transposed: bool = False) -> T.Fragment: """ Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with `inverse_mma_store_layout` to @@ -36,6 +34,7 @@ def make_mma_load_base_layout(dtype: str = "float16", shared_16x16_to_mma_32x8_layout_sr_b, shared_16x32_to_mma_32x16_layout_sr_b, ) + assert matrix in ["A", "B"], "matrix should be either A or B" dtype_bits = DataType(dtype).bits # s represents spatial axis @@ -67,17 +66,15 @@ def make_mma_load_base_layout(dtype: str = "float16", # so the b matrix expected a transposed basic layout transform_func: Callable = None if matrix == "A": - transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( - j, i) + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) micro_size_s, micro_size_r = micro_size_x, micro_size_k elif matrix == "B": - transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( - j, i) + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i) micro_size_s, micro_size_r = micro_size_k, micro_size_y else: raise ValueError(f"Unsupported matrix {matrix}") - inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) def forward_thread(i: int, j: int) -> int: """ @@ -110,7 +107,7 @@ def forward_index(i: int, j: int) -> int: from tilelang.tools import plot_layout # ldmatrix layout 16x16 -base_layout = make_mma_load_base_layout(dtype="float16", matrix="A", transposed=False) +base_layout = make_mma_load_base_layout(dtype=T.float16, matrix="A", transposed=False) print(base_layout) plot_layout(base_layout, name="base_layout") diff --git a/examples/quickstart.py b/examples/quickstart.py index 42514ee39..e99fc0dbc 100644 --- a/examples/quickstart.py +++ b/examples/quickstart.py @@ -6,13 +6,12 @@ # target currently can be "cuda" or "hip" or "cpu". # if not specified, it will be inferred from the input tensors during compile time @tilelang.jit -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def matmul_relu_kernel( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): @@ -55,10 +54,9 @@ def matmul_relu_kernel( block_N = 128 block_K = 32 -# 1. Define the kernel (matmul) and compile/lower it into an executable module +# Define the kernel (matmul) and compile/lower it into an executable module matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K) - -# 3. Test the kernel in Python with PyTorch data +# Test the kernel in Python with PyTorch data import torch # Create random input tensors on the GPU @@ -78,7 +76,7 @@ def matmul_relu_kernel( print("Kernel output matches PyTorch reference.") # 4. Retrieve and inspect the generated CUDA source (optional) -# cuda_source = jit_kernel.get_kernel_source() +# cuda_source = matmul_relu_kernel.get_kernel_source() # print("Generated CUDA kernel:\n", cuda_source) # 5.Profile latency with kernel diff --git a/examples/rand/rand_uint.py b/examples/rand/rand_uint.py new file mode 100644 index 000000000..466a51b7a --- /dev/null +++ b/examples/rand/rand_uint.py @@ -0,0 +1,57 @@ +import tilelang +import tilelang.language as T +import torch +import triton +import triton.language as tl + + +@tilelang.jit +def tilelang_rand_1d(M=1024, seed=42): + num_per_thread = 128 + threads = 1 + blk_M = num_per_thread * threads + + @T.prim_func + def rand_kernel(A: T.Tensor((M,), "uint32")): + with T.Kernel(T.ceildiv(M, threads * num_per_thread), threads=threads) as bx: + tx = T.get_thread_binding() + T.rng_init(seed, 0, bx * blk_M + tx * num_per_thread) + for i, j in T.Parallel(threads, num_per_thread): + offsets = (bx * threads + i) * num_per_thread + idx = offsets + j + if idx < M: + A[idx] = T.rng_rand() + + return rand_kernel + + +@triton.jit +def triton_rand_1d(X, M, elements_per_thread, seed): + pid = tl.program_id(0) + offset = pid * elements_per_thread + tl.arange(0, elements_per_thread) + + r0, r1, r2, r3 = tl.randint4x(seed, offset) + + base_idx = offset * 4 + tl.store(X + base_idx, r0, mask=base_idx < M) + tl.store(X + base_idx + 1, r1, mask=(base_idx + 1) < M) + tl.store(X + base_idx + 2, r2, mask=(base_idx + 2) < M) + tl.store(X + base_idx + 3, r3, mask=(base_idx + 3) < M) + + +def test_rand_1d(M, seed): + kernel = tilelang_rand_1d(M, seed) + tilelang_result = torch.empty(M, dtype=torch.uint32, device="cuda") + kernel(tilelang_result) + + triton_result = torch.empty(M, dtype=torch.uint32, device="cuda") + grid = (triton.cdiv(M, 128),) + triton_rand_1d[grid](triton_result, tl.constexpr(M), tl.constexpr(128 // 4), seed) + + torch.testing.assert_close(tilelang_result, triton_result) + + +if __name__ == "__main__": + test_rand_1d(1024, 42) + test_rand_1d(512, 123) + test_rand_1d(128, 0) diff --git a/examples/seer_attention/block_sparse_attn_tilelang.py b/examples/seer_attention/block_sparse_attn_tilelang.py index dcd581c6b..0a3c3a6e3 100644 --- a/examples/seer_attention/block_sparse_attn_tilelang.py +++ b/examples/seer_attention/block_sparse_attn_tilelang.py @@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -30,70 +27,33 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F @tilelang.jit( - out_idx=[4], pass_configs={ + out_idx=[4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_causal): block_M = 64 block_N = 64 num_stages = 0 threads = 128 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] block_mask_shape = [batch, heads, downsample_len, downsample_len] - dtype = "float16" - accum_dtype = "float" - block_mask_dtype = "int8" + dtype = T.float16 + accum_dtype = T.float32 + block_mask_dtype = T.int8 def kernel_func(block_M, block_N, num_stages, threads): - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -108,47 +68,61 @@ def main( scores_scale = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - block_mask = T.alloc_local([downsample_len], block_mask_dtype) + block_mask = T.alloc_fragment([downsample_len], block_mask_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - for vj in T.serial(downsample_len): - block_mask[vj] = BlockSparseMask[bz, by, bx, vj] + T.copy(BlockSparseMask[bz, by, bx, :], block_mask) loop_range = T.ceildiv(seq_kv, block_N) for k in T.Pipelined(loop_range, num_stages=num_stages): if block_mask[k] != 0: - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) if is_causal: past_len = seq_kv - seq_q for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else( - bx * block_M + i + past_len >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i + past_len >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) - - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, - scores_sum, logsum) - Rescale(acc_o, scores_scale) - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main @@ -163,44 +137,40 @@ def test_topk_sparse_attention(): torch.manual_seed(0) # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.float16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.float16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) # Run tilelang kernel - kernel = blocksparse_flashattn( - BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) tilelang_output = kernel(q, k, v, block_mask.to(torch.int8)) # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal # PyTorch reference implementation - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale - attn = attn.masked_fill(~full_mask, float('-inf')) + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) print("ref_output", ref_output) print("tilelang_output", tilelang_output) # Verify accuracy - assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), \ - "TileLang output doesn't match reference" + assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), "TileLang output doesn't match reference" print("Pass topk sparse attention test with qlen == klen") @@ -213,42 +183,40 @@ def test_topk_sparse_attention_qlen_lt_klen(): torch.manual_seed(0) # Create inputs. - q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.float16) sm_scale = 1.0 / (D_HEAD**0.5) downsample_factor = BLOCK downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension - x_ds = torch.randn( - BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.float16) + x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.float16) # Force the first column to be high so that the first block is always selected. x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) - kernel = blocksparse_flashattn( - BATCH, N_HEADS, Q_LEN, K_LEN, D_HEAD, downsample_len, is_causal=True) + kernel = blocksparse_flashattn(BATCH, N_HEADS, Q_LEN, K_LEN, D_HEAD, downsample_len, is_causal=True) print(kernel.get_kernel_source()) tilelang_output = kernel(q, k, v, block_mask.to(torch.int8)) past_len = K_LEN - Q_LEN - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale - full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool() + full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool() full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) - causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN) + causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN) final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) - attn = attn.masked_fill(~final_mask, float('-inf')) + attn = attn.masked_fill(~final_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) print("ref_output", ref_output) print("tilelang_output", tilelang_output) @@ -264,5 +232,56 @@ def main(): test_topk_sparse_attention_qlen_lt_klen() +def run_regression_perf(): + BATCH, N_HEADS, SEQ_LEN, D_HEAD = 4, 2, 256, 64 + TOPK = 2 + BLOCK = 64 + torch.manual_seed(0) + + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.float16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + + kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(q, k, v, block_mask.to(torch.int8)) + + latency_1 = do_bench(run_kernel_only, backend="cupti") + + BATCH, N_HEADS = 1, 1 + Q_LEN, K_LEN, D_HEAD = 128, 256, 64 + TOPK = 1 + BLOCK = 64 + torch.manual_seed(0) + + q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.float16) + + downsample_factor = BLOCK + downsample_len = math.ceil(K_LEN / downsample_factor) + x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.float16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + + kernel = blocksparse_flashattn(BATCH, N_HEADS, Q_LEN, K_LEN, D_HEAD, downsample_len, is_causal=True) + print(kernel.get_kernel_source()) + + def run_kernel_only2(): + kernel(q, k, v, block_mask.to(torch.int8)) + + latency_2 = do_bench(run_kernel_only2, backend="cupti") + + return (latency_1 + latency_2) / 2 + + if __name__ == "__main__": main() diff --git a/examples/seer_attention/block_sparse_attn_triton.py b/examples/seer_attention/block_sparse_attn_triton.py index ed33cc1e2..b4cc3cd00 100644 --- a/examples/seer_attention/block_sparse_attn_triton.py +++ b/examples/seer_attention/block_sparse_attn_triton.py @@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -54,7 +51,6 @@ def _fwd_kernel_inner( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) if mask_val == True: @@ -69,7 +65,7 @@ def _fwd_kernel_inner( qk *= sm_scale # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N - qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float('-inf')) + qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf")) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_ij[:, None] @@ -149,7 +145,7 @@ def _fwd_kernel( v_ptrs = V + off_v mask_ptrs = block_mask_ptr + start_m * stride_bmm - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) @@ -185,24 +181,12 @@ def _fwd_kernel( acc = acc * l_recip acc = acc.to(Out.dtype.element_ty) - off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[ - None, :] * stride_od + off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) -def _forward(ctx, - q, - k, - v, - block_sparse_mask, - sm_scale, - BLOCK_M=64, - BLOCK_N=64, - num_warps=None, - num_stages=1, - out=None): - +def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None): assert q.shape[-1] == k.shape[-1] == v.shape[-1] assert k.shape[2] == v.shape[2] o = out if out is not None else torch.empty_like(q).contiguous() @@ -247,7 +231,6 @@ def _forward(ctx, class _sparse_attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, block_sparse_dense, sm_scale): # shape constraints @@ -271,9 +254,9 @@ def test_topk_sparse_attention(): torch.manual_seed(0) # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) @@ -281,9 +264,7 @@ def test_topk_sparse_attention(): downsample_len = math.ceil(SEQ_LEN / downsample_factor) print("downsample_len", downsample_len) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 print("x_ds.shape", x_ds.shape) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) @@ -295,22 +276,21 @@ def test_topk_sparse_attention(): # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal # PyTorch reference implementation - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale - attn = attn.masked_fill(~full_mask, float('-inf')) + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) # print("ref_output", ref_output) # print("triton_output", triton_output) # Verify accuracy - assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ - "Triton output doesn't match reference" + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference" print("Pass topk sparse attention test with qlen == klen") @@ -322,16 +302,15 @@ def test_topk_sparse_attention_qlt_kl(): torch.manual_seed(0) # Create inputs. - q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) + q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) # softmax scale sm_scale = 1.0 / (D_HEAD**0.5) downsample_factor = BLOCK downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension - x_ds = torch.randn( - BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.bfloat16) + x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.bfloat16) # Force the first column to be high so that the first block is always selected. x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) @@ -340,26 +319,25 @@ def test_topk_sparse_attention_qlt_kl(): past_len = K_LEN - Q_LEN - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale - full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool() + full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool() full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) - causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN) + causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN) final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) - attn = attn.masked_fill(~final_mask, float('-inf')) + attn = attn.masked_fill(~final_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) # Verify accuracy. - assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ - "Triton output doesn't match reference when qlen < klen" + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference when qlen < klen" print("Pass topk sparse attention test with qlen < klen") diff --git a/examples/seer_attention/regression_block_sparse_attn_tilelang.py b/examples/seer_attention/regression_block_sparse_attn_tilelang.py new file mode 100644 index 000000000..86d7b3b28 --- /dev/null +++ b/examples/seer_attention/regression_block_sparse_attn_tilelang.py @@ -0,0 +1,10 @@ +import tilelang.testing +import block_sparse_attn_tilelang + + +def regression_block_sparse_attn_tilelang(): + tilelang.testing.process_func(block_sparse_attn_tilelang.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/sparse_tensorcore/regression_example_sparse_tensorcore.py b/examples/sparse_tensorcore/regression_example_sparse_tensorcore.py new file mode 100644 index 000000000..1167c1603 --- /dev/null +++ b/examples/sparse_tensorcore/regression_example_sparse_tensorcore.py @@ -0,0 +1,11 @@ +import tilelang.testing +import tilelang +import tilelang_example_sparse_tensorcore + + +def regression_example_sparse_tensorcore(): + tilelang.testing.process_func(tilelang_example_sparse_tensorcore.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py b/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py index 59c79c283..f33832aff 100644 --- a/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py +++ b/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py @@ -1,7 +1,8 @@ import torch import tilelang from tilelang.utils.sparse import compress_sm90 -from tilelang.layout import make_metadata_layout +from tilelang.layout import make_cutlass_metadata_layout +from tilelang import language as T import tilelang.testing @@ -24,32 +25,24 @@ def matmul_sp( A_shared_shape = (block_M, block_K // 2) B_shared_shape = (block_K, block_N) - import tilelang.language as T - @T.prim_func def main( - A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // 8), 'uint8'), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // 8), "uint8"), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) - E_shared = T.alloc_shared((block_M, block_K // 8), 'uint8') + E_shared = T.alloc_shared((block_M, block_K // 8), "uint8") C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - T.annotate_layout({ - E: - make_metadata_layout( - E, mma_dtype="float16", arch="9.0", backend="cutlass", block_k=block_K), - E_shared: - make_metadata_layout( - E_shared, - mma_dtype="float16", - arch="9.0", - backend="cutlass", - block_k=block_K), - }) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="9.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="9.0", block_k=block_K), + } + ) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(E[by * block_M, k * block_K // 8], E_shared) @@ -61,7 +54,7 @@ def main( return main -def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device='cpu'): +def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device="cpu"): if shape[-1] % 4 != 0: raise ValueError("Last dimension must be divisible by 4 for 2:4 sparsity.") @@ -106,9 +99,9 @@ def run_gemm_sp( num_threads, ) - A = generate_2_to_4_sparse_tensor((M, K), dtype=torch.float16, device='cuda') + A = generate_2_to_4_sparse_tensor((M, K), dtype=torch.float16, device="cuda") A_sparse, E = compress_sm90(A, block_k=block_K, transposed=False) - B = torch.randn((K, N), device='cuda', dtype=torch.float16) + B = torch.randn((K, N), device="cuda", dtype=torch.float16) C_sp = kernel(A_sparse, E, B).half() C = torch.matmul(A, B) @@ -117,7 +110,46 @@ def run_gemm_sp( def main(): - run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 128, 128, 128, 2, 128) + run_gemm_sp(512, 1024, 768, T.float16, T.float16, T.float32, 128, 128, 128, 2, 128) + + +def run_regression_perf(): + M, N, K, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages, num_threads = ( + 512, + 1024, + 768, + 128, + 128, + 128, + "float16", + "float16", + "float32", + 2, + 128, + ) + kernel = matmul_sp( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + num_threads, + ) + A = generate_2_to_4_sparse_tensor((M, K), dtype=torch.float16, device="cuda") + A_sparse, E = compress_sm90(A, block_k=block_K, transposed=False) + B = torch.randn((K, N), device="cuda", dtype=torch.float16) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(A_sparse, E, B) + + return do_bench(run_kernel_only, backend="cupti") if __name__ == "__main__": diff --git a/examples/topk/example_topk.py b/examples/topk/example_topk.py index 0ca19fb18..ed5ba0d4a 100644 --- a/examples/topk/example_topk.py +++ b/examples/topk/example_topk.py @@ -22,19 +22,19 @@ def tl_topk( blk_m, threads=128, ): - dtype = "float32" + dtype = T.float32 @T.prim_func def topk_kernel( - logits: T.Tensor([M, N], dtype), - topk_gates: T.Tensor([M, topk], dtype), - topk_indices: T.Tensor([M, topk], "int32"), + logits: T.Tensor([M, N], dtype), + topk_gates: T.Tensor([M, topk], dtype), + topk_indices: T.Tensor([M, topk], T.int32), ): with T.Kernel(T.ceildiv(M, blk_m), threads=threads) as bx: logits_frag = T.alloc_fragment([blk_m, N], dtype=dtype) max_val = T.alloc_fragment([blk_m], dtype=dtype) - expand_max_idx = T.alloc_fragment([blk_m, N], "int32") - max_idx = T.alloc_fragment([blk_m], "int32") + expand_max_idx = T.alloc_fragment([blk_m, N], T.int32) + max_idx = T.alloc_fragment([blk_m], T.int32) T.copy(logits[bx * blk_m, 0], logits_frag) @@ -43,15 +43,12 @@ def topk_kernel( T.reduce_max(logits_frag, max_val, dim=1, clear=True) for i, j in T.Parallel(blk_m, N): - expand_max_idx[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], j, - expand_max_idx[i, j]) + expand_max_idx[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], j, expand_max_idx[i, j]) T.reduce_max(expand_max_idx, max_idx, dim=1, clear=True) for i, j in T.Parallel(blk_m, N): - - logits_frag[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], -10000.0, - logits_frag[i, j]) + logits_frag[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], -10000.0, logits_frag[i, j]) for i in T.Parallel(blk_m): topk_gates[bx * blk_m + i, k] = max_val[i] @@ -61,7 +58,6 @@ def topk_kernel( def ref_program(logits, top_k): - top_k_gates, top_k_indices = logits.topk(top_k, dim=1) return top_k_gates, top_k_indices.to(torch.int32) @@ -93,5 +89,29 @@ def main(argv=None): print(f"Tilelang latency: {tilelang_latency}") +def run_regression_perf(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument("--M", type=int, default=320, help="num_tokens") + parser.add_argument("--N", type=int, default=128, help="num_experts") + parser.add_argument("--topk", type=int, default=6, help="topk") + parser.add_argument("--blk_m", type=int, default=64, help="blk_m") + # In benchmark mode, ignore process-wide sys.argv unless an explicit argv is provided. + args = parser.parse_args(argv or []) + M, N, topk, blk_m = args.M, args.N, args.topk, args.blk_m + + logits = torch.rand((M, N), device="cuda", dtype=torch.float32) + + kernel = tl_topk(M=M, N=N, topk=topk, blk_m=blk_m) + tl_gates, tl_indices = kernel(logits) + + torch_gates, torch_indices = ref_program(logits, topk) + + torch.testing.assert_close(tl_gates, torch_gates) + torch.testing.assert_close(tl_indices, torch_indices) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/topk/regression_topk_tilelang.py b/examples/topk/regression_topk_tilelang.py new file mode 100644 index 000000000..f59d866e8 --- /dev/null +++ b/examples/topk/regression_topk_tilelang.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_topk + + +def regression_example_topk(): + tilelang.testing.process_func(example_topk.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/visual_layout_inference/visual_layout_inference.py b/examples/visual_layout_inference/visual_layout_inference.py new file mode 100644 index 000000000..8fa1eaf85 --- /dev/null +++ b/examples/visual_layout_inference/visual_layout_inference.py @@ -0,0 +1,61 @@ +import tilelang +import tilelang.language as T + + +# use pass_configs to enable layout visualization +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE: True, + tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS: "svg", + }, +) +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def gemm( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return gemm + + +def main(): + kernel = matmul(128, 128, 128, 32, 32, 32) + + import torch + + a = torch.randn(128, 128).cuda().half() + b = torch.randn(128, 128).cuda().half() + + c = kernel(a, b) + + ref_c = a @ b + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("All check passed.") + + # print the layout visualization result and save figures to ./tmp. + """ + C_local inferenced layout: + Shape: [32, 32] -> [8] + Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2 + Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2] + """ + + +if __name__ == "__main__": + main() diff --git a/examples/warp_specialize/example_warp_specialize_flashmla.py b/examples/warp_specialize/example_warp_specialize_flashmla.py index 4a8f41ee4..155a45970 100644 --- a/examples/warp_specialize/example_warp_specialize_flashmla.py +++ b/examples/warp_specialize/example_warp_specialize_flashmla.py @@ -9,21 +9,23 @@ @tilelang.jit(out_idx=[6]) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): - scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) - dtype = "float16" - accum_dtype = "float" + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // kv_head_num VALID_BLOCK_H = min(block_H, kv_group_num) assert kv_head_num == 1, "kv_head_num must be 1" h_dim = dim // 2 - @T.macro - def flash_attn( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid): # smem_sQ @@ -81,11 +83,6 @@ def flash_attn( cur_kv_head = hid // (kv_group_num // block_H) - T.annotate_layout({ - O_shared_l: tilelang.layout.make_swizzled_layout(O_shared_l), - O_shared_r: tilelang.layout.make_swizzled_layout(O_shared_r), - }) - # barriers_Q q_shared_ready_barrier = T.alloc_barrier(arrive_count=256) @@ -108,9 +105,9 @@ def flash_attn( tx = T.get_thread_binding() - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :h_dim], Q_shared_l) - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, h_dim:], Q_shared_r) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :h_dim], Q_shared_l) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, h_dim:], Q_shared_r) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.barrier_arrive(q_shared_ready_barrier) T.barrier_wait(q_shared_ready_barrier, 0) @@ -123,25 +120,18 @@ def flash_attn( T.fill(acc_o_l, 0) T.fill(logsum_0, 0) - T.copy(KV[bid, block_N:2 * block_N, cur_kv_head, :h_dim], KV_shared_1_l) + T.copy(KV[bid, block_N : 2 * block_N, cur_kv_head, :h_dim], KV_shared_1_l) T.barrier_arrive(kv_shared_1_l_is_ready) - T.copy(KV[bid, block_N:2 * block_N, cur_kv_head, h_dim:], KV_shared_1_r) + T.copy(KV[bid, block_N : 2 * block_N, cur_kv_head, h_dim:], KV_shared_1_r) T.barrier_arrive(kv_shared_1_r_is_ready) - T.copy(K_pe[bid, block_N:2 * block_N, cur_kv_head, :], K_pe_shared_1) + T.copy(K_pe[bid, block_N : 2 * block_N, cur_kv_head, :], K_pe_shared_1) T.barrier_arrive(kv_shared_1_pe_is_ready) for k in T.serial(loop_range): - T.barrier_wait(kv_shared_0_l_is_ready, k % 2) - T.gemm( - Q_shared_l, - KV_shared_0_l, - acc_s_0, - transpose_B=True, - clear_accum=True, - wg_wait=-1) + T.gemm(Q_shared_l, KV_shared_0_l, acc_s_0, transpose_B=True, clear_accum=True, wg_wait=-1) T.barrier_wait(kv_shared_0_r_is_ready, k % 2) T.gemm(Q_shared_r, KV_shared_0_r, acc_s_0, transpose_B=True, wg_wait=-1) @@ -161,8 +151,7 @@ def flash_attn( for i, j in T.Parallel(block_H, block_N): acc_s_0[i, j] = T.exp2(acc_s_0[i, j] * scale - scores_max[i] * scale) for i in T.Parallel(block_H): - scores_scale_0[i] = T.exp2(scores_max_prev_0[i] * scale - - scores_max[i] * scale) + scores_scale_0[i] = T.exp2(scores_max_prev_0[i] * scale - scores_max[i] * scale) T.reduce_sum(acc_s_0, scores_sum_0, dim=1) @@ -182,9 +171,7 @@ def flash_attn( T.barrier_wait(scale_1_ready_barrier, k % 2) if k < loop_range - 1: - T.copy( - KV[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, - cur_kv_head, :h_dim], KV_shared_0_l) + T.copy(KV[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, :h_dim], KV_shared_0_l) T.barrier_arrive(kv_shared_0_l_is_ready) # Step 11. @@ -204,15 +191,10 @@ def flash_attn( T.gemm(SP1_shared, KV_shared_1_l, acc_o_l) if k < loop_range - 1: - - T.copy( - KV[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, - cur_kv_head, :h_dim], KV_shared_1_l) + T.copy(KV[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, :h_dim], KV_shared_1_l) T.barrier_arrive(kv_shared_1_l_is_ready) - T.copy( - K_pe[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, cur_kv_head, :], - K_pe_shared_1) + T.copy(K_pe[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, :], K_pe_shared_1) T.barrier_arrive(kv_shared_1_pe_is_ready) T.copy(logsum_0, logsum) @@ -221,8 +203,7 @@ def flash_attn( for i, j in T.Parallel(block_H, h_dim): acc_o_l[i, j] /= logsum[i] T.copy(acc_o_l, O_shared_l) - T.copy(O_shared_l, Output[bid, - hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :h_dim]) + T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :h_dim]) else: T.copy(Q_pe_shared, Q_pe_local_1) @@ -237,16 +218,9 @@ def flash_attn( T.barrier_arrive(kv_shared_0_pe_is_ready) for k in T.serial(loop_range): - # Step 2. T.barrier_wait(kv_shared_1_l_is_ready, k % 2) - T.gemm( - Q_shared_l, - KV_shared_1_l, - acc_s_1, - transpose_B=True, - clear_accum=True, - wg_wait=-1) + T.gemm(Q_shared_l, KV_shared_1_l, acc_s_1, transpose_B=True, clear_accum=True, wg_wait=-1) T.barrier_wait(kv_shared_1_r_is_ready, k % 2) T.gemm(Q_shared_r, KV_shared_1_r, acc_s_1, transpose_B=True, wg_wait=-1) @@ -265,8 +239,7 @@ def flash_attn( T.copy(scores_max_1, scores_max) for i in T.Parallel(block_H): - scores_scale_1[i] = T.exp2(scores_max_prev_1[i] * scale - - scores_max[i] * scale) + scores_scale_1[i] = T.exp2(scores_max_prev_1[i] * scale - scores_max[i] * scale) # Step 8. for i, j in T.Parallel(block_H, block_N): @@ -279,8 +252,7 @@ def flash_attn( acc_o_r[i, j] = acc_o_r[i, j] * (scores_scale_0[i] * scores_scale_1[i]) for i in T.Parallel(block_H): - logsum_1[i] = logsum_1[i] * scores_scale_1[i] * scores_scale_0[ - i] + scores_sum_1[i] + logsum_1[i] = logsum_1[i] * scores_scale_1[i] * scores_scale_0[i] + scores_sum_1[i] T.barrier_arrive(scale_1_ready_barrier) @@ -291,9 +263,7 @@ def flash_attn( T.barrier_arrive(s_shared_ready_barrier) if k < loop_range - 1: - T.copy( - KV[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, cur_kv_head, - h_dim:], KV_shared_1_r) + T.copy(KV[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, h_dim:], KV_shared_1_r) T.barrier_arrive(kv_shared_1_r_is_ready) T.barrier_wait(p0_1_1_ready_barrier, k % 2) @@ -301,15 +271,10 @@ def flash_attn( T.gemm(SP0_shared, KV_shared_0_r, acc_o_r) if k < loop_range - 1: - - T.copy( - KV[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, cur_kv_head, - h_dim:], KV_shared_0_r) + T.copy(KV[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, h_dim:], KV_shared_0_r) T.barrier_arrive(kv_shared_0_r_is_ready) - T.copy( - K_pe[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, cur_kv_head, :], - K_pe_shared_0) + T.copy(K_pe[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, :], K_pe_shared_0) T.barrier_arrive(kv_shared_0_pe_is_ready) T.barrier_wait(lse_0_ready_barrier, 0) @@ -319,20 +284,7 @@ def flash_attn( for i, j in T.Parallel(block_H, h_dim): acc_o_r[i, j] /= logsum[i] T.copy(acc_o_r, O_shared_r) - T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - h_dim:]) - - @T.prim_func - def main_no_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), - ): - flash_attn(Q, Q_pe, KV, K_pe, Output) + T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, h_dim:]) return main_no_split @@ -352,31 +304,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out @@ -399,12 +344,12 @@ def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') + parser.add_argument("--batch", type=int, default=132, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) diff --git a/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py b/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py index 3f552795e..1672dbfb8 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py @@ -1,12 +1,13 @@ import tilelang import tilelang.language as T +tilelang.disable_cache() + # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit @tilelang.jit(out_idx=[2]) -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): num_stages = 2 mbarrier_list = [128, 128] * num_stages @@ -30,19 +31,13 @@ def main( for ko in range(T.ceildiv(K, block_K)): with T.ws(1): - T.mbarrier_wait_parity( - mbarrier=ko % num_stages + num_stages, - parity=((ko // num_stages) % num_stages) ^ 1) - T.copy(A[by * block_M:(by + 1) * block_M, ko * block_K:(ko + 1) * block_K], - A_shared[ko % num_stages, :, :]) - T.copy(B[ko * block_K:(ko + 1) * block_K, bx * block_N:(bx + 1) * block_N], - B_shared[ko % num_stages, :, :]) + T.mbarrier_wait_parity(mbarrier=ko % num_stages + num_stages, parity=((ko // num_stages) % num_stages) ^ 1) + T.copy(A[by * block_M : (by + 1) * block_M, ko * block_K : (ko + 1) * block_K], A_shared[ko % num_stages, :, :]) + T.copy(B[ko * block_K : (ko + 1) * block_K, bx * block_N : (bx + 1) * block_N], B_shared[ko % num_stages, :, :]) T.mbarrier_arrive(mbarrier=ko % num_stages) with T.ws(0): - T.mbarrier_wait_parity( - mbarrier=ko % num_stages, parity=(ko // num_stages) % num_stages) - T.gemm(A_shared[ko % num_stages, :, :], B_shared[ko % num_stages, :, :], - C_local) + T.mbarrier_wait_parity(mbarrier=ko % num_stages, parity=(ko // num_stages) % num_stages) + T.gemm(A_shared[ko % num_stages, :, :], B_shared[ko % num_stages, :, :], C_local) T.mbarrier_arrive(mbarrier=ko % num_stages + num_stages) with T.ws(0): @@ -52,11 +47,14 @@ def main( def main(M=16384, N=16384, K=16384): + tilelang.disable_cache() block_M = 128 block_N = 128 block_K = 64 jit_kernel = matmul(M, N, K, block_M, block_N, block_K) + print(jit_kernel.get_kernel_source()) + import torch a = torch.randn(M, K, device="cuda", dtype=torch.float16) @@ -84,5 +82,15 @@ def main(M=16384, N=16384, K=16384): print(f"Latency: {latency} ms") +def run_regression_perf(M=16384, N=16384, K=16384): + tilelang.disable_cache() + block_M = 128 + block_N = 128 + block_K = 64 + jit_kernel = matmul(M, N, K, block_M, block_N, block_K) + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py index 9ba9f6816..b582ee74c 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py @@ -5,20 +5,12 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit @tilelang.jit(out_idx=[2]) -def matmul_warp_specialize_copy_0_gemm_1(M, - N, - K, - block_M, - block_N, - block_K, - dtype="float16", - accum_dtype="float"): - +def matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): @@ -82,5 +74,27 @@ def main(M=1024, N=1024, K=1024): print(f"Latency: {latency} ms") +def run_regression_perf(M=4096, N=4096, K=4096): + block_M = 128 + block_N = 128 + block_K = 64 + + jit_kernel = matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K) + + import torch + + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + c = jit_kernel(a, b) + ref_c = a @ b + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py index faaf48c64..d6d243bb0 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py @@ -5,20 +5,12 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit @tilelang.jit(out_idx=[2]) -def matmul_warp_specialize_copy_1_gemm_0(M, - N, - K, - block_M, - block_N, - block_K, - dtype="float16", - accum_dtype="float"): - +def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): @@ -83,5 +75,28 @@ def main(M=16384, N=16384, K=16384): print(f"Latency: {latency} ms") +def run_regression_perf(M=16384, N=16384, K=16384): + block_M = 128 + block_N = 128 + block_K = 64 + + jit_kernel = matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K) + + import torch + + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + c = jit_kernel(a, b) + + ref_c = a @ b + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py index c91274540..5468aa6ea 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py @@ -5,26 +5,20 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - }) -def matmul_warp_specialize_copy_1_gemm_0(M, - N, - K, - block_M, - block_N, - block_K, - dtype="float16", - accum_dtype="float"): - + }, +) +def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): warp_group_num = 2 threads = 128 * warp_group_num @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): diff --git a/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py b/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py index 3b1d86719..54566b785 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py @@ -5,8 +5,7 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit @tilelang.jit(out_idx=[2]) -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def main( A: T.Tensor[(M, K), dtype], @@ -79,5 +78,28 @@ def main(M=16384, N=16384, K=16384): print(f"Latency: {latency} ms") +def run_regression_perf(M=16384, N=16384, K=16384): + block_M = 128 + block_N = 128 + block_K = 64 + + jit_kernel = matmul(M, N, K, block_M, block_N, block_K) + + import torch + + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + c = jit_kernel(a, b) + + ref_c = a @ b + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/warp_specialize/regression_example_warp_specialize.py b/examples/warp_specialize/regression_example_warp_specialize.py new file mode 100644 index 000000000..d5cd17d48 --- /dev/null +++ b/examples/warp_specialize/regression_example_warp_specialize.py @@ -0,0 +1,25 @@ +import tilelang.testing +import example_warp_specialize_gemm_barrierpipe_stage2 +import example_warp_specialize_gemm_copy_0_gemm_1 +import example_warp_specialize_gemm_copy_1_gemm_0 +import example_warp_specialize_gemm_softpipe_stage2 + + +def regression_example_warp_specialize_gemm_barrierpipe_stage2(): + tilelang.testing.process_func(example_warp_specialize_gemm_barrierpipe_stage2.run_regression_perf, M=1024, N=1024, K=1024) + + +def regression_example_warp_specialize_gemm_copy_0_gemm_1(): + tilelang.testing.process_func(example_warp_specialize_gemm_copy_0_gemm_1.run_regression_perf, M=1024, N=1024, K=1024) + + +def regression_example_warp_specialize_gemm_copy_1_gemm_0(): + tilelang.testing.process_func(example_warp_specialize_gemm_copy_1_gemm_0.run_regression_perf, M=1024, N=1024, K=1024) + + +def regression_example_warp_specialize_gemm_softpipe_stage2(): + tilelang.testing.process_func(example_warp_specialize_gemm_softpipe_stage2.run_regression_perf, M=1024, N=1024, K=1024) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/format.sh b/format.sh index 8f127433c..3cc4390db 100755 --- a/format.sh +++ b/format.sh @@ -9,7 +9,7 @@ # bash format.sh --all # # -# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase. +# Ruff (format) + Clang formatter (if installed). This script formats all changed files from the last mergebase. # You are encouraged to run this locally before pushing changes for review. # Cause the script to exit if a single command fails @@ -29,10 +29,7 @@ ALL_FILES='' ONLY_CHANGED='' FILES=() if (($# == 0)); then - if [[ -n "$(git status --porcelain --ignore-submodules --untracked-files=no)" ]]; then - echo "Detected uncommitted changes. Please commit or stash them before running $0." >&2 - exit 1 - fi + # Default: allow dirty workspace; run on changed files (committed + worktree) ONLY_CHANGED='true' else while (($# > 0)); do @@ -78,14 +75,17 @@ if [[ -n "${ALL_FILES}" ]]; then echo "Checking all files..." >&2 elif [[ -n "${ONLY_CHANGED}" ]]; then MERGE_BASE="$(get_merge_base)" - echo "Checking changed files compared to merge base (${MERGE_BASE})..." >&2 + echo "Checking changed files vs merge base (${MERGE_BASE}) and working tree..." >&2 elif [[ "${#FILES[@]}" -gt 0 ]]; then echo "Checking specified files: ${FILES[*]}..." >&2 fi +# Some systems set pip's default to --user, which breaks isolated virtualenvs. +export PIP_USER=0 + # If pre-commit is not installed, install it. if ! python3 -m pre_commit --version &>/dev/null; then - python3 -m pip install pre-commit + python3 -m pip install pre-commit --user fi echo 'tile-lang pre-commit: Check Start' @@ -93,7 +93,17 @@ echo 'tile-lang pre-commit: Check Start' if [[ -n "${ALL_FILES}" ]]; then python3 -m pre_commit run --all-files elif [[ -n "${ONLY_CHANGED}" ]]; then - python3 -m pre_commit run --from-ref "${MERGE_BASE}" --to-ref HEAD + # Collect changed files (committed since merge-base + current worktree) + CHANGED_FILES="$(git diff --name-only --diff-filter=ACM "${MERGE_BASE}" 2>/dev/null || true)" + if [[ -n "${CHANGED_FILES}" ]]; then + echo "Running pre-commit on changed files:" + echo "${CHANGED_FILES}" + # Convert newline-separated files to space-separated and run pre-commit once + CHANGED_FILES_SPACE="$(echo "${CHANGED_FILES}" | tr '\n' ' ')" + python3 -m pre_commit run --files ${CHANGED_FILES_SPACE} + else + echo "No files changed relative to merge base and worktree. Skipping pre-commit." + fi elif [[ "${#FILES[@]}" -gt 0 ]]; then python3 -m pre_commit run --files "${FILES[@]}" fi @@ -105,7 +115,7 @@ echo 'tile-lang clang-tidy: Check Start' if [[ -x "$(command -v run-clang-tidy)" ]]; then # Check if clang-tidy is available if [[ ! -x "$(command -v clang-tidy)" ]]; then - python3 -m pip install --upgrade --requirements "${ROOT}/requirements-lint.txt" + python3 -m pip install --upgrade --requirements "${ROOT}/requirements-lint.txt" --user fi # Get clang-tidy version CLANG_TIDY_VERSION="$(clang-tidy --version | head -n1 | awk '{print $4}')" diff --git a/images/MatmulExample.svg b/images/MatmulExample.svg index 6e20daf55..294e8f631 100644 --- a/images/MatmulExample.svg +++ b/images/MatmulExample.svg @@ -1 +1 @@ -A_shared=T.alloc_shared((block_M,block_K))B_shared=T.alloc_shared((block_K,block_N))C_local=T.alloc_fragment((block_M,block_N),accum_dtype)importtilelang.languageasTdefMatmul(A:T.Buffer,B:T.Buffer,C:T.Buffer):withT.Kernel(ceildiv(N,block_N),ceildiv(M,block_M),threads=128)as(bx,by):T.clear(C_local)forkinT.Pipelined(ceildiv(K,block_K),num_stages=3):T.copy(A[by*block_M,k*block_K],A_shared)T.copy(B[k*block_K,bx*block_N],B_shared)T.gemm(A_shared,B_shared,C_local)Kernel Context InitializationBuffer AllocationRegisterInitialize Accumulate Buffer with ZeroMain Loop with Pipeline AnnotationT.copy(C_local,C[by*block_M,bx*block_N])Write Back to Global MemoryCopy Data from Global to Shared MemoryGEMMSharedMemoryGlobal MemoryShared MemoryRegister Files(a) Efficient GEMM with Multi-Level Tiling on GPUs(b) Describing Tiled GPU GEMM with TileLang \ No newline at end of file +A_shared=T.alloc_shared((block_M,block_K))B_shared=T.alloc_shared((block_K,block_N))C_local=T.alloc_fragment((block_M,block_N),accum_dtype)importtilelang.languageasTdefMatmul(A:T.Buffer,B:T.Buffer,C:T.Buffer):withT.Kernel(ceildiv(N,block_N),ceildiv(M,block_M),threads=128)as(bx,by):T.clear(C_local)forkinT.Pipelined(ceildiv(K,block_K),num_stages=3):T.copy(A[by*block_M,k*block_K],A_shared)T.copy(B[k*block_K,bx*block_N],B_shared)T.gemm(A_shared,B_shared,C_local)Kernel Context InitializationBuffer AllocationRegisterInitialize Accumulate Buffer with ZeroMain Loop with Pipeline AnnotationT.copy(C_local,C[by*block_M,bx*block_N])Write Back to Global MemoryCopy Data from Global to Shared MemoryGEMMSharedMemoryGlobal MemoryShared MemoryRegister Files(a) Efficient GEMM with Multi-Level Tiling on GPUs(b) Describing Tiled GPU GEMM with TileLang diff --git a/images/logo-row.svg b/images/logo-row.svg index 633243f3a..e73244b74 100644 --- a/images/logo-row.svg +++ b/images/logo-row.svg @@ -1 +1 @@ - \ No newline at end of file + diff --git a/maint/gemm_v2/correctness_evaluation.py b/maint/gemm_v2/correctness_evaluation.py new file mode 100644 index 000000000..44441cdeb --- /dev/null +++ b/maint/gemm_v2/correctness_evaluation.py @@ -0,0 +1,739 @@ +# pytest correctness_evaluation.py -n 32 +import pytest +from tilelang import tvm as tvm +import tilelang.testing +from tilelang import language as T +import torch + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def _compile_and_check( + program, + trans_A, + trans_B, + in_dtype, + out_dtype, +): + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + # tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False, + }, + ) + + print(kernel.get_kernel_source()) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + def ref_program(A, B): + if trans_A: + A = A.T + if trans_B: + B = B.T + if in_dtype == T.float32: + A = (A.view(torch.int32) - 0x1000).view(torch.float32) + B = (B.view(torch.int32) - 0x1000).view(torch.float32) + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + print("assert_allclose") + + +def run_gemm( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=2, + num_threads=128, +): + if block_N >= 256 or block_M >= 256 or block_K >= 256: + num_stages = 0 + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + _compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype) + + +def matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + A_frag_shape = A_shared_shape + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn") + A_frag = T.alloc_fragment(A_frag_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(A_shared, A_frag) + T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B) + # T.gemm(A_frag, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_rs( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=2, + num_threads=128, +): + if block_N >= 256 or block_M >= 256 or block_K >= 256: + num_stages = 0 + program = matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + _compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype) + + +def matmul_sr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + B_frag_shape = B_shared_shape + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn") + B_frag = T.alloc_fragment(B_frag_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(B_shared, B_frag) + T.gemm_v2(A_shared, B_frag, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_sr( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=2, + num_threads=128, +): + if block_N >= 256 or block_M >= 256 or block_K >= 256: + num_stages = 0 + program = matmul_sr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + _compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype) + + +def matmul_rr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + A_frag_shape = A_shared_shape + B_frag_shape = B_shared_shape + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn") + A_frag = T.alloc_fragment(A_frag_shape, in_dtype) + B_frag = T.alloc_fragment(B_frag_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(A_shared, A_frag) + T.copy(B_shared, B_frag) + T.gemm_v2(A_frag, B_frag, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_rr( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=2, + num_threads=128, +): + if block_N >= 256 or block_M >= 256 or block_K >= 256: + num_stages = 0 + program = matmul_rr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + _compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype) + + +M_VALUES = [64, 128, 256] +N_VALUES = [16, 32, 64, 128, 256, 512] +K_VALUES = [16, 32, 64, 128] +K_VALUES_8Bit = [32, 64, 128] +FALSE_TRUE_CASES = ( + [ + pytest.param( + k, + T.float16, + T.float16, + T.float16, + id=f"K{k}-float16-float16-float16", + ) + for k in K_VALUES + ] + + [ + pytest.param( + k, + T.int8, + T.int32, + T.int32, + id="K32-int8-int32-int32", + ) + for k in K_VALUES_8Bit + ] + + [ + pytest.param( + k, + T.float8_e5m2, + T.float32, + T.float32, + id="K32-float8_e5m2-float32-float32", + ) + for k in K_VALUES_8Bit + ] + + [ + pytest.param( + k, + T.float8_e4m3fn, + T.float32, + T.float32, + id="K32-float8_e4m3-float32-float32", + ) + for k in K_VALUES_8Bit + ] +) + + +def _ensure_torch_dtypes(*dtype_names): + import torch + + for name in set(dtype_names): + if not hasattr(torch, name): + pytest.skip(f"Torch does not expose dtype {name}") + + +def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + run_gemm_rs(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k) + + +def run_gemm_rs_false_false(m, n, k): + run_gemm_rs(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k) + + +def run_gemm_rs_true_false(m, n, k): + run_gemm_rs(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k) + + +def run_gemm_rs_true_true(m, n, k): + run_gemm_rs(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k) + + +def run_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + run_gemm_sr(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k) + + +def run_gemm_sr_false_false(m, n, k): + run_gemm_sr(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k) + + +def run_gemm_sr_true_false(m, n, k): + run_gemm_sr(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k) + + +def run_gemm_sr_true_true(m, n, k): + run_gemm_sr(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k) + + +def run_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + run_gemm_rr(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k) + + +def run_gemm_rr_false_false(m, n, k): + run_gemm_rr(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k) + + +def run_gemm_rr_true_false(m, n, k): + run_gemm_rr(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k) + + +def run_gemm_rr_true_true(m, n, k): + run_gemm_rr(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k) + + +TRANS_CASES = [ + pytest.param(False, False, id="nn"), + pytest.param(False, True, id="nt"), + pytest.param(True, False, id="tn"), + pytest.param(True, True, id="tt"), +] + + +@pytest.fixture(scope="module", autouse=True) +def _setup_tilelang_environment(): + tilelang.disable_cache() + tilelang.testing.set_random_seed(42) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES) +def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + import torch + + required_torch_attrs = { + in_dtype, + out_dtype, + accum_dtype, + } + for attr in required_torch_attrs: + if not hasattr(torch, attr): + pytest.skip(f"Torch does not expose dtype {attr}") + run_gemm( + m, + n, + k * 3, + False, + True, + in_dtype, + out_dtype, + accum_dtype, + m, + n, + k, + ) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_false_false(m, n, k): + run_gemm( + m, + n, + k * 3, + False, + False, + T.float16, + T.float16, + T.float16, + m, + n, + k, + ) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_true_false(m, n, k): + run_gemm( + m, + n, + k * 3, + True, + False, + T.float16, + T.float16, + T.float16, + m, + n, + k, + ) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_true_true(m, n, k): + run_gemm( + m, + n, + k * 3, + True, + True, + T.float16, + T.float16, + T.float16, + m, + n, + k, + ) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES) +def test_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + _ensure_torch_dtypes(in_dtype, out_dtype, accum_dtype) + run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_rs_false_false(m, n, k): + _ensure_torch_dtypes(T.float16) + run_gemm_rs_false_false(m, n, k) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_rs_true_false(m, n, k): + _ensure_torch_dtypes(T.float16) + run_gemm_rs_true_false(m, n, k) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_rs_true_true(m, n, k): + _ensure_torch_dtypes(T.float16) + run_gemm_rs_true_true(m, n, k) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES) +def test_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + _ensure_torch_dtypes(in_dtype, out_dtype, accum_dtype) + run_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_sr_false_false(m, n, k): + _ensure_torch_dtypes(T.float16) + run_gemm_sr_false_false(m, n, k) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_sr_true_false(m, n, k): + _ensure_torch_dtypes(T.float16) + run_gemm_sr_true_false(m, n, k) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_sr_true_true(m, n, k): + _ensure_torch_dtypes(T.float16) + run_gemm_sr_true_true(m, n, k) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES) +def test_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + _ensure_torch_dtypes(in_dtype, out_dtype, accum_dtype) + run_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_rr_false_false(m, n, k): + _ensure_torch_dtypes(T.float16) + run_gemm_rr_false_false(m, n, k) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_rr_true_false(m, n, k): + _ensure_torch_dtypes(T.float16) + run_gemm_rr_true_false(m, n, k) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_rr_true_true(m, n, k): + _ensure_torch_dtypes(T.float16) + run_gemm_rr_true_true(m, n, k) + + +if __name__ == "__main__": + tilelang.testing.main() + + # # Test Pass + # for m in [64, 128, 256]: + # for n in [16, 32, 64, 128]: + # for k in [16, 32, 64, 128]: + # print(f"======================= Test {m} {n} {k} False True =============================") + # run_gemm(m, n, k * 3, False, True, T.float16, T.float16, T.float16, m, n, k, 2, 128) + # print(f"Test {m} {n} {k} Pass") + + # # Test Pass + # for m in [64, 128, 256]: + # for n in [16, 32, 64, 128]: + # for k in [16, 32, 64, 128]: + # print(f"======================= Test {m} {n} {k} False False =============================") + # run_gemm(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k, 2, 128) + # print(f"Test {m} {n} {k} Pass") + + # # Test Pass + # for m in [64, 128, 256]: + # for n in [16, 32, 64, 128]: + # for k in [16, 32, 64, 128]: + # print(f"======================= Test {m} {n} {k} True False =============================") + # run_gemm(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k, 2, 128) + # print(f"Test {m}, {n} {k} Pass") + # print(f"Test {n} Pass") + + # # Test Pass + # for m in [64, 128, 256]: + # for n in [16, 32, 64, 128]: + # for k in [16, 32, 64, 128]: + # print(f"======================= Test {m} {n} {k} True True =============================") + # run_gemm(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k, 2, 128) + # print(f"Test {m}, {n} {k} Pass") + # print(f"Test {n} Pass") + + # Test Pass + # for m in [64, 128, 256]: + # for n in [16, 32, 64, 128]: + # for k in [16, 32, 64, 128]: + # print(f"======================= Test {m} {n} {k} False True =============================") + # run_gemm_rs(m, n, k * 3, False, True, T.float16, T.float16, T.float16, m, n, k, 2, 128) + # print(f"Test {m} {n} {k} Pass") + + # for n in [16, 32, 64, 128]: + # for k in [16, 32, 64, 128]: + # run_gemm_rs(64, n, k, False, False, T.float16, T.float16, T.float16, 64, n, k, 0, 256) + # print(f"Test {64} {n} {k} Pass") + + # for n in [16, 32, 64, 128]: + # for k in [16, 32, 64, 128]: + # run_gemm(64, n, k, False, False, T.float16, T.float16, T.float16, 64, n, k, 0, 256) + # print(f"Test {64} {n} {k} Pass") diff --git a/maint/gemm_v2/correctness_evaluation_sm70.py b/maint/gemm_v2/correctness_evaluation_sm70.py new file mode 100644 index 000000000..606d10261 --- /dev/null +++ b/maint/gemm_v2/correctness_evaluation_sm70.py @@ -0,0 +1,350 @@ +# pytest maint/gemm_v2/correctness_evaluation_sm70.py -n 32 +import pytest +from tilelang import tvm as tvm +import tilelang.testing +from tilelang import language as T + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + # T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.gemm_v2(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def _compile_and_check( + program, + trans_A, + trans_B, + in_dtype, + out_dtype, +): + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + # tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False, + }, + ) + + print(kernel.get_kernel_source()) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + if in_dtype == T.float32: + A = (A.view(torch.int32) - 0x1000).view(torch.float32) + B = (B.view(torch.int32) - 0x1000).view(torch.float32) + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + print("assert_allclose") + + +def run_gemm( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + _compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype) + + +def matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + A_frag_shape = A_shared_shape + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn") + A_frag = T.alloc_fragment(A_frag_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(A_shared, A_frag) + T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B) + # T.gemm(A_frag, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_rs( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + _compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype) + + +M_VALUES = [64, 128] +N_VALUES = [32, 64, 128] +K_VALUES = [16, 32, 64] +FALSE_TRUE_CASES = [ + pytest.param( + k, + T.float16, + T.float16, + T.float16, + id=f"K{k}-float16-float16-float16", + ) + for k in K_VALUES +] + [ + pytest.param( + k, + T.float16, + T.float16, + T.float32, + id=f"K{k}-float16-float16-float32", + ) + for k in K_VALUES +] + + +def _ensure_torch_dtypes(*dtype_names): + import torch + + for name in set(dtype_names): + if not hasattr(torch, name): + pytest.skip(f"Torch does not expose dtype {name}") + + +def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + run_gemm_rs(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k, 2, 128) + + +def run_gemm_rs_false_false(m, n, k): + run_gemm_rs(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k, 2, 128) + + +TRANS_CASES = [ + pytest.param(False, False, id="nn"), + pytest.param(False, True, id="nt"), + pytest.param(True, False, id="tn"), + pytest.param(True, True, id="tt"), +] + + +@pytest.fixture(scope="module", autouse=True) +def _setup_tilelang_environment(): + tilelang.disable_cache() + tilelang.testing.set_random_seed(42) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES) +def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + import torch + + required_torch_attrs = { + in_dtype, + out_dtype, + accum_dtype, + } + for attr in required_torch_attrs: + if not hasattr(torch, attr): + pytest.skip(f"Torch does not expose dtype {attr}") + run_gemm( + m, + n, + k * 3, + False, + True, + in_dtype, + out_dtype, + accum_dtype, + m, + n, + k, + 2, + 128, + ) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_false_false(m, n, k): + run_gemm( + m, + n, + k * 3, + False, + False, + T.float16, + T.float16, + T.float16, + m, + n, + k, + 2, + 128, + ) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES) +def test_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + _ensure_torch_dtypes(in_dtype, out_dtype, accum_dtype) + run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_rs_false_false(m, n, k): + _ensure_torch_dtypes(T.float16) + run_gemm_rs_false_false(m, n, k) + + +if __name__ == "__main__": + tilelang.testing.main() + + # # Test Pass + # for m in [64, 128]: + # for n in [16, 32, 64, 128]: + # for k in [16, 32, 64]: + # print(f"======================= Test {m} {n} {k} False True =============================") + # run_gemm(m, n, k * 3, False, True, T.float16, T.float16, T.float16, m, n, k, 2, 128) + # print(f"Test {m} {n} {k} Pass") + + # # Test Pass + # for m in [64, 128]: + # for n in [16, 32, 64, 128]: + # for k in [16, 32, 64]: + # print(f"======================= Test {m} {n} {k} False False =============================") + # run_gemm(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k, 2, 128) + # print(f"Test {m} {n} {k} Pass") diff --git a/maint/gemm_v2/correctness_evaluation_tcgen05.py b/maint/gemm_v2/correctness_evaluation_tcgen05.py new file mode 100644 index 000000000..8d9728182 --- /dev/null +++ b/maint/gemm_v2/correctness_evaluation_tcgen05.py @@ -0,0 +1,218 @@ +# pytest correctness_evaluation.py -n 32 +import pytest +from tilelang import tvm as tvm +import tilelang.testing +import tilelang.language as T + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + mbar = T.alloc_barrier(1) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k == 0) + T.mbarrier_wait_parity(mbar, k % 2) + + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +def _compile_and_check( + program, + trans_A, + trans_B, + in_dtype, + out_dtype, +): + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + + print(kernel.get_kernel_source()) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + if in_dtype == T.float32: + A = (A.view(torch.int32) - 0x1000).view(torch.float32) + B = (B.view(torch.int32) - 0x1000).view(torch.float32) + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + print("assert_allclose") + + +def run_gemm( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=2, + num_threads=128, +): + if block_N >= 256 or block_M >= 256 or block_K >= 256: + num_stages = 0 + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + _compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype) + + +M_VALUES = [32, 64, 128, 256] +N_VALUES = [64, 128, 256, 512] +K_VALUES = [16, 32, 64, 128] +K_VALUES_8Bit = [32, 64, 128] +FALSE_TRUE_CASES = [ + pytest.param( + k, + T.float16, + T.float32, + T.float32, + id=f"K{k}-float16-float-float", + ) + for k in K_VALUES +] + [ + pytest.param( + k, + T.float8_e5m2, + T.float32, + T.float32, + id="K32-float8_e5m2-float32-float32", + ) + for k in K_VALUES_8Bit +] + +TRANS_CASES = [ + pytest.param(False, True, id="nt"), +] + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES) +def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + import torch + + required_torch_attrs = { + in_dtype, + out_dtype, + accum_dtype, + } + for attr in required_torch_attrs: + if not hasattr(torch, attr): + pytest.skip(f"Torch does not expose dtype {attr}") + run_gemm( + m, + n, + k * 3, + False, + True, + in_dtype, + out_dtype, + accum_dtype, + m, + n, + k, + ) + + +if __name__ == "__main__": + tilelang.testing.main() + + # # Test Pass + # for m in [32, 64, 128, 256]: + # for n in [16, 32, 64, 128]: + # for k in [16, 32, 64, 128]: + # if m in [32, 64] and (n not in [64, 128, 256]): + # continue + # print(f"======================= Test {m} {n} {k} False True =============================") + # run_gemm(m, n, k * 3, False, True, T.float16, T.float, T.float, m, n, k, 2, 128) + # print(f"Test {m} {n} {k} Pass") + + # # Test Pass + # for m in [32, 64, 128, 256]: + # for n in [32, 64, 128]: + # for k in [16, 32, 64, 128]: + # if m in [32, 64] and (n not in [64, 128, 256]): + # continue + # print(f"======================= Test {m} {n} {k} False True =============================") + # run_gemm(m, n, k * 3, False, True, T.float16, T.float, T.float, m, n, k, 2, 256) + # print(f"Test {m} {n} {k} Pass") + + # # Test Pass + # for m in [32, 64, 128, 256]: + # for n in [16, 32, 64, 128]: + # for k in [32, 64, 128]: + # if m in [32, 64] and (n not in [64, 128, 256]): + # continue + # print(f"======================= Test {m} {n} {k} False True =============================") + # run_gemm(m, n, k * 3, False, True, T.float8_e5m2, T.float, T.float, m, n, k, 2, 128) diff --git a/maint/gemm_v2/latency.py b/maint/gemm_v2/latency.py new file mode 100644 index 000000000..b7b2a2af9 --- /dev/null +++ b/maint/gemm_v2/latency.py @@ -0,0 +1,98 @@ +import tilelang +import tilelang.language as T +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--use_v2", action="store_true") +args = parser.parse_args() + +use_v2 = args.use_v2 + + +# @tilelang.jit(target="cuda") +# target currently can be "cuda" or "hip" or "cpu". +# if not specified, it will be inferred from the input tensors during compile time +@tilelang.jit +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def matmul_relu_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Enable rasterization for better L2 cache locality (Optional) + # T.use_swizzle(panel_size=10, enable=True) + + # Clear local accumulation + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy tile of A + # This is a sugar syntax for parallelized copy + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Copy tile of B + T.copy(B[ko * block_K, bx * block_N], B_shared) + + # Perform a tile-level GEMM on the shared buffers + # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs + if use_v2: + T.gemm_v2(A_shared, B_shared, C_local) + else: + T.gemm_v1(A_shared, B_shared, C_local) + + # relu + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = T.max(C_local[i, j], 0) + + # Copy result back to global memory + T.copy(C_local, C[by * block_M, bx * block_N]) + + return matmul_relu_kernel + + +M = 16384 # M = T.dynamic("m") if you want to use dynamic shape +N = 16384 +K = 16384 +block_M = 128 +block_N = 128 +block_K = 32 + +# 1. Define the kernel (matmul) and compile/lower it into an executable module +matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K) + +# 3. Test the kernel in Python with PyTorch data +import torch + +# Create random input tensors on the GPU +a = torch.randn(M, K, device="cuda", dtype=torch.float16) +b = torch.randn(K, N, device="cuda", dtype=torch.float16) +c = torch.empty(M, N, device="cuda", dtype=torch.float16) + +# Run the kernel through the Profiler +matmul_relu_kernel(a, b, c) + +print(c) +# Reference multiplication using PyTorch +ref_c = torch.relu(a @ b) + +# Validate correctness +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) +print("Kernel output matches PyTorch reference.") + +# 4. Retrieve and inspect the generated CUDA source (optional) +# cuda_source = jit_kernel.get_kernel_source() +# print("Generated CUDA kernel:\n", cuda_source) + +# 5.Profile latency with kernel +profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + +latency = profiler.do_bench() + +print(f"Latency: {latency} ms") diff --git a/maint/gemm_v2/latency_gemm.py b/maint/gemm_v2/latency_gemm.py new file mode 100644 index 000000000..5f0450e02 --- /dev/null +++ b/maint/gemm_v2/latency_gemm.py @@ -0,0 +1,98 @@ +import tilelang +import tilelang.language as T +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--use_v2", action="store_true") +args = parser.parse_args() + +use_v2 = args.use_v2 + + +# @tilelang.jit(target="cuda") +# target currently can be "cuda" or "hip" or "cpu". +# if not specified, it will be inferred from the input tensors during compile time +@tilelang.jit +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def matmul_relu_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Enable rasterization for better L2 cache locality (Optional) + # T.use_swizzle(panel_size=10, enable=True) + + # Clear local accumulation + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy tile of A + # This is a sugar syntax for parallelized copy + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Copy tile of B + T.copy(B[ko * block_K, bx * block_N], B_shared) + + # Perform a tile-level GEMM on the shared buffers + # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs + if use_v2: + T.gemm_v2(A_shared, B_shared, C_local) + else: + T.gemm_v1(A_shared, B_shared, C_local) + + # relu + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = T.max(C_local[i, j], 0) + + # Copy result back to global memory + T.copy(C_local, C[by * block_M, bx * block_N]) + + return matmul_relu_kernel + + +M = 16384 # M = T.dynamic("m") if you want to use dynamic shape +N = 16384 +K = 16384 +block_M = 128 +block_N = 128 +block_K = 64 + +# 1. Define the kernel (matmul) and compile/lower it into an executable module +matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K) + +# 3. Test the kernel in Python with PyTorch data +import torch + +# Create random input tensors on the GPU +a = torch.randn(M, K, device="cuda", dtype=torch.float16) +b = torch.randn(K, N, device="cuda", dtype=torch.float16) +c = torch.empty(M, N, device="cuda", dtype=torch.float16) + +# Run the kernel through the Profiler +matmul_relu_kernel(a, b, c) + +print(c) +# Reference multiplication using PyTorch +ref_c = torch.relu(a @ b) + +# Validate correctness +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) +print("Kernel output matches PyTorch reference.") + +# 4. Retrieve and inspect the generated CUDA source (optional) +# cuda_source = jit_kernel.get_kernel_source() +# print("Generated CUDA kernel:\n", cuda_source) + +# 5.Profile latency with kernel +profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + +latency = profiler.do_bench() + +print(f"Latency: {latency} ms") diff --git a/maint/gemm_v2/latency_mha_fwd_bhsd.py b/maint/gemm_v2/latency_mha_fwd_bhsd.py new file mode 100644 index 000000000..7a83d7cec --- /dev/null +++ b/maint/gemm_v2/latency_mha_fwd_bhsd.py @@ -0,0 +1,228 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import itertools +import argparse +from functools import partial + +parser = argparse.ArgumentParser() +parser.add_argument("--batch", type=int, default=128, help="batch size") +parser.add_argument("--heads", type=int, default=16, help="heads") +parser.add_argument("--seq_q", type=int, default=1024, help="query sequence length") +parser.add_argument("--seq_kv", type=int, default=1024, help="key/value sequence length") +parser.add_argument("--dim", type=int, default=256, help="dim") +parser.add_argument("--is_causal", action="store_true", help="causal") +parser.add_argument("--tune", action="store_true", help="tune configs") +parser.add_argument("--use_v2", action="store_true") + +args = parser.parse_args() + +use_v2 = args.use_v2 + + +def get_configs(): + iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + q_shape = [batch, heads, seq_q, dim] + kv_shape = [batch, heads, seq_kv, dim] + dtype = T.float16 + accum_dtype = T.float32 + + past_len = seq_kv - seq_q + assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" + + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + if use_v2: + T.gemm_v2(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + else: + T.gemm_v1(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + # T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + if use_v2: + T.gemm_v2(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + else: + T.gemm_v1(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + if is_causal + else T.ceildiv(seq_kv, block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + + return main + + +def ref_program(Q, K, V, is_causal): + dim = Q.size(-1) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_q = Q.size(2) + seq_kv = K.size(2) + mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) + return output + + +def main( + batch: int = 1, + heads: int = 1, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 64, + is_causal: bool = False, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + if not tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=128) + print(kernel.get_kernel_source()) + ref_program_processed = partial(ref_program, is_causal=is_causal) + + profiler = kernel.get_profiler() + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program_processed, warmup=500) + print(f"Ref: {latency:.2f} ms") + print(f"Ref: {total_flops / latency * 1e-9:.2f} TFlops") + latency = profiler.do_bench(warmup=500) + print(f"Tile-lang: {latency:.2f} ms") + print(f"Tile-lang: {total_flops / latency * 1e-9:.2f} TFlops") + else: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref latency: {ref_latency}") + + +if __name__ == "__main__": + tilelang.disable_cache() + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) diff --git a/maint/host_checks/01_num_args_mismatch.py b/maint/host_checks/01_num_args_mismatch.py new file mode 100644 index 000000000..9528652ee --- /dev/null +++ b/maint/host_checks/01_num_args_mismatch.py @@ -0,0 +1,22 @@ +"""Reproduce: Argument count mismatch. + +Note: The adapter-level wrapper expects only inputs (A, B) because C is marked as output. +Calling with the wrong number of inputs raises a ValueError before host entry. +""" + +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 256 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = torch.empty((M, K), device="cuda", dtype=torch.float16) + # Missing b + # Expected: ValueError with message about expected vs. actual inputs + fn(a) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/02_pointer_type_error.py b/maint/host_checks/02_pointer_type_error.py new file mode 100644 index 000000000..188a4f8cc --- /dev/null +++ b/maint/host_checks/02_pointer_type_error.py @@ -0,0 +1,23 @@ +"""Reproduce: Pointer-type argument expected but scalar provided. + +We pass an integer for A; wrapper forwards it to the host where a pointer is expected. +Expected: error like "Expect buffer A_handle to be pointer or tensor" (exact name depends on kernel param). +""" + +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 256 + fn = build_matmul_kernel(M, N, K, target="cuda") + + # Wrong type for A (int instead of tensor) + a = 1 + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/03_ndim_mismatch.py b/maint/host_checks/03_ndim_mismatch.py new file mode 100644 index 000000000..76637e8de --- /dev/null +++ b/maint/host_checks/03_ndim_mismatch.py @@ -0,0 +1,19 @@ +"""Reproduce: ndim (rank) mismatch for A.""" + +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 128 + fn = build_matmul_kernel(M, N, K, target="cuda") + + # A has rank 3 instead of 2 + a = torch.empty((M, K, 1), device="cuda", dtype=torch.float16) + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/04_dtype_mismatch.py b/maint/host_checks/04_dtype_mismatch.py new file mode 100644 index 000000000..f3554c1d6 --- /dev/null +++ b/maint/host_checks/04_dtype_mismatch.py @@ -0,0 +1,19 @@ +"""Reproduce: dtype mismatch for A (float32 vs expected float16).""" + +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 128 + fn = build_matmul_kernel(M, N, K, target="cuda") + print(fn.get_host_source()) + + a = torch.empty((M, K), device="cuda", dtype=torch.float32) # should be float16 + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/05_shape_mismatch.py b/maint/host_checks/05_shape_mismatch.py new file mode 100644 index 000000000..a48248176 --- /dev/null +++ b/maint/host_checks/05_shape_mismatch.py @@ -0,0 +1,19 @@ +"""Reproduce: shape constant/symbol mismatch on A.""" + +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 128 + fn = build_matmul_kernel(M, N, K, target="cuda") + + # A's second dimension is wrong (K+1 instead of K) + a = torch.empty((M, K + 1), device="cuda", dtype=torch.float16) + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/06_strides_mismatch.py b/maint/host_checks/06_strides_mismatch.py new file mode 100644 index 000000000..7e523cd64 --- /dev/null +++ b/maint/host_checks/06_strides_mismatch.py @@ -0,0 +1,19 @@ +"""Reproduce: strides check failure (non-contiguous A via transpose).""" + +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 128 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = torch.empty((M, K), device="cuda", dtype=torch.float16) + a_nc = a.t() # non-contiguous after transpose + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a_nc, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/07_device_type_mismatch.py b/maint/host_checks/07_device_type_mismatch.py new file mode 100644 index 000000000..af8e5efd5 --- /dev/null +++ b/maint/host_checks/07_device_type_mismatch.py @@ -0,0 +1,18 @@ +"""Reproduce: device_type mismatch by passing CPU tensors to a CUDA kernel.""" + +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 64 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = torch.empty((M, K), device="cpu", dtype=torch.float16) + b = torch.empty((K, N), device="cpu", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/08_device_id_mismatch.py b/maint/host_checks/08_device_id_mismatch.py new file mode 100644 index 000000000..280aca157 --- /dev/null +++ b/maint/host_checks/08_device_id_mismatch.py @@ -0,0 +1,25 @@ +"""Reproduce: device_id mismatch (requires >=2 CUDA devices).""" + +import torch +from common import build_matmul_kernel + + +def main(): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + if torch.cuda.device_count() < 2: + print("[SKIP] Need at least 2 CUDA devices to reproduce device_id mismatch.") + return + + M = N = K = 64 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = torch.empty((M, K), device="cuda:0", dtype=torch.float16) + b = torch.empty((K, N), device="cuda:1", dtype=torch.float16) + # Output device is derived by the adapter; mismatch occurs in host checks + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/09_null_data_pointer.py b/maint/host_checks/09_null_data_pointer.py new file mode 100644 index 000000000..09f5de1af --- /dev/null +++ b/maint/host_checks/09_null_data_pointer.py @@ -0,0 +1,26 @@ +"""Reproduce: NULL data pointer (advanced). + +Passing None for a tensor argument will be forwarded through the adapter. Depending on +FFI handling, this commonly triggers a pointer-type assertion (e.g., "Expect buffer to be pointer or tensor") +or a host-side non-NULL pointer check. + +Note: Constructing a true DLTensor with NULL data in PyTorch is not typical; this script +demonstrates passing None, which still reproduces the intended class of failure. +""" + +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 64 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = None # attempt to pass a null-like pointer + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/10_scalar_type_mismatch.py b/maint/host_checks/10_scalar_type_mismatch.py new file mode 100644 index 000000000..4f2c90b8d --- /dev/null +++ b/maint/host_checks/10_scalar_type_mismatch.py @@ -0,0 +1,15 @@ +"""Reproduce: scalar parameter type mismatch (int/bool).""" + +from common import build_scalar_check_kernel + + +def main(): + fn = build_scalar_check_kernel(target="cuda") + + # Wrong types + fn(1.0, True) # x should be int -> Expect arg[0] to be int + fn(1, 2.5) # flag should be bool -> Expect arg[1] to be boolean + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/README.md b/maint/host_checks/README.md new file mode 100644 index 000000000..ac23d6fd2 --- /dev/null +++ b/maint/host_checks/README.md @@ -0,0 +1,21 @@ +# Host-Side Check Repro Scripts + +This folder contains standalone scripts that deliberately trigger host-side (and adapter-side) validation errors described in `docs/compiler_internals/tensor_checks.md`. Each script can be run directly and will reproduce the corresponding error with a minimal example. + +Prerequisites +- CUDA-capable environment (most scripts compile a CUDA-targeted kernel) +- Python packages: torch, tilelang + +Usage +- Run any script, e.g.: + - `python 01_num_args_mismatch.py` + - `python 02_pointer_type_error.py` + - ... up to `10_scalar_type_mismatch.py` + +- Or run all at once with a summary: + - `python run_all.py` + - Logs per test are saved under `logs/` as `