diff --git a/.github/workflows/nvidia-h100.yml b/.github/workflows/nvidia-h100.yml index a396487c9e..cc5963a4dc 100644 --- a/.github/workflows/nvidia-h100.yml +++ b/.github/workflows/nvidia-h100.yml @@ -20,7 +20,7 @@ jobs: if: github.event_name != 'pull_request' || github.event.action != 'closed' uses: ./.github/workflows/reusable-ci-tests.yml with: - runner: 'nvidia-h100-1' + runner: 'nvidia-h100-pt2-7' gpu_type: 'nvidia' conda_env_name: 'pytorch_2_7' pytorch_version: '2.7.0' @@ -31,21 +31,8 @@ jobs: if: github.event_name != 'pull_request' || github.event.action != 'closed' uses: ./.github/workflows/reusable-ci-tests.yml with: - runner: 'nvidia-h100-2' + runner: 'nvidia-h100-3' gpu_type: 'nvidia' conda_env_name: 'pytorch_nightly' pytorch_version: 'nightly' skip_gpu_check: true - - test-h100-pytorch-2-6: - name: Test H100 (PyTorch 2.6) - if: github.event_name != 'pull_request' || github.event.action != 'closed' - uses: ./.github/workflows/reusable-ci-tests.yml - with: - runner: 'nvidia-h100-3' - gpu_type: 'nvidia' - conda_env_name: 'pytorch_2_6' - pytorch_version: '2.6.0' - pytorch_cuda_version: 'cu126' - nvcc_toolkit_version: '12.6.3' - skip_gpu_check: true diff --git a/.github/workflows/reusable-ci-tests.yml b/.github/workflows/reusable-ci-tests.yml index a6e6a70670..9e7f7420cc 100644 --- a/.github/workflows/reusable-ci-tests.yml +++ b/.github/workflows/reusable-ci-tests.yml @@ -54,6 +54,21 @@ jobs: shell: bash run: | set -e + + # 1. ADDED LOGIC: Determine the target conda environment name based on the runner. + # This block implements the core requirement. + TARGET_CONDA_ENV="" + echo "Determining conda environment based on runner: ${{ runner.name }}" + if [[ "${{ runner.name }}" == "nvidia-h100-1" ]]; then + TARGET_CONDA_ENV="${{ inputs.conda_env_name }}" + elif [[ "${{ runner.name }}" == "nvidia-h100-2" ]]; then + TARGET_CONDA_ENV="${{ inputs.conda_env_name }}_1" + else + TARGET_CONDA_ENV="${{ inputs.conda_env_name }}" + echo "Runner is not a special case, using input env: '${TARGET_CONDA_ENV}'" + fi + echo "--> Runner is '${{ runner.name }}', selected environment is '${TARGET_CONDA_ENV}'" + echo "Searching for Conda installation in home directory ($HOME)..." POSSIBLE_NAMES=("miniforge3" "miniconda3" "anaconda3") FOUND_PATH="" @@ -66,12 +81,23 @@ jobs: break fi done + if [ -n "${FOUND_PATH}" ]; then echo "Setting CONDA environment variable to: ${FOUND_PATH}" echo "CONDA=${FOUND_PATH}" >> $GITHUB_ENV - # MODIFICATION: Set both CONDA and CONDA_BIN_PATH here. - # This makes them available to all subsequent steps. - echo "CONDA_BIN_PATH=${FOUND_PATH}/envs/${{ inputs.conda_env_name }}/bin" >> $GITHUB_ENV + + # 2. MODIFIED LOGIC: Use the dynamically determined TARGET_CONDA_ENV variable. + # Instead of using inputs.conda_env_name, we use the variable set in step 1. + echo "CONDA_BIN_PATH=${FOUND_PATH}/envs/${TARGET_CONDA_ENV}/bin" >> $GITHUB_ENV + + # 3. ADDED LOGIC: Also export the determined environment name itself. + # This is very useful for subsequent 'conda activate' steps. + echo "CONDA_ENV_NAME=${TARGET_CONDA_ENV}" >> $GITHUB_ENV + + echo "Successfully set environment variables." + echo "CONDA = ${FOUND_PATH}" + echo "CONDA_ENV_NAME = ${TARGET_CONDA_ENV}" + echo "CONDA_BIN_PATH = ${FOUND_PATH}/envs/${TARGET_CONDA_ENV}/bin" else echo "::error::Could not automatically find a Conda installation." exit 1 @@ -166,11 +192,12 @@ jobs: # ================================================================= # STAGE 2: OPS TESTS # ================================================================= - - name: Find dependent test files for Ops + - name: Find dependent OP test files for Ops if: steps.check_skip.outputs.skip_tests == 'false' id: find-ops-tests shell: bash run: | + export TEST_SCOPE="EXCLUDE_MODELS" TEST_FILES=$($CONDA_BIN_PATH/python scripts/find_dependent_tests.py "${{ steps.changed-files.outputs.all_changed_files }}") echo "Found ops test files: $TEST_FILES" echo "test_files=$TEST_FILES" >> $GITHUB_OUTPUT @@ -179,16 +206,7 @@ jobs: if: steps.find-ops-tests.outputs.test_files && steps.check_skip.outputs.skip_tests == 'false' shell: bash run: | - TRITON_PRINT_AUTOTUNING=0 SKIP_TEST_CHUNK_VARLEN=1 \ - $CONDA_BIN_PATH/pytest -s -v ${{ steps.find-ops-tests.outputs.test_files }} - - - name: Run pytest on ops varlen test files - if: steps.find-ops-tests.outputs.test_files && steps.check_skip.outputs.skip_tests == 'false' - shell: bash - run: | - TRITON_PRINT_AUTOTUNING=0 SKIP_TEST_CHUNK_VARLEN=0 \ - $CONDA_BIN_PATH/pytest -s -v ${{ steps.find-ops-tests.outputs.test_files }} || \ - echo "Varlen tests failed for ops (non-critical)" + $CONDA_BIN_PATH/pytest -s -v ${{ steps.find-ops-tests.outputs.test_files }} - name: Verify FLA import shell: bash @@ -198,9 +216,8 @@ jobs: test-models: runs-on: ${{ inputs.runner }} needs: test-ops - if: always() + if: success() || failure() env: - # MODIFICATION: Removed CONDA_BIN_PATH from here. FLA_CI_ENV: 1 steps: @@ -212,19 +229,50 @@ jobs: shell: bash run: | set -e + + # 1. ADDED LOGIC: Determine the target conda environment name based on the runner. + # This block implements the core requirement. + TARGET_CONDA_ENV="" + echo "Determining conda environment based on runner: ${{ runner.name }}" + if [[ "${{ runner.name }}" == "nvidia-h100-1" ]]; then + TARGET_CONDA_ENV="${{ inputs.conda_env_name }}" + elif [[ "${{ runner.name }}" == "nvidia-h100-2" ]]; then + TARGET_CONDA_ENV="${{ inputs.conda_env_name }}_1" + else + TARGET_CONDA_ENV="${{ inputs.conda_env_name }}" + echo "Runner is not a special case, using input env: '${TARGET_CONDA_ENV}'" + fi + echo "--> Runner is '${{ runner.name }}', selected environment is '${TARGET_CONDA_ENV}'" + + echo "Searching for Conda installation in home directory ($HOME)..." POSSIBLE_NAMES=("miniforge3" "miniconda3" "anaconda3") FOUND_PATH="" for name in "${POSSIBLE_NAMES[@]}"; do CANDIDATE_PATH="$HOME/$name" + echo "--> Checking for path: ${CANDIDATE_PATH}" if [ -d "${CANDIDATE_PATH}" ] && [ -x "${CANDIDATE_PATH}/bin/conda" ]; then + echo " Found valid Conda installation: ${CANDIDATE_PATH}" FOUND_PATH="${CANDIDATE_PATH}" break fi done + if [ -n "${FOUND_PATH}" ]; then + echo "Setting CONDA environment variable to: ${FOUND_PATH}" echo "CONDA=${FOUND_PATH}" >> $GITHUB_ENV - # MODIFICATION: Set both CONDA and CONDA_BIN_PATH here. - echo "CONDA_BIN_PATH=${FOUND_PATH}/envs/${{ inputs.conda_env_name }}/bin" >> $GITHUB_ENV + + # 2. MODIFIED LOGIC: Use the dynamically determined TARGET_CONDA_ENV variable. + # Instead of using inputs.conda_env_name, we use the variable set in step 1. + echo "CONDA_BIN_PATH=${FOUND_PATH}/envs/${TARGET_CONDA_ENV}/bin" >> $GITHUB_ENV + + # 3. ADDED LOGIC: Also export the determined environment name itself. + # This is very useful for subsequent 'conda activate' steps. + echo "CONDA_ENV_NAME=${TARGET_CONDA_ENV}" >> $GITHUB_ENV + + echo "Successfully set environment variables." + echo "CONDA = ${FOUND_PATH}" + echo "CONDA_ENV_NAME = ${TARGET_CONDA_ENV}" + echo "CONDA_BIN_PATH = ${FOUND_PATH}/envs/${TARGET_CONDA_ENV}/bin" else echo "::error::Could not automatically find a Conda installation." exit 1 @@ -258,11 +306,12 @@ jobs: # ================================================================= # STAGE 3: MODELS TESTS (Reuses the same activated environment) # ================================================================= - - name: Find dependent test files for Models + - name: Find dependent Model test files for Models if: steps.check_skip.outputs.skip_tests == 'false' id: find-models-tests shell: bash run: | + export TEST_SCOPE="MODELS_ONLY" TEST_FILES=$($CONDA_BIN_PATH/python scripts/find_dependent_tests.py "${{ steps.changed-files.outputs.all_changed_files }}") echo "Found models test files: $TEST_FILES" echo "test_files=$TEST_FILES" >> $GITHUB_OUTPUT @@ -271,13 +320,4 @@ jobs: if: steps.find-models-tests.outputs.test_files && steps.check_skip.outputs.skip_tests == 'false' shell: bash run: | - TRITON_PRINT_AUTOTUNING=0 SKIP_TEST_CHUNK_VARLEN=1 \ - $CONDA_BIN_PATH/pytest -s -v ${{ steps.find-models-tests.outputs.test_files }} - - - name: Run pytest on models varlen test files - if: steps.find-models-tests.outputs.test_files && steps.check_skip.outputs.skip_tests == 'false' - shell: bash - run: | - TRITON_PRINT_AUTOTUNING=0 SKIP_TEST_CHUNK_VARLEN=0 \ - $CONDA_BIN_PATH/pytest -s -v ${{ steps.find-models-tests.outputs.test_files }} || \ - echo "Varlen tests failed for models (non-critical)" + $CONDA_BIN_PATH/pytest -s -v ${{ steps.find-models-tests.outputs.test_files }} diff --git a/fla/ops/attn/parallel.py b/fla/ops/attn/parallel.py index 8d2ec72f8d..9b4b1126d6 100644 --- a/fla/ops/attn/parallel.py +++ b/fla/ops/attn/parallel.py @@ -22,7 +22,7 @@ @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else []) + for num_warps in [2, 4] + ([8] if check_shared_mem('hopper') else []) for num_stages in [2, 3, 4, 5] ], key=['B', 'H', 'HQ', 'G', 'K', 'V', 'BK', 'BV', 'USE_G', 'IS_VARLEN'], @@ -177,7 +177,7 @@ def parallel_attn_bwd_kernel_preprocess( @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else []) + for num_warps in [2, 4] + ([8] if check_shared_mem('hopper') else []) for num_stages in [2, 3, 4, 5] ], key=['B', 'H', 'HQ', 'G', 'K', 'V', 'BK', 'BV', 'USE_G', 'IS_VARLEN'], @@ -319,7 +319,7 @@ def parallel_attn_bwd_kernel_dq( @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else []) + for num_warps in [2, 4] + ([8] if check_shared_mem('hopper') else []) for num_stages in [2, 3, 4, 5] ], key=['B', 'H', 'HQ', 'G', 'K', 'V', 'BK', 'BV', 'USE_G', 'IS_VARLEN'], diff --git a/fla/ops/common/chunk_o.py b/fla/ops/common/chunk_o.py index 7b53e48ed8..ed83fcbc73 100644 --- a/fla/ops/common/chunk_o.py +++ b/fla/ops/common/chunk_o.py @@ -22,11 +22,9 @@ }) @triton.autotune( configs=[ - triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) - for BK in BKV_LIST - for BV in BKV_LIST - for num_warps in NUM_WARPS - for num_stages in [2, 3, 4] + triton.Config({'BK': 128, 'BV': 128}, num_warps=8, num_stages=3), + triton.Config({'BK': 64, 'BV': 64}, num_warps=4, num_stages=3), + triton.Config({'BK': 32, 'BV': 32}, num_warps=2, num_stages=3), ], key=['H', 'K', 'V', 'BT'], ) diff --git a/fla/ops/generalized_delta_rule/iplr/chunk.py b/fla/ops/generalized_delta_rule/iplr/chunk.py index e67b3154fc..7c65cb30ec 100644 --- a/fla/ops/generalized_delta_rule/iplr/chunk.py +++ b/fla/ops/generalized_delta_rule/iplr/chunk.py @@ -23,7 +23,7 @@ @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps) - for num_warps in [2, 4, 8, 16] + for num_warps in [2, 4] + ([] if check_shared_mem('hopper') else [8]) ], key=['BT', 'BK', 'BV'], use_cuda_graph=use_cuda_graph, @@ -104,11 +104,10 @@ def chunk_generalized_iplr_delta_rule_fwd_kernel_h( }) @triton.autotune( configs=[ - triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps) for BK in BKV_LIST for BV in BKV_LIST for num_warps in [2, 4, 8] - for num_stages in [2, 3] ], key=['BT'], use_cuda_graph=use_cuda_graph, diff --git a/fla/ops/gsa/fused_recurrent.py b/fla/ops/gsa/fused_recurrent.py index a90a55cbf1..13b0b58ae0 100644 --- a/fla/ops/gsa/fused_recurrent.py +++ b/fla/ops/gsa/fused_recurrent.py @@ -247,7 +247,7 @@ def fused_recurrent_gsa_bwd( B, T, H, K, V, M = *q.shape, v.shape[-1], s.shape[-1] N = B if cu_seqlens is None else len(cu_seqlens) - 1 - BK, BV, BM = min(K, 64), min(V, 64), min(M, 64) + BK, BV, BM = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64), min(triton.next_power_of_2(M), 64) NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM) dqv = q.new_empty(NV, B, T, H, M, dtype=torch.float) diff --git a/fla/ops/rwkv7/fused_addcmul.py b/fla/ops/rwkv7/fused_addcmul.py index eb190ed537..1684cae4bc 100644 --- a/fla/ops/rwkv7/fused_addcmul.py +++ b/fla/ops/rwkv7/fused_addcmul.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import logging +import sys from typing import Optional import torch @@ -14,6 +15,17 @@ if not check_pytorch_version('2.4'): logger.warning('PyTorch < 2.4 detected - computations may be slower due to lack of optimizations') + +def identity_decorator(fn): + return fn + + +if sys.version_info > (3, 10): + torch_compile = torch.compile(fullgraph=True) +else: + logger.warning('torch.compile is not available in Python 3.10, using identity decorator instead') + torch_compile = identity_decorator + NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [2, 4, 8, 16, 32] @@ -164,7 +176,7 @@ def addcmul_bwd1(d_xr, d_xw, d_xk, d_xv, d_xa, d_xg, return g_hiddn, g_delta -@torch.compile(fullgraph=True) +@torch_compile def addcmul_bwd2(d_oxr, d_xw, d_xk, d_xv, d_xa, d_xg, delta, use_xg: bool): g_xr = (d_oxr * delta).sum(dim=(0, 1), keepdim=True) g_xw = (d_xw * delta).sum(dim=(0, 1), keepdim=True) diff --git a/fla/ops/utils/cumsum.py b/fla/ops/utils/cumsum.py index 8b0454a39e..5c73a7eea0 100644 --- a/fla/ops/utils/cumsum.py +++ b/fla/ops/utils/cumsum.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang -import warnings from typing import Optional import torch @@ -165,7 +164,7 @@ def chunk_global_cumsum_scalar_kernel( b_z = tl.zeros([], dtype=tl.float32) NT = tl.cdiv(T, BT) for i_c in range(NT): - i_t = NT-1-i_c if REVERSE else i_c + i_t = NT - 1 - i_c if REVERSE else i_c if HEAD_FIRST: p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) @@ -232,7 +231,7 @@ def chunk_global_cumsum_vector_kernel( b_z = tl.zeros([BS], dtype=tl.float32) NT = tl.cdiv(T, BT) for i_c in range(NT): - i_t = NT-1-i_c if REVERSE else i_c + i_t = NT - 1 - i_c if REVERSE else i_c if HEAD_FIRST: p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) p_o = tl.make_block_ptr(o + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) @@ -245,8 +244,7 @@ def chunk_global_cumsum_vector_kernel( if HAS_SCALE: b_c *= scale tl.store(p_o, b_c.to(p_o.dtype.element_ty), boundary_check=(0, 1)) - if i_c >= 0: - b_z += tl.sum(b_s, 0) + b_z += tl.sum(b_s, 0) def chunk_local_cumsum_scalar( @@ -437,13 +435,6 @@ def chunk_local_cumsum( output_dtype: Optional[torch.dtype] = torch.float, **kwargs ) -> torch.Tensor: - if not head_first and g.shape[1] < g.shape[2]: - warnings.warn( - f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). " - "This may indicate the inputs were passed in head-first format [B, H, T, ...] " - "when `head_first=False` was specified. " - "Please verify your input tensor format matches the expected shape [B, T, H, ...]." - ) if cu_seqlens is not None: assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" if len(g.shape) == 3: diff --git a/tests/ops/test_attn.py b/tests/ops/test_attn.py index 7475fc0f62..c52081b5d8 100644 --- a/tests/ops/test_attn.py +++ b/tests/ops/test_attn.py @@ -1,13 +1,14 @@ # -*- coding: utf-8 -*- import os +from typing import List import pytest import torch from fla.ops.attn.parallel import parallel_attn from fla.ops.utils import prepare_lens -from fla.utils import COMPILER_MODE, assert_close, check_shared_mem, device +from fla.utils import assert_close, check_shared_mem, device try: from flash_attn import flash_attn_func, flash_attn_varlen_func @@ -16,49 +17,37 @@ HAS_FLASH = False -if COMPILER_MODE: - test_b_list = [2] - test_t_list = [2048] - test_t_varlen_list = test_t_list - test_d_list = [64, 100, 128] -else: - test_b_list = [2, 4] - test_t_list = [1, 15, 63, 286, 300, 1024, 2048] - test_t_varlen_list = [63, 286, 300, 512] - test_d_list = [32, 64, 100] -test_hq_list = [8, 16] -test_h_list = [2] - - -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('HQ', test_hq_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('scale', [1, 0.1]) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.skipif( - not HAS_FLASH, - reason="Skipping test because flash-attn is not installed" +@pytest.mark.parametrize( + ('B', 'T', 'H', 'HQ', 'D', 'scale'), + [ + pytest.param(*test, id="B{}-T{}-H{}-HQ{}-D{}-scale{}".format(*test)) + for test in [ + (1, 63, 1, 1, 64, 1.0), + (3, 111, 2, 2, 100, 1.0), + (3, 1024, 2, 8, 60, 0.1), + (3, 1024, 2, 8, 128, 0.1), + (4, 2048, 2, 8, 64, 0.1) + ] + ] ) def test_parallel( B: int, + T: int, H: int, HQ: int, - T: int, D: int, - dtype: torch.dtype, scale: float, ): if not check_shared_mem('hopper') and D > 128: pytest.skip(reason="Skip test, do not have enough shard mem") + if not HAS_FLASH: + pytest.skip(reason="Skipping test because flash-attn is not installed") torch.manual_seed(42) os.environ['TRITON_F32_DEFAULT'] = 'ieee' - - q = torch.randn((B, T, HQ, D), dtype=dtype, device=device).requires_grad_(True) - k = torch.randn((B, T, H, D), dtype=dtype, device=device).requires_grad_(True) - v = torch.randn((B, T, H, D), dtype=dtype, device=device).requires_grad_(True) - do = torch.randn((B, T, HQ, D), dtype=dtype, device=device) + q = torch.randn((B, T, HQ, D), dtype=torch.float16, device=device).requires_grad_(True) + k = torch.randn((B, T, H, D), dtype=torch.float16, device=device).requires_grad_(True) + v = torch.randn((B, T, H, D), dtype=torch.float16, device=device).requires_grad_(True) + do = torch.randn((B, T, HQ, D), dtype=torch.float16, device=device) ref = flash_attn_func(q=q, k=k, v=v, softmax_scale=scale, causal=True) ref.backward(do) @@ -78,37 +67,29 @@ def test_parallel( assert_close("dv", ref_dv, tri_dv, 0.005) -@pytest.mark.parametrize('N', test_b_list) -@pytest.mark.parametrize('T', test_t_varlen_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('HQ', test_hq_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.skipif( - not HAS_FLASH, - reason="Skipping test because flash-attn is not installed" +@pytest.mark.parametrize( + ('H', 'HQ', 'D', 'cu_seqlens'), + [ + pytest.param(*test, id="H{}-HQ{}-D{}-cu_seqlens{}".format(*test)) + for test in [ + (2, 2, 64, [0, 15]), + (2, 8, 64, [0, 256, 500, 1000]), + (2, 2, 100, [0, 15, 100, 300, 1200, 2000]), + ] + ] ) def test_parallel_varlen( - N: int, - T: int, H: int, HQ: int, D: int, - dtype: torch.dtype, + cu_seqlens: List[int], ): - if not check_shared_mem('hopper') and D > 128: - pytest.skip(reason="Skip test, do not have enough shard mem") - torch.manual_seed(42) - os.environ['TRITON_F32_DEFAULT'] = 'ieee' + if not HAS_FLASH: + pytest.skip(reason="Skipping test because flash-attn is not installed") + T = cu_seqlens[-1] + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + dtype = torch.float16 - N = min(1, N) if T < 64 else N - # randomly split the sequence into N segments - cu_seqlens = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 16)[:N-1]], - torch.tensor([T], dtype=torch.long) - ], 0).to(device).sort()[0].to(torch.int32) - # seq-first required for inputs with variable lengths q = torch.randn((1, T, HQ, D), dtype=dtype, device=device).requires_grad_() k = torch.randn((1, T, H, D), dtype=dtype, device=device).requires_grad_() v = torch.randn((1, T, H, D), dtype=dtype, device=device).requires_grad_() diff --git a/tests/ops/test_based.py b/tests/ops/test_based.py index 6a4fa31250..d03725c49d 100644 --- a/tests/ops/test_based.py +++ b/tests/ops/test_based.py @@ -1,38 +1,30 @@ # -*- coding: utf-8 -*- -import os - import pytest import torch from fla.ops.based import fused_chunk_based, parallel_based from fla.ops.based.naive import naive_parallel_based -from fla.utils import COMPILER_MODE, device - -if COMPILER_MODE: - test_b_list = [1] - test_t_list = [4096] - test_d_list = [64, 128, 256] -else: - test_b_list = [2] - test_t_list = [1, 15, 63, 300] - test_d_list = [64, 32, 100, 256] -test_h_list = [2] +from fla.utils import device -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float32]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-{}".format(*test)) + for test in [ + (1, 63, 1, 60, torch.float16), + (3, 111, 2, 64, torch.float16), + (3, 1024, 4, 100, torch.float16), + (3, 1024, 8, 128, torch.float16), + (4, 2048, 8, 256, torch.float16) + ] + ] ) def test_based( B: int, - H: int, T: int, + H: int, D: int, dtype: torch.dtype ): diff --git a/tests/ops/test_comba.py b/tests/ops/test_comba.py index cdc0c0179f..5d24595cae 100644 --- a/tests/ops/test_comba.py +++ b/tests/ops/test_comba.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import os +from typing import List import pytest import torch @@ -9,22 +10,7 @@ from fla.ops.comba import chunk_comba, fused_recurrent_comba from fla.ops.comba.utils import chunk_comba_cumsum_scalar_fwd -from fla.utils import COMPILER_MODE, assert_close, device, is_intel_alchemist - -if COMPILER_MODE: - test_b_list = [1] - test_t_list = [4096] - test_t_varlen_list = test_t_list - test_d_list = [64, 128, 256] - test_gate_list = [1.0] -else: - test_b_list = [2] - test_t_list = [1, 15, 63, 300] - test_t_varlen_list = [63, 286, 300, 512] - test_d_list = [64, 32, 100, 256] - test_gate_list = [1, 0.1, 10] -test_h_list = [2] -test_hv_list = [4] +from fla.utils import assert_close, device, is_intel_alchemist def cumsum_comba_local_fwd_reference(s, reverse=False, chunk_size=128): @@ -40,16 +26,24 @@ def cumsum_comba_local_fwd_reference(s, reverse=False, chunk_size=128): return o_0, o_1 -@pytest.mark.parametrize("B", [32]) -@pytest.mark.parametrize("T", [256, 1024, 2048]) -@pytest.mark.parametrize("H", [4]) -@pytest.mark.parametrize("dtype", [torch.float, torch.float16]) -@pytest.mark.parametrize("chunk_size", [32, 64]) -@pytest.mark.skipif( - os.getenv("SKIP_TEST_CHUNK_VARLEN") == "0", - reason="Skipping test because TEST_CHUNK_VARLEN is enabled" +@pytest.mark.parametrize( + ('B', 'T', 'H', 'chunk_size', 'dtype'), + [ + pytest.param(*test, id='B{}-T{}-H{}-chunk_size{}-{}'.format(*test)) + for test in [ + (32, 200, 4, 64, torch.float), + (32, 1000, 4, 64, torch.float), + (32, 2048, 8, 128, torch.float), + ] + ] ) -def test_cumsum_local_scalar_fwd(B, T, H, dtype, chunk_size): +def test_cumsum_local_scalar_fwd( + B: int, + T: int, + H: int, + chunk_size: int, + dtype: torch.dtype, +): s = torch.randn((B, T, H), dtype=dtype, device=device).requires_grad_() ref_0, ref_1 = cumsum_comba_local_fwd_reference(s, chunk_size=chunk_size) tri_0, tri_1 = chunk_comba_cumsum_scalar_fwd(s, chunk_size=chunk_size) @@ -136,16 +130,21 @@ def chunk_comba_ref( return o, S -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('gate_logit_normalizer', test_gate_list) -@pytest.mark.parametrize('scale', [1]) -@pytest.mark.parametrize('dtype', [torch.float32, torch.float16]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'gate_logit_normalizer', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, 1, 1, torch.float), + (2, 1024, 4, 60, 1, 1, torch.float), + (2, 1024, 8, 128, 1, 0.1, torch.float), + (2, 1024, 8, 128, 0.1, 1, torch.float), + (2, 1024, 8, 128, 1, 10, torch.float), + (4, 2048, 8, 64, 0.1, 1, torch.float), + (2, 1024, 8, 128, 1, 0.1, torch.float16), + (2, 1024, 8, 128, 1, 10, torch.float16), + ] + ] ) def test_fused_recurrent( B: int, @@ -153,8 +152,8 @@ def test_fused_recurrent( H: int, D: int, scale: float, - dtype: torch.dtype, gate_logit_normalizer: float, + dtype: torch.dtype, ): torch.manual_seed(42) q = F.normalize(torch.randn(B, T, H, D, dtype=torch.float32), p=2, dim=-1).to(dtype) @@ -188,31 +187,35 @@ def test_fused_recurrent( initial_state=h0.clone(), output_final_state=True, ) - assert_close(' o', ref, tri, 0.002) - assert_close(' ht', ref_ht, tri_ht, 0.002) + assert_close('o', ref, tri, 0.002) + assert_close('ht', ref_ht, tri_ht, 0.002) -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('gate_logit_normalizer', test_gate_list) -@pytest.mark.parametrize('scale', [1, 0.1]) -@pytest.mark.parametrize('mask_p', [0, 0.5]) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'gate_logit_normalizer', 'mask_p', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-mask_p{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, 1, 1, 0, torch.float16), + (2, 1000, 3, 60, 1, 1, 0, torch.float16), + (2, 1024, 3, 64, 0.1, 1, 0.5, torch.float16), + (2, 1024, 4, 100, 1, 0.1, 0, torch.float16), + (2, 1024, 4, 128, 0.1, 1, 0, torch.float16), + (2, 1024, 4, 128, 0.1, 1, 0.5, torch.float16), + (2, 1024, 4, 128, 0.1, 10, 0, torch.float16), + (4, 2048, 8, 64, 0.1, 1, 0, torch.float16) + ] + ] ) def test_chunk( B: int, T: int, H: int, D: int, - dtype: torch.dtype, scale: float, gate_logit_normalizer: float, mask_p: float, + dtype: torch.dtype, ): if is_intel_alchemist and D > 128: pytest.skip(reason='chunk_gated_delta_rule is not supported on alchemist for D>128') @@ -272,36 +275,37 @@ def test_chunk( assert_close("dh0", ref_dh0, tri_dh0, 0.008) -@pytest.mark.parametrize('N', test_b_list) -@pytest.mark.parametrize('T', test_t_varlen_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('scale', [1, 0.1]) -@pytest.mark.parametrize('mask_p', [0, 0.5]) -@pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize( + ('H', 'D', 'mask_p', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-D{}-mask_p{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (4, 64, 0, [0, 15], torch.float16), + (4, 64, 0, [0, 256, 500, 1000], torch.float16), + (4, 64, 0.5, [0, 256, 500, 1000], torch.float16), + (4, 100, 0, [0, 15, 100, 300, 1200, 2000], torch.float16), + ] + ] +) @pytest.mark.skipif( os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', reason='Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set' ) def test_chunk_varlen( - N: int, - T: int, H: int, D: int, - scale: float, mask_p: float, + cu_seqlens: List[int], dtype: torch.dtype, ): if is_intel_alchemist and D > 128: pytest.skip(reason='chunk_gated_delta_rule is not supported on alchemist for D>128') torch.manual_seed(42) os.environ['TRITON_F32_DEFAULT'] = 'ieee' - # randomly split the sequence into N segments - cu_seqlens = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 16)[:N-1]], - torch.tensor([T], dtype=torch.long) - ], 0).to(device).sort()[0] + + N = len(cu_seqlens) - 1 + T = cu_seqlens[-1] + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) q = torch.randn((1, T, H, D), dtype=dtype) k = F.normalize(torch.randn(1, T, H, D, dtype=torch.float32), p=2, dim=-1).to(dtype) @@ -323,7 +327,6 @@ def test_chunk_varlen( p=p.clone(), beta=beta.clone(), g=g.clone(), - scale=scale, output_final_state=True, initial_state=h0.clone(), cu_seqlens=cu_seqlens, @@ -342,7 +345,6 @@ def test_chunk_varlen( p=p[:, cu_seqlens[i]:cu_seqlens[i+1]], beta=beta[:, cu_seqlens[i]:cu_seqlens[i+1]], g=g[:, cu_seqlens[i]:cu_seqlens[i+1]], - scale=scale, initial_state=h0[i], output_final_state=True, ) @@ -354,11 +356,11 @@ def test_chunk_varlen( ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dg, ref_dh0 = q.grad, k.grad, v.grad, beta.grad, g.grad, h0.grad - assert_close(' o', ref, tri, 0.005) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.007) - assert_close(' dk', ref_dk, tri_dk, 0.008) - assert_close(' dv', ref_dv, tri_dv, 0.007) - assert_close(' db', ref_dbeta, tri_dbeta, 0.015) - assert_close(' dg', ref_dg, tri_dg, 0.015) + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.007) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.007) + assert_close('db', ref_dbeta, tri_dbeta, 0.015) + assert_close('dg', ref_dg, tri_dg, 0.015) assert_close('dh0', ref_dh0, tri_dh0, 0.007) diff --git a/tests/ops/test_cumsum.py b/tests/ops/test_cumsum.py deleted file mode 100644 index 4ae9c3011a..0000000000 --- a/tests/ops/test_cumsum.py +++ /dev/null @@ -1,148 +0,0 @@ -# -*- coding: utf-8 -*- - -import os - -import pytest -import torch - -from fla.ops.utils.cumsum import chunk_global_cumsum, chunk_local_cumsum -from fla.utils import COMPILER_MODE, assert_close, device, device_platform - -if COMPILER_MODE: - test_b_list = [1] - test_t_list = [4096] - test_t_varlen_list = test_t_list - test_d_list = [64, 128, 256] -else: - test_b_list = [2] - test_t_list = [1, 15, 63, 300] - test_t_varlen_list = [63, 286, 300, 512] - test_d_list = [64, 32, 100, 256] -test_h_list = [2] - - -def rev_cumsum(s, dim=-1): - return torch.flip(torch.cumsum(torch.flip(s, dims=[dim]), dim), dims=[dim]) - - -def cumsum_local_reference( - s: torch.Tensor, - reverse: bool = False, - chunk_size: int = 128 -): - o = torch.zeros_like(s) - T = s.size(1) - fn = torch.cumsum if not reverse else rev_cumsum - for i in range(0, T, chunk_size): - s_chunk = s[:, i:i+chunk_size] - o[:, i:i+chunk_size] = fn(s_chunk.float(), dim=1).to(o) - - return o - - -def cumsum_global_reference( - s: torch.Tensor, - reverse: bool = False, -): - fn = torch.cumsum if not reverse else rev_cumsum - return fn(s.float(), dim=1).to(s) - - -@pytest.mark.parametrize("B", test_b_list) -@pytest.mark.parametrize("T", test_t_list) -@pytest.mark.parametrize("H", test_h_list) -@pytest.mark.parametrize("D", test_d_list) -@pytest.mark.parametrize("chunk_size", [32, 64]) -@pytest.mark.parametrize("dtype", [torch.float, torch.float16]) -@pytest.mark.parametrize("reverse", [False, True]) -@pytest.mark.skipif( - os.getenv("SKIP_TEST_CHUNK_VARLEN") == "0", - reason="Skipping test because TEST_CHUNK_VARLEN is enabled" -) -def test_cumsum_local_vector( - B: int, - T: int, - H: int, - D: int, - dtype: torch.dtype, - reverse: bool, - chunk_size: int -): - s = torch.randn((B, T, H, D), dtype=dtype, device=device).requires_grad_() - ref = cumsum_local_reference(s, reverse=reverse, chunk_size=chunk_size) - tri = chunk_local_cumsum(s, reverse=reverse, chunk_size=chunk_size) - assert_close("local cumsum vector", ref, tri, 0.001 if dtype == torch.float else 0.003) - - -@pytest.mark.parametrize("B", test_b_list) -@pytest.mark.parametrize("T", test_t_list) -@pytest.mark.parametrize("H", test_h_list) -@pytest.mark.parametrize("dtype", [torch.float, torch.float16]) -@pytest.mark.parametrize("reverse", [True, False]) -@pytest.mark.parametrize("chunk_size", [32, 64]) -@pytest.mark.skipif( - os.getenv("SKIP_TEST_CHUNK_VARLEN") == "0", - reason="Skipping test because TEST_CHUNK_VARLEN is enabled" -) -def test_cumsum_local_scalar( - B: int, - T: int, - H: int, - dtype: torch.dtype, - reverse: bool, - chunk_size: int -): - s = torch.randn((B, T, H), dtype=dtype, device=device).requires_grad_() - ref = cumsum_local_reference(s, reverse=reverse, chunk_size=chunk_size) - tri = chunk_local_cumsum(s, reverse=reverse, chunk_size=chunk_size) - assert_close("local cumsum scalar", ref, tri, 0.001 if dtype == torch.float else 0.003) - - -@pytest.mark.parametrize("B", test_b_list) -@pytest.mark.parametrize("T", test_t_list) -@pytest.mark.parametrize("H", test_h_list) -@pytest.mark.parametrize("D", test_d_list) -@pytest.mark.parametrize("dtype", [torch.float, torch.float16]) -@pytest.mark.parametrize("reverse", [True, False]) -@pytest.mark.skipif( - os.getenv("SKIP_TEST_CHUNK_VARLEN") == "0", - reason="Skipping test because TEST_CHUNK_VARLEN is enabled" -) -@pytest.mark.skipif( - device_platform == 'intel', - reason="Intel Triton Failure" -) -def test_cumsum_global_vector( - B: int, - T: int, - H: int, - D: int, - dtype: torch.dtype, - reverse: bool, -): - s = torch.randn((B, T, H, D), dtype=dtype, device=device).requires_grad_() - ref = cumsum_global_reference(s, reverse=reverse) - tri = chunk_global_cumsum(s, reverse=reverse) - assert_close("global cumsum vector", ref, tri, 0.001 if dtype == torch.float else 0.003) - - -@pytest.mark.parametrize("B", test_b_list) -@pytest.mark.parametrize("T", test_t_list) -@pytest.mark.parametrize("H", test_h_list) -@pytest.mark.parametrize("dtype", [torch.float, torch.float16]) -@pytest.mark.parametrize("reverse", [True, False]) -@pytest.mark.skipif( - os.getenv("SKIP_TEST_CHUNK_VARLEN") == "0", - reason="Skipping test because TEST_CHUNK_VARLEN is enabled" -) -def test_cumsum_global_scalar( - B: int, - T: int, - H: int, - dtype: torch.dtype, - reverse: bool, -): - s = torch.randn((B, T, H), dtype=dtype, device=device).requires_grad_() - ref = cumsum_global_reference(s, reverse=reverse) - tri = chunk_global_cumsum(s, reverse=reverse) - assert_close("global cumsum scalar", ref, tri, 0.001 if dtype == torch.float else 0.003) diff --git a/tests/ops/test_delta.py b/tests/ops/test_delta.py index aa3db01996..c894bd5ebd 100644 --- a/tests/ops/test_delta.py +++ b/tests/ops/test_delta.py @@ -1,36 +1,28 @@ # -*- coding: utf-8 -*- -import os +from typing import List import pytest import torch import torch.nn.functional as F from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule -from fla.utils import COMPILER_MODE, assert_close, device, device_platform - -if COMPILER_MODE: - test_b_list = [1] - test_t_list = [4096] - test_t_varlen_list = test_t_list - test_d_list = [64, 128, 256] -else: - test_b_list = [2] - test_t_list = [15, 63, 300, 512] - test_t_varlen_list = [63, 286, 300, 512] - test_d_list = [64, 32, 100, 256] -test_h_list = [2] - - -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('scale', [1]) -@pytest.mark.parametrize('dtype', [torch.bfloat16]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +from fla.utils import assert_close, device, device_platform + + +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-scale{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, 1, torch.float16), + (2, 100, 4, 60, 0.1, torch.float16), + (2, 1024, 3, 128, 0.1, torch.float16), + (2, 1024, 4, 128, 1, torch.float16), + (3, 2000, 4, 128, 0.1, torch.float16), + (4, 2048, 8, 64, 0.1, torch.float16), + ] + ] ) @pytest.mark.skipif( device_platform == 'intel', @@ -41,8 +33,8 @@ def test_chunk( T: int, H: int, D: int, - dtype: torch.dtype, scale: float, + dtype: torch.dtype, ): torch.manual_seed(42) q = torch.randn(B, T, H, D, dtype=dtype) @@ -79,45 +71,42 @@ def test_chunk( ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dh0 = q.grad, k.grad, v.grad, beta.grad, h0.grad - assert_close(' o', ref, tri, 0.006) - assert_close(' ht', ref_ht, tri_ht, 0.006) - assert_close(' dq', ref_dq, tri_dq, 0.008) - assert_close(' dk', ref_dk, tri_dk, 0.008) - assert_close(' dv', ref_dv, tri_dv, 0.008) - assert_close(' db', ref_dbeta, tri_dbeta, 0.008) + assert_close('o', ref, tri, 0.006) + assert_close('ht', ref_ht, tri_ht, 0.006) + assert_close('dq', ref_dq, tri_dq, 0.008) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.008) + assert_close('db', ref_dbeta, tri_dbeta, 0.008) assert_close('dh0', ref_dh0, tri_dh0, 0.008) -@pytest.mark.parametrize('N', [4]) -@pytest.mark.parametrize('T', test_t_varlen_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('scale', [1]) -@pytest.mark.parametrize('dtype', [torch.bfloat16]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', - reason='Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set' +@pytest.mark.parametrize( + ('H', 'D', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-D{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (2, 64, [0, 15], torch.float16), + (3, 60, [0, 111, 500], torch.float16), + (3, 64, [0, 256, 500, 900, 1000], torch.float16), + (4, 100, [0, 15, 100, 300, 1200, 1599, 1800, 2000], torch.float16), + ] + ] ) @pytest.mark.skipif( device_platform == 'intel', reason='Intel Triton Failure' ) def test_chunk_varlen( - N: int, - T: int, H: int, D: int, - scale: float, + cu_seqlens: List[int], dtype: torch.dtype, ): torch.manual_seed(42) - os.environ['TRITON_F32_DEFAULT'] = 'ieee' - # randomly split the sequence into N segments - cu_seqlens = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 16)[:N-1]], - torch.tensor([T], dtype=torch.long) - ], 0).to(device).sort()[0] + T = cu_seqlens[-1] + N = len(cu_seqlens) - 1 + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + # seq-first required for inputs with variable lengths q = torch.randn((1, T, H, D), dtype=dtype) k = F.normalize(torch.randn(1, T, H, D, dtype=torch.float32), p=2, dim=-1).to(dtype) @@ -133,7 +122,6 @@ def test_chunk_varlen( k.clone(), v.clone(), beta.clone(), - scale=scale, output_final_state=True, initial_state=h0.clone(), cu_seqlens=cu_seqlens, @@ -147,7 +135,6 @@ def test_chunk_varlen( k.clone(), v.clone(), beta.clone(), - scale=scale, output_final_state=True, initial_state=h0.clone(), cu_seqlens=cu_seqlens, @@ -155,142 +142,10 @@ def test_chunk_varlen( ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dh0 = q.grad, k.grad, v.grad, beta.grad, h0.grad - assert_close(' o', ref, tri, 0.005) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.008) - assert_close(' dk', ref_dk, tri_dk, 0.008) - assert_close(' dv', ref_dv, tri_dv, 0.008) - assert_close(' db', ref_dbeta, tri_dbeta, 0.008) + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.008) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.008) + assert_close('db', ref_dbeta, tri_dbeta, 0.008) assert_close('dh0', ref_dh0, tri_dh0, 0.008) - - -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('scale', [0.1]) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.skipif( - device_platform == 'intel', - reason='Intel Triton Failure' -) -def test_l2_in_kernel( - B: int, - T: int, - H: int, - D: int, - dtype: torch.dtype, - scale: float, -): - q = torch.randn(B, T, H, D, dtype=dtype) - k = torch.randn(B, T, H, D, dtype=dtype) - v = torch.randn(B, T, H, D, dtype=dtype) - beta = torch.rand(B, T, H, dtype=dtype).sigmoid() - h0 = torch.randn(B, H, D, D, dtype=torch.float32) - - q, k, v, beta, h0 = map(lambda x: x.to(device).requires_grad_(True), (q, k, v, beta, h0)) - do = torch.rand_like(v) - dht = torch.rand_like(h0) - - tri, tri_ht = chunk_delta_rule( - F.normalize(q.clone(), p=2, dim=-1).to(dtype), - F.normalize(k.clone(), p=2, dim=-1).to(dtype), - v.clone(), - beta.clone(), - scale=scale, - output_final_state=True, - initial_state=h0.clone(), - ) - ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) - tri_dq, tri_dk, tri_dv, tri_dbeta, tri_dh0 = q.grad, k.grad, v.grad, beta.grad, h0.grad - q.grad = k.grad = v.grad = beta.grad = h0.grad = None - - ref, ref_ht = chunk_delta_rule( - q.clone(), - k.clone(), - v.clone(), - beta.clone(), - scale=scale, - output_final_state=True, - initial_state=h0.clone(), - use_qk_l2norm_in_kernel=True - ) - ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) - ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dh0 = q.grad, k.grad, v.grad, beta.grad, h0.grad - q.grad = k.grad = v.grad = beta.grad = h0.grad = None - assert_close(' o', ref, tri, 0.01) - assert_close(' ht', ref_ht, tri_ht, 0.01) - assert_close(' dq', ref_dq, tri_dq, 0.01) - assert_close(' dk', ref_dk, tri_dk, 0.01) - assert_close(' dv', ref_dv, tri_dv, 0.01) - assert_close(' db', ref_dbeta, tri_dbeta, 0.01) - assert_close('dh0', ref_dh0, tri_dh0, 0.01) - - tri, tri_ht = fused_recurrent_delta_rule( - F.normalize(q.clone().float(), p=2, dim=-1).to(dtype), - F.normalize(k.clone().float(), p=2, dim=-1).to(dtype), - v.clone(), - beta.clone(), - scale=scale, - output_final_state=True, - initial_state=h0.clone(), - ) - ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) - tri_dq, tri_dk, tri_dv, tri_dbeta, tri_dh0 = q.grad, k.grad, v.grad, beta.grad, h0.grad - q.grad = k.grad = v.grad = beta.grad = h0.grad = None - - ref, ref_ht = fused_recurrent_delta_rule( - q.clone(), - k.clone(), - v.clone(), - beta.clone(), - scale=scale, - output_final_state=True, - initial_state=h0.clone(), - use_qk_l2norm_in_kernel=True - ) - ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) - ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dh0 = q.grad, k.grad, v.grad, beta.grad, h0.grad - q.grad = k.grad = v.grad = beta.grad = h0.grad = None - - assert_close(' o', ref, tri, 0.002) - assert_close(' ht', ref_ht, tri_ht, 0.002) - assert_close(' dq', ref_dq, tri_dq, 0.002) - assert_close(' dk', ref_dk, tri_dk, 0.002) - assert_close(' dv', ref_dv, tri_dv, 0.002) - assert_close(' db', ref_dbeta, tri_dbeta, 0.002) - assert_close('dh0', ref_dh0, tri_dh0, 0.002) - - tri, tri_ht = fused_recurrent_delta_rule( - F.normalize(q.float().clone(), p=2, dim=-1).to(dtype), - F.normalize(k.float().clone(), p=2, dim=-1).to(dtype), - v.clone(), - beta.clone(), - scale=scale, - output_final_state=True, - initial_state=h0.clone(), - ) - ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) - tri_dq, tri_dk, tri_dv, tri_dbeta, tri_dh0 = q.grad, k.grad, v.grad, beta.grad, h0.grad - q.grad = k.grad = v.grad = beta.grad = h0.grad = None - - ref, ref_ht = fused_recurrent_delta_rule( - q.clone(), - k.clone(), - v.clone(), - beta.clone(), - scale=scale, - output_final_state=True, - initial_state=h0.clone(), - use_qk_l2norm_in_kernel=True - ) - ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) - ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dh0 = q.grad, k.grad, v.grad, beta.grad, h0.grad - q.grad = k.grad = v.grad = beta.grad = h0.grad = None - assert_close(' o', ref, tri, 0.002) - assert_close(' ht', ref_ht, tri_ht, 0.002) - assert_close(' dq', ref_dq, tri_dq, 0.002) - assert_close(' dk', ref_dk, tri_dk, 0.002) - assert_close(' dv', ref_dv, tri_dv, 0.002) - assert_close(' db', ref_dbeta, tri_dbeta, 0.002) - assert_close('dh0', ref_dh0, tri_dh0, 0.002) diff --git a/tests/ops/test_delta_product.py b/tests/ops/test_delta_product.py index 085cdc9428..7a7e2d3edf 100644 --- a/tests/ops/test_delta_product.py +++ b/tests/ops/test_delta_product.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang -import os from typing import List import pytest @@ -10,32 +10,39 @@ from fla.ops.gated_delta_product import chunk_gated_delta_product from fla.ops.gated_delta_product.chunk_ref import chunk_gated_delta_product_ref from fla.ops.gated_delta_product.naive import naive_recurrent_gated_delta_product -from fla.utils import assert_close, device, is_intel_alchemist - - -@pytest.mark.parametrize('B', [2]) -@pytest.mark.parametrize('T', [30, 100, 1000]) -@pytest.mark.parametrize('H', [3]) -@pytest.mark.parametrize('D', [64, 100]) -@pytest.mark.parametrize('num_householder', [3]) -@pytest.mark.parametrize('scale', [1]) -@pytest.mark.parametrize('dtype', [torch.float16]) +from fla.utils import assert_close, device, device_platform + + +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'num_householder', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-scale{}-num_householder{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, 1, 1, torch.float16), + (2, 1024, 4, 60, 1, 1, torch.float16), + (2, 1024, 8, 128, 1, 2, torch.float16), + (2, 1024, 8, 128, 0.1, 2, torch.float16), + (2, 1024, 8, 128, 1, 2, torch.float16), + (4, 2048, 8, 64, 0.1, 3, torch.float16), + (2, 1024, 8, 128, 1, 3, torch.float16), + (2, 1024, 8, 128, 1, 3, torch.float16), + ] + ] +) @pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' + device_platform == 'intel', + reason='Intel Triton Failure' ) def test_chunk( B: int, T: int, H: int, D: int, + scale: float, num_householder: int, dtype: torch.dtype, - scale: float, ): - if is_intel_alchemist and D > 128: - pytest.skip(reason='chunk_gated_delta_rule is not supported on alchemist for D>128') - + torch.manual_seed(42) q = torch.randn(B, T, H, D, dtype=dtype) k = F.normalize(torch.randn(B, T * num_householder, H, D, dtype=torch.float32), p=2, dim=-1).to(dtype) v = torch.randn(B, T * num_householder, H, D, dtype=dtype) @@ -74,40 +81,42 @@ def test_chunk( ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dh0 = q.grad, k.grad, v.grad, beta.grad, h0.grad - assert_close(' o', ref, tri, 0.005) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.008) - assert_close(' dk', ref_dk, tri_dk, 0.008) - assert_close(' dv', ref_dv, tri_dv, 0.008) - assert_close(' db', ref_dbeta, tri_dbeta, 0.02) + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.008) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.008) + assert_close('db', ref_dbeta, tri_dbeta, 0.02) assert_close('dh0', ref_dh0, tri_dh0, 0.008) -@pytest.mark.parametrize('H', [2]) -@pytest.mark.parametrize('D', [128]) -@pytest.mark.parametrize('cu_seqlens', [[0, 15, 122, 229, 400, 467, 1000]]) -@pytest.mark.parametrize('scale', [1]) -@pytest.mark.parametrize('num_householder', [3, 4]) -@pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize( + ('H', 'D', 'num_householder', 'cu_seqlens', 'dtype'), + [ + (2, 64, 3, [0, 63, ], torch.float16), + (2, 100, 2, [0, 63, 100, 500, 1000], torch.float16), + (2, 128, 2, [0, 100, 300, 800, 1500, 2000], torch.float16), + (2, 256, 3, [0, 100, 123, 300, 500, 800, 1000, 1500, 2048], torch.float16), + ] +) @pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', - reason='Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set' + device_platform == 'intel', + reason='Intel Triton Failure' ) def test_chunk_varlen( - cu_seqlens: List[int], H: int, D: int, - scale: float, num_householder: int, + cu_seqlens: List[int], dtype: torch.dtype, ): - if is_intel_alchemist and D > 128: - pytest.skip(reason='chunk_gated_delta_rule is not supported on alchemist for D>128') torch.manual_seed(42) - os.environ['TRITON_F32_DEFAULT'] = 'ieee' - cu_seqlens = torch.LongTensor(cu_seqlens).to(device) + T = cu_seqlens[-1] N = len(cu_seqlens) - 1 + cu_seqlens = torch.LongTensor(cu_seqlens).to(device) + scale = 1.0 + q = torch.nn.functional.normalize(torch.randn((1, T, H, D), dtype=dtype), dim=-1, p=2) k = torch.nn.functional.normalize(torch.randn(1, T*num_householder, H, D, dtype=dtype), dim=-1, p=2) v = torch.randn((1, T*num_householder, H, D), dtype=dtype) @@ -149,12 +158,12 @@ def test_chunk_varlen( ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dh0 = q.grad, k.grad, v.grad, beta.grad, h0.grad - assert_close(' o', ref, tri, 0.005) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.007) - assert_close(' dk', ref_dk, tri_dk, 0.008) - assert_close(' dv', ref_dv, tri_dv, 0.007) - assert_close(' db', ref_dbeta, tri_dbeta, 0.015) + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.007) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.007) + assert_close('db', ref_dbeta, tri_dbeta, 0.015) assert_close('dh0', ref_dh0, tri_dh0, 0.007) q.grad = k.grad = v.grad = beta.grad = h0.grad = None @@ -167,17 +176,17 @@ def test_chunk_varlen( v_i = v[:, start*num_householder:end*num_householder, :, :] beta_i = beta[:, start*num_householder:end*num_householder, :] o3_i, h3_i = naive_recurrent_gated_delta_product( - q_i, k_i, v_i, None, beta_i, scale=1.0, cu_seqlens=None, output_final_state=True, num_householder=num_householder + q_i, k_i, v_i, None, beta_i, scale=scale, cu_seqlens=None, output_final_state=True, num_householder=num_householder ) torch_ref[:, start:end, :, :] = o3_i torch_ref_ht[i, :, :, :] = h3_i.squeeze(0) ((torch_ref * do).sum() + (torch_ref_ht * dht).sum()).backward(retain_graph=True) - assert_close(' o', ref, tri, 0.005) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.007) - assert_close(' dk', ref_dk, tri_dk, 0.008) - assert_close(' dv', ref_dv, tri_dv, 0.007) - assert_close(' db', ref_dbeta, tri_dbeta, 0.015) + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.007) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.007) + assert_close('db', ref_dbeta, tri_dbeta, 0.015) assert_close('dh0', ref_dh0, tri_dh0, 0.007) diff --git a/tests/ops/test_dplr_delta.py b/tests/ops/test_dplr_delta.py index dc4619dc63..286b64188d 100644 --- a/tests/ops/test_dplr_delta.py +++ b/tests/ops/test_dplr_delta.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import os +from typing import List import pytest import torch @@ -8,21 +9,7 @@ from einops import rearrange from fla.ops.generalized_delta_rule.dplr import chunk_dplr_delta_rule, fused_recurrent_dplr_delta_rule -from fla.utils import COMPILER_MODE, assert_close, device, device_platform - -if COMPILER_MODE: - test_b_list = [1] - test_t_list = [4096] - test_t_varlen_list = test_t_list - test_d_list = [64, 128, 256] - test_gate_list = [1.0] -else: - test_b_list = [2] - test_t_list = [1, 15, 63, 300] - test_t_varlen_list = [63, 286, 300, 512] - test_d_list = [32, 64, 100, 256] - test_gate_list = [1, 0.1, 10] -test_h_list = [2] +from fla.utils import assert_close, device, device_platform def recurrent_dplr_delta_rule_ref( @@ -142,15 +129,19 @@ def chunk_dplr_delta_rule_ref( return o, S -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('scale', [0.25]) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-scale{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, 1, torch.float), + (2, 1024, 4, 60, 1, torch.float), + (2, 1024, 8, 128, 1, torch.float), + (2, 1024, 8, 128, 0.1, torch.float), + (4, 2048, 8, 64, 0.1, torch.float), + (2, 1024, 8, 128, 1, torch.float16), + ] + ] ) def test_recurrent_fwd( B: int, @@ -197,21 +188,24 @@ def test_recurrent_fwd( initial_state=h0.clone(), output_final_state=True, ) - assert_close(' o', ref, tri, 0.001) + assert_close('o', ref, tri, 0.001) assert_close('ht', ref_ht, tri_ht, 0.001) -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('scale', [0.25]) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-scale{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, 1, torch.float), + (2, 1024, 4, 60, 1, torch.float), + (2, 1024, 8, 100, 1, torch.float), + (2, 1024, 8, 128, 0.1, torch.float), + (4, 2048, 8, 64, 0.1, torch.float), + ] + ] ) -def test_fused_recurrent_fwd( +def test_fused_recurrent( B: int, T: int, H: int, @@ -255,21 +249,25 @@ def test_fused_recurrent_fwd( initial_state=h0.clone(), output_final_state=True, ) - assert_close(' o', ref, tri, 0.002) + assert_close('o', ref, tri, 0.002) assert_close('ht', ref_ht, tri_ht, 0.002) -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('gate_logit_normalizer', test_gate_list) -@pytest.mark.parametrize('scale', [0.25]) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.parametrize('compile', [False, True]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'gate_logit_normalizer', 'mask_p', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-mask_p{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, 1, 1, 0, torch.float16), + (2, 1000, 3, 60, 1, 1, 0, torch.float16), + (2, 1024, 3, 64, 0.1, 1, 0.5, torch.float16), + (2, 1024, 4, 100, 1, 0.1, 0, torch.float16), + (2, 1024, 4, 128, 0.1, 1, 0, torch.float16), + (2, 1024, 4, 128, 0.1, 1, 0.5, torch.float16), + (2, 1024, 4, 128, 0.1, 10, 0, torch.float16), + (4, 2048, 8, 64, 0.1, 1, 0, torch.float16) + ] + ] ) @pytest.mark.skipif( device_platform == 'intel', @@ -282,8 +280,8 @@ def test_chunk( D: int, scale: float, gate_logit_normalizer: float, + mask_p: float, dtype: torch.dtype, - compile: bool, ): torch.manual_seed(42) q = torch.randn(B, T, H, D, dtype=dtype) @@ -294,7 +292,9 @@ def test_chunk( a = F.normalize(a, p=2, dim=-1) b = -a - gk = F.logsigmoid(gk) / gate_logit_normalizer + gk = F.logsigmoid(gk) + gk = gk / gate_logit_normalizer + gk = gk * (torch.rand_like(gk) > mask_p) h0 = torch.randn(B, H, D, D, dtype=torch.float) q, k, v, a, b, gk, h0 = map(lambda x: x.to(device).requires_grad_(True), (q, k, v, a, b, gk, h0)) @@ -315,9 +315,7 @@ def test_chunk( ref_dq, ref_dk, ref_dv, ref_da, ref_db, ref_dg, ref_dh0 = q.grad, k.grad, v.grad, a.grad, b.grad, gk.grad, h0.grad q.grad = k.grad = v.grad = a.grad = b.grad = gk.grad = h0.grad = None - chunk_compiled = torch.compile(chunk_dplr_delta_rule) if compile else chunk_dplr_delta_rule - - tri, tri_ht = chunk_compiled( + tri, tri_ht = chunk_dplr_delta_rule( q=q.clone(), k=k.clone(), v=v.clone(), @@ -332,48 +330,48 @@ def test_chunk( tri_dq, tri_dk, tri_dv, tri_da, tri_db, tri_dg, tri_dh0 = q.grad, k.grad, v.grad, a.grad, b.grad, gk.grad, h0.grad q.grad = k.grad = v.grad = a.grad = b.grad = gk.grad = h0.grad = None - assert_close(' o', ref, tri, 0.007) - assert_close(' ht', ref_ht, tri_ht, 0.008) - assert_close(' dq', ref_dq, tri_dq, 0.008) - assert_close(' dk', ref_dk, tri_dk, 0.008) - assert_close(' dv', ref_dv, tri_dv, 0.008) - assert_close(' da', ref_da, tri_da, 0.008) - assert_close(' db', ref_db, tri_db, 0.008) + assert_close('o', ref, tri, 0.007) + assert_close('ht', ref_ht, tri_ht, 0.008) + assert_close('dq', ref_dq, tri_dq, 0.008) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.008) + assert_close('da', ref_da, tri_da, 0.008) + assert_close('db', ref_db, tri_db, 0.008) if gate_logit_normalizer >= 1 and ref_dg.norm() > 0.01: # otherwise it is meaningless - assert_close(' dg', ref_dg, tri_dg, 0.008) + assert_close('dg', ref_dg, tri_dg, 0.008) assert_close('dh0', ref_dh0, tri_dh0, 0.008) -@pytest.mark.parametrize('N', test_b_list) -@pytest.mark.parametrize('T', test_t_varlen_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('scale', [0.25]) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', - reason='Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set' +@pytest.mark.parametrize( + ('H', 'D', 'mask_p', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-D{}-mask_p{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (4, 64, 0, [0, 15], torch.float16), + (4, 64, 0, [0, 256, 500, 1000], torch.float16), + (4, 64, 0.5, [0, 256, 500, 1000], torch.float16), + (4, 100, 0, [0, 15, 100, 300, 1111, 1599, 2000], torch.float16), + ] + ] ) @pytest.mark.skipif( device_platform == 'intel', reason='Intel Triton Failure' ) def test_chunk_varlen( - N: int, - T: int, H: int, D: int, - scale: float, + mask_p: float, + cu_seqlens: List[int], dtype: torch.dtype, ): torch.manual_seed(42) os.environ['TRITON_F32_DEFAULT'] = 'ieee' - # randomly split the sequence into N segments - cu_seqlens = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 16)[:N-1]], - torch.tensor([T], dtype=torch.long) - ], 0).to(device).sort()[0] + + N = len(cu_seqlens) - 1 + T = cu_seqlens[-1] + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + # seq-first required for inputs with variable lengths q = torch.randn(1, T, H, D, dtype=dtype) k = torch.randn(1, T, H, D, dtype=dtype) @@ -383,6 +381,7 @@ def test_chunk_varlen( a = F.normalize(a, p=2, dim=-1) b = -a gk = F.logsigmoid(gk) + gk = gk * (torch.rand_like(gk) > mask_p) h0 = torch.randn(N, H, D, D, dtype=torch.float) q, k, v, a, b, gk, h0 = map(lambda x: x.to(device).requires_grad_(True), (q, k, v, a, b, gk, h0)) @@ -393,7 +392,6 @@ def test_chunk_varlen( a=a.clone(), b=b.clone(), gk=gk.clone(), - scale=scale, output_final_state=True, initial_state=h0.clone(), cu_seqlens=cu_seqlens, @@ -414,7 +412,6 @@ def test_chunk_varlen( a=a[:, cu_seqlens[i]:cu_seqlens[i+1]], b=b[:, cu_seqlens[i]:cu_seqlens[i+1]], gk=gk[:, cu_seqlens[i]:cu_seqlens[i+1]], - scale=scale, initial_state=h0[i, None], output_final_state=True, ) @@ -426,12 +423,12 @@ def test_chunk_varlen( ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) ref_dq, ref_dk, ref_dv, ref_da, ref_db, ref_dg, ref_dh0 = q.grad, k.grad, v.grad, a.grad, b.grad, gk.grad, h0.grad - assert_close(' o', ref, tri, 0.007) - assert_close(' ht', ref_ht, tri_ht, 0.008) - assert_close(' dq', ref_dq, tri_dq, 0.008) - assert_close(' dk', ref_dk, tri_dk, 0.008) - assert_close(' dv', ref_dv, tri_dv, 0.008) - assert_close(' da', ref_da, tri_da, 0.008) - assert_close(' db', ref_db, tri_db, 0.008) - assert_close(' dg', ref_dg, tri_dg, 0.008) + assert_close('o', ref, tri, 0.007) + assert_close('ht', ref_ht, tri_ht, 0.008) + assert_close('dq', ref_dq, tri_dq, 0.008) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.008) + assert_close('da', ref_da, tri_da, 0.008) + assert_close('db', ref_db, tri_db, 0.008) + assert_close('dg', ref_dg, tri_dg, 0.008) assert_close('dh0', ref_dh0, tri_dh0, 0.008) diff --git a/tests/ops/test_forgetting_attn.py b/tests/ops/test_forgetting_attn.py index 9d6d81d5af..8fe4ebd9bd 100644 --- a/tests/ops/test_forgetting_attn.py +++ b/tests/ops/test_forgetting_attn.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- -import os -from typing import Optional, Tuple +from typing import List, Optional import pytest import torch @@ -9,21 +8,7 @@ from einops import rearrange, repeat from fla.ops.forgetting_attn.parallel import parallel_forgetting_attn -from fla.utils import COMPILER_MODE, assert_close, check_shared_mem, device, is_intel_alchemist - -if COMPILER_MODE: - test_b_list = [1] - test_t_list = [1024] - test_t_varlen_list = test_t_list - test_d_list = [64, 100] -else: - test_b_list = [2] - test_t_list = [3, 15, 63, 300, 1024, 2048] - test_t_varlen_list = [63, 300, 1024, 512, 2048] - test_d_list = [64, 100] -test_fgate_logit_range_list = [(0, 5), (5, 10)] -test_hq_list = [8, 16] -test_h_list = [2] +from fla.utils import assert_close, check_shared_mem, device, is_intel_alchemist def naive_forgetting_attn( @@ -47,45 +32,40 @@ def naive_forgetting_attn( return ref -@pytest.mark.parametrize("B", test_b_list) -@pytest.mark.parametrize("T", test_t_list) -@pytest.mark.parametrize("H", test_h_list) -@pytest.mark.parametrize("HQ", test_hq_list) -@pytest.mark.parametrize("D", test_d_list) -@pytest.mark.parametrize("fgate_logit_range", test_fgate_logit_range_list) -@pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.skipif( - os.getenv("SKIP_TEST_CHUNK_VARLEN") == "0", - reason="Skipping test because TEST_CHUNK_VARLEN is enabled" -) -@pytest.mark.skipif( - is_intel_alchemist, - reason="Intel Triton Failure" +@pytest.mark.parametrize( + ('B', 'T', 'H', 'HQ', 'D', 'scale'), + [ + pytest.param(*test, id="B{}-T{}-H{}-HQ{}-D{}-scale{}".format(*test)) + for test in [ + (1, 63, 1, 1, 64, 1.0), + (3, 111, 2, 2, 100, 1.0), + (3, 1024, 2, 8, 60, 0.1), + (3, 1024, 2, 8, 128, 0.1), + (4, 2048, 2, 8, 64, 0.1) + ] + ] ) def test_parallel( B: int, + T: int, H: int, HQ: int, - T: int, D: int, - fgate_logit_range: Tuple[float, float], - dtype: torch.dtype + scale: float, ): + torch.manual_seed(42) + dtype = torch.float16 if not check_shared_mem('hopper') and D > 128: # maybe we can enable this test on Triton 3.3.0 pytest.skip("Skipping test because global shared memory is not available") - torch.manual_seed(42) - os.environ['TRITON_F32_DEFAULT'] = 'ieee' q = torch.randn((B, T, HQ, D), dtype=dtype, device=device).requires_grad_(True) k = torch.randn((B, T, H, D), dtype=dtype, device=device).requires_grad_(True) v = torch.randn((B, T, H, D), dtype=dtype, device=device).requires_grad_(True) - logit_min, logit_max = fgate_logit_range - g = torch.rand((B, T, HQ), dtype=dtype, device=device) * (logit_max - logit_min) + logit_min - g = F.logsigmoid(g).requires_grad_(True) - do = torch.randn((B, T, HQ, D), dtype=dtype, device=device) - scale = D ** -0.5 + g = torch.randn((B, T, HQ), dtype=dtype, device=device).uniform_(-0.1, -0.01).requires_grad_(True) + + do = torch.randn((B, T, HQ, D), dtype=dtype, device=device) ref = naive_forgetting_attn(q, k, v, g, scale) ref.backward(do) ref_dq, q.grad = q.grad.clone(), None @@ -107,54 +87,40 @@ def test_parallel( assert_close("dg", ref_dg, tri_dg, 0.005) -@pytest.mark.parametrize("N", test_b_list) -@pytest.mark.parametrize("T", test_t_varlen_list) -@pytest.mark.parametrize("H", test_h_list) -@pytest.mark.parametrize("HQ", test_hq_list) -@pytest.mark.parametrize("D", test_d_list) -@pytest.mark.parametrize("fgate_logit_range", test_fgate_logit_range_list) -@pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.skipif( - os.getenv("SKIP_TEST_CHUNK_VARLEN") == "1", - reason="Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set" +@pytest.mark.parametrize( + ('H', 'HQ', 'D', 'cu_seqlens'), + [ + pytest.param(*test, id="H{}-HQ{}-D{}-cu_seqlens{}".format(*test)) + for test in [ + (2, 2, 64, [0, 15]), + (2, 8, 64, [0, 256, 500, 1000]), + (2, 2, 100, [0, 15, 100, 300, 1200, 2000]), + ] + ] ) @pytest.mark.skipif( is_intel_alchemist, reason="Intel Triton Failure" ) def test_parallel_varlen( - N: int, - T: int, H: int, HQ: int, D: int, - fgate_logit_range: Tuple[float, float], - dtype: torch.dtype, + cu_seqlens: List[int], ): - if not check_shared_mem('hopper') and D > 128: - # maybe we can enable this test on Triton 3.3.0 - pytest.skip("Skipping test because global shared memory is not available") torch.manual_seed(42) - os.environ['TRITON_F32_DEFAULT'] = 'ieee' - - N = min(1, N) if T < 64 else N - # randomly split the sequence into N segments - offsets = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 16)[:N-1]], - torch.tensor([T], dtype=torch.long) - ], 0).to(device).sort()[0].to(torch.int32) + T = cu_seqlens[-1] + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + dtype = torch.float16 # seq-first required for inputs with variable lengths q = torch.randn((1, T, HQ, D), dtype=dtype, device=device).requires_grad_() k = torch.randn((1, T, H, D), dtype=dtype, device=device).requires_grad_() v = torch.randn((1, T, H, D), dtype=dtype, device=device).requires_grad_() - logit_min, logit_max = fgate_logit_range - g = torch.rand((1, T, HQ), dtype=dtype, device=device) * (logit_max - logit_min) + logit_min - g = F.logsigmoid(g).requires_grad_(True) + g = torch.rand((1, T, HQ), dtype=dtype, device=device).uniform_(-0.1, -0.01).requires_grad_(True) do = torch.randn((1, T, HQ, D), dtype=dtype, device=device) ref = q.new_empty(1, T, HQ, D) - for bos, eos in zip(offsets[:-1], offsets[1:]): + for bos, eos in zip(cu_seqlens[:-1], cu_seqlens[1:]): ref[:, bos:eos] = naive_forgetting_attn( q=q[:, bos:eos], k=k[:, bos:eos], @@ -172,7 +138,7 @@ def test_parallel_varlen( k=k, v=v, g=g, - cu_seqlens=offsets + cu_seqlens=cu_seqlens ) tri.backward(do) tri_dq, q.grad = q.grad.clone(), None diff --git a/tests/ops/test_gated_delta.py b/tests/ops/test_gated_delta.py index 1bb06fb895..9111370601 100644 --- a/tests/ops/test_gated_delta.py +++ b/tests/ops/test_gated_delta.py @@ -9,22 +9,7 @@ from einops import rearrange, repeat from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule -from fla.utils import COMPILER_MODE, assert_close, device, is_intel_alchemist - -if COMPILER_MODE: - test_b_list = [1] - test_t_list = [4096] - test_t_varlen_list = test_t_list - test_d_list = [64, 128, 256] - test_gate_list = [1.0] -else: - test_b_list = [2] - test_t_list = [1, 15, 63, 300] - test_t_varlen_list = [63, 286, 300, 512] - test_d_list = [64, 32, 100, 256] - test_gate_list = [1, 0.1, 10] -test_h_list = [2] -test_hv_list = [4] +from fla.utils import assert_close, device, is_intel_alchemist def recurrent_gated_delta_rule_ref( @@ -137,18 +122,21 @@ def chunk_gated_delta_rule_ref( return o, S -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('HV', test_hv_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('gate_logit_normalizer', test_gate_list) -@pytest.mark.parametrize('scale', [1]) -@pytest.mark.parametrize('use_qk_l2norm_in_kernel', [False, True]) -@pytest.mark.parametrize('dtype', [torch.float32, torch.float16]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +@pytest.mark.parametrize( + ('B', 'T', 'H', 'HV', 'D', 'scale', 'gate_logit_normalizer', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-HV{}-D{}-scale{}-gate_logit_normalizer{}-{}".format(*test)) + for test in [ + (1, 63, 1, 1, 64, 1, 1, torch.float), + (2, 1024, 4, 4, 60, 1, 1, torch.float), + (2, 1024, 2, 8, 128, 1, 0.1, torch.float), + (2, 1024, 2, 2, 128, 0.1, 1, torch.float), + (2, 1024, 3, 3, 128, 1, 10, torch.float), + (4, 2048, 4, 4, 64, 0.1, 1, torch.float), + (2, 1024, 4, 4, 128, 1, 0.1, torch.float16), + (2, 1024, 4, 8, 128, 1, 10, torch.float16), + ] + ] ) def test_fused_recurrent( B: int, @@ -157,9 +145,8 @@ def test_fused_recurrent( HV: int, D: int, scale: float, - dtype: torch.dtype, - use_qk_l2norm_in_kernel: bool, gate_logit_normalizer: float, + dtype: torch.dtype, ): torch.manual_seed(42) q = torch.randn(B, T, H, D, dtype=torch.float32) @@ -181,41 +168,45 @@ def test_fused_recurrent( output_final_state=True, ) tri, tri_ht = fused_recurrent_gated_delta_rule( - q=F.normalize(q.clone(), p=2, dim=-1).to(dtype) if not use_qk_l2norm_in_kernel else q.clone(), - k=F.normalize(k.clone(), p=2, dim=-1).to(dtype) if not use_qk_l2norm_in_kernel else k.clone(), + q=q.clone(), + k=k.clone(), v=v.clone(), beta=beta.clone(), g=g.clone(), scale=scale, initial_state=h0.clone(), - use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + use_qk_l2norm_in_kernel=True, output_final_state=True, ) - assert_close(' o', ref, tri, 0.002) - assert_close(' ht', ref_ht, tri_ht, 0.002) + assert_close('o', ref, tri, 0.002) + assert_close('ht', ref_ht, tri_ht, 0.002) -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('gate_logit_normalizer', test_gate_list) -@pytest.mark.parametrize('scale', [1, 0.1]) -@pytest.mark.parametrize('mask_p', [0, 0.5]) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'gate_logit_normalizer', 'mask_p', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-mask_p{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, 1, 1, 0, torch.float16), + (2, 1000, 3, 60, 1, 1, 0, torch.float16), + (2, 1024, 3, 64, 0.1, 1, 0.5, torch.float16), + (2, 1024, 4, 100, 1, 0.1, 0, torch.float16), + (2, 1024, 4, 128, 0.1, 1, 0, torch.float16), + (2, 1024, 4, 128, 0.1, 1, 0.5, torch.float16), + (2, 1024, 4, 128, 0.1, 10, 0, torch.float16), + (4, 2048, 8, 64, 0.1, 1, 0, torch.float16) + ] + ] ) def test_chunk( B: int, T: int, H: int, D: int, - dtype: torch.dtype, scale: float, gate_logit_normalizer: float, mask_p: float, + dtype: torch.dtype, ): if is_intel_alchemist and D > 128: pytest.skip(reason='chunk_gated_delta_rule is not supported on alchemist for D>128') @@ -259,33 +250,38 @@ def test_chunk( ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dg, ref_dh0 = q.grad, k.grad, v.grad, beta.grad, g.grad, h0.grad - assert_close(' o', ref, tri, 0.005) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.008) - assert_close(' dk', ref_dk, tri_dk, 0.008) - assert_close(' dv', ref_dv, tri_dv, 0.008) - assert_close(' db', ref_dbeta, tri_dbeta, 0.02) + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.008) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.008) + assert_close('db', ref_dbeta, tri_dbeta, 0.02) if gate_logit_normalizer >= 1 and ref_dg.norm() > 0.01: - assert_close(' dg', ref_dg, tri_dg, 0.02) + assert_close('dg', ref_dg, tri_dg, 0.02) assert_close('dh0', ref_dh0, tri_dh0, 0.008) -@pytest.mark.parametrize('H', [2]) -@pytest.mark.parametrize('D', [128]) -@pytest.mark.parametrize('cu_seqlens', [[0, 122, 229, 400, 1000]]) -@pytest.mark.parametrize('scale', [1]) -@pytest.mark.parametrize('mask_p', [0.5]) -@pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize( + ('H', 'D', 'mask_p', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-D{}-mask_p{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (4, 60, 0, [0, 15], torch.float16), + (4, 64, 0, [0, 256, 500, 1000], torch.float16), + (4, 64, 0.5, [0, 256, 500, 1000], torch.float16), + (4, 100, 0, [0, 15, 100, 300, 1200, 2000], torch.float16), + ] + ] +) @pytest.mark.skipif( os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', reason='Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set' ) def test_chunk_varlen( - cu_seqlens: List[int], H: int, D: int, - scale: float, mask_p: float, + cu_seqlens: List[int], dtype: torch.dtype, ): if is_intel_alchemist and D > 128: @@ -316,7 +312,6 @@ def test_chunk_varlen( v=v.clone(), beta=beta.clone(), g=g.clone(), - scale=scale, output_final_state=True, initial_state=h0.clone(), cu_seqlens=cu_seqlens, @@ -334,7 +329,6 @@ def test_chunk_varlen( v=v[:, cu_seqlens[i]:cu_seqlens[i+1]], beta=beta[:, cu_seqlens[i]:cu_seqlens[i+1]], g=g[:, cu_seqlens[i]:cu_seqlens[i+1]], - scale=scale, initial_state=h0[i], output_final_state=True, ) @@ -346,11 +340,11 @@ def test_chunk_varlen( ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dg, ref_dh0 = q.grad, k.grad, v.grad, beta.grad, g.grad, h0.grad - assert_close(' o', ref, tri, 0.005) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.007) - assert_close(' dk', ref_dk, tri_dk, 0.008) - assert_close(' dv', ref_dv, tri_dv, 0.007) - assert_close(' db', ref_dbeta, tri_dbeta, 0.015) + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.007) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.007) + assert_close('db', ref_dbeta, tri_dbeta, 0.015) assert_close('dh0', ref_dh0, tri_dh0, 0.007) - assert_close(' dg', ref_dg, tri_dg, 0.015) + assert_close('dg', ref_dg, tri_dg, 0.015) diff --git a/tests/ops/test_gated_delta_product.py b/tests/ops/test_gated_delta_product.py index c8cf709413..38d604ed53 100644 --- a/tests/ops/test_gated_delta_product.py +++ b/tests/ops/test_gated_delta_product.py @@ -10,47 +10,35 @@ from fla.ops.gated_delta_product import chunk_gated_delta_product from fla.ops.gated_delta_product.chunk_ref import chunk_gated_delta_product_ref from fla.ops.gated_delta_product.naive import naive_recurrent_gated_delta_product -from fla.utils import COMPILER_MODE, assert_close, device, is_intel_alchemist - -if COMPILER_MODE: - test_b_list = [1] - test_t_list = [4096] - test_t_varlen_list = test_t_list - test_d_list = [64, 128, 256] - test_gate_list = [1.0] -else: - test_b_list = [2] - test_t_list = [63, 300, 1000] - test_t_varlen_list = [63, 286, 300, 512] - test_d_list = [64, 32, 100, 256] - test_gate_list = [1, 0.1, 10] -test_h_list = [2] -test_hv_list = [4] - - -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('num_householder', [3]) -@pytest.mark.parametrize('gate_logit_normalizer', test_gate_list) -@pytest.mark.parametrize('scale', [1]) -@pytest.mark.parametrize('mask_p', [0.5]) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +from fla.utils import assert_close, device, is_intel_alchemist + + +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'num_householder', 'gate_logit_normalizer', 'mask_p', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-scale{}-num_householder{}-gate_logit_normalizer{}-mask_p{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, 0.1, 1, 1, 0, torch.float16), + (2, 200, 3, 60, 0.1, 1, 1, 0, torch.float16), + (2, 1000, 4, 64, 0.1, 2, 0.1, 0.5, torch.float16), + (2, 1024, 4, 64, 1, 2, 1, 0, torch.float16), + (2, 1024, 6, 100, 1, 2, 10, 0, torch.float16), + (4, 1500, 8, 128, 0.1, 3, 1, 0.5, torch.float16), + (2, 2048, 8, 128, 1, 3, 1, 0, torch.float16), + (2, 2048, 8, 128, 1, 3, 1, 0, torch.float16), + ] + ] ) def test_chunk( B: int, T: int, H: int, D: int, - num_householder: int, - dtype: torch.dtype, scale: float, + num_householder: int, gate_logit_normalizer: float, mask_p: float, + dtype: torch.dtype, ): if is_intel_alchemist and D > 128: pytest.skip(reason='chunk_gated_delta_rule is not supported on alchemist for D>128') @@ -96,35 +84,38 @@ def test_chunk( ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dg, ref_dh0 = q.grad, k.grad, v.grad, beta.grad, g.grad, h0.grad - assert_close(' o', ref, tri, 0.005) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.008) - assert_close(' dk', ref_dk, tri_dk, 0.008) - assert_close(' dv', ref_dv, tri_dv, 0.008) - assert_close(' db', ref_dbeta, tri_dbeta, 0.02) + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.008) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.008) + assert_close('db', ref_dbeta, tri_dbeta, 0.02) if gate_logit_normalizer >= 1 and ref_dg.norm() > 0.01: - assert_close(' dg', ref_dg, tri_dg, 0.02) + assert_close('dg', ref_dg, tri_dg, 0.02) assert_close('dh0', ref_dh0, tri_dh0, 0.008) -@pytest.mark.parametrize('H', [2]) -@pytest.mark.parametrize('D', [128]) -@pytest.mark.parametrize('cu_seqlens', [[0, 15, 122, 229, 400, 467, 1000]]) -@pytest.mark.parametrize('scale', [1]) -@pytest.mark.parametrize('mask_p', [0.2]) -@pytest.mark.parametrize('num_householder', [3, 4]) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', - reason='Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set' +@pytest.mark.parametrize( + ('H', 'D', 'num_householder', 'mask_p', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-D{}-num_householder{}-mask_p{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (2, 64, 3, 0, [0, 63], torch.float16), + (2, 100, 2, 0, [0, 63, 100, 500, 1000], torch.float16), + (2, 100, 2, 0, [0, 100, 256, 512, 1500, 1500], torch.float16), + (2, 128, 2, 0, [0, 100, 300, 800, 1500, 2000], torch.float16), + (2, 128, 2, 0.5, [0, 31, 111, 799, 1000, 1500, 1800, 2000], torch.float16), + (2, 128, 2, 0.5, [0, 63, 300, 800, 1000, 1399, 2048], torch.float16), + (2, 256, 3, 0, [0, 100, 123, 300, 500, 800, 1000, 1500, 2048], torch.float16), + ] + ] ) def test_chunk_varlen( - cu_seqlens: List[int], H: int, D: int, - scale: float, - mask_p: float, num_householder: int, + mask_p: float, + cu_seqlens: List[int], dtype: torch.dtype, ): if is_intel_alchemist and D > 128: @@ -146,6 +137,7 @@ def test_chunk_varlen( q, k, v, beta, g, h0 = map(lambda x: x.to(device).requires_grad_(), (q, k, v, beta, g, h0)) do = torch.randn_like(q) dht = torch.rand_like(h0) + scale = D ** -0.5 tri, tri_ht = chunk_gated_delta_product( q=q.clone(), @@ -179,14 +171,14 @@ def test_chunk_varlen( ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dg, ref_dh0 = q.grad, k.grad, v.grad, beta.grad, g.grad, h0.grad - assert_close(' o', ref, tri, 0.005) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.007) - assert_close(' dk', ref_dk, tri_dk, 0.008) - assert_close(' dv', ref_dv, tri_dv, 0.007) - assert_close(' db', ref_dbeta, tri_dbeta, 0.015) + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.007) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.007) + assert_close('db', ref_dbeta, tri_dbeta, 0.015) assert_close('dh0', ref_dh0, tri_dh0, 0.007) - assert_close(' dg', ref_dg, tri_dg, 0.015) + assert_close('dg', ref_dg, tri_dg, 0.015) q.grad = k.grad = v.grad = beta.grad = g.grad = h0.grad = None torch_ref = torch.zeros_like(ref) @@ -199,18 +191,18 @@ def test_chunk_varlen( g_i = g[:, start:end, :] beta_i = beta[:, start*num_householder:end*num_householder, :] o3_i, h3_i = naive_recurrent_gated_delta_product( - q_i, k_i, v_i, g_i, beta_i, scale=1.0, cu_seqlens=None, output_final_state=True, num_householder=num_householder + q_i, k_i, v_i, g_i, beta_i, scale=scale, cu_seqlens=None, output_final_state=True, num_householder=num_householder ) torch_ref[:, start:end, :, :] = o3_i torch_ref_ht[i, :, :, :] = h3_i.squeeze(0) ((torch_ref * do).sum() + (torch_ref_ht * dht).sum()).backward(retain_graph=True) - assert_close(' o', ref, tri, 0.005) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.007) - assert_close(' dk', ref_dk, tri_dk, 0.008) - assert_close(' dv', ref_dv, tri_dv, 0.007) - assert_close(' db', ref_dbeta, tri_dbeta, 0.015) + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.007) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.007) + assert_close('db', ref_dbeta, tri_dbeta, 0.015) assert_close('dh0', ref_dh0, tri_dh0, 0.007) - assert_close(' dg', ref_dg, tri_dg, 0.015) + assert_close('dg', ref_dg, tri_dg, 0.015) diff --git a/tests/ops/test_gla.py b/tests/ops/test_gla.py index 7c181b38a5..2df9fcd9e9 100644 --- a/tests/ops/test_gla.py +++ b/tests/ops/test_gla.py @@ -80,12 +80,12 @@ def test_fused_recurrent( tri_dg, g.grad = g.grad.clone(), None tri_dh0, h0.grad = h0.grad.clone(), None - assert_close(' o', ref, tri, 0.005) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.005) - assert_close(' dk', ref_dk, tri_dk, 0.005) - assert_close(' dv', ref_dv, tri_dv, 0.005) - assert_close(' dg', ref_dg, tri_dg, 0.005) + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.005) + assert_close('dk', ref_dk, tri_dk, 0.005) + assert_close('dv', ref_dv, tri_dv, 0.005) + assert_close('dg', ref_dg, tri_dg, 0.005) assert_close('dh0', ref_dh0, tri_dh0, 0.005) @@ -164,12 +164,12 @@ def test_fused_recurrent_varlen( tri_dg, g.grad = g.grad.clone(), None tri_dh0, h0.grad = h0.grad.clone(), None - assert_close(' o', ref, tri, 0.005) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.005) - assert_close(' dk', ref_dk, tri_dk, 0.005) - assert_close(' dv', ref_dv, tri_dv, 0.005) - assert_close(' dg', ref_dg, tri_dg, 0.005) + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.005) + assert_close('dk', ref_dk, tri_dk, 0.005) + assert_close('dv', ref_dv, tri_dv, 0.005) + assert_close('dg', ref_dg, tri_dg, 0.005) assert_close('dh0', ref_dh0, tri_dh0, 0.005) @@ -240,12 +240,12 @@ def test_chunk( ref_dg, g.grad = g.grad.clone(), None ref_dh0, h0.grad = h0.grad.clone(), None - assert_close(' o', ref, tri, 0.004) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.005) - assert_close(' dk', ref_dk, tri_dk, 0.005) - assert_close(' dv', ref_dv, tri_dv, 0.005) - assert_close(' dg', ref_dg, tri_dg, 0.005) + assert_close('o', ref, tri, 0.004) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.005) + assert_close('dk', ref_dk, tri_dk, 0.005) + assert_close('dv', ref_dv, tri_dv, 0.005) + assert_close('dg', ref_dg, tri_dg, 0.005) assert_close('dh0', ref_dh0, tri_dh0, 0.005) @@ -318,10 +318,10 @@ def test_chunk_varlen( tri_dg, g.grad = g.grad.clone(), None tri_dh0, h0.grad = h0.grad.clone(), None - assert_close(' o', ref, tri, 0.004) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.005) - assert_close(' dk', ref_dk, tri_dk, 0.005) - assert_close(' dv', ref_dv, tri_dv, 0.005) - assert_close(' dg', ref_dg, tri_dg, 0.005) + assert_close('o', ref, tri, 0.004) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.005) + assert_close('dk', ref_dk, tri_dk, 0.005) + assert_close('dv', ref_dv, tri_dv, 0.005) + assert_close('dg', ref_dg, tri_dg, 0.005) assert_close('dh0', ref_dh0, tri_dh0, 0.005) diff --git a/tests/ops/test_gsa.py b/tests/ops/test_gsa.py index 91f6163260..1103d8bd09 100644 --- a/tests/ops/test_gsa.py +++ b/tests/ops/test_gsa.py @@ -83,14 +83,14 @@ def test_fused_recurrent( tri_dhk0, hk0.grad = hk0.grad.clone(), None tri_dhv0, hv0.grad = hv0.grad.clone(), None - assert_close(' o', ref, tri, 0.005) - assert_close(' hkt', ref_hkt, tri_hkt, 0.005) - assert_close(' hvt', ref_hvt, tri_hvt, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.005) - assert_close(' dk', ref_dk, tri_dk, 0.005) - assert_close(' dv', ref_dv, tri_dv, 0.005) - assert_close(' ds', ref_ds, tri_ds, 0.005) - assert_close(' dg', ref_dg, tri_dg, 0.005) + assert_close('o', ref, tri, 0.005) + assert_close('hkt', ref_hkt, tri_hkt, 0.005) + assert_close('hvt', ref_hvt, tri_hvt, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.005) + assert_close('dk', ref_dk, tri_dk, 0.005) + assert_close('dv', ref_dv, tri_dv, 0.005) + assert_close('ds', ref_ds, tri_ds, 0.005) + assert_close('dg', ref_dg, tri_dg, 0.005) assert_close('dhk0', ref_dhk0, tri_dhk0, 0.005) assert_close('dhv0', ref_dhv0, tri_dhv0, 0.005) @@ -181,14 +181,14 @@ def test_fused_recurrent_varlen( tri_dhk0, hk0.grad = hk0.grad.clone(), None tri_dhv0, hv0.grad = hv0.grad.clone(), None - assert_close(' o', ref, tri, 0.005) - assert_close(' hkt', ref_hkt, tri_hkt, 0.005) - assert_close(' hvt', ref_hvt, tri_hvt, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.005) - assert_close(' dk', ref_dk, tri_dk, 0.005) - assert_close(' dv', ref_dv, tri_dv, 0.005) - assert_close(' ds', ref_ds, tri_ds, 0.005) - assert_close(' dg', ref_dg, tri_dg, 0.005) + assert_close('o', ref, tri, 0.005) + assert_close('hkt', ref_hkt, tri_hkt, 0.005) + assert_close('hvt', ref_hvt, tri_hvt, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.005) + assert_close('dk', ref_dk, tri_dk, 0.005) + assert_close('dv', ref_dv, tri_dv, 0.005) + assert_close('ds', ref_ds, tri_ds, 0.005) + assert_close('dg', ref_dg, tri_dg, 0.005) assert_close('dhk0', ref_dhk0, tri_dhk0, 0.005) assert_close('dhv0', ref_dhv0, tri_dhv0, 0.005) @@ -273,14 +273,14 @@ def test_chunk( tri_dhk0, hk0.grad = hk0.grad.clone(), None tri_dhv0, hv0.grad = hv0.grad.clone(), None - assert_close(' o', ref, tri, 0.005) - assert_close(' hkt', ref_hkt, tri_hkt, 0.005) - assert_close(' hvt', ref_hvt, tri_hvt, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.005) - assert_close(' dk', ref_dk, tri_dk, 0.005) - assert_close(' dv', ref_dv, tri_dv, 0.005) - assert_close(' ds', ref_ds, tri_ds, 0.008) - assert_close(' dg', ref_dg, tri_dg, 0.008) + assert_close('o', ref, tri, 0.005) + assert_close('hkt', ref_hkt, tri_hkt, 0.005) + assert_close('hvt', ref_hvt, tri_hvt, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.005) + assert_close('dk', ref_dk, tri_dk, 0.005) + assert_close('dv', ref_dv, tri_dv, 0.005) + assert_close('ds', ref_ds, tri_ds, 0.008) + assert_close('dg', ref_dg, tri_dg, 0.008) assert_close('dhk0', ref_dhk0, tri_dhk0, 0.005) assert_close('dhv0', ref_dhv0, tri_dhv0, 0.005) @@ -371,14 +371,14 @@ def test_chunk_varlen( tri_dhk0, hk0.grad = hk0.grad.clone(), None tri_dhv0, hv0.grad = hv0.grad.clone(), None - assert_close(' o', ref, tri, 0.004) - assert_close(' hkt', ref_hkt, tri_hkt, 0.005) - assert_close(' hvt', ref_hvt, tri_hvt, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.005) - assert_close(' dk', ref_dk, tri_dk, 0.005) - assert_close(' dv', ref_dv, tri_dv, 0.005) - assert_close(' ds', ref_ds, tri_ds, 0.005) - assert_close(' dg', ref_dg, tri_dg, 0.005) + assert_close('o', ref, tri, 0.004) + assert_close('hkt', ref_hkt, tri_hkt, 0.005) + assert_close('hvt', ref_hvt, tri_hvt, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.005) + assert_close('dk', ref_dk, tri_dk, 0.005) + assert_close('dv', ref_dv, tri_dv, 0.005) + assert_close('ds', ref_ds, tri_ds, 0.005) + assert_close('dg', ref_dg, tri_dg, 0.005) assert_close('dhk0', ref_dhk0, tri_dhk0, 0.005) assert_close('dhv0', ref_dhv0, tri_dhv0, 0.005) diff --git a/tests/ops/test_hgrn.py b/tests/ops/test_hgrn.py index 26f38c1293..a041feeecf 100644 --- a/tests/ops/test_hgrn.py +++ b/tests/ops/test_hgrn.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import os +from typing import List import pytest import torch @@ -8,30 +9,21 @@ from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn from fla.ops.hgrn.naive import naive_recurrent_hgrn -from fla.utils import COMPILER_MODE, assert_close, device - -if COMPILER_MODE: - test_b_list = [1] - test_t_list = [4096] - test_t_varlen_list = test_t_list - test_d_list = [500, 1024] - test_gate_list = [1.0] -else: - test_b_list = [2] - test_t_list = [1, 15, 63, 300] - test_t_varlen_list = [63, 286, 300, 512] - test_d_list = [500, 1024] - test_gate_list = [1, 0.1, 10] -test_h_list = [2] - - -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('dtype', [torch.float]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +from fla.utils import assert_close, device + + +@pytest.mark.parametrize( + ('B', 'T', 'D', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-D{}-{}".format(*test)) + for test in [ + (1, 63, 500, torch.float), + (2, 1024, 500, torch.float), + (2, 1024, 512, torch.float), + (2, 1024, 1000, torch.float), + (4, 2048, 2048, torch.float), + ] + ] ) def test_fused_recurrent( B: int, @@ -62,35 +54,36 @@ def test_fused_recurrent( tri_dg, g.grad = g.grad.clone(), None tri_dh0, h0.grad = h0.grad.clone(), None - assert_close(' o', ref, tri, 0.005) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dx', ref_dx, tri_dx, 0.005) - assert_close(' dg', ref_dg, tri_dg, 0.005) + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dx', ref_dx, tri_dx, 0.005) + assert_close('dg', ref_dg, tri_dg, 0.005) assert_close('dh0', ref_dh0, tri_dh0, 0.005) -@pytest.mark.parametrize('N', test_b_list) -@pytest.mark.parametrize('T', test_t_varlen_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('dtype', [torch.float]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +@pytest.mark.parametrize( + ('D', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="D{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (500, [0, 15], torch.float), + (512, [0, 256, 500, 1000], torch.float), + (1000, [0, 15, 100, 300, 1200, 2000], torch.float), + (2048, [0, 200, 512, 1200, 2048], torch.float16), + ] + ] ) def test_fused_recurrent_varlen( - N: int, - T: int, D: int, + cu_seqlens: List[int], dtype: torch.dtype ): torch.manual_seed(42) os.environ['TRITON_F32_DEFAULT'] = 'ieee' - # randomly split the sequence into N segments - cu_seqlens = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 16)[:N-1]], - torch.tensor([T], dtype=torch.long) - ], 0).to(device).sort()[0] + + N = len(cu_seqlens) - 1 + T = cu_seqlens[-1] + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) x = torch.randn((1, T, D), dtype=dtype, device=device) g = torch.randn((1, T, D), dtype=dtype, device=device) @@ -123,20 +116,24 @@ def test_fused_recurrent_varlen( tri_dg, g.grad = g.grad.clone(), None tri_dh0, h0.grad = h0.grad.clone(), None - assert_close(' o', ref, tri, 0.005) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dx', ref_dx, tri_dx, 0.005) - assert_close(' dg', ref_dg, tri_dg, 0.005) + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dx', ref_dx, tri_dx, 0.005) + assert_close('dg', ref_dg, tri_dg, 0.005) assert_close('dh0', ref_dh0, tri_dh0, 0.005) -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +@pytest.mark.parametrize( + ('B', 'T', 'D', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-D{}-{}".format(*test)) + for test in [ + (1, 63, 500, torch.float16), + (2, 500, 1000, torch.float16), + (2, 1000, 1024, torch.float16), + (4, 2048, 2048, torch.float16), + ] + ] ) def test_chunk( B: int, @@ -164,6 +161,6 @@ def test_chunk( tri_dx, x.grad = x.grad.clone(), None tri_dg, g.grad = g.grad.clone(), None - assert_close(' o', ref, tri, 0.005) + assert_close('o', ref, tri, 0.005) assert_close('dx', ref_dx, tri_dx, 0.005) assert_close('dg', ref_dg, tri_dg, 0.005) diff --git a/tests/ops/test_iplr_delta.py b/tests/ops/test_iplr_delta.py index 2beb10b207..789873f0cb 100644 --- a/tests/ops/test_iplr_delta.py +++ b/tests/ops/test_iplr_delta.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -import os from typing import Optional import pytest @@ -10,21 +9,7 @@ from fla.ops.generalized_delta_rule.iplr.chunk import chunk_iplr_delta_rule from fla.ops.generalized_delta_rule.iplr.fused_recurrent import fused_recurrent_iplr_delta_rule -from fla.utils import COMPILER_MODE, assert_close, device - -if COMPILER_MODE: - test_b_list = [1] - test_t_list = [4096] - test_t_varlen_list = test_t_list - test_d_list = [64, 128, 256] - test_gate_list = [1.0] -else: - test_b_list = [2] - test_t_list = [1, 15, 63, 300] - test_t_varlen_list = [63, 286, 300, 512] - test_d_list = [32, 64, 100, 256] - test_gate_list = [1, 0.1, 10] -test_h_list = [2] +from fla.utils import assert_close, device def chunk_iplr_delta_rule_ref( @@ -131,17 +116,20 @@ def recurrence_iplr_delta_rule_ref( return o.to(orig_dtype), S -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('scale', [0.25]) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-scale{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, 1, torch.float), + (2, 1024, 4, 60, 1, torch.float), + (2, 1024, 8, 100, 1, torch.float), + (2, 1024, 8, 128, 0.1, torch.float), + (4, 2048, 8, 64, 0.1, torch.float), + ] + ] ) -def test_chunk( +def test_fused_recurrent( B: int, T: int, H: int, @@ -157,7 +145,7 @@ def test_chunk( a = F.normalize(a, p=2, dim=-1) b = -a h0 = torch.zeros(B, H, D, D, dtype=torch.float32) - q, k, v, a, b, h0 = map(lambda x: x.to(device).requires_grad_(), (q, k, v, a, b, h0)) + q, k, v, a, b, h0 = map(lambda x: x.to(device).requires_grad_(True), (q, k, v, a, b, h0)) ref, ref_ht = recurrence_iplr_delta_rule_ref( q=q.clone(), k=k.clone(), @@ -168,7 +156,12 @@ def test_chunk( initial_state=h0.clone(), output_final_state=True, ) - tri, tri_ht = chunk_iplr_delta_rule( + dht = torch.rand_like(h0) + do = torch.rand_like(ref) + ((dht * ref_ht).sum() + (do * ref).sum()).backward() + dq, dk, dv, da, db, dh0 = map(lambda x: x.grad, (q, k, v, a, b, h0)) + q.grad, k.grad, v.grad, a.grad, b.grad, h0.grad = None, None, None, None, None, None + tri, tri_ht = fused_recurrent_iplr_delta_rule( q=q.clone(), k=k.clone(), v=v.clone(), @@ -178,17 +171,32 @@ def test_chunk( initial_state=h0.clone(), output_final_state=True, ) - assert_close(' o', ref, tri, 0.007) - assert_close('ht', ref_ht, tri_ht, 0.008) + ((dht * tri_ht).sum() + (do * tri).sum()).backward() + assert_close('o', ref, tri, 0.003) + assert_close('ht', ref_ht, tri_ht, 0.003) + assert_close('dq', dq, q.grad, 0.003) + assert_close('dk', dk, k.grad, 0.003) + assert_close('dv', dv, v.grad, 0.003) + assert_close('da', da, a.grad, 0.003) + assert_close('db', db, b.grad, 0.003) + assert_close('dh0', dh0, h0.grad, 0.003) -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('scale', [0.25]) -@pytest.mark.parametrize('dtype', [torch.float16]) -def test_recurrent( +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-scale{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, 1, torch.float16), + (2, 500, 3, 60, 1, torch.float16), + (2, 1000, 3, 64, 0.1, torch.float16), + (2, 1024, 4, 100, 1, torch.float16), + (3, 1024, 4, 128, 0.1, torch.float16), + (4, 2048, 8, 64, 0.1, torch.float16) + ] + ] +) +def test_chunk( B: int, T: int, H: int, @@ -204,7 +212,7 @@ def test_recurrent( a = F.normalize(a, p=2, dim=-1) b = -a h0 = torch.zeros(B, H, D, D, dtype=torch.float32) - q, k, v, a, b, h0 = map(lambda x: x.to(device).requires_grad_(True), (q, k, v, a, b, h0)) + q, k, v, a, b, h0 = map(lambda x: x.to(device).requires_grad_(), (q, k, v, a, b, h0)) ref, ref_ht = recurrence_iplr_delta_rule_ref( q=q.clone(), k=k.clone(), @@ -215,12 +223,7 @@ def test_recurrent( initial_state=h0.clone(), output_final_state=True, ) - dht = torch.rand_like(h0) - do = torch.rand_like(ref) - ((dht * ref_ht).sum() + (do * ref).sum()).backward() - dq, dk, dv, da, db, dh0 = map(lambda x: x.grad, (q, k, v, a, b, h0)) - q.grad, k.grad, v.grad, a.grad, b.grad, h0.grad = None, None, None, None, None, None - tri, tri_ht = fused_recurrent_iplr_delta_rule( + tri, tri_ht = chunk_iplr_delta_rule( q=q.clone(), k=k.clone(), v=v.clone(), @@ -230,12 +233,5 @@ def test_recurrent( initial_state=h0.clone(), output_final_state=True, ) - ((dht * tri_ht).sum() + (do * tri).sum()).backward() - assert_close(' o', ref, tri, 0.003) - assert_close(' ht', ref_ht, tri_ht, 0.003) - assert_close(' dq', dq, q.grad, 0.003) - assert_close(' dk', dk, k.grad, 0.003) - assert_close(' dv', dv, v.grad, 0.003) - assert_close(' da', da, a.grad, 0.003) - assert_close(' db', db, b.grad, 0.003) - assert_close('dh0', dh0, h0.grad, 0.003) + assert_close('o', ref, tri, 0.007) + assert_close('ht', ref_ht, tri_ht, 0.008) diff --git a/tests/ops/test_linear_attn.py b/tests/ops/test_linear_attn.py index ab98fc650a..2540609167 100644 --- a/tests/ops/test_linear_attn.py +++ b/tests/ops/test_linear_attn.py @@ -1,37 +1,25 @@ # -*- coding: utf-8 -*- -import os - import pytest import torch from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn, fused_recurrent_linear_attn from fla.ops.linear_attn.naive import naive_chunk_linear_attn -from fla.utils import COMPILER_MODE, assert_close, device - -if COMPILER_MODE: - test_b_list = [1] - test_t_list = [4096] - test_t_varlen_list = test_t_list - test_d_list = [64, 32, 128] - test_gate_list = [1.0] -else: - test_b_list = [2] - test_t_list = [64, 128] - test_t_varlen_list = [63, 286, 300, 512] - test_d_list = [64, 32, 128] - test_gate_list = [1, 0.1, 10] -test_h_list = [2] - - -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('dtype', [torch.float]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +from fla.utils import assert_close, device + + +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-{}".format(*test)) + for test in [ + (1, 64, 1, 64, torch.float), + (2, 512, 4, 60, torch.float), + (3, 1024, 8, 128, torch.float), + (2, 2048, 8, 256, torch.float16), + (2, 2048, 4, 256, torch.float16), + ] + ] ) def test_fused_recurrent( B: int, @@ -58,20 +46,24 @@ def test_fused_recurrent( tri_dk, k.grad = k.grad.clone(), None tri_dv, v.grad = v.grad.clone(), None - assert_close(' o', ref, tri, 0.001) + assert_close('o', ref, tri, 0.001) assert_close('dq', ref_dq, tri_dq, 0.001) assert_close('dk', ref_dk, tri_dk, 0.001) assert_close('dv', ref_dv, tri_dv, 0.001) -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, torch.float16), + (2, 500, 3, 60, torch.float16), + (2, 1000, 3, 128, torch.float16), + (3, 1000, 4, 64, torch.float16), + (2, 2048, 4, 256, torch.float16), + ] + ] ) def test_chunk( B: int, @@ -114,21 +106,25 @@ def test_chunk( tri_dk, k.grad = k.grad.clone(), None tri_dv, v.grad = v.grad.clone(), None - assert_close(' o', ref, tri, 0.001) + assert_close('o', ref, tri, 0.001) assert_close('ht', ref_ht, tri_ht, 0.001) assert_close('dq', ref_dq, tri_dq, 0.001) assert_close('dk', ref_dk, tri_dk, 0.001) assert_close('dv', ref_dv, tri_dv, 0.001) -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, torch.float16), + (2, 500, 3, 60, torch.float16), + (2, 1000, 3, 128, torch.float16), + (3, 1000, 4, 64, torch.float16), + (2, 2048, 4, 256, torch.float16), + ] + ] ) def test_fused_chunk( B: int, @@ -143,6 +139,8 @@ def test_fused_chunk( v = torch.randn((B, T, H, D), dtype=dtype, device=device).requires_grad_() h0 = torch.zeros((B, H, D, D), dtype=dtype, device=device).requires_grad_() do = torch.randn_like(v) + dht = torch.randn((B, H, D, D), dtype=dtype, device=device) + ref, ref_ht = fused_recurrent_linear_attn( q.to(torch.float32), k.to(torch.float32), @@ -151,9 +149,7 @@ def test_fused_chunk( output_final_state=True, normalize=False ) - ref = ref.to(dtype) - ref_ht = ref_ht.to(dtype) - ref.backward(do) + ((ref * do).sum() + (ref_ht * dht).sum()).backward() ref_dq, q.grad = q.grad.clone(), None ref_dk, k.grad = k.grad.clone(), None ref_dv, v.grad = v.grad.clone(), None @@ -166,12 +162,12 @@ def test_fused_chunk( output_final_state=True, normalize=False ) - tri.backward(do) + ((tri * do).sum() + (tri_ht * dht).sum()).backward() tri_dq, q.grad = q.grad.clone(), None tri_dk, k.grad = k.grad.clone(), None tri_dv, v.grad = v.grad.clone(), None - assert_close(' o', ref, tri, 0.001) + assert_close('o', ref, tri, 0.001) assert_close('ht', ref_ht, tri_ht, 0.001) assert_close('dq', ref_dq, tri_dq, 0.001) assert_close('dk', ref_dk, tri_dk, 0.001) diff --git a/tests/ops/test_mesa.py b/tests/ops/test_mesa.py index f860f4df1f..dd78646abe 100644 --- a/tests/ops/test_mesa.py +++ b/tests/ops/test_mesa.py @@ -8,17 +8,7 @@ import torch.nn.functional as F from fla.ops.mesa_net import chunk_mesa_net, mesa_net_decoding_one_step, naive_mesa_net_decoding_one_step, naive_mesa_net_exact -from fla.utils import COMPILER_MODE, assert_close, device, device_platform, is_intel_alchemist - -if COMPILER_MODE: - test_b_list = [1] - test_t_list = [512] - test_d_list = [128] -else: - test_b_list = [2] - test_t_list = [15, 63, 300, 1000] - test_d_list = [128, 64, 50, 100] -test_h_list = [3] +from fla.utils import assert_close, device, device_platform, is_intel_alchemist @pytest.mark.parametrize( diff --git a/tests/ops/test_nsa.py b/tests/ops/test_nsa.py index acc813f4a1..124b2b4533 100644 --- a/tests/ops/test_nsa.py +++ b/tests/ops/test_nsa.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import os +from typing import List import pytest import torch @@ -9,47 +10,36 @@ from fla.ops.nsa.naive import naive_nsa from fla.ops.nsa.parallel import parallel_nsa from fla.ops.utils import prepare_token_indices -from fla.utils import COMPILER_MODE, assert_close, device - -if COMPILER_MODE: - test_b_list = [1] - test_t_list = [4096] - test_t_varlen_list = test_t_list -else: - test_b_list = [2] - test_t_list = [256, 1024, 2000] - test_t_varlen_list = [63, 286, 300, 512] -test_h_list = [2] +from fla.utils import assert_close, device # FIXME -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('HQ', [64]) -@pytest.mark.parametrize('D', [100, 64]) -@pytest.mark.parametrize('S', [16]) -@pytest.mark.parametrize('block_size', [32]) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.parametrize('scale', [0.1]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +@pytest.mark.parametrize( + ('B', 'T', 'H', 'HQ', 'D', 'S', 'block_size', 'scale', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-HQ{}-D{}-S{}-block_size{}-scale{}-{}".format(*test)) + for test in [ + (1, 63, 1, 1, 64, 16, 32, 1.0, torch.float16), + (3, 111, 2, 2, 100, 16, 32, 1.0, torch.float16), + (3, 1024, 2, 8, 60, 16, 32, 0.1, torch.float16), + (3, 1024, 2, 8, 128, 16, 32, 0.1, torch.float16), + (4, 2048, 2, 8, 64, 16, 32, 0.1, torch.float16) + ] + ] ) @pytest.mark.skipif( - True, - reason='TBD' + True, reason='TBD' ) def test_parallel( B: int, + T: int, H: int, HQ: int, - T: int, D: int, S: int, block_size: int, + scale: float, dtype: torch.dtype, - scale: float ): torch.manual_seed(42) os.environ['TRITON_F32_DEFAULT'] = 'ieee' @@ -85,41 +75,39 @@ def test_parallel( assert_close("dv", ref_dv, tri_dv, 0.005) -@pytest.mark.parametrize('N', test_b_list) -@pytest.mark.parametrize('T', test_t_varlen_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('HQ', [64]) -@pytest.mark.parametrize('D', [100, 64]) -@pytest.mark.parametrize('S', [16]) -@pytest.mark.parametrize('block_size', [32]) -@pytest.mark.parametrize('dtype', [torch.bfloat16]) +@pytest.mark.parametrize( + ('H', 'HQ', 'D', 'S', 'block_size', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-HQ{}-D{}-S{}-block_size{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (2, 2, 64, 16, 32, [0, 15], torch.float16), + (2, 8, 64, 16, 32, [0, 256, 500, 1000], torch.float16), + (2, 2, 100, 16, 32, [0, 15, 100, 300, 1200, 2000], torch.float16), + ] + ] +) @pytest.mark.skipif( os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', reason='Skipping test because SKIP_TEST_CHUNK_VARLEN is set' ) @pytest.mark.skipif( - True, - reason='TBD' + True, reason='TBD' ) def test_parallel_varlen( - N: int, - T: int, H: int, HQ: int, D: int, S: int, block_size: int, + cu_seqlens: List[int], dtype: torch.dtype, ): torch.manual_seed(42) os.environ['TRITON_F32_DEFAULT'] = 'ieee' - # randomly split the sequence into N segments - cu_seqlens = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 16)[:N-1]], - torch.tensor([T], dtype=torch.long) - ], 0).to(device).sort()[0] + T = cu_seqlens[-1] + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + # seq-first required for inputs with variable lengths q = torch.randn((1, T, HQ, D), dtype=dtype, device=device).requires_grad_() k = torch.randn((1, T, H, D), dtype=dtype, device=device).requires_grad_() @@ -162,7 +150,7 @@ def test_parallel_varlen( tri_dk, k.grad = k.grad.clone(), None tri_dv, v.grad = v.grad.clone(), None - assert_close(' o', ref, tri, 0.004) + assert_close('o', ref, tri, 0.004) assert_close('dq', ref_dq, tri_dq, 0.005) assert_close('dk', ref_dk, tri_dk, 0.005) assert_close('dv', ref_dv, tri_dv, 0.005) diff --git a/tests/ops/test_path_attn.py b/tests/ops/test_path_attn.py index 30cee2fbc0..64d9c9595a 100644 --- a/tests/ops/test_path_attn.py +++ b/tests/ops/test_path_attn.py @@ -8,19 +8,7 @@ from einops import rearrange from fla.ops.path_attn.parallel import parallel_path_attention -from fla.utils import COMPILER_MODE, assert_close, check_shared_mem, device, is_intel_alchemist - -if COMPILER_MODE: - test_b_list = [1] - test_t_list = [1024] - test_d_list = [64] -else: - test_b_list = [2] - test_t_list = [63, 300, 4095] - test_d_list = [64] -test_fgate_logit_range_list = [(0.95, 1), (1, 1)] -test_hq_list = [8, 16] -test_h_list = [2] +from fla.utils import assert_close, check_shared_mem, device, is_intel_alchemist def naive_path_attn(q, k, v, w, beta, g, scale, BT=64): @@ -74,16 +62,18 @@ def naive_path_attn(q, k, v, w, beta, g, scale, BT=64): return ref_o.to(original_dtype).transpose(1, 2) -@pytest.mark.parametrize("B", test_b_list) -@pytest.mark.parametrize("T", test_t_list) -@pytest.mark.parametrize("H", test_h_list) -@pytest.mark.parametrize("HQ", test_hq_list) -@pytest.mark.parametrize("D", test_d_list) -@pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize("use_forget_gate", [True, False]) -@pytest.mark.skipif( - os.getenv("SKIP_TEST_CHUNK_VARLEN") == "0", - reason="Skipping test because TEST_CHUNK_VARLEN is enabled" +@pytest.mark.parametrize( + ('B', 'T', 'H', 'HQ', 'D', 'use_forget_gate', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-HQ{}-D{}-use_forget_gate{}-{}".format(*test)) + for test in [ + (1, 63, 1, 1, 32, False, torch.float16), + (3, 111, 2, 2, 32, False, torch.float16), + (3, 1024, 2, 8, 64, True, torch.float16), + (3, 1024, 2, 8, 64, True, torch.float16), + (4, 2048, 2, 8, 64, False, torch.float16) + ] + ] ) @pytest.mark.skipif( is_intel_alchemist, @@ -146,12 +136,19 @@ def test_parallel( assert_close("db", ref_db, tri_db, 0.005) -@pytest.mark.parametrize("cu_seqlens", [[0, 19, 321, 394, 1111, 2048], [0, 621, 1024, 4222]]) -@pytest.mark.parametrize("H", test_h_list) -@pytest.mark.parametrize("HQ", test_hq_list) -@pytest.mark.parametrize("D", test_d_list) -@pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize("use_forget_gate", [True, False]) +@pytest.mark.parametrize( + ('H', 'HQ', 'D', 'use_forget_gate', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-HQ{}-D{}-use_forget_gate{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (2, 2, 32, False, [0, 15], torch.float16), + (2, 8, 32, False, [0, 256, 500, 1000], torch.float16), + (2, 8, 64, True, [0, 100, 500, 800, 1000], torch.float16), + (2, 2, 64, False, [0, 15, 100, 300, 1200, 2000], torch.float16), + (2, 2, 64, True, [0, 100, 300, 1000, 1989, 2000], torch.float16), + ] + ] +) @pytest.mark.skipif( os.getenv("SKIP_TEST_CHUNK_VARLEN") == "0", reason="Skipping test because TEST_CHUNK_VARLEN is enabled" @@ -161,11 +158,11 @@ def test_parallel( reason="Intel Triton Failure" ) def test_parallel_varlen( - cu_seqlens: List[int], H: int, HQ: int, D: int, use_forget_gate: bool, + cu_seqlens: List[int], dtype: torch.dtype ): if not check_shared_mem('hopper') and D > 128: diff --git a/tests/ops/test_retention.py b/tests/ops/test_retention.py index b6f937d7eb..9ddf6acd13 100644 --- a/tests/ops/test_retention.py +++ b/tests/ops/test_retention.py @@ -1,36 +1,28 @@ # -*- coding: utf-8 -*- import os +from typing import List import pytest import torch from fla.ops.retention import chunk_retention, fused_recurrent_retention, parallel_retention -from fla.ops.retention.naive import naive_retention -from fla.utils import COMPILER_MODE, assert_close, device - -if COMPILER_MODE: - test_b_list = [1] - test_t_list = [4096] - test_t_varlen_list = test_t_list - test_d_list = [64, 32, 100] -else: - test_b_list = [2] - test_t_list = [1, 15, 63, 300] - test_t_varlen_list = [63, 286, 300, 512] - test_d_list = [64, 32, 100] -test_h_list = [2] - - -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('K', test_d_list) -@pytest.mark.parametrize('expand_ratio', [1, 2]) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +from fla.utils import assert_close, device + + +@pytest.mark.parametrize( + ('B', 'T', 'H', 'K', 'expand_ratio', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-K{}-expand_ratio{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, 1, torch.float16), + (2, 1024, 3, 60, 1, torch.float16), + (2, 1024, 3, 100, 1, torch.float16), + (2, 1000, 3, 128, 2, torch.float16), + (2, 1024, 4, 256, 2, torch.float16), + (4, 2048, 4, 64, 2, torch.float16) + ] + ] ) def test_chunk( B: int, @@ -63,41 +55,43 @@ def test_chunk( tri_dk, k.grad = k.grad.clone(), None tri_dv, v.grad = v.grad.clone(), None - assert_close(' o', ref, tri, 0.005) + assert_close('o', ref, tri, 0.005) assert_close('ht', ref_ht, tri_ht, 0.005) assert_close('dq', ref_dq, tri_dq, 0.005) assert_close('dk', ref_dk, tri_dk, 0.005) assert_close('dv', ref_dv, tri_dv, 0.005) -@pytest.mark.parametrize('N', test_b_list) -@pytest.mark.parametrize('T', test_t_varlen_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('K', test_d_list) -@pytest.mark.parametrize('expand_ratio', [1, 2]) -@pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize( + ('H', 'K', 'expand_ratio', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-K{}-expand_ratio{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (4, 64, 1, [0, 15], torch.float16), + (4, 64, 2, [0, 256, 500, 1000], torch.float16), + (4, 100, 2, [0, 15, 100, 300, 1200, 2000], torch.float16), + ] + ] +) @pytest.mark.skipif( os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', reason='Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set' ) def test_chunk_varlen( - N: int, - T: int, H: int, K: int, expand_ratio: int, + cu_seqlens: List[int], dtype: torch.dtype, ): torch.manual_seed(42) os.environ['TRITON_F32_DEFAULT'] = 'ieee' V = K * expand_ratio - # randomly split the sequence into N segments - offsets = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 16)[:N-1]], - torch.tensor([T], dtype=torch.long) - ], 0).to(device).sort()[0] + N = len(cu_seqlens) - 1 + T = cu_seqlens[-1] + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.long, device=device) + # seq-first required for inputs with variable lengths q = torch.randn((1, T, H, K), dtype=dtype, device=device).requires_grad_() k = torch.randn((1, T, H, K), dtype=dtype, device=device).requires_grad_() @@ -112,7 +106,7 @@ def test_chunk_varlen( v=v, initial_state=h0, output_final_state=True, - cu_seqlens=offsets, + cu_seqlens=cu_seqlens, ) ((ref * do).sum() + (ref_ht * dht).sum()).backward() ref_dq, q.grad = q.grad.clone(), None @@ -126,7 +120,7 @@ def test_chunk_varlen( v=v, initial_state=h0, output_final_state=True, - cu_seqlens=offsets, + cu_seqlens=cu_seqlens, ) ((tri * do).sum() + (tri_ht * dht).sum()).backward() tri_dq, q.grad = q.grad.clone(), None @@ -134,23 +128,27 @@ def test_chunk_varlen( tri_dv, v.grad = v.grad.clone(), None tri_dh0, h0.grad = h0.grad.clone(), None - assert_close(' o', ref, tri, 0.004) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.005) - assert_close(' dk', ref_dk, tri_dk, 0.005) - assert_close(' dv', ref_dv, tri_dv, 0.005) + assert_close('o', ref, tri, 0.004) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.005) + assert_close('dk', ref_dk, tri_dk, 0.005) + assert_close('dv', ref_dv, tri_dv, 0.005) assert_close('dh0', ref_dh0, tri_dh0, 0.005) -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('K', test_d_list) -@pytest.mark.parametrize('expand_ratio', [1, 2]) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +@pytest.mark.parametrize( + ('B', 'T', 'H', 'K', 'expand_ratio', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-K{}-expand_ratio{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, 1, torch.float16), + (2, 500, 4, 60, 1, torch.float16), + (2, 1024, 8, 128, 1, torch.float16), + (3, 1024, 8, 128, 2, torch.float16), + (3, 1024, 8, 256, 2, torch.float16), + (4, 2048, 8, 64, 2, torch.float16) + ] + ] ) def test_parallel( B: int, @@ -168,7 +166,7 @@ def test_parallel( v = torch.randn((B, T, H, V), dtype=dtype, device=device).requires_grad_() do = torch.randn_like(v) - ref = naive_retention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2) + ref, _ = fused_recurrent_retention(q, k, v) ref.backward(do) ref_dq, q.grad = q.grad.clone(), None ref_dk, k.grad = k.grad.clone(), None @@ -180,7 +178,7 @@ def test_parallel( tri_dk, k.grad = k.grad.clone(), None tri_dv, v.grad = v.grad.clone(), None - assert_close(' o', ref, tri, 0.005) + assert_close('o', ref, tri, 0.005) assert_close('dq', ref_dq, tri_dq, 0.005) assert_close('dk', ref_dk, tri_dk, 0.005) assert_close('dv', ref_dv, tri_dv, 0.005) diff --git a/tests/ops/test_rwkv6.py b/tests/ops/test_rwkv6.py index 02c9993f43..c0a19e10fe 100644 --- a/tests/ops/test_rwkv6.py +++ b/tests/ops/test_rwkv6.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import os +from typing import List import pytest import torch @@ -8,44 +9,35 @@ from fla.ops.rwkv6 import chunk_rwkv6 from fla.ops.rwkv6.fused_recurrent import fused_recurrent_rwkv6 -from fla.utils import COMPILER_MODE, assert_close, device, device_platform - -if COMPILER_MODE: - test_b_list = [1] - test_t_list = [4096] - test_t_varlen_list = test_t_list - test_d_list = [64, 128, 256] - test_gate_list = [1.0] -else: - test_b_list = [2] - test_t_list = [1, 15, 63, 300] - test_t_varlen_list = [63, 286, 300, 512] - test_d_list = [64, 32, 100, 256] - test_gate_list = [1, 0.1, 10] -test_h_list = [2] - - -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('gate_logit_normalizer', test_gate_list) -@pytest.mark.parametrize('dtype', [torch.bfloat16]) -@pytest.mark.skipif( - os.getenv("SKIP_TEST_CHUNK_VARLEN") == "0", - reason="Skipping test because TEST_CHUNK_VARLEN is enabled" -) +from fla.utils import assert_close, device, device_platform + + @pytest.mark.skipif( device_platform == 'intel', reason="Intel Triton Failure" ) +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'gate_logit_normalizer', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-gate_logit_normalizer{}-{}".format(*test)) + for test in [ + (1, 15, 2, 60, 1.0, torch.float16), + (3, 60, 3, 64, 0.1, torch.float16), + (3, 64, 2, 64, 1, torch.float16), + (4, 500, 3, 256, 1, torch.float16), + (4, 1000, 4, 64, 10, torch.float16), + (4, 2048, 4, 64, 1, torch.float16), + (4, 2048, 4, 256, 1, torch.float16), + ] + ] +) def test_chunk( B: int, T: int, H: int, D: int, - dtype: torch.dtype, gate_logit_normalizer: float, + dtype: torch.dtype, ): torch.manual_seed(42) os.environ['TRITON_F32_DEFAULT'] = 'ieee' @@ -105,40 +97,39 @@ def test_chunk( tri_du, u.grad = u.grad.clone(), None tri_dh0, h0.grad = h0.grad.clone(), None - assert_close(' o', ref, tri, 0.004) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.005) - assert_close(' dk', ref_dk, tri_dk, 0.005) - assert_close(' dv', ref_dv, tri_dv, 0.005) - assert_close(' dw', ref_dw, tri_dw, 0.005) - assert_close(' du', ref_du, tri_du, 0.005) + assert_close('o', ref, tri, 0.004) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.005) + assert_close('dk', ref_dk, tri_dk, 0.005) + assert_close('dv', ref_dv, tri_dv, 0.005) + assert_close('dw', ref_dw, tri_dw, 0.005) + assert_close('du', ref_du, tri_du, 0.005) assert_close('dh0', ref_dh0, tri_dh0, 0.005) -@pytest.mark.parametrize("N", test_b_list) -@pytest.mark.parametrize("T", test_t_varlen_list) -@pytest.mark.parametrize("H", test_h_list) -@pytest.mark.parametrize("D", test_d_list) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float]) -@pytest.mark.skipif( - os.getenv("SKIP_TEST_CHUNK_VARLEN") == "1", - reason="Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set" +@pytest.mark.parametrize( + ('H', 'D', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-D{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (4, 64, [0, 15], torch.float16), + (4, 64, [0, 256, 500, 1000], torch.float16), + (4, 100, [0, 15, 100, 300, 1200, 2000], torch.float16), + ] + ] ) def test_chunk_varlen( - N: int, - T: int, H: int, D: int, + cu_seqlens: List[int], dtype: torch.dtype, ): torch.manual_seed(42) os.environ['TRITON_F32_DEFAULT'] = 'ieee' - # randomly split the sequence into N segments - cu_seqlens = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 16)[:N-1]], - torch.tensor([T], dtype=torch.long) - ], 0).to(device).sort()[0] + N = len(cu_seqlens) - 1 + T = cu_seqlens[-1] + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + # seq-first required for inputs with variable lengths q = torch.randn((1, T, H, D), dtype=dtype, device=device).requires_grad_() k = torch.randn((1, T, H, D), dtype=dtype, device=device).requires_grad_() @@ -193,11 +184,11 @@ def test_chunk_varlen( tri_dw, w.grad = w.grad.clone(), None tri_du, u.grad = u.grad.clone(), None tri_dh0, h0.grad = h0.grad.clone(), None - assert_close(' o', ref, tri, 0.004) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.005) - assert_close(' dk', ref_dk, tri_dk, 0.005) - assert_close(' dv', ref_dv, tri_dv, 0.005) - assert_close(' dw', ref_dw, tri_dw, 0.005) - assert_close(' du', ref_du, tri_du, 0.005) + assert_close('o', ref, tri, 0.004) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.005) + assert_close('dk', ref_dk, tri_dk, 0.005) + assert_close('dv', ref_dv, tri_dv, 0.005) + assert_close('dw', ref_dw, tri_dw, 0.005) + assert_close('du', ref_du, tri_du, 0.005) assert_close('dh0', ref_dh0, tri_dh0, 0.005) diff --git a/tests/ops/test_rwkv7.py b/tests/ops/test_rwkv7.py index 5de64b0c7e..b9826138a8 100644 --- a/tests/ops/test_rwkv7.py +++ b/tests/ops/test_rwkv7.py @@ -16,8 +16,8 @@ @pytest.mark.parametrize("B", [2]) @pytest.mark.parametrize("T", [1024]) -@pytest.mark.parametrize("n_embd", [512, 1024]) -@pytest.mark.parametrize("dim_ffn", [2048, 4096]) +@pytest.mark.parametrize("n_embd", [1024]) +@pytest.mark.parametrize("dim_ffn", [4096]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("inplace", [True, False]) @pytest.mark.parametrize("xprevdim", [2, 3]) @@ -127,7 +127,7 @@ def test_fused_mul_recurrent_fwd( initial_state=h0.clone(), output_final_state=True, ) - assert_close(' o', ref, tri, 0.002) + assert_close('o', ref, tri, 0.002) assert_close('ht', ref_ht, tri_ht, 0.002) diff --git a/tests/ops/test_simple_gla.py b/tests/ops/test_simple_gla.py index 4e27a41ee3..102acbb64a 100644 --- a/tests/ops/test_simple_gla.py +++ b/tests/ops/test_simple_gla.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import os +from typing import List import pytest import torch @@ -10,30 +11,25 @@ from fla.ops.simple_gla.fused_recurrent import fused_recurrent_simple_gla from fla.ops.simple_gla.naive import naive_parallel_simple_gla, naive_recurrent_simple_gla from fla.ops.simple_gla.parallel import parallel_simple_gla -from fla.utils import COMPILER_MODE, assert_close, device - -if COMPILER_MODE: - test_b_list = [1] - test_t_list = [4096] - test_t_varlen_list = test_t_list - test_d_list = [64, 128, 256] - test_gate_list = [1.0] -else: - test_b_list = [2] - test_t_list = [1, 15, 63, 300] - test_t_varlen_list = [63, 286, 300, 512] - test_d_list = [64, 32, 100, 256] - test_gate_list = [1, 0.1, 10] -test_h_list = [2] - - -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('scale', [1, 0.1]) -@pytest.mark.parametrize('gate_logit_normalizer', test_gate_list) -@pytest.mark.parametrize('dtype', [torch.float]) +from fla.utils import assert_close, device + + +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'gate_logit_normalizer', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, 1, 1, torch.float), + (2, 1024, 4, 60, 1, 1, torch.float), + (2, 1024, 8, 128, 1, 0.1, torch.float), + (2, 1024, 8, 128, 0.1, 1, torch.float), + (2, 1024, 8, 128, 1, 10, torch.float), + (4, 2048, 8, 64, 0.1, 1, torch.float), + (2, 1024, 8, 128, 1, 0.1, torch.float16), + (2, 1024, 8, 128, 1, 10, torch.float16), + ] + ] +) def test_fused_recurrent( B: int, T: int, @@ -85,39 +81,44 @@ def test_fused_recurrent( tri_dg, g.grad = g.grad.clone(), None tri_dh0, h0.grad = h0.grad.clone(), None - assert_close(' o', ref, tri, 0.005) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.005) - assert_close(' dk', ref_dk, tri_dk, 0.005) - assert_close(' dv', ref_dv, tri_dv, 0.005) - assert_close(' dg', ref_dg, tri_dg, 0.005, err_atol=2e-4) + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.005) + assert_close('dk', ref_dk, tri_dk, 0.005) + assert_close('dv', ref_dv, tri_dv, 0.005) + assert_close('dg', ref_dg, tri_dg, 0.005, err_atol=2e-4) assert_close('dh0', ref_dh0, tri_dh0, 0.005) -@pytest.mark.parametrize('N', test_b_list) -@pytest.mark.parametrize('T', test_t_varlen_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('scale', [1, 0.1]) -@pytest.mark.parametrize('gate_logit_normalizer', test_gate_list) -@pytest.mark.parametrize('dtype', [torch.float]) +@pytest.mark.parametrize( + ('H', 'D', 'scale', 'gate_logit_normalizer', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-D{}-scale{}-gate_logit_normalizer{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (4, 64, 1, 1, [0, 15], torch.float), + (4, 64, 1, 1, [0, 256, 500, 1000], torch.float), + (4, 100, 0.1, 1, [0, 15, 100, 300, 1200, 2000], torch.float), + (4, 100, 1, 1, [0, 15, 100, 300, 1200, 2000], torch.float), + (4, 100, 1, 10, [0, 15, 100, 300, 1200, 2000], torch.float), + (4, 64, 1, 1, [0, 1, 100, 300, 1200, 2048], torch.float16), + (4, 128, 1, 1, [0, 200, 512, 1200, 2048], torch.float16), + ] + ] +) def test_fused_recurrent_varlen( - N: int, - T: int, H: int, D: int, scale: float, gate_logit_normalizer: float, + cu_seqlens: List[int], dtype: torch.dtype ): torch.manual_seed(42) - # randomly split the sequence into N segments - cu_seqlens = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 16)[:N-1]], - torch.tensor([T], dtype=torch.long) - ], 0).to(device).sort()[0] + N = len(cu_seqlens) - 1 + T = cu_seqlens[-1] + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + q = torch.randn((1, T, H, D), dtype=dtype, device=device).requires_grad_() k = torch.randn((1, T, H, D), dtype=dtype, device=device).requires_grad_() v = torch.randn((1, T, H, D), dtype=dtype, device=device).requires_grad_() @@ -166,30 +167,37 @@ def test_fused_recurrent_varlen( tri_dg, g.grad = g.grad.clone(), None tri_dh0, h0.grad = h0.grad.clone(), None - assert_close(' o', ref, tri, 0.005) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.005) - assert_close(' dk', ref_dk, tri_dk, 0.005) - assert_close(' dv', ref_dv, tri_dv, 0.005) - assert_close(' dg', ref_dg, tri_dg, 0.005, err_atol=2e-4) + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.005) + assert_close('dk', ref_dk, tri_dk, 0.005) + assert_close('dv', ref_dv, tri_dv, 0.005) + assert_close('dg', ref_dg, tri_dg, 0.005, err_atol=2e-4) assert_close('dh0', ref_dh0, tri_dh0, 0.005) -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('gate_logit_normalizer', test_gate_list) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.parametrize('scale', [1, 0.1]) +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'gate_logit_normalizer', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, 1, 1, torch.float16), + (2, 1024, 4, 60, 1, 1, torch.float16), + (2, 1024, 8, 128, 1, 0.1, torch.float16), + (2, 1024, 8, 128, 0.1, 1, torch.float16), + (2, 1024, 8, 128, 0.1, 10, torch.float16), + (4, 2048, 8, 64, 0.1, 1, torch.float16) + ] + ] +) def test_chunk( B: int, T: int, H: int, D: int, - dtype: torch.dtype, scale: float, - gate_logit_normalizer: float + gate_logit_normalizer: float, + dtype: torch.dtype, ): torch.manual_seed(42) os.environ['TRITON_F32_DEFAULT'] = 'ieee' @@ -234,40 +242,43 @@ def test_chunk( tri_dg, g.grad = g.grad.clone(), None tri_dh0, h0.grad = h0.grad.clone(), None - assert_close(' o', ref, tri, 0.004) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.005) - assert_close(' dk', ref_dk, tri_dk, 0.005) - assert_close(' dv', ref_dv, tri_dv, 0.005) - assert_close(' dg', ref_dg, tri_dg, 0.005) + assert_close('o', ref, tri, 0.004) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.005) + assert_close('dk', ref_dk, tri_dk, 0.005) + assert_close('dv', ref_dv, tri_dv, 0.005) + assert_close('dg', ref_dg, tri_dg, 0.005) assert_close('dh0', ref_dh0, tri_dh0, 0.005) -@pytest.mark.parametrize('N', test_b_list) -@pytest.mark.parametrize('T', test_t_varlen_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize( + ('H', 'D', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-D{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (4, 64, [0, 15], torch.float16), + (4, 64, [0, 256, 500, 1000], torch.float16), + (4, 100, [0, 15, 100, 300, 1200, 2000], torch.float16), + ] + ] +) @pytest.mark.skipif( os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', reason='Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set' ) def test_chunk_varlen( - N: int, - T: int, H: int, D: int, + cu_seqlens: List[int], dtype: torch.dtype, ): torch.manual_seed(42) os.environ['TRITON_F32_DEFAULT'] = 'ieee' - # randomly split the sequence into N segments - cu_seqlens = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 16)[:N-1]], - torch.tensor([T], dtype=torch.long) - ], 0).to(device).sort()[0] + N = len(cu_seqlens) - 1 + T = cu_seqlens[-1] + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + # seq-first required for inputs with variable lengths q = torch.randn((1, T, H, D), dtype=dtype, device=device).requires_grad_() k = torch.randn((1, T, H, D), dtype=dtype, device=device).requires_grad_() @@ -309,34 +320,37 @@ def test_chunk_varlen( tri_dg, g.grad = g.grad.clone(), None tri_dh0, h0.grad = h0.grad.clone(), None - assert_close(' o', ref, tri, 0.004) - assert_close(' ht', ref_ht, tri_ht, 0.005) - assert_close(' dq', ref_dq, tri_dq, 0.005) - assert_close(' dk', ref_dk, tri_dk, 0.005) - assert_close(' dv', ref_dv, tri_dv, 0.005) - assert_close(' dg', ref_dg, tri_dg, 0.005) + assert_close('o', ref, tri, 0.004) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.005) + assert_close('dk', ref_dk, tri_dk, 0.005) + assert_close('dv', ref_dv, tri_dv, 0.005) + assert_close('dg', ref_dg, tri_dg, 0.005) assert_close('dh0', ref_dh0, tri_dh0, 0.005) -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('gate_logit_normalizer', test_gate_list) -@pytest.mark.parametrize('scale', [0.1]) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'gate_logit_normalizer', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, 1, 1, torch.float16), + (2, 500, 3, 60, 1, 1, torch.float16), + (2, 1024, 4, 128, 0.1, 1, torch.float16), + (3, 1024, 4, 128, 0.1, 10, torch.float16), + (3, 1024, 4, 256, 0.1, 0.1, torch.float16), + (4, 2048, 4, 64, 0.1, 0.1, torch.float16) + ] + ] ) def test_parallel( B: int, - H: int, T: int, + H: int, D: int, - dtype: torch.dtype, scale: float, gate_logit_normalizer: float, + dtype: torch.dtype, ): torch.manual_seed(42) os.environ['TRITON_F32_DEFAULT'] = 'ieee' @@ -364,8 +378,8 @@ def test_parallel( tri_dv, v.grad = v.grad.clone(), None if USE_G: tri_dg, g.grad = g.grad.clone(), None - assert_close(' o', ref, tri, 0.005) - assert_close(' A', ref_A, tri_A, 0.005) + assert_close('o', ref, tri, 0.005) + assert_close('A', ref_A, tri_A, 0.005) assert_close('dq', ref_dq, tri_dq, 0.005) assert_close('dk', ref_dk, tri_dk, 0.005) assert_close('dv', ref_dv, tri_dv, 0.005) @@ -373,31 +387,33 @@ def test_parallel( assert_close('dg', ref_dg, tri_dg, 0.015) -@pytest.mark.parametrize('N', test_b_list) -@pytest.mark.parametrize('T', test_t_varlen_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize( + ('H', 'D', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-D{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (4, 64, [0, 15], torch.float16), + (4, 64, [0, 256, 500, 1000], torch.float16), + (4, 100, [0, 15, 100, 300, 1200, 2000], torch.float16), + ] + ] +) @pytest.mark.skipif( os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', reason='Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set' ) def test_parallel_varlen( - N: int, - T: int, H: int, D: int, + cu_seqlens: List[int], dtype: torch.dtype, ): torch.manual_seed(42) os.environ['TRITON_F32_DEFAULT'] = 'ieee' - # randomly split the sequence into N segments - cu_seqlens = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 16)[:N-1]], - torch.tensor([T], dtype=torch.long) - ], 0).to(device).sort()[0] + T = cu_seqlens[-1] + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + q = torch.randn((1, T, H, D), dtype=dtype, device=device).requires_grad_() k = torch.randn((1, T, H, D), dtype=dtype, device=device).requires_grad_() v = torch.randn((1, T, H, D), dtype=dtype, device=device).requires_grad_() @@ -431,18 +447,21 @@ def test_parallel_varlen( tri_dv, v.grad = v.grad.clone(), None tri_dg, g.grad = g.grad.clone(), None - assert_close(' o', ref, tri, 0.004) - assert_close(' dq', ref_dq, tri_dq, 0.005) - assert_close(' dk', ref_dk, tri_dk, 0.005) - assert_close(' dv', ref_dv, tri_dv, 0.005) - assert_close(' dg', ref_dg, tri_dg, 0.005) + assert_close('o', ref, tri, 0.004) + assert_close('dq', ref_dq, tri_dq, 0.005) + assert_close('dk', ref_dk, tri_dk, 0.005) + assert_close('dv', ref_dv, tri_dv, 0.005) + assert_close('dg', ref_dg, tri_dg, 0.005) -@pytest.mark.parametrize('vary_A', [True, False]) -@pytest.mark.parametrize('dtype', [torch.float, torch.float16]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +@pytest.mark.parametrize( + ('vary_A', 'dtype'), + [ + pytest.param(True, torch.float, id='vary_A{}-dtype{}'.format(True, torch.float)), + pytest.param(False, torch.float, id='vary_A{}-dtype{}'.format(False, torch.float)), + pytest.param(True, torch.float16, id='vary_A{}-dtype{}'.format(True, torch.float16)), + pytest.param(False, torch.float16, id='vary_A{}-dtype{}'.format(False, torch.float16)) + ] ) def test_simple_gla_to_mamba2(vary_A, dtype): try: diff --git a/tests/ops/test_solve_tril.py b/tests/ops/test_solve_tril.py index 1e721e383d..1364cf2afa 100644 --- a/tests/ops/test_solve_tril.py +++ b/tests/ops/test_solve_tril.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import os +from typing import List import pytest import torch @@ -8,23 +9,22 @@ from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd from fla.ops.utils.solve_tril import solve_tril -from fla.utils import COMPILER_MODE, assert_close, device, device_platform +from fla.utils import assert_close, device, device_platform -if COMPILER_MODE: - test_b_list = [1] - test_t_list = [4096] - test_t_varlen_list = [[0, 64, 128, 256, 512]] -else: - test_b_list = [2] - test_t_list = [128, 200, 300, 500] - test_t_varlen_list = [[0, 63, 286, 300, 512], [0, 127, 246, 521, 1000], [0, 255, 492, 1042, 2000]] -test_h_list = [2] - -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('chunk_size', [16, 32, 64]) +@pytest.mark.parametrize( + ('B', 'T', 'H', 'chunk_size'), + [ + pytest.param(*test, id="B{}-T{}-H{}-chunk_size{}".format(*test)) + for test in [ + (1, 63, 1, 16), + (2, 500, 4, 32), + (2, 1000, 5, 64), + (3, 1024, 6, 64), + (4, 2048, 8, 64), + ] + ] +) @pytest.mark.skipif( os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', reason='Skipping test because TEST_CHUNK_VARLEN is enabled' @@ -48,9 +48,19 @@ def test_solve_tril(B, T, H, chunk_size): assert_close('solve_tril', Ai, Ai_ref, 0.0001) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('cu_seqlens', test_t_varlen_list) -@pytest.mark.parametrize('chunk_size', [64, 32, 16]) +@pytest.mark.parametrize( + ('H', 'D', 'chunk_size', 'cu_seqlens'), + [ + pytest.param(*test, id="H{}-D{}-chunk_size{}-cu_seqlens{}".format(*test)) + for test in [ + (4, 64, 16, [0, 15]), + (4, 64, 32, [0, 256, 500, 1000]), + (4, 100, 64, [0, 15, 100, 300, 1200, 2000]), + (4, 64, 16, [0, 1, 100, 300, 1200, 2048]), + (4, 128, 32, [0, 200, 512, 1200, 2048]), + ] + ] +) @pytest.mark.skipif( os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', reason='Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set' @@ -59,13 +69,18 @@ def test_solve_tril(B, T, H, chunk_size): device_platform == 'intel', reason='Intel Pytorch Failure' ) -def test_solve_tril_varlen(H, cu_seqlens, chunk_size): +def test_solve_tril_varlen( + H: int, + D: int, + chunk_size: int, + cu_seqlens: List[int], +): T = cu_seqlens[-1] cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) # Construct the input. otherwise inverse's condition number might be too large to measure the error - k = F.normalize(torch.randn((1, T, H, 64), dtype=torch.bfloat16, device=device), dim=-1) + k = F.normalize(torch.randn((1, T, H, D), dtype=torch.bfloat16, device=device), dim=-1) beta = torch.randn((1, T, H), dtype=torch.bfloat16, device=device).sigmoid() - A, _ = chunk_scaled_dot_kkt_fwd(k, beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size) + A = chunk_scaled_dot_kkt_fwd(k, beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size) Ai = solve_tril(A, cu_seqlens=cu_seqlens) Ai_ref = torch.zeros_like(Ai) diff --git a/tests/ops/test_titans.py b/tests/ops/test_titans.py index afb66fb2af..9bc6098837 100644 --- a/tests/ops/test_titans.py +++ b/tests/ops/test_titans.py @@ -1,26 +1,11 @@ # -*- coding: utf-8 -*- -import os - import pytest import torch import torch.nn.functional as F -# from fla.ops.titans.fused_chunk import fused_chunk_titans_linear from fla.ops.titans.naive import chunk_titans_linear_ref -from fla.utils import COMPILER_MODE, assert_close, device - -if COMPILER_MODE: - test_b_list = [1] - test_t_list = [4096] - test_t_varlen_list = test_t_list - test_d_list = [64, 128, 256] -else: - test_b_list = [2] - test_t_list = [1, 15, 63, 300] - test_t_varlen_list = [63, 286, 300, 512] - test_d_list = [64, 32, 100, 256] -test_h_list = [2] +from fla.utils import assert_close, device def initialize_chunked_param(B, H, T, BT, dtype=torch.float32): @@ -52,19 +37,28 @@ def initialize_chunked_param(B, H, T, BT, dtype=torch.float32): return theta -@pytest.mark.parametrize("B", test_b_list) -@pytest.mark.parametrize("T", test_t_list) -@pytest.mark.parametrize("H", test_h_list) -@pytest.mark.parametrize("D", test_d_list) -@pytest.mark.parametrize("scale", [1]) -@pytest.mark.parametrize("dtype", [torch.float32]) -@pytest.mark.parametrize("head_first", [True, False]) +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, torch.float16), + (2, 100, 4, 60, torch.float16), + (2, 1024, 3, 128, torch.float16), + (3, 2000, 4, 128, torch.float16), + (4, 2048, 8, 64, torch.float16), + ] + ] +) @pytest.mark.skipif( - os.getenv("SKIP_TEST_CHUNK_VARLEN") == "0", - reason="Skipping test because TEST_CHUNK_VARLEN is enabled" + True, reason='FIXME' ) -def test_naive_chunk_fwd( - B: int, T: int, H: int, D: int, dtype: torch.dtype, scale: float, head_first: bool +def test_naive_chunk( + B: int, + T: int, + H: int, + D: int, + dtype: torch.dtype, ): BT = 64 # set seed @@ -84,13 +78,12 @@ def test_naive_chunk_fwd( w = torch.randn(H, D, dtype=dtype) b = torch.randn(H, D, dtype=dtype) h0 = torch.randn(B, H, D, D, dtype=torch.float32) - if not head_first: - q = q.permute(0, 2, 1, 3) - k = k.permute(0, 2, 1, 3) - v = v.permute(0, 2, 1, 3) - theta = theta.permute(0, 2, 1, 3) - alpha = alpha.permute(0, 2, 1, 3) - eta = eta.permute(0, 2, 1, 3) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + theta = theta.permute(0, 2, 1, 3) + alpha = alpha.permute(0, 2, 1, 3) + eta = eta.permute(0, 2, 1, 3) q, k, v, w, b, theta, alpha, eta = map( lambda x: x.to(device).requires_grad_(False), (q, k, v, w, b, theta, alpha, eta) ) @@ -109,7 +102,6 @@ def test_naive_chunk_fwd( output_final_state=True, chunk_size=BT, initial_state=h0.clone(), - head_first=head_first, use_chunk=False, ) ref, ref_ht = chunk_titans_linear_ref( @@ -124,97 +116,8 @@ def test_naive_chunk_fwd( output_final_state=True, chunk_size=BT, initial_state=h0.clone(), - head_first=head_first, use_chunk=True, ) assert_close(" o", ref, ref_naive, 0.006) assert_close("ht", ref_ht, ref_ht_naive, 0.005) - - -# @pytest.mark.parametrize("B", test_b_list) -# @pytest.mark.parametrize("T", test_t_list) -# @pytest.mark.parametrize("H", test_h_list) -# @pytest.mark.parametrize("D", test_d_list) -# @pytest.mark.parametrize("scale", [1]) -# @pytest.mark.parametrize("dtype", [torch.float32]) -# @pytest.mark.parametrize("head_first", [True, False]) -# def test_fused_chunk_fwd( -# B: int, T: int, H: int, D: int, dtype: torch.dtype, scale: float, head_first: bool -# ): -# BT = 1 -# # set seed -# torch.manual_seed(1) -# # we don't use such initialization in the original code -# # theta = initialize_chunked_param(B, H, T, BT, dtype) -# # alpha = initialize_chunked_param(B, H, T, BT, dtype) -# # eta = initialize_chunked_param(B, H, T, BT, dtype) -# theta = torch.rand(B, H, T, 1, dtype=dtype) -# alpha = torch.rand(B, H, T, 1, dtype=dtype) -# eta = torch.rand(B, H, T, 1, dtype=dtype) - -# if head_first: -# # titans normalize queries and keys using ℓ2-normalization -# q = F.normalize(torch.randn(B, H, T, D, dtype=torch.float32), p=2, dim=-1).to( -# dtype -# ) -# k = F.normalize(torch.randn(B, H, T, D, dtype=torch.float32), p=2, dim=-1).to( -# dtype -# ) -# v = torch.randn(B, H, T, D, dtype=dtype) -# w = torch.randn(H, D, dtype=dtype) -# b = torch.randn(H, D, dtype=dtype) -# h0 = torch.randn(B, H, D, D, dtype=dtype) -# else: -# q = F.normalize(torch.randn(B, T, H, D, dtype=torch.float32), p=2, dim=-1).to( -# dtype -# ) -# k = F.normalize(torch.randn(B, T, H, D, dtype=torch.float32), p=2, dim=-1).to( -# dtype -# ) -# v = torch.randn(B, T, H, D, dtype=dtype) -# w = torch.randn(H, D, dtype=dtype) -# b = torch.randn(H, D, dtype=dtype) -# h0 = torch.randn(B, H, D, D, dtype=dtype) -# # we need to reshape here because head_first is True -# theta = theta.permute(0, 2, 1, 3) -# alpha = alpha.permute(0, 2, 1, 3) -# eta = eta.permute(0, 2, 1, 3) -# q, k, v, w, b, theta, alpha, eta = map( -# lambda x: x.to(device).requires_grad_(False), (q, k, v, w, b, theta, alpha, eta) -# ) -# # in titans paper, h0 is not learnable -# h0 = h0.to(device) - -# ref_naive, ref_ht_naive = fused_chunk_titans_linear( -# q.clone(), -# k.clone(), -# v.clone(), -# w.clone(), -# b.clone(), -# theta.clone(), -# alpha.clone(), -# eta.clone(), -# output_final_state=True, -# chunk_size=BT, -# initial_state=h0.clone(), -# head_first=head_first, -# ) -# ref, ref_ht = chunk_titans_linear_ref( -# q.clone(), -# k.clone(), -# v.clone(), -# w.clone(), -# b.clone(), -# theta.clone(), -# alpha.clone(), -# eta.clone(), -# output_final_state=True, -# chunk_size=BT, -# initial_state=h0.clone(), -# head_first=head_first, -# use_chunk=True, -# ) - -# # assert_close(" o", ref, ref_naive, 0.006) -# assert_close("ht", ref_ht, ref_ht_naive, 0.005) diff --git a/tests/ops/test_ttt.py b/tests/ops/test_ttt.py index e417cc6a18..5128e70682 100755 --- a/tests/ops/test_ttt.py +++ b/tests/ops/test_ttt.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import os +from typing import List import pytest import torch @@ -8,62 +9,44 @@ from fla.ops.ttt import chunk_ttt_linear, fused_chunk_ttt_linear from fla.ops.ttt.naive import chunk_ttt_linear_ref -from fla.utils import COMPILER_MODE, assert_close, check_shared_mem, device +from fla.utils import assert_close, check_shared_mem, device -if COMPILER_MODE: - test_b_list = [1] - test_t_list = [4096] - test_t_varlen_list = test_t_list - test_d_list = [64, 128] -else: - test_b_list = [2] - test_t_list = [1, 15, 63, 300] - test_t_varlen_list = [63, 286, 300, 512] - test_d_list = [50, 64, 100, 128] -test_h_list = [2] - -@pytest.mark.parametrize("B", test_b_list) -@pytest.mark.parametrize("T", test_t_list) -@pytest.mark.parametrize("H", test_h_list) -@pytest.mark.parametrize("D", test_d_list) -@pytest.mark.parametrize("scale", [0.1]) -@pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize("head_first", [True, False]) -@pytest.mark.skipif( - os.getenv("SKIP_TEST_CHUNK_VARLEN") == "0", - reason="Skipping test because TEST_CHUNK_VARLEN is enabled" +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-scale{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, 1, torch.float16), + (2, 100, 4, 60, 0.1, torch.float16), + (2, 1024, 3, 128, 0.1, torch.float16), + (2, 1024, 4, 128, 1, torch.float16), + (3, 2000, 4, 128, 0.1, torch.float16), + (4, 2048, 8, 64, 0.1, torch.float16), + ] + ] ) def test_chunk( B: int, T: int, H: int, D: int, - dtype: torch.dtype, scale: float, - head_first: bool + dtype: torch.dtype, ): if D > 64 and check_shared_mem('hopper') is False: pytest.skip(reason="Current CI do not support this config") + if T > 1000: + pytest.skip(reason="Current CI do not support this config") eta_base = 5e-3 - if head_first: - q = torch.randn(B, H, T, D, dtype=dtype) - k = F.normalize(torch.randn(B, H, T, D, dtype=torch.float32), p=2, dim=-1).to(dtype) - v = torch.randn(B, H, T, D, dtype=dtype) - w = torch.randn(H, D, dtype=dtype) - b = torch.randn(H, D, dtype=dtype) - eta = torch.randn(B, H, T, 1, dtype=dtype) * eta_base - h0 = torch.randn(B, H, D, D, dtype=torch.float32) - hb0 = torch.randn(B, H, 1, D, dtype=torch.float32) - else: - q = torch.randn(B, T, H, D, dtype=dtype) - k = F.normalize(torch.randn(B, T, H, D, dtype=torch.float32), p=2, dim=-1).to(dtype) - v = torch.randn(B, T, H, D, dtype=dtype) - w = torch.randn(H, D, dtype=dtype) - b = torch.randn(H, D, dtype=dtype) - eta = torch.randn(B, T, H, 1, dtype=dtype) * eta_base - h0 = torch.randn(B, H, D, D, dtype=torch.float32) - hb0 = torch.randn(B, H, 1, D, dtype=torch.float32) + q = torch.randn(B, T, H, D, dtype=dtype) + k = F.normalize(torch.randn(B, T, H, D, dtype=torch.float32), p=2, dim=-1).to(dtype) + v = torch.randn(B, T, H, D, dtype=dtype) + w = torch.randn(H, D, dtype=dtype) + b = torch.randn(H, D, dtype=dtype) + eta = torch.randn(B, T, H, 1, dtype=dtype) * eta_base + h0 = torch.randn(B, H, D, D, dtype=torch.float32) + hb0 = torch.randn(B, H, 1, D, dtype=torch.float32) q, k, v, w, b, eta, h0, hb0 = map(lambda x: x.to(device).requires_grad_(True), (q, k, v, w, b, eta, h0, hb0)) do = torch.rand_like(v) @@ -81,7 +64,6 @@ def test_chunk( output_final_state=True, initial_state=h0.clone(), initial_state_bias=hb0.clone(), - head_first=head_first ) ((tri * do).sum() + (tri_ht * dht).sum() + (tri_hbt * dhbt).sum()).backward(retain_graph=True) tri_dq, tri_dk, tri_dv, tri_dw, tri_db, tri_deta, \ @@ -99,7 +81,6 @@ def test_chunk( output_final_state=True, initial_state=h0.clone(), initial_state_bias=hb0.clone(), - head_first=head_first ) ((ref * do).sum() + (ref_ht * dht).sum() + (ref_hbt * dhbt).sum()).backward(retain_graph=True) ref_dq, ref_dk, ref_dv, ref_dw, ref_db, ref_deta, \ @@ -114,55 +95,46 @@ def test_chunk( assert_close(" dw", ref_dw, tri_dw, 0.006) assert_close(" db", ref_db, tri_db, 0.006) assert_close(" de", ref_deta, tri_deta, 0.030) # because the last element of the chunk - if head_first: - assert_close(" de0", ref_deta[:, :, :14, :], tri_deta[:, :, :14, :], 0.010) - else: - assert_close(" de0", ref_deta[:, :14, :, :], tri_deta[:, :14, :, :], 0.010) + assert_close(" de0", ref_deta[:, :14, :, :], tri_deta[:, :14, :, :], 0.010) assert_close(" dh0", ref_dh0, tri_dh0, 0.007) assert_close("dhb0", ref_dhb0, tri_dhb0, 0.005) -@pytest.mark.parametrize("B", test_b_list) -@pytest.mark.parametrize("T", test_t_list) -@pytest.mark.parametrize("H", test_h_list) -@pytest.mark.parametrize("D", test_d_list) -@pytest.mark.parametrize("scale", [0.1]) -@pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize("head_first", [True, False]) -@pytest.mark.skipif( - os.getenv("SKIP_TEST_CHUNK_VARLEN") == "0", - reason="Skipping test because TEST_CHUNK_VARLEN is enabled" +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-scale{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, 1, torch.float16), + (2, 100, 4, 60, 0.1, torch.float16), + (2, 1024, 3, 128, 0.1, torch.float16), + (2, 1024, 4, 128, 1, torch.float16), + (3, 2000, 4, 128, 0.1, torch.float16), + (4, 2048, 8, 64, 0.1, torch.float16), + ] + ] ) -def test_fused_chunk_fwd( +def test_fused_chunk( B: int, T: int, H: int, D: int, - dtype: torch.dtype, scale: float, - head_first: bool + dtype: torch.dtype, ): if D > 64 and check_shared_mem('hopper') is False: pytest.skip(reason="Current CI do not support this config") + if T > 1000: + pytest.skip(reason="Current CI do not support this config") eta_base = 5e-3 - if head_first: - q = torch.randn(B, H, T, D, dtype=dtype) - k = F.normalize(torch.randn(B, H, T, D, dtype=torch.float32), p=2, dim=-1).to(dtype) - v = torch.randn(B, H, T, D, dtype=dtype) - w = torch.randn(H, D, dtype=dtype) - b = torch.randn(H, D, dtype=dtype) - eta = torch.randn(B, H, T, 1, dtype=dtype) * eta_base - h0 = torch.randn(B, H, D, D, dtype=torch.float32) - hb0 = torch.randn(B, H, 1, D, dtype=torch.float32) - else: - q = torch.randn(B, T, H, D, dtype=dtype) - k = F.normalize(torch.randn(B, T, H, D, dtype=torch.float32), p=2, dim=-1).to(dtype) - v = torch.randn(B, T, H, D, dtype=dtype) - w = torch.randn(H, D, dtype=dtype) - b = torch.randn(H, D, dtype=dtype) - eta = torch.randn(B, T, H, 1, dtype=dtype) * eta_base - h0 = torch.randn(B, H, D, D, dtype=torch.float32) - hb0 = torch.randn(B, H, 1, D, dtype=torch.float32) + q = torch.randn(B, T, H, D, dtype=dtype) + k = F.normalize(torch.randn(B, T, H, D, dtype=torch.float32), p=2, dim=-1).to(dtype) + v = torch.randn(B, T, H, D, dtype=dtype) + w = torch.randn(H, D, dtype=dtype) + b = torch.randn(H, D, dtype=dtype) + eta = torch.randn(B, T, H, 1, dtype=dtype) * eta_base + h0 = torch.randn(B, H, D, D, dtype=torch.float32) + hb0 = torch.randn(B, H, 1, D, dtype=torch.float32) q, k, v, w, b, eta, h0, hb0 = map(lambda x: x.to(device).requires_grad_(True), (q, k, v, w, b, eta, h0, hb0)) do = torch.rand_like(v) @@ -180,7 +152,6 @@ def test_fused_chunk_fwd( output_final_state=True, initial_state=h0.clone(), initial_state_bias=hb0.clone(), - head_first=head_first ) ((tri * do).sum() + (tri_ht * dht).sum() + (tri_hbt * dhbt).sum()).backward(retain_graph=True) tri_dq, tri_dk, tri_dv, tri_dw, tri_db, tri_deta, \ @@ -198,7 +169,6 @@ def test_fused_chunk_fwd( output_final_state=True, initial_state=h0.clone(), initial_state_bias=hb0.clone(), - head_first=head_first ) ((ref * do).sum() + (ref_ht * dht).sum() + (ref_hbt * dhbt).sum()).backward(retain_graph=True) ref_dq, ref_dk, ref_dv, ref_dw, ref_db, ref_deta, \ @@ -213,42 +183,41 @@ def test_fused_chunk_fwd( assert_close(" dw", ref_dw, tri_dw, 0.005) assert_close(" db", ref_db, tri_db, 0.005) assert_close(" de", ref_deta, tri_deta, 0.03) # because the last element of the chunk - if head_first: - assert_close(" de0", ref_deta[:, :, :14, :], tri_deta[:, :, :14, :], 0.008) - else: - assert_close(" de0", ref_deta[:, :14, :, :], tri_deta[:, :14, :, :], 0.008) + assert_close(" de0", ref_deta[:, :14, :, :], tri_deta[:, :14, :, :], 0.008) assert_close(" dh0", ref_dh0, tri_dh0, 0.006) assert_close("dhb0", ref_dhb0, tri_dhb0, 0.005) -@pytest.mark.parametrize("N", test_b_list) -@pytest.mark.parametrize("T", test_t_varlen_list) -@pytest.mark.parametrize("H", test_h_list) -@pytest.mark.parametrize("D", test_d_list) -@pytest.mark.parametrize("scale", [0.1]) -@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize( + ('H', 'D', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-D{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (2, 64, [0, 15], torch.float16), + (3, 60, [0, 111, 500], torch.float16), + (3, 64, [0, 256, 500, 900, 1000], torch.float16), + (4, 100, [0, 15, 100, 300, 1200, 1599, 1800, 2000], torch.float16), + ] + ] +) @pytest.mark.skipif( os.getenv("SKIP_TEST_CHUNK_VARLEN") == "1", reason="Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set" ) -def test_chunk_varlen_fwd( - N: int, - T: int, +def test_chunk_varlen( H: int, D: int, - scale: float, + cu_seqlens: List[int], dtype: torch.dtype, ): if D > 64 and check_shared_mem('hopper') is False: pytest.skip(reason="Current CI do not support this config") torch.manual_seed(42) os.environ['TRITON_F32_DEFAULT'] = 'ieee' - # randomly split the sequence into N segments - offsets = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 1)[:N-1]], - torch.tensor([T], dtype=torch.long) - ], 0).to(device).sort()[0] + T = cu_seqlens[-1] + N = len(cu_seqlens) - 1 + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + eta_base = 5e-3 # seq-first required for inputs with variable lengths q = torch.randn((1, T, H, D), dtype=dtype) @@ -268,12 +237,10 @@ def test_chunk_varlen_fwd( w.clone(), b.clone(), eta.clone(), - scale=scale, output_final_state=True, initial_state=h0.clone(), initial_state_bias=hb0.clone(), - cu_seqlens=offsets, - head_first=False + cu_seqlens=cu_seqlens, ) ref = [] @@ -281,17 +248,15 @@ def test_chunk_varlen_fwd( ref_hbt = [] for i in range(N): ref_i, ref_ht_i, ref_hbt_i = chunk_ttt_linear_ref( - q=q[:, offsets[i]:offsets[i+1]], - k=k[:, offsets[i]:offsets[i+1]], - v=v[:, offsets[i]:offsets[i+1]], + q=q[:, cu_seqlens[i]:cu_seqlens[i+1]], + k=k[:, cu_seqlens[i]:cu_seqlens[i+1]], + v=v[:, cu_seqlens[i]:cu_seqlens[i+1]], w=w, b=b, - eta=eta[:, offsets[i]:offsets[i+1]], - scale=scale, + eta=eta[:, cu_seqlens[i]:cu_seqlens[i+1]], initial_state=h0[i], initial_state_bias=hb0[i], output_final_state=True, - head_first=False ) ref.append(ref_i) ref_ht.append(ref_ht_i) diff --git a/tests/ops/test_utils.py b/tests/ops/test_utils.py index 201cb4d968..4f62cb06f0 100644 --- a/tests/ops/test_utils.py +++ b/tests/ops/test_utils.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import os +from typing import List import pytest import torch @@ -8,19 +9,7 @@ from fla.ops.utils import chunk_global_cumsum, chunk_local_cumsum, mean_pooling from fla.ops.utils.index import prepare_lens from fla.ops.utils.pack import pack_sequence, unpack_sequence -from fla.utils import COMPILER_MODE, assert_close, device - -if COMPILER_MODE: - test_b_list = [1] - test_t_list = [4096] - test_t_varlen_list = test_t_list - test_d_list = [64, 128, 256] -else: - test_b_list = [2] - test_t_list = [1, 15, 63, 300] - test_t_varlen_list = [63, 286, 300, 512] - test_d_list = [64, 32, 100, 256] -test_h_list = [2] +from fla.utils import assert_close, device def reversed_cumsum(x, dim=-1): @@ -31,14 +20,18 @@ def reversed_cumsum(x, dim=-1): return y.to(dtype) -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-{}".format(*test)) + for test in [ + (1, 63, 1, 30, torch.float), + (2, 500, 4, 60, torch.float), + (2, 1000, 5, 128, torch.float), + (3, 1024, 6, 500, torch.float), + (4, 2048, 8, 1024, torch.float), + ] + ] ) def test_global_cumsum( B: int, @@ -59,28 +52,33 @@ def test_global_cumsum( assert_close('global_cumsum', ref, tri, 1e-3) -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_varlen_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize( + ('H', 'D', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-D{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (2, 60, [0, 15], torch.float), + (3, 100, [0, 256, 500, 1000], torch.float), + (4, 256, [0, 15, 100, 300, 1200, 2000], torch.float), + (4, 500, [0, 1, 100, 300, 1200, 2048], torch.float16), + (2, 1024, [0, 200, 512, 1200, 2048], torch.float16), + ] + ] +) @pytest.mark.skipif( os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', reason='Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set' ) def test_global_cumsum_varlen( - B: int, - T: int, H: int, D: int, + cu_seqlens: List[int], dtype: torch.dtype, ): torch.manual_seed(42) - cu_seqlens = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(1, T)[torch.randperm(T - 1)[:B-1]], - torch.tensor([T], dtype=torch.long) - ], 0).to(device).sort()[0] + T = cu_seqlens[-1] + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + s = torch.randn(1, T, H, dtype=dtype).to(device) ref = torch.cat([s[:, start:end].float().cumsum(1) for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])], 1).to(dtype) tri = chunk_global_cumsum(s, cu_seqlens=cu_seqlens) @@ -92,14 +90,18 @@ def test_global_cumsum_varlen( assert_close('global_cumsum', ref, tri, 1e-3) -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-{}".format(*test)) + for test in [ + (1, 63, 1, 30, torch.float), + (2, 500, 4, 60, torch.float), + (2, 1000, 5, 128, torch.float), + (3, 1024, 6, 500, torch.float), + (4, 2048, 8, 1024, torch.float), + ] + ] ) def test_global_reversed_cumsum( B: int, @@ -120,28 +122,33 @@ def test_global_reversed_cumsum( assert_close('global_cumsum', ref, tri, 1e-3) -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_varlen_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize( + ('H', 'D', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-D{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (2, 60, [0, 15], torch.float), + (3, 100, [0, 256, 500, 1000], torch.float), + (4, 256, [0, 15, 100, 300, 1200, 2000], torch.float), + (4, 500, [0, 1, 100, 300, 1200, 2048], torch.float16), + (2, 1024, [0, 200, 512, 1200, 2048], torch.float16), + ] + ] +) @pytest.mark.skipif( os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', reason='Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set' ) def test_global_reversed_cumsum_varlen( - B: int, - T: int, H: int, D: int, + cu_seqlens: List[int], dtype: torch.dtype, ): torch.manual_seed(42) - cu_seqlens = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(1, T)[torch.randperm(T - 1)[:B-1]], - torch.tensor([T], dtype=torch.long) - ], 0).to(device).sort()[0] + T = cu_seqlens[-1] + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + s = torch.randn(1, T, H, dtype=dtype).to(device) ref = torch.cat([reversed_cumsum(s[:, start:end], 1) for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])], 1).to(dtype) tri = chunk_global_cumsum(s, reverse=True, cu_seqlens=cu_seqlens) @@ -153,12 +160,19 @@ def test_global_reversed_cumsum_varlen( assert_close('global_reversed_cumsum', ref, tri, 1e-3) -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('C', [32, 64]) -@pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize( + ('B', 'T', 'H', 'C', 'D', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-C{}-D{}-{}".format(*test)) + for test in [ + (1, 63, 1, 16, 30, torch.float), + (2, 500, 4, 32, 60, torch.float), + (2, 1000, 5, 64, 128, torch.float), + (3, 1024, 6, 64, 500, torch.float), + (4, 2048, 8, 128, 1024, torch.float), + ] + ] +) def test_local_cumsum( B: int, T: int, @@ -179,30 +193,34 @@ def test_local_cumsum( assert_close('local_cumsum', ref, tri, 1e-3) -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_varlen_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('C', [32, 64]) -@pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize( + ('H', 'C', 'D', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-C{}-D{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (2, 32, 60, [0, 15], torch.float), + (3, 64, 100, [0, 256, 500, 1000], torch.float), + (4, 64, 256, [0, 15, 100, 300, 1200, 2000], torch.float), + (4, 128, 500, [0, 1, 100, 300, 1200, 2048], torch.float16), + (2, 128, 1024, [0, 200, 512, 1200, 2048], torch.float16), + ] + ] +) @pytest.mark.skipif( os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', reason='Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set' ) def test_local_cumsum_varlen( - B: int, - T: int, H: int, C: int, D: int, + cu_seqlens: List[int], dtype: torch.dtype, ): torch.manual_seed(42) - cu_seqlens = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(1, T)[torch.randperm(T - 1)[:B-1]], - torch.tensor([T], dtype=torch.long) - ], 0).to(device).sort()[0] + T = cu_seqlens[-1] + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + s = torch.randn(1, T, H, dtype=dtype).to(device) ref = torch.cat([ torch.cat([s[:, i:min(end, i+C), :].float().cumsum(1) for i in range(start, end, C)], 1) @@ -220,15 +238,18 @@ def test_local_cumsum_varlen( assert_close('local_cumsum', ref, tri, 1e-3) -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('C', [32, 64]) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.skipif( - os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', - reason='Skipping test because TEST_CHUNK_VARLEN is enabled' +@pytest.mark.parametrize( + ('B', 'T', 'H', 'C', 'D', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-C{}-D{}-{}".format(*test)) + for test in [ + (1, 63, 1, 16, 30, torch.float), + (2, 500, 4, 32, 60, torch.float), + (2, 1000, 5, 64, 128, torch.float), + (3, 1024, 6, 64, 500, torch.float), + (4, 2048, 8, 128, 1024, torch.float), + ] + ] ) def test_mean_pooling( B: int, @@ -254,30 +275,33 @@ def test_mean_pooling( assert_close('mean_pooling', ref_dx, tri_dx, 1e-3) -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('C', [32, 64]) -@pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize( + ('H', 'C', 'D', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-C{}-D{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (2, 32, 60, [0, 15], torch.float), + (3, 64, 100, [0, 256, 500, 1000], torch.float), + (4, 64, 256, [0, 15, 100, 300, 1200, 2000], torch.float), + (4, 128, 500, [0, 1, 100, 300, 1200, 2048], torch.float16), + (2, 128, 1024, [0, 200, 512, 1200, 2048], torch.float16), + ] + ] +) @pytest.mark.skipif( os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', reason='Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set' ) def test_mean_pooling_varlen( - B: int, - T: int, H: int, C: int, D: int, + cu_seqlens: List[int], dtype: torch.dtype, ): torch.manual_seed(42) - cu_seqlens = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(1, T)[torch.randperm(T - 1)[:B-1]], - torch.tensor([T], dtype=torch.long) - ], 0).to(device).sort()[0] + T = cu_seqlens[-1] + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) x = torch.randn(1, T, H, D, dtype=dtype).to(device).requires_grad_(True) ref = torch.cat([ @@ -296,12 +320,19 @@ def test_mean_pooling_varlen( torch.testing.assert_close(ref_dx, tri_dx.to(ref_dx.dtype), rtol=1.6e-2, atol=3e-5) -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('padding_side', ['left', 'right']) -@pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'padding_side', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-padding_side{}-{}".format(*test)) + for test in [ + (1, 63, 1, 30, 'left', torch.float), + (2, 500, 4, 60, 'right', torch.float), + (2, 1000, 5, 128, 'left', torch.float), + (3, 1024, 6, 500, 'right', torch.float), + (4, 2048, 8, 1024, 'left', torch.float), + ] + ] +) def test_pack_sequence( B: int, T: int, @@ -329,16 +360,23 @@ def test_pack_sequence( tri.backward(dy) tri_dx, x.grad = x.grad.clone(), None - assert_close(' y', ref, tri, 1e-3) + assert_close('y', ref, tri, 1e-3) assert_close('dx', ref_dx, tri_dx, 1e-3) -@pytest.mark.parametrize('B', test_b_list) -@pytest.mark.parametrize('T', test_t_list) -@pytest.mark.parametrize('H', test_h_list) -@pytest.mark.parametrize('D', test_d_list) -@pytest.mark.parametrize('padding_side', ['left', 'right']) -@pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'padding_side', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-padding_side{}-{}".format(*test)) + for test in [ + (1, 63, 1, 30, 'left', torch.float), + (2, 500, 4, 60, 'right', torch.float), + (2, 1000, 5, 128, 'left', torch.float), + (3, 1024, 6, 500, 'right', torch.float), + (4, 2048, 8, 1024, 'left', torch.float), + ] + ] +) def test_unpack_sequence( B: int, T: int, @@ -370,5 +408,5 @@ def test_unpack_sequence( tri.backward(dy) tri_dx, x.grad = x.grad.clone(), None - assert_close(' y', ref, tri, 1e-3) + assert_close('y', ref, tri, 1e-3) assert_close('dx', ref_dx, tri_dx, 1e-3) diff --git a/tests/test_fused_chunk.py b/tests/test_fused_chunk.py deleted file mode 100644 index 60872d00e8..0000000000 --- a/tests/test_fused_chunk.py +++ /dev/null @@ -1,127 +0,0 @@ -# -*- coding: utf-8 -*- - -import torch -import triton -import triton.language as tl - - -@triton.jit -def attention_fwd_kernel( - q, - k, - v, - h, - o, - s_qh, - s_qt, - s_qd, - s_hh, - s_ht, - T, - scale, - BT: tl.constexpr, - BD: tl.constexpr, - NT: tl.constexpr, - STORE: tl.constexpr, - IFCOND: tl.constexpr -): - i_bh = tl.program_id(0) - - # [BD, BD] - b_h = tl.zeros([BD, BD], dtype=tl.float32) - for i in range(0, tl.cdiv(T, BT)): - p_q = tl.make_block_ptr(q + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0)) - p_k = tl.make_block_ptr(k + i_bh * s_qh, (BD, T), (s_qd, s_qt), (0, i * BT), (BD, BT), (0, 1)) - p_v = tl.make_block_ptr(v + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0)) - p_h = tl.make_block_ptr(h + i_bh * s_hh, (NT * BD, BD), (s_ht, s_qd), (i * BD, 0), (BD, BD), (1, 0)) - p_o = tl.make_block_ptr(o + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0)) - - if STORE: - tl.store(p_h, b_h.to(p_h.dtype.element_ty)) - # [BT, BD] - b_q = tl.load(p_q) - b_q = (b_q * scale).to(b_q.dtype) - # [BD, BT] - b_k = tl.load(p_k) - # [BT, BD] - b_v = tl.load(p_v) - - # [BT, BT] - b_s = tl.dot(b_q, b_k, allow_tf32=False) - # [BT, BD] - b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) - if IFCOND: - if i == 0: - b_h = tl.dot(b_k, b_v, allow_tf32=False) - else: - b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) - b_h += tl.dot(b_k, b_v, allow_tf32=False) - else: - b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) - b_h += tl.dot(b_k, b_v, allow_tf32=False) - - tl.store(p_o, b_o.to(p_o.dtype.element_ty)) - - -class AttentionFunction(torch.autograd.Function): - - @staticmethod - def forward(ctx, q, k, v, store=False, ifcond=False): - batch_size, n_heads, seq_len, d_head = q.shape - scale = d_head ** -0.5 - BD = q.shape[-1] - BT = 32 - NT = triton.cdiv(seq_len, BT) - num_stages = 3 if d_head <= 64 else 2 - num_warps = 4 - - h = q.new_empty(batch_size, n_heads, NT * BD, BD) - o = torch.empty_like(q) - grid = (batch_size * n_heads,) - attention_fwd_kernel[grid]( - q, k, v, h, o, - q.stride(1), q.stride(2), q.stride(3), h.stride(1), h.stride(2), - seq_len, scale, - BT=BT, BD=BD, NT=NT, STORE=store, IFCOND=ifcond, - num_warps=num_warps, - num_stages=num_stages - ) - return o - - -def sizeof_fmt(num, suffix='B'): - for unit in ('', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi'): - if abs(num) < 1024.0: - return f'{num:3.1f}{unit}{suffix}' - num /= 1024.0 - return f'{num:.1f}Yi{suffix}' - - -if __name__ == '__main__': - B, H, T, D = 2, 8, 1024, 128 - dtype = torch.float - torch.manual_seed(42) - from fla.utils import device - - # [batch_size, n_heads, seq_len, d_head] - q = torch.randn((B, H, T, D), dtype=dtype, device=device) - k = torch.randn((B, H, T, D), dtype=dtype, device=device) - v = torch.randn((B, H, T, D), dtype=dtype, device=device) - - ref = AttentionFunction.apply(q, k, v) - infos = torch.cuda.get_device_properties(q) - - def fmt(x): - if isinstance(x, (float, torch.Tensor)): - return f"{x:>16.2f}" - return f"{str(x):>16}" - print(f'{fmt("GPU Type")}{fmt("Memory")}{fmt("Cores")}\n' - f"{fmt(infos.name)}{fmt(sizeof_fmt(infos.total_memory))}{fmt(infos.multi_processor_count)}") - print(f'{"DTYPE":>16}{"STORE":>16}{"INIT CHECK":>16}{"DIFF":>16}{"PASSED":>16}') - for dtype in (torch.float, torch.bfloat16): - q, k, v = q.clone().to(dtype), k.clone().to(dtype), v.clone().to(dtype) - for store in [False, True]: - for check in [False, True]: - tri = AttentionFunction.apply(q, k, v, store, check) - diff = (ref - tri).abs().max() - print(f"{fmt(q.dtype)}{fmt(store)}{fmt(check)}{fmt(diff)}{fmt(bool(diff < 1))}")