diff --git a/README.md b/README.md index b7e02867095..406157fed60 100755 --- a/README.md +++ b/README.md @@ -128,74 +128,47 @@ FlashAttention-2 ROCm CK backend currently supports: 3. Both forward's and backward's head dimensions up to 256. #### Triton Backend -The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress. +The Triton implementation of [Flash Attention](https://tridao.me/publications/flash2/flash2.pdf) supports AMD's CDNA (MI200, MI300) and RDNA GPUs using fp16, bf16, and fp32 datatypes. It provides forward and backward passes with causal masking, variable sequence lengths, arbitrary Q/KV sequence lengths and head sizes, MQA/GQA, dropout, rotary embeddings, ALiBi, paged attention, and FP8 (via the Flash Attention v3 interface). Sliding window attention is currently a work in progress. -It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes. - -These features are supported in Fwd and Bwd -1) Fwd and Bwd with causal masking -2) Variable sequence lengths -3) Arbitrary Q and KV sequence lengths -4) Arbitrary head sizes -5) Multi and grouped query attention -6) Dropout -7) Rotary embeddings -8) ALiBi - -We are working on the following things -1) Paged Attention -2) Sliding Window -3) FP8 -4) Performance Improvements - -##### Getting Started -To get started with the triton backend for AMD, follow the steps below. - -First install the torch for ROCm from https://pytorch.org/get-started/locally/ if it is not installed. The torch and triton will be installed. - -Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`. - -``` +To install, first get PyTorch for ROCm from https://pytorch.org/get-started/locally/, then install Triton and Flash Attention: +```sh +pip install triton==3.5.1 cd flash-attention FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install ``` -To test that things are working, you can run our tests. These tests take hours so you don't need to run the full thing. -``` +To run the tests (note: full suite takes hours): +```sh FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py ``` -You can use autotune for better performance by using this flag `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"` -``` -FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE" python $PATH_TO_CODE -``` +For better performance, enable autotune with `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"`. -###### Docker -You can also use the Dockerfile below which does the above steps on top of the latest rocm/pytorch image. -``` +For a quick start with Docker: +```dockerfile FROM rocm/pytorch:latest WORKDIR /workspace -# install flash attention -ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" +# install triton +RUN pip install triton==3.5.1 -RUN git clone https://github.com/ROCm/flash-attention.git &&\ +# build flash attention with triton backend +RUN git clone https://github.com/Dao-AILab/flash-attention &&\ cd flash-attention &&\ - python setup.py install + FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install # set working dir WORKDIR /workspace/flash-attention -``` -To build the docker file -``` -docker build -t fa_triton . +# set env variable to use triton backend +ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" ``` -To run the docker image -``` -docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri fa_triton +Build and run: +```sh +docker build -t flash-attn-triton . +docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri flash-attn-triton ``` ## How to use FlashAttention diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 865f1db5432..a53b4a3108a 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -10,7 +10,7 @@ # We need to import the CUDA kernels after importing torch USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" if USE_TRITON_ROCM: - from .flash_attn_triton_amd import interface_fa as flash_attn_gpu + from .flash_attn_triton_amd import flash_attn_2 as flash_attn_gpu else: import flash_attn_2_cuda as flash_attn_gpu diff --git a/flash_attn/flash_attn_triton_amd/Dockerfile b/flash_attn/flash_attn_triton_amd/Dockerfile deleted file mode 100644 index 29a2c0c43ec..00000000000 --- a/flash_attn/flash_attn_triton_amd/Dockerfile +++ /dev/null @@ -1,17 +0,0 @@ -FROM rocm/pytorch:latest - -WORKDIR /workspace - -# install triton -RUN pip install triton==3.2.0 - -# install flash attention -ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" - -RUN git clone https://github.com/ROCm/flash-attention.git &&\ - cd flash-attention &&\ - git checkout main_perf &&\ - python setup.py install - -# set working dir -WORKDIR /workspace/flash-attention \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/README.md b/flash_attn/flash_attn_triton_amd/README.md deleted file mode 100644 index 2d8fd8e70f3..00000000000 --- a/flash_attn/flash_attn_triton_amd/README.md +++ /dev/null @@ -1,113 +0,0 @@ -Flash Attention Triton Kernel -=============== - -#### Introduction -The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress. - -It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes. - -These features are supported in Fwd and Bwd -1) Fwd and Bwd with causal masking -2) Variable sequence lengths -3) Arbitrary Q and KV sequence lengths -4) Arbitrary head sizes -5) Multi and grouped query attention -6) Dropout -7) Rotary embeddings -8) ALiBi - -We are working on the following things -1) Paged Attention -2) Sliding Window -3) FP8 -4) Performance Improvements - -##### Getting Started -To get started with the triton backend for AMD, follow the steps below. - -First install the recommended Triton version - -``` -pip install triton==3.2.0 -``` -Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`. - -``` -cd flash-attention -git checkout main_perf -FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install -``` - -To test that things are working, you can run our tests. These tests take hours so you don't need to run the full thing. -``` -FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py -``` - -You can use autotune for better performance by using this flag `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"` -``` -FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE" python $PATH_TO_CODE -``` - -###### Docker -You can also use the Dockerfile below which does the above steps on top of the latest rocm/pytorch image. -``` -FROM rocm/pytorch:latest - -WORKDIR /workspace - -# install triton -RUN pip install triton==3.2.0 - -# install flash attention -ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" - -RUN git clone https://github.com/ROCm/flash-attention.git &&\ - cd flash-attention &&\ - git checkout main_perf &&\ - python setup.py install - -# set working dir -WORKDIR /workspace/flash-attention -``` - -To build the docker file -``` -docker build -t fa_triton . -``` - -To run the docker image -``` -docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri fa_triton -``` - -###### FP8 -In our fork We have created the following api functions that use fp8 to compute their values. These functions are `flash_attn_fp8_func`, `flash_attn_varlen_fp8_func`, `flash_attn_qkvpacked_fp8_func` and `flash_attn_varlen_qkvpacked_fp8_func`. To use these functions just call them with like the other api functions, the casting will be handled internally. For example - -``` -from flash_attn import flash_attn_qkvpacked_fp8_func - -# forward pass -out, lse, S_dmask = flash_attn_qkvpacked_fp8_func( - qkv, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - -# backward pass -do = torch.randn_like(out) -dqkv = torch.autograd.grad(out, (qkv), do) -``` - -You can use the other api functions in a similar way. - - - -##### Credits -AMD Triton kernels team - -OpenAI kernel team diff --git a/flash_attn/flash_attn_triton_amd/__init__.py b/flash_attn/flash_attn_triton_amd/__init__.py index e69de29bb2d..78f85fb268f 100644 --- a/flash_attn/flash_attn_triton_amd/__init__.py +++ b/flash_attn/flash_attn_triton_amd/__init__.py @@ -0,0 +1,4 @@ +from . import interface_v2 as flash_attn_2 +from . import interface_v3 as flash_attn_3 + +__all__ = ["flash_attn_2", "flash_attn_3"] diff --git a/flash_attn/flash_attn_triton_amd/bench.py b/flash_attn/flash_attn_triton_amd/bench.py deleted file mode 100755 index 05e64c349be..00000000000 --- a/flash_attn/flash_attn_triton_amd/bench.py +++ /dev/null @@ -1,1223 +0,0 @@ -import os -import sys -import torch -import triton -import time -import argparse -import itertools -import pandas as pd -from logging import warning -from typing import Dict, List, Literal, Optional, Tuple -from dataclasses import dataclass -from functools import lru_cache -from utils import get_arch, input_helper - -DEBUG = False - -ENV_FLAGS = ["FLASH_ATTENTION_TRITON_AMD_ENABLE", "FLASH_ATTENTION_TRITON_AMD_AUTOTUNE", "FLASH_ATTENTION_TRITON_AMD_DEBUG"] - -FUNCTIONS = [ - "flash_attn_func", - "flash_attn_fp8_func", - "flash_attn_kvpacked_func", - "flash_attn_qkvpacked_func", - "flash_attn_qkvpacked_fp8_func", - "flash_attn_varlen_func", - "flash_attn_varlen_fp8_func", - "flash_attn_varlen_kvpacked_func", - "flash_attn_varlen_qkvpacked_func", - "flash_attn_varlen_qkvpacked_fp8_func", - "flash_attn_with_kvcache", -] - -SUPPORTED_DTYPES = { - "flash_attn_func": [torch.float16], - "flash_attn_fp8_func": [torch.float8_e4m3fnuz], - "flash_attn_kvpacked_func": [torch.float16], - "flash_attn_qkvpacked_func": [torch.float16], - "flash_attn_qkvpacked_fp8_func": [torch.float16], - "flash_attn_varlen_func": [torch.float16], - "flash_attn_varlen_fp8_func": [torch.float8_e4m3fnuz], - "flash_attn_varlen_kvpacked_func": [torch.float16], - "flash_attn_varlen_qkvpacked_func": [torch.float16], - "flash_attn_varlen_qkvpacked_fp8_func": [torch.float16], - "flash_attn_with_kvcache": [torch.float16], -} - -SUPPORTED_BACKENDS = { - "flash_attn_func": ["ck", "triton"], - "flash_attn_fp8_func": ["triton"], - "flash_attn_kvpacked_func": ["ck", "triton"], - "flash_attn_qkvpacked_func": ["ck", "triton"], - "flash_attn_qkvpacked_fp8_func": ["triton"], - "flash_attn_varlen_func": ["ck", "triton"], - "flash_attn_varlen_fp8_func": ["triton"], - "flash_attn_varlen_kvpacked_func": ["ck", "triton"], - "flash_attn_varlen_qkvpacked_func": ["ck", "triton"], - "flash_attn_varlen_qkvpacked_fp8_func": ["triton"], - "flash_attn_with_kvcache": ["ck", "triton"], -} - -VALID_MODES = ['fwd', 'bwd', 'full'] -SUPPORTED_MODES = { - "flash_attn_func": ["fwd", "bwd", "full"], - "flash_attn_fp8_func": ["fwd", "bwd", "full"], - "flash_attn_kvpacked_func": ["fwd", "bwd", "full"], - "flash_attn_qkvpacked_func": ["fwd", "bwd", "full"], - "flash_attn_qkvpacked_fp8_func": ["fwd", "bwd", "full"], - "flash_attn_varlen_func": ["fwd", "bwd", "full"], - "flash_attn_varlen_fp8_func": ["fwd", "bwd", "full"], - "flash_attn_varlen_kvpacked_func": ["fwd", "bwd", "full"], - "flash_attn_varlen_qkvpacked_func": ["fwd", "bwd", "full"], - "flash_attn_varlen_qkvpacked_fp8_func": ["fwd", "bwd", "full"], - "flash_attn_with_kvcache": ["fwd"], -} - -@dataclass -class EnvVariableConfig: - key: str - values: List[str] - backend: Optional[Literal["triton", "ck"]] = None - -ENV_VARIABLE_CONFIGS : List[EnvVariableConfig] = [ - EnvVariableConfig(key="BWD_MODE", values=["split", "fused", "jingning"], backend="triton"), -] - -class FunctionConfig: - def __init__(self, fn_name: str, mode: Literal["fwd", "bwd", "full"], dtype, backend: Literal["triton", "ck"], env_config: Dict): - self.fn_name = fn_name - self.mode: Literal["fwd", "bwd", "full"] = mode - self.dtype = dtype - self.backend: Literal["triton", "ck"] = backend - self.arch = get_arch() - self.env_configs = env_config - - def __str__(self): - # extract base dtype name if it's a torch dtype - dtype_str = str(self.dtype) - if "torch." in dtype_str: - dtype_str = dtype_str.split(".")[-1] - - if len(self.env_configs) > 0: - env_str = "" - for env_key, env_value in self.env_configs.items(): - env_str += f"{env_key}={env_value}" - return f"{self.fn_name}_{self.mode}_{dtype_str}_{self.backend}_{self.arch}_{env_str}" - else: - return f"{self.fn_name}_{self.mode}_{dtype_str}_{self.backend}_{self.arch}" - - def column_name(self): - return f"{self}_ms" - - -@lru_cache() -def available_backends(): - available = [] - - # try to load each backend - for backend in ["triton", "ck"]: - try: - # try loading the module with this backend - flash_attn = load_flash_attn_module(backend) - - # if we got here, the backend loaded successfully - available.append(backend) - except Exception as e: - # backend not available, just continue - print(f"Backend {backend} not available. Error: {e}") - - # if no backends available, default to triton - if not available: - raise ValueError("No Backends available") - - return available - -@lru_cache() -def get_fn_params(fn_name): - # get params for fn - packing = get_packing_type(fn_name) - is_varlen = True if "varlen" in fn_name else False - is_fp8 = True if "fp8" in fn_name else False - supported_dtypes = SUPPORTED_DTYPES.get(fn_name, [torch.float16]) # default to float16 if not found - supported_backends = [backend for backend in SUPPORTED_BACKENDS.get(fn_name, ["triton"]) if backend in available_backends()] # default to triton backend - supports_backward = False if fn_name in ["flash_attn_with_kvcache"] else True - supported_modes = SUPPORTED_MODES.get(fn_name, ["fwd"]) - device = "cuda" - - # get supported env configs for each backend - supported_env_configs = {} - for backend in supported_backends: - supported_env_configs[backend] = get_env_value_combinations(backend) - - # check backward pass support - if not supports_backward: - warning(f"{fn_name} does not have a backward pass so benching forward pass only.") - - return is_varlen, is_fp8, packing, supported_dtypes, supported_backends, supported_modes, supported_env_configs, device - -def generate_fn_inputs( - fn_name: str, - BATCH: int, - HQ: int, - HK: int, - N_CTX_Q: int, - N_CTX_K: int, - D_HEAD: int, - CAUSAL: bool, - DROPOUT_P: float, - dtype: torch.dtype, - device: Literal["cpu", "cuda"] - ): - if fn_name == "flash_attn_func": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", device=device) - elif fn_name == "flash_attn_kvpacked_func": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", packing="kv", device=device) - elif fn_name == "flash_attn_qkvpacked_func": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", packing="qkv", device=device) - elif fn_name == "flash_attn_varlen_func": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="thd", device=device) - elif fn_name == "flash_attn_varlen_kvpacked_func": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="thd", packing="kv", device=device) - elif fn_name == "flash_attn_varlen_qkvpacked_func": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="thd", packing="qkv", device=device) - elif fn_name == "flash_attn_with_kvcache": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", device=device) - elif fn_name == "flash_attn_fp8_func": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", device=device) - elif fn_name == "flash_attn_qkvpacked_fp8_func": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", packing="qkv", device=device) - elif fn_name == "flash_attn_varlen_fp8_func": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="thd", device=device) - elif fn_name == "flash_attn_varlen_qkvpacked_fp8_func": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="thd", packing="qkv", device=device) - else: - valid_fn_names = ", ".join(FUNCTIONS) - raise ValueError(f"{fn_name} should be one of the following functions. {valid_fn_names}") - -def estimate_memory(config): - batch, hq, hk, sq, sk, d_head, causal, dropout = config - memory_estimate = batch * (hq * sq + hk * sk) * d_head * 4 # bytes - return memory_estimate - -def generate_benchmark_configs(is_varlen: bool, packing: Optional[Literal["kv", "qkv"]]): - """ - generates a small number of configs that cover the parameter space well - """ - - # define all parameter options as lists - batch_sizes = [1, 64] - if packing == "qkv": - hq_values = hk_values = [2, 8] - sq_values = sk_values = [256, 8192] - else: - if is_varlen: # make sure the seqlen is greater than the batchsize so that subsequences are greater than 0 - hq_values = [16, 32] # test mqa/gqa - hk_values = [8, 16] - sq_values = [128, 512] - sk_values = [512, 2024] - else: - hq_values = [64, 128] # test mqa/gqa - hk_values = [16, 64] - sq_values = [4, 4096] - sk_values = [4096, 16384] # test large k values for inference perf - d_head_values = [64, 128] - causal_values = [True, False] # most models usual causal True - dropout_values = [0.0, 0.1] - - # generate all fn_configs without inputs - input_configs = [] - - # one big loop to generate configs - for batch in batch_sizes: - for hq in hq_values: - for hk in hk_values: - for sq in sq_values: - for sk in sk_values: - for d_head in d_head_values: - for causal in causal_values: - for dropout in dropout_values: - # filter configs - input_config = (batch, hq, hk, sq, sk, d_head, causal, dropout) - - # skip if memory usage would be too high - if estimate_memory(input_config) > 8 * 1024 * 1024 * 1024: # 8 GB limit - continue - - # we need hq to be a multiple of hk - if hq % hk != 0: - continue - - # for qkvpacked functions, q and k must have same dimensions - if packing == "qkv" and (sq != sk or hq != hk): - continue - - input_configs.append(input_config) - - return input_configs - -def create_benchmark_fn( - flash_attn, - fn_name, - fn_input, - mode: Literal["fwd", "bwd", "full"] -): - if DEBUG: - print("create_benchmark_fn") - print("flash_attn:", flash_attn) - print("fn_name:", fn_name) - print("fn_input:", len(fn_input)) - print("mode:", mode) - - if fn_name == "flash_attn_func": - q, k, v, do, metadata = fn_input - if mode == "fwd": - def flash_attn_bench_fn(): - out, lse, S_dmask = flash_attn.flash_attn_func( - q, - k, - v, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - return out - elif mode == "bwd": - out, lse, S_dmask = flash_attn.flash_attn_func( - q, - k, - v, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - def flash_attn_bench_fn(): - dq, dk, dv = torch.autograd.grad(out, (q, k, v), do, retain_graph=True) - return dq, dk, dv - elif mode == "full": - def flash_attn_bench_fn(): - out, lse, S_dmask = flash_attn.flash_attn_func( - q, - k, - v, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - dq, dk, dv = torch.autograd.grad(out, (q, k, v), do, retain_graph=True) - return dq, dk, dv - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_bench_fn - - elif fn_name == "flash_attn_kvpacked_func": - q, kv, do, metadata = fn_input - if mode == "fwd": - def flash_attn_kvpacked_bench_fn(): - out, lse, S_dmask = flash_attn.flash_attn_kvpacked_func( - q, - kv, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - return out - elif mode == "bwd": - out, lse, S_dmask = flash_attn.flash_attn_kvpacked_func( - q, - kv, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - def flash_attn_kvpacked_bench_fn(): - dq, dkv = torch.autograd.grad(out, (q, kv), do, retain_graph=True) - return dq, dkv - elif mode == "full": - def flash_attn_kvpacked_bench_fn(): - out, lse, S_dmask = flash_attn.flash_attn_kvpacked_func( - q, - kv, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - dq, dkv = torch.autograd.grad(out, (q, kv), do, retain_graph=True) - return dq, dkv - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_kvpacked_bench_fn - elif fn_name == "flash_attn_qkvpacked_func": - qkv, do, metadata = fn_input - if mode == "fwd": - def flash_attn_qkvpacked_bench_fn(): - out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_func( - qkv, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - return out - elif mode == "bwd": - out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_func( - qkv, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - def flash_attn_qkvpacked_bench_fn(): - dqkv = torch.autograd.grad(out, (qkv), do, retain_graph=True) - return dqkv - elif mode == "full": - def flash_attn_qkvpacked_bench_fn(): - out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_func( - qkv, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - dqkv = torch.autograd.grad(out, (qkv), do, retain_graph=True) - return dqkv - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_qkvpacked_bench_fn - elif fn_name == "flash_attn_varlen_func": - q_unpad, k_unpad, v_unpad, do_unpad, metadata = fn_input - if mode == "fwd": - def flash_attn_varlen_bench_fn(): - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_func( - q_unpad, - k_unpad, - v_unpad, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - return out_unpad - elif mode == "bwd": - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_func( - q_unpad, - k_unpad, - v_unpad, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - def flash_attn_varlen_bench_fn(): - dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), do_unpad, retain_graph=True) - return dq_unpad, dk_unpad, dv_unpad - elif mode == "full": - def flash_attn_varlen_bench_fn(): - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_func( - q_unpad, - k_unpad, - v_unpad, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), do_unpad, retain_graph=True) - return dq_unpad, dk_unpad, dv_unpad - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_varlen_bench_fn - elif fn_name == "flash_attn_varlen_kvpacked_func": - q_unpad, kv_unpad, do_unpad, metadata = fn_input - if mode == "fwd": - def flash_attn_varlen_kvpacked_bench_fn(): - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_kvpacked_func( - q_unpad, - kv_unpad, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - return out_unpad - elif mode == "bwd": - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_kvpacked_func( - q_unpad, - kv_unpad, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - def flash_attn_varlen_kvpacked_bench_fn(): - dq_unpad, dkv_unpad = torch.autograd.grad(out_unpad, (q_unpad, kv_unpad), do_unpad, retain_graph=True) - return dq_unpad, dkv_unpad - elif mode == "full": - def flash_attn_varlen_kvpacked_bench_fn(): - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_kvpacked_func( - q_unpad, - kv_unpad, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - dq_unpad, dkv_unpad = torch.autograd.grad(out_unpad, (q_unpad, kv_unpad), do_unpad, retain_graph=True) - return dq_unpad, dkv_unpad - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_varlen_kvpacked_bench_fn - elif fn_name == "flash_attn_varlen_qkvpacked_func": - qkv_unpad, do_unpad, metadata = fn_input - if mode == "fwd": - def flash_attn_varlen_qkvpacked_bench_fn(): - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_func( - qkv_unpad, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - return out_unpad - elif mode == "bwd": - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_func( - qkv_unpad, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - def flash_attn_varlen_qkvpacked_bench_fn(): - dqkv_unpad = torch.autograd.grad(out_unpad, (qkv_unpad), do_unpad, retain_graph=True) - return dqkv_unpad - elif mode == "full": - def flash_attn_varlen_qkvpacked_bench_fn(): - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_func( - qkv_unpad, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - dqkv_unpad = torch.autograd.grad(out_unpad, (qkv_unpad), do_unpad, retain_graph=True) - return dqkv_unpad - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_varlen_qkvpacked_bench_fn - elif fn_name == "flash_attn_with_kvcache": - q, k_cache, v_cache, _, metadata = fn_input - if mode == "fwd": - def flash_attn_with_kvcache_bench_fn(): - out = flash_attn.flash_attn_with_kvcache( - q, - k_cache, - v_cache, - None, - None, - rotary_cos=None, - rotary_sin=None, - cache_seqlens=None, - cache_batch_idx=None, - cache_leftpad=None, - block_table=None, - causal=metadata.causal, - window_size=(-1, -1), - rotary_interleaved=False, - alibi_slopes=None, - num_splits=0, - ) - return out - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_with_kvcache_bench_fn - elif fn_name == "flash_attn_fp8_func": - (q, descale_q), (k, descale_k), (v, descale_v), (do, descale_do), metadata = fn_input - if mode == "fwd": - def flash_attn_f8_bench_fn(): - out, lse, S_dmask = flash_attn.flash_attn_fp8_func( - q, - k, - v, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - return out - elif mode == "bwd": - out, lse, S_dmask = flash_attn.flash_attn_fp8_func( - q, - k, - v, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - def flash_attn_f8_bench_fn(): - dq, dk, dv = torch.autograd.grad(out, (q, k, v), do, retain_graph=True) - return dq, dk, dv - elif mode == "full": - def flash_attn_f8_bench_fn(): - out, lse, S_dmask = flash_attn.flash_attn_fp8_func( - q, - k, - v, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - dq, dk, dv = torch.autograd.grad(out, (q, k, v), do, retain_graph=True) - return dq, dk, dv - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_f8_bench_fn - elif fn_name == "flash_attn_qkvpacked_fp8_func": - qkv, do, metadata = fn_input - if mode == "fwd": - def flash_attn_qkvpacked_fp8_bench_fn(): - out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_fp8_func( - qkv, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - return out - elif mode == "bwd": - out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_fp8_func( - qkv, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - def flash_attn_qkvpacked_fp8_bench_fn(): - dqkv = torch.autograd.grad(out, (qkv), do, retain_graph=True) - return dqkv - elif mode == "full": - def flash_attn_qkvpacked_fp8_bench_fn(): - out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_fp8_func( - qkv, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - dqkv = torch.autograd.grad(out, (qkv), do, retain_graph=True) - return dqkv - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_qkvpacked_fp8_bench_fn - elif fn_name == "flash_attn_varlen_fp8_func": - (q_unpad, descale_q), (k_unpad, descale_k), (v_unpad, descale_v), (do_unpad, descale_do), metadata = fn_input - if mode == "fwd": - def flash_attn_varlen_fp8_bench_fn(): - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_fp8_func( - q_unpad, - k_unpad, - v_unpad, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - return out_unpad - elif mode == "bwd": - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_fp8_func( - q_unpad, - k_unpad, - v_unpad, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - def flash_attn_varlen_fp8_bench_fn(): - dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), do_unpad, retain_graph=True) - return dq_unpad, dk_unpad, dv_unpad - elif mode == "full": - def flash_attn_varlen_fp8_bench_fn(): - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_fp8_func( - q_unpad, - k_unpad, - v_unpad, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), do_unpad, retain_graph=True) - return dq_unpad, dk_unpad, dv_unpad - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_varlen_fp8_bench_fn - elif fn_name == "flash_attn_varlen_qkvpacked_fp8_func": - qkv_unpad, do_unpad, metadata = fn_input - if mode == "fwd": - def flash_attn_varlen_qkvpacked_fp8_bench_fn(): - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_fp8_func( - qkv_unpad, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - return out_unpad - elif mode == "bwd": - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_fp8_func( - qkv_unpad, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - def flash_attn_varlen_qkvpacked_fp8_bench_fn(): - dqkv_unpad = torch.autograd.grad(out_unpad, (qkv_unpad), do_unpad, retain_graph=True) - return dqkv_unpad - elif mode == "full": - def flash_attn_varlen_qkvpacked_fp8_bench_fn(): - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_fp8_func( - qkv_unpad, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - dqkv_unpad = torch.autograd.grad(out_unpad, (qkv_unpad), do_unpad, retain_graph=True) - return dqkv_unpad - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_varlen_qkvpacked_fp8_bench_fn - else: - valid_fn_names = ", ".join(FUNCTIONS) - raise ValueError(f"{fn_name} should be one of the following functions. {valid_fn_names}") - -def get_packing_type(fn_name: str) -> Optional[Literal["kv", "qkv"]]: - if "_kvpacked" in fn_name: - packing = "kv" - elif "_qkvpacked" in fn_name: - packing = "qkv" - else: - packing = None - - return packing - -def load_flash_attn_module(backend: Literal["triton", "ck"], env_configs: Dict = {}, verbose = False): - """ - Load the flash_attn module with the specified backend configuration - """ - - # remove any existing env variables first - for key in ENV_FLAGS: - if key in os.environ: - del os.environ[key] - - # set environment variable for the desired backend - if backend == "triton": - os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE" - os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "0" - os.environ["FLASH_ATTENTION_TRITON_AMD_DEBUG"] = "0" - elif backend == "ck": - os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "FALSE" - else: - raise ValueError(f"Unknown backend {backend}") - - # add custom env configs - add_env_configs(env_configs) - - if verbose: - print(f"Loading flash_attn module with {backend} backend.") - - # Remove any existing flash_attn modules from sys.modules - for module_name in list(sys.modules.keys()): - if module_name.startswith('flash_attn'): - del sys.modules[module_name] - - # Clear CUDA cache - torch.cuda.empty_cache() - - # Import and return the module - import flash_attn - - return flash_attn - -def add_env_configs(env_config: Dict): - for env_key, env_value in env_config.items(): - if env_key in os.environ: - del os.environ[env_key] # remove previous version so that env key is the latest key added - os.environ[env_key] = env_value - -def run_benchmark(func_config: FunctionConfig, input_configs): - """ - Runs the benchmark for the provided function configuration with the given input configurations. - """ - # print new line to seperate benchmark runs - print() - if DEBUG: - print("func_config:", func_config) - - # extract function configuration parameters - fn_name = func_config.fn_name - mode = func_config.mode - dtype = func_config.dtype - backend = func_config.backend - - # load flash attention module - flash_attn_module = load_flash_attn_module(backend, func_config.env_configs, verbose=True) - - # start timing the benchmark - start_time = time.time() - - # print bench fn - print(f"Benchmarking {func_config} ...") - - # Setup benchmark configurations - bench_configs = [ - triton.testing.Benchmark( - x_names=["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K", "D_HEAD", "CAUSAL", "DROPOUT"], - x_vals=list(input_configs.keys()), - line_arg="provider", - line_vals=["triton"], - line_names=["Time (ms)"], - styles=[("red", "-")], - ylabel="ms", - plot_name=f"benchmark-{func_config}", - args={ - }, - ) - ] - - @triton.testing.perf_report(bench_configs) - def bench_function( - BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT, provider, device="cuda" - ): - if DEBUG: - print("BATCH:", BATCH) - print("HQ:", HQ) - print("HK:", HK) - print("N_CTX_Q:", N_CTX_Q) - print("N_CTX_Q:", N_CTX_Q) - print("D_HEAD:", D_HEAD) - print("CAUSAL:", CAUSAL) - print("DROPOUT:", DROPOUT) - print("mode:", mode) - print("provider:", provider) - print("device:", device) - fn_input = input_configs[(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT)] - benchmark_fn = create_benchmark_fn(flash_attn_module, fn_name, fn_input, mode) - - # run the benchmark - ms = triton.testing.do_bench(benchmark_fn, warmup=25, rep=100) - return ms - - df = bench_function.run(save_path=".", print_data=True, return_df=True)[0] - - # set the column name to reflect the function configuration - df = df.rename(columns={"Time (ms)": func_config.column_name()}) - - # calculate and print elapsed time - elapsed_time = time.time() - start_time - print(f"Total time for benchmarking {fn_name} in {mode} mode with {dtype}: {elapsed_time:.2f} seconds") - - return df - -def filter_modes(requested_modes, fn_name, supported_modes_for_fn): - modes_to_run = [] - if requested_modes: - for mode in requested_modes: - if mode in supported_modes_for_fn: - modes_to_run.append(mode) - else: - warning(f"Mode '{mode}' requested but not supported by function '{fn_name}'. Skipping this mode for this function.") - else: - modes_to_run = ["full" if "full" in supported_modes_for_fn else "fwd"] - return modes_to_run - -def get_env_value_combinations(current_backend: Optional[Literal["triton", "ck"]]) -> List[Dict[str, str]]: - # filter environment variations applicable to the current backend - applicable_variations = [ - var_config for var_config in ENV_VARIABLE_CONFIGS - if var_config.backend is None or var_config.backend == current_backend - ] - - if not applicable_variations: - # no applicable variations, return list with empty dict - return [{}] - - # prepare keys and value lists - variation_keys = [v.key for v in applicable_variations] - variation_value_lists = [v.values for v in applicable_variations] - - # generate all combinations as dictionaries directly - env_configs = [] - for value_combination in itertools.product(*variation_value_lists): - env_configs.append(dict(zip(variation_keys, value_combination))) - - return env_configs - -def get_input_config_set(config_type): - if config_type == "llama": - # batch, hq, hk, sq, sk, d_head, causal, dropout - input_configs = [ - # LLaMA 3 8B - (1, 32, 8, 8192, 8192, 128, True, 0.0), - # LLaMA 3 70B - (1, 64, 8, 8192, 8192, 128, True, 0.0), - ] - else: - raise ValueError(f"Unknown input config: {config_type}") - - return input_configs - - -def process_args(): - """ - Parses command-line arguments and returns function configs and input configs. - """ - # create parser - parser = argparse.ArgumentParser( - prog="Benchmark FlashAttention", - allow_abbrev=False, - ) - # functions - parser.add_argument( - "-benchmark_fn", - type=str, - nargs="*", - choices=FUNCTIONS, - required=True, - help=f"Function(s) to benchmark", - ) - parser.add_argument( - "--mode", - type=str, - nargs='*', - choices=VALID_MODES, - default=None, - help=f"Benchmarking mode(s) to run. If omitted, runs all supported modes for each function.", - ) - # config - parser.add_argument("-b", type=int, default=None, help="Batch size") - parser.add_argument("-hq", type=int, default=None, help="Q Number of heads") - parser.add_argument("-hk", type=int, default=None, help="K and V Number of heads") - parser.add_argument("-sq", type=int, default=None, help="Q Sequence Length") - parser.add_argument("-sk", type=int, default=None, help="K and V Sequence Length") - parser.add_argument("-d", type=int, default=None, help="Head Dimension") - parser.add_argument("-causal", action="store_true", default=None, help="Causal") - parser.add_argument("-dropout", type=float, default=None, help="Dropout") - - # parse args - args = parser.parse_args() - - # parse function args - benchmark_fns = args.benchmark_fn - requested_modes = args.mode - - # fenerate function configurations and input configurations separately - all_function_configs = [] - all_input_configs = {} # Maps function config -> input configs - for fn_name in benchmark_fns: - is_varlen, is_fp8, packing, supported_dtypes, supported_backends, supported_modes_for_fn, supported_env_configs, device = get_fn_params(fn_name) - - # Generate or use custom input configurations - if args.b or args.hq or args.hk or args.sq or args.sk or args.d: - assert args.b and args.hq and args.sq and args.d, ( - "if custom config is specified, please provide at least batch, number of Q heads, Q sequence length, and head size." - ) - - batch = args.b - hq = args.hq - hk = args.hk if args.hk is not None else args.hq - sq = args.sq - sk = args.sk if args.sk is not None else args.sq - d_head = args.d - causal = args.causal if args.causal is not None else False - dropout = args.dropout if args.dropout is not None else 0.0 - input_configs = [(batch, hq, hk, sq, sk, d_head, causal, dropout)] - else: - if True: - input_configs = get_input_config_set("llama") - else: - input_configs = generate_benchmark_configs(is_varlen, packing) - - # filter by mode - modes_to_run = filter_modes(requested_modes, fn_name, supported_modes_for_fn) - if not modes_to_run: - warning(f"No valid modes to run for function '{fn_name}' based on request and function support. Skipping this function.") - continue - - # create a function config for each backend and dtype combination - for backend in supported_backends: - for dtype in supported_dtypes: - for mode in modes_to_run: - for env_config in supported_env_configs[backend]: - func_config = FunctionConfig(fn_name, mode, dtype, backend, env_config) - all_function_configs.append(func_config) - - # Generate inputs for this function configuration - fn_inputs = {} - for input_config in input_configs: - fn_inputs[input_config] = generate_fn_inputs(fn_name, *input_config, dtype, device) - - all_input_configs[func_config] = fn_inputs - - return all_function_configs, all_input_configs - -def check_environment_variables(): - for key in ENV_FLAGS: - if key in os.environ: - raise ValueError(f"Running with {key} environment variable is not recommended for the benching script. Use --help to see how to use the benching script.") - -def main(): - """ - Main function to run benchmarks. - """ - # check environment variables - check_environment_variables() - - # start timing the entire benchmarking process - total_start_time = time.time() - - # process args to get function configs and input configs - function_configs, all_input_configs = process_args() - - # Check if we have multiple function configurations - has_multiple_func_configs = len(function_configs) > 1 - combined_df = None - - # run benchmarks for each function configuration - for func_config in function_configs: - # run benchmark with the input configs for this function config - input_configs = all_input_configs[func_config] - df = run_benchmark(func_config, input_configs) - - # Define the columns that represent input configurations - input_config_cols = ["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K", "D_HEAD", "CAUSAL", "DROPOUT"] - - # merge into one final dataframe - if combined_df is None: - combined_df = df - else: - # Ensure we're joining on input configuration columns - combined_df = combined_df.merge(df, on=input_config_cols, how="outer") - - - # print new line to seperate the combined data information from the benchmark specific information - print() - - # print total time for all benchmarks - total_elapsed_time = time.time() - total_start_time - print(f"Total time for all benchmarks: {total_elapsed_time:.2f} seconds") - - # save combined data and make comparisons if we have multiple function configs - if has_multiple_func_configs: - if len(function_configs) == 2: - func1 = function_configs[0] - func2 = function_configs[1] - - # construct column names for the timing results - col1 = func1.column_name() - col2 = func2.column_name() - - # Check if we're comparing triton vs ck (in either order) - is_triton_vs_ck = ( - (func1.backend == "triton" and func2.backend == "ck") or - (func1.backend == "ck" and func2.backend == "triton") - ) - - # For triton vs ck comparisons - if is_triton_vs_ck: - # For triton vs ck comparisons, always make triton the baseline - if func1.backend == "triton" and func2.backend == "ck": - triton_col = col1 - ck_col = col2 - ratio_col = f"ck_to_triton_ratio" - else: - triton_col = col2 - ck_col = col1 - ratio_col = f"ck_to_triton_ratio" - - # Calculate ratio: ck_time / triton_time (values > 1 mean triton is faster) - combined_df[ratio_col] = combined_df[ck_col] / combined_df[triton_col] - - # print explanation - print(f"Comparison Results (triton vs ck):") - print(f"Ratio values: values > 1 mean triton is faster (by that factor), values < 1 mean ck is faster") - elif False: - # For other comparisons, use the standard approach - ratio_col = f"{func1}_to_{func2}_ratio" - - # Calculate the ratio - combined_df[ratio_col] = combined_df[col2] / combined_df[col1] - - # print explanation - print(f"Comparison Results ({func1} vs {func2}):") - print(f"Ratio values: values > 1 mean {func1} is faster than {func2} (by that factor), values < 1 mean slower") - - print(f"Combined data:") - print(combined_df) - - # save csv & markdown - combined_filename = f"benchmark_combined" - combined_df.to_csv(f"{combined_filename}.csv", index=False) - with open(f"{combined_filename}.md", 'w') as f: - f.write(combined_df.to_markdown(index=False, floatfmt=".2f")) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py new file mode 100755 index 00000000000..87dc49fc9bc --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -0,0 +1,4880 @@ +import os +import torch +import triton +import triton.language as tl +import warnings +from typing import Literal, Optional +from .common import compute_fp8_scaling_factors +from .utils import ( + DEBUG, + AUTOTUNE, + is_fp8, + get_arch, +) + +PREPROCESS_AUTOTUNE_KEYS = [ + "max_seqlen_q", + "ACTUAL_HEAD_DIM", + "IS_VARLEN", +] + +CAUSAL_AUTOTUNE_KEYS = [ + "dropout_p", + "max_seqlen_q", + "max_seqlen_k", + "ACTUAL_HEAD_DIM", + "IS_VARLEN", + "HQ", + "HK", +] + +NONCAUSAL_AUTOTUNE_KEYS = [ + "dropout_p", + "max_seqlen_q", + "max_seqlen_k", + "ACTUAL_HEAD_DIM", + "IS_VARLEN", + "HQ", + "HK", +] + + +def get_bwd_configs(autotune: bool): + + # default config + if not autotune: + arch = get_arch() + + # configs for the kernels + if arch.name == "gfx942": + if arch.cu_count < 304: + preprocess_configs = [ + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 1}, num_stages=1, num_warps=8 + ), + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8 + ), + triton.Config( + {"PRE_BLOCK": 128, "waves_per_eu": 2}, num_stages=1, num_warps=4 + ), + ] + noncausal_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 32, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 32, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=8, + ), + ] + causal_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + ] + else: + preprocess_configs = [ + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8 + ), + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 1}, num_stages=1, num_warps=4 + ), + ] + noncausal_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + ] + causal_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + ] + elif arch.name == "gfx950": + preprocess_configs = [ + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8 + ), + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=1, num_warps=8 + ), + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=4 + ), + ] + noncausal_configs = [ + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 128, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 16, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 2, + }, + num_stages=1, + num_warps=4, + ), + ] + causal_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + ] + elif arch.is_rdna: + preprocess_configs = [ + triton.Config( + {"PRE_BLOCK": 32}, num_stages=1, num_warps=4 + ), + ] + noncausal_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 32, + "BLOCK_M2": 32, + "BLOCK_N2": 32, + "BLK_SLICE_FACTOR": 2, + }, + num_stages=1, + num_warps=4, + ), + ] + causal_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 32, + "BLOCK_M2": 32, + "BLOCK_N2": 32, + "BLK_SLICE_FACTOR": 2, + }, + num_stages=1, + num_warps=4, + ), + ] + else: + preprocess_configs = [ + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8 + ), + ] + noncausal_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + ] + causal_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + ] + + # assert constraints + for noncausal_cfg, causal_cfg in zip(noncausal_configs, causal_configs): + assert ( + noncausal_cfg.all_kwargs()["BLOCK_N1"] + == noncausal_cfg.all_kwargs()["BLOCK_M2"] + ), f"BLOCK_N1 ({noncausal_cfg.all_kwargs()['BLOCK_N1']}) must equal BLOCK_M2 ({noncausal_cfg.all_kwargs()['BLOCK_M2']})" + assert ( + causal_cfg.all_kwargs()["BLOCK_N1"] + == causal_cfg.all_kwargs()["BLOCK_M2"] + ), f"BLOCK_N1 ({causal_cfg.all_kwargs()['BLOCK_N1']}) must equal BLOCK_M2 ({causal_cfg.all_kwargs()['BLOCK_M2']})" + + return (preprocess_configs, causal_configs, noncausal_configs) + + # ===================== Autotune Sweep ===================== + # param options + PRE_BLOCK_OPTIONS = [64, 128] # og: 128 + PRE_WAVES_PER_EU_OPTIONS = [1, 2] + PRE_NUM_STAGES_OPTIONS = [1, 2] + PRE_NUM_WARPS_OPTIONS = [4, 8] + NUM_STAGES_OPTIONS = [1, 2] # og: 1 + NUM_WARPS_OPTIONS = [4, 8] # og: 4 + WAVES_PER_EU_OPTIONS = [1, 2] # og: 1 + NON_CAUSAL_BLOCK_M1_OPTIONS = [16, 32, 64, 128] # og: 32 + NON_CAUSAL_BLOCK_N1_M2_OPTIONS = [32, 64, 128, 256] # og: 128 + NON_CAUSAL_BLOCK_N2_OPTIONS = [16, 32, 64, 128] # og: 32 + CAUSAL_BLOCK_M1_OPTIONS = [ # og: 32 + 32, + 64 + ] + CAUSAL_BLOCK_N1_M2_OPTIONS = [32, 64, 128] # og: 128 + CAUSAL_BLOCK_N2_OPTIONS = [32, 64] # og: 32 + BLK_SLICE_FACTOR_OPTIONS = [2] # og: 2 + + # ==================== sweep configs ================================ + preprocess_autotune_configs = [] + for pre_num_warps in PRE_NUM_WARPS_OPTIONS: + for pre_num_stages in PRE_NUM_STAGES_OPTIONS: + for pre_waves in PRE_WAVES_PER_EU_OPTIONS: + for pre_block in PRE_BLOCK_OPTIONS: + preprocess_autotune_configs.append( + triton.Config( + { + "PRE_BLOCK": pre_block, + "waves_per_eu": pre_waves, + }, + num_stages=pre_num_stages, + num_warps=pre_num_warps, + ) + ) + + causal_autotune_configs = [] + for num_warps in NUM_WARPS_OPTIONS: + for num_stages in NUM_STAGES_OPTIONS: + for waves in WAVES_PER_EU_OPTIONS: + for m1 in CAUSAL_BLOCK_M1_OPTIONS: + for n1 in CAUSAL_BLOCK_N1_M2_OPTIONS: + m2 = n1 + for n2 in CAUSAL_BLOCK_N2_OPTIONS: + # Ensure constraint + assert ( + n1 == m2 + ), f"BLOCK_N1 ({n1}) must equal BLOCK_M2 ({m2})" + + # Skip configs where BLOCK_M2 % BLOCK_N2 != 0 + if m2 % n2 != 0: + continue + + # Skip configs where BLOCK_N1 % BLOCK_M1 != 0 + if n1 % m1 != 0: + continue + + for blk_slice in BLK_SLICE_FACTOR_OPTIONS: + causal_autotune_configs.append( + triton.Config( + { + "BLOCK_M1": m1, + "BLOCK_N1": n1, + "BLOCK_M2": m2, + "BLOCK_N2": n2, + "BLK_SLICE_FACTOR": blk_slice, + "waves_per_eu": waves, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + + noncausal_autotune_configs = [] + for num_warps in NUM_WARPS_OPTIONS: + for num_stages in NUM_STAGES_OPTIONS: + for waves in WAVES_PER_EU_OPTIONS: + for m1 in NON_CAUSAL_BLOCK_M1_OPTIONS: + for n1 in NON_CAUSAL_BLOCK_N1_M2_OPTIONS: + m2 = n1 + for n2 in NON_CAUSAL_BLOCK_N2_OPTIONS: + # Ensure constraint + assert ( + n1 == m2 + ), f"BLOCK_N1 ({n1}) must equal BLOCK_M2 ({m2})" + + # Skip configs where BLOCK_M2 % BLOCK_N2 != 0 + if m2 % n2 != 0: + continue + + # Skip configs where BLOCK_N1 % BLOCK_M1 != 0 + if n1 % m1 != 0: + continue + + for blk_slice in BLK_SLICE_FACTOR_OPTIONS: + noncausal_autotune_configs.append( + triton.Config( + { + "BLOCK_M1": m1, + "BLOCK_N1": n1, + "BLOCK_M2": m2, + "BLOCK_N2": n2, + "BLK_SLICE_FACTOR": blk_slice, + "waves_per_eu": waves, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + + return ( + preprocess_autotune_configs, + causal_autotune_configs, + noncausal_autotune_configs, + ) + +# os.environ["TRITON_PRINT_AUTOTUNING"] = "1" +( + preprocess_autotune_configs, + causal_autotune_configs, + noncausal_autotune_configs, +) = get_bwd_configs(AUTOTUNE) + + +@triton.jit +def _bwd_dq_inner_split( + dq, + q, + K, + V, + do, + m, + Delta, + sm_scale, + stride_qm, + stride_qk, + stride_kn, + stride_kk, + stride_vn, + stride_vk, + stride_dropout_m, + stride_dropout_n, + stride_deltam, + seqlen_q, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + RCP_LN2: tl.constexpr = 1.4426950408889634 + + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + + # mask to make sure not OOB of seqlen_q + mask_m = offs_m < seqlen_q + + kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk + vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk + + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0) + + curr_n = start_n + step_n = BLOCK_N + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + for blk_idx in range(num_steps): + offs_n = curr_n + tl.arange(0, BLOCK_N) + # end_n is needed because the end of causal True might not be perfectly + # aligned with the end of the block + mask_n = offs_n < end_n + mask_kT = mask_n[None, :] + mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) + if PADDED_HEAD: + mask_kT &= offs_k[:, None] < BLOCK_D_MODEL + + kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) + vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) + + # dropout + if ENABLE_DROPOUT: + philox_offs = ( + curr_philox_offset + + offs_m[:, None] * stride_dropout_m + + offs_n[None, :] * stride_dropout_n + ) + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1 / (1 - dropout_p) + + # qk + if IS_FP8: + qk = tl.dot(q, kT) * descale_q * descale_k + else: + qk = tl.dot(q, kT) + p = tl.math.exp2(qk * sm_scale * RCP_LN2 - m * RCP_LN2) + + if MASK: + causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] + mask = causal_mask * mask_mn + p = tl.where(mask, p, 0.0) + + # dp + if IS_FP8: + dp = tl.dot(do.to(vT.type.element_ty), vT) * descale_v + else: + dp = tl.dot(do, vT) + + if ENABLE_DROPOUT: + dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale + + # ds + delta_i = Di[:, None] + ds = p * (dp - delta_i) + + # dq + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + if IS_FP8: + # Rewrite dq += ds @ kT.T as dq += (kT @ ds.T).T + # This puts FP8 tensor (kT) on LHS of dot product + # Cast the transposed ds to FP8 to match kT's dtype + ds_transposed = tl.trans(ds).to(kT.type.element_ty) + dq += tl.trans(tl.dot(kT, ds_transposed)) * descale_k + else: + dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) + + curr_n += step_n + kT_ptrs += step_n * stride_kn + vT_ptrs += step_n * stride_vn + return dq + + +@triton.jit +def _bwd_dkdv_inner_split( + dk, + dv, + Q, + k, + v, + DO, + M, + D, + sm_scale, + stride_q_m, + stride_q_k, + stride_do_m, + stride_do_k, + stride_dropout_m, + stride_dropout_n, + stride_deltam, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + seqlen_q, + seqlen_k, + start_n, + start_m, + num_steps, + descale_q, + descale_k, + descale_v, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + + # mask to make sure not OOB of seqlen_q + mask_n = offs_n < seqlen_k + qT_ptrs = ( + Q + offs_m[None, :] * stride_q_m + offs_k[:, None] * stride_q_k + ) # [BLOCK_D_MODEL_POW2, BLOCK_M] + do_ptrs = DO + offs_m[:, None] * stride_do_m + offs_k[None, :] * stride_do_k + curr_m = start_m + step_m = BLOCK_M + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 + + # Iterate over blocks(BLOCK_M size) of Q while calculating + # a fixed block(BLOCK_N) of dk and dv. Note, during backward + # pass P has to be recomputed. However, this kernel computes + # dV and dK, so we compute we need P^T and S^T. See backward pass + # equations + # + # From Flash Attention Paper: + # ForwardPass: S = QkT, P=softmax(S), O=PV + # + # BackwardPass equations + # dV = P^TdO + # dP = dOV^T + # dS = dsoftmax(dP) + # dQ = dSK + # dK = QdS^T + for blk_idx in range(num_steps): + offs_m = curr_m + tl.arange(0, BLOCK_M) + mask_m = offs_m < seqlen_q + mask_qT = mask_m[None, :] + mask_do = mask_m[:, None] + mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) + if PADDED_HEAD: + mask_qT &= offs_k[:, None] < BLOCK_D_MODEL + mask_do &= offs_k[None, :] < BLOCK_D_MODEL + + # load qT + qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + + # dropout + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = ( + curr_philox_offset + + offs_m[None, :] * stride_dropout_m + + offs_n[:, None] * stride_dropout_n + ) + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1.0 / (1 - dropout_p) + + # Load M + m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) + + # Compute qkT + if IS_FP8: + qkT = tl.dot(k, qT) * descale_q * descale_k + else: + qkT = tl.dot(k, qT) + + # Compute pT(use m and also apply sm_scale) + pT = tl.math.exp(qkT * sm_scale - m[None, :]) + + if MASK: + causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] + mask = causal_mask & mask_nm + pT = tl.where(mask, pT, 0.0) + + # load DO + do = tl.load(do_ptrs, mask=mask_do, other=0.0) + + # dV + if ENABLE_DROPOUT: + pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + else: + dv += tl.dot(pT.to(do.type.element_ty), do) + + # Load delta + Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) + + # Compute dP and dS + if IS_FP8: + dpT = tl.dot(v, tl.trans(do.to(v.type.element_ty))) * descale_v + else: + dpT = tl.dot(v, tl.trans(do)) + + if ENABLE_DROPOUT: + dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + + delta_i = Di[None, :] + dsT = pT * (dpT - delta_i) + + # compute dk + if IS_FP8: + # Rewrite dk += dsT @ qT.T as dk += (qT @ dsT.T).T + # This puts FP8 tensor (qT) on LHS of dot product + # Cast the transposed dsT to FP8 to match qT's dtype + dsT_transposed = tl.trans(dsT).to(qT.type.element_ty) + dk += tl.trans(tl.dot(qT, dsT_transposed)) * descale_q + else: + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + + # increment pointers + curr_m += step_m + qT_ptrs += step_m * stride_q_m + do_ptrs += step_m * stride_do_m + + return dk, dv + + +@triton.jit +def _bwd_dkdvdq_inner_atomic( + dk, + dv, + Q, + k, + v, + DO, + DQ, + M, + D, + sm_scale, + stride_q_m, + stride_q_k, + stride_dq_m, + stride_dq_k, + stride_do_m, + stride_do_k, + stride_dropout_m, + stride_dropout_n, + stride_deltam, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + seqlen_q, + seqlen_k, + start_n, + start_m, + num_steps, + descale_q, + descale_k, + descale_v, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + workgroup_id: tl.int32, +): + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + + # mask to make sure not OOB of seqlen_q + mask_n = offs_n < seqlen_k + + qT_ptrs_start = ( + Q + offs_m[None, :] * stride_q_m + offs_k[:, None] * stride_q_k + ) # [BLOCK_D_MODEL_POW2, BLOCK_M] + dq_ptrs_start = ( + DQ + offs_m[:, None] * stride_dq_m + offs_k[None, :] * stride_dq_k + ) # [BLOCK_M, BLOCK_D_MODEL_POW2] + + do_ptrs_start = DO + offs_m[:, None] * stride_do_m + offs_k[None, :] * stride_do_k + curr_m = start_m + step_m = BLOCK_M + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 + + # Iterate over blocks(BLOCK_M size) of Q while calculating + # a fixed block(BLOCK_N) of dk and dv. Note, during backward + # pass P has to be recomputed. However, this kernel computes + # dV and dK, so we compute we need P^T and S^T. See backward pass + # equations + # + # From Flash Attention Paper: + # ForwardPass: S = QkT, P=softmax(S), O=PV + # + # BackwardPass equations + # dV = P^TdO + # dP = dOV^T + # dS = dsoftmax(dP) + # dQ = dSK + # dK = QdS^T + + # Compute a starting index and step based on workgroup_id + # Use a simple hash-like function to spread out the starting points + start_idx = ( + workgroup_id * 17 + ) % num_steps # 17 is an arbitrary prime to spread indices + # Ensure step is coprime with num_steps to visit all indices exactly once + step = 1 # 3 if num_steps > 1 or num_steps==3 else 1 # coprime with num_steps + + for iter in range(num_steps): + # Compute the permuted block index + blk_idx = (start_idx + iter * step) % num_steps + + curr_m = start_m + blk_idx * step_m + qT_ptrs = qT_ptrs_start + blk_idx * step_m * stride_q_m + dq_ptrs = dq_ptrs_start + blk_idx * step_m * stride_dq_m + do_ptrs = do_ptrs_start + blk_idx * step_m * stride_do_m + + offs_m = curr_m + tl.arange(0, BLOCK_M) + mask_m = offs_m < seqlen_q + mask_qT = mask_m[None, :] + mask_do = mask_m[:, None] + mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) + + if PADDED_HEAD: + mask_qT &= offs_k[:, None] < BLOCK_D_MODEL + mask_do &= offs_k[None, :] < BLOCK_D_MODEL + + # load qT + qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + + # dropout + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = ( + curr_philox_offset + + offs_m[None, :] * stride_dropout_m + + offs_n[:, None] * stride_dropout_n + ) + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1.0 / (1 - dropout_p) + + # Load M + m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) + + # Compute qkT + if IS_FP8: + qkT = tl.dot(k, qT) * descale_q * descale_k + else: + qkT = tl.dot(k, qT) + + # Compute pT(use m and also apply sm_scale) + pT = tl.math.exp(qkT * sm_scale - m[None, :]) + + if MASK: + causal_mask = (offs_m[None, :] - delta_qk) >= (offs_n[:, None]) + mask = causal_mask & mask_nm + pT = tl.where(mask, pT, 0.0) + + # load DO + do = tl.load(do_ptrs, mask=mask_do, other=0.0) + + # dV + if ENABLE_DROPOUT: + pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + else: + dv += tl.dot(pT.to(do.type.element_ty), do) + + # Load delta + Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) + + # Compute dP and dS + if IS_FP8: + dpT = tl.dot(v, tl.trans(do.to(v.type.element_ty))) * descale_v + else: + dpT = tl.dot(v, tl.trans(do)) + + if ENABLE_DROPOUT: + dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + + delta_i = Di[None, :] + dsT = pT * (dpT - delta_i) + + # compute dk + if IS_FP8: + # Rewrite dk += dsT @ qT.T as dk += (qT @ dsT.T).T + # This puts FP8 tensor (qT) on LHS of dot product + # Cast the transposed dsT to FP8 to match qT's dtype + dsT_transposed = tl.trans(dsT).to(qT.type.element_ty) + dk += tl.trans(tl.dot(qT, dsT_transposed)) * descale_q + else: + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + + # We can compute the dq_partial here and do a atomic add to the correct memory location + # NOTE: Possible problems with the atomic add: contention, is inside a loop which has achieved bad perf before + # (BLOCK_M, BLOCK_N) x (BLOCK_N, D) + if IS_FP8: + dq_partial = tl.dot(dsT.to(k.type.element_ty).T, k) * descale_k + else: + dq_partial = tl.dot(dsT.to(k.type.element_ty).T, k) + tl.atomic_add( + dq_ptrs, + dq_partial * sm_scale, + mask=mask_m[:, None], + sem="relaxed", + ) + + return dk, dv + + +@triton.jit +def _bwd_kernel_fused_atomic_causal( + q_ptr, + k_ptr, + v_ptr, + sm_scale, + do_ptr, + dk_ptr, + dv_ptr, + dq_ptr, + m_ptr, + delta_ptr, + stride_q_b, + stride_q_h, + stride_q_m, + stride_q_k, + stride_k_b, + stride_k_h, + stride_k_n, + stride_k_k, + stride_v_b, + stride_v_h, + stride_v_n, + stride_v_k, + stride_dk_b, + stride_dk_h, + stride_dk_n, + stride_dk_k, + stride_dq_b, + stride_dq_h, + stride_dq_m, + stride_dq_k, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_do_b, + stride_do_h, + stride_do_m, + stride_do_k, + stride_dropout_b, + stride_dropout_h, + stride_dropout_m, + stride_dropout_n, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset_base, + descale_q_ptr, + descale_k_ptr, + descale_v_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BATCH, + NUM_K_PIDS, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + wid = tl.program_id(0) # workgoup id: 0, ..., NUM_K_PIDS * BATCH * NUM_K_HEADS - 1 + + # workgroups get launched first along batch dim, then in head_k dim, and then in seq k block dim + batch_idx = wid % BATCH + head_k_idx = wid // BATCH % NUM_K_HEADS + seq_k_blk_idx = wid // (BATCH * NUM_K_HEADS) % NUM_K_PIDS + + # Determine q and k start along with seqlen_q and seqlen_k + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + batch_idx) + q_end = tl.load(cu_seqlens_q + batch_idx + 1) + k_start = tl.load(cu_seqlens_k + batch_idx) + k_end = tl.load(cu_seqlens_k + batch_idx + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + # Figure out causal starting block since we have seqlen_q >=< seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + delta_qk = seqlen_q - seqlen_k + + # q > k: diretcly skip all the way until the start of causal block + start_delta_q_gt_k = delta_qk + + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N + delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M + if delta_qk >= 0: + start_delta = delta_qk + else: + start_delta = start_delta_q_lt_k + + start_n = seq_k_blk_idx * BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask_k = offs_k < BLOCK_D_MODEL + mask_kv &= mask_k[None, :] + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = ( + batch_idx * stride_k_b + + head_k_idx * stride_k_h + + k_start * stride_k_n + + offs_n[:, None] * stride_k_n + + offs_k[None, :] * stride_k_k + ) + adj_v = ( + batch_idx * stride_v_b + + head_k_idx * stride_v_h + + k_start * stride_v_n + + offs_n[:, None] * stride_v_n + + offs_k[None, :] * stride_v_k + ) + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(k_ptr + adj_k, mask=mask_kv, other=0.0) + v = tl.load(v_ptr + adj_v, mask=mask_kv, other=0.0) + + # If MQA / GQA, set the K and V head offsets appropriately. + for head_q_idx in range( + head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE + ): + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N + else: + start_m = max(start_n + delta_qk, 0) + start_m = (start_m // BLOCK_M) * BLOCK_M + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N + residue_m + + # offset input and output tensor by batch and Q/K heads + adj_q = batch_idx * stride_q_b + head_q_idx * stride_q_h + q_start * stride_q_m + adj_dq = ( + batch_idx * stride_dq_b + head_q_idx * stride_dq_h + q_start * stride_dq_m + ) + + q_ptr_adj = q_ptr + adj_q + dq_ptr_adj = dq_ptr + adj_dq + + adj_do = ( + batch_idx * stride_do_b + head_q_idx * stride_do_h + q_start * stride_do_m + ) + do_ptr_adj = do_ptr + adj_do + adj_delta = ( + batch_idx * stride_delta_b + + head_q_idx * stride_delta_h + + q_start * stride_delta_m + ) + m_ptr_adj = m_ptr + adj_delta + delta_ptr_adj = delta_ptr + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h + ) + dropout_offset = ( + dropout_mask + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h + ) + + MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M) + + # when q < k, we may skip the initial masked op + # if seq_k_blk_idx < num_blocks_skip: + # num_steps = 0 + + if IS_FP8: + descale_q = tl.load( + descale_q_ptr + batch_idx * stride_descale_q_z + head_q_idx + ) + descale_k = tl.load( + descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx + ) + descale_v = tl.load( + descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx + ) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + # if unaligned start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + dk, dv = _bwd_dkdvdq_inner_atomic( + dk, + dv, # output tensors + q_ptr_adj, + k, + v, + do_ptr_adj, + dq_ptr_adj, + m_ptr_adj, + delta_ptr_adj, + sm_scale, # input tensors + stride_q_m, + stride_q_k, # strides for q + stride_dq_m, + stride_dq_k, # strides for q + stride_do_m, + stride_do_k, # strides for o + stride_dropout_m, + stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, # + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + MASK_BLOCK_M, + BLOCK_N, # block dim + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, # head dim + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + workgroup_id=seq_k_blk_idx, + ) + + start_m += num_steps * MASK_BLOCK_M + num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) + end_m = start_m + num_steps * BLOCK_M + + dk, dv = _bwd_dkdvdq_inner_atomic( + dk, + dv, # output tensors + q_ptr_adj, + k, + v, + do_ptr_adj, + dq_ptr_adj, + m_ptr_adj, + delta_ptr_adj, + sm_scale, # input tensors + stride_q_m, + stride_q_k, # strides for q + stride_dq_m, + stride_dq_k, # strides for dq + stride_do_m, + stride_do_k, # strides for o + stride_dropout_m, + stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, # + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + BLOCK_M, + BLOCK_N, # block dim + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, # head dim + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + workgroup_id=seq_k_blk_idx, + ) + + # Write back dV and dK. + offs_dkdv = ( + batch_idx * stride_dk_b + + head_k_idx * stride_dk_h + + k_start * stride_dk_n + + offs_n[:, None] * stride_dk_n + + offs_k[None, :] * stride_dk_k + ) + tl.store(dv_ptr + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(dk_ptr + offs_dkdv, dk, mask=mask_kv) + + +@triton.jit +def _bwd_kernel_split_dkdv_causal( + q_ptr, + k_ptr, + v_ptr, + sm_scale, + do_ptr, + dk_ptr, + dv_ptr, + m_ptr, + delta_ptr, + stride_q_b, + stride_q_h, + stride_q_m, + stride_q_k, + stride_k_b, + stride_k_h, + stride_k_n, + stride_k_k, + stride_v_b, + stride_v_h, + stride_v_n, + stride_v_k, + stride_dk_b, + stride_dk_h, + stride_dk_n, + stride_dk_k, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_do_b, + stride_do_h, + stride_do_m, + stride_do_k, + stride_dropout_b, + stride_dropout_h, + stride_dropout_m, + stride_dropout_n, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset_base, + descale_q_ptr, + descale_k_ptr, + descale_v_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + # seq block, batch, head_k + seq_k_blk_idx = tl.program_id(0) + batch_idx = tl.program_id(1) + head_k_idx = tl.program_id(2) + + # Determine q and k start along with seqlen_q and seqlen_k + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + batch_idx) + q_end = tl.load(cu_seqlens_q + batch_idx + 1) + k_start = tl.load(cu_seqlens_k + batch_idx) + k_end = tl.load(cu_seqlens_k + batch_idx + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + # Figure out causal starting block since we have seqlen_q >=< seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + delta_qk = seqlen_q - seqlen_k + + # q > k: diretcly skip all the way until the start of causal block + start_delta_q_gt_k = delta_qk + + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N + delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M + if delta_qk >= 0: + start_delta = delta_qk + else: + start_delta = start_delta_q_lt_k + + start_n = seq_k_blk_idx * BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask_k = offs_k < BLOCK_D_MODEL + mask_kv &= mask_k[None, :] + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = ( + batch_idx * stride_k_b + + head_k_idx * stride_k_h + + k_start * stride_k_n + + offs_n[:, None] * stride_k_n + + offs_k[None, :] * stride_k_k + ) + adj_v = ( + batch_idx * stride_v_b + + head_k_idx * stride_v_h + + k_start * stride_v_n + + offs_n[:, None] * stride_v_n + + offs_k[None, :] * stride_v_k + ) + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(k_ptr + adj_k, mask=mask_kv, other=0.0) + v = tl.load(v_ptr + adj_v, mask=mask_kv, other=0.0) + + # If MQA / GQA, set the K and V head offsets appropriately. + for head_q_idx in range( + head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE + ): + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N + else: + start_m = max(start_n + delta_qk, 0) + start_m = start_m // BLOCK_M * BLOCK_M + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N + residue_m + + # offset input and output tensor by batch and Q/K heads + adj_q = batch_idx * stride_q_b + head_q_idx * stride_q_h + q_start * stride_q_m + q_ptr_adj = q_ptr + adj_q + adj_do = ( + batch_idx * stride_do_b + head_q_idx * stride_do_h + q_start * stride_do_m + ) + do_ptr_adj = do_ptr + adj_do + adj_delta = ( + batch_idx * stride_delta_b + + head_q_idx * stride_delta_h + + q_start * stride_delta_m + ) + m_ptr_adj = m_ptr + adj_delta + delta_ptr_adj = delta_ptr + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h + ) + dropout_offset = ( + dropout_mask + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h + ) + + MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M) + # when q < k, we may skip the initial masked op + if seq_k_blk_idx < num_blocks_skip: + num_steps = 0 + + if IS_FP8: + descale_q = tl.load( + descale_q_ptr + batch_idx * stride_descale_q_z + head_q_idx + ) + descale_k = tl.load( + descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx + ) + descale_v = tl.load( + descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx + ) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + # if start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + dk, dv = _bwd_dkdv_inner_split( + dk, + dv, # output tensors + q_ptr_adj, + k, + v, + do_ptr_adj, + m_ptr_adj, + delta_ptr_adj, + sm_scale, # input tensors + stride_q_m, + stride_q_k, # strides for q + stride_do_m, + stride_do_k, # strides for o + stride_dropout_m, + stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, # + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + MASK_BLOCK_M, + BLOCK_N, # block dim + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, # head dim + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + start_m += num_steps * MASK_BLOCK_M + num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) + end_m = start_m + num_steps * BLOCK_M + + dk, dv = _bwd_dkdv_inner_split( + dk, + dv, # output tensors + q_ptr_adj, + k, + v, + do_ptr_adj, + m_ptr_adj, + delta_ptr_adj, + sm_scale, # input tensors + stride_q_m, + stride_q_k, # strides for q + stride_do_m, + stride_do_k, # strides for o + stride_dropout_m, + stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, # + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + BLOCK_M, + BLOCK_N, # block dim + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, # head dim + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + + # Write back dV and dK. + offs_dkdv = ( + batch_idx * stride_dk_b + + head_k_idx * stride_dk_h + + k_start * stride_dk_n + + offs_n[:, None] * stride_dk_n + + offs_k[None, :] * stride_dk_k + ) + tl.store(dv_ptr + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(dk_ptr + offs_dkdv, dk, mask=mask_kv) + + +@triton.jit +def _bwd_kernel_split_dq_causal( + q_ptr, + k_ptr, + v_ptr, + sm_scale, + do_ptr, + dq_ptr, + m_ptr, + delta_ptr, + stride_q_b, + stride_q_h, + stride_q_m, + stride_q_k, + stride_k_b, + stride_k_h, + stride_k_n, + stride_k_k, + stride_v_b, + stride_v_h, + stride_v_n, + stride_v_k, + stride_dq_b, + stride_dq_h, + stride_dq_m, + stride_dq_k, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_do_b, + stride_do_h, + stride_do_m, + stride_do_k, + stride_dropout_b, + stride_dropout_h, + stride_dropout_m, + stride_dropout_n, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset_base, + descale_q_ptr, + descale_k_ptr, + descale_v_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + seq_q_blk_idx = tl.program_id(0) + batch_idx = tl.program_id(1) + head_k_idx = tl.program_id(2) + + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + batch_idx) + q_end = tl.load(cu_seqlens_q + batch_idx + 1) + k_start = tl.load(cu_seqlens_k + batch_idx) + k_end = tl.load(cu_seqlens_k + batch_idx + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + # Figure out causal starting block since we have seqlen_q <=> seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + # DQ tiles on M dim and iterate on N dim, so we there could be some tiles we + # can simply skip and we need to adjust starting position. + start_m = seq_q_blk_idx * BLOCK_M + # seqlen_q > seqlen_k, no need to process these tile for dq + delta_qk = seqlen_q - seqlen_k + if start_m + BLOCK_M < delta_qk: + return + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_m = start_m + tl.arange(0, BLOCK_M) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask_k = offs_k < BLOCK_D_MODEL + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_q_m + offs_k[None, :] * stride_q_k + offs_do = offs_m[:, None] * stride_do_m + offs_k[None, :] * stride_do_k + adj_k = batch_idx * stride_k_b + head_k_idx * stride_k_h + k_start * stride_k_n + adj_v = batch_idx * stride_v_b + head_k_idx * stride_v_h + k_start * stride_v_n + k_ptr_adj = k_ptr + v_ptr_adj = v_ptr + k_ptr_adj += adj_k + v_ptr_adj += adj_v + + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + for head_q_idx in range( + head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE + ): + # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front + # for every M-tile + end_n = start_m + BLOCK_M - delta_qk + # clamp end_n at [0, seqlen_k] + end_n = max(min(end_n, seqlen_k), 0) + + # offset input and output tensor by batch and Q/K heads + adj_q = batch_idx * stride_q_b + head_q_idx * stride_q_h + q_start * stride_q_m + adj_do = ( + batch_idx * stride_do_b + head_q_idx * stride_do_h + q_start * stride_do_m + ) + adj_delta = ( + batch_idx * stride_delta_b + + head_q_idx * stride_delta_h + + q_start * stride_delta_m + ) + delta_ptr_adj = delta_ptr + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h + ) + dropout_offset = ( + dropout_mask + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h + ) + + q = tl.load(q_ptr + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(do_ptr + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(m_ptr + adj_delta + offs_m * stride_delta_m, mask=offs_m < seqlen_q) + m = m[:, None] + + MASK_BLOCK_N: tl.constexpr = BLOCK_N // BLK_SLICE_FACTOR + # start can only be 0 at minimum + start_n = max(end_n - BLOCK_M, 0) + num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N) + + if IS_FP8: + descale_q = tl.load( + descale_q_ptr + batch_idx * stride_descale_q_z + head_q_idx + ) + descale_k = tl.load( + descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx + ) + descale_v = tl.load( + descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx + ) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + dq = tl.zeros([BLOCK_M, BLOCK_D_MODEL_POW2], dtype=tl.float32) + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _bwd_dq_inner, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + dq = _bwd_dq_inner_split( + dq, + q, + k_ptr_adj, + v_ptr_adj, + do, + m, + delta_ptr_adj, + sm_scale, + stride_q_m, + stride_q_k, + stride_k_n, + stride_k_k, + stride_v_n, + stride_v_k, + stride_dropout_m, + stride_dropout_n, + stride_delta_m, + seqlen_q, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + BLOCK_M, + MASK_BLOCK_N, + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, + MASK=True, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + end_n -= num_steps * MASK_BLOCK_N + num_steps = tl.cdiv(end_n, BLOCK_N) + start_n = max(end_n - num_steps * BLOCK_N, 0) + dq = _bwd_dq_inner_split( + dq, + q, + k_ptr_adj, + v_ptr_adj, + do, + m, + delta_ptr_adj, + sm_scale, + stride_q_m, + stride_q_k, + stride_k_n, + stride_k_k, + stride_v_n, + stride_v_k, + stride_dropout_m, + stride_dropout_n, + stride_delta_m, + seqlen_q, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + BLOCK_M, + BLOCK_N, + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + # Write back dQ. + offs_dq = ( + batch_idx * stride_dq_b + + head_q_idx * stride_dq_h + + q_start * stride_dq_m + + offs_m[:, None] * stride_dq_m + + offs_k[None, :] * stride_dq_k + ) + dq *= sm_scale + tl.store(dq_ptr + offs_dq, dq, mask=mask_q) + + +@triton.jit +def _bwd_kernel_fused_atomic_noncausal( + Q, + K, + V, + sm_scale, + DO, + DK, + DV, + DQ, + M, + Delta, + stride_qb, + stride_qh, + stride_qm, + stride_qk, + stride_kb, + stride_kh, + stride_kn, + stride_kk, + stride_vb, + stride_vh, + stride_vn, + stride_vk, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkk, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqk, + stride_deltab, + stride_deltah, + stride_deltam, + stride_dob, + stride_doh, + stride_dom, + stride_dok, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q_ptr, + descale_k_ptr, + descale_v_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BATCH, + NUM_K_PIDS, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + # workgroup id + wid = tl.program_id(0) # 0, ..., NUM_K_PIDS * BATCH * NUM_K_HEADS - 1 + + # Workgroups get launched first along batch dim, then in head_k dim, and then in seq k block dim + # This is in order to avoid contention for the tl.atomic_add (inside _bwd_dkdvdq_inner) that happens between workgroups that share the same batch and head_k. + bid = wid % BATCH + hkid = wid // BATCH % NUM_K_HEADS + pid = wid // (BATCH * NUM_K_HEADS) % NUM_K_PIDS + + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + start_n = pid * BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask_kv &= offs_k < BLOCK_D_MODEL + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = ( + bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_k[None, :] * stride_kk + ) + adj_v = ( + bid * stride_vb + + hkid * stride_vh + + k_start * stride_vn + + offs_n[:, None] * stride_vn + + offs_k[None, :] * stride_vk + ) + + k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + + Q_ptr = Q + adj_q + DQ_ptr = DQ + adj_dq + + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + # dropout + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + if IS_FP8: + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hkid) + descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) + descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M) + + dk, dv = _bwd_dkdvdq_inner_atomic( + dk, + dv, + Q_ptr, + k, + v, + DO_ptr, + DQ_ptr, + M_ptr, + Delta_ptr, + sm_scale, + stride_qm, + stride_qk, + stride_dqm, + stride_dqk, + stride_dom, + stride_dok, + stride_dropoutm, + stride_dropoutn, + stride_deltam, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + seqlen_q, + seqlen_k, + start_n, + start_m, + num_steps, + descale_q, + descale_k, + descale_v, + BLOCK_M, + BLOCK_N, + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + workgroup_id=pid, + ) + + adj_dkdv = ( + bid * stride_dkb + + hkid * stride_dkh + + k_start * stride_dkn + + offs_n[:, None] * stride_dkn + + offs_k[None, :] * stride_dkk + ) + tl.store(DV + adj_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv, dk, mask=mask_kv) + + +@triton.jit +def _bwd_kernel_split_dkdv_noncausal( + Q, + K, + V, + sm_scale, + DO, + DK, + DV, + M, + Delta, + stride_qb, + stride_qh, + stride_qm, + stride_qk, + stride_kb, + stride_kh, + stride_kn, + stride_kk, + stride_vb, + stride_vh, + stride_vn, + stride_vk, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkk, + stride_deltab, + stride_deltah, + stride_deltam, + stride_dob, + stride_doh, + stride_dom, + stride_dok, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q_ptr, + descale_k_ptr, + descale_v_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + start_n = pid * BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask_kv &= offs_k < BLOCK_D_MODEL + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = ( + bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_k[None, :] * stride_kk + ) + adj_v = ( + bid * stride_vb + + hkid * stride_vh + + k_start * stride_vn + + offs_n[:, None] * stride_vn + + offs_k[None, :] * stride_vk + ) + + k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + # dropout + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + if IS_FP8: + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hkid) + descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) + descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M) + dk, dv = _bwd_dkdv_inner_split( + dk, + dv, + Q_ptr, + k, + v, + DO_ptr, + M_ptr, + Delta_ptr, + sm_scale, + stride_qm, + stride_qk, + stride_dom, + stride_dok, + stride_dropoutm, + stride_dropoutn, + stride_deltam, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + seqlen_q, + seqlen_k, + start_n, + start_m, + num_steps, + descale_q, + descale_k, + descale_v, + BLOCK_M, + BLOCK_N, + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + + adj_dkdv = ( + bid * stride_dkb + + hkid * stride_dkh + + k_start * stride_dkn + + offs_n[:, None] * stride_dkn + + offs_k[None, :] * stride_dkk + ) + tl.store(DV + adj_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv, dk, mask=mask_kv) + + +@triton.jit +def _bwd_kernel_split_dq_noncausal( + Q, + K, + V, + sm_scale, + DO, + DQ, + M, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qk, + stride_kb, + stride_kh, + stride_kn, + stride_kk, + stride_vb, + stride_vh, + stride_vn, + stride_vk, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqk, + stride_deltab, + stride_deltah, + stride_deltam, + stride_dob, + stride_doh, + stride_dom, + stride_dok, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset_base, + descale_q_ptr, + descale_k_ptr, + descale_v_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + pid = tl.program_id(0) # seqlen + bid = tl.program_id(1) # batch + hkid = tl.program_id(2) # head_k + + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + start_m = pid * BLOCK_M + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_m = start_m + tl.arange(0, BLOCK_M) + + # mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask_k = offs_k < BLOCK_D_MODEL + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + K += adj_k + V += adj_v + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + delta_ptr = delta + adj_delta + + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=offs_m < seqlen_q) + m = m[:, None] + + # FP8 + if IS_FP8: + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hkid) + descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) + descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + start_n = 0 + end_n = seqlen_k + num_steps = tl.cdiv(seqlen_k, BLOCK_N) + dq = tl.zeros([BLOCK_M, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dq = _bwd_dq_inner_split( + dq, + q, + K, + V, + do, + m, + delta_ptr, + sm_scale, + stride_qm, + stride_qk, + stride_kn, + stride_kk, + stride_vn, + stride_vk, + stride_dropoutm, + stride_dropoutn, + stride_deltam, + seqlen_q, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + BLOCK_M, + BLOCK_N, + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + + +# This function computes delta given output Out and gradient DO +# Here is the I/O shape: +# Out: (batch, nhead_q, max_seqlens_q, headDim) +# DO: (batch, nhead_q, max_seqlens_q, headDim) +# Delta: (batch, nheads_q, max_seqlens_q) +@triton.autotune( + configs=preprocess_autotune_configs, + key=PREPROCESS_AUTOTUNE_KEYS, + use_cuda_graph=True, +) +@triton.jit +def _bwd_preprocess( + O, + DO, # noqa: E741 + Delta, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_delta_b, + stride_delta_h, + stride_delta_m, + cu_seqlens_q, + max_seqlen_q, + PRE_BLOCK: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, +): + pid_m = tl.program_id(0) + bid = tl.program_id(1) + hid = tl.program_id(2) + # Handle varlen + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + seqlen_q = q_end - q_start + else: + q_start = 0 + seqlen_q = max_seqlen_q + + # Compute offsets + offs_m = pid_m * PRE_BLOCK + tl.arange(0, PRE_BLOCK) + offs_d = tl.arange(0, HEAD_DIM_V) + # pointer offsets for O & DO + off_o = ( + bid * stride_ob + + hid * stride_oh + + q_start * stride_om + + offs_m[:, None] * stride_om + + offs_d[None, :] * stride_od + ) # noqa: E741 + off_do = ( + bid * stride_dob + + hid * stride_doh + + q_start * stride_dom + + offs_m[:, None] * stride_dom + + offs_d[None, :] * stride_dod + ) + + # create masks + mask_m = offs_m < seqlen_q + mask_md = mask_m[:, None] + PADDED_HEAD_V: tl.constexpr = ACTUAL_HEAD_DIM_V != HEAD_DIM_V + if PADDED_HEAD_V: + mask_md &= offs_d[None, :] < ACTUAL_HEAD_DIM_V + # load + o = tl.load(O + off_o, mask=mask_md, other=0.0) + do = tl.load(DO + off_do, mask=mask_md, other=0.0) + # compute and write-back to delta + # NOTE: Both o and do are FP32 + delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) + off_delta = ( + bid * stride_delta_b + + hid * stride_delta_h + + q_start * stride_delta_m + + offs_m * stride_delta_m + ) + tl.store(Delta + off_delta, delta, mask=mask_m) + + +# The main inner-loop logic for computing dK and dV. +@triton.jit +def _bwd_dkdv_inner( + dk, + dv, # output + Q, + k, + v, + DO, + M, + D, + sm_scale, # input tensor + stride_qm, + stride_qk, + stride_dom, + stride_dok, + stride_dropoutm, + stride_dropoutn, + stride_lse_m, + stride_delta_m, + BLOCK_M: tl.constexpr, # 16 + BLOCK_N: tl.constexpr, # 128 + HEAD_DIM_QK: tl.constexpr, # + HEAD_DIM_V: tl.constexpr, # + ACTUAL_HEAD_DIM_QK: tl.constexpr, # + ACTUAL_HEAD_DIM_V: tl.constexpr, # + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + seqlen_q, + seqlen_k, # max sequence length for q and k + # Filled in by the wrapper. + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + MASK: tl.constexpr, # causal masking, only apply to tiles on mask diagonal + ENABLE_DROPOUT: tl.constexpr, # activate dropout + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, # activate exp2 + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # if HEAD_DIM is padded + PADDED_HEAD_QK: tl.constexpr = ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK + PADDED_HEAD_V: tl.constexpr = ACTUAL_HEAD_DIM_V != HEAD_DIM_V + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) # start_m + (0, 15) + offs_n = start_n + tl.arange(0, BLOCK_N) # start_m + (0, 127) + offs_k_qk = tl.arange(0, HEAD_DIM_QK) + offs_k_v = tl.arange(0, HEAD_DIM_V) + # mask to make sure not OOB of seqlen_q + mask_n = offs_n < seqlen_k + # Q and DO are (seqlen_q, head_dim) + # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM_QK, 1), transpose of q + qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k_qk[:, None] * stride_qk + # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM_V), NOT transposed + do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k_v[None, :] * stride_dok + # BLOCK_N must be a multiple of BLOCK_M, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N % BLOCK_M == 0) + curr_m = start_m + step_m = BLOCK_M + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + + for blk_idx in range(num_steps): + if DEBUG_TRITON: + print(f"iter {blk_idx}: curr_m = {curr_m}") # noqa: E701 + offs_m = curr_m + tl.arange(0, BLOCK_M) + # update the mask because offs_m advanced + mask_m = offs_m < seqlen_q + mask_qT = mask_m[None, :] + mask_do = mask_m[:, None] + mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) + if PADDED_HEAD_QK: + mask_qT &= offs_k_qk[:, None] < ACTUAL_HEAD_DIM_QK + if PADDED_HEAD_V: + mask_do &= offs_k_v[None, :] < ACTUAL_HEAD_DIM_V + qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + # generate dropout mask + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = ( + curr_philox_offset + + offs_m[None, :] * stride_dropoutm + + offs_n[:, None] * stride_dropoutn + ) + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1.0 / (1 - dropout_p) + # Load m before computing qk to reduce pipeline stall. + m = tl.load(M + offs_m * stride_lse_m, mask=mask_m, other=0.0) + if IS_FP8: + qkT = tl.dot(k, qT) * descale_q * descale_k + else: + qkT = tl.dot(k, qT) + qkT_scaled = qkT * sm_scale + + if USE_ALIBI: + relative_pos_block = offs_n[:, None] + seqlen_q - seqlen_k - offs_m[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + qkT_scaled += alibi_block + + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"qT: {qT.shape}\n", qT) + print(f"k: {k.shape}\n", k) + print(f"qkT scaled: {qkT.shape}\n", qkT_scaled) + # TODO: remove the scaling of m later when we removed re-scaling in fwd + if USE_EXP2: + pT = tl.math.exp2(qkT_scaled * RCP_LN2 - m[None, :] * RCP_LN2) + else: + pT = tl.math.exp(qkT_scaled - m[None, :]) + + # Autoregressive masking. + if MASK: + # offset offs_m with delta_qk since the causal mask starts at + # bottom right of the (seqlen_q, seqlen_k) matrix + causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] + mask = causal_mask & mask_nm + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"causal_mask: {causal_mask.shape}\n", causal_mask) + print( + f"qkT after causal: {qkT.shape}\n", + tl.where(causal_mask, qkT * sm_scale, 0.0), + ) + pT = tl.where(mask, pT, 0.0) + do = tl.load(do_ptrs, mask=mask_do, other=0.0) + # Compute dV. + if ENABLE_DROPOUT: + pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + else: + dv += tl.dot(pT.to(do.type.element_ty), do) + + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"pT: {pT.shape}\n", pT) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m * stride_delta_m, mask=mask_m) + # Compute dP and dS. + if IS_FP8: + dpT = tl.dot(v, tl.trans(do.to(v.type.element_ty))) * descale_v + else: + dpT = tl.dot(v, tl.trans(do)) + if ENABLE_DROPOUT: + dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + delta_i = Di[None, :] + dsT = pT * (dpT - delta_i) + if IS_FP8: + # Rewrite dk += dsT @ qT.T as dk += (qT @ dsT.T).T + # This puts FP8 tensor (qT) on LHS of dot product + # Cast the transposed dsT to FP8 to match qT's dtype + dsT_transposed = tl.trans(dsT).to(qT.type.element_ty) + dk += tl.trans(tl.dot(qT, dsT_transposed)) * descale_q + else: + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + # Increment pointers. + curr_m += step_m + qT_ptrs += step_m * stride_qm + do_ptrs += step_m * stride_dom + return dk, dv + + +# the main inner-loop logic for computing dQ +@triton.jit +def _bwd_dq_inner( + dq, # output + q, + K, + V, + do, + m, + Delta, + sm_scale, # input + # shared by Q/K/V. + stride_qm, + stride_qk, + stride_kn, + stride_kk, + stride_vn, + stride_vk, + stride_dropoutm, + stride_dropoutn, # stride for dropout + stride_lse_m, + stride_delta_m, + seqlen_q, + seqlen_k, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + HEAD_DIM_QK: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_QK: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, # + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + # Filled in by the wrapper. + start_m, + start_n, + end_n, + num_steps, # + descale_q, + descale_k, + descale_v, + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # if HEAD_DIM is padded + PADDED_HEAD_QK: tl.constexpr = ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK + PADDED_HEAD_V: tl.constexpr = ACTUAL_HEAD_DIM_V != HEAD_DIM_V + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k_qk = tl.arange(0, HEAD_DIM_QK) + offs_k_v = tl.arange(0, HEAD_DIM_V) + + # mask to make sure not OOB of seqlen_q + mask_m = offs_m < seqlen_q + + kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k_qk[:, None] * stride_kk + vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k_v[:, None] * stride_vk + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(Delta + offs_m * stride_delta_m, mask=mask_m, other=0.0) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + for blk_idx in range(num_steps): + if DEBUG_TRITON: + print(f"iter {blk_idx}: curr_n = {curr_n}") # noqa: E701 + offs_n = curr_n + tl.arange(0, BLOCK_N2) + # end_n is needed because the end of causal True might not be perfectly + # aligned with the end of the block + mask_n = offs_n < end_n + if DEBUG_TRITON_DETAIL: + print( + f"start_n = {start_n}, end_n = {end_n}, offs_n: {offs_n.shape}\n{offs_n}" + ) # noqa: E701 + if DEBUG_TRITON_DETAIL: + print(f"mask_n: {mask_n.shape}\n{mask_n}") # noqa: E701 + mask_kT = mask_n[None, :] + mask_vT = mask_n[None, :] + mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) + if PADDED_HEAD_QK: + mask_kT &= offs_k_qk[:, None] < ACTUAL_HEAD_DIM_QK + if PADDED_HEAD_V: + mask_vT &= offs_k_v[:, None] < ACTUAL_HEAD_DIM_V + + kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) + vT = tl.load(vT_ptrs, mask=mask_vT, other=0.0) + + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = ( + curr_philox_offset + + offs_m[:, None] * stride_dropoutm + + offs_n[None, :] * stride_dropoutn + ) + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1 / (1 - dropout_p) + + if IS_FP8: + qk = tl.dot(q, kT) * descale_q * descale_k + else: + qk = tl.dot(q, kT) + qk_scaled = qk * sm_scale + + if USE_ALIBI: + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + qk_scaled += alibi_block + + if DEBUG_TRITON_DETAIL: + print(f"qk scaled: {qk.shape}\n", qk_scaled) # noqa: E701 + if USE_EXP2: + p = tl.math.exp2(qk_scaled * RCP_LN2 - m * RCP_LN2) + else: + p = tl.math.exp(qk_scaled - m) + + # Autoregressive masking. + if MASK: + causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] + mask = causal_mask & mask_mn + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + if IS_FP8: + dp = tl.dot(do.to(vT.type.element_ty), vT) * descale_v + else: + dp = tl.dot(do, vT) + if ENABLE_DROPOUT: + dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale + delta_i = Di[:, None] + ds = p * (dp - delta_i) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + if IS_FP8: + # Rewrite dq += ds @ kT.T as dq += (kT @ ds.T).T + # This puts FP8 tensor (kT) on LHS of dot product + # Cast the transposed ds to FP8 to match kT's dtype + ds_transposed = tl.trans(ds).to(kT.type.element_ty) + dq += tl.trans(tl.dot(kT, ds_transposed)) * descale_k + else: + dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) + # Increment pointers. + curr_n += step_n + kT_ptrs += step_n * stride_kn + vT_ptrs += step_n * stride_vn + return dq + + +@triton.autotune( + configs=causal_autotune_configs, + key=CAUSAL_AUTOTUNE_KEYS, + use_cuda_graph=True, +) +@triton.jit +def bwd_kernel_fused_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), batch) + Q, + K, + V, + sm_scale, + DO, + DQ, + DK, + DV, + M, + Delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_dvb, + stride_dvh, + stride_dvn, + stride_dvd, + stride_lse_b, + stride_lse_h, + stride_lse_m, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_az, + stride_ah, + HQ, + HK, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, # Add seqused parameters + max_seqlen_q, + max_seqlen_k, + Dropout_mask, + dropout_p, + philox_seed, + philox_offset_base, + Alibi_slopes, + Descale_q, + Descale_k, + Descale_v, + BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM_QK: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_QK: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + USE_SEQUSED: tl.constexpr, # Add flag for seqused + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + hkid = tl.program_id(0) + pid = tl.program_id(1) + bid = tl.program_id(2) + if DEBUG_TRITON: + print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + + # If seqused is provided, use it to limit the actual sequence length + if USE_SEQUSED: + actual_seqlen_q = ( + tl.load(seqused_q + bid) if seqused_q is not None else q_end - q_start + ) + seqlen_q = tl.minimum(actual_seqlen_q, q_end - q_start) + actual_seqlen_k = ( + tl.load(seqused_k + bid) if seqused_k is not None else k_end - k_start + ) + seqlen_k = tl.minimum(actual_seqlen_k, k_end - k_start) + else: + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + delta_qk = seqlen_q - seqlen_k + if DEBUG_TRITON: + print(f"delta_qk = {delta_qk}") # noqa: E701 + PADDED_HEAD_QK: tl.constexpr = ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK + PADDED_HEAD_V: tl.constexpr = ACTUAL_HEAD_DIM_V != HEAD_DIM_V + offs_d_qk = tl.arange(0, HEAD_DIM_QK) + offs_d_v = tl.arange(0, HEAD_DIM_V) + GROUP_SIZE: tl.constexpr = HQ // HK + + # align the delta_qk + start_n = pid * BLOCK_N1 + if start_n < seqlen_k: + # This section does dk and dv + dk = tl.zeros([BLOCK_N1, HEAD_DIM_QK], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM_V], dtype=tl.float32) + + # q > k: diretcly skip all the way until the start of causal block + start_delta_q_gt_k = delta_qk + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N1 + delta_aligned = (num_blocks_skip + 1) * BLOCK_N1 + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M1 * BLOCK_M1 + if delta_qk >= 0: + start_delta = delta_qk + if DEBUG_TRITON: + print( + f"q >= k: start_delta = delta_qk aligned to BLOCK_M = {start_delta_q_gt_k}" + ) # noqa: E701 + else: + start_delta = start_delta_q_lt_k + if DEBUG_TRITON: + print( + f"q < k: start_delta = residue btw multiple BLOCK_N and delta_qk = {delta_aligned} = aligned to BLOCK_M = {start_delta_q_lt_k}" + ) # noqa: E701 + + offs_n = start_n + tl.arange(0, BLOCK_N1) + # Mask for loading K and V + mask_k = offs_n[:, None] < seqlen_k + mask_v = offs_n[:, None] < seqlen_k + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_k &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_v &= mask_d_v[None, :] + + # K/V tensors not changed for the group + adj_k = ( + bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_d_qk[None, :] * stride_kd + ) + adj_v = ( + bid * stride_vb + + hkid * stride_vh + + k_start * stride_vn + + offs_n[:, None] * stride_vn + + offs_d_v[None, :] * stride_vd + ) + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + adj_k, mask=mask_k) + v = tl.load(V + adj_v, mask=mask_v) + # If MQA / GQA, set the K and V head offsets appropriately. + # hqid = hkid + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N1 + else: + start_m = max(start_n + delta_qk, 0) + start_m = start_m // BLOCK_M1 * BLOCK_M1 + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N1 + residue_m + if DEBUG_TRITON: + print(f"residue_m = {residue_m}") # noqa: E701 + + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = ( + bid * stride_delta_b + hqid * stride_delta_h + q_start * stride_delta_m + ) + Delta_ptr = Delta + adj_delta + adj_m = bid * stride_lse_b + hqid * stride_lse_h + q_start * stride_lse_m + M_ptr = M + adj_m + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + if IS_FP8: + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hkid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M1) + # when q < k, we may skip the initial masked op + if pid < num_blocks_skip: + num_steps = 0 + + # if start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + if DEBUG_TRITON: + print( + f"Masked: start_n: {start_n}; start_m: {start_m}, num_steps: {num_steps}" + ) # noqa: E701 + dk, dv = _bwd_dkdv_inner( + dk, + dv, # output tensors + Q_ptr, + k, + v, + DO_ptr, + M_ptr, + Delta_ptr, + sm_scale, # input tensors + stride_qm, + stride_qd, # strides for q + stride_dom, + stride_dod, # strides for o + stride_dropoutm, + stride_dropoutn, # strides for dropout + stride_lse_m, + stride_delta_m, + MASK_BLOCK_M1, + BLOCK_N1, # block dim + HEAD_DIM_QK, + HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V, # head dim + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + start_m += num_steps * MASK_BLOCK_M1 + num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M1) + end_m = start_m + num_steps * BLOCK_M1 + + if DEBUG_TRITON: + print( + f"start_m after Masked step: {start_m}; num_steps: {num_steps}" + ) # noqa: E701 + if DEBUG_TRITON: + print( + f"unMasked: start_n: {start_n}, start_m: {start_m}, end_m: {end_m}, num_steps: {num_steps}" + ) # noqa: E701 + if DEBUG_TRITON: + print("unMasked") # noqa: E701 + dk, dv = _bwd_dkdv_inner( + dk, + dv, # output tensors + Q_ptr, + k, + v, + DO_ptr, + M_ptr, + Delta_ptr, + sm_scale, # input tensors + stride_qm, + stride_qd, # strides for q + stride_dom, + stride_dod, # strides for o + stride_dropoutm, + stride_dropoutn, # strides for dropout + stride_lse_m, + stride_delta_m, + BLOCK_M1, + BLOCK_N1, # block dim + HEAD_DIM_QK, + HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V, # head dim + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # end of GQA/MQA of dkdv + # Write back dV + adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn + offs_dv = offs_n[:, None] * stride_dvn + offs_d_v[None, :] * stride_dvd + tl.store(DV + adj_dv + offs_dv, dv, mask=mask_v) + # write back dk + adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn + offs_dk = offs_n[:, None] * stride_dkn + offs_d_qk[None, :] * stride_dkd + dk *= sm_scale + tl.store(DK + adj_dk + offs_dk, dk, mask=mask_k) + + # This part does dq + start_m = pid * BLOCK_M2 + if start_m < seqlen_q: + # seqlen_q > seqlen_k, no need to process these tile for dq + if DEBUG_TRITON: + print( + f"end_n = start_m + BLOCK_M = {start_m} + {BLOCK_M2} = {start_m + BLOCK_M2}" + ) # noqa: E701 + if start_m + BLOCK_M2 < delta_qk: + if DEBUG_TRITON: + print( + f"start_m + BLOCK_M2 = {start_m} + {BLOCK_M2} = {start_m + BLOCK_M2} < delta_qk of {delta_qk}" + ) # noqa: E701 + return + + offs_m = start_m + tl.arange(0, BLOCK_M2) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + mask_do = offs_m[:, None] < seqlen_q + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_q &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_do &= mask_d_v[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_d_qk[None, :] * stride_qd + offs_do = offs_m[:, None] * stride_dom + offs_d_v[None, :] * stride_dod + # NOTE: don't assume that the strides for k and v are the same! + K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn + V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn + + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front + # for every M-tile + end_n = start_m + BLOCK_M2 - delta_qk + # clamp end_n at [0, seqlen_k] + end_n = max(min(end_n, seqlen_k), 0) + if DEBUG_TRITON: + print(f"delta_qk: {delta_qk}; end_n: {end_n}") # noqa: E701 + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = ( + bid * stride_delta_b + hqid * stride_delta_h + q_start * stride_delta_m + ) + Delta_ptr = Delta + adj_delta + adj_m = bid * stride_lse_b + hqid * stride_lse_h + q_start * stride_lse_m + M_ptr = M + adj_m + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_do, other=0.0) + m = tl.load(M + adj_m + offs_m * stride_lse_m, mask=offs_m < seqlen_q) + m = m[:, None] + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + # start can only be 0 at minimum + start_n = max(end_n - BLOCK_M2, 0) + num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N2) + + if IS_FP8: + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hkid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + dq = tl.zeros([BLOCK_M2, HEAD_DIM_QK], dtype=tl.float32) + dq = _bwd_dq_inner( + dq, + q, + K, + V, + do, + m, + Delta_ptr, + sm_scale, + stride_qm, + stride_qd, + stride_kn, + stride_kd, + stride_vn, + stride_vd, + stride_dropoutm, + stride_dropoutn, + stride_lse_m, + stride_delta_m, + seqlen_q, + seqlen_k, + BLOCK_M2, + MASK_BLOCK_N2, + HEAD_DIM_QK, + HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + MASK=True, # + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + end_n -= num_steps * MASK_BLOCK_N2 + num_steps = tl.cdiv(end_n, BLOCK_N2) + start_n = max(end_n - num_steps * BLOCK_N2, 0) + if DEBUG_TRITON: + print( + f"unMasked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}" + ) # noqa: E701 + dq = _bwd_dq_inner( + dq, + q, + K, + V, + do, + m, + Delta_ptr, + sm_scale, + stride_qm, + stride_qd, + stride_kn, + stride_kd, + stride_vn, + stride_vd, + stride_dropoutm, + stride_dropoutn, + stride_lse_m, + stride_delta_m, + seqlen_q, + seqlen_k, + BLOCK_M2, + BLOCK_N2, + HEAD_DIM_QK, + HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_d_qk[None, :] * stride_dqd + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + # end of GQA/MQA of dq + + +@triton.autotune( + configs=noncausal_autotune_configs, + key=NONCAUSAL_AUTOTUNE_KEYS, + use_cuda_graph=True, +) +@triton.jit +def bwd_kernel_fused_noncausal( + Q, + K, + V, + sm_scale, + DO, + DQ, + DK, + DV, + M, + Delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_dvb, + stride_dvh, + stride_dvn, + stride_dvd, + stride_lse_b, + stride_lse_h, + stride_lse_m, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_az, + stride_ah, + HQ, + HK, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, # Add seqused parameters + max_seqlen_q, + max_seqlen_k, + Dropout_mask, + dropout_p, + philox_seed, + philox_offset_base, + Alibi_slopes, + Descale_q, + Descale_k, + Descale_v, + BLOCK_M1: tl.constexpr, # 32 + BLOCK_N1: tl.constexpr, # 128 + BLOCK_M2: tl.constexpr, # 128 + BLOCK_N2: tl.constexpr, # 32 + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM_QK: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_QK: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + USE_SEQUSED: tl.constexpr, # Add flag for seqused + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + hkid = tl.program_id(0) + pid = tl.program_id(1) + bid = tl.program_id(2) + if DEBUG_TRITON: + print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + + # If seqused is provided, use it to limit the actual sequence length + if USE_SEQUSED: + actual_seqlen_q = ( + tl.load(seqused_q + bid) if seqused_q is not None else q_end - q_start + ) + seqlen_q = tl.minimum(actual_seqlen_q, q_end - q_start) + actual_seqlen_k = ( + tl.load(seqused_k + bid) if seqused_k is not None else k_end - k_start + ) + seqlen_k = tl.minimum(actual_seqlen_k, k_end - k_start) + else: + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + PADDED_HEAD_QK: tl.constexpr = ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK + PADDED_HEAD_V: tl.constexpr = ACTUAL_HEAD_DIM_V != HEAD_DIM_V + offs_d_qk = tl.arange(0, HEAD_DIM_QK) + offs_d_v = tl.arange(0, HEAD_DIM_V) + GROUP_SIZE: tl.constexpr = HQ // HK + + start_n = pid * BLOCK_N1 + if start_n < seqlen_k: + dk = tl.zeros([BLOCK_N1, HEAD_DIM_QK], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM_V], dtype=tl.float32) + + offs_n = start_n + tl.arange(0, BLOCK_N1) + # Mask for loading K and V + mask_k = offs_n[:, None] < seqlen_k + mask_v = offs_n[:, None] < seqlen_k + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_k &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_v &= mask_d_v[None, :] + # NOTE: don't assume that the strides for k and v are the same! + # K/V tensors not changed for the group + adj_k = ( + bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_d_qk[None, :] * stride_kd + ) + adj_v = ( + bid * stride_vb + + hkid * stride_vh + + k_start * stride_vn + + offs_n[:, None] * stride_vn + + offs_d_v[None, :] * stride_vd + ) + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + adj_k, mask=mask_k) + v = tl.load(V + adj_v, mask=mask_v) + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = ( + bid * stride_delta_b + hqid * stride_delta_h + q_start * stride_delta_m + ) + Delta_ptr = Delta + adj_delta + adj_m = bid * stride_lse_b + hqid * stride_lse_h + q_start * stride_lse_m + M_ptr = M + adj_m + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + if IS_FP8: + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hkid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + # because there is no causal, we always start from the beginning + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M1) + dk, dv = _bwd_dkdv_inner( + dk, + dv, # output tensors + Q_ptr, + k, + v, + DO_ptr, + M_ptr, + Delta_ptr, + sm_scale, # input tensors + stride_qm, + stride_qd, # strides for q + stride_dom, + stride_dod, # strides for o + stride_dropoutm, + stride_dropoutn, # strides for dropout + stride_lse_m, + stride_delta_m, + BLOCK_M1, + BLOCK_N1, # block dim + HEAD_DIM_QK, + HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V, # head dim + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, # + alibi_slope, + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + # Write back dV + adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn + offs_dv = offs_n[:, None] * stride_dvn + offs_d_v[None, :] * stride_dvd + tl.store(DV + adj_dv + offs_dv, dv, mask=mask_v) + # write back dk + adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn + offs_dk = offs_n[:, None] * stride_dkn + offs_d_qk[None, :] * stride_dkd + dk *= sm_scale + tl.store(DK + adj_dk + offs_dk, dk, mask=mask_k) + + # THIS PART DOES DQ + start_m = pid * BLOCK_M2 + if start_m < seqlen_q: + offs_m = start_m + tl.arange(0, BLOCK_M2) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + mask_do = offs_m[:, None] < seqlen_q + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_q &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_do &= mask_d_v[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_d_qk[None, :] * stride_qd + offs_do = offs_m[:, None] * stride_dom + offs_d_v[None, :] * stride_dod + K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn + V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = ( + bid * stride_delta_b + hqid * stride_delta_h + q_start * stride_delta_m + ) + Delta_ptr = Delta + adj_delta + adj_m = bid * stride_lse_b + hqid * stride_lse_h + q_start * stride_lse_m + M_ptr = M + adj_m + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_do, other=0.0) + m = tl.load(M + adj_m + offs_m * stride_lse_m, mask=offs_m < seqlen_q) + m = m[:, None] + + if IS_FP8: + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hkid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + # start can only be 0 at minimum + start_n = 0 + end_n = seqlen_k + num_steps = tl.cdiv(seqlen_k, BLOCK_N2) + + dq = tl.zeros([BLOCK_M2, HEAD_DIM_QK], dtype=tl.float32) + dq = _bwd_dq_inner( + dq, + q, + K, + V, + do, + m, + Delta_ptr, + sm_scale, + stride_qm, + stride_qd, + stride_kn, + stride_kd, + stride_vn, + stride_vd, + stride_dropoutm, + stride_dropoutn, + stride_lse_m, + stride_delta_m, + seqlen_q, + seqlen_k, + BLOCK_M2, + BLOCK_N2, + HEAD_DIM_QK, + HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_d_qk[None, :] * stride_dqd + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + + +def is_contiguous(x, name): + if x.is_contiguous(): + return x + else: + print(f"{name} is not contiguous") + return x.contiguous() + + +# Triton kernel debug flags derived from DEBUG level. +# Level 1: basic kernel debug prints (iteration info) +# Level 2: detailed kernel debug prints (tensor values) +# Requires TRITON_INTERPRET=1 to actually print inside kernels. +DEBUG_TRITON: bool = DEBUG >= 1 +DEBUG_TRITON_DETAIL: bool = DEBUG >= 2 + + +def attention_backward_triton_impl( + *, + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + softmax_lse: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + delta: torch.Tensor, + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + layout: Literal["bshd", "bhsd", "thd"], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + philox_seed: Optional[int] = None, + philox_offset: Optional[int] = None, + use_exp2: bool = True, + mode: Literal["fused", "fused_atomic", "split"] = "fused", +): + # get params, strides and shape + IS_VARLEN = layout == "thd" + use_dropout = dropout_p > 0.0 + + # common assertions + assert ( + 0.0 <= dropout_p <= 1.0 + ), f"dropout_p must be between 0 and 1, got {dropout_p}" + assert ( + q.device == k.device == v.device == o.device == do.device == softmax_lse.device + ), f"All tensors must be on the same device. Got: q={q.device}, k={k.device}, v={v.device}, o={o.device}, do={do.device}, softmax_lse={softmax_lse.device}" + assert q.dtype == k.dtype == v.dtype, "q, k, v must have the same dtype" + current_device = torch.cuda.current_device() + assert ( + q.is_cuda and q.device.index == current_device + ), f"Device mismatch: Kernel will launch on cuda:{current_device}, but tensors are on {q.device}" + + # get shapes and strides + if IS_VARLEN: + # shape + total_seqlen_q, nheads_q, head_size_q = q.shape + total_seqlen_k, nheads_k, head_size_k = k.shape + total_seqlen_v, nheads_v, head_size_v = v.shape + nheads_lse, total_seqlen_lse = softmax_lse.shape + + # assert shapes + assert ( + total_seqlen_lse == total_seqlen_q + ), f"softmax_lse seqlen {total_seqlen_lse} != q seqlen {total_seqlen_q}" + assert ( + cu_seqlens_q is not None + ), "cu_seqlens_q must be provided for varlen layout" + assert ( + cu_seqlens_k is not None + ), "cu_seqlens_k must be provided for varlen layout" + assert ( + max_seqlen_q is not None + ), "max_seqlen_q must be provided for varlen layout" + assert ( + max_seqlen_k is not None + ), "max_seqlen_k must be provided for varlen layout" + + # assert head dimensions + assert ( + head_size_q == head_size_k + ), f"head sizes must match: q={head_size_q}, k={head_size_k}" + assert ( + nheads_k == nheads_v + ), f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" + assert ( + nheads_q % nheads_k == 0 + ), f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" + assert ( + nheads_lse == nheads_q + ), f"softmax_lse heads {nheads_lse} != q heads {nheads_q}" + + # assert output shapes + assert o.shape == ( + total_seqlen_q, + nheads_q, + head_size_v, + ), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_v)}" + assert do.shape == o.shape, f"do shape {do.shape} != o shape {o.shape}" + assert dq.shape == q.shape, f"dq shape {dq.shape} != q shape {q.shape}" + assert dk.shape == k.shape, f"dk shape {dk.shape} != k shape {k.shape}" + assert dv.shape == v.shape, f"dv shape {dv.shape} != v shape {v.shape}" + + # assert cu_seqlens + assert ( + cu_seqlens_q.dtype == torch.int32 + ), f"cu_seqlens_q must be int32, got {cu_seqlens_q.dtype}" + assert ( + cu_seqlens_k.dtype == torch.int32 + ), f"cu_seqlens_k must be int32, got {cu_seqlens_k.dtype}" + assert cu_seqlens_q[0] == 0, "cu_seqlens_q must start with 0" + assert cu_seqlens_k[0] == 0, "cu_seqlens_k must start with 0" + assert ( + cu_seqlens_q[-1] == total_seqlen_q + ), f"cu_seqlens_q[-1] {cu_seqlens_q[-1]} != total_seqlen_q {total_seqlen_q}" + assert ( + cu_seqlens_k[-1] == total_seqlen_k + ), f"cu_seqlens_k[-1] {cu_seqlens_k[-1]} != total_seqlen_k {total_seqlen_k}" + + # set vars + batch = len(cu_seqlens_q) - 1 + head_size_qk = head_size_q + + # strides + stride_qb, stride_qm, stride_qh, stride_qd = ( + 0, + q.stride(0), + q.stride(1), + q.stride(2), + ) + stride_kb, stride_kn, stride_kh, stride_kd = ( + 0, + k.stride(0), + k.stride(1), + k.stride(2), + ) + stride_vb, stride_vn, stride_vh, stride_vd = ( + 0, + v.stride(0), + v.stride(1), + v.stride(2), + ) + stride_ob, stride_om, stride_oh, stride_od = ( + 0, + o.stride(0), + o.stride(1), + o.stride(2), + ) + stride_dqb, stride_dqm, stride_dqh, stride_dqd = ( + 0, + dq.stride(0), + dq.stride(1), + dq.stride(2), + ) + stride_dkb, stride_dkn, stride_dkh, stride_dkd = ( + 0, + dk.stride(0), + dk.stride(1), + dk.stride(2), + ) + stride_dvb, stride_dvn, stride_dvh, stride_dvd = ( + 0, + dv.stride(0), + dv.stride(1), + dv.stride(2), + ) + stride_dob, stride_dom, stride_doh, stride_dod = ( + 0, + do.stride(0), + do.stride(1), + do.stride(2), + ) + stride_lse_b, stride_lse_h, stride_lse_m = ( + 0, + softmax_lse.stride(0), + softmax_lse.stride(1), + ) + else: + # shapes + batch_q, seqlen_q, nheads_q, head_size_q = q.shape + batch_k, seqlen_k, nheads_k, head_size_k = k.shape + batch_v, seqlen_v, nheads_v, head_size_v = v.shape + batch_lse, nheads_lse, seqlen_lse = softmax_lse.shape + + # assert batch dimensions + assert ( + batch_q == batch_k == batch_v + ), f"batch sizes must match: q={batch_q}, k={batch_k}, v={batch_v}" + + # assert head dimensions + assert ( + head_size_q == head_size_k + ), f"head sizes must match: q={head_size_q}, k={head_size_k}" + assert ( + nheads_k == nheads_v + ), f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" + assert ( + nheads_q % nheads_k == 0 + ), f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" + + # assert sequence lengths + assert ( + seqlen_k == seqlen_v + ), f"k and v sequence lengths must match: k={seqlen_k}, v={seqlen_v}" + + # assert output shapes + assert o.shape == ( + batch_q, + seqlen_q, + nheads_q, + head_size_v, + ), f"o shape {o.shape} != expected" + assert do.shape == o.shape, f"do shape {do.shape} != o shape {o.shape}" + assert dq.shape == q.shape, f"dq shape {dq.shape} != q shape {q.shape}" + assert dk.shape == k.shape, f"dk shape {dk.shape} != k shape {k.shape}" + assert dv.shape == v.shape, f"dv shape {dv.shape} != v shape {v.shape}" + + # assert softmax_lse shape + assert softmax_lse.shape == ( + batch_q, + nheads_q, + seqlen_q, + ), f"softmax_lse shape {softmax_lse.shape} != expected" + + # set vars + batch = batch_q + head_size_qk = head_size_q + max_seqlen_q = seqlen_q + max_seqlen_k = seqlen_k + + # strides + stride_qb, stride_qm, stride_qh, stride_qd = q.stride() + stride_kb, stride_kn, stride_kh, stride_kd = k.stride() + stride_vb, stride_vn, stride_vh, stride_vd = v.stride() + stride_ob, stride_om, stride_oh, stride_od = o.stride() + stride_dqb, stride_dqm, stride_dqh, stride_dqd = dq.stride() + stride_dkb, stride_dkn, stride_dkh, stride_dkd = dk.stride() + stride_dvb, stride_dvn, stride_dvh, stride_dvd = dv.stride() + stride_dob, stride_dom, stride_doh, stride_dod = do.stride() + stride_lse_b, stride_lse_h, stride_lse_m = softmax_lse.stride() + + # fp8 + IS_FP8 = is_fp8([q, k, v]) + if IS_FP8: + arch = get_arch() + if not arch.supports_fp8: + raise RuntimeError( + f"{arch.name} does not support FP8" + ) + FP8_MAX = torch.finfo(q.dtype).max + + warnings.warn( + "FP8 tensors detected in backward pass. Backward pass supports FP8 inputs but " + "descaling factors will default to 1.0.", + UserWarning, + ) + + # For GQA/MQA, q_descale should be shaped (batch, nheads_k) to match forward pass + descale_q = torch.ones(batch, nheads_k, dtype=torch.float32, device=q.device) + + descale_k = torch.ones(batch, nheads_k, dtype=torch.float32, device=q.device) + + descale_v = torch.ones(batch, nheads_k, dtype=torch.float32, device=q.device) + + stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None + stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None + stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None + + if DEBUG: + print(f"FP8 path triggered in bwd.py") + else: + FP8_MAX = None + descale_q = descale_k = descale_v = None + stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = None + + # alibi setup + use_alibi, (stride_az, stride_ah) = ( + (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) + ) + + # get closest power of 2 over or equal to 32. + padded_d_model_qk = 1 << (head_size_qk - 1).bit_length() + padded_d_model_qk = max(padded_d_model_qk, 32) + padded_d_model_v = 1 << (head_size_v - 1).bit_length() + padded_d_model_v = max(padded_d_model_v, 32) + HEAD_DIM_QK = padded_d_model_qk + HEAD_DIM_V = padded_d_model_v + ACTUAL_HEAD_DIM_QK = head_size_qk + ACTUAL_HEAD_DIM_V = head_size_v + + # Validate pre-allocated delta tensor + if IS_VARLEN: + # Shape expected by interface varlen backward: (Hq, Total_Q) + total_q, _, _ = q.shape + assert ( + delta.shape[0] == nheads_q + ), f"delta.shape[0] ({delta.shape[0]}) must equal nheads_q ({nheads_q})" + assert ( + delta.shape[1] >= total_q + ), f"delta.shape[1] ({delta.shape[1]}) must be >= total_q ({total_q})" + assert delta.dtype == torch.float32, f"delta must be float32, got {delta.dtype}" + assert delta.device == q.device, f"delta must be on same device as q" + stride_delta_b, stride_delta_h, stride_delta_m = ( + 0, + delta.stride(0), + delta.stride(1), + ) + else: + # Shape expected by dense backward: (B, Hq, Sq) + seqlen_q = q.shape[1] + assert ( + delta.shape[0] == batch + ), f"delta.shape[0] ({delta.shape[0]}) must equal batch ({batch})" + assert ( + delta.shape[1] == nheads_q + ), f"delta.shape[1] ({delta.shape[1]}) must equal nheads_q ({nheads_q})" + assert ( + delta.shape[2] >= seqlen_q + ), f"delta.shape[2] ({delta.shape[2]}) must be >= seqlen_q ({seqlen_q})" + assert delta.dtype == torch.float32, f"delta must be float32, got {delta.dtype}" + assert delta.device == q.device, f"delta must be on same device as q" + stride_delta_b, stride_delta_h, stride_delta_m = delta.stride() + + pre_grid = lambda META: ( + triton.cdiv(max_seqlen_q, META["PRE_BLOCK"]), + batch, + nheads_q, + ) + _bwd_preprocess[pre_grid]( + o, + do, + delta, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_delta_b, + stride_delta_h, + stride_delta_m, + cu_seqlens_q, + max_seqlen_q, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + ) + + if DEBUG: + print("delta:", delta, delta.shape) + + # dropout mask tensor for debugging. We dump the dropout mask created in + # the kernel for testing + dropout_mask = None + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = (0, 0, 0, 0) + if use_dropout: + dropout_mask = torch.zeros( + (batch, nheads_q, max_seqlen_q, max_seqlen_k), + device=q.device, + dtype=torch.float32, + ) + + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = ( + dropout_mask.stride() + ) + + # Choose which kernels to call based on mode + if mode == "fused": + seqlen = max(max_seqlen_q, max_seqlen_k) + grid = lambda META: ( + nheads_k, + (seqlen + META["BLOCK_N1"] - 1) // META["BLOCK_N1"], + batch, + ) + if causal: + if DEBUG_TRITON: + print(f"bwd_kernel: grid = {grid}") # noqa: E701 + bwd_kernel_fused_causal[grid]( + q, + k, + v, + sm_scale, + do, + dq, + dk, + dv, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_dvb, + stride_dvh, + stride_dvn, + stride_dvd, + stride_lse_b, + stride_lse_h, + stride_lse_m, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_az, + stride_ah, + nheads_q, + nheads_k, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, # Pass seqused tensors + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, + descale_q, + descale_k, + descale_v, + HEAD_DIM_QK=HEAD_DIM_QK, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, + USE_EXP2=use_exp2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + USE_SEQUSED=( + seqused_q is not None or seqused_k is not None + ), # Add flag for seqused + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + else: + bwd_kernel_fused_noncausal[grid]( + q, + k, + v, + sm_scale, + do, + dq, + dk, + dv, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_dvb, + stride_dvh, + stride_dvn, + stride_dvd, + stride_lse_b, + stride_lse_h, + stride_lse_m, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_az, + stride_ah, + nheads_q, + nheads_k, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, # Pass seqused tensors + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, + descale_q, + descale_k, + descale_v, + HEAD_DIM_QK=HEAD_DIM_QK, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, + USE_EXP2=use_exp2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + USE_SEQUSED=( + seqused_q is not None or seqused_k is not None + ), # Add flag for seqused + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + elif mode == "fused_atomic": + NUM_WARPS, NUM_STAGES = 4, 1 + WAVES_PER_EU = 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 64, 64, 64, 16 + BLK_SLICE_FACTOR = 2 + BLOCK_D_MODEL_POW2 = max(triton.next_power_of_2(HEAD_DIM_QK), 16) + + grid_dkdv = ((max_seqlen_k + BLOCK_N1 - 1) // BLOCK_N1, batch, nheads_k) + grid_dq = ((max_seqlen_q + BLOCK_M2 - 1) // BLOCK_M2, batch, nheads_k) + + # fuses dk, dv, dq computations into one kernel by computing the dq using atomic adds between workgroups + BLOCK_N = ( + 128 if BLOCK_D_MODEL_POW2 < 160 else 64 + ) # larger head sizes lead to oom + config = { + "BLOCK_M": 32, + "BLOCK_N": BLOCK_N, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 1, + "BLK_SLICE_FACTOR": 2, + } + + num_k_pids = (max_seqlen_k + BLOCK_N - 1) // BLOCK_N + grid_dkdvdq = (batch * nheads_k * num_k_pids,) + + if causal: + _bwd_kernel_fused_atomic_causal[grid_dkdvdq]( + q, + k, + v, + sm_scale, + do, + dk, + dv, + dq, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BATCH=batch, + NUM_K_PIDS=num_k_pids, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + **config, + ) + else: + _bwd_kernel_fused_atomic_noncausal[grid_dkdvdq]( + q, + k, + v, + sm_scale, + do, + dk, + dv, + dq, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BATCH=batch, + NUM_K_PIDS=num_k_pids, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + **config, + ) + elif mode == "split": + NUM_WARPS, NUM_STAGES = 4, 1 + WAVES_PER_EU = 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 64, 64, 64, 16 + BLK_SLICE_FACTOR = 2 + BLOCK_D_MODEL_POW2 = max(triton.next_power_of_2(HEAD_DIM_QK), 16) + + grid_dkdv = ((max_seqlen_k + BLOCK_N1 - 1) // BLOCK_N1, batch, nheads_k) + grid_dq = ((max_seqlen_q + BLOCK_M2 - 1) // BLOCK_M2, batch, nheads_k) + + if causal: + _bwd_kernel_split_dkdv_causal[grid_dkdv]( + q, + k, + v, + sm_scale, + do, + dk, + dv, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BLOCK_M=BLOCK_M1, + BLOCK_N=BLOCK_N1, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=HEAD_DIM_QK, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + _bwd_kernel_split_dq_causal[grid_dq]( + q, + k, + v, + sm_scale, + do, + dq, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BLOCK_M=BLOCK_M2, + BLOCK_N=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=HEAD_DIM_QK, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + else: + _bwd_kernel_split_dkdv_noncausal[grid_dkdv]( + q, + k, + v, + sm_scale, + do, + dk, + dv, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BLOCK_M=BLOCK_M1, + BLOCK_N=BLOCK_N1, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=HEAD_DIM_QK, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + + _bwd_kernel_split_dq_noncausal[grid_dq]( + q, + k, + v, + sm_scale, + do, + dq, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BLOCK_M=BLOCK_M2, + BLOCK_N=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=HEAD_DIM_QK, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + else: + raise ValueError( + f"Unknown backward mode '{mode}'. Expected 'split', 'fused_atomic' or 'fused'." + ) diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill.py b/flash_attn/flash_attn_triton_amd/bwd_prefill.py deleted file mode 100644 index 44e2c294b0d..00000000000 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill.py +++ /dev/null @@ -1,814 +0,0 @@ -from typing import Literal, Optional -import torch -import triton -import triton.language as tl -from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, compute_fp8_scaling_factors, get_shapes_from_layout, get_strides_from_layout, is_fp8, write_dropout_mask, create_dropout_mask - -# TODO: move this into utils.py so it's shared among kernels -# NOTE: triton fails to import tl.constexprs so create them here for the file -tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) -tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) - -@triton.jit -def _bwd_preprocess( - Out, - DO, - Delta, - stride_oz, stride_oh, stride_om, stride_ok, - stride_doz, stride_doh, stride_dom, stride_dok, - stride_deltaz, stride_deltah, stride_deltam, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - DESCALE_do, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, - N_CTX_Q: tl.constexpr, - Z: tl.constexpr, - H: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr, -): - pid_bh = tl.program_id(0) - pid_m = tl.program_id(1) - - # Compute batch and head indices - off_z = pid_bh // H - off_h = pid_bh % H - - if IS_VARLEN: - # Compute sequence lengths for the current batch - q_start = tl.load(cu_seqlens_q + off_z) - q_end = tl.load(cu_seqlens_q + off_z + 1) - k_start = tl.load(cu_seqlens_k + off_z) - k_end = tl.load(cu_seqlens_k + off_z + 1) - - # Compute actual sequence lengths - N_CTX_Q = q_end - q_start - N_CTX_K = k_end - k_start - else: - q_start = 0 - k_start = 0 - N_CTX_Q = max_seqlen_q - N_CTX_K = max_seqlen_k - - off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_d = tl.arange(0, BLOCK_DMODEL) - - # create masks - mask_m = off_m < N_CTX_Q - mask_d = off_d < ACTUAL_BLOCK_DMODEL - - # compute offsets - o_offset = Out + off_z * stride_oz + off_h * stride_oh + q_start * stride_om - do_offset = DO + off_z * stride_oz + off_h * stride_oh + q_start * stride_om - - # compute pointers - out_ptrs = o_offset + off_m[:, None] * stride_om + off_d[None, :] * stride_ok - do_ptrs = do_offset + off_m[:, None] * stride_dom + off_d[None, :] * stride_dok - - # load - o = tl.load(out_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0) - do = tl.load(do_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0) - - # compute delta - if IS_FP8: - stride_descale_q_z = H - descale_do = tl.load(DESCALE_do + off_z * stride_descale_q_z + off_h) - - # NOTE: do is scaled into the fp8 range and o is in fp8 but should be in the same scale as fp32 - delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) - else: - delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) - - # write-back delta - delta_offset = Delta + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam - delta_ptrs = delta_offset + off_m * stride_deltam - tl.store(delta_ptrs, delta, mask=mask_m) - - -@triton.jit -def _bwd_kernel_one_col_block( - Q, - K, - V, - sm_scale, - Out, - DO, - DQ, - DK, - DV, - L, - D, - q_offset, - k_offset, - v_offset, - do_offset, - dq_offset, - dk_offset, - dv_offset, - l_offset, - delta_offset, - dropout_offset, - stride_dq_all, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vn, - stride_vk, - stride_deltaz, - stride_deltah, - stride_deltam, - stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, - N_CTX_Q, - N_CTX_K, - start_n, - num_block_m, - num_block_n, - dropout_p, - philox_seed, - batch_philox_offset, - descale_q, - descale_k, - descale_v, - descale_do, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - SEQUENCE_PARALLEL: tl.constexpr, - CAUSAL: tl.constexpr, - DROPOUT: tl.constexpr, - USE_EXP2: tl.constexpr, - GROUP_SIZE: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - if CAUSAL: - # TODO: Causal can skip more blocks with something like lo = start_m * BLOCK_M - lo = 0 - else: - lo = 0 - - # initialize col and head offsets - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - - # masks - mask_n = offs_n < N_CTX_K - mask_d = offs_d < ACTUAL_BLOCK_DMODEL - kv_mask = mask_n[:, None] & mask_d[None, :] - - - # initialize grad accumulators - dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) - dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) - - # load k and v once per column block - k_ptrs = k_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk - v_ptrs = v_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk - k = tl.load(k_ptrs, mask=kv_mask, other=0.0) - kT = tl.trans(k) - vT = tl.trans(tl.load(v_ptrs, mask=kv_mask, other=0.0)) - - # loop over rows - for start_m in range(lo, num_block_m): - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk - dq_ptrs = dq_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk - do_ptrs = do_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk - - # update mask as row block changes - mask_m = offs_m < N_CTX_Q - q_mask = mask_m[:, None] & mask_d[None, :] - - # load q, k, v, do on-chip - q = tl.load(q_ptrs, mask=q_mask, other=0.0) - do = tl.load(do_ptrs, mask=q_mask, other=0.0) - - # recompute p = softmax(qk, dim=-1).T - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - if IS_FP8: - qk += (tl.dot(q, kT) * descale_q * descale_k) - else: - qk += tl.dot(q, kT) - - if CAUSAL: - col_offset = N_CTX_Q - N_CTX_K - causal_mask = offs_m[:, None] >= (col_offset + offs_n[None, :]) - qk = tl.where(causal_mask, qk, float("-inf")) - - l_ptrs = l_offset + offs_m * stride_deltam - l_i = tl.load(l_ptrs, mask=mask_m) - - # compute p - if USE_EXP2: - RCP_LN2: tl.constexpr = 1.4426950408889634 - qk *= sm_scale * RCP_LN2 - l_i *= RCP_LN2 - p = tl.math.exp2(qk - l_i[:, None]) - else: - qk *= sm_scale - p = tl.math.exp(qk - l_i[:, None]) - - # mask block in the cases where the data is smaller the block size - p_mask = mask_m[:, None] & mask_n[None, :] - p = tl.where(p_mask, p, 0.0) - - if DROPOUT: - # NOTE: must create a new var p_drop to prevent p (which is used later to compute ds) from changing - philox_offset = batch_philox_offset + offs_m[:, None] * stride_dropoutm + offs_n[None, :] * stride_dropoutn - # print("philox_seed:", philox_seed) - # print("philox_offset:", philox_offset) - if tl_DROPOUT_USE_PYTORCH: - dropout_ptrs = dropout_offset + offs_m[:, None] * stride_dropoutm + offs_n[None, :] * stride_dropoutn - dropout_mask = tl.load(dropout_ptrs, mask=p_mask) - else: - rand_vals = tl.rand(philox_seed, philox_offset) - dropout_mask = rand_vals > dropout_p - dropout_scale = 1/ (1 - dropout_p) - - if tl_DROPOUT_DUMP: - dropout_ptrs = dropout_offset + offs_m[:, None] * stride_dropoutm + offs_n[None, :] * stride_dropoutn - tl.store(dropout_ptrs, dropout_mask, mask=p_mask) - - # apply dropout mask - p_drop = tl.where(dropout_mask, p, 0.0) - p_drop_scaled = p_drop * dropout_scale - - # compute dv - if IS_FP8: - scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors(p_drop_scaled, FP8_MAX) - dv += (tl.dot(tl.trans(p_drop_scaled * scale_p_dropout).to(do.type.element_ty), do) * descale_p_dropout * descale_do) - else: - dv += tl.dot(tl.trans(p_drop_scaled).to(do.type.element_ty), do) - - # compute dp - if IS_FP8: - dp_drop_scaled = (tl.dot(do, vT) * descale_do * descale_v) - else: - dp_drop_scaled = tl.dot(do, vT) - dp = tl.where(dropout_mask, dp_drop_scaled, 0.0) * dropout_scale - else: - - # compute dv - if IS_FP8: - scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) - dv += (tl.dot(tl.trans(p * scale_p).to(do.type.element_ty), do) * descale_p * descale_do) - else: - dv += tl.dot(tl.trans(p).to(do.type.element_ty), do) - - # compute dp - if IS_FP8: - dp = (tl.dot(do, vT) * descale_do * descale_v) - else: - dp = tl.dot(do, vT) - - - # load delta - delta_ptrs = delta_offset + offs_m * stride_deltam - delta_i = tl.load(delta_ptrs, mask=mask_m) - - # compute ds - dscores_scaled = (p * (dp - delta_i[:, None])) - ds = dscores_scaled * sm_scale - ds = tl.where(p_mask, ds, 0.0) - - # compute descale_ds - if IS_FP8: - scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) - else: - scale_ds, descale_ds = 1.0, 1.0 - - # compute dk - if IS_FP8: - dk += (tl.dot(tl.trans(ds * scale_ds).to(q.type.element_ty), q) * descale_ds * descale_q) - else: - dk += tl.dot(tl.trans(ds).to(q.type.element_ty), q) - - # compute dq - if SEQUENCE_PARALLEL: - if IS_FP8: - dq = (tl.dot((ds * scale_ds).to(k.type.element_ty), k) * descale_ds * descale_k) - else: - dq = tl.dot(ds.to(k.type.element_ty), k) - else: - dq = tl.load(dq_ptrs, mask=q_mask, other=0.0) - if IS_FP8: - dq += (tl.dot((ds * scale_ds).to(k.type.element_ty), k) * descale_ds * descale_k) - else: - dq += tl.dot(ds.to(k.type.element_ty), k) - tl.store(dq_ptrs, dq.to(Q.dtype.element_ty), mask=q_mask) - - # write-back dv and dk - dk_ptrs = dk_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk - dv_ptrs = dv_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk - - # write-back - if GROUP_SIZE != 1: - # use atomic_add to properly accumulate gradients from multiple query heads - tl.atomic_add(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask) - tl.atomic_add(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask) - else: - tl.store(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask) - tl.store(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask) - -@triton.jit -def _bwd_kernel( - Q, - K, - V, - sm_scale, - Out, - DO, - DQ, - DK, - DV, - L, - Delta, - Dropout_mask, - DESCALE_q, - DESCALE_k, - DESCALE_v, - DESCALE_do, - stride_dq_all, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vn, - stride_vk, - stride_deltaz, - stride_deltah, - stride_deltam, - stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, - Z, - HQ, - HK, - num_block_m, - num_block_n, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset_base, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - SEQUENCE_PARALLEL: tl.constexpr, - CAUSAL: tl.constexpr, - DROPOUT: tl.constexpr, - USE_EXP2: tl.constexpr, - IS_VARLEN: tl.constexpr, - GROUP_SIZE: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - # program ids - off_zh = tl.program_id(0) - if SEQUENCE_PARALLEL: - start_n = tl.program_id(1) - off_z = off_zh // HQ - off_hq = off_zh % HQ - - # check if GQA/MQA - if GROUP_SIZE != 1: - off_hk = off_hq // GROUP_SIZE - else: - off_hk = off_hq - - if IS_VARLEN: - # Compute sequence lengths for the current batch - q_start = tl.load(cu_seqlens_q + off_z) - q_end = tl.load(cu_seqlens_q + off_z + 1) - k_start = tl.load(cu_seqlens_k + off_z) - k_end = tl.load(cu_seqlens_k + off_z + 1) - - # Compute actual sequence lengths - N_CTX_Q = q_end - q_start - N_CTX_K = k_end - k_start - else: - q_start = 0 - k_start = 0 - N_CTX_Q = max_seqlen_q - N_CTX_K = max_seqlen_k - - # input tensor offsets - q_offset = Q + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm - k_offset = K + off_z * stride_kz + off_hk * stride_kh + k_start * stride_kn - v_offset = V + off_z * stride_vz + off_hk * stride_vh + k_start * stride_vn - do_offset = DO + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm - l_offset = L + off_z * stride_deltaz + off_hq * stride_deltah + q_start * stride_deltam - delta_offset = Delta + off_z * stride_deltaz + off_hq * stride_deltah + q_start * stride_deltam - - if DROPOUT: - batch_philox_offset = philox_offset_base + off_z * stride_dropoutz + off_hq * stride_dropouth #+ q_start * stride_dropoutm - dropout_offset = Dropout_mask + off_z * stride_dropoutz + off_hq * stride_dropouth #+ q_start * stride_dropoutm - else: - batch_philox_offset = 0 - dropout_offset = 0 - - if IS_FP8: - stride_descale_q_z = HQ - stride_descale_kv_z = HK - - descale_q = tl.load(DESCALE_q + off_z * stride_descale_q_z + off_hq) - descale_k = tl.load(DESCALE_k + off_z * stride_descale_kv_z + off_hk) - descale_v = tl.load(DESCALE_v + off_z * stride_descale_kv_z + off_hk) - descale_do = tl.load(DESCALE_do + off_z * stride_descale_q_z + off_hq) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - # output tensor offsets - dk_offset = DK + off_z * stride_kz + off_hk * stride_kh + k_start * stride_kn - dv_offset = DV + off_z * stride_vz + off_hk * stride_vh + k_start * stride_vn - if SEQUENCE_PARALLEL: - dq_offset = DQ + start_n * stride_dq_all + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm - else: - dq_offset = DQ + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm - - # inner loop - if SEQUENCE_PARALLEL: - _bwd_kernel_one_col_block( - Q, - K, - V, - sm_scale, - Out, - DO, - DQ, - DK, - DV, - L, - Delta, - q_offset, - k_offset, - v_offset, - do_offset, - dq_offset, - dk_offset, - dv_offset, - l_offset, - delta_offset, - dropout_offset, - stride_dq_all, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vn, - stride_vk, - stride_deltaz, - stride_deltah, - stride_deltam, - stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, - N_CTX_Q, - N_CTX_K, - start_n, - num_block_m, - num_block_n, - dropout_p, - philox_seed, - batch_philox_offset, - descale_q, - descale_k, - descale_v, - descale_do, - BLOCK_M=BLOCK_M, - BLOCK_DMODEL=BLOCK_DMODEL, - ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, - BLOCK_N=BLOCK_N, - SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, - CAUSAL=CAUSAL, - DROPOUT=DROPOUT, - USE_EXP2=USE_EXP2, - GROUP_SIZE=GROUP_SIZE, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX - ) - else: - for start_n in range(0, num_block_n): - _bwd_kernel_one_col_block( - Q, - K, - V, - sm_scale, - Out, - DO, - DQ, - DK, - DV, - L, - Delta, - q_offset, - k_offset, - v_offset, - do_offset, - dq_offset, - dk_offset, - dv_offset, - l_offset, - delta_offset, - dropout_offset, - stride_dq_all, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vn, - stride_vk, - stride_deltaz, - stride_deltah, - stride_deltam, - stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, - N_CTX_Q, - N_CTX_K, - start_n, - num_block_m, - num_block_n, - dropout_p, - philox_seed, - batch_philox_offset, - descale_q, - descale_k, - descale_v, - descale_do, - BLOCK_M=BLOCK_M, - BLOCK_DMODEL=BLOCK_DMODEL, - ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, - BLOCK_N=BLOCK_N, - SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, - CAUSAL=CAUSAL, - DROPOUT=DROPOUT, - USE_EXP2=USE_EXP2, - GROUP_SIZE=GROUP_SIZE, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX - ) - - -# NOTE: smaller blocks have lower accuracy. more accumulation error probably 128 * 128 seems good but leads to oom. 64 * 64 has accumulation errors but no oom. -def attention_prefill_backward_triton_impl( - do: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - softmax_lse: torch.Tensor, - dq: torch.Tensor, - dk: torch.Tensor, - dv: torch.Tensor, - sm_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - layout: Literal["bshd", "bhsd", "thd"], - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - philox_seed: Optional[int], - philox_offset: Optional[int], - use_exp2: bool, - sequence_parallel: bool = True, - # fp8 - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None, -): - if DEBUG: - print() - print("attention_prefill_backward_triton_impl") - print("do:", do, do.shape) - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("o:", o, o.shape) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("dq:", dq, dq.shape if dq is not None else None) - print("dk:", dk, dk.shape if dk is not None else None) - print("dv:", dv, dv.shape if dv is not None else None) - print("sm_scale:", sm_scale) - print("alibi_slopes:", alibi_slopes) - print("causal:", causal) - print("layout:", layout) - print("cu_seqlens_q:", cu_seqlens_q) - print("cu_seqlens_k:", cu_seqlens_k) - print("max_seqlen_q:", max_seqlen_q) - print("max_seqlen_k:", max_seqlen_k) - print("dropout_p:", dropout_p) - print("philox_seed:", philox_seed) - print("philox_offset:", philox_offset) - print("use_exp2:", use_exp2) - print("sequence_parallel:", sequence_parallel) - print("descale_q:", descale_q) - print("descale_k:", descale_k) - print("descale_v:", descale_v) - print("descale_do:", descale_do) - - IS_FP8 = is_fp8(q) - if IS_FP8: - FP8_MAX=torch.finfo(q.dtype).max - else: - FP8_MAX=None - - # make contiguous - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - softmax_lse = softmax_lse.contiguous() - - # get strides and shape - batch, nheads_q, nheads_k, head_size, max_seqlen_q, max_seqlen_k = get_shapes_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) - q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout) - stride_qz, stride_qh, stride_qm, stride_qk = q_strides - stride_kz, stride_kh, stride_kn, stride_kk = k_strides - stride_vz, stride_vh, stride_vn, stride_vk = v_strides - stride_oz, stride_oh, stride_om, stride_ok = o_strides - is_varlen = layout == "thd" - group_size = nheads_q // nheads_k - use_dropout = (dropout_p > 0.0) - - # FIXME: some configs lead to oom for some reason when using 64 x 64 blocks - if max_seqlen_q <= 32 or max_seqlen_k <= 32: - BLOCK_M = 32 - BLOCK_N = 32 - else: - BLOCK_M = 64 - BLOCK_N = 64 - - if DEBUG: - print("BLOCK_M:", BLOCK_M) - print("BLOCK_N:", BLOCK_N) - - num_warps = 4 # NOTE: original is 8. changing it to 1 caused issues be careful - num_stages = 1 - waves_per_eu = 1 - - # divide up the problem - num_blocks_m = triton.cdiv(max_seqlen_q, BLOCK_M) - num_blocks_n = triton.cdiv(max_seqlen_k, BLOCK_N) - - # get closest power of 2 over or equal to 32. - padded_d_model = 1 << (head_size - 1).bit_length() - padded_d_model = max(padded_d_model, 16) - BLOCK_DMODEL = padded_d_model - ACTUAL_BLOCK_DMODEL = head_size - - do = do.contiguous() - - # deal with dq - if sequence_parallel: - dq = dq.unsqueeze(0).repeat(num_blocks_n, *([1] * len(q.shape))) # we do repeat instead of expand because we need to write data so views are not enough - stride_dq_all = dq.stride()[0] - - # assert contiguous - assert do.is_contiguous() - assert q.is_contiguous() - assert k.is_contiguous() - assert v.is_contiguous() - assert o.is_contiguous() - assert softmax_lse.is_contiguous() - - # init delta - delta = torch.zeros_like(softmax_lse) - if is_varlen: - stride_deltam, stride_deltah = delta.stride() - stride_deltaz = 0 - else: - stride_deltaz, stride_deltah, stride_deltam = delta.stride() - - # dropout mask tensor for debugging. We dump the dropout mask created in the kernel for testing - if use_dropout: - if DROPOUT_USE_PYTORCH: - dropout_mask = create_dropout_mask(dropout_p, (batch, nheads_q, max_seqlen_q, max_seqlen_k), seed = philox_seed) - else: - dropout_mask = torch.zeros((batch, nheads_q, max_seqlen_q, max_seqlen_k), device=q.device, - dtype=torch.float32) - stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn = (dropout_mask.stride(0), dropout_mask.stride(1), dropout_mask.stride(2), dropout_mask.stride(3)) - else: - dropout_mask = None - stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn = (0, 0 , 0 , 0) - - - _bwd_preprocess[(batch * nheads_q, num_blocks_m)]( - o, - do, - delta, - stride_oz, stride_oh, stride_om, stride_ok, - stride_oz, stride_oh, stride_om, stride_ok, # FIXME: don't share strides with derivatives this was causing a lot of issues - stride_deltaz, stride_deltah, stride_deltam, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - descale_do, - BLOCK_M=BLOCK_M, - BLOCK_DMODEL=BLOCK_DMODEL, - ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, - N_CTX_Q=max_seqlen_q, - Z=batch, - H=nheads_q, - IS_VARLEN=is_varlen, - IS_FP8=IS_FP8 - ) - - if DEBUG: - print("delta:", delta, delta.shape) - print("group_size:", group_size) - - _bwd_kernel[(batch * nheads_q, num_blocks_n if sequence_parallel else 1)]( - q, - k, - v, - sm_scale, - o, - do, - dq, - dk, - dv, - softmax_lse, - delta, - dropout_mask, - descale_q, - descale_k, - descale_v, - descale_do, - stride_dq_all, - stride_qz, stride_qh, stride_qm, stride_qk, # FIXME: don't share strides with derivatives this was causing a lot of issues - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - stride_deltaz, stride_deltah, stride_deltam, - stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, - batch, - nheads_q, - nheads_k, - num_blocks_m, - num_blocks_n, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, philox_seed, philox_offset, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - BLOCK_DMODEL=BLOCK_DMODEL, - ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, - SEQUENCE_PARALLEL=sequence_parallel, - CAUSAL=causal, - DROPOUT=use_dropout, - USE_EXP2=use_exp2, - num_warps=num_warps, - num_stages=num_stages, - waves_per_eu = waves_per_eu, - IS_VARLEN=is_varlen, - GROUP_SIZE=group_size, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX - ) - - if sequence_parallel: - dq = dq.sum(dim=0) - - if DEBUG: - print("attention_prefill_backward_triton_impl outputs") - print("dv:", dv, dv.shape) - print("dk:", dk, dk.shape) - print("dq:", dq, dq.shape) - if use_dropout: - print("dropout_mask:", dropout_mask, dropout_mask.shape if dropout_mask is not None else None) - print("dropout_fraction bwd:", 1.0 - (dropout_mask.sum()/ dropout_mask.numel()).item()) - write_dropout_mask(dropout_mask, "dropout_mask_bwd") - - return delta diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py deleted file mode 100644 index 3c018be4fa0..00000000000 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py +++ /dev/null @@ -1,3266 +0,0 @@ -import torch -import triton -import triton.language as tl - -from typing import Optional, Tuple - -@triton.jit -def cdiv_fn(x, y): - return (x + y - 1) // y - -@triton.jit -def compute_fp8_scaling_factors(x, fp8_max: tl.constexpr): - # compute fp8 scaling and descaling factor for a block - x_amax = tl.max(tl.abs(x)) # NOTE: abs deals with negative values - x_amax = tl.where(x_amax <= 1e-9, 1e-9, x_amax) - scale_x = fp8_max / x_amax - descale_x = x_amax / fp8_max - return scale_x, descale_x - -def is_fp8(x): - if x.dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz}: - if arch_supports_fp8(): - return True - else: - raise RuntimeError("This device does not support fp8") - else: - return False - - -def cast_to_fp8( - x: torch.Tensor, - fp8_dtype, - layout, - clamp_val=1e-9, -): - if len(x.shape) != 4: - raise ValueError(f"'bshd' tensor should have shape [batch, seqlen, heads, dim], got {x.shape}") - reduce_dims = (1, 3) # seq_len and dim dimensions - - # Compute the absolute max along reduce_dims, clamped to avoid 0-scale - x_abs_max = x.abs().amax(dim=reduce_dims) - x_abs_max = torch.maximum(x_abs_max, x.new_tensor(clamp_val)) - - # Unsqueeze back to a shape suitable for broadcast - unsqueeze_dims = sorted(reduce_dims) - for d in unsqueeze_dims: - x_abs_max = x_abs_max.unsqueeze(d) - - # compute scale and descale - fp8_max = torch.finfo(fp8_dtype).max - scale = fp8_max / x_abs_max - descale_factor = x_abs_max / fp8_max - - # cast to FP8, optionally setting requires_grad - x_fp8 = (x * scale).to(fp8_dtype) - - return x_fp8, descale_factor - - -def cast_varlen_to_fp8( - x: torch.Tensor, - fp8_dtype: torch.dtype, - cu_seqlens, - clamp_val: float = 1e-9, -) -> tuple[torch.Tensor, torch.Tensor]: - # validate tensor shape - if len(x.shape) != 3: - raise ValueError(f"tensor should have shape [total_seqlen, heads, dim], got {x.shape}") - num_heads = x.shape[1] - - # Get batch size from cu_seqlens - batch = cu_seqlens.shape[0] - 1 - fp8_max = torch.finfo(fp8_dtype).max - - # Compute scale and descale factors per sequence - x_fp8 = torch.zeros_like(x, dtype=fp8_dtype) - descale_factors = torch.zeros((batch, num_heads), device=x.device, dtype=torch.float32) - - for i in range(batch): - start = cu_seqlens[i] - end = cu_seqlens[i + 1] - x_slice = x[start:end] # Slice for current sequence - - # Standard tensor (0: seq_len, 2: head_dim) - x_abs_max = x_slice.abs().amax(dim=(0, 2)) # [heads] - - # apply minimum clamping - x_abs_max = torch.maximum(x_abs_max, x.new_tensor(clamp_val)) - - # compute scale and descale factors - scale_i = fp8_max / x_abs_max - descale_i = x_abs_max / fp8_max - - # store descale factors - descale_factors[i, :] = descale_i - - scale_reshape = scale_i.reshape(1, num_heads, 1) - - # scale and cast to FP8 - x_fp8[start:end] = (x_slice * scale_reshape).to(fp8_dtype) - - return x_fp8, descale_factors - - -#TODO Move this to a common folder. Will need to add future arch list -def get_arch(): - return triton.runtime.driver.active.get_current_target().arch - -def is_hip(): - return triton.runtime.driver.active.get_current_target().backend == "hip" - -def arch_supports_fp8(): - return is_hip() and get_arch() in ('gfx942') - -@triton.jit -def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): - if offset_first is not None and offset_second is not None: - mask = (offset_first[:, None] < boundary_first) & \ - (offset_second[None, :] < boundary_second) - tensor = tl.load(ptrs, mask=mask, other=0.0) - elif offset_first is not None: - mask = offset_first[:, None] < boundary_first - tensor = tl.load(ptrs, mask=mask, other=0.0) - elif offset_second is not None: - mask = offset_second[None, :] < boundary_second - tensor = tl.load(ptrs, mask=mask, other=0.0) - else: - tensor = tl.load(ptrs) - return tensor - -@triton.jit -def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): - # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix - # for casual mask we want something like this where (1 is kept and 0 is masked) - # seqlen_q = 2 and seqlen_k = 5 - # 1 1 1 1 0 - # 1 1 1 1 1 - # seqlen_q = 5 and seqlen_k = 2 - # 0 0 - # 0 0 - # 0 0 - # 1 0 - # 1 1 - # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal - # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False - # 1. offs_m[:,None] = [[0], - # [1], - # 2. offs_m[:,None] + seqlen_k = [[5], - # [6], - # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], - # [4], - # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], - # [4], [ 4, 3, 2, 1, 0]] - # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], - # [ -4, -3, -2, -1, 0]], - relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] - alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) - if transpose: - return alibi_block.T - else: - return alibi_block - -@triton.jit -def _attn_fwd_inner( - acc, - l_i, - m_i, - q, - k_ptrs, - v_ptrs, - stride_kn, - stride_vk, - stride_sn, - start_m, - seqlen_k, - seqlen_q, - dropout_p, - sd_mask_ptrs, - dropout_mask_ptrs, - philox_seed, - philox_ptrs, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - alibi_slope, - descale_q, - descale_k, - descale_v, - OFFS_M: tl.constexpr, - OFFS_N: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_DMODEL_POW2: tl.constexpr, - SM_SCALE: tl.constexpr, - IS_CAUSAL: tl.constexpr, - MASK_STEPS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_SCORES: tl.constexpr, - PADDED_HEAD: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - RCP_LN2: tl.constexpr = 1.4426950408889634 - - # loop over k, v, and update accumulator - - for start_n in range(block_min, block_max, BLOCK_N): - # For padded blocks, we will overrun the tensor size if - # we load all BLOCK_N. For others, the blocks are all within range. - if MASK_STEPS: - k_offs_n = start_n + tl.arange(0, BLOCK_N) - else: - k_offs_n = None - k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL_POW2) - k = load_fn(k_ptrs, k_offs_k, k_offs_n, BLOCK_DMODEL, seqlen_k) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - # We start from end of seqlen_k so only the first iteration would need - # to be checked for padding if it is not a multiple of block_n - # TODO: This can be optimized to only be true for the padded block. - if MASK_STEPS: - # If this is the last block / iteration, we want to - # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. - # last step might get wasted but that is okay. check if this masking works For - # that case. - if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): - boundary_m = tl.full([BLOCK_M], seqlen_k, dtype=tl.int32) - size_n = start_n + OFFS_N[None, :] - mask = size_n < boundary_m[:, None] - qk = tl.where(mask, qk, float("-inf")) - - # compute masks - q_mask = (OFFS_M[:, None] < seqlen_q) - k_mask = ((start_n + tl.arange(0, BLOCK_N))[None, :] < seqlen_k) - p_mask = q_mask & k_mask - - # -- compute qk ---- - if IS_FP8: - qk += (tl.dot(q, k) * descale_q * descale_k) - else: - qk += tl.dot(q, k) - qk_scaled = qk * SM_SCALE - if IS_CAUSAL: - causal_boundary = start_n + offs_n_causal - causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] - qk_scaled = tl.where(causal_mask, qk_scaled, float("-inf")) - - if alibi_slope is not None: - # Compute the global position of each token within the sequence - global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - global_n_positions = start_n + tl.arange(0, BLOCK_N) - alibi_block = compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, global_m_positions, - global_n_positions) - qk_scaled += alibi_block - # get max scores so far - m_ij = tl.maximum(m_i, tl.max(qk_scaled, 1)) - - # scale and subtract max - q_shifted = qk_scaled - m_ij[:, None] - - # Compute scaled QK and softmax probabilities - p = tl.math.exp2(q_shifted * RCP_LN2) - - # CAVEAT: Must update l_ij before applying dropout - l_ij = tl.sum(p, 1) - if ENABLE_DROPOUT: - rng_output = tl.rand(philox_seed, philox_ptrs) # TODO: use tl.randint for better performance - dropout_mask = rng_output > dropout_p - tl.store(dropout_mask_ptrs, dropout_mask, mask=p_mask) - - # return scores with negative values for dropped vals - sd_mask = tl.where(dropout_mask, p, -p) - tl.store(sd_mask_ptrs, sd_mask, mask=p_mask) - - # apply dropout mask in place - p = tl.where(dropout_mask, p, 0.0) - elif RETURN_SCORES: - # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - tl.store(sd_mask_ptrs, p, mask=p_mask) - - # -- update output accumulator -- - # alpha is an adjustment factor for acc and li as we loop and find new maxes - # store the diff in maxes to adjust acc and li as we discover new maxes - m_diff = m_i - m_ij - alpha = tl.math.exp2(m_diff * RCP_LN2) - acc = acc * alpha[:, None] - v = load_fn(v_ptrs, k_offs_n, k_offs_k, seqlen_k, BLOCK_DMODEL) - # -- update m_i and l_i - l_i = l_i * alpha + l_ij - # update m_i and l_i - m_i = m_ij - - if IS_FP8: - scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) - acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v) - else: - acc += tl.dot(p.to(v.type.element_ty), v) - - k_ptrs += BLOCK_N * stride_kn - v_ptrs += BLOCK_N * stride_vk - if RETURN_SCORES: - sd_mask_ptrs += BLOCK_N * stride_sn - - if ENABLE_DROPOUT: - dropout_mask_ptrs += BLOCK_N * stride_sn - philox_ptrs += BLOCK_N * stride_sn - - return acc, l_i, m_i - - -@triton.jit -def _attn_fwd(q_ptr: torch.Tensor, - k_ptr: torch.Tensor, - v_ptr: torch.Tensor, - descale_q_ptr: torch.Tensor, - descale_k_ptr: torch.Tensor, - descale_v_ptr: torch.Tensor, - out_ptr: torch.Tensor, - alibi_slopes_ptr: torch.Tensor, - s_dmask_ptr: torch.Tensor, - dropout_mask_ptr: torch.Tensor, - softmax_lse_ptr: torch.Tensor, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, - stride_oz, stride_oh, stride_om, stride_on, - stride_alibi_z, stride_alibi_h, - stride_sd_z, stride_sd_h, stride_sd_m, stride_sd_n, - stride_lse_z, stride_lse_h, stride_lse_m, - sm_scale, - cu_seqlens_q, - cu_seqlens_k, - dropout_p, - philox_seed, - philox_offset, - SEQLEN_Q: tl.constexpr, - SEQLEN_K: tl.constexpr, - IS_CAUSAL: tl.constexpr, - NUM_Q_HEADS: tl.constexpr, - NUM_K_HEADS: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_DMODEL_POW2: tl.constexpr, - RETURN_SCORES: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - VARLEN: tl.constexpr, -): - #calculate offsets - start_m = tl.program_id(0) #seqlen_q - off_q_head = tl.program_id(1) #num_q_heads - off_z = tl.program_id(2) #batch - - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL_POW2) - - if VARLEN: - cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) - cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) - - seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start - # We have a one-size-fits-all grid in id(0). Some seqlens might be too - # small for all start_m so for those we return early. - if start_m * BLOCK_M > seqlen_q: - return - cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) - cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) - seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start - else: - cu_seqlens_q_start = 0 - cu_seqlens_k_start = 0 - seqlen_q = SEQLEN_Q - seqlen_k = SEQLEN_K - - n_blocks = cdiv_fn(seqlen_k, BLOCK_N) - - # Now we compute whether we need to exit early due to causal masking. - # This is because for seqlen_q > seqlen_k, M rows of the attn scores - # are completely masked, resulting in 0s written to the output, and - # inf written to LSE. We don't need to do any GEMMs in this case. - # This block of code determines what N is, and if this WG is operating - # on those M rows. - if (IS_CAUSAL): - # If seqlen_q == seqlen_k, the attn scores are a square matrix. - # If seqlen_q != seqlen_k, attn scores are rectangular which means - # the causal mask boundary is bottom right aligned, and ends at either - # the top edge (seqlen_q < seqlen_k) or left edge. - - # This captures the decrease in n_blocks if we have a rectangular attn matrix - n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) - - # This is what adjusts the block_max for the current WG, only - # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks - n_blocks = min(n_blocks, n_blocks_seqlen) - - # If we have no blocks after adjusting for seqlen deltas, this WG is part of - # the blocks that are all 0. We exit early. - if n_blocks <= 0: - offs_out = (off_z * stride_oz + - off_q_head * stride_oh + - cu_seqlens_q_start * stride_om + - offs_m[:, None] * stride_om + - offs_d[None, :] * stride_on) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_POW2], dtype=out_ptr.type.element_ty) - out_mask = (offs_m[:, None] < seqlen_q) & (offs_d < BLOCK_DMODEL) - tl.store(out_ptr + offs_out, acc, mask=out_mask) - - if softmax_lse_ptr is not None: - offs_lse = (off_z * stride_lse_z + - off_q_head * stride_lse_h + - cu_seqlens_q_start * stride_lse_m + - offs_m*stride_lse_m - ) - lse_mask = offs_m < SEQLEN_Q - lse = tl.full([BLOCK_M], value=0.0, dtype=tl.float32) - tl.store(softmax_lse_ptr + offs_lse, lse, mask=lse_mask) - # TODO: Should dropout and return encoded softmax be handled here too? - - return - - grp_sz:tl.constexpr = NUM_Q_HEADS // NUM_K_HEADS - if grp_sz != 1: #Grouped Query Attention - off_k_head = off_q_head // grp_sz - else: - off_k_head = off_q_head - - #q,k,v offsets - q_offs = (off_z * stride_qz + - off_q_head * stride_qh + - cu_seqlens_q_start * stride_qm + - offs_m[:, None] * stride_qm + offs_d[None, :]*stride_qk - ) - q_ptrs = q_ptr + q_offs - - k_offs = (off_z * stride_kz + - off_k_head * stride_kh + - cu_seqlens_k_start * stride_kn + - offs_d[:, None] * stride_kk + offs_n[None, :]*stride_kn - ) - k_ptrs = k_ptr + k_offs - - v_offs = (off_z * stride_vz + - off_k_head * stride_vh + - cu_seqlens_k_start * stride_vn + - offs_n[:, None] * stride_vn + offs_d[None, :]*stride_vk - ) - v_ptrs = v_ptr + v_offs - - #alibi slopes - if alibi_slopes_ptr is not None: - alibi_offs = off_z * stride_alibi_z + off_q_head * stride_alibi_h - alibi_slope = tl.load(alibi_slopes + alibi_offs) - else: - alibi_slope = None - - #s_dmask (return_scores) - if s_dmask_ptr is not None: - s_dmask_offs = (off_z * stride_sd_z + - off_q_head * stride_sd_h + - offs_m[:, None] * stride_sd_m + - offs_n[None, :] * stride_sd_n - ) - s_dmask_ptrs = s_dmask_ptr + s_dmask_offs - else: - s_dmask_ptrs = None - - #dropout - if dropout_mask_ptr is not None: - dropout_mask_offs = (off_z * stride_sd_z + - off_q_head * stride_sd_h + - offs_m[:, None] * stride_sd_m + - offs_n[None, :] * stride_sd_n - ) - dropout_mask_ptrs = dropout_mask_ptr + dropout_mask_offs - philox_ptrs = (philox_offset + - off_z * stride_sd_z + - off_q_head * stride_sd_h + - offs_m[:, None] * stride_sd_m + - offs_n[None, :] * stride_sd_n - ) - else: - dropout_mask_ptrs = None - philox_ptrs = None - - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_POW2], dtype=tl.float32) - if (BLOCK_DMODEL == BLOCK_DMODEL_POW2): - q_mask = (offs_m[:, None] < seqlen_q) - else: - q_mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < BLOCK_DMODEL) - q = tl.load(q_ptrs, mask=q_mask, other=0.0) - if IS_FP8: - descale_q = tl.load(descale_q_ptr + off_z * stride_descale_q_z + off_q_head) - descale_k = tl.load(descale_k_ptr + off_z * stride_descale_k_z + off_k_head) - descale_v = tl.load(descale_v_ptr + off_z * stride_descale_v_z + off_k_head) - else: - descale_q, descale_k ,descale_v = 1.0, 1.0, 1.0 - - n_extra_tokens = 0 - if seqlen_k < BLOCK_N: - n_extra_tokens = BLOCK_N -seqlen_k - elif seqlen_k % BLOCK_N: - n_extra_tokens = seqlen_k % BLOCK_N - - #if CAUSAL, then determine masked_blocks and full blocks - # Here we compute how many full and masked blocks we have. - padded_block_k = n_extra_tokens != 0 - is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) - if IS_CAUSAL: - # There are always at least BLOCK_M // BLOCK_N masked blocks. - # Additionally there might be one more due to dissimilar seqlens. - masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) - else: - # Padding on Q does not need to be masked in the FA loop. - masked_blocks = padded_block_k - # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. - # In this case we might exceed n_blocks so pick the min. - masked_blocks = min(masked_blocks, n_blocks) - n_full_blocks = n_blocks - masked_blocks - block_min = 0 - block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false regardless of its actual - # value because there is no masking. Similarly we do not need padding. - if n_full_blocks > 0: - block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner(acc, - l_i, - m_i, - q, - k_ptrs, - v_ptrs, - stride_kn, - stride_vn, - stride_sd_n, - start_m, - seqlen_k, - seqlen_q, - dropout_p, - s_dmask_ptrs, dropout_mask_ptrs, philox_seed, philox_ptrs, - block_min, block_max, 0, 0, 0, alibi_slope, - descale_q, descale_k, descale_v, - offs_m, offs_n, BLOCK_M, BLOCK_N, BLOCK_DMODEL,BLOCK_DMODEL_POW2, - sm_scale, False, MASK_STEPS=False, ENABLE_DROPOUT=ENABLE_DROPOUT, - RETURN_SCORES=RETURN_SCORES, PADDED_HEAD=BLOCK_DMODEL!=BLOCK_DMODEL_POW2, - IS_FP8=IS_FP8, FP8_MAX=FP8_MAX - ) - block_min = block_max - block_max = n_blocks * BLOCK_N - - # Remaining blocks, if any, are full / not masked. - if (masked_blocks > 0): - if IS_CAUSAL: - offs_n_causal = offs_n + (seqlen_q - seqlen_k) - else: - offs_n_causal = 0 - k_ptrs += n_full_blocks * BLOCK_N * stride_kn - v_ptrs += n_full_blocks * BLOCK_N * stride_vn - if RETURN_SCORES: - s_dmask_ptrs += n_full_blocks * BLOCK_N * stride_sd_n - if ENABLE_DROPOUT: - dropout_mask_ptrs += n_full_blocks * BLOCK_N * stride_sd_n - acc, l_i, m_i = _attn_fwd_inner(acc, - l_i, - m_i, - q, - k_ptrs, - v_ptrs, - stride_kn, stride_vn, stride_sd_n, - start_m, seqlen_k, seqlen_q, - dropout_p, - s_dmask_ptrs, dropout_mask_ptrs, philox_seed, philox_ptrs, - block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, - descale_q, descale_k, descale_v, - offs_m, offs_n, BLOCK_M, BLOCK_N, BLOCK_DMODEL,BLOCK_DMODEL_POW2, - sm_scale, IS_CAUSAL, MASK_STEPS=True, ENABLE_DROPOUT=ENABLE_DROPOUT, - RETURN_SCORES=RETURN_SCORES, PADDED_HEAD=BLOCK_DMODEL!=BLOCK_DMODEL_POW2, - IS_FP8=IS_FP8, FP8_MAX=FP8_MAX - ) - # epilogue - # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. - l_recip = 1 / l_i[:, None] - acc = acc * l_recip - if ENABLE_DROPOUT: - dropout_scale = 1 / (1 - dropout_p) - acc = acc * dropout_scale - # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, - # then we have one block with a row of all NaNs which come from computing - # softmax over a row of all -infs (-inf - inf = NaN). We check for that here - # and store 0s where there are NaNs as these rows should've been zeroed out. - end_m_idx = (start_m + 1) * BLOCK_M - start_m_idx = start_m * BLOCK_M - causal_start_idx = seqlen_q - seqlen_k - if IS_CAUSAL: - if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full((BLOCK_DMODEL_POW2, ), causal_start_idx, dtype=tl.int32) - mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] - z = 0.0 - acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) - - # write back LSE(Log Sum Exponents), the log of the normalization constant - overflow_size = end_m_idx - seqlen_q - if softmax_lse_ptr is not None: - RCP_LN2: tl.constexpr = 1.4426950408889634 - LN2: tl.constexpr = 0.6931471824645996 - # compute log-sum-exp in base 2 units - mi_base2 = m_i * RCP_LN2 - softmax_lse = mi_base2 + tl.math.log2(l_i) - # convert back to natural units - softmax_lse *= LN2 - - if IS_CAUSAL: - # zero out nans caused by -infs when doing causal - lse_causal_mask = (start_m_idx + tl.arange(0, BLOCK_M)) < causal_start_idx - softmax_lse = tl.where(lse_causal_mask, 0.0, softmax_lse) - - # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. - # This is only true for the last M block. For others, overflow_size will be -ve - offs_lse = off_z * stride_lse_z + off_q_head * stride_lse_h + cu_seqlens_q_start * stride_lse_m + offs_m*stride_lse_m - if overflow_size > 0: - boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) - lse_mask = tl.arange(0, BLOCK_M) < boundary - tl.store(softmax_lse_ptr + offs_lse, softmax_lse, mask=lse_mask) # the log of the normalization constant - else: - tl.store(softmax_lse_ptr + offs_lse, softmax_lse) # the log of the normalization constant - - # write back O - offs_out = (off_z * stride_oz + - off_q_head * stride_oh + - cu_seqlens_q_start * stride_om + - offs_m[:, None] * stride_om + - offs_d[None, :] * stride_on) - out_mask = tl.full([BLOCK_M, BLOCK_DMODEL_POW2], 1, dtype=tl.int1) - if overflow_size > 0: - out_mask = out_mask & (offs_m[:, None] < seqlen_q) - if BLOCK_DMODEL != BLOCK_DMODEL_POW2: - out_mask = out_mask & (offs_d[None, :] < BLOCK_DMODEL) - op = acc.to(out_ptr.dtype.element_ty) - tl.store(out_ptr + offs_out, op, mask=out_mask) - -def _flash_attn_forward( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - alibi_slopes: Optional[torch.Tensor], - return_lse: bool, - return_softmax: bool, - max_seqlen_q: int, - max_seqlen_k: int, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k: Optional[torch.Tensor] = None, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - #FP8 - IS_FP8 = is_fp8(q) - FP8_MAX: tl.constexpr=torch.finfo(q.dtype).max - is_varlen = True if cu_seqlens_q is not None else False - - if IS_FP8: - o = torch.zeros_like(q, dtype=torch.float32) - else: - o = torch.zeros_like(q) - if is_varlen: - #Layout for q,k,v is thd ie [total_tokens, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], q.shape[2] - seqlen_k, num_k_heads = max_seqlen_k, k.shape[1] - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) - else: - #Layout for q,k,v is bshd ie [batch, seq_len, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = q.shape - seqlen_k = k.shape[1] - num_k_heads = k.shape[2] - q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) - k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) - v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) - o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) - - #padding for head_dim. Power of 2 or 16 - BLOCK_DMODEL_POW2 = triton.next_power_of_2(head_sz) - BLOCK_DMODEL_POW2 = max(BLOCK_DMODEL_POW2, 16) - - #softmax_lse [batch, num_q_heads, seqlen_q] - if return_lse: - if is_varlen: - softmax_lse = torch.zeros((q.shape[0], num_q_heads), device=q.device, dtype=torch.float32) - stride_lse_z, stride_lse_h, stride_lse_m = 0, softmax_lse.stride(1), softmax_lse.stride(0) - else: - softmax_lse = torch.zeros((batch, num_q_heads, max_seqlen_q), device=q.device, dtype=torch.float32) - stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() - else: - softmax_lse = None - - #exp_scores [batch, num_q_heads, seqlen_q, seqlen_k] - enable_dropout = dropout_p > 0.0 - if enable_dropout: - philox_seed = torch.randint(0, 0xffffff, (1,))[0].item() #No specific reason to restrict range to 0xffffff - philox_offset = torch.randint(0, 0xffffff, (1,))[0].item() #Pass in an int, not Tensor - else: - philox_seed = 0 - philox_offset = 0 - if return_softmax or enable_dropout: - s_dmask = torch.zeros((batch, num_q_heads, max_seqlen_q, max_seqlen_k), device=q.device, dtype=torch.float32) - dropout_mask = torch.zeros((batch, num_q_heads, max_seqlen_q, max_seqlen_k), device=q.device, dtype=torch.float32) - else: - s_dmask = None - dropout_mask = None - - - # Best config from ROCm/triton/python/perf-kernels/flash_attention.py::attn_fwd autotuning is BLOCK_M: 128, BLOCK_N: 64, waves_per_eu: 2, num_warps: 4, num_ctas: 1, num_stages: 1 - # Tuned for MI300x - config = { - 'BLOCK_M': 128, - 'BLOCK_N': 32, # BLOCK_N: 64 spills for _attn_fwd - 'waves_per_eu': 2, - 'num_warps': 4, - 'num_ctas': 1, - 'num_stages': 1, - } - - grid = lambda META:(triton.cdiv(seqlen_q, META['BLOCK_M']), num_q_heads, batch) - _attn_fwd[grid](q, - k, - v, - descale_q, - descale_k, - descale_v, - o, - alibi_slopes, - s_dmask, - dropout_mask, - softmax_lse, - *q_strides, - *k_strides, - *v_strides, - descale_q.stride(0) if descale_q is not None else 0, - descale_k.stride(0) if descale_k is not None else 0, - descale_v.stride(0) if descale_v is not None else 0, - *o_strides, - alibi_slopes.stride(0) if alibi_slopes is not None else 0, - alibi_slopes.stride(1) if alibi_slopes is not None else 0, - s_dmask.stride(0) if s_dmask is not None else 0, - s_dmask.stride(1) if s_dmask is not None else 0, - s_dmask.stride(2) if s_dmask is not None else 0, - s_dmask.stride(3) if s_dmask is not None else 0, - stride_lse_z if softmax_lse is not None else 0, - stride_lse_h if softmax_lse is not None else 0, - stride_lse_m if softmax_lse is not None else 0, - softmax_scale, - cu_seqlens_q, - cu_seqlens_k, - dropout_p, - philox_seed, - philox_offset, - SEQLEN_Q=max_seqlen_q, - SEQLEN_K=max_seqlen_k, - IS_CAUSAL=causal, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BLOCK_DMODEL=head_sz, - BLOCK_DMODEL_POW2=BLOCK_DMODEL_POW2, - RETURN_SCORES=return_softmax, - ENABLE_DROPOUT=enable_dropout, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - VARLEN=is_varlen, - **config - ) - - return o, softmax_lse, s_dmask, philox_seed, philox_offset - -# This function computes delta given output Out and gradient DO -# Here is the I/O shape: -# Out: (batch, nhead_q, max_seqlens_q, headDim) -# DO: (batch, nhead_q, max_seqlens_q, headDim) -# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at -@triton.jit -def _bwd_preprocess( - o_ptr, do_ptr, # noqa: E741 - delta_ptr, - stride_o_b, stride_o_h, stride_o_m, stride_o_k, - stride_delta_b, stride_delta_h, stride_delta_m, - stride_descale_do_z, - cu_seqlens_q, max_seqlen_q, - descale_do_ptr, - BLOCK_M: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr -): - pid_m = tl.program_id(0) #seqlen - bid = tl.program_id(1) #batch - hid = tl.program_id(2) #head - - # Handle varlen - q_start = 0 - seqlen_q = max_seqlen_q - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - seqlen_q = q_end - q_start - else: - q_start = 0 - seqlen_q = max_seqlen_q - - # Compute offsets - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - - # Offset O/DO by batch, head and q_start - offs = (bid * stride_o_b + - hid * stride_o_h + - q_start * stride_o_m + offs_m[:, None] * stride_o_m + - offs_k[None, :] * stride_o_k) - - # create masks - mask_m = offs_m < seqlen_q - mask = mask_m[:, None] - PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) - if PADDED_HEAD: - mask &= offs_k[None, :] < BLOCK_D_MODEL - - # load [BLOCK_M, BLOCK_D_MODEL_POW2] - o = tl.load(o_ptr + offs, mask=mask, other=0.0) - do = tl.load(do_ptr + offs, mask=mask, other=0.0) - - # compute and write-back to delta - if IS_FP8: - descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hid) - - # NOTE: do is in the fp8 range and o is not in fp8 - delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) - else: - delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) - - offs_delta = (bid * stride_delta_b + - hid * stride_delta_h + - q_start * stride_delta_m + offs_m * stride_delta_m) - tl.store(delta_ptr + offs_delta, delta, mask=mask_m) - -@triton.jit -def _bwd_dq_inner( - dq, - q, K, V, do, m, Delta, sm_scale, - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, - stride_dropout_m, stride_dropout_n, - stride_deltam, - seqlen_q, seqlen_k, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - start_m, start_n, end_n, num_steps, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - MASK: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - RCP_LN2: tl.constexpr = 1.4426950408889634 - - PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) - delta_qk = seqlen_q - seqlen_k - offs_m = start_m + tl.arange(0, BLOCK_M) - offs_n = start_n + tl.arange(0, BLOCK_N) - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - - # mask to make sure not OOB of seqlen_q - mask_m = offs_m < seqlen_q - - kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk - vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk - - # D (= delta) is pre-divided by ds_scale. - Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0) - - curr_n = start_n - step_n = BLOCK_N - curr_philox_offset = batch_philox_offset - curr_dropout_offset = dropout_offset - for blk_idx in range(num_steps): - offs_n = curr_n + tl.arange(0, BLOCK_N) - # end_n is needed because the end of causal True might not be perfectly - # aligned with the end of the block - mask_n = offs_n < end_n - mask_kT = mask_n[None, :] - mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) - if PADDED_HEAD: - mask_kT &= offs_k[:, None] < BLOCK_D_MODEL - - kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) - vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) - - #dropout - if ENABLE_DROPOUT: - philox_offs = (curr_philox_offset + - offs_m[:, None] * stride_dropout_m + - offs_n[None, :] * stride_dropout_n) - rand_vals = tl.rand(philox_seed, philox_offs) - dropout_mask = rand_vals > dropout_p - dropout_scale = 1 / (1 - dropout_p) - - #qk - if IS_FP8: - qk = tl.dot(q, kT) * descale_q * descale_k - else: - qk = tl.dot(q, kT) - p = tl.math.exp2(qk * sm_scale * RCP_LN2 - m * RCP_LN2) - - if MASK: - causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] - mask = causal_mask * mask_mn - p = tl.where(mask, p, 0.0) - - #dp - if IS_FP8: - dp = (tl.dot(do, vT) * descale_do * descale_v) - else: - dp = tl.dot(do, vT) - - if ENABLE_DROPOUT: - dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale - - #ds - delta_i = Di[:, None] - ds = p * (dp - delta_i) - - #dq - # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. - if IS_FP8: - scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) - dq += (tl.dot((ds*scale_ds).to(kT.type.element_ty), tl.trans(kT)) * descale_ds * descale_k) - else: - dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) - - curr_n += step_n - kT_ptrs += step_n * stride_kn - vT_ptrs += step_n * stride_vn - return dq - - -@triton.jit -def _bwd_dkdv_inner( - dk, dv, - Q, k, v, DO, M, D, sm_scale, - stride_q_m, stride_q_k, - stride_do_m, stride_do_k, - stride_dropout_m, stride_dropout_n, - stride_deltam, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - seqlen_q, seqlen_k, - start_n, start_m, num_steps, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - MASK: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) - delta_qk = seqlen_q - seqlen_k - offs_m = start_m + tl.arange(0, BLOCK_M) - offs_n = start_n + tl.arange(0, BLOCK_N) - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - - # mask to make sure not OOB of seqlen_q - mask_n = offs_n < seqlen_k - qT_ptrs = Q + offs_m[None, :] * stride_q_m + offs_k[:, None] * stride_q_k #[BLOCK_D_MODEL_POW2, BLOCK_M] - do_ptrs = DO + offs_m[:, None] * stride_do_m + offs_k[None,: ] * stride_do_k - curr_m = start_m - step_m = BLOCK_M - curr_philox_offset = batch_philox_offset - curr_dropout_offset = dropout_offset - RCP_LN2: tl.constexpr = 1.4426950408889634 - - #Iterate over blocks(BLOCK_M size) of Q while calculating - #a fixed block(BLOCK_N) of dk and dv. Note, during backward - #pass P has to be recomputed. However, this kernel computes - #dV and dK, so we compute we need P^T and S^T. See backward pass - #equations - # - #From Flash Attention Paper: - #ForwardPass: S = QkT, P=softmax(S), O=PV - # - #BackwardPass equations - #dV = P^TdO - #dP = dOV^T - #dS = dsoftmax(dP) - #dQ = dSK - #dK = QdS^T - for blk_idx in range(num_steps): - offs_m = curr_m + tl.arange(0, BLOCK_M) - mask_m = offs_m < seqlen_q - mask_qT = mask_m[None, :] - mask_do = mask_m[:, None] - mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) - if PADDED_HEAD: - mask_qT &= offs_k[:, None] < BLOCK_D_MODEL - mask_do &= offs_k[None, :] < BLOCK_D_MODEL - - #load qT - qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) - - #dropout - if ENABLE_DROPOUT: - # NOTE: dropout is transposed because it is used to mask pT - philox_offs = (curr_philox_offset + - offs_m[None, :] * stride_dropout_m + - offs_n[:, None] * stride_dropout_n) - rand_vals = tl.rand(philox_seed, philox_offs) - dropout_mask = rand_vals > dropout_p - dropout_scale = 1.0 / (1 - dropout_p) - - #Load M - m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) - - #Compute qkT - if IS_FP8: - qkT = (tl.dot(k, qT) * descale_q * descale_k) - else: - qkT = tl.dot(k, qT) - - #Compute pT(use m and also apply sm_scale) - pT = tl.math.exp(qkT * sm_scale - m[None, :]) - - if MASK: - causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] - mask = causal_mask & mask_nm - pT = tl.where(mask, pT, 0.0) - - #load DO - do = tl.load(do_ptrs, mask=mask_do, other=0.0) - - #dV - if ENABLE_DROPOUT: - pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale - if IS_FP8: - scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors(pT_dropout, FP8_MAX) - dv += (tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do) * descale_p_dropout * descale_do) - else: - dv += tl.dot(pT_dropout.to(do.type.element_ty), do) - else: - if IS_FP8: - scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) - dv += (tl.dot((pT * scale_pT).to(do.type.element_ty), do) * descale_pT * descale_do) - else: - dv += tl.dot(pT.to(do.type.element_ty), do) - - #Load delta - Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) - - #Compute dP and dS - if IS_FP8: - dpT = tl.dot(v, tl.trans(do)) * descale_v * descale_do - else: - dpT = tl.dot(v, tl.trans(do)) - - if ENABLE_DROPOUT: - dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale - - delta_i = Di[None, :] - dsT = pT * (dpT - delta_i) - - #compute dk - if IS_FP8: - scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) - dk += (tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) * descale_dsT * descale_q) - else: - dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) - - #increment pointers - curr_m += step_m - qT_ptrs += step_m * stride_q_m - do_ptrs += step_m * stride_do_m - - return dk, dv - - -@triton.jit -def _bwd_dkdvdq_inner( - dk, dv, - Q, k, v, DO, DQ, M, D, sm_scale, - stride_q_m, stride_q_k, - stride_do_m, stride_do_k, - stride_dropout_m, stride_dropout_n, - stride_deltam, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - seqlen_q, seqlen_k, - start_n, start_m, num_steps, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - MASK: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - workgroup_id: tl.int32, -): - PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) - delta_qk = seqlen_q - seqlen_k - offs_m = start_m + tl.arange(0, BLOCK_M) - offs_n = start_n + tl.arange(0, BLOCK_N) - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - - # mask to make sure not OOB of seqlen_q - mask_n = offs_n < seqlen_k - - qT_ptrs_start = Q + offs_m[None, :] * stride_q_m + offs_k[:, None] * stride_q_k #[BLOCK_D_MODEL_POW2, BLOCK_M] - dq_ptrs_start = DQ + offs_m[:, None] * stride_q_m + offs_k[None,:] * stride_q_k #[BLOCK_M, BLOCK_D_MODEL_POW2] - - do_ptrs_start = DO + offs_m[:, None] * stride_do_m + offs_k[None,: ] * stride_do_k - curr_m = start_m - step_m = BLOCK_M - curr_philox_offset = batch_philox_offset - curr_dropout_offset = dropout_offset - RCP_LN2: tl.constexpr = 1.4426950408889634 - - #Iterate over blocks(BLOCK_M size) of Q while calculating - #a fixed block(BLOCK_N) of dk and dv. Note, during backward - #pass P has to be recomputed. However, this kernel computes - #dV and dK, so we compute we need P^T and S^T. See backward pass - #equations - # - #From Flash Attention Paper: - #ForwardPass: S = QkT, P=softmax(S), O=PV - # - #BackwardPass equations - #dV = P^TdO - #dP = dOV^T - #dS = dsoftmax(dP) - #dQ = dSK - #dK = QdS^T - - # Compute a starting index and step based on workgroup_id - # Use a simple hash-like function to spread out the starting points - start_idx = (workgroup_id * 17) % num_steps # 17 is an arbitrary prime to spread indices - # Ensure step is coprime with num_steps to visit all indices exactly once - step = 1 # 3 if num_steps > 1 or num_steps==3 else 1 # coprime with num_steps - - - for iter in range(num_steps): - # Compute the permuted block index - blk_idx = (start_idx + iter * step) % num_steps - - curr_m = start_m + blk_idx * step_m - qT_ptrs = qT_ptrs_start + blk_idx * step_m * stride_q_m - dq_ptrs = dq_ptrs_start + blk_idx * step_m * stride_q_m - do_ptrs = do_ptrs_start + blk_idx * step_m * stride_do_m - - offs_m = curr_m + tl.arange(0, BLOCK_M) - mask_m = offs_m < seqlen_q - mask_qT = mask_m[None, :] - mask_do = mask_m[:, None] - mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) - - if PADDED_HEAD: - mask_qT &= offs_k[:, None] < BLOCK_D_MODEL - mask_do &= offs_k[None, :] < BLOCK_D_MODEL - - #load qT - qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) - - #dropout - if ENABLE_DROPOUT: - # NOTE: dropout is transposed because it is used to mask pT - philox_offs = (curr_philox_offset + - offs_m[None, :] * stride_dropout_m + - offs_n[:, None] * stride_dropout_n) - rand_vals = tl.rand(philox_seed, philox_offs) - dropout_mask = rand_vals > dropout_p - dropout_scale = 1.0 / (1 - dropout_p) - - #Load M - m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) - - #Compute qkT - if IS_FP8: - qkT = (tl.dot(k, qT) * descale_q * descale_k) - else: - qkT = tl.dot(k, qT) - - #Compute pT(use m and also apply sm_scale) - pT = tl.math.exp(qkT * sm_scale - m[None, :]) - - if MASK: - causal_mask = (offs_m[None, :] - delta_qk) >= (offs_n[:, None]) - mask = causal_mask & mask_nm - pT = tl.where(mask, pT, 0.0) - - #load DO - do = tl.load(do_ptrs, mask=mask_do, other=0.0) - - #dV - if ENABLE_DROPOUT: - pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale - if IS_FP8: - scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors(pT_dropout, FP8_MAX) - dv += (tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do) * descale_p_dropout * descale_do) - else: - dv += tl.dot(pT_dropout.to(do.type.element_ty), do) - else: - if IS_FP8: - scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) - dv += (tl.dot((pT * scale_pT).to(do.type.element_ty), do) * descale_pT * descale_do) - else: - dv += tl.dot(pT.to(do.type.element_ty), do) - - #Load delta - Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) - - #Compute dP and dS - if IS_FP8: - dpT = tl.dot(v, tl.trans(do)) * descale_v * descale_do - else: - dpT = tl.dot(v, tl.trans(do)) - - if ENABLE_DROPOUT: - dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale - - delta_i = Di[None, :] - dsT = pT * (dpT - delta_i) - - #compute dk - if IS_FP8: - scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) - dk += (tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) * descale_dsT * descale_q) - else: - dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) - - - # We can compute the dq_partial here and do a atomic add to the correct memory location - # NOTE: Possible problems with the atomic add: contention, is inside a loop which has achieved bad perf before - # (BLOCK_M, BLOCK_N) x (BLOCK_N, D) - if IS_FP8: - dq_partial = tl.dot((dsT * scale_dsT).to(k.dtype).T, k) * descale_dsT * descale_k - else: - dq_partial = tl.dot(dsT.to(k.dtype).T, k) - tl.atomic_add( - dq_ptrs, - dq_partial * sm_scale, - mask=mask_m[:, None], - sem="relaxed", - ) - - return dk, dv - - -@triton.jit -def _bwd_kernel_dkdvdq_causal( - q_ptr, k_ptr, v_ptr, sm_scale, do_ptr, dk_ptr, dv_ptr, dq_ptr, - m_ptr, delta_ptr, - stride_q_b, stride_q_h, stride_q_m, stride_q_k, - stride_k_b, stride_k_h, stride_k_n, stride_k_k, - stride_v_b, stride_v_h, stride_v_n, stride_v_k, - stride_dk_b, stride_dk_h, stride_dk_n, stride_dk_k, - stride_delta_b, stride_delta_h, stride_delta_m, - stride_do_b, stride_do_h, stride_do_m, stride_do_k, - stride_dropout_b, stride_dropout_h, stride_dropout_m, stride_dropout_n, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - dropout_mask, dropout_p, philox_seed, philox_offset_base, - descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, - NUM_Q_HEADS: tl.constexpr, - NUM_K_HEADS: tl.constexpr, - BATCH, - NUM_K_PIDS, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - wid = tl.program_id(0) # workgoup id: 0, ..., NUM_K_PIDS * BATCH * NUM_K_HEADS - 1 - - # workgroups get launched first along batch dim, then in head_k dim, and then in seq k block dim - batch_idx = wid % BATCH - head_k_idx = wid // BATCH % NUM_K_HEADS - seq_k_blk_idx = wid // (BATCH * NUM_K_HEADS) % NUM_K_PIDS - - #Determine q and k start along with seqlen_q and seqlen_k - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + batch_idx) - q_end = tl.load(cu_seqlens_q + batch_idx + 1) - k_start = tl.load(cu_seqlens_k + batch_idx) - k_end = tl.load(cu_seqlens_k + batch_idx + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) - dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) - - # Figure out causal starting block since we have seqlen_q >=< seqlen_k. - # Unlike forward pass where we tile on M dim and iterate on N dim, so that - # we can skip some M blocks, in backward pass, we tile on the N dim for kv - # and iterate over the M. In this way, we cannot skip N blocks, but only to - # determine the starting M blocks to skip some initial blocks masked by - # causal. - delta_qk = seqlen_q - seqlen_k - - # q > k: diretcly skip all the way until the start of causal block - start_delta_q_gt_k = delta_qk - - # q < k: some blocks will have no Masked block, other needs to re-calc - # starting position - # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the - # masked op - num_blocks_skip = -delta_qk // BLOCK_N - delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk - start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M - if delta_qk >= 0: - start_delta = delta_qk - else: - start_delta = start_delta_q_lt_k - - start_n = seq_k_blk_idx * BLOCK_N - - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - offs_n = start_n + tl.arange(0, BLOCK_N) - # Mask for loading K and V - mask_kv = offs_n[:, None] < seqlen_k - PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) - if PADDED_HEAD: - mask_k = offs_k < BLOCK_D_MODEL - mask_kv &= mask_k[None, :] - - GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS - adj_k = (batch_idx * stride_k_b + - head_k_idx * stride_k_h + - k_start * stride_k_n + offs_n[:, None] * stride_k_n + - offs_k[None, :] * stride_k_k) - adj_v = (batch_idx * stride_v_b + - head_k_idx * stride_v_h + - k_start * stride_v_n + offs_n[:, None] * stride_v_n + - offs_k[None, :] * stride_v_k) - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(k_ptr + adj_k , mask=mask_kv, other=0.0) - v = tl.load(v_ptr + adj_v, mask=mask_kv, other=0.0) - - # If MQA / GQA, set the K and V head offsets appropriately. - for head_q_idx in range(head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE): - if delta_qk >= 0: - start_m = start_n + start_delta - len_m = BLOCK_N - else: - start_m = max(start_n + delta_qk, 0) - start_m = (start_m // BLOCK_M) * BLOCK_M - # because we might shift the masked blocks up, we are deeper into - # the masked out region, so we would potentially increase the total - # steps with masked operation to get out of it - residue_m = max(start_n + delta_qk - start_m, 0) - len_m = BLOCK_N + residue_m - - # offset input and output tensor by batch and Q/K heads - adj_q = batch_idx * stride_q_b + head_q_idx * stride_q_h + q_start * stride_q_m - - q_ptr_adj = q_ptr + adj_q - dq_ptr_adj = dq_ptr + adj_q - - adj_do = batch_idx * stride_do_b + head_q_idx * stride_do_h + q_start * stride_do_m - do_ptr_adj = do_ptr + adj_do - adj_delta = batch_idx * stride_delta_b + head_q_idx * stride_delta_h + q_start * stride_delta_m - m_ptr_adj = m_ptr + adj_delta - delta_ptr_adj = delta_ptr + adj_delta - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = (philox_offset_base + batch_idx * stride_dropout_b + - head_q_idx * stride_dropout_h) - dropout_offset = (dropout_mask + batch_idx * stride_dropout_b + - head_q_idx * stride_dropout_h) - - MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR - # bound the masked operation to q len so it does not have to wast cycles - len_m = min(len_m, seqlen_q) - num_steps = tl.cdiv(len_m, MASK_BLOCK_M) - - - # when q < k, we may skip the initial masked op - # if seq_k_blk_idx < num_blocks_skip: - # num_steps = 0 - - if IS_FP8: - descale_q = tl.load(descale_q_ptr + batch_idx * stride_descale_q_z + head_q_idx) - descale_k = tl.load(descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx) - descale_v = tl.load(descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx) - descale_do = tl.load(descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - # if start_m is negative, the current N-tile has no block on the - # diagonal of causal mask, so everything have no causal mask - dk, dv = _bwd_dkdvdq_inner( - dk, dv, # output tensors - q_ptr_adj, k, v, do_ptr_adj, dq_ptr_adj, m_ptr_adj, delta_ptr_adj, sm_scale, # input tensors - stride_q_m, stride_q_k, # strides for q - stride_do_m, stride_do_k, # strides for o - stride_dropout_m, stride_dropout_n, # strides for dropout - stride_delta_m, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - seqlen_q, seqlen_k, # max sequence length for q and k - start_n, start_m, num_steps, # iteration numbers - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - MASK_BLOCK_M, BLOCK_N, # block dim - BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, # head dim - MASK=True, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - workgroup_id=seq_k_blk_idx, - ) - start_m += num_steps * MASK_BLOCK_M - num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) - end_m = start_m + num_steps * BLOCK_M - - dk, dv = _bwd_dkdvdq_inner( - dk, dv, # output tensors - q_ptr_adj, k, v, do_ptr_adj, dq_ptr_adj, m_ptr_adj, delta_ptr_adj, sm_scale, # input tensors - stride_q_m, stride_q_k, # strides for q - stride_do_m, stride_do_k, # strides for o - stride_dropout_m, stride_dropout_n, # strides for dropout - stride_delta_m, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - seqlen_q, seqlen_k, # max sequence length for q and k - start_n, start_m, num_steps, # iteration numbers - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - BLOCK_M, BLOCK_N, # block dim - BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, # head dim - MASK=False, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - workgroup_id=seq_k_blk_idx, - ) - - # Write back dV and dK. - offs_dkdv = (batch_idx * stride_dk_b + - head_k_idx * stride_dk_h + - k_start * stride_dk_n + offs_n[:, None] * stride_dk_n + - offs_k[None, :] * stride_dk_k) - tl.store(dv_ptr + offs_dkdv, dv, mask=mask_kv) - dk *= sm_scale - tl.store(dk_ptr + offs_dkdv, dk, mask=mask_kv) - - -@triton.jit -def _bwd_kernel_dkdv_causal( - q_ptr, k_ptr, v_ptr, sm_scale, do_ptr, dk_ptr, dv_ptr, - m_ptr, delta_ptr, - stride_q_b, stride_q_h, stride_q_m, stride_q_k, - stride_k_b, stride_k_h, stride_k_n, stride_k_k, - stride_v_b, stride_v_h, stride_v_n, stride_v_k, - stride_dk_b, stride_dk_h, stride_dk_n, stride_dk_k, - stride_delta_b, stride_delta_h, stride_delta_m, - stride_do_b, stride_do_h, stride_do_m, stride_do_k, - stride_dropout_b, stride_dropout_h, stride_dropout_m, stride_dropout_n, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - dropout_mask, dropout_p, philox_seed, philox_offset_base, - descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, - NUM_Q_HEADS: tl.constexpr, - NUM_K_HEADS: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - #seq block, batch, head_k - seq_k_blk_idx = tl.program_id(0) - batch_idx = tl.program_id(1) - head_k_idx = tl.program_id(2) - - #Determine q and k start along with seqlen_q and seqlen_k - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + batch_idx) - q_end = tl.load(cu_seqlens_q + batch_idx + 1) - k_start = tl.load(cu_seqlens_k + batch_idx) - k_end = tl.load(cu_seqlens_k + batch_idx + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) - dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) - - # Figure out causal starting block since we have seqlen_q >=< seqlen_k. - # Unlike forward pass where we tile on M dim and iterate on N dim, so that - # we can skip some M blocks, in backward pass, we tile on the N dim for kv - # and iterate over the M. In this way, we cannot skip N blocks, but only to - # determine the starting M blocks to skip some initial blocks masked by - # causal. - delta_qk = seqlen_q - seqlen_k - - # q > k: diretcly skip all the way until the start of causal block - start_delta_q_gt_k = delta_qk - - # q < k: some blocks will have no Masked block, other needs to re-calc - # starting position - # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the - # masked op - num_blocks_skip = -delta_qk // BLOCK_N - delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk - start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M - if delta_qk >= 0: - start_delta = delta_qk - else: - start_delta = start_delta_q_lt_k - - start_n = seq_k_blk_idx *BLOCK_N - - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - offs_n = start_n + tl.arange(0, BLOCK_N) - # Mask for loading K and V - mask_kv = offs_n[:, None] < seqlen_k - PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) - if PADDED_HEAD: - mask_k = offs_k < BLOCK_D_MODEL - mask_kv &= mask_k[None, :] - - GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS - adj_k = (batch_idx * stride_k_b + - head_k_idx * stride_k_h + - k_start * stride_k_n + offs_n[:, None] * stride_k_n + - offs_k[None, :] * stride_k_k) - adj_v = (batch_idx * stride_v_b + - head_k_idx * stride_v_h + - k_start * stride_v_n + offs_n[:, None] * stride_v_n + - offs_k[None, :] * stride_v_k) - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(k_ptr + adj_k , mask=mask_kv, other=0.0) - v = tl.load(v_ptr + adj_v, mask=mask_kv, other=0.0) - - # If MQA / GQA, set the K and V head offsets appropriately. - for head_q_idx in range(head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE): - if delta_qk >= 0: - start_m = start_n + start_delta - len_m = BLOCK_N - else: - start_m = max(start_n + delta_qk, 0) - start_m = start_m // BLOCK_M * BLOCK_M - # because we might shift the masked blocks up, we are deeper into - # the masked out region, so we would potentially increase the total - # steps with masked operation to get out of it - residue_m = max(start_n + delta_qk - start_m, 0) - len_m = BLOCK_N + residue_m - - # offset input and output tensor by batch and Q/K heads - adj_q = batch_idx * stride_q_b + head_q_idx * stride_q_h + q_start * stride_q_m - q_ptr_adj = q_ptr + adj_q - adj_do = batch_idx * stride_do_b + head_q_idx * stride_do_h + q_start * stride_do_m - do_ptr_adj = do_ptr + adj_do - adj_delta = batch_idx * stride_delta_b + head_q_idx * stride_delta_h + q_start * stride_delta_m - m_ptr_adj = m_ptr + adj_delta - delta_ptr_adj = delta_ptr + adj_delta - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = (philox_offset_base + batch_idx * stride_dropout_b + - head_q_idx * stride_dropout_h) - dropout_offset = (dropout_mask + batch_idx * stride_dropout_b + - head_q_idx * stride_dropout_h) - - MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR - # bound the masked operation to q len so it does not have to wast cycles - len_m = min(len_m, seqlen_q) - num_steps = tl.cdiv(len_m, MASK_BLOCK_M) - # when q < k, we may skip the initial masked op - if seq_k_blk_idx < num_blocks_skip: - num_steps = 0 - - if IS_FP8: - descale_q = tl.load(descale_q_ptr + batch_idx * stride_descale_q_z + head_q_idx) - descale_k = tl.load(descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx) - descale_v = tl.load(descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx) - descale_do = tl.load(descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - # if start_m is negative, the current N-tile has no block on the - # diagonal of causal mask, so everything have no causal mask - dk, dv = _bwd_dkdv_inner( - dk, dv, # output tensors - q_ptr_adj, k, v, do_ptr_adj, m_ptr_adj, delta_ptr_adj, sm_scale, # input tensors - stride_q_m, stride_q_k, # strides for q - stride_do_m, stride_do_k, # strides for o - stride_dropout_m, stride_dropout_n, # strides for dropout - stride_delta_m, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - seqlen_q, seqlen_k, # max sequence length for q and k - start_n, start_m, num_steps, # iteration numbers - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - MASK_BLOCK_M, BLOCK_N, # block dim - BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, # head dim - MASK=True, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - ) - start_m += num_steps * MASK_BLOCK_M - num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) - end_m = start_m + num_steps * BLOCK_M - - dk, dv = _bwd_dkdv_inner( - dk, dv, # output tensors - q_ptr_adj, k, v, do_ptr_adj, m_ptr_adj, delta_ptr_adj, sm_scale, # input tensors - stride_q_m, stride_q_k, # strides for q - stride_do_m, stride_do_k, # strides for o - stride_dropout_m, stride_dropout_n, # strides for dropout - stride_delta_m, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - seqlen_q, seqlen_k, # max sequence length for q and k - start_n, start_m, num_steps, # iteration numbers - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - BLOCK_M, BLOCK_N, # block dim - BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, # head dim - MASK=False, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - ) - - # Write back dV and dK. - offs_dkdv = (batch_idx * stride_dk_b + - head_k_idx * stride_dk_h + - k_start * stride_dk_n + offs_n[:, None] * stride_dk_n + - offs_k[None, :] * stride_dk_k) - tl.store(dv_ptr + offs_dkdv, dv, mask=mask_kv) - dk *= sm_scale - tl.store(dk_ptr + offs_dkdv, dk, mask=mask_kv) - -@triton.jit -def _bwd_kernel_dq_causal( - q_ptr, k_ptr, v_ptr, sm_scale, do_ptr, dq_ptr, - m_ptr, delta_ptr, - stride_q_b, stride_q_h, stride_q_m, stride_q_k, - stride_k_b, stride_k_h, stride_k_n, stride_k_k, - stride_v_b, stride_v_h, stride_v_n, stride_v_k, - stride_dq_b, stride_dq_h, stride_dq_m, stride_dq_k, - stride_delta_b, stride_delta_h, stride_delta_m, - stride_do_b, stride_do_h, stride_do_m, stride_do_k, - stride_dropout_b, stride_dropout_h, stride_dropout_m, stride_dropout_n, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - dropout_mask, dropout_p, philox_seed, philox_offset_base, - descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, - NUM_Q_HEADS: tl.constexpr, - NUM_K_HEADS: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - seq_q_blk_idx = tl.program_id(0) - batch_idx = tl.program_id(1) - head_k_idx = tl.program_id(2) - - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + batch_idx) - q_end = tl.load(cu_seqlens_q + batch_idx + 1) - k_start = tl.load(cu_seqlens_k + batch_idx) - k_end = tl.load(cu_seqlens_k + batch_idx + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - # Figure out causal starting block since we have seqlen_q <=> seqlen_k. - # Unlike forward pass where we tile on M dim and iterate on N dim, so that - # we can skip some M blocks, in backward pass, we tile on the N dim for kv - # and iterate over the M. In this way, we cannot skip N blocks, but only to - # determine the starting M blocks to skip some initial blocks masked by - # causal. - # DQ tiles on M dim and iterate on N dim, so we there could be some tiles we - # can simply skip and we need to adjust starting position. - start_m = seq_q_blk_idx * BLOCK_M - # seqlen_q > seqlen_k, no need to process these tile for dq - delta_qk = seqlen_q - seqlen_k - if start_m + BLOCK_M < delta_qk: - return - - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - offs_m = start_m + tl.arange(0, BLOCK_M) - # Mask for loading K and V - mask_q = offs_m[:, None] < seqlen_q - PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) - if PADDED_HEAD: - mask_k = offs_k < BLOCK_D_MODEL - mask_q &= mask_k[None, :] - offs_q = offs_m[:, None] * stride_q_m + offs_k[None, :] * stride_q_k - offs_do = offs_m[:, None] * stride_do_m + offs_k[None, :] * stride_do_k - adj_k = batch_idx * stride_k_b + head_k_idx * stride_k_h + k_start * stride_k_n - adj_v = batch_idx * stride_v_b + head_k_idx * stride_v_h + k_start * stride_v_n - k_ptr_adj = k_ptr - v_ptr_adj = v_ptr - k_ptr_adj += adj_k - v_ptr_adj += adj_v - - # If MQA / GQA, set the K and V head offsets appropriately. - GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS - for head_q_idx in range(head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE): - # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front - # for every M-tile - end_n = start_m + BLOCK_M - delta_qk - # clamp end_n at [0, seqlen_k] - end_n = max(min(end_n, seqlen_k), 0) - - # offset input and output tensor by batch and Q/K heads - adj_q = (batch_idx * stride_q_b + - head_q_idx * stride_q_h + - q_start * stride_q_m) - adj_do = (batch_idx * stride_do_b + - head_q_idx * stride_do_h + - q_start * stride_do_m) - adj_delta = (batch_idx * stride_delta_b + - head_q_idx * stride_delta_h + - q_start * stride_delta_m) - delta_ptr_adj = delta_ptr + adj_delta - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = (philox_offset_base + - batch_idx * stride_dropout_b + - head_q_idx * stride_dropout_h) - dropout_offset = (dropout_mask + - batch_idx * stride_dropout_b + - head_q_idx * stride_dropout_h) - - q = tl.load(q_ptr + adj_q + offs_q, mask=mask_q, other=0.0) - do = tl.load(do_ptr + adj_do + offs_do, mask=mask_q, other=0.0) - m = tl.load(m_ptr + adj_delta + offs_m * stride_delta_m, - mask=offs_m < seqlen_q) - m = m[:, None] - - MASK_BLOCK_N: tl.constexpr = BLOCK_N // BLK_SLICE_FACTOR - # start can only be 0 at minimum - start_n = max(end_n - BLOCK_M, 0) - num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N) - - if IS_FP8: - descale_q = tl.load(descale_q_ptr + batch_idx * stride_descale_q_z + head_q_idx) - descale_k = tl.load(descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx) - descale_v = tl.load(descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx) - descale_do = tl.load(descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - dq = tl.zeros([BLOCK_M, BLOCK_D_MODEL_POW2], dtype=tl.float32) - # Compute dQ for masked (diagonal) blocks. - # NOTE: This code scans each row of QK^T backward (from right to left, - # but inside each call to _bwd_dq_inner, from left to right), but that's - # not due to anything important. I just wanted to reuse the loop - # structure for dK & dV above as much as possible. - dq = _bwd_dq_inner( - dq, - q, k_ptr_adj, v_ptr_adj, do, m, delta_ptr_adj, sm_scale, - stride_q_m, stride_q_k, stride_k_n, stride_k_k, stride_v_n, stride_v_k, - stride_dropout_m, stride_dropout_n, - stride_delta_m, - seqlen_q, seqlen_k, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - start_m, start_n, end_n, num_steps, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M, MASK_BLOCK_N, - BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, - MASK=True, - ENABLE_DROPOUT=ENABLE_DROPOUT, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - ) - end_n -= num_steps * MASK_BLOCK_N - num_steps = tl.cdiv(end_n, BLOCK_N) - start_n = max(end_n - num_steps * BLOCK_N, 0) - dq = _bwd_dq_inner( - dq, - q, k_ptr_adj, v_ptr_adj, do, m, delta_ptr_adj, sm_scale, - stride_q_m, stride_q_k, stride_k_n, stride_k_k, stride_v_n, stride_v_k, - stride_dropout_m, stride_dropout_n, - stride_delta_m, - seqlen_q, seqlen_k, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - start_m, start_n, end_n, num_steps, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M, BLOCK_N, - BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, - MASK=False, - ENABLE_DROPOUT=ENABLE_DROPOUT, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - ) - # Write back dQ. - offs_dq = (batch_idx * stride_dq_b + - head_q_idx * stride_dq_h + - q_start * stride_dq_m + - offs_m[:, None] * stride_dq_m + - offs_k[None, :] * stride_dq_k) - dq *= sm_scale - tl.store(dq_ptr + offs_dq, dq, mask=mask_q) - - -@triton.jit -def _bwd_kernel_dkdvdq_noncausal( - Q, K, V, sm_scale, DO, DK, DV, DQ, - M, Delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - dropout_mask, dropout_p, philox_seed, philox_offset, - descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, - NUM_Q_HEADS: tl.constexpr, - NUM_K_HEADS: tl.constexpr, - BATCH, - NUM_K_PIDS, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - # workgroup id - wid = tl.program_id(0) # 0, ..., NUM_K_PIDS * BATCH * NUM_K_HEADS - 1 - - # Workgroups get launched first along batch dim, then in head_k dim, and then in seq k block dim - # This is in order to avoid contention for the tl.atomic_add (inside _bwd_dkdvdq_inner) that happens between workgroups that share the same batch and head_k. - bid = wid % BATCH - hkid = wid // BATCH % NUM_K_HEADS - pid = wid // (BATCH * NUM_K_HEADS) % NUM_K_PIDS - - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - k_start = tl.load(cu_seqlens_k + bid) - k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - - dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) - dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) - - start_n = pid * BLOCK_N - - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - offs_n = start_n + tl.arange(0, BLOCK_N) - mask_kv = offs_n[:, None] < seqlen_k - PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) - if PADDED_HEAD: - mask_kv &= offs_k < BLOCK_D_MODEL - - GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS - adj_k = (bid * stride_kb + - hkid * stride_kh + - k_start * stride_kn + - offs_n[:, None] * stride_kn + - offs_k[None, :] * stride_kk) - adj_v = (bid * stride_vb + - hkid * stride_vh + - k_start * stride_vn + - offs_n[:, None] * stride_vn + - offs_k[None, :] * stride_vk) - - k = tl.load(K + adj_k, mask=mask_kv, other=0.0) - v = tl.load(V + adj_v, mask=mask_kv, other=0.0) - - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - adj_q = (bid * stride_qb + hqid * stride_qh + q_start * stride_qm) - - Q_ptr = Q + adj_q - DQ_ptr = DQ + adj_q - - adj_do = (bid * stride_dob + hqid * stride_doh + q_start * stride_dom) - DO_ptr = DO + adj_do - adj_delta = (bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam) - M_ptr = M + adj_delta - Delta_ptr = Delta + adj_delta - - #dropout - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset + bid * stride_dropoutb + \ - hqid * stride_dropouth - dropout_offset = dropout_mask + bid * stride_dropoutb + \ - hqid * stride_dropouth - - if IS_FP8: - descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) - descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) - descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) - descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - start_m = 0 - num_steps = tl.cdiv(seqlen_q, BLOCK_M) - - dk, dv = _bwd_dkdvdq_inner( - dk, dv, - Q_ptr, k, v, DO_ptr, DQ_ptr, M_ptr, Delta_ptr, sm_scale, - stride_qm, stride_qk, - stride_dom, stride_dok, - stride_dropoutm, stride_dropoutn, - stride_deltam, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - seqlen_q, seqlen_k, - start_n, start_m, num_steps, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M, BLOCK_N, - BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, - MASK=False, - ENABLE_DROPOUT=ENABLE_DROPOUT, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - workgroup_id=pid, - ) - - adj_dkdv = (bid * stride_dkb + - hkid * stride_dkh + - k_start * stride_dkn + offs_n[:, None] * stride_dkn + - offs_k[None, :] * stride_dkk) - tl.store(DV + adj_dkdv, dv, mask=mask_kv) - dk *= sm_scale - tl.store(DK + adj_dkdv, dk, mask=mask_kv) - - - -@triton.jit -def _bwd_kernel_dkdv_noncausal( - Q, K, V, sm_scale, DO, DK, DV, - M, Delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - dropout_mask, dropout_p, philox_seed, philox_offset, - descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, - NUM_Q_HEADS: tl.constexpr, - NUM_K_HEADS: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - pid = tl.program_id(0) - bid = tl.program_id(1) - hkid = tl.program_id(2) - - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - k_start = tl.load(cu_seqlens_k + bid) - k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - - dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) - dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) - - start_n = pid * BLOCK_N - - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - offs_n = start_n + tl.arange(0, BLOCK_N) - mask_kv = offs_n[:, None] < seqlen_k - PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) - if PADDED_HEAD: - mask_kv &= offs_k < BLOCK_D_MODEL - - GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS - adj_k = (bid * stride_kb + - hkid * stride_kh + - k_start * stride_kn + - offs_n[:, None] * stride_kn + - offs_k[None, :] * stride_kk) - adj_v = (bid * stride_vb + - hkid * stride_vh + - k_start * stride_vn + - offs_n[:, None] * stride_vn + - offs_k[None, :] * stride_vk) - - k = tl.load(K + adj_k, mask=mask_kv, other=0.0) - v = tl.load(V + adj_v, mask=mask_kv, other=0.0) - - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - adj_q = (bid * stride_qb + hqid * stride_qh + q_start * stride_qm) - Q_ptr = Q + adj_q - adj_do = (bid * stride_dob + hqid * stride_doh + q_start * stride_dom) - DO_ptr = DO + adj_do - adj_delta = (bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam) - M_ptr = M + adj_delta - Delta_ptr = Delta + adj_delta - - #dropout - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset + bid * stride_dropoutb + \ - hqid * stride_dropouth - dropout_offset = dropout_mask + bid * stride_dropoutb + \ - hqid * stride_dropouth - - if IS_FP8: - descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) - descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) - descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) - descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - start_m = 0 - num_steps = tl.cdiv(seqlen_q, BLOCK_M) - dk, dv = _bwd_dkdv_inner( - dk, dv, - Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, - stride_qm, stride_qk, - stride_dom, stride_dok, - stride_dropoutm, stride_dropoutn, - stride_deltam, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - seqlen_q, seqlen_k, - start_n, start_m, num_steps, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M, BLOCK_N, - BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, - MASK=False, - ENABLE_DROPOUT=ENABLE_DROPOUT, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - ) - - adj_dkdv = (bid * stride_dkb + - hkid * stride_dkh + - k_start * stride_dkn + offs_n[:, None] * stride_dkn + - offs_k[None, :] * stride_dkk) - tl.store(DV + adj_dkdv, dv, mask=mask_kv) - dk *= sm_scale - tl.store(DK + adj_dkdv, dk, mask=mask_kv) - - -@triton.jit -def _bwd_kernel_dq_noncausal( - Q, K, V, sm_scale, DO, DQ, - M, delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - dropout_mask, dropout_p, philox_seed, philox_offset_base, - descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, - NUM_Q_HEADS: tl.constexpr, - NUM_K_HEADS: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - pid = tl.program_id(0) #seqlen - bid = tl.program_id(1) #batch - hkid = tl.program_id(2) #head_k - - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - - if IS_VARLEN: - # Compute actual sequence lengths - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - k_start = tl.load(cu_seqlens_k + bid) - k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - start_m = pid * BLOCK_M - - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - offs_m = start_m + tl.arange(0, BLOCK_M) - - #mask for loading K and V - mask_q = offs_m[:, None] < seqlen_q - PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) - if PADDED_HEAD: - mask_k = offs_k < BLOCK_D_MODEL - mask_q &= mask_k[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk - offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok - adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn - adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn - K += adj_k - V += adj_v - - GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam - delta_ptr = delta + adj_delta - - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = (philox_offset_base + - bid * stride_dropoutb + - hqid * stride_dropouth) - dropout_offset = ( - dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth) - - q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) - do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) - m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=offs_m < seqlen_q) - m = m[:, None] - - #FP8 - if IS_FP8: - descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) - descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) - descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) - descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - start_n = 0 - end_n = seqlen_k - num_steps = tl.cdiv(seqlen_k, BLOCK_N) - dq = tl.zeros([BLOCK_M, BLOCK_D_MODEL_POW2], dtype=tl.float32) - dq = _bwd_dq_inner( - dq, - q, K, V, do, m, delta_ptr, sm_scale, - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, - stride_dropoutm, stride_dropoutn, - stride_deltam, - seqlen_q, seqlen_k, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - start_m, start_n, end_n, num_steps, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M, BLOCK_N, - BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, - MASK=False, - ENABLE_DROPOUT=ENABLE_DROPOUT, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - ) - - adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk - dq *= sm_scale - tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) - -def _flash_attn_backward( - do: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - softmax_lse: torch.Tensor, - dq: torch.Tensor, - dk: torch.Tensor, - dv: torch.Tensor, - sm_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - philox_seed: Optional[int] = 0, - philox_offset: Optional[int] = 0, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None, - fused: bool = False, -): - IS_FP8 = is_fp8(q) - if IS_FP8: - FP8_MAX = torch.finfo(q.dtype).max - descale_strides = (descale_q.stride(0),descale_k.stride(0),descale_v.stride(0),descale_do.stride(0) ) - else: - FP8_MAX = None - stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_do_z = None - descale_strides = (stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z) - - IS_VARLEN = True if cu_seqlens_q is not None else False - - #get strides and shape - if IS_VARLEN: - #Layout for q,k,v is thd ie [total tokens, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], q.shape[2] - seqlen_k, num_k_heads = max_seqlen_k, k.shape[1] - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) - dq_strides = (0, dq.stride(1), dq.stride(0), dq.stride(2)) - dk_strides = (0, dk.stride(1), dk.stride(0), dk.stride(2)) - dv_strides = (0, dv.stride(1), dv.stride(0), dv.stride(2)) - do_strides = (0, do.stride(1), do.stride(0), do.stride(2)) - else: - #Layout for q,k,v is bshd ie [batch, seq_len, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = q.shape - seqlen_k, num_k_heads = k.shape[1], k.shape[2] - q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) - k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) - v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) - o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) - dq_strides = (dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3)) - dk_strides = (dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3)) - dv_strides = (dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3)) - do_strides = (do.stride(0), do.stride(2), do.stride(1), do.stride(3)) - - #BLOCK_D_MODEL, BLOCK_D_MODEL_POW2 - #padding for head_dim. Power of 2 or 16 - BLOCK_D_MODEL_POW2 = triton.next_power_of_2(head_sz) - BLOCK_D_MODEL_POW2 = max(BLOCK_D_MODEL_POW2, 16) - - #Configs - #PRE_BLOCK, BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 - #BLK_SLICE_FACTOR - NUM_WARPS, NUM_STAGES = 4, 1 - WAVES_PER_EU = 1 - PRE_BLOCK = 128 - #BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 64, 64, 64, 16 - BLK_SLICE_FACTOR = 2 - - #init delta - delta = torch.zeros_like(softmax_lse) - if IS_VARLEN: - #[total_tokens, num_q_heads, seqlen_q] - delta_strides = (0, delta.stride(1), delta.stride(0)) - else: - #[batch, num_q_heads, seqlen_q] - delta_strides = delta.stride() - - #preprocess - #compute D(delta) = rowsum(dO*O). Note, multiplication is element-wise. - pre_grid = (triton.cdiv(max_seqlen_q, PRE_BLOCK), batch, num_q_heads) - _bwd_preprocess[pre_grid]( - o, do, - delta, - *o_strides, - *delta_strides, - descale_strides[3], - cu_seqlens_q, max_seqlen_q, - descale_do, - BLOCK_M=PRE_BLOCK, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8 - ) - - #dropout_mask - use_dropout = (dropout_p > 0.0) - if use_dropout: - dropout_mask = torch.zeros( - (batch, num_q_heads, max_seqlen_q, max_seqlen_k), - device=q.device, - dtype=torch.float32) - dropout_strides = dropout_mask.stride() - else: - dropout_mask = None - dropout_strides = (0, 0, 0, 0) - - grid_dkdv = ((max_seqlen_k + BLOCK_N1 - 1) // BLOCK_N1, batch, num_k_heads) - grid_dq = ((max_seqlen_q + BLOCK_M2 - 1) // BLOCK_M2, batch, num_k_heads) - - if fused: # fuses dk, dv, dq computations into one kernel by computing the dq using atomic adds between workgroups - - BLOCK_N = 128 - config = { - "BLOCK_M": 32, - "BLOCK_N": BLOCK_N, - "num_warps": 4, - "num_stages": 1, - "waves_per_eu": 1, - "BLK_SLICE_FACTOR": 2, - } - - num_k_pids = (max_seqlen_k + BLOCK_N - 1) // BLOCK_N - grid_dkdvdq = (batch * num_k_heads * num_k_pids,) - - if causal: - _bwd_kernel_dkdvdq_causal[grid_dkdvdq]( - q, k, v, sm_scale, do, dk, dv, dq, - softmax_lse, delta, - *q_strides, - *k_strides, - *v_strides, - *dk_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask,dropout_p, philox_seed, philox_offset, - descale_q, descale_k, descale_v, descale_do, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BATCH=batch, - NUM_K_PIDS=num_k_pids, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - **config, - ) - else: - _bwd_kernel_dkdvdq_noncausal[grid_dkdvdq]( - q, k, v, sm_scale, do, dk, dv, dq, - softmax_lse, delta, - *q_strides, - *k_strides, - *v_strides, - *dk_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask,dropout_p, philox_seed, philox_offset, - descale_q, descale_k, descale_v, descale_do, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BATCH=batch, - NUM_K_PIDS=num_k_pids, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - **config, - ) - - return delta - - # split kernels solution: one kernel computes dk, dv and the other computes dq - - if causal: - _bwd_kernel_dkdv_causal[grid_dkdv]( - q, k, v, sm_scale, do, dk, dv, - softmax_lse, delta, - *q_strides, - *k_strides, - *v_strides, - *dk_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask,dropout_p, philox_seed, philox_offset, - descale_q, descale_k, descale_v, descale_do, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BLOCK_M=BLOCK_M1, - BLOCK_N=BLOCK_N1, - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu=WAVES_PER_EU, - ) - _bwd_kernel_dq_causal[grid_dq]( - q, k, v, sm_scale, do, dq, - softmax_lse, delta, - *q_strides, - *k_strides, - *v_strides, - *dq_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - dropout_mask,dropout_p, philox_seed, philox_offset, - descale_q, descale_k, descale_v, descale_do, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BLOCK_M=BLOCK_M2, - BLOCK_N=BLOCK_N2, - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu=WAVES_PER_EU, - ) - else: - _bwd_kernel_dkdv_noncausal[grid_dkdv]( - q, k, v, sm_scale, do, dk, dv, - softmax_lse, delta, - *q_strides, - *k_strides, - *v_strides, - *dk_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask,dropout_p, philox_seed, philox_offset, - descale_q, descale_k, descale_v, descale_do, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BLOCK_M=BLOCK_M1, - BLOCK_N=BLOCK_N1, - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu=WAVES_PER_EU, - ) - - _bwd_kernel_dq_noncausal[grid_dq]( - q, k, v, sm_scale, do, dq, - softmax_lse, delta, - *q_strides, - *k_strides, - *v_strides, - *dq_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask,dropout_p, philox_seed, philox_offset, - descale_q, descale_k, descale_v, descale_do, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BLOCK_M=BLOCK_M2, - BLOCK_N=BLOCK_N2, - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu=WAVES_PER_EU, - ) - - return delta - - -class FlashAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_softmax, - is_grad_enabled, - fused_backward, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q,k,v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - - - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = _flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - alibi_slopes=alibi_slopes, - return_lse=return_lse, - return_softmax=return_softmax and dropout_p > 0, - max_seqlen_q=q.shape[1], - max_seqlen_k=k.shape[1], - ) - - if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse) - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - ctx.fused_backward = fused_backward - - - out = out_padded[..., :head_size_og] - result = [out] - if return_lse: - result.append(softmax_lse) - if return_softmax: - result.append(S_dmask) - - return tuple(result) - - @staticmethod - def backward(ctx, do, *args): - q, k, v, out, softmax_lse = ctx.saved_tensors - dq, dk, dv = torch.zeros_like(q), torch.empty_like(k), torch.empty_like(v) - head_size_v_og = do.size(3) - do_padded = do - if head_size_v_og % 8 != 0: - do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) - _flash_attn_backward( - do_padded, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - ctx.softmax_scale, - ctx.alibi_slopes, - ctx.causal, - None, - None, - max_seqlen_q=q.shape[1], - max_seqlen_k=k.shape[1], - dropout_p=ctx.dropout_p, - philox_seed=ctx.philox_seed, - philox_offset=ctx.philox_offset, - fused=ctx.fused_backward, - ) - dq = dq[..., : q.shape[-1]] # We could have padded the head dimension - dk = dk[..., : k.shape[-1]] - dv = dv[..., : v.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None - -def flash_attn_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1,-1), - alibi_slopes=None, - deterministic=True, - return_lse=False, - return_attn_probs=False, - fused_backward=False, -): - """dropout_p should be set to 0.0 during evaluation - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - k: (batch_size, seqlen, nheads_k, headdim) - v: (batch_size, seqlen, nheads_k, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_lse=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnFunc.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_attn_probs, - torch.is_grad_enabled(), - fused_backward, - ) - - -class FlashAttnFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_softmax, - is_grad_enabled, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q,k,v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # cast input to fp8 - fp8_dtype = torch.float8_e4m3fnuz - q_fp8, descale_q = cast_to_fp8(q, fp8_dtype, "bshd") - k_fp8, descale_k = cast_to_fp8(k, fp8_dtype, "bshd") - v_fp8, descale_v = cast_to_fp8(v, fp8_dtype, "bshd") - - out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = _flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - alibi_slopes=alibi_slopes, - return_lse=return_lse, - return_softmax=return_softmax and dropout_p > 0, - max_seqlen_q=q.shape[1], - max_seqlen_k=k.shape[1], - cu_seqlens_q=None, - cu_seqlens_k=None, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v - ) - - if is_grad: - ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_padded, softmax_lse, descale_q, descale_k, descale_v) - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - - out = out_padded[..., :head_size_og] - result = [out] - if return_lse: - result.append(softmax_lse) - if return_softmax: - result.append(S_dmask) - - return tuple(result) - - @staticmethod - def backward(ctx, do, *args): - q_fp8, k_fp8, v_fp8, out, softmax_lse, descale_q, descale_k, descale_v = ctx.saved_tensors - dq, dk, dv = torch.zeros_like(q_fp8, dtype=torch.float32), torch.zeros_like(k_fp8, dtype=torch.float32), torch.zeros_like(v_fp8, dtype=torch.float32) - head_size_v_og = do.size(3) - do_padded = do - if head_size_v_og % 8 != 0: - do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) - - fp8_dtype = torch.float8_e4m3fnuz - do_padded_fp8, descale_do = cast_to_fp8(do_padded, fp8_dtype, "bshd") - _flash_attn_backward( - do_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out, - softmax_lse, - dq, - dk, - dv, - ctx.softmax_scale, - ctx.alibi_slopes, - ctx.causal, - None, - None, - max_seqlen_q=q_fp8.shape[1], - max_seqlen_k=k_fp8.shape[1], - dropout_p=ctx.dropout_p, - philox_seed=ctx.philox_seed, - philox_offset=ctx.philox_offset, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_do=descale_do, - ) - #dq = dq[..., : q_fp8.shape[-1]] # We could have padded the head dimension - #dk = dk[..., : k_fp8.shape[-1]] - #dv = dv[..., : v_fp8.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None - -def flash_attn_fp8_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_lse=False, - return_attn_probs=False -): - return FlashAttnFP8Func.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_attn_probs, - torch.is_grad_enabled() - ) - -class FlashAttnVarlenFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_softmax, - block_table, - is_grad_enabled, - fused_backward, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = _flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - alibi_slopes=alibi_slopes, - return_lse=return_lse, - return_softmax=return_softmax and dropout_p > 0.0, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - ) - if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k) - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.fused_backward = fused_backward - out = out_padded[..., :head_size_og] - - result = [out] - if return_lse: - result.append(softmax_lse) - if return_softmax: - result.append(S_dmask) - - return tuple(result) - - @staticmethod - def backward(ctx, do, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors - dq, dk, dv = torch.zeros_like(q), torch.empty_like(k), torch.empty_like(v) - head_size_og = do.size(2) - do_padded = do - if head_size_og % 8 != 0: - do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_og % 8]) - _flash_attn_backward( - do_padded, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - ctx.softmax_scale, - ctx.alibi_slopes, - ctx.causal, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q=ctx.max_seqlen_q, - max_seqlen_k=ctx.max_seqlen_k, - dropout_p=ctx.dropout_p, - philox_seed=ctx.philox_seed, - philox_offset=ctx.philox_offset, - fused=ctx.fused_backward, - ) - dq = dq[..., : q.shape[-1]] # We could have padded the head dimension - dk = dk[..., : k.shape[-1]] - dv = dv[..., : v.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None - - -def flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1,-1), - alibi_slopes=None, - deterministic=False, - return_lse=False, - return_attn_probs=False, - block_table=None, - fused_backward=False, -): - """dropout_p should be set to 0.0 during evaluation - Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into q. - cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into kv. - max_seqlen_q: int. Maximum query sequence length in the batch. - max_seqlen_k: int. Maximum key sequence length in the batch. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnVarlenFunc.apply( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_attn_probs, - block_table, - torch.is_grad_enabled(), - fused_backward, - ) - - -class FlashAttnVarlenFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_softmax, - block_table, - is_grad_enabled, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # cast input to fp8 - fp8_dtype = torch.float8_e4m3fnuz - q_fp8, descale_q = cast_varlen_to_fp8(q, fp8_dtype, cu_seqlens=cu_seqlens_q) - k_fp8, descale_k = cast_varlen_to_fp8(k, fp8_dtype, cu_seqlens=cu_seqlens_k) - v_fp8, descale_v = cast_varlen_to_fp8(v, fp8_dtype, cu_seqlens=cu_seqlens_k) - - out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = _flash_attn_forward( - q_fp8, - k_fp8, - v_fp8, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - alibi_slopes=alibi_slopes, - return_lse=return_lse, - return_softmax=return_softmax and dropout_p > 0, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v - ) - if is_grad: - ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, descale_q, descale_k, descale_v) - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - out = out_padded[..., :head_size_og] - result = [out] - if return_lse: - result.append(softmax_lse) - if return_softmax: - result.append(S_dmask) - - return tuple(result) - - @staticmethod - def backward(ctx, do, *args): - q_fp8, k_fp8, v_fp8, out, softmax_lse, cu_seqlens_q, cu_seqlens_q, descale_q, descale_k, descale_v = ctx.saved_tensors - dq, dk, dv = torch.zeros_like(q, dtype=torch.float32), torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32) - head_size_v_og = do.size(3) - do_padded = do - if head_size_v_og % 8 != 0: - do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) - - fp8_dtype = torch.float8_e4m3fnuz - do_padded_fp8, descale_do = cast_varlen_to_fp8(dout_padded, fp8_dtype, "thd", cu_seqlens_q) - - _flash_attn_backward( - do_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out, - softmax_lse, - dq, - dk, - dv, - ctx.softmax_scale, - ctx.alibi_slopes, - ctx.causal, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - dropout_p=ctx.dropout_p, - philox_seed=ctx.philox_seed, - philox_offset=ctx.philox_offset, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_do=descale_do - ) - dq = dq[..., : q.shape[-1]] # We could have padded the head dimension - dk = dk[..., : k.shape[-1]] - dv = dv[..., : v.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None - -def flash_attn_varlen_fp8_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_lse=False, - return_attn_probs=False, - block_table=None -): - return FlashAttnVarlenFP8Func.apply( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_attn_probs, - block_table, - torch.is_grad_enabled() - ) \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py deleted file mode 100644 index 3f650d288db..00000000000 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py +++ /dev/null @@ -1,1091 +0,0 @@ -import torch -import triton # type: ignore -import triton.language as tl # type: ignore -from typing import Literal, Optional -from .utils import AUTOTUNE, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, get_shapes_from_layout, compute_fp8_scaling_factors, \ - get_strides_from_layout, create_dropout_mask, create_dropout_mask_varlen, is_cdna, is_rdna - -# NOTE: triton fails to import tl.constexprs so create them here for the file -tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) -tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) - - -def get_autotune_configs(): - if False: - if is_cdna(): - # shared meta-parameters - NUM_STAGES = 1 - NUM_WARPS = 4 - WAVES_PER_EU = 2 - MATRIX_INSTR_NONKDIM = 16 - - preprocess_autotune_configs = [ - triton.Config({"PRE_BLOCK": 128, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), # og config - triton.Config({"PRE_BLOCK": 64, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - triton.Config({"PRE_BLOCK": 32, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - triton.Config({"PRE_BLOCK": 16, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - ] - preprocess_autotune_keys = [ - "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", - ] - causal_autotune_configs = [ - triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 32, "BLK_SLICE_FACTOR": 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), # og config - triton.Config({'BLOCK_M1': 16, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 16, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - triton.Config({'BLOCK_M1': 16, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 16, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - ] - causal_autotune_keys = [ - "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", - ] - noncausal_autotune_configs = [ - triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 32, "BLK_SLICE_FACTOR": 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), # og config - triton.Config({'BLOCK_M1': 16, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 16, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - triton.Config({'BLOCK_M1': 16, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 16, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - ] - noncausal_autotune_keys = [ - "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", - ] - - return (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) - else: - raise ValueError("Unknown Device Type") - else: - # meta-parameters - # TODO: fix num_stages later - NUM_WARPS, NUM_STAGES = 4, 1 - WAVES_PER_EU = 1 - PRE_BLOCK = 128 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 - BLK_SLICE_FACTOR = 2 - - assert BLOCK_N1 == BLOCK_M2 - - # configs for the kernels - preprocess_autotune_configs = [ - triton.Config({"PRE_BLOCK": PRE_BLOCK, "waves_per_eu": WAVES_PER_EU}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - ] - preprocess_autotune_keys = [ - "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", - ] - causal_autotune_configs = [ - triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - ] - causal_autotune_keys = [ - "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", - ] - noncausal_autotune_configs = [ - triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - ] - noncausal_autotune_keys = [ - "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", - ] - return (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) - - - -(preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) = get_autotune_configs() - - -# This function computes delta given output Out and gradient DO -# Here is the I/O shape: -# Out: (batch, nhead_q, max_seqlens_q, headDim) -# DO: (batch, nhead_q, max_seqlens_q, headDim) -# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at -# fwd_prefill.py line 607 -@triton.autotune( - configs=preprocess_autotune_configs, - key=preprocess_autotune_keys, - use_cuda_graph=True, -) -@triton.jit -def _bwd_preprocess( - O, DO, # noqa: E741 - Delta, - stride_ob, stride_oh, stride_om, stride_ok, - stride_deltab, stride_deltah, stride_deltam, - stride_descale_do_z, - cu_seqlens_q, max_seqlen_q, - Descale_do, - PRE_BLOCK: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr -): - pid_m = tl.program_id(0) - bid = tl.program_id(1) - hid = tl.program_id(2) - # Handle varlen - q_start = 0 - seqlen_q = max_seqlen_q - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - seqlen_q = q_end - q_start - else: - q_start = 0 - seqlen_q = max_seqlen_q - - # Compute offsets - offs_m = pid_m * PRE_BLOCK + tl.arange(0, PRE_BLOCK) - offs_k = tl.arange(0, HEAD_DIM) - # Offset O/DO by batch, head and q_start - O += bid * stride_ob + hid * stride_oh + q_start * stride_om # noqa: E741 - DO += bid * stride_ob + hid * stride_oh + q_start * stride_om - # create masks - mask_m = offs_m < seqlen_q - mask_md = mask_m[:, None] - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - if PADDED_HEAD: - mask_md &= offs_k[None, :] < ACTUAL_HEAD_DIM - # compute pointers - offs_do = offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok - out_ptrs = O + offs_do - do_ptrs = DO + offs_do - # load - o = tl.load(out_ptrs, mask=mask_md, other=0.0) - do = tl.load(do_ptrs, mask=mask_md, other=0.0) - # compute and write-back to delta - if IS_FP8: - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hid) - - # NOTE: do is in the fp8 range and o is not in fp8 - delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) - else: - delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) - delta_offset = Delta + bid * stride_deltab + hid * stride_deltah + q_start * stride_deltam - tl.store(delta_offset + offs_m * stride_deltam, delta, mask=mask_m) - - -# The main inner-loop logic for computing dK and dV. -@triton.jit -def _bwd_dkdv_inner( - dk, dv, # output - Q, k, v, DO, M, D, sm_scale, # input tensor - stride_qm, stride_qk, - stride_dom, stride_dok, - stride_dropoutm, stride_dropoutn, # - stride_deltam, - BLOCK_M: tl.constexpr, # 16 - BLOCK_N: tl.constexpr, # 128 - HEAD_DIM: tl.constexpr, # - ACTUAL_HEAD_DIM: tl.constexpr, # - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - seqlen_q, seqlen_k, # max sequence length for q and k - # Filled in by the wrapper. - start_n, start_m, num_steps, # iteration numbers - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - MASK: tl.constexpr, # causal masking, only apply to tiles on mask diagonal - ENABLE_DROPOUT: tl.constexpr, # activate dropout - USE_EXP2: tl.constexpr, # activate exp2 - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # if HEAD_DIM is padded - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - delta_qk = seqlen_q - seqlen_k - offs_m = start_m + tl.arange(0, BLOCK_M) # start_m + (0, 15) - offs_n = start_n + tl.arange(0, BLOCK_N) # start_m + (0, 127) - offs_k = tl.arange(0, HEAD_DIM) - # mask to make sure not OOB of seqlen_q - mask_n = offs_n < seqlen_k - # Q and DO are (seqlen_q, head_dim) - # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM, 1), transpose of q - qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k[:, None] * stride_qk - # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM), NOT transposed - do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok - # BLOCK_N must be a multiple of BLOCK_M, otherwise the code wouldn't work. - tl.static_assert(BLOCK_N % BLOCK_M == 0) - curr_m = start_m - step_m = BLOCK_M - curr_philox_offset = batch_philox_offset - curr_dropout_offset = dropout_offset - RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) - - for blk_idx in range(num_steps): - if DEBUG_TRITON: print(f"iter {blk_idx}: curr_m = {curr_m}") # noqa: E701 - offs_m = curr_m + tl.arange(0, BLOCK_M) - # update the mask because offs_m advanced - mask_m = offs_m < seqlen_q - mask_qT = mask_m[None, :] - mask_do = mask_m[:, None] - mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) - if PADDED_HEAD: - mask_qT &= offs_k[:, None] < ACTUAL_HEAD_DIM - mask_do &= offs_k[None, :] < ACTUAL_HEAD_DIM - qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) - # generate dropout mask - if ENABLE_DROPOUT: - # NOTE: dropout is transposed because it is used to mask pT - philox_offs = curr_philox_offset + \ - offs_m[None, :] * stride_dropoutm + \ - offs_n[:, None] * stride_dropoutn - if tl_DROPOUT_USE_PYTORCH: - dropout_offs = offs_m[None, :] * stride_dropoutm + \ - offs_n[:, None] * stride_dropoutn - dropout_mask = tl.load( - curr_dropout_offset + dropout_offs, - mask=mask_nm - ) - else: - rand_vals = tl.rand(philox_seed, philox_offs) - dropout_mask = rand_vals > dropout_p - dropout_scale = 1.0 / (1 - dropout_p) - # Load m before computing qk to reduce pipeline stall. - m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) - if IS_FP8: - qkT = (tl.dot(k, qT) * descale_q * descale_k) - else: - qkT = tl.dot(k, qT) - if DEBUG_TRITON_DETAIL: - if start_n == 256: - print(f"qT: {qT.shape}\n", qT) - print(f"k: {k.shape}\n", k) - print(f"qkT scaled: {qkT.shape}\n", qkT * sm_scale) - # TODO: remove the scaling of m later when we removed re-scaling in fwd - if USE_EXP2: - pT = tl.math.exp2(qkT * sm_scale * RCP_LN2 - m[None, :] * RCP_LN2) - else: - pT = tl.math.exp(qkT * sm_scale - m[None, :]) - - # Autoregressive masking. - if MASK: - # offset offs_m with delta_qk since the causal mask starts at - # bottom right of the (seqlen_q, seqlen_k) matrix - causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] - mask = causal_mask & mask_nm - if DEBUG_TRITON_DETAIL: - if start_n == 256: - print(f"causal_mask: {causal_mask.shape}\n", causal_mask) - print(f"qkT after causal: {qkT.shape}\n", tl.where(causal_mask, qkT * sm_scale, 0.0)) - pT = tl.where(mask, pT, 0.0) - do = tl.load(do_ptrs, mask=mask_do, other=0.0) - # Compute dV. - if ENABLE_DROPOUT: - pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale - if IS_FP8: - scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors(pT_dropout, FP8_MAX) - dv += (tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do)* descale_p_dropout * descale_do) - else: - dv += tl.dot(pT_dropout.to(do.type.element_ty), do) - else: - if IS_FP8: - scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) - dv += (tl.dot((pT * scale_pT).to(do.type.element_ty), do) * descale_pT * descale_do) - else: - dv += tl.dot(pT.to(do.type.element_ty), do) - - if DEBUG_TRITON_DETAIL: - if start_n == 256: - print(f"pT: {pT.shape}\n", pT) - # D (= delta) is pre-divided by ds_scale. - Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) - # Compute dP and dS. - if IS_FP8: - dpT = (tl.dot(v, tl.trans(do)) * descale_v * descale_do) - else: - dpT = tl.dot(v, tl.trans(do)) - if ENABLE_DROPOUT: - dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale - delta_i = Di[None, :] - dsT = pT * (dpT - delta_i) - if IS_FP8: - scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) - dk += (tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) * descale_dsT * descale_q) - else: - dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) - # Increment pointers. - curr_m += step_m - qT_ptrs += step_m * stride_qm - do_ptrs += step_m * stride_dom - return dk, dv - -# the main inner-loop logic for computing dQ -@triton.jit -def _bwd_dq_inner( - dq, # output - q, K, V, do, m, Delta, sm_scale, # input - # shared by Q/K/V. - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, - stride_dropoutm, stride_dropoutn, # stride for dropout - stride_deltam, - seqlen_q, seqlen_k, # - BLOCK_M2: tl.constexpr, # - BLOCK_N2: tl.constexpr, # - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, # - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - # Filled in by the wrapper. - start_m, start_n, end_n, num_steps, # - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - MASK: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # if HEAD_DIM is padded - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - delta_qk = seqlen_q - seqlen_k - offs_m = start_m + tl.arange(0, BLOCK_M2) - offs_n = start_n + tl.arange(0, BLOCK_N2) - offs_k = tl.arange(0, HEAD_DIM) - - # mask to make sure not OOB of seqlen_q - mask_m = offs_m < seqlen_q - - kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk - vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk - # D (= delta) is pre-divided by ds_scale. - Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0) - # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. - tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) - curr_n = start_n - step_n = BLOCK_N2 - curr_philox_offset = batch_philox_offset - curr_dropout_offset = dropout_offset - RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) - for blk_idx in range(num_steps): - if DEBUG_TRITON: print(f"iter {blk_idx}: curr_n = {curr_n}") # noqa: E701 - offs_n = curr_n + tl.arange(0, BLOCK_N2) - # end_n is needed because the end of causal True might not be perfectly - # aligned with the end of the block - mask_n = offs_n < end_n - if DEBUG_TRITON_DETAIL: print(f"start_n = {start_n}, end_n = {end_n}, offs_n: {offs_n.shape}\n{offs_n}") # noqa: E701 - if DEBUG_TRITON_DETAIL: print(f"mask_n: {mask_n.shape}\n{mask_n}") # noqa: E701 - mask_kT = mask_n[None, :] - mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) - if PADDED_HEAD: - mask_kT &= offs_k[:, None] < ACTUAL_HEAD_DIM - - kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) - vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) - - if ENABLE_DROPOUT: - # NOTE: dropout is transposed because it is used to mask pT - philox_offs = curr_philox_offset + \ - offs_m[:, None] * stride_dropoutm + \ - offs_n[None, :] * stride_dropoutn - if tl_DROPOUT_USE_PYTORCH: - dropout_offs = offs_m[:, None] * stride_dropoutm + \ - offs_n[None, :] * stride_dropoutn - dropout_mask = tl.load( - curr_dropout_offset + dropout_offs, - mask=mask_mn) - else: - rand_vals = tl.rand(philox_seed, philox_offs) - dropout_mask = rand_vals > dropout_p - dropout_scale = 1 / (1 - dropout_p) - - if IS_FP8: - qk = (tl.dot(q, kT) * descale_q * descale_k) - else: - qk = tl.dot(q, kT) - if DEBUG_TRITON_DETAIL: print(f"qk scaled: {qk.shape}\n", qk * sm_scale) # noqa: E701 - if USE_EXP2: - p = tl.math.exp2(qk * sm_scale * RCP_LN2 - m * RCP_LN2) - else: - p = tl.math.exp(qk * sm_scale - m) - - # Autoregressive masking. - if MASK: - causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] - mask = causal_mask & mask_mn - p = tl.where(mask, p, 0.0) - # Compute dP and dS. - if IS_FP8: - dp = (tl.dot(do, vT) * descale_do * descale_v) - else: - dp = tl.dot(do, vT) - if ENABLE_DROPOUT: - dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale - delta_i = Di[:, None] - ds = p * (dp -delta_i) - # Compute dQ. - # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. - if IS_FP8: - scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) - dq += (tl.dot((ds * scale_ds).to(kT.type.element_ty), tl.trans(kT)) * descale_ds * descale_k) - else: - dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) - # Increment pointers. - curr_n += step_n - kT_ptrs += step_n * stride_kn - vT_ptrs += step_n * stride_vn - return dq - -@triton.autotune( - configs=causal_autotune_configs, - key=causal_autotune_keys, - use_cuda_graph=True, -) -@triton.jit -def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nheads_q) - Q, K, V, sm_scale, DO, DQ, DK, DV, - M, Delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - HQ, HK, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - dropout_mask, dropout_p, philox_seed, philox_offset_base, - BLOCK_M1: tl.constexpr, - BLOCK_N1: tl.constexpr, - BLOCK_M2: tl.constexpr, - BLOCK_N2: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - USE_EXP2: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # program ids - pid = tl.program_id(0) - bid = tl.program_id(1) - hkid = tl.program_id(2) - if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 - # figure out varlen start and end - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - if IS_VARLEN: - # Compute actual sequence lengths - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - k_start = tl.load(cu_seqlens_k + bid) - k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - delta_qk = seqlen_q - seqlen_k - if DEBUG_TRITON: print(f"delta_qk = {delta_qk}") # noqa: E701 - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - offs_k = tl.arange(0, HEAD_DIM) - GROUP_SIZE: tl.constexpr = HQ // HK - - # align the delta_qk - start_n = pid * BLOCK_N1 - if start_n < seqlen_k: - # This section does dk and dv - dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - - # q > k: diretcly skip all the way until the start of causal block - start_delta_q_gt_k = delta_qk - # q < k: some blocks will have no Masked block, other needs to re-calc - # starting position - # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the - # masked op - num_blocks_skip = -delta_qk // BLOCK_N1 - delta_aligned = (num_blocks_skip + 1) * BLOCK_N1 + delta_qk - start_delta_q_lt_k = delta_aligned // BLOCK_M1 * BLOCK_M1 - if delta_qk >= 0: - start_delta = delta_qk - if DEBUG_TRITON: print(f"q >= k: start_delta = delta_qk aligned to BLOCK_M = {start_delta_q_gt_k}") # noqa: E701 - else: - start_delta = start_delta_q_lt_k - if DEBUG_TRITON: print(f"q < k: start_delta = residue btw multiple BLOCK_N and delta_qk = {delta_aligned} = aligned to BLOCK_M = {start_delta_q_lt_k}") # noqa: E701 - - offs_n = start_n + tl.arange(0, BLOCK_N1) - # Mask for loading K and V - mask_kv = offs_n[:, None] < seqlen_k - if PADDED_HEAD: - mask_k = offs_k < ACTUAL_HEAD_DIM - mask_kv &= mask_k[None, :] - offs_kv = offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk - - # K/V tensors not changed for the group - adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_kv + offs_kv, mask=mask_kv, other=0.0) - v = tl.load(V + adj_kv + offs_kv, mask=mask_kv, other=0.0) - # If MQA / GQA, set the K and V head offsets appropriately. - # hqid = hkid - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - if delta_qk >= 0: - start_m = start_n + start_delta - len_m = BLOCK_N1 - else: - start_m = max(start_n + delta_qk, 0) - start_m = start_m // BLOCK_M1 * BLOCK_M1 - # because we might shift the masked blocks up, we are deeper into - # the masked out region, so we would potentially increase the total - # steps with masked operation to get out of it - residue_m = max(start_n + delta_qk - start_m, 0) - len_m = BLOCK_N1 + residue_m - if DEBUG_TRITON: print(f"residue_m = {residue_m}") # noqa: E701 - - # offset input and output tensor by batch and Q/K heads - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - Q_ptr = Q + adj_q - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - DO_ptr = DO + adj_do - adj_delta = bid * stride_deltab + hqid * stride_deltah + \ - q_start * stride_deltam - M_ptr = M + adj_delta - Delta_ptr = Delta + adj_delta - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ - hqid * stride_dropouth - dropout_offset = dropout_mask + bid * stride_dropoutb + \ - hqid * stride_dropouth - - MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR - # bound the masked operation to q len so it does not have to wast cycles - len_m = min(len_m, seqlen_q) - num_steps = tl.cdiv(len_m, MASK_BLOCK_M1) - # when q < k, we may skip the initial masked op - if pid < num_blocks_skip: - num_steps = 0 - - # if start_m is negative, the current N-tile has no block on the - # diagonal of causal mask, so everything have no causal mask - if DEBUG_TRITON: print(f"Masked: start_n: {start_n}; start_m: {start_m}, num_steps: {num_steps}") # noqa: E701 - dk, dv = _bwd_dkdv_inner( - dk, dv, # output tensors - Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors - stride_qm, stride_qk, # strides for q - stride_dom, stride_dok, # strides for o - stride_dropoutm, stride_dropoutn, # strides for dropout - stride_deltam, - MASK_BLOCK_M1, BLOCK_N1, # block dim - HEAD_DIM, ACTUAL_HEAD_DIM, # head dim - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - seqlen_q, seqlen_k, # max sequence length for q and k - start_n, start_m, num_steps, # iteration numbers - None, None, None, None, - MASK=True, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - USE_EXP2=USE_EXP2, - IS_FP8=False, - FP8_MAX=None, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - start_m += num_steps * MASK_BLOCK_M1 - num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M1) - end_m = start_m + num_steps * BLOCK_M1 - - if DEBUG_TRITON: print(f"start_m after Masked step: {start_m}; num_steps: {num_steps}") # noqa: E701 - if DEBUG_TRITON: print(f"unMasked: start_n: {start_n}, start_m: {start_m}, end_m: {end_m}, num_steps: {num_steps}") # noqa: E701 - if DEBUG_TRITON: print("unMasked") # noqa: E701 - dk, dv = _bwd_dkdv_inner( - dk, dv, # output tensors - Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors - stride_qm, stride_qk, # strides for q - stride_dom, stride_dok, # strides for o - stride_dropoutm, stride_dropoutn, # strides for dropout - stride_deltam, - BLOCK_M1, BLOCK_N1, # block dim - HEAD_DIM, ACTUAL_HEAD_DIM, # head dim - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - seqlen_q, seqlen_k, # max sequence length for q and k - start_n, start_m, num_steps, # iteration numbers - None, None, None, None, - MASK=False, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - USE_EXP2=USE_EXP2, - IS_FP8=False, - FP8_MAX=None, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - # end of GQA/MQA of dkdv - # Write back dV and dK. - adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn - offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk - tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) - dk *= sm_scale - tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) - - # This part does dq - start_m = pid * BLOCK_M2 - if start_m < seqlen_q: - # seqlen_q > seqlen_k, no need to process these tile for dq - if DEBUG_TRITON: print(f"end_n = start_m + BLOCK_M = {start_m} + {BLOCK_M2} = {start_m + BLOCK_M2}") # noqa: E701 - if start_m + BLOCK_M2 < delta_qk: - if DEBUG_TRITON: print(f"start_m + BLOCK_M2 = {start_m} + {BLOCK_M2} = {start_m + BLOCK_M2} < delta_qk of {delta_qk}") # noqa: E701 - return - - offs_m = start_m + tl.arange(0, BLOCK_M2) - # Mask for loading K and V - mask_q = offs_m[:, None] < seqlen_q - if PADDED_HEAD: - mask_k = offs_k < ACTUAL_HEAD_DIM - mask_q &= mask_k[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk - offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok - adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn - K += adj_kv - V += adj_kv - # If MQA / GQA, set the K and V head offsets appropriately. - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front - # for every M-tile - end_n = start_m + BLOCK_M2 - delta_qk - # clamp end_n at [0, seqlen_k] - end_n = max(min(end_n, seqlen_k), 0) - if DEBUG_TRITON: print(f"delta_qk: {delta_qk}; end_n: {end_n}") # noqa: E701 - # offset input and output tensor by batch and Q/K heads - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - adj_delta = \ - bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam - Delta_ptr = Delta + adj_delta - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base + \ - bid * stride_dropoutb + \ - hqid * stride_dropouth - dropout_offset = \ - dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth - - q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) - do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) - m = tl.load(M + adj_delta + offs_m * stride_deltam, - mask=offs_m < seqlen_q) - m = m[:, None] - - MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR - # start can only be 0 at minimum - start_n = max(end_n - BLOCK_M2, 0) - num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N2) - dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) - dq = _bwd_dq_inner( - dq, - q, K, V, do, m, Delta_ptr, sm_scale, # - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, - stride_dropoutm, stride_dropoutn, # - stride_deltam, - seqlen_q, seqlen_k, # - BLOCK_M2, MASK_BLOCK_N2, # - HEAD_DIM, ACTUAL_HEAD_DIM, # - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - start_m, start_n, end_n, num_steps, # - None, None, None, None, - MASK=True, # - ENABLE_DROPOUT=ENABLE_DROPOUT, - USE_EXP2=USE_EXP2, - IS_FP8=False, - FP8_MAX=None, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - end_n -= num_steps * MASK_BLOCK_N2 - num_steps = tl.cdiv(end_n, BLOCK_N2) - start_n = max(end_n - num_steps * BLOCK_N2, 0) - if DEBUG_TRITON: print(f"unMasked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}") # noqa: E701 - dq = _bwd_dq_inner( - dq, # - q, K, V, do, m, Delta_ptr, sm_scale, # - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, # - stride_dropoutm, stride_dropoutn, # - stride_deltam, - seqlen_q, seqlen_k, # - BLOCK_M2, BLOCK_N2, # - HEAD_DIM, ACTUAL_HEAD_DIM, # - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - start_m, start_n, end_n, num_steps, # - None, None, None, None, - MASK=False, # - ENABLE_DROPOUT=ENABLE_DROPOUT, - USE_EXP2=USE_EXP2, - IS_FP8=False, - FP8_MAX=None, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - # Write back dQ. - adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk - dq *= sm_scale - tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) - # end of GQA/MQA of dq - -@triton.autotune( - configs=noncausal_autotune_configs, - key=noncausal_autotune_keys, - use_cuda_graph=True, -) -@triton.jit -def bwd_kernel_noncausal( - Q, K, V, sm_scale, DO, DQ, DK, DV, - M, Delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - HQ, HK, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - dropout_mask, dropout_p, philox_seed, philox_offset_base, - BLOCK_M1: tl.constexpr, # 32 - BLOCK_N1: tl.constexpr, # 128 - BLOCK_M2: tl.constexpr, # 128 - BLOCK_N2: tl.constexpr, # 32 - BLK_SLICE_FACTOR: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - USE_EXP2: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # program ids - pid = tl.program_id(0) - bid = tl.program_id(1) - hkid = tl.program_id(2) - if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 - # figure out varlen start and end - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - if IS_VARLEN: - # Compute actual sequence lengths - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - k_start = tl.load(cu_seqlens_k + bid) - k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - offs_k = tl.arange(0, HEAD_DIM) - GROUP_SIZE: tl.constexpr = HQ // HK - - start_n = pid * BLOCK_N1 - if start_n < seqlen_k: - dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - - offs_n = start_n + tl.arange(0, BLOCK_N1) - # Mask for loading K and V - mask_kv = offs_n[:, None] < seqlen_k - if PADDED_HEAD: - mask_k = offs_k < ACTUAL_HEAD_DIM - mask_kv &= mask_k[None, :] - offs_kv = offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk - - # K/V tensors not changed for the group - adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_kv + offs_kv, mask=mask_kv, other=0.0) - v = tl.load(V + adj_kv + offs_kv, mask=mask_kv, other=0.0) - # If MQA / GQA, set the K and V head offsets appropriately. - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - # offset input and output tensor by batch and Q/K heads - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - Q_ptr = Q + adj_q - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - DO_ptr = DO + adj_do - adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam - M_ptr = M + adj_delta - Delta_ptr = Delta + adj_delta - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ - hqid * stride_dropouth - dropout_offset = dropout_mask + bid * stride_dropoutb + \ - hqid * stride_dropouth - - # because there is no causal, we always start from the beginning - start_m = 0 - num_steps = tl.cdiv(seqlen_q, BLOCK_M1) - dk, dv = _bwd_dkdv_inner( - dk, dv, # output tensors - Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors - stride_qm, stride_qk, # strides for q - stride_dom, stride_dok, # strides for o - stride_dropoutm, stride_dropoutn, # strides for dropout - stride_deltam, - BLOCK_M1, BLOCK_N1, # block dim - HEAD_DIM, ACTUAL_HEAD_DIM, # head dim - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - seqlen_q, seqlen_k, # max sequence length for q and k - start_n, start_m, num_steps, # iteration numbers - None, None, None, None, - MASK=False, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - USE_EXP2=USE_EXP2, - IS_FP8=False, - FP8_MAX=None, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - - # Write back dV and dK. - adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn - offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk - tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) - dk *= sm_scale - tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) - - # THIS PART DOES DQ - start_m = pid * BLOCK_M2 - if start_m < seqlen_q: - offs_m = start_m + tl.arange(0, BLOCK_M2) - # Mask for loading K and V - mask_q = offs_m[:, None] < seqlen_q - if PADDED_HEAD: - mask_k = offs_k < ACTUAL_HEAD_DIM - mask_q &= mask_k[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk - offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok - adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn - K += adj_kv - V += adj_kv - # If MQA / GQA, set the K and V head offsets appropriately. - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - # offset input and output tensor by batch and Q/K heads - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - adj_delta = \ - bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam - Delta_ptr = Delta + adj_delta - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base + \ - bid * stride_dropoutb + \ - hqid * stride_dropouth - dropout_offset = \ - dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth - - q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) - do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) - m = tl.load(M + adj_delta + offs_m * stride_deltam, - mask=offs_m < seqlen_q) - m = m[:, None] - - # start can only be 0 at minimum - start_n = 0 - end_n = seqlen_k - num_steps = tl.cdiv(seqlen_k, BLOCK_N2) - - dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) - dq = _bwd_dq_inner( - dq, # - q, K, V, do, m, Delta_ptr, sm_scale, # - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, # - stride_dropoutm, stride_dropoutn, # - stride_deltam, - seqlen_q, seqlen_k, # - BLOCK_M2, BLOCK_N2, # - HEAD_DIM, ACTUAL_HEAD_DIM, # - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - start_m, start_n, end_n, num_steps, # - None, None, None, None, - MASK=False, # - ENABLE_DROPOUT=ENABLE_DROPOUT, - USE_EXP2=USE_EXP2, - IS_FP8=False, - FP8_MAX=None, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - # Write back dQ. - adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk - dq *= sm_scale - tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) - - -def attention_prefill_backward_triton_split_oneKernel_impl( - do: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - softmax_lse: torch.Tensor, - dq: torch.Tensor, - dk: torch.Tensor, - dv: torch.Tensor, - sm_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - layout: Literal["bshd", "bhsd", "thd"], - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - max_seqlen_q: Optional[int], - max_seqlen_k: Optional[int], - dropout_p: float, - philox_seed: Optional[int], - philox_offset: Optional[int], - use_exp2: bool, -): - # debug - DEBUG_TRITON: bool = False - DEBUG_TRITON_DETAIL: bool = False - - # get strides and shape - batch, nheads_q, nheads_k, head_size, max_seqlen_q_final, max_seqlen_k_final = \ - get_shapes_from_layout( - q, k, layout, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k - ) - q_strides, k_strides, v_strides, o_strides = \ - get_strides_from_layout(q, k, v, o, layout) - stride_qb, stride_qh, stride_qm, stride_qk = q_strides - stride_kb, stride_kh, stride_kn, stride_kk = k_strides - stride_vb, stride_vh, stride_vn, stride_vk = v_strides - stride_ob, stride_oh, stride_om, stride_ok = o_strides - dq_strides, dk_strides, _, do_strides = \ - get_strides_from_layout(dq, dk, dv, do, layout) - stride_dqb, stride_dqh, stride_dqm, stride_dqk = dq_strides - stride_dkb, stride_dkh, stride_dkn, stride_dkk = dk_strides - stride_dob, stride_doh, stride_dom, stride_dok = do_strides - IS_VARLEN = layout == "thd" - use_dropout = (dropout_p > 0.0) - - # get closest power of 2 over or equal to 32. - padded_d_model = 1 << (head_size - 1).bit_length() - padded_d_model = max(padded_d_model, 16) - HEAD_DIM = padded_d_model - ACTUAL_HEAD_DIM = head_size - - # init delta - delta = torch.empty_like(softmax_lse) - if IS_VARLEN: - stride_deltab = 0 - stride_deltam, stride_deltah = delta.stride() - else: - stride_deltab, stride_deltah, stride_deltam = delta.stride() - pre_grid = lambda META: (triton.cdiv(max_seqlen_q_final, META['PRE_BLOCK']), batch, nheads_q) - _bwd_preprocess[pre_grid]( - o, do, - delta, - stride_ob, stride_oh, stride_om, stride_ok, - stride_deltab, stride_deltah, stride_deltam, - 0, - cu_seqlens_q, max_seqlen_q_final, - None, - HEAD_DIM=HEAD_DIM, - ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, - IS_VARLEN=IS_VARLEN, - IS_FP8=False - ) - - # dropout mask tensor for debugging. We dump the dropout mask created in - # the kernel for testing - dropout_mask = None - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ - (0, 0 , 0 , 0) - if use_dropout: - dropout_mask = torch.zeros( - (batch, nheads_q, max_seqlen_q_final, max_seqlen_k_final), - device=q.device, - dtype=torch.float32 - ) - - if DROPOUT_USE_PYTORCH: - if not IS_VARLEN: - dropout_mask = create_dropout_mask( - dropout_p, - (batch, nheads_q, max_seqlen_q_final, max_seqlen_k_final), - seed = philox_seed - ) - else: - dropout_mask = create_dropout_mask_varlen( - dropout_p, batch, nheads_q, - cu_seqlens_q, cu_seqlens_k, philox_seed - ) - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ - dropout_mask.stride() - - seqlen = max(max_seqlen_q_final, max_seqlen_k_final) - grid = lambda META: ((seqlen + META['BLOCK_N1'] - 1) // META['BLOCK_N1'], batch, nheads_k) - if causal: - if DEBUG_TRITON: print(f"bwd_kernel: grid = {grid}" ) # noqa: E701 - bwd_kernel_causal[grid]( - q, k, v, sm_scale, do, dq, dk, dv, - softmax_lse, delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - nheads_q, nheads_k, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q_final, max_seqlen_k_final, - dropout_mask, dropout_p, philox_seed, philox_offset, - HEAD_DIM=HEAD_DIM, - ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - USE_EXP2=use_exp2, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - else: - bwd_kernel_noncausal[grid]( - q, k, v, sm_scale, do, dq, dk, dv, - softmax_lse, delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - nheads_q, nheads_k, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q_final, max_seqlen_k_final, - dropout_mask, dropout_p, philox_seed, philox_offset, - HEAD_DIM=HEAD_DIM, - ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - USE_EXP2=use_exp2, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - - return delta \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py deleted file mode 100644 index 5cc93edc5e4..00000000000 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py +++ /dev/null @@ -1,1354 +0,0 @@ -import torch -import triton # type: ignore -import triton.language as tl # type: ignore -from typing import Literal, Optional -from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, compute_fp8_scaling_factors, get_shapes_from_layout, \ - get_strides_from_layout, create_dropout_mask, create_dropout_mask_varlen, is_fp8 - -# NOTE: triton fails to import tl.constexprs so create them here for the file -tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) -tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) - -# This function computes delta given output Out and gradient DO -# Here is the I/O shape: -# Out: (batch, nhead_q, max_seqlens_q, headDim) -# DO: (batch, nhead_q, max_seqlens_q, headDim) -# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at -# fwd_prefill.py line 607 -@triton.jit -def _bwd_preprocess( - O, DO, # noqa: E741 - Delta, - stride_ob, stride_oh, stride_om, stride_ok, - stride_deltab, stride_deltah, stride_deltam, - stride_descale_do_z, - cu_seqlens_q, max_seqlen_q, - Descale_do, - BLOCK_M: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr -): - pid_m = tl.program_id(0) - bid = tl.program_id(1) - hid = tl.program_id(2) - # Handle varlen - q_start = 0 - seqlen_q = max_seqlen_q - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - seqlen_q = q_end - q_start - else: - q_start = 0 - seqlen_q = max_seqlen_q - - # Compute offsets - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_k = tl.arange(0, HEAD_DIM) - # Offset O/DO by batch, head and q_start - O += bid * stride_ob + hid * stride_oh + q_start * stride_om # noqa: E741 - DO += bid * stride_ob + hid * stride_oh + q_start * stride_om - # create masks - mask_m = offs_m < seqlen_q - mask_md = mask_m[:, None] - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - if PADDED_HEAD: - mask_md &= offs_k[None, :] < ACTUAL_HEAD_DIM - # compute pointers - offs_do = offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok - out_ptrs = O + offs_do - do_ptrs = DO + offs_do - # load - o = tl.load(out_ptrs, mask=mask_md, other=0.0) - do = tl.load(do_ptrs, mask=mask_md, other=0.0) - # compute and write-back to delta - if IS_FP8: - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hid) - - # NOTE: do is in the fp8 range and o is not in fp8 - delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) - else: - delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) - delta_offset = Delta + bid * stride_deltab + hid * stride_deltah + q_start * stride_deltam - tl.store(delta_offset + offs_m * stride_deltam, delta, mask=mask_m) - - -# The main inner-loop logic for computing dK and dV. -@triton.jit -def _bwd_dkdv_inner( - dk, dv, # output - Q, k, v, DO, M, D, sm_scale, # input tensor - stride_qm, stride_qk, - stride_dom, stride_dok, - stride_dropoutm, stride_dropoutn, - stride_deltam, - BLOCK_M: tl.constexpr, # 16 - BLOCK_N: tl.constexpr, # 128 - HEAD_DIM: tl.constexpr, # - ACTUAL_HEAD_DIM: tl.constexpr, # - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - alibi_slope, - seqlen_q, seqlen_k, # max sequence length for q and k - # Filled in by the wrapper. - start_n, start_m, num_steps, # iteration numbers - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - MASK: tl.constexpr, # causal masking, only apply to tiles on mask diagonal - ENABLE_DROPOUT: tl.constexpr, # activate dropout - USE_ALIBI: tl.constexpr, - USE_EXP2: tl.constexpr, # activate exp2 - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # if HEAD_DIM is padded - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - delta_qk = seqlen_q - seqlen_k - offs_m = start_m + tl.arange(0, BLOCK_M) # start_m + (0, 15) - offs_n = start_n + tl.arange(0, BLOCK_N) # start_m + (0, 127) - offs_k = tl.arange(0, HEAD_DIM) - # mask to make sure not OOB of seqlen_q - mask_n = offs_n < seqlen_k - # Q and DO are (seqlen_q, head_dim) - # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM, 1), transpose of q - qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k[:, None] * stride_qk - # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM), NOT transposed - do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok - # BLOCK_N must be a multiple of BLOCK_M, otherwise the code wouldn't work. - tl.static_assert(BLOCK_N % BLOCK_M == 0) - curr_m = start_m - step_m = BLOCK_M - curr_philox_offset = batch_philox_offset - curr_dropout_offset = dropout_offset - RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) - - for blk_idx in range(num_steps): - if DEBUG_TRITON: print(f"iter {blk_idx}: curr_m = {curr_m}") # noqa: E701 - offs_m = curr_m + tl.arange(0, BLOCK_M) - # update the mask because offs_m advanced - mask_m = offs_m < seqlen_q - mask_qT = mask_m[None, :] - mask_do = mask_m[:, None] - mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) - if PADDED_HEAD: - mask_qT &= offs_k[:, None] < ACTUAL_HEAD_DIM - mask_do &= offs_k[None, :] < ACTUAL_HEAD_DIM - qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) - # generate dropout mask - if ENABLE_DROPOUT: - # NOTE: dropout is transposed because it is used to mask pT - philox_offs = curr_philox_offset + \ - offs_m[None, :] * stride_dropoutm + \ - offs_n[:, None] * stride_dropoutn - if tl_DROPOUT_USE_PYTORCH: - dropout_offs = offs_m[None, :] * stride_dropoutm + \ - offs_n[:, None] * stride_dropoutn - dropout_mask = tl.load( - curr_dropout_offset + dropout_offs, - mask=mask_nm - ) - else: - rand_vals = tl.rand(philox_seed, philox_offs) - dropout_mask = rand_vals > dropout_p - dropout_scale = 1.0 / (1 - dropout_p) - # Load m before computing qk to reduce pipeline stall. - m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) - if IS_FP8: - qkT = (tl.dot(k, qT) * descale_q * descale_k) - else: - qkT = tl.dot(k, qT) - qkT_scaled = qkT * sm_scale - - if USE_ALIBI: - relative_pos_block = offs_n[:, None] + seqlen_q - seqlen_k - offs_m[None, :] - alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) - qkT_scaled += alibi_block - - if DEBUG_TRITON_DETAIL: - if start_n == 256: - print(f"qT: {qT.shape}\n", qT) - print(f"k: {k.shape}\n", k) - print(f"qkT scaled: {qkT.shape}\n", qkT_scaled) - # TODO: remove the scaling of m later when we removed re-scaling in fwd - if USE_EXP2: - pT = tl.math.exp2(qkT_scaled * RCP_LN2 - m[None, :] * RCP_LN2) - else: - pT = tl.math.exp(qkT_scaled - m[None, :]) - - # Autoregressive masking. - if MASK: - # offset offs_m with delta_qk since the causal mask starts at - # bottom right of the (seqlen_q, seqlen_k) matrix - causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] - mask = causal_mask & mask_nm - if DEBUG_TRITON_DETAIL: - if start_n == 256: - print(f"causal_mask: {causal_mask.shape}\n", causal_mask) - print(f"qkT after causal: {qkT.shape}\n", tl.where(causal_mask, qkT * sm_scale, 0.0)) - pT = tl.where(mask, pT, 0.0) - do = tl.load(do_ptrs, mask=mask_do, other=0.0) - # Compute dV. - if ENABLE_DROPOUT: - pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale - if IS_FP8: - scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors(pT_dropout, FP8_MAX) - dv += (tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do)* descale_p_dropout * descale_do) - else: - dv += tl.dot(pT_dropout.to(do.type.element_ty), do) - else: - if IS_FP8: - scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) - dv += (tl.dot((pT * scale_pT).to(do.type.element_ty), do) * descale_pT * descale_do) - else: - dv += tl.dot(pT.to(do.type.element_ty), do) - - if DEBUG_TRITON_DETAIL: - if start_n == 256: - print(f"pT: {pT.shape}\n", pT) - # D (= delta) is pre-divided by ds_scale. - Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) - # Compute dP and dS. - if IS_FP8: - dpT = (tl.dot(v, tl.trans(do)) * descale_v * descale_do) - else: - dpT = tl.dot(v, tl.trans(do)) - if ENABLE_DROPOUT: - dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale - delta_i = Di[None, :] - dsT = pT * (dpT - delta_i) - if IS_FP8: - scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) - dk += (tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) * descale_dsT * descale_q) - else: - dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) - # Increment pointers. - curr_m += step_m - qT_ptrs += step_m * stride_qm - do_ptrs += step_m * stride_dom - return dk, dv - - -# grid = (max_seqlen_k // BLOCK_N, batch, nheads_q) -@triton.jit -def _bwd_kernel_dkdv_causal( - Q, K, V, sm_scale, DO, DK, DV, - M, Delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - stride_az, stride_ah, - HQ, HK, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - Dropout_mask, dropout_p, philox_seed, philox_offset_base, - Alibi_slopes, - Descale_q, Descale_k, Descale_v, Descale_do, - BLOCK_M: tl.constexpr, # 32 - BLOCK_N: tl.constexpr, # 128 - BLK_SLICE_FACTOR: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - USE_ALIBI: tl.constexpr, - USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - FP8_OUTPUT: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # program ids - pid = tl.program_id(0) - bid = tl.program_id(1) - hkid = tl.program_id(2) - # figure out varlen start and end - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - if IS_VARLEN: - # Compute actual sequence lengths - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - k_start = tl.load(cu_seqlens_k + bid) - k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - dk = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) - dv = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) - # Figure out causal starting block since we have seqlen_q >=< seqlen_k. - # Unlike forward pass where we tile on M dim and iterate on N dim, so that - # we can skip some M blocks, in backward pass, we tile on the N dim for kv - # and iterate over the M. In this way, we cannot skip N blocks, but only to - # determine the starting M blocks to skip some initial blocks masked by - # causal. - delta_qk = seqlen_q - seqlen_k - if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") - if DEBUG_TRITON: print(f"delta_qk = {delta_qk}") - # q > k: diretcly skip all the way until the start of causal block - start_delta_q_gt_k = delta_qk - # q < k: some blocks will have no Masked block, other needs to re-calc - # starting position - # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the - # masked op - num_blocks_skip = -delta_qk // BLOCK_N - delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk - start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M - if delta_qk >= 0: - start_delta = delta_qk - if DEBUG_TRITON: print(f"q >= k: start_delta = delta_qk aligned to BLOCK_M = {start_delta_q_gt_k}") - else: - start_delta = start_delta_q_lt_k - if DEBUG_TRITON: print(f"q < k: start_delta = residue btw multiple BLOCK_N and delta_qk = {delta_aligned} = aligned to BLOCK_M = {start_delta_q_lt_k}") - # align the delta_qk - start_n = pid * BLOCK_N - - offs_k = tl.arange(0, HEAD_DIM) - offs_n = start_n + tl.arange(0, BLOCK_N) - # Mask for loading K and V - mask_kv = offs_n[:, None] < seqlen_k - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - if PADDED_HEAD: - mask_k = offs_k < ACTUAL_HEAD_DIM - mask_kv &= mask_k[None, :] - - GROUP_SIZE = HQ // HK - # K/V tensors not changed for the group - adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk - adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_k , mask=mask_kv, other=0.0) - v = tl.load(V + adj_v, mask=mask_kv, other=0.0) - # If MQA / GQA, set the K and V head offsets appropriately. - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - if delta_qk >= 0: - start_m = start_n + start_delta - len_m = BLOCK_N - else: - start_m = max(start_n + delta_qk, 0) - start_m = start_m // BLOCK_M * BLOCK_M - # because we might shift the masked blocks up, we are deeper into - # the masked out region, so we would potentially increase the total - # steps with masked operation to get out of it - residue_m = max(start_n + delta_qk - start_m, 0) - len_m = BLOCK_N + residue_m - if DEBUG_TRITON: print(f"residue_m = {residue_m}") - - # offset input and output tensor by batch and Q/K heads - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - Q_ptr = Q + adj_q - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - DO_ptr = DO + adj_do - adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam - M_ptr = M + adj_delta - Delta_ptr = Delta + adj_delta - - if USE_ALIBI: - alibi_offset = bid * stride_az + hqid * stride_ah - alibi_slope = tl.load(Alibi_slopes + alibi_offset) - else: - alibi_slope = None - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ - hqid * stride_dropouth - dropout_offset = Dropout_mask + bid * stride_dropoutb + \ - hqid * stride_dropouth - - MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR - # bound the masked operation to q len so it does not have to wast cycles - len_m = min(len_m, seqlen_q) - num_steps = tl.cdiv(len_m, MASK_BLOCK_M) - # when q < k, we may skip the initial masked op - if pid < num_blocks_skip: - num_steps = 0 - - if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) - descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) - descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - # if start_m is negative, the current N-tile has no block on the - # diagonal of causal mask, so everything have no causal mask - if DEBUG_TRITON: print(f"Masked: start_n: {start_n}; start_m: {start_m}, num_steps: {num_steps}") - dk, dv = _bwd_dkdv_inner( - dk, dv, # output tensors - Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors - stride_qm, stride_qk, # strides for q - stride_dom, stride_dok, # strides for o - stride_dropoutm, stride_dropoutn, # strides for dropout - stride_deltam, - MASK_BLOCK_M, BLOCK_N, # block dim - HEAD_DIM, ACTUAL_HEAD_DIM, # head dim - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - alibi_slope, - seqlen_q, seqlen_k, # max sequence length for q and k - start_n, start_m, num_steps, # iteration numbers - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - MASK=True, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - start_m += num_steps * MASK_BLOCK_M - num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) - end_m = start_m + num_steps * BLOCK_M - - if DEBUG_TRITON: print(f"start_m after Masked step: {start_m}; num_steps: {num_steps}") # noqa: E701 - if DEBUG_TRITON: print(f"unMasked: start_n: {start_n}, start_m: {start_m}, end_m: {end_m}, num_steps: {num_steps}") # noqa: E701 - if DEBUG_TRITON: print("unMasked") # noqa: E701 - dk, dv = _bwd_dkdv_inner( - dk, dv, # output tensors - Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors - stride_qm, stride_qk, # strides for q - stride_dom, stride_dok, # strides for o - stride_dropoutm, stride_dropoutn, # strides for dropout - stride_deltam, - BLOCK_M, BLOCK_N, # block dim - HEAD_DIM, ACTUAL_HEAD_DIM, # head dim - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - alibi_slope, - seqlen_q, seqlen_k, # max sequence length for q and k - start_n, start_m, num_steps, # iteration numbers - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - MASK=False, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - - # Write back dV and dK. - adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn - offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk - tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) - dk *= sm_scale - tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) - - -# the main inner-loop logic for computing dQ -@triton.jit -def _bwd_dq_inner( - dq, # output - q, K, V, do, m, Delta, sm_scale, # input - # shared by Q/K/V. - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, - stride_dropoutm, stride_dropoutn, # stride for dropout - stride_deltam, - seqlen_q, seqlen_k, # - BLOCK_M2: tl.constexpr, # - BLOCK_N2: tl.constexpr, # - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, # - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - alibi_slope, - # Filled in by the wrapper. - start_m, start_n, end_n, num_steps, # - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - MASK: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - USE_ALIBI: tl.constexpr, - USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # if HEAD_DIM is padded - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - delta_qk = seqlen_q - seqlen_k - offs_m = start_m + tl.arange(0, BLOCK_M2) - offs_n = start_n + tl.arange(0, BLOCK_N2) - offs_k = tl.arange(0, HEAD_DIM) - - # mask to make sure not OOB of seqlen_q - mask_m = offs_m < seqlen_q - - kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk - vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk - # D (= delta) is pre-divided by ds_scale. - Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0) - # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. - tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) - curr_n = start_n - step_n = BLOCK_N2 - curr_philox_offset = batch_philox_offset - curr_dropout_offset = dropout_offset - RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) - for blk_idx in range(num_steps): - if DEBUG_TRITON: print(f"iter {blk_idx}: curr_n = {curr_n}") # noqa: E701 - offs_n = curr_n + tl.arange(0, BLOCK_N2) - # end_n is needed because the end of causal True might not be perfectly - # aligned with the end of the block - mask_n = offs_n < end_n - if DEBUG_TRITON_DETAIL: print(f"start_n = {start_n}, end_n = {end_n}, offs_n: {offs_n.shape}\n{offs_n}") # noqa: E701 - if DEBUG_TRITON_DETAIL: print(f"mask_n: {mask_n.shape}\n{mask_n}") # noqa: E701 - mask_kT = mask_n[None, :] - mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) - if PADDED_HEAD: - mask_kT &= offs_k[:, None] < ACTUAL_HEAD_DIM - - kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) - vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) - - if ENABLE_DROPOUT: - # NOTE: dropout is transposed because it is used to mask pT - philox_offs = curr_philox_offset + \ - offs_m[:, None] * stride_dropoutm + \ - offs_n[None, :] * stride_dropoutn - if tl_DROPOUT_USE_PYTORCH: - dropout_offs = offs_m[:, None] * stride_dropoutm + \ - offs_n[None, :] * stride_dropoutn - dropout_mask = tl.load( - curr_dropout_offset + dropout_offs, - mask=mask_mn) - else: - rand_vals = tl.rand(philox_seed, philox_offs) - dropout_mask = rand_vals > dropout_p - dropout_scale = 1 / (1 - dropout_p) - - if IS_FP8: - qk = (tl.dot(q, kT) * descale_q * descale_k) - else: - qk = tl.dot(q, kT) - qk_scaled = qk * sm_scale - - if USE_ALIBI: - relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] - alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) - qk_scaled += alibi_block - - if DEBUG_TRITON_DETAIL: print(f"qk scaled: {qk.shape}\n", qk_scaled) # noqa: E701 - if USE_EXP2: - p = tl.math.exp2(qk_scaled * RCP_LN2 - m * RCP_LN2) - else: - p = tl.math.exp(qk_scaled - m) - - # Autoregressive masking. - if MASK: - causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] - mask = causal_mask & mask_mn - p = tl.where(mask, p, 0.0) - # Compute dP and dS. - if IS_FP8: - dp = (tl.dot(do, vT) * descale_do * descale_v) - else: - dp = tl.dot(do, vT) - if ENABLE_DROPOUT: - dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale - delta_i = Di[:, None] - ds = p * (dp -delta_i) - # Compute dQ. - # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. - if IS_FP8: - scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) - dq += (tl.dot((ds * scale_ds).to(kT.type.element_ty), tl.trans(kT)) * descale_ds * descale_k) - else: - dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) - # Increment pointers. - curr_n += step_n - kT_ptrs += step_n * stride_kn - vT_ptrs += step_n * stride_vn - return dq - - -# grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nheads_q) -@triton.jit -def _bwd_kernel_dq_causal( - Q, K, V, sm_scale, DO, DQ, - M, Delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - stride_az, stride_ah, - HQ, HK, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - Dropout_mask, dropout_p, philox_seed, philox_offset_base, - Alibi_slopes, - Descale_q, Descale_k, Descale_v, Descale_do, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - USE_ALIBI: tl.constexpr, - USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - FP8_OUTPUT: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # program ids - pid = tl.program_id(0) - bid = tl.program_id(1) - hkid = tl.program_id(2) - # figure out varlen start and end - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - if IS_VARLEN: - # Compute actual sequence lengths - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - k_start = tl.load(cu_seqlens_k + bid) - k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - # Figure out causal starting block since we have seqlen_q <=> seqlen_k. - # Unlike forward pass where we tile on M dim and iterate on N dim, so that - # we can skip some M blocks, in backward pass, we tile on the N dim for kv - # and iterate over the M. In this way, we cannot skip N blocks, but only to - # determine the starting M blocks to skip some initial blocks masked by - # causal. - # DQ tiles on M dim and iterate on N dim, so we there could be some tiles we - # can simply skip and we need to adjust starting position. - start_m = pid * BLOCK_M - # seqlen_q > seqlen_k, no need to process these tile for dq - delta_qk = seqlen_q - seqlen_k - if DEBUG_TRITON: print(f"end_n = start_m + BLOCK_M = {start_m} + {BLOCK_M} = {start_m + BLOCK_M}") # noqa: E701 - if start_m + BLOCK_M < delta_qk: - if DEBUG_TRITON: print(f"start_m + BLOCK_M = {start_m} + {BLOCK_M} = {start_m + BLOCK_M} < delta_qk of {delta_qk}") # noqa: E701 - return - - offs_k = tl.arange(0, HEAD_DIM) - offs_m = start_m + tl.arange(0, BLOCK_M) - # Mask for loading K and V - mask_q = offs_m[:, None] < seqlen_q - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - if PADDED_HEAD: - mask_k = offs_k < ACTUAL_HEAD_DIM - mask_q &= mask_k[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk - offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok - adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn - adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn - K += adj_k - V += adj_v - # If MQA / GQA, set the K and V head offsets appropriately. - GROUP_SIZE = HQ // HK - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front - # for every M-tile - end_n = start_m + BLOCK_M - delta_qk - # clamp end_n at [0, seqlen_k] - end_n = max(min(end_n, seqlen_k), 0) - if DEBUG_TRITON: print(f"delta_qk: {delta_qk}; end_n: {end_n}") # noqa: E701 - # offset input and output tensor by batch and Q/K heads - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - adj_delta = \ - bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam - Delta_ptr = Delta + adj_delta - - if USE_ALIBI: - alibi_offset = bid * stride_az + hqid * stride_ah - alibi_slope = tl.load(Alibi_slopes + alibi_offset) - else: - alibi_slope = None - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base + \ - bid * stride_dropoutb + \ - hqid * stride_dropouth - dropout_offset = \ - Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth - - q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) - do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) - m = tl.load(M + adj_delta + offs_m * stride_deltam, - mask=offs_m < seqlen_q) - m = m[:, None] - - MASK_BLOCK_N: tl.constexpr = BLOCK_N // BLK_SLICE_FACTOR - # start can only be 0 at minimum - start_n = max(end_n - BLOCK_M, 0) - num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N) - - if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) - descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) - descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - dq = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) - if DEBUG_TRITON: print(f"pid: {pid}; end_n: {end_n}, start_m: {start_m}") # noqa: E701 - # Compute dQ for masked (diagonal) blocks. - # NOTE: This code scans each row of QK^T backward (from right to left, - # but inside each call to _bwd_dq_inner, from left to right), but that's - # not due to anything important. I just wanted to reuse the loop - # structure for dK & dV above as much as possible. - if DEBUG_TRITON: print(f"Masked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}") # noqa: E701 - dq = _bwd_dq_inner( - dq, - q, K, V, do, m, Delta_ptr, sm_scale, - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, - stride_dropoutm, stride_dropoutn, - stride_deltam, - seqlen_q, seqlen_k, - BLOCK_M, MASK_BLOCK_N, - HEAD_DIM, ACTUAL_HEAD_DIM, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - alibi_slope, - start_m, start_n, end_n, num_steps, - descale_q, descale_k, descale_v, descale_do, - MASK=True, - ENABLE_DROPOUT=ENABLE_DROPOUT, - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - end_n -= num_steps * MASK_BLOCK_N - num_steps = tl.cdiv(end_n, BLOCK_N) - start_n = max(end_n - num_steps * BLOCK_N, 0) - if DEBUG_TRITON: print(f"unMasked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}") # noqa: E701 - dq = _bwd_dq_inner( - dq, - q, K, V, do, m, Delta_ptr, sm_scale, - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, - stride_dropoutm, stride_dropoutn, - stride_deltam, - seqlen_q, seqlen_k, - BLOCK_M, BLOCK_N, - HEAD_DIM, ACTUAL_HEAD_DIM, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - alibi_slope, - start_m, start_n, end_n, num_steps, - descale_q, descale_k, descale_v, descale_do, - MASK=False, - ENABLE_DROPOUT=ENABLE_DROPOUT, - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - # Write back dQ. - adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk - dq *= sm_scale - tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) - - -@triton.jit -def _bwd_kernel_dkdv_noncausal( - Q, K, V, sm_scale, DO, DK, DV, - M, Delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - stride_az, stride_ah, - HQ, HK, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - Dropout_mask, dropout_p, philox_seed, philox_offset_base, - Alibi_slopes, - Descale_q, Descale_k, Descale_v, Descale_do, - BLOCK_M: tl.constexpr, # 32 - BLOCK_N: tl.constexpr, # 128 - BLK_SLICE_FACTOR: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - USE_ALIBI: tl.constexpr, - USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - FP8_OUTPUT: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # program ids - pid = tl.program_id(0) - bid = tl.program_id(1) - hkid = tl.program_id(2) - # figure out varlen start and end - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - if IS_VARLEN: - # Compute actual sequence lengths - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - k_start = tl.load(cu_seqlens_k + bid) - k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - dk = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) - dv = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) - - start_n = pid * BLOCK_N - - offs_k = tl.arange(0, HEAD_DIM) - offs_n = start_n + tl.arange(0, BLOCK_N) - # Mask for loading K and V - mask_kv = offs_n[:, None] < seqlen_k - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - if PADDED_HEAD: - mask_k = offs_k < ACTUAL_HEAD_DIM - mask_kv &= mask_k[None, :] - - GROUP_SIZE = HQ // HK - # K/V tensors not changed for the group - adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk - adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_k, mask=mask_kv, other=0.0) - v = tl.load(V + adj_v, mask=mask_kv, other=0.0) - # If MQA / GQA, set the K and V head offsets appropriately. - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - # offset input and output tensor by batch and Q/K heads - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - Q_ptr = Q + adj_q - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - DO_ptr = DO + adj_do - adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam - M_ptr = M + adj_delta - Delta_ptr = Delta + adj_delta - - if USE_ALIBI: - alibi_offset = bid * stride_az + hqid * stride_ah - alibi_slope = tl.load(Alibi_slopes + alibi_offset) - else: - alibi_slope = None - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ - hqid * stride_dropouth - dropout_offset = Dropout_mask + bid * stride_dropoutb + \ - hqid * stride_dropouth - - if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) - descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) - descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - # because there is no causal, we always start from the beginning - start_m = 0 - num_steps = tl.cdiv(seqlen_q, BLOCK_M) - dk, dv = _bwd_dkdv_inner( - dk, dv, # output tensors - Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors - stride_qm, stride_qk, # strides for q - stride_dom, stride_dok, # strides for o - stride_dropoutm, stride_dropoutn, # strides for dropout - stride_deltam, - BLOCK_M, BLOCK_N, # block dim - HEAD_DIM, ACTUAL_HEAD_DIM, # head dim - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - alibi_slope, - seqlen_q, seqlen_k, # max sequence length for q and k - start_n, start_m, num_steps, # iteration numbers - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - MASK=False, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - - # Write back dV and dK. - adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn - offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk - tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) - dk *= sm_scale - tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) - - -@triton.jit -def _bwd_kernel_dq_noncausal( - Q, K, V, sm_scale, DO, DQ, - M, Delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - stride_az, stride_ah, - HQ, HK, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - Dropout_mask, dropout_p, philox_seed, philox_offset_base, - Alibi_slopes, - Descale_q, Descale_k, Descale_v, Descale_do, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - USE_ALIBI: tl.constexpr, - USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - FP8_OUTPUT: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # program ids - pid = tl.program_id(0) - bid = tl.program_id(1) - hkid = tl.program_id(2) - # figure out varlen start and end - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - if IS_VARLEN: - # Compute actual sequence lengths - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - k_start = tl.load(cu_seqlens_k + bid) - k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - start_m = pid * BLOCK_M - - offs_k = tl.arange(0, HEAD_DIM) - offs_m = start_m + tl.arange(0, BLOCK_M) - # Mask for loading K and V - mask_q = offs_m[:, None] < seqlen_q - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - if PADDED_HEAD: - mask_k = offs_k < ACTUAL_HEAD_DIM - mask_q &= mask_k[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk - offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok - adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn - adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn - K += adj_k - V += adj_v - # If MQA / GQA, set the K and V head offsets appropriately. - GROUP_SIZE = HQ // HK - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - # offset input and output tensor by batch and Q/K heads - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - adj_delta = \ - bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam - Delta_ptr = Delta + adj_delta - - if USE_ALIBI: - alibi_offset = bid * stride_az + hqid * stride_ah - alibi_slope = tl.load(Alibi_slopes + alibi_offset) - else: - alibi_slope = None - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base + \ - bid * stride_dropoutb + \ - hqid * stride_dropouth - dropout_offset = \ - Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth - - q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) - do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) - m = tl.load(M + adj_delta + offs_m * stride_deltam, - mask=offs_m < seqlen_q) - m = m[:, None] - - if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) - descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) - descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - # start can only be 0 at minimum - start_n = 0 - end_n = seqlen_k - num_steps = tl.cdiv(seqlen_k, BLOCK_N) - dq = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) - dq = _bwd_dq_inner( - dq, - q, K, V, do, m, Delta_ptr, sm_scale, - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, - stride_dropoutm, stride_dropoutn, - stride_deltam, - seqlen_q, seqlen_k, - BLOCK_M, BLOCK_N, - HEAD_DIM, ACTUAL_HEAD_DIM, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - alibi_slope, - start_m, start_n, end_n, num_steps, - descale_q, descale_k, descale_v, descale_do, - MASK=False, - ENABLE_DROPOUT=ENABLE_DROPOUT, - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - # Write back dQ. - adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk - dq *= sm_scale - tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) - - -def attention_prefill_backward_triton_split_impl( - do: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - softmax_lse: torch.Tensor, - dq: torch.Tensor, - dk: torch.Tensor, - dv: torch.Tensor, - sm_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - layout: Literal["bshd", "bhsd", "thd"], - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - max_seqlen_q: Optional[int], - max_seqlen_k: Optional[int], - dropout_p: float, - philox_seed: Optional[int], - philox_offset: Optional[int], - use_exp2: bool, - # fp8 - descale_q: Optional[torch.Tensor], - descale_k: Optional[torch.Tensor], - descale_v: Optional[torch.Tensor], - descale_o: Optional[torch.Tensor], - descale_do: Optional[torch.Tensor], - descale_dq: Optional[torch.Tensor], - descale_dk: Optional[torch.Tensor], - descale_dv: Optional[torch.Tensor], -): - # debug - DEBUG_TRITON: bool = False - DEBUG_TRITON_DETAIL: bool = False - - # fp8 - IS_FP8 = is_fp8(q) - if IS_FP8: - FP8_MAX = torch.finfo(q.dtype).max - # assert that the main inputs are fp8 - assert is_fp8(do) and is_fp8(q) and is_fp8(k) and is_fp8(v), f"Non fp8 type found: do.dtype={do.dtype}, q.dtype={q.dtype}, k.dtype={k.dtype}, v.dtype={v.dtype}. All tensors must be fp8." - if is_fp8(o): - FP8_OUTPUT = True - assert descale_o is not None, f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor o." - assert descale_dq is not None, f"descale_dq is None. In fp8, you need to pass a tensor for descale_dq along with a tensor dq." - assert descale_dk is not None, f"descale_dk is None. In fp8, you need to pass a tensor for descale_dk along with a tensor dk." - assert descale_dv is not None, f"descale_dv is None. In fp8, you need to pass a tensor for descale_dv along with a tensor dv." - else: - FP8_OUTPUT = False - - stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None - stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None - stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None - stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None - stride_descale_do_z = descale_do.stride(0) if descale_do is not None else None - else: - FP8_MAX = None - FP8_OUTPUT = False - stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_o_z = stride_descale_do_z = None - - - # get strides and shape - batch, nheads_q, nheads_k, head_size, max_seqlen_q_final, max_seqlen_k_final = \ - get_shapes_from_layout( - q, k, layout, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k - ) - q_strides, k_strides, v_strides, o_strides = \ - get_strides_from_layout(q, k, v, o, layout) - stride_qb, stride_qh, stride_qm, stride_qk = q_strides - stride_kb, stride_kh, stride_kn, stride_kk = k_strides - stride_vb, stride_vh, stride_vn, stride_vk = v_strides - stride_ob, stride_oh, stride_om, stride_ok = o_strides - dq_strides, dk_strides, dv_strides, do_strides = \ - get_strides_from_layout(dq, dk, dv, do, layout) - stride_dqb, stride_dqh, stride_dqm, stride_dqk = dq_strides - stride_dkb, stride_dkh, stride_dkn, stride_dkk = dk_strides - stride_dvb, stride_dvh, stride_dvn, stride_dvk = dv_strides - stride_dob, stride_doh, stride_dom, stride_dok = do_strides - IS_VARLEN = layout == "thd" - use_dropout = (dropout_p > 0.0) - use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) - - # get closest power of 2 over or equal to 32. - padded_d_model = 1 << (head_size - 1).bit_length() - padded_d_model = max(padded_d_model, 32) # NOTE: the causal path expects a min of 32. It will cause a compiler assert. - HEAD_DIM = padded_d_model - ACTUAL_HEAD_DIM = head_size - # meta-parameters - # TODO: fix num_stages later - NUM_WARPS, NUM_STAGES = 4, 1 - WAVES_PER_EU = 1 - PRE_BLOCK = 128 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 - BLK_SLICE_FACTOR = 2 - - # init delta - delta = torch.zeros_like(softmax_lse) - if IS_VARLEN: - stride_deltab = 0 - stride_deltah, stride_deltam = delta.stride() - else: - stride_deltab, stride_deltah, stride_deltam = delta.stride() - pre_grid = (triton.cdiv(max_seqlen_q_final, PRE_BLOCK), batch, nheads_q) - _bwd_preprocess[pre_grid]( - o, do, - delta, - stride_ob, stride_oh, stride_om, stride_ok, - stride_deltab, stride_deltah, stride_deltam, - stride_descale_do_z, - cu_seqlens_q, max_seqlen_q_final, - descale_do, - BLOCK_M=PRE_BLOCK, - HEAD_DIM=HEAD_DIM, - ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8 - ) - - if DEBUG: - print("delta:", delta, delta.shape) - - # dropout mask tensor for debugging. We dump the dropout mask created in - # the kernel for testing - dropout_mask = None - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ - (0, 0 , 0 , 0) - if use_dropout: - dropout_mask = torch.zeros( - (batch, nheads_q, max_seqlen_q_final, max_seqlen_k_final), - device=q.device, - dtype=torch.float32 - ) - - if DROPOUT_USE_PYTORCH: - if not IS_VARLEN: - dropout_mask = create_dropout_mask( - dropout_p, - (batch, nheads_q, max_seqlen_q_final, max_seqlen_k_final), - seed = philox_seed - ) - else: - dropout_mask = create_dropout_mask_varlen( - dropout_p, batch, nheads_q, - cu_seqlens_q, cu_seqlens_k, philox_seed - ) - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ - dropout_mask.stride() - - grid_dkdv = ((max_seqlen_k_final + BLOCK_N1 - 1) // BLOCK_N1, batch, nheads_k) - grid_dq = ((max_seqlen_q_final + BLOCK_M2 - 1) // BLOCK_M2, batch, nheads_k) - if causal: - if DEBUG_TRITON: print(f"_bwd_kernel_dkdv: grid = {grid_dkdv}, block_size = ({BLOCK_M1, BLOCK_N1})", ) # noqa: E701 - _bwd_kernel_dkdv_causal[grid_dkdv]( - q, k, v, sm_scale, do, dk, dv, - softmax_lse, delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - stride_az, stride_ah, - nheads_q, nheads_k, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q_final, max_seqlen_k_final, - dropout_mask, dropout_p, philox_seed, philox_offset, - alibi_slopes, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M1, BLOCK_N1, BLK_SLICE_FACTOR, - HEAD_DIM, ACTUAL_HEAD_DIM, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - USE_ALIBI=use_alibi, - USE_EXP2=use_exp2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - FP8_OUTPUT=FP8_OUTPUT, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu = WAVES_PER_EU, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - - if DEBUG_TRITON: print(f"\n_bwd_kernel_dq: grid = {grid_dq}, block_size = ({BLOCK_M2, BLOCK_N2})", ) # noqa: E701 - _bwd_kernel_dq_causal[grid_dq]( - q, k, v, sm_scale, do, dq, - softmax_lse, delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - stride_az, stride_ah, - nheads_q, nheads_k, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q_final, max_seqlen_k_final, - dropout_mask, dropout_p, philox_seed, philox_offset, - alibi_slopes, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M2, BLOCK_N2, BLK_SLICE_FACTOR, - HEAD_DIM, ACTUAL_HEAD_DIM, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - USE_ALIBI=use_alibi, - USE_EXP2=use_exp2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - FP8_OUTPUT=FP8_OUTPUT, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu = WAVES_PER_EU, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - else: - _bwd_kernel_dkdv_noncausal[grid_dkdv]( - q, k, v, sm_scale, do, dk, dv, - softmax_lse, delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - stride_az, stride_ah, - nheads_q, nheads_k, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q_final, max_seqlen_k_final, - dropout_mask, dropout_p, philox_seed, philox_offset, - alibi_slopes, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M1, BLOCK_N1, BLK_SLICE_FACTOR, - HEAD_DIM, ACTUAL_HEAD_DIM, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - USE_ALIBI=use_alibi, - USE_EXP2=use_exp2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - FP8_OUTPUT=FP8_OUTPUT, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu = WAVES_PER_EU, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - - _bwd_kernel_dq_noncausal[grid_dq]( - q, k, v, sm_scale, do, dq, - softmax_lse, delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - stride_az, stride_ah, - nheads_q, nheads_k, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q_final, max_seqlen_k_final, - dropout_mask, dropout_p, philox_seed, philox_offset, - alibi_slopes, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M2, BLOCK_N2, BLK_SLICE_FACTOR, - HEAD_DIM, ACTUAL_HEAD_DIM, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - USE_ALIBI=use_alibi, - USE_EXP2=use_exp2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - FP8_OUTPUT=FP8_OUTPUT, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu = WAVES_PER_EU, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - - return delta diff --git a/flash_attn/flash_attn_triton_amd/bwd_ref.py b/flash_attn/flash_attn_triton_amd/bwd_ref.py deleted file mode 100644 index 90a98ce4fcc..00000000000 --- a/flash_attn/flash_attn_triton_amd/bwd_ref.py +++ /dev/null @@ -1,478 +0,0 @@ -import torch -import math -from typing import Literal, Optional -from .utils import DEBUG, compute_alibi_tensor_ref - -DEBUG_CORE = False - -def attention_backward_core_ref_impl( - do, q, k, v, o, softmax_lse, sm_scale, causal, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2 -): - if DEBUG_CORE: - print() - print("attention_backward_core_ref_impl") - print("do:", do, do.shape) - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("o:", o, o.shape) # is a bad number - print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("sm_scale:", sm_scale) - print("causal:", causal) - print("dropout_p:", dropout_p) - print("philox_seed:", philox_seed) - print("philox_offset:", philox_offset) - print("use_exp2:", use_exp2) - - # cast to float32 - do = do.to(torch.float32) - q = q.to(torch.float32) - k = k.to(torch.float32) - v = v.to(torch.float32) - o = o.to(torch.float32) - softmax_lse = softmax_lse.to(torch.float32) - - - # recompute attention_scores. Make sure it matches the forward impl. i.e. It use float32 - attention_scores = torch.matmul(q, k.transpose(-2, -1)) - if DEBUG_CORE: - print("attention_scores:", attention_scores, attention_scores.shape) - - # scale scores - attention_scaled_scores = sm_scale * attention_scores - if DEBUG_CORE: - print("attention_scaled_scores:", attention_scaled_scores, attention_scaled_scores.shape) - - if alibi_slopes is not None: - L_q, L_k = q.shape[1], k.shape[1] - if DEBUG_CORE: - print("alibi_slopes:", alibi_slopes, alibi_slopes.shape) - alibi_bias = compute_alibi_tensor_ref(alibi_slopes, L_q, L_k) - alibi_bias = alibi_bias.reshape(-1, L_q, L_k) - if True: - print("alibi_bias:", alibi_bias, alibi_bias.shape) - attention_scaled_scores = attention_scaled_scores + alibi_bias - if DEBUG_CORE: - print("attention_scaled_scores after alibi:", attention_scaled_scores, attention_scaled_scores.shape) - - # Apply causal mask if necessary - if causal: - L_q, L_k = q.shape[1], k.shape[1] - row_idx = torch.arange(L_q, device=q.device).unsqueeze(1) - col_idx = torch.arange(L_k, device=q.device).unsqueeze(0) - col_offset = L_q-L_k - causal_mask = row_idx >= (col_offset + col_idx) - if DEBUG_CORE: - print("causal_mask:", causal_mask) - # set -inf to places the causal mask is false - attention_scaled_scores = attention_scaled_scores.masked_fill( - torch.logical_not(causal_mask.unsqueeze(0)), float('-inf') - ) - if DEBUG_CORE: - print("attention_scaled_scores after causal:", attention_scaled_scores, attention_scaled_scores.shape) - - # compute probabilities using softmax_lse - if use_exp2: - RCP_LN = 1 / math.log(2) - attention_scaled_scores_base2 = attention_scaled_scores * RCP_LN - softmax_lse_base2 = softmax_lse * RCP_LN - softmax_lse_3d = softmax_lse_base2.unsqueeze(-1) - p = torch.exp2(attention_scaled_scores_base2 - softmax_lse_3d) - else: - softmax_lse_3d = softmax_lse.unsqueeze(-1) - p = torch.exp(attention_scaled_scores - softmax_lse_3d) - if DEBUG_CORE: - print("softmax_lse_3d:", softmax_lse_3d, softmax_lse_3d.shape) - print("p:", p, p.shape) - - if dropout_p > 0.0: - rand_vals = torch.rand(p.shape, generator=torch.Generator(device=p.device).manual_seed(philox_seed), device=p.device, dtype=p.dtype) - dropout_mask, dropout_scale = rand_vals > dropout_p, (1.0 / (1 - dropout_p)) - if DEBUG: - print("dropout_scale:", dropout_scale) - print("dropout_mask:", dropout_mask) - - p_drop = torch.where(dropout_mask, p, torch.zeros_like(p)) - p_drop_scaled = p_drop * dropout_scale - if DEBUG_CORE: - print("dropout_scale:", dropout_scale) - print("p_drop:", p_drop, p_drop.shape) - print("p_drop_scaled:", p_drop_scaled, p_drop_scaled.shape) - - # compute dv - dv = torch.matmul(p_drop_scaled.transpose(-2, -1), do) - if DEBUG_CORE: - print("dv:", dv, dv.shape) - - # compute dp - dp_dropout = torch.matmul(do, v.transpose(-2, -1)) - dp = torch.where(dropout_mask, dp_dropout , torch.zeros_like(dp_dropout)) * dropout_scale - if DEBUG_CORE: - print("dp_dropout:", dp_dropout, dp_dropout.shape) - print("dp:", dp, dp.shape) - else: - # compute dv - dv = torch.matmul(p.transpose(-2, -1), do) - if DEBUG_CORE: - print("dv:", dv, dv.shape) - - # compute dp - dp = torch.matmul(do, v.transpose(-2, -1)) - if DEBUG_CORE: - print("dp:", dp, dp.shape) - - # calculate ds - if False: - delta = torch.sum(o * do, axis=-1).unsqueeze(-1) - else: - delta = torch.sum(p * dp, axis=-1).unsqueeze(-1) - if DEBUG: - print("delta:", delta, delta.shape) - dscores_scaled = p * (dp - delta) - ds = dscores_scaled * sm_scale - if DEBUG_CORE: - print("dscores_scaled:", dscores_scaled, dscores_scaled.shape) - print("ds:", ds, ds.shape) - - # compute gradient wrt k & q - dk = torch.matmul(ds.transpose(-2, -1), q) - dq = torch.matmul(ds, k) - if DEBUG_CORE: - print("dk:", dk, dk.shape) - print("dq:", dq, dq.shape) - - # cast back to original dtype - dq = dq.to(torch.float16) - dk = dk.to(torch.float16) - dv = dv.to(torch.float16) - # remove d dim with size 1 - delta = delta.squeeze(-1) - - if DEBUG_CORE: - print("attention_backward_core_ref_impl output") - print("delta:", delta, delta.shape) - print("dv:", dv, dv.shape) - print("dk:", dk, dk.shape) - print("dq:", dq, dq.shape) - - return dq, dk, dv, delta - -def attention_varlen_backward_pytorch_ref_impl( - do, - q, - k, - v, - o, - softmax_lse, - sm_scale, - causal, - layout, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - use_exp2, -): - # Ensure the layout is 'thd' - if layout != 'thd': - raise ValueError(f"Unsupported layout {layout}. Expected 'thd'.") - - batch_size = cu_seqlens_q.shape[0] - 1 - nheads_q, head_dim = q.shape[1], q.shape[2] - nheads_k = k.shape[1] - - group_size = nheads_q // nheads_k - if nheads_q % nheads_k != 0: - raise ValueError("nheads_q must be divisible by nheads_k") - - # Pre-allocate outputs - total_L_q = q.shape[0] - total_L_k = k.shape[0] - - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) - # delta has the same shape as softmax_lse: [total_L_q, nheads_q] - delta = torch.zeros((total_L_q, nheads_q), dtype=torch.float32, device=o.device) - - for i in range(batch_size): - # Get the start and end indices for the current sequence - start_q = cu_seqlens_q[i].item() - end_q = cu_seqlens_q[i + 1].item() - start_k = cu_seqlens_k[i].item() - end_k = cu_seqlens_k[i + 1].item() - - # Extract q_i, k_i, v_i, do_i, o_i, softmax_lse_i - q_i = q[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] - k_i = k[start_k:end_k, :, :] # [L_k_i, nheads_k, head_dim] - v_i = v[start_k:end_k, :, :] # [L_k_i, nheads_k, head_dim] - do_i = do[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] - o_i = o[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] - softmax_lse_i = softmax_lse[start_q:end_q, :] # [L_q_i, nheads_q] - - if group_size != 1: - # MQA or GQA case - # Reshape tensors to include group dimension - q_i = q_i.view(q_i.shape[0], nheads_k, group_size, head_dim) - do_i = do_i.view(do_i.shape[0], nheads_k, group_size, head_dim) - o_i = o_i.view(o_i.shape[0], nheads_k, group_size, head_dim) - softmax_lse_i = softmax_lse_i.view(softmax_lse_i.shape[0], nheads_k, group_size) - # Expand k_i and v_i to match group_size - k_i = k_i.unsqueeze(2).expand(-1, -1, group_size, -1) - v_i = v_i.unsqueeze(2).expand(-1, -1, group_size, -1) - # Flatten the nheads_k and group_size dimensions - q_i = q_i.reshape(q_i.shape[0], nheads_k * group_size, head_dim) - do_i = do_i.reshape(do_i.shape[0], nheads_k * group_size, head_dim) - o_i = o_i.reshape(o_i.shape[0], nheads_k * group_size, head_dim) - softmax_lse_i = softmax_lse_i.reshape(softmax_lse_i.shape[0], nheads_k * group_size) - k_i = k_i.reshape(k_i.shape[0], nheads_k * group_size, head_dim) - v_i = v_i.reshape(v_i.shape[0], nheads_k * group_size, head_dim) - # Permute to [nheads_total, L, head_dim] - q_i = q_i.permute(1, 0, 2) - k_i = k_i.permute(1, 0, 2) - v_i = v_i.permute(1, 0, 2) - do_i = do_i.permute(1, 0, 2) - o_i = o_i.permute(1, 0, 2) - softmax_lse_i = softmax_lse_i.transpose(0, 1) - if alibi_slopes is not None: - alibi_slopes_i = alibi_slopes[i] - else: - alibi_slopes_i = None - - # Call the core backward function for this sequence - dq_i, dk_i, dv_i, delta_i = attention_backward_core_ref_impl( - do_i, - q_i, - k_i, - v_i, - o_i, - softmax_lse_i, - sm_scale, - causal, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes_i, - use_exp2 - ) - - # Convert back to 'thd' layout - dq_i = dq_i.permute(1, 0, 2) # [L_q_i, nheads_total, head_dim] - dk_i = dk_i.permute(1, 0, 2) # [L_k_i, nheads_total, head_dim] - dv_i = dv_i.permute(1, 0, 2) # [L_k_i, nheads_total, head_dim] - delta_i = delta_i.transpose(1, 0) # [L_q_i, nheads_total] - - if group_size != 1: - # Reshape dq_i and delta_i back to original shape - dq_i = dq_i.view(dq_i.shape[0], nheads_k, group_size, head_dim) - delta_i = delta_i.view(delta_i.shape[0], nheads_k, group_size) - # Sum dk_i and dv_i over group dimension - dk_i = dk_i.view(dk_i.shape[0], nheads_k, group_size, head_dim) - dv_i = dv_i.view(dv_i.shape[0], nheads_k, group_size, head_dim) - dk_i = dk_i.sum(dim=2) - dv_i = dv_i.sum(dim=2) - # Reshape dq_i back to [L_q_i, nheads_q, head_dim] - dq_i = dq_i.reshape(dq_i.shape[0], nheads_q, head_dim) - delta_i = delta_i.reshape(delta_i.shape[0], nheads_q) - else: - # No need to reshape - pass - - # Place outputs in pre-allocated tensors - dq[start_q:end_q, :, :] = dq_i - dk[start_k:end_k, :, :] += dk_i # Accumulate gradients for shared keys - dv[start_k:end_k, :, :] += dv_i # Accumulate gradients for shared values - delta[start_q:end_q, :] = delta_i - - return dq, dk, dv, delta - -def attention_vanilla_backward_pytorch_ref_impl( - do, - q, - k, - v, - o, - softmax_lse, - sm_scale, - causal, - layout, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - use_exp2, -): - if layout == "bshd": - if DEBUG: - print() - print("Changing layout to bhsd!") - do = do.transpose(1, 2).contiguous() - q = q.transpose(1, 2).contiguous() - k = k.transpose(1, 2).contiguous() - v = v.transpose(1, 2).contiguous() - o = o.transpose(1, 2).contiguous() - elif layout == "bhsd": - pass - else: - raise ValueError(f"Unknown layout {layout}") - - # Prepare tensors - batch_size, nheads_q, seq_len_q, head_dim = q.shape - batch_size, nheads_k, seq_len_k, head_dim = k.shape - - group_size = nheads_q // nheads_k - if nheads_q % nheads_k != 0: - raise ValueError("nheads_q must be divisible by nheads_k") - - if group_size != 1: - # MQA or GQA case - # Reshape do, q, o to [batch_size, nheads_k, group_size, seq_len_q, head_dim] - do = do.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) - q = q.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) - o = o.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) - # Reshape softmax_lse to [batch_size, nheads_k, group_size, seq_len_q] - softmax_lse = softmax_lse.reshape(batch_size, nheads_k, group_size, seq_len_q) - # Expand k and v to match group_size - k = k.unsqueeze(2).expand(-1, -1, group_size, -1, -1) # [batch_size, nheads_k, group_size, seq_len_k, head_dim] - v = v.unsqueeze(2).expand(-1, -1, group_size, -1, -1) - # Flatten the first three dimensions for computation - do = do.reshape(batch_size * nheads_k * group_size, seq_len_q, head_dim) - q = q.reshape(batch_size * nheads_k * group_size, seq_len_q, head_dim) - k = k.reshape(batch_size * nheads_k * group_size, seq_len_k, head_dim) - v = v.reshape(batch_size * nheads_k * group_size, seq_len_k, head_dim) - o = o.reshape(batch_size * nheads_k * group_size, seq_len_q, head_dim) - softmax_lse = softmax_lse.reshape(batch_size * nheads_k * group_size, seq_len_q) - else: - # Standard case - do = do.reshape(batch_size * nheads_q, seq_len_q, head_dim) - q = q.reshape(batch_size * nheads_q, seq_len_q, head_dim) - k = k.reshape(batch_size * nheads_k, seq_len_k, head_dim) - v = v.reshape(batch_size * nheads_k, seq_len_k, head_dim) - o = o.reshape(batch_size * nheads_q, seq_len_q, head_dim) - softmax_lse = softmax_lse.reshape(batch_size * nheads_q, seq_len_q) - - # Call the core backward function - dq, dk, dv, delta = attention_backward_core_ref_impl( - do, - q, - k, - v, - o, - softmax_lse, - sm_scale, - causal, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - use_exp2 - ) - - if group_size != 1: - # Reshape dq back to [batch_size, nheads_k, group_size, seq_len_q, head_dim] - dq = dq.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) - # Reshape delta back to [batch_size, nheads_k, group_size, seq_len_q] - delta = delta.reshape(batch_size, nheads_k, group_size, seq_len_q) - # Sum dk and dv over group_size dimension, since k and v are shared across groups - dk = dk.reshape(batch_size, nheads_k, group_size, seq_len_k, head_dim) - dk = dk.sum(dim=2) # Sum over group_size dimension - dv = dv.reshape(batch_size, nheads_k, group_size, seq_len_k, head_dim) - dv = dv.sum(dim=2) - # Reshape dq to [batch_size, nheads_q, seq_len_q, head_dim] - dq = dq.reshape(batch_size, nheads_k * group_size, seq_len_q, head_dim) - delta = delta.reshape(batch_size, nheads_k * group_size, seq_len_q) - else: - # Standard case - dq = dq.reshape(batch_size, nheads_q, seq_len_q, head_dim) - dk = dk.reshape(batch_size, nheads_k, seq_len_k, head_dim) - dv = dv.reshape(batch_size, nheads_k, seq_len_k, head_dim) - delta = delta.reshape(batch_size, nheads_q, seq_len_q) - - # Go back to original layout - if layout == "bshd": - if DEBUG: - print() - print("Changing back to bshd!") - dq = dq.transpose(1, 2) - dk = dk.transpose(1, 2) - dv = dv.transpose(1, 2) - elif layout == "bhsd": - pass - else: - raise ValueError(f"Unknown layout {layout}") - - return dq, dk, dv, delta - -def attention_backward_pytorch_ref_impl( - do: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - softmax_lse: torch.Tensor, - dq: torch.Tensor, - dk: torch.Tensor, - dv: torch.Tensor, - sm_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - layout: Literal["bshd", "bhsd", "thd"], - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - max_seqlen_q: Optional[int], - max_seqlen_k: Optional[int], - dropout_p: float, - philox_seed: Optional[int], - philox_offset: Optional[int], - use_exp2: bool -): - if layout == "thd": - dq_ref, dk_ref, dv_ref, delta = attention_varlen_backward_pytorch_ref_impl( - do, - q, - k, - v, - o, - softmax_lse, - sm_scale, - causal, - layout, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - use_exp2, - ) - else: - dq_ref, dk_ref, dv_ref, delta = attention_vanilla_backward_pytorch_ref_impl( - do, - q, - k, - v, - o, - softmax_lse, - sm_scale, - causal, - layout, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - use_exp2, - ) - - - # copy into output tensor - dv.copy_(dv_ref.to(dv.dtype)) - dk.copy_(dk_ref.to(dk.dtype)) - dq.copy_(dq_ref.to(dq.dtype)) - - return delta diff --git a/flash_attn/flash_attn_triton_amd/common.py b/flash_attn/flash_attn_triton_amd/common.py new file mode 100644 index 00000000000..2f1a209383a --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/common.py @@ -0,0 +1,551 @@ +""" +Triton kernel helper functions shared across flash attention modules. + +This module contains Triton JIT-compiled helper functions that are used within +the main attention kernels (fwd_prefill, fwd_decode, bwd). These are kept +separate from utils.py to allow stricter type checking on pure Python utilities. +""" +from typing import Literal, Optional, Tuple, Union + +import torch +import triton +import triton.language as tl + +from .utils import DEBUG, get_shape_from_layout, get_stride_from_layout, is_fp8 + + +@triton.jit +def compute_alibi_block( + alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False +): + """ + Compute ALiBi (Attention with Linear Biases) block. + + When seqlen_k and seqlen_q are different, the diagonal sticks to the + bottom right of the attention matrix. + """ + # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5 + # offs_m = [0, 1], offs_n = [0, 1, 2, 3, 4] + # Result: [[-3, -2, -1, 0, -1], [-4, -3, -2, -1, 0]] + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + if transpose: + return alibi_block.T + else: + return alibi_block + + +@triton.jit +def compute_fp8_scaling_factors(x, fp8_max: tl.constexpr): + """Compute FP8 scaling and descaling factors for a block.""" + x_amax = tl.max(tl.abs(x)) + x_amax = tl.where(x_amax <= 1e-9, 1e-9, x_amax) + scale_x = fp8_max / x_amax + descale_x = x_amax / fp8_max + return scale_x, descale_x + + +@triton.jit +def _cast_varlen_to_fp8_kernel_2d( + X, + X_fp8, + Descale, + cu_seqlens, + H, + MAX_SEQLEN, + stride_batch, + stride_seq, + stride_head, + stride_dim, + stride_out_batch, + stride_out_seq, + stride_out_head, + stride_out_dim, + stride_desc_batch, + stride_desc_head, + FP8_CLAMP_VAL, + FP8_MAX, + BLOCK_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + """Cast tensor to FP8 with per-(batch, head) scaling.""" + b_id = tl.program_id(0) + h_id = tl.program_id(1) + + # Get sequence bounds for this batch + if IS_VARLEN: + seq_start = tl.load(cu_seqlens + b_id) + seq_end = tl.load(cu_seqlens + b_id + 1) + seqlen = seq_end - seq_start + else: + seq_start = 0 + seqlen = MAX_SEQLEN + + # initialize max value tracker + x_max_val = 0.0 + + # STEP 1: Find max absolute value across the entire sequence + num_of_blocks = tl.cdiv(seqlen, BLOCK_SIZE) + for blk_idx in range(0, num_of_blocks): + offs_seq = blk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_dim = tl.arange(0, HEAD_DIM) + + mask_seq = offs_seq[:, None] < seqlen + if ACTUAL_HEAD_DIM != HEAD_DIM: + mask_dim = offs_dim[None, :] < ACTUAL_HEAD_DIM + mask_seq = mask_seq & mask_dim + + adj_x = ( + b_id * stride_batch + + h_id * stride_head + + seq_start * stride_seq + + offs_seq[:, None] * stride_seq + + offs_dim[None, :] * stride_dim + ) + x_block = tl.load(X + adj_x, mask=mask_seq, other=0.0) + block_max = tl.max(tl.abs(x_block)) + x_max_val = tl.maximum(x_max_val, block_max) + + # clamp to avoid division by zero + x_max_val = tl.maximum(x_max_val, FP8_CLAMP_VAL) + + # compute scale and descale factors + scale = FP8_MAX / x_max_val + descale = x_max_val / FP8_MAX + + # store descale factor + desc_ptr = Descale + b_id * stride_desc_batch + h_id + tl.store(desc_ptr, descale) + + # STEP 2: Apply scaling and convert to FP8 + for blk_idx in range(0, num_of_blocks): + offs_seq = blk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_dim = tl.arange(0, HEAD_DIM) + + mask_seq = offs_seq[:, None] < seqlen + if ACTUAL_HEAD_DIM != HEAD_DIM: + mask_dim = offs_dim[None, :] < ACTUAL_HEAD_DIM + mask_seq = mask_seq & mask_dim + + addr = ( + b_id * stride_batch + + h_id * stride_head + + seq_start * stride_seq + + offs_seq[:, None] * stride_seq + + offs_dim[None, :] * stride_dim + ) + x_block = tl.load(X + addr, mask=mask_seq, other=0.0) + x_fp8_block = (x_block * scale).to(X_fp8.type.element_ty) + + addr_out = ( + b_id * stride_out_batch + + h_id * stride_out_head + + seq_start * stride_out_seq + + offs_seq[:, None] * stride_out_seq + + offs_dim[None, :] * stride_out_dim + ) + tl.store(X_fp8 + addr_out, x_fp8_block, mask=mask_seq) + + +@triton.jit +def _rotary_kernel( + OUT, + X, + COS, + SIN, + CU_SEQLENS, + SEQLEN_OFFSETS, + seqlen, + nheads, + seqlen_ro, + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + ROTARY_DIM: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_M: tl.constexpr, +): + """Apply rotary positional embeddings.""" + BLOCK_K: tl.constexpr = triton.next_power_of_2(ROTARY_DIM) + ROTARY_DIM_HALF = ROTARY_DIM // 2 + pid_head = tl.program_id(axis=0) + pid_m = tl.program_id(axis=1) + pid_batch = tl.program_id(axis=2) + + if not IS_VARLEN: + X = X + pid_batch * stride_x_batch + OUT = OUT + pid_batch * stride_out_batch + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + X = X + start_idx * stride_x_seqlen + OUT = OUT + start_idx * stride_out_seqlen + + if pid_m * BLOCK_M >= seqlen: + return + + rh = pid_head * BLOCK_H + tl.arange(0, BLOCK_H) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + + rk_half = tl.arange(0, BLOCK_K // 2) + COS = COS + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) + SIN = SIN + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) + mask_cs = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < ROTARY_DIM_HALF) + cos = tl.load(COS, mask=mask_cs, other=1.0).to(tl.float32) + sin = tl.load(SIN, mask=mask_cs, other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + + if not INTERLEAVED: + X = X + ( + rh[:, None, None] * stride_x_nheads + + rm[None, :, None] * stride_x_seqlen + + rk_half[None, None, :] * stride_x_headdim + ) + OUT = OUT + ( + rh[:, None, None] * stride_out_nheads + + rm[None, :, None] * stride_out_seqlen + + rk_half[None, None, :] * stride_out_headdim + ) + mask = ( + (rh[:, None, None] < nheads) + & (rm[None, :, None] < seqlen) + & (rk_half[None, None, :] < ROTARY_DIM_HALF) + ) + x0 = tl.load(X, mask=mask, other=0.0).to(tl.float32) + x1 = tl.load(X + ROTARY_DIM_HALF * stride_x_headdim, mask=mask, other=0.0).to( + tl.float32 + ) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + tl.store(OUT, o0, mask=mask) + tl.store(OUT + ROTARY_DIM_HALF * stride_out_headdim, o1, mask=mask) + else: + rk = tl.arange(0, BLOCK_K) + X = X + ( + rh[:, None, None] * stride_x_nheads + + rm[None, :, None] * stride_x_seqlen + + rk[None, None, :] * stride_x_headdim + ) + OUT = OUT + ( + rh[:, None, None] * stride_out_nheads + + rm[None, :, None] * stride_out_seqlen + + rk[None, None, :] * stride_out_headdim + ) + mask = ( + (rh[:, None, None] < nheads) + & (rm[None, :, None] < seqlen) + & (rk[None, None, :] < ROTARY_DIM) + ) + x = tl.load(X, mask=mask, other=0.0).to(tl.float32) + x0, x1 = tl.split(tl.reshape(x, [BLOCK_H, BLOCK_M, BLOCK_K // 2, 2])) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + o = tl.reshape(tl.join(o0, o1), [BLOCK_H, BLOCK_M, BLOCK_K]) + tl.store(OUT, o, mask=mask) + + +# ------------------------------- +# Python wrappers for Triton kernels +# ------------------------------- + + +def cast_to_fp8( + x: torch.Tensor, + fp8_dtype: torch.dtype, + layout: Literal["bshd", "thd"], + clamp_val: float = 1e-9, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Cast tensor to FP8 with per-(batch, head) scaling factors.""" + if DEBUG > 0: + print() + print("cast_to_fp8") + print("x:", x, x.shape) + print("fp8_dtype:", fp8_dtype) + print("cu_seqlens:", cu_seqlens) + print("max_seqlen:", max_seqlen) + print("clamp_val:", clamp_val) + + assert x.dtype in { + torch.float16, + torch.float32, + torch.float64, + torch.bfloat16, + } and is_fp8(fp8_dtype), f"Cannot cast {x.dtype} to {fp8_dtype}" + + batch, max_seqlen_final, num_heads, head_dim = get_shape_from_layout( + x, layout, cu_seqlens, max_seqlen + ) + is_varlen = layout == "thd" + fp8_max = torch.finfo(fp8_dtype).max + + padded_head_dim = 1 << (head_dim - 1).bit_length() + padded_head_dim = max(padded_head_dim, 32) + + x_fp8 = torch.zeros_like(x, dtype=fp8_dtype) + descale_factors = torch.zeros( + (batch, num_heads), device=x.device, dtype=torch.float32 + ) + BLOCK_SIZE = 128 + + stride_batch, stride_head, stride_seq, stride_dim = get_stride_from_layout(x, layout) + stride_out_batch, stride_out_head, stride_out_seq, stride_out_dim = get_stride_from_layout(x_fp8, layout) + stride_desc_batch, stride_desc_head = descale_factors.stride() + + grid = (batch, num_heads) + _cast_varlen_to_fp8_kernel_2d[grid]( + x, + x_fp8, + descale_factors, + cu_seqlens, + num_heads, + max_seqlen_final, + stride_batch, + stride_seq, + stride_head, + stride_dim, + stride_out_batch, + stride_out_seq, + stride_out_head, + stride_out_dim, + stride_desc_batch, + stride_desc_head, + clamp_val, + fp8_max, + BLOCK_SIZE=BLOCK_SIZE, + HEAD_DIM=padded_head_dim, + ACTUAL_HEAD_DIM=head_dim, + IS_VARLEN=is_varlen, + ) + + return x_fp8, descale_factors + + +def _apply_rotary_kernel( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + """Apply rotary positional embeddings using Triton kernel.""" + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, "If cu_seqlens is passed, max_seqlen must also be provided" + total_seqlen, nheads, headdim = x.shape + assert cu_seqlens is not None + batch_p_1 = cu_seqlens.shape[0] + batch = batch_p_1 - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim_half = cos.shape + assert sin.shape == cos.shape + rotary_dim = 2 * rotary_dim_half + assert rotary_dim <= headdim + assert headdim <= 256 + assert seqlen_ro >= seqlen + + cos, sin = cos.contiguous(), sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in (torch.int32, torch.int64) + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + out = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + out[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + BLOCK_M = 8 if rotary_dim <= 128 else 4 + grid = ( + triton.cdiv(nheads, 2), + triton.cdiv(seqlen, BLOCK_M), + batch, + ) + + with torch.cuda.device(x.device.index): + torch.library.wrap_triton(_rotary_kernel)[grid]( + out, + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, + nheads, + seqlen_ro, + out.stride(0) if not is_varlen else 0, + out.stride(-3), + out.stride(-2), + out.stride(-1), + x.stride(0) if not is_varlen else 0, + x.stride(-3), + x.stride(-2), + x.stride(-1), + rotary_dim, + isinstance(seqlen_offsets, torch.Tensor), + is_varlen, + interleaved, + conjugate, + BLOCK_M=BLOCK_M, + BLOCK_H=2, + ) + return out + + +class _ApplyRotary(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool, + inplace: bool, + seqlen_offsets: Union[int, torch.Tensor], + cu_seqlens: Optional[torch.Tensor], + max_seqlen: Optional[int], + ) -> torch.Tensor: + out = _apply_rotary_kernel( + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=interleaved, + inplace=inplace, + conjugate=False, + ) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin, cu_seqlens) + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + ctx.inplace = inplace + ctx.max_seqlen = max_seqlen + return out if not inplace else x + + @staticmethod + def backward(ctx, do: torch.Tensor) -> tuple[torch.Tensor, None, None, None, None, None, None, None]: + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cu_seqlens = ctx.saved_tensors + dx = _apply_rotary_kernel( + do, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=ctx.interleaved, + inplace=ctx.inplace, + conjugate=True, + ) + return dx, None, None, None, None, None, None, None + + +def apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool = False, + inplace: bool = False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +) -> torch.Tensor: + """Apply rotary embeddings to tensor x. + + Args: + x: (B, S, H, D) if `cu_seqlens` is None else (total_S, H, D). + cos, sin: (S_rotary, rotary_dim/2) + interleaved: GPT-J style if True. + inplace: modify x in place. + seqlen_offsets: int or (B,) tensor of starting offsets per sequence. + cu_seqlens: (B+1,) tensor enabling varlen mode. + max_seqlen: required when `cu_seqlens` is provided. + """ + original_dtype = x.dtype + is_fp8_input = original_dtype == getattr(torch, "float8_e4m3fn", None) + if is_fp8_input: + target_dtype = ( + torch.bfloat16 + if cos.dtype == torch.bfloat16 or torch.cuda.is_bf16_supported() + else torch.float16 + ) + x_up = x.to(target_dtype) + cos_up = cos.to(target_dtype) if cos.dtype != target_dtype else cos + sin_up = sin.to(target_dtype) if sin.dtype != target_dtype else sin + out_up = _ApplyRotary.apply( + x_up, cos_up, sin_up, interleaved, False, seqlen_offsets, cu_seqlens, max_seqlen + ) + if inplace: + x.copy_(out_up.to(original_dtype)) + return x + return out_up.to(original_dtype) + else: + return _ApplyRotary.apply( + x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen + ) + + +def apply_rotary( + q: torch.Tensor, + k_new: Optional[torch.Tensor], + cos: torch.Tensor, + sin: torch.Tensor, + *, + causal: bool, + local: bool, + interleaved: bool = False, + seqlen_offsets: Union[int, torch.Tensor] = 0, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Apply rotary embeddings to q and optionally k_new. + + Policy: + - If causal OR local attention: apply rotary directly on (B, S, H, D). + - Else (non-causal global): flatten heads into sequence, apply, unflatten. + - k_new is always rotated directly when provided. + """ + assert q.ndim == 4, f"Expected q shape (B,S,H,D), got {q.shape}" + B, S, H, D = q.shape + use_flatten = (not causal) and (not local) + + if use_flatten: + q_flat = q.reshape(B, S * H, D).unsqueeze(1) + q_flat = apply_rotary_emb(q_flat, cos, sin, interleaved=interleaved, seqlen_offsets=seqlen_offsets) + q = q_flat.view(B, 1, S * H, D).reshape(B, S, H, D) + else: + q = apply_rotary_emb(q, cos, sin, interleaved=interleaved, seqlen_offsets=seqlen_offsets) + + if k_new is not None: + k_new = apply_rotary_emb(k_new, cos, sin, interleaved=interleaved, seqlen_offsets=seqlen_offsets) + return q, k_new diff --git a/flash_attn/flash_attn_triton_amd/fp8.py b/flash_attn/flash_attn_triton_amd/fp8.py deleted file mode 100644 index df79c7926b2..00000000000 --- a/flash_attn/flash_attn_triton_amd/fp8.py +++ /dev/null @@ -1,716 +0,0 @@ -from typing import Optional, Sequence, Tuple, Union -import torch -import torch.nn as nn -from .utils import cast_to_fp8, is_fp8 -from . import interface_fa as flash_attn_gpu - - -def maybe_contiguous(x): - return x.contiguous() if x is not None and x.stride(-1) != 1 else x - -class FlashAttnFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # figure out fwd parameters - if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" - q_fp8 = q - k_fp8 = k - v_fp8 = v - out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "bshd") - k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "bshd") - v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "bshd") - out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None - - q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] - _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( - q_fp8, - k_fp8, - v_fp8, - out_fp8, - alibi_slopes, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - return_softmax=return_softmax and dropout_p > 0, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_o=descale_o - ) - if is_grad: - ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_fp8[..., :head_size_og] # NOTE: this used to be out_padded. It might cause issue doing an empty - - # check output type - assert out.dtype == q.dtype, "Input and output type must match otherwise there will be implicit casting by autograd" - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors - head_size_og = dout.size(3) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - - # figure out bwd parameters - if is_fp8(dout): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_do is not None), f"You need to pass descale factors for do" - dout_padded_fp8 = dout_padded - dq, descale_dq = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - dk, descale_dk = torch.zeros_like(k_fp8), torch.zeros_like(descale_k) - dv, descale_dv = torch.zeros_like(v_fp8), torch.zeros_like(descale_v) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "bshd") - dq, descale_dq = torch.zeros_like(q_fp8, dtype=torch.float32), None - dk, descale_dk = torch.zeros_like(k_fp8, dtype=torch.float32), None - dv, descale_dv = torch.zeros_like(v_fp8, dtype=torch.float32), None - - # dq, dk, dv are allocated by us so they should already be contiguous - dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] - flash_attn_gpu.bwd( - dout_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out_fp8, - softmax_lse, - dq, - dk, - dv, - ctx.alibi_slopes, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.deterministic, - None, # gen_ - rng_state, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - descale_dq, - descale_dk, - descale_dv, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None - -def flash_attn_fp8_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None -): - return FlashAttnFP8Func.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - descale_q, - descale_k, - descale_v, - descale_do - ) - -class FlashAttnVarlenFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - block_table, - is_grad_enabled, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # figure out fwd parameters - if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" - q_fp8 = q - k_fp8 = k - v_fp8 = v - out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_q, max_seqlen=max_seqlen_q) - k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_k, max_seqlen=max_seqlen_k) - v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_k, max_seqlen=max_seqlen_k) - out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None - - q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] - _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( - q_fp8, - k_fp8, - v_fp8, - out_fp8, - cu_seqlens_q, - cu_seqlens_k, - None, - None, - block_table, - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - False, - causal, - window_size[0], - window_size[1], - softcap, - return_softmax, - None, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_o=descale_o - ) - if is_grad: - ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) - ctx.dropout_p = dropout_p - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_fp8[..., :head_size_og] # NOTE: this used to be out_padded. It might cause issue doing an empty - - # check output type - assert out.dtype == q.dtype, "Input and output type must match otherwise there will be implicit casting by autograd" - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors - head_size_og = dout.size(2) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - - # figure out bwd parameters - if is_fp8(dout_padded): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_do is not None), f"You need to pass descale factors for do" - dout_padded_fp8 = dout_padded - dq, descale_dq = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - dk, descale_dk = torch.zeros_like(k_fp8), torch.zeros_like(descale_k) - dv, descale_dv = torch.zeros_like(v_fp8), torch.zeros_like(descale_v) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_q, max_seqlen=ctx.max_seqlen_q) - dq, descale_dq = torch.zeros_like(q_fp8, dtype=torch.float32), None - dk, descale_dk = torch.zeros_like(k_fp8, dtype=torch.float32), None - dv, descale_dv = torch.zeros_like(v_fp8, dtype=torch.float32), None - - # dq, dk, dv are allocated by us so they should already be contiguous - dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] - flash_attn_gpu.varlen_bwd( - dout_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out_fp8, - softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - ctx.alibi_slopes, - ctx.max_seqlen_q, - ctx.max_seqlen_k, - ctx.dropout_p, - ctx.softmax_scale, - False, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.deterministic, - None, - rng_state, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - descale_dq, - descale_dk, - descale_dv, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None - - -def flash_attn_varlen_fp8_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - block_table=None -): - return FlashAttnVarlenFP8Func.apply( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - block_table, - torch.is_grad_enabled() - ) - -class FlashAttnQKVPackedFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None - ): - is_grad = is_grad_enabled and qkv.requires_grad - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach() - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # figure out fwd parameters - if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" - q_fp8 = q - k_fp8 = k - v_fp8 = v - out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "bshd") - k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "bshd") - v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "bshd") - out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None - - q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] - _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( - q_fp8, - k_fp8, - v_fp8, - out_fp8, - alibi_slopes, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - return_softmax=return_softmax and dropout_p > 0, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_o=descale_o, - ) - if is_grad: - ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_fp8[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors - qkv_shape = q_fp8.shape[:-2] + (3, *q_fp8.shape[-2:]) - head_size_og = dout.size(3) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - - # figure out bwd parameters - if is_fp8(dout): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_do is not None), f"You need to pass descale factors for do" - dout_padded_fp8 = dout_padded - dqkv, descale_dqkv = torch.zeros(qkv_shape, device=q_fp8.device), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "bshd") - dqkv, descale_dqkv = torch.zeros(qkv_shape, dtype=torch.float32, device=q_fp8.device), None - - - # dq, dk, dv are allocated by us so they should already be contiguous - dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] - flash_attn_gpu.bwd( - dout_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out_fp8, - softmax_lse, - dqkv[:, :, 0], - dqkv[:, :, 1], - dqkv[:, :, 2], - ctx.alibi_slopes, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.deterministic, - None, # gen_ - rng_state, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - None, - None, - None, - ) - dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None - - -def flash_attn_qkvpacked_fp8_func( - qkv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # <=0.0 means deactivate - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - return FlashAttnQKVPackedFP8Func.apply( - qkv, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) - - -class FlashAttnVarlenQKVPackedFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None - ): - is_grad = is_grad_enabled and qkv.requires_grad - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach() - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # figure out fwd parameters - if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" - q_fp8 = q - k_fp8 = k - v_fp8 = v - out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) - k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) - v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) - out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None - - q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] - _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( - q_fp8, - k_fp8, - v_fp8, - out_fp8, - cu_seqlens, - cu_seqlens, - None, - None, - None, - alibi_slopes, - max_seqlen, - max_seqlen, - dropout_p, - softmax_scale, - False, - causal, - window_size[0], - window_size[1], - softcap, - return_softmax, - None, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_o=descale_o - ) - if is_grad: - ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) - ctx.dropout_p = dropout_p - ctx.max_seqlen = max_seqlen - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_fp8[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors - qkv_shape = q_fp8.shape[:-2] + (3, *q_fp8.shape[-2:]) - head_size_og = dout.size(2) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - - # figure out bwd parameters - if is_fp8(dout_padded): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_do is not None), f"You need to pass descale factors for do" - dout_padded_fp8 = dout_padded - dqkv, descale_dqkv = torch.zeros(qkv_shape, device=q_fp8.device), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=ctx.max_seqlen) - dqkv, descale_dqkv = torch.zeros(qkv_shape, dtype=torch.float32, device=q_fp8.device), None - - # dq, dk, dv are allocated by us so they should already be contiguous - dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] - flash_attn_gpu.varlen_bwd( - dout_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out_fp8, - softmax_lse, - dqkv[:, 0], - dqkv[:, 1], - dqkv[:, 2], - cu_seqlens, - cu_seqlens, - ctx.alibi_slopes, - ctx.max_seqlen, - ctx.max_seqlen, - ctx.dropout_p, - ctx.softmax_scale, - False, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.deterministic, - None, - rng_state, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - None, - None, - None, - ) - dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None, None, None - - -def flash_attn_varlen_qkvpacked_fp8_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - return FlashAttnVarlenQKVPackedFP8Func.apply( - qkv, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py old mode 100644 new mode 100755 index 3f2d92c22d6..4581b3f61d8 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -1,72 +1,259 @@ +import os +import warnings import torch import triton import triton.language as tl -from typing import Literal, Optional, Union -from .utils import AUTOTUNE, DEBUG, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna - -def get_cdna_autotune_configs(): - return [ - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - # Fall-back config. - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] - -def get_autotune_configs(): - if AUTOTUNE: - if is_cdna(): - autotune_configs, autotune_keys = get_cdna_autotune_configs() - fwd_auto_tune_configs, fwd_autotune_keys= autotune_configs, autotune_keys - reduce_auto_tune_configs, reduce_autotune_keys = autotune_configs, autotune_keys - return (fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) +from typing import Literal, Optional +from .common import apply_rotary +from .utils import ( + DEBUG, + AUTOTUNE, + get_arch, + get_padded_headsize, + get_shape_from_layout, + get_stride_from_layout, + is_fp8, +) + + +FWD_DECODE_AUTOTUNE_KEYS = [ + "N_CTX_Q", + "N_CTX_K", + "ACTUAL_BLOCK_DMODEL", + "H_q", + "H_kv", + "IS_CAUSAL", + "IS_GQA", +] + +# Maximum BLOCK_M across all configs (for intermediate tensor allocation) +MAX_BLOCK_M = 64 + + +def get_fwd_decode_configs(autotune: bool): + """ + Returns configs for both the splitK kernel and reduce kernel. + + Returns: + (splitk_configs, reduce_config): Tuple of config lists for each kernel + """ + + if not autotune: + arch = get_arch() + + if arch.is_rdna: + return ( + [ + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32}, + num_stages=1, + num_warps=4, + ), + ], + [triton.Config({}, num_stages=1, num_warps=4)], + ) + elif arch.is_cdna: + return ( + [ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1}, + num_stages=1, + num_warps=4, + ), + ], + [triton.Config({}, num_stages=1, num_warps=4)], + ) else: - raise ValueError("Unknown Device Type") + # Default / fallback + return ( + [ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1}, + num_stages=1, + num_warps=4, + ), + ], + [triton.Config({}, num_stages=1, num_warps=4)], + ) + + # ===================== Autotune Sweep ===================== + arch = get_arch() + splitk_configs = [] + + BLOCK_M_OPTIONS = [64, 32, 16] + BLOCK_N_OPTIONS = [128, 64, 32, 16] + NUM_WARPS_OPTIONS = [2, 4] + NUM_STAGES_OPTIONS = [1] + WAVES_PER_EU_OPTIONS = [4, 2, 1] + + # Ensure BLOCK_M options don't exceed MAX_BLOCK_M + assert all(bm <= MAX_BLOCK_M for bm in BLOCK_M_OPTIONS), \ + f"BLOCK_M_OPTIONS {BLOCK_M_OPTIONS} exceeds MAX_BLOCK_M {MAX_BLOCK_M}" + + for bm in BLOCK_M_OPTIONS: + for bn in BLOCK_N_OPTIONS: + for waves in WAVES_PER_EU_OPTIONS: + for nw in NUM_WARPS_OPTIONS: + for ns in NUM_STAGES_OPTIONS: + splitk_configs.append( + triton.Config( + { + "BLOCK_M": bm, + "BLOCK_N": bn, + "waves_per_eu": waves, + }, + num_stages=ns, + num_warps=nw, + ) + ) + + # Reduce kernel configs - sweep num_warps + NUM_WARPS_REDUCE_OPTIONS = [2, 4] + reduce_configs = [ + triton.Config({}, num_stages=1, num_warps=nw) + for nw in NUM_WARPS_REDUCE_OPTIONS + ] + + return splitk_configs, reduce_configs + + +fwd_decode_splitk_configs, fwd_decode_reduce_configs = get_fwd_decode_configs(AUTOTUNE) + + +@triton.jit +def _attn_fwd_inner( + q, + kT, + v, + pos, + col_mask, + m_i, + l_i, + acc, + pid_m, + q_descale, + k_descale, + v_descale, # FP8 scaling factors + alibi_slope, + apply_col_mask, + IS_FP8: tl.constexpr, # FP8 flag + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + N_CTX_Q: tl.constexpr, + N_CTX_K_FINAL: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + IS_CAUSAL: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, +): + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if IS_FP8: + qk += tl.dot(q, kT) * q_descale * k_descale # Apply FP8 scaling else: - autotune_configs, autotune_keys = [ - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - ], [ - "IS_CAUSAL", - "dropout_p", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "ACTUAL_BLOCK_DMODEL", - "VARLEN", - "HQ", - "HK", - ] + qk += tl.dot(q, kT) # noqa: F821 + + if USE_ALIBI: + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = pos + tl.arange(0, BLOCK_N) + + # Compute relative positions + relative_pos = row_idx[:, None] + N_CTX_K_FINAL - (N_CTX_Q + col_idx[None, :]) + relative_pos = tl.abs(relative_pos) + + # Compute ALiBi bias + alibi_bias = -1 * alibi_slope * relative_pos + qk += alibi_bias * 1.44269504 + + # ------------------------------------------------------------------ + # masking + # ------------------------------------------------------------------ + if USE_SLIDING_WINDOW: + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # q positions + col_idx = pos + tl.arange(0, BLOCK_N) # k positions + row = row_idx[:, None] # [M,1] + col = col_idx[None, :] # [1,N] + + if IS_CAUSAL: + # -------- causal + window -------- + diag = N_CTX_K_FINAL - N_CTX_Q # sk-sq + causal_ok = col <= row + diag + if WINDOW_SIZE_LEFT < 0: # only right window + win_ok = col <= row + diag + WINDOW_SIZE_RIGHT + else: # both sides + win_ok = (col >= row + diag - WINDOW_SIZE_LEFT) & ( + col <= row + diag + WINDOW_SIZE_RIGHT + ) + mask = ~(causal_ok & win_ok) # True ⇒ -inf + else: + # -------- non-causal window -------- + sk, sq = N_CTX_K_FINAL, N_CTX_Q + if WINDOW_SIZE_LEFT < 0: + mask = col > row + (sk - sq) + WINDOW_SIZE_RIGHT + else: + right = tl.minimum(row + (sk - sq) + WINDOW_SIZE_RIGHT, sk) + left = row + (sk - sq) - WINDOW_SIZE_LEFT + mask = (col > right) | (col < left) + qk = tl.where(mask, float("-inf"), qk) + else: + if IS_CAUSAL: + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = pos + tl.arange(0, BLOCK_N) + + # create a N_CTX_Q x kv_len causal mask + col_offset = N_CTX_K_FINAL - N_CTX_Q + causal_mask = row_idx[:, None] >= (col_idx[None, :] - col_offset) + + # Apply the mask + qk = tl.where(causal_mask, qk, float("-inf")) + + # Column mask (tail / variable-length). Instead of recomputing an arange each time, + # we accept a precomputed mask from the caller (col_valid_mask). + if apply_col_mask: + # Expect col_mask shape: [BLOCK_N]. True where column is within sequence. + qk = tl.where(col_mask[None, :], qk, float("-inf")) + + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) # per-row max so far - fwd_auto_tune_configs, fwd_autotune_keys= autotune_configs, autotune_keys - reduce_auto_tune_configs, reduce_autotune_keys = autotune_configs, autotune_keys - return (fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) + # rows that are *all* -inf after masking + valid = m_i_new > float("-inf") + # scale previous partial sums safely + alpha = tl.where(valid, tl.math.exp2(m_i - m_i_new), 0.0) -(fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) = get_autotune_configs() + # subtract the row max only on valid rows + qk = tl.where(valid[:, None], qk - m_i_new[:, None], float("-inf")) + p = tl.math.exp2(qk) -# @triton.autotune( -# configs=fwd_auto_tune_configs, -# key=fwd_autotune_keys, -# use_cuda_graph=True, -# ) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(q.dtype) + + # -- scale and update acc -- + acc *= alpha[:, None] + if IS_FP8: + acc += tl.dot(p.to(v.dtype), v) * v_descale # Apply FP8 scaling for V + else: + acc += tl.dot(p.to(v.dtype), v) + + return m_i, l_i, acc + + +@triton.autotune( + configs=fwd_decode_splitk_configs, + key=FWD_DECODE_AUTOTUNE_KEYS, + use_cuda_graph=True, +) @triton.jit def _fwd_kernel_splitK( Q, K, V, + Q_Descale, # FP8 descale factors for Q + K_Descale, # FP8 descale factors for K + V_Descale, # FP8 descale factors for V sm_scale, Out_splitK, # [B*H*G, split_k, Mq, K] Metadata, # [B*H*G, 2, split_k, M_ceil] contains [mi, li] @@ -74,6 +261,7 @@ def _fwd_kernel_splitK( V_new, Cache_seqlens, Cache_batch_idx, + Block_table, Alibi_slopes, stride_qz, stride_qm, @@ -108,13 +296,22 @@ def _fwd_kernel_splitK( stride_vn_g, stride_vn_h, stride_vn_d, - stride_az, + stride_bt_b, + stride_bt_s, + stride_az, stride_ah, + stride_q_descale_z, # FP8 descale strides + stride_q_descale_h, + stride_k_descale_z, + stride_k_descale_h, + stride_v_descale_z, + stride_v_descale_h, Z, N_CTX_Q, N_CTX_K, N_CTX_NEW, BLOCK_N_PER_SPLIT, + BLOCK_SIZE_K: tl.constexpr, H_q: tl.constexpr, H_kv: tl.constexpr, G_q: tl.constexpr, @@ -122,7 +319,6 @@ def _fwd_kernel_splitK( BLOCK_DMODEL: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - BOUNDS_CHECKS_N: tl.constexpr, USE_CACHE_SEQLENs: tl.constexpr, USE_CACHE_BATCH_IDX: tl.constexpr, NEW_KV: tl.constexpr, @@ -131,6 +327,11 @@ def _fwd_kernel_splitK( USE_ALIBI: tl.constexpr, PADDED_HEAD: tl.constexpr, GROUP_SIZE: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, + USE_BLOCK_TABLE: tl.constexpr, + IS_FP8: tl.constexpr, # FP8 flag ): # get program ids pid_m = tl.program_id(0) @@ -150,14 +351,32 @@ def _fwd_kernel_splitK( hk_id = hq_id hv_id = hq_id + # Load FP8 descale factors if needed + if IS_FP8: + if IS_GQA: + # For MQA/GQA, q_descale uses the same indexing as k/v (hk_id) + q_descale = tl.load( + Q_Descale + z_id * stride_q_descale_z + hk_id * stride_q_descale_h + ) + else: + # For MHA, q_descale uses hq_id + q_descale = tl.load( + Q_Descale + z_id * stride_q_descale_z + hq_id * stride_q_descale_h + ) + k_descale = tl.load( + K_Descale + z_id * stride_k_descale_z + hk_id * stride_k_descale_h + ) + v_descale = tl.load( + V_Descale + z_id * stride_v_descale_z + hv_id * stride_v_descale_h + ) + else: + q_descale, k_descale, v_descale = 1.0, 1.0, 1.0 + # figure out seqlens lo = pid_splitk * BLOCK_N_PER_SPLIT if USE_CACHE_SEQLENs: cache_seqlen_last_idx = tl.load(Cache_seqlens + z_id) - if NEW_KV: - N_CTX_K_FINAL = cache_seqlen_last_idx + N_CTX_NEW - else: - N_CTX_K_FINAL = cache_seqlen_last_idx + N_CTX_K_FINAL = cache_seqlen_last_idx else: N_CTX_K_FINAL = N_CTX_K hi = tl.minimum((pid_splitk + 1) * BLOCK_N_PER_SPLIT, N_CTX_K_FINAL) @@ -176,14 +395,29 @@ def _fwd_kernel_splitK( # compute ptrs q_offset = Q + hq_id * stride_qh + z_id * stride_qz + g_id * stride_qg q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd - k_offset = K + hk_id * stride_kh + cache_batch_idx * stride_kz + g_id * stride_kg - v_offset = V + hv_id * stride_vh + cache_batch_idx * stride_vz + g_id * stride_vg + + # Handle block table for paged attention + if USE_BLOCK_TABLE: + # K and V now point to paged cache + # Each batch has its own block table row + block_table_ptr = Block_table + z_id * stride_bt_b + else: + k_offset = ( + K + hk_id * stride_kh + cache_batch_idx * stride_kz + g_id * stride_kg + ) + v_offset = ( + V + hv_id * stride_vh + cache_batch_idx * stride_vz + g_id * stride_vg + ) # compute masks if PADDED_HEAD: q_mask = (offs_m < N_CTX_Q)[:, None] & (offs_d < ACTUAL_BLOCK_DMODEL)[None, :] - kT_mask = (offs_d < ACTUAL_BLOCK_DMODEL)[:, None] & (offs_n < N_CTX_K_FINAL)[None, :] - v_mask = (offs_n < N_CTX_K_FINAL)[:, None] & (offs_d < ACTUAL_BLOCK_DMODEL)[None, :] + kT_mask = (offs_d < ACTUAL_BLOCK_DMODEL)[:, None] & (offs_n < N_CTX_K_FINAL)[ + None, : + ] + v_mask = (offs_n < N_CTX_K_FINAL)[:, None] & (offs_d < ACTUAL_BLOCK_DMODEL)[ + None, : + ] osk_mask = (offs_m < N_CTX_Q)[:, None] & (offs_d < ACTUAL_BLOCK_DMODEL)[None, :] else: q_mask = (offs_m < N_CTX_Q)[:, None] @@ -195,7 +429,7 @@ def _fwd_kernel_splitK( # 2^x instead of exp in the loop because CSE and LICM # don't work as expected with `exp` in the loop qk_scale = sm_scale * 1.44269504 - + # load q: it will stay in SRAM throughout q = tl.load(q_ptrs, mask=q_mask, other=0.0) q = (q * qk_scale).to(q.dtype) @@ -207,137 +441,182 @@ def _fwd_kernel_splitK( else: alibi_slope = None - # Copy new Keys and Values into Cache - if NEW_KV: - knew_base = K_new + hk_id * stride_kn_h + z_id * stride_kn_z + g_id * stride_kn_g - - # Determine the starting position for new data in the cache - if USE_CACHE_SEQLENs: - start_idx = tl.load(Cache_seqlens + z_id) - else: - start_idx = N_CTX_K - N_CTX_NEW - - # Copy new Keys - for i in range(0, N_CTX_NEW, BLOCK_N): - # Load from K_new - k_new_block = tl.load( - knew_base + - tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d + - (tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n, - mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & - (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), - other=0 - ) - - # Store to K - tl.store( - k_offset + - tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kd + - (tl.arange(0, BLOCK_N) + i + start_idx)[None, :] * stride_kn, - k_new_block, - mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & - (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), - ) - - # Copy new Values - vnew_base = V_new + hv_id * stride_vn_h + z_id * stride_vn_z + g_id * stride_vn_g - for i in range(0, N_CTX_NEW, BLOCK_N): - # Load from V_new - v_new_block = tl.load( - vnew_base + - (tl.arange(0, BLOCK_N) + i)[:, None] * stride_vn_n + - tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vn_d, - mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) & - (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), - other=0 - ) - - # Store to V - tl.store( - v_offset + - (tl.arange(0, BLOCK_N) + i + start_idx)[:, None] * stride_vn + - tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vd, - v_new_block, - mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) & - (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), - ) - - # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # noqa: F821 - # loop over k, v and update accumulator - for start_n in range(lo, hi, BLOCK_N): - kT_ptrs = k_offset + offs_d[:, None] * stride_kd + (start_n + offs_n)[None, :] * stride_kn - V_ptrs = v_offset + (start_n + offs_n)[:, None] * stride_vn + offs_d[None, :] * stride_vd - - # load k - kT = tl.load(kT_ptrs, mask=kT_mask, other=0.0) - v = tl.load(V_ptrs, mask=v_mask, other=0.0) - - # -- compute qk --- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, kT) # noqa: F821 - - if USE_ALIBI: - row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - col_idx = start_n + tl.arange(0, BLOCK_N) - - # Compute relative positions - relative_pos = row_idx[:, None] + N_CTX_K_FINAL - (N_CTX_Q + col_idx[None, :]) - relative_pos = tl.abs(relative_pos) - - # Compute ALiBi bias - alibi_bias = -1 * alibi_slope * relative_pos - qk += (alibi_bias * 1.44269504) - - # Apply causal mask if IS_CAUSAL is True - if IS_CAUSAL: - row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - col_idx = start_n + tl.arange(0, BLOCK_N) - - # create a N_CTX_Q x kv_len causal mask - col_offset = N_CTX_Q - N_CTX_K_FINAL - causal_mask = row_idx[:, None] >= (col_offset + col_idx[None, :]) - - # Apply the mask - qk = tl.where(causal_mask, qk, float("-inf")) - - # TODO: This is slow, and only needed at the last iteration. - # Maybe we can unroll the last iteration instead? - if BOUNDS_CHECKS_N: - qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) - - # -- compute scaling constant --- - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - if IS_CAUSAL: - alpha = tl.math.exp2(tl.where(m_i > float("-inf"), m_i - m_i_new, float("-inf"))) - else: - alpha = tl.math.exp2(m_i - m_i_new) - # cause of nan because subtracting infs - if IS_CAUSAL: - qk = tl.where(qk > float("-inf"), qk - m_i_new[:, None], float("-inf")) - else: - qk = qk - m_i_new[:, None] - - p = tl.math.exp2(qk) - - # -- update m_i and l_i -- - l_i = l_i * alpha + tl.sum(p, 1) - m_i = m_i_new - p = p.to(Q.dtype.element_ty) + if USE_BLOCK_TABLE: + # Paged attention: process all KV blocks from cache + # Note: Cache should be updated externally before calling this kernel + num_kv_blocks = (N_CTX_K_FINAL + BLOCK_SIZE_K - 1) // BLOCK_SIZE_K + + for block_idx in range(num_kv_blocks): + # Calculate sequence range for this block + block_start = block_idx * BLOCK_SIZE_K + block_end = tl.minimum(block_start + BLOCK_SIZE_K, N_CTX_K_FINAL) + + # Check if block overlaps with our split-k range [lo, hi) + if block_end > lo and block_start < hi: + # Load physical block number + physical_block = tl.load(block_table_ptr + block_idx * stride_bt_s) + + # Calculate the range within this block that overlaps with [lo, hi) + process_start = tl.maximum(lo - block_start, 0) + process_end = tl.minimum(hi - block_start, BLOCK_SIZE_K) + process_end = tl.minimum(process_end, block_end - block_start) + + # Instead of forcing a floor alignment to BLOCK_N (which can still skip + # part of the intended range if start falls mid-tile for small splits), + # start from the raw (possibly unaligned) process_start rounded *down* but + # allow the loop to begin earlier (at most BLOCK_N before) so that any + # partial tile overlapping [lo, hi) is covered. Masking below will remove + # columns < lo or >= hi ensuring numerically identical coverage without + # duplication. + aligned_start = (process_start // BLOCK_N) * BLOCK_N + if aligned_start > 0 and aligned_start + BLOCK_N > process_start: + # ensure we include the tile that contains process_start + process_start = aligned_start + else: + process_start = aligned_start + + for offset in range(process_start, process_end, BLOCK_N): + # Current position (may begin slightly before logical split range; masking fixes it) + pos = block_start + offset + # Proceed unconditionally; masking below enforces [lo, hi) + # Calculate base addresses for K and V in this physical block + k_base = ( + K + + physical_block * BLOCK_SIZE_K * stride_kn + + hk_id * stride_kh + + g_id * stride_kg + ) + v_base = ( + V + + physical_block * BLOCK_SIZE_K * stride_vn + + hv_id * stride_vh + + g_id * stride_vg + ) + + # Offsets within the current block + block_offs = offset + offs_n + + # Masks for valid data respecting: + # (1) global key length (seq_mask) + # (2) block bounds (block_mask) + # (3) current split range [lo, hi) + seq_mask = (pos + offs_n) < N_CTX_K_FINAL + block_mask = block_offs < BLOCK_SIZE_K + end_mask = block_offs < process_end + split_mask = ((pos + offs_n) >= lo) & ((pos + offs_n) < hi) + col_mask = seq_mask & block_mask & end_mask & split_mask + + # Apply masks + kT_mask_final = kT_mask & col_mask[None, :] + v_mask_final = v_mask & col_mask[:, None] + + # Load K and V + kT_ptrs = ( + k_base + + offs_d[:, None] * stride_kd + + block_offs[None, :] * stride_kn + ) + v_ptrs = ( + v_base + + block_offs[:, None] * stride_vn + + offs_d[None, :] * stride_vd + ) + + kT = tl.load(kT_ptrs, mask=kT_mask_final, other=0.0) + v = tl.load(v_ptrs, mask=v_mask_final, other=0.0) + + # Unified inner function handles both paged and contiguous + m_i, l_i, acc = _attn_fwd_inner( + q, + kT, + v, + pos, + col_mask, + m_i, + l_i, + acc, + pid_m, + q_descale, + k_descale, + v_descale, + alibi_slope, + True, + IS_FP8, + BLOCK_M, + BLOCK_N, + N_CTX_Q, + N_CTX_K_FINAL, + USE_ALIBI, + USE_SLIDING_WINDOW, + IS_CAUSAL, + WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT, + ) + else: + # Non-paged attention: process KV from cache + # Note: Cache should be updated externally before calling this kernel + # Compute bounds check flag once: needed if split size not aligned to BLOCK_N or variable seqlens + bounds_checks_n = ((BLOCK_N_PER_SPLIT % BLOCK_N) > 0) | USE_CACHE_SEQLENs + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + kT_ptrs = ( + k_offset + + offs_d[:, None] * stride_kd + + (start_n + offs_n)[None, :] * stride_kn + ) + V_ptrs = ( + v_offset + + (start_n + offs_n)[:, None] * stride_vn + + offs_d[None, :] * stride_vd + ) - # -- scale and update acc -- - acc *= alpha[:, None] - acc += tl.dot(p.to(v.dtype), v) + # load k + kT = tl.load(kT_ptrs, mask=kT_mask, other=0.0) + v = tl.load(V_ptrs, mask=v_mask, other=0.0) + + # Use the same inner loop logic + # Precompute column validity mask for this tile (all True for full tiles). + # hi is the upper bound of the overall split range; start_n marks this tile's base. + col_valid_mask = offs_n < (hi - start_n) + + m_i, l_i, acc = _attn_fwd_inner( + q, + kT, + v, + start_n, + col_valid_mask, + m_i, + l_i, + acc, + pid_m, + q_descale, + k_descale, + v_descale, + alibi_slope, + bounds_checks_n, + IS_FP8, + BLOCK_M, + BLOCK_N, + N_CTX_Q, + N_CTX_K_FINAL, + USE_ALIBI, + USE_SLIDING_WINDOW, + IS_CAUSAL, + WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT, + ) # write back O osk_offset = Out_splitK + pid_zhg * stride_osk_zhg + pid_splitk * stride_osk_s - osk_ptrs = osk_offset + offs_m[:, None] * stride_osk_m + offs_d[None, :] * stride_osk_d + osk_ptrs = ( + osk_offset + offs_m[:, None] * stride_osk_m + offs_d[None, :] * stride_osk_d + ) tl.store( osk_ptrs, acc, @@ -351,11 +630,17 @@ def _fwd_kernel_splitK( tl.store(metadata_ptr + stride_m2, l_i) -# @triton.autotune( -# configs=reduce_auto_tune_configs, -# key=reduce_autotune_keys, -# use_cuda_graph=True, -# ) +FWD_DECODE_REDUCE_AUTOTUNE_KEYS = [ + "BLOCK_DMODEL", + "split_k", +] + + +@triton.autotune( + configs=fwd_decode_reduce_configs, + key=FWD_DECODE_REDUCE_AUTOTUNE_KEYS, + use_cuda_graph=True, +) @triton.jit def _splitK_reduce( Out_splitK, # [B*H*G, split_k, Mq, K] @@ -385,7 +670,6 @@ def _splitK_reduce( split_k: tl.constexpr, splitK_pow2: tl.constexpr, MASK_SPLITK: tl.constexpr, - IS_CAUSAL: tl.constexpr, PADDED_HEAD: tl.constexpr, ): # get pids @@ -397,7 +681,6 @@ def _splitK_reduce( offs_splitK = tl.arange(0, splitK_pow2) offs_k = pid_k * K_BLOCK_SIZE + tl.arange(0, K_BLOCK_SIZE) - # compute masks if PADDED_HEAD: o_mask = offs_k < ACTUAL_BLOCK_DMODEL @@ -409,7 +692,11 @@ def _splitK_reduce( metadata_ptr = metadata_offset + offs_splitK * stride_ms + pid_m * stride_mm osk_offset = Out_splitK + pid_zhg * stride_osk_zhg + pid_m * stride_osk_m - osk_ptr = osk_offset + offs_splitK[:, None] * stride_osk_s + offs_k[None, :] * stride_osk_k + osk_ptr = ( + osk_offset + + offs_splitK[:, None] * stride_osk_s + + offs_k[None, :] * stride_osk_k + ) # read max values of each splitK if MASK_SPLITK: @@ -423,40 +710,29 @@ def _splitK_reduce( acc = tl.load(osk_ptr) g_m = tl.max(l_m, axis=0) - - if IS_CAUSAL: - l_m_offset = l_m - g_m - alpha = tl.where(l_m_offset > float("-inf"), tl.math.exp2(l_m_offset), 0.0) - else: - alpha = tl.math.exp2(l_m - g_m) + + alpha = tl.where(l_m > float("-inf"), tl.math.exp2(l_m - g_m), 0.0) # read sum l_sum *= alpha g_sum = tl.sum(l_sum, axis=0) acc = acc * alpha[:, None] - if IS_CAUSAL: - # Avoid division by zero - g_sum_safe = tl.where(g_sum > 0, g_sum, 1.0) - acc_out = tl.sum(acc, axis=0) / g_sum_safe - else: - acc_out = tl.sum(acc, axis=0) / g_sum + g_sum_safe = tl.where(g_sum > 0, g_sum, 1.0) + acc_out = tl.sum(acc, axis=0) / g_sum_safe # Store output z_id = pid_zhg // (H * G) h_id = (pid_zhg // G) % H g_id = pid_zhg % G - out_offset = Out + z_id * stride_oz + h_id * stride_oh + g_id * stride_og + out_offset = Out + z_id * stride_oz + h_id * stride_oh + g_id * stride_og out_ptr = out_offset + pid_m * stride_om + offs_k tl.store(out_ptr, acc_out, mask=o_mask) # Store lse l_ptrs = LSE + pid_zhg * stride_lse_zhg + pid_m - if IS_CAUSAL: - lse = tl.where(g_sum > 0, (g_m + tl.math.log2(g_sum)) / 1.44269504, g_m) - tl.store(l_ptrs, lse) - else: - tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504) + lse_val = tl.where(g_sum > 0, (g_m + tl.math.log2(g_sum)) / 1.44269504, g_m) + tl.store(l_ptrs, lse_val) @triton.jit @@ -468,6 +744,7 @@ def cast_uint32_to_half2(scale_shift): shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) return scale, shift + @triton.jit def dequantize( x_, @@ -477,14 +754,18 @@ def dequantize( ): # PACKED_PER_VAL is the number of values packed into # each element x_. For example, for int4 quantization - #and x_ of type int32, PACKED_PER_VAL is 8. + # and x_ of type int32, PACKED_PER_VAL is 8. BLOCK_N: tl.constexpr = x_.shape[0] BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] offsets = tl.arange(0, PACKED_PER_VAL) * 4 - quant_offset = (x_[:, None, :] >> offsets[None, :, None]) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) + quant_offset = ( + x_[:, None, :] >> offsets[None, :, None] + ) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) - quant_offset = tl.view(quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)) + quant_offset = tl.view( + quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL) + ) # Trick - instead of converting int4 to float16 we view it as float16 # and then multiply by 32768 * 512 == 2**24 quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) @@ -494,6 +775,7 @@ def dequantize( dequant = quant_offset * scale_512 + shift return dequant + def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: # Scale and shift are such that quantization linearly maps # int4 values range [0..15] to input values range min(k)..max(k) @@ -511,7 +793,9 @@ def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: in_bytes = in_bytes.to(torch.uint8) in_int4 = in_bytes & 0xF in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4) - scale_shift = torch.concat([scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1) + scale_shift = torch.concat( + [scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1 + ) k_quant = torch.concat( [ scale_shift.flatten(start_dim=-2), @@ -528,7 +812,9 @@ def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tens ss_size = num_groups * 4 scale_shift_ui8 = k_ui8[..., 0:ss_size] - scale_shift_ui8 = scale_shift_ui8.reshape(*scale_shift_ui8.shape[:-1], num_groups, 4) + scale_shift_ui8 = scale_shift_ui8.reshape( + *scale_shift_ui8.shape[:-1], num_groups, 4 + ) scale = scale_shift_ui8[..., 0:2].view(torch.float16) shift = scale_shift_ui8[..., 2:4].view(torch.float16) @@ -540,7 +826,11 @@ def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tens k1_f16 = k1_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) k2_f16 = k2_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) - out = torch.empty((*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), dtype=torch.float16, device=quant_k.device) + out = torch.empty( + (*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), + dtype=torch.float16, + device=quant_k.device, + ) out[..., ::2] = k1_f16 out[..., 1::2] = k2_f16 out = out.reshape(*k_shape[:-2], -1) @@ -561,69 +851,194 @@ def get_split_k(B: int, G: int, H: int, Mk: int) -> int: split_k = max(split_k, 1) return split_k -def attention_decode_forward_triton_impl( - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - k_new: Optional[torch.Tensor], - v_new: Optional[torch.Tensor], - out: torch.Tensor, - sm_scale: float, - causal: bool, - alibi_slopes: Optional[torch.Tensor], - layout: Literal["bshd"], - cache_seqlens: Optional[Union[(int, torch.Tensor)]], - cache_batch_idx: Optional[torch.Tensor], + +def attention_forward_decode_triton_impl( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_new: Optional[torch.Tensor], + v_new: Optional[torch.Tensor], + out: torch.Tensor, + softmax_lse: torch.Tensor, + sm_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + alibi_slopes: Optional[torch.Tensor], + layout: Literal["bshd"], + cache_seqlens: Optional[torch.Tensor], + cache_batch_idx: Optional[torch.Tensor], + block_table: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + # rotary (optional) + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, + rotary_interleaved: bool = False, + seqlens_rotary: Optional[torch.Tensor] = None, ): - # triton configs - BLOCK_M = 16 - BLOCK_N = 64 - num_stages = 1 - num_warps_fwd = 1 - num_warps_reduce = 4 - + # Validate layout at entry + if layout != "bshd": + raise ValueError(f"{layout} layout is not supported, only 'bshd' is supported") + + # apply rotary embedding + if rotary_cos is not None and rotary_sin is not None: + # Prefer explicitly provided rotary sequence start offsets if given; fall back to cache_seqlens. + seqlen_offsets = ( + seqlens_rotary + if seqlens_rotary is not None + else (cache_seqlens if cache_seqlens is not None else 0) + ) + local = (window_size_left != -1) or (window_size_right != -1) + q, k_new = apply_rotary( + q, + k_new, + rotary_cos, + rotary_sin, + causal=causal, + local=local, + interleaved=rotary_interleaved, + seqlen_offsets=seqlen_offsets, + ) + + # handle cache updates + if k_new is not None and v_new is not None: + # Update cache with new KV values + if block_table is None: + # Non-paged attention: update cache directly + batch_size = k_new.shape[0] + seqlen_new = k_new.shape[1] + + if cache_seqlens is not None: + # Use cache_seqlens to determine where to insert new KV + for b in range(batch_size): + start_idx = int(cache_seqlens[b].item()) + end_idx = start_idx + seqlen_new + k_cache[b, start_idx:end_idx] = k_new[b] + v_cache[b, start_idx:end_idx] = v_new[b] + cache_seqlens[b] = end_idx + else: + # Append at the end of existing cache + seqlen_cache = k_cache.shape[1] + k_cache[:, seqlen_cache - seqlen_new :] = k_new + v_cache[:, seqlen_cache - seqlen_new :] = v_new + else: + # Paged attention: update cache using block table + batch_size = k_new.shape[0] + seqlen_new = k_new.shape[1] + block_size = k_cache.shape[ + 1 + ] # k_cache shape: [num_blocks, block_size, nheads, head_dim] + + # Update cache for each batch element + for b in range(batch_size): + if cache_seqlens is not None: + start_idx = int(cache_seqlens[b].item()) + else: + # If no cache_seqlens, assume we're appending at the end + # Find the last used position from block table + start_idx = 0 + for block_idx in range(block_table.shape[1]): + if block_table[b, block_idx] >= 0: + start_idx = (block_idx + 1) * block_size + else: + start_idx = block_idx * block_size + break + + # Copy new KV values into the paged cache + for i in range(seqlen_new): + pos = start_idx + i + block_idx = pos // block_size + within_block_idx = pos % block_size + + # Get the physical block number from block table + if block_idx < block_table.shape[1]: + physical_block = int(block_table[b, block_idx].item()) + + # Update k_cache and v_cache at the physical block location + k_cache[physical_block, within_block_idx] = k_new[b, i] + v_cache[physical_block, within_block_idx] = v_new[b, i] + + # Update cache_seqlens if provided + if cache_seqlens is not None: + cache_seqlens[b] = start_idx + seqlen_new + # kernel_configs - is_new_kv = True if k_new is not None and v_new is not None else False - use_alibi = False if alibi_slopes is None else True + is_new_kv = False # Cache has been updated, so no new KV in kernel + use_alibi, (stride_az, stride_ah) = True if alibi_slopes is not None else False, ( + alibi_slopes.stride() if alibi_slopes is not None else (None, None) + ) use_cache_seqlens = cache_seqlens is not None - SPLIT_K = None + use_sliding_window = window_size_left != -1 or window_size_right != -1 + use_block_table = block_table is not None NUM_QUANT_GROUPS = 1 # get shapes and strides - (batch_size, seqlen_q, nheads_q, dim_q), (stride_qz, stride_qh, stride_qm, stride_qd) = get_shape_and_strides_from_layout(q, layout) - (_, seqlen_kc, nheads_kc, dim_kc), (stride_kc_z, stride_kc_h, stride_kc_n, stride_kc_d) = get_shape_and_strides_from_layout(k_cache, layout) - (_, seqlen_vc, nheads_vc, dim_vc), (stride_vc_z, stride_vc_h, stride_vc_n, stride_vc_d) = get_shape_and_strides_from_layout(v_cache, layout) - if is_new_kv: - ( _, seqlen_kn, nheads_kn, dim_kn), (stride_kn_z, stride_kn_h, stride_kn_n, stride_kn_d) = get_shape_and_strides_from_layout(k_new, layout) - (_, seqlen_vn, nheads_vn, dim_vn), (stride_vn_z, stride_vn_h, stride_vn_n, stride_vn_d) = get_shape_and_strides_from_layout(v_new, layout) + batch_size, seqlen_q, nheads_q, dim_q = get_shape_from_layout(q, layout) + stride_qz, stride_qh, stride_qm, stride_qd = get_stride_from_layout(q, layout) + + # Handle paged KV cache layout + if use_block_table: + # For paged attention, k_cache and v_cache have shape [num_blocks, block_size, nheads, head_dim] + num_blocks_kc, block_size_k, nheads_kc, dim_kc = k_cache.shape + num_blocks_vc, block_size_v, nheads_vc, dim_vc = v_cache.shape + # Get the actual sequence length from cache_seqlens or block_table + if cache_seqlens is not None: + seqlen_kc = int(cache_seqlens.max().item()) + else: + # Infer from block_table shape [batch_size, num_blocks_per_seq] + assert block_table is not None + num_blocks_per_seq = block_table.shape[1] + seqlen_kc = num_blocks_per_seq * block_size_k + seqlen_vc = seqlen_kc + + # Strides for paged layout + stride_kc_z = 0 # No batch dimension in paged cache + stride_kc_n = k_cache.stride(1) # Sequence stride + stride_kc_h = k_cache.stride(2) # Head stride + stride_kc_d = k_cache.stride(3) # Dim stride + + stride_vc_z = 0 + stride_vc_n = v_cache.stride(1) + stride_vc_h = v_cache.stride(2) + stride_vc_d = v_cache.stride(3) else: - ( _, seqlen_kn, nheads_kn, dim_kn), (stride_kn_z, stride_kn_h, stride_kn_n, stride_kn_d) = (None, None, None, None), (None, None, None, None) - (_, seqlen_vn, nheads_vn, dim_vn), (stride_vn_z, stride_vn_h, stride_vn_n, stride_vn_d) = (None, None, None, None), (None, None, None, None) - (_, seqlen_o, nheads_o, dim_o), (stride_oz, stride_oh, stride_om, stride_od) = get_shape_and_strides_from_layout(out, layout) - if use_alibi: - stride_az, stride_ah = alibi_slopes.stride() + _, seqlen_kc, nheads_kc, dim_kc = get_shape_from_layout(k_cache, layout) + stride_kc_z, stride_kc_h, stride_kc_n, stride_kc_d = get_stride_from_layout(k_cache, layout) + _, seqlen_vc, nheads_vc, dim_vc = get_shape_from_layout(v_cache, layout) + stride_vc_z, stride_vc_h, stride_vc_n, stride_vc_d = get_stride_from_layout(v_cache, layout) + block_size_k = 0 # Not used + if is_new_kv: + _, seqlen_kn, nheads_kn, dim_kn = get_shape_from_layout(k_new, layout) + stride_kn_z, stride_kn_h, stride_kn_n, stride_kn_d = get_stride_from_layout(k_new, layout) + _, seqlen_vn, nheads_vn, dim_vn = get_shape_from_layout(v_new, layout) + stride_vn_z, stride_vn_h, stride_vn_n, stride_vn_d = get_stride_from_layout(v_new, layout) else: - stride_az, stride_ah = (None, None) - - assert dim_q == dim_kc == dim_vc, f"Dimensions must match: {dim_q}, {dim_kc}, {dim_vc}" + _, seqlen_kn, nheads_kn, dim_kn = None, None, None, None + stride_kn_z, stride_kn_h, stride_kn_n, stride_kn_d = None, None, None, None + _, seqlen_vn, nheads_vn, dim_vn = None, None, None, None + stride_vn_z, stride_vn_h, stride_vn_n, stride_vn_d = None, None, None, None + _, seqlen_o, nheads_o, dim_o = get_shape_from_layout(out, layout) + stride_oz, stride_oh, stride_om, stride_od = get_stride_from_layout(out, layout) + assert ( + dim_q == dim_kc == dim_vc + ), f"Dimensions must match: {dim_q}, {dim_kc}, {dim_vc}" # add extra information needed by the kernels - if layout == "bshd": - (n_group_q, heads_per_group_q), stride_qg = (1, nheads_q), stride_qm - (n_group_k, heads_per_group_k), stride_kc_g = (1, nheads_kc), stride_kc_n - (n_group_v, heads_per_group_v), stride_vc_g = (1, nheads_vc), stride_vc_n - if is_new_kv: - (n_group_kn, heads_per_group_kn), stride_kn_g = (1, nheads_kn), stride_kn_n - (n_group_vn, heads_per_group_vn), stride_vn_g = (1, nheads_vn), stride_vn_n - else: - (n_group_kn, heads_per_group_kn), stride_kn_g = (None, None), None - (n_group_vn, heads_per_group_vn), stride_vn_g = (None, None), None - (n_group_o, heads_per_group_o), stride_og = (1, nheads_o), stride_om + (n_group_q, heads_per_group_q), stride_qg = (1, nheads_q), stride_qm + (n_group_k, heads_per_group_k), stride_kc_g = (1, nheads_kc), stride_kc_n + (n_group_v, heads_per_group_v), stride_vc_g = (1, nheads_vc), stride_vc_n + if is_new_kv: + (n_group_kn, heads_per_group_kn), stride_kn_g = (1, nheads_kn), stride_kn_n + (n_group_vn, heads_per_group_vn), stride_vn_g = (1, nheads_vn), stride_vn_n else: - raise ValueError(f"{layout} layout is not supported") + (n_group_kn, heads_per_group_kn), stride_kn_g = (None, None), None + (n_group_vn, heads_per_group_vn), stride_vn_g = (None, None), None + (n_group_o, heads_per_group_o), stride_og = (1, nheads_o), stride_om # get padded size - dim_padded = get_padded_headsize(dim_kc) + dim_padded = get_padded_headsize(dim_kc) is_padded_head = dim_padded != dim_kc # Handle MQA/GQA case @@ -633,54 +1048,197 @@ def attention_decode_forward_triton_impl( else: is_gqa = False - if SPLIT_K is not None: - split_k = SPLIT_K + # Use heuristics for split_k + if use_block_table: + # For paged attention, use the actual sequence length from cache_seqlens + max_seqlen = ( + int(cache_seqlens.max().item()) + if cache_seqlens is not None + else block_size_k + ) + split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, max_seqlen) else: - # Use heuristics - split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_kc) # NOTE: should the split think about seqlens? + split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_kc) split_size = (seqlen_kc + split_k - 1) // split_k - # setup grid - seqlen_q_ceil = (seqlen_q + BLOCK_M - 1) // BLOCK_M * BLOCK_M - grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch_size * n_group_q * heads_per_group_q, split_k) - + # setup grid - use lambda to get BLOCK_M from autotune + # Use MAX_BLOCK_M for intermediate tensor allocation to ensure enough space + seqlen_q_ceil = (seqlen_q + MAX_BLOCK_M - 1) // MAX_BLOCK_M * MAX_BLOCK_M + grid = lambda META: ( + triton.cdiv(seqlen_q, META["BLOCK_M"]), + batch_size * n_group_q * heads_per_group_q, + split_k, + ) + # create intermediate tensors - out_splitk = torch.empty([batch_size * n_group_q * heads_per_group_q, split_k, seqlen_q_ceil, dim_kc], dtype=torch.float32, device=q.device) - metadata = torch.empty([batch_size * n_group_q * heads_per_group_q, 2, split_k, seqlen_q_ceil], dtype=torch.float32, device=q.device) - lse = torch.empty((batch_size * n_group_q * heads_per_group_q, seqlen_q), dtype=torch.float32, device=q.device) - + out_splitk = torch.empty( + [batch_size * n_group_q * heads_per_group_q, split_k, seqlen_q_ceil, dim_kc], + dtype=torch.float32, + device=q.device, + ) + metadata = torch.empty( + [batch_size * n_group_q * heads_per_group_q, 2, split_k, seqlen_q_ceil], + dtype=torch.float32, + device=q.device, + ) + + # Validate pre-allocated softmax_lse tensor + # Expected shape after view: (batch_size, n_group_q * heads_per_group_q, seqlen_q) + # Internal shape: (batch_size * n_group_q * heads_per_group_q, seqlen_q) + expected_h_total = batch_size * n_group_q * heads_per_group_q + assert ( + softmax_lse.shape[0] == batch_size + ), f"softmax_lse.shape[0] ({softmax_lse.shape[0]}) must equal batch_size ({batch_size})" + assert ( + softmax_lse.shape[1] == n_group_q * heads_per_group_q + ), f"softmax_lse.shape[1] ({softmax_lse.shape[1]}) must equal n_group_q * heads_per_group_q ({n_group_q * heads_per_group_q})" + assert ( + softmax_lse.shape[2] >= seqlen_q + ), f"softmax_lse.shape[2] ({softmax_lse.shape[2]}) must be >= seqlen_q ({seqlen_q})" + assert ( + softmax_lse.dtype == torch.float32 + ), f"softmax_lse must be float32, got {softmax_lse.dtype}" + assert softmax_lse.device == q.device, f"softmax_lse must be on same device as q" + + # Create internal lse view for kernel use + lse = softmax_lse.view(expected_h_total, -1)[:, :seqlen_q].contiguous() + # get intermediate tensor strides stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d = out_splitk.stride() stride_mzhg, stride_m2, stride_ms, stride_mm = metadata.stride() stride_lse_zhg, stride_lse_m = lse.stride() - if False: - print("batch_size, seqlen_q, nheads_q, dim_q", (batch_size, seqlen_q, nheads_q, dim_q)) + # Block table strides + if use_block_table: + assert block_table is not None + stride_bt_b, stride_bt_s = block_table.stride() + else: + stride_bt_b, stride_bt_s = 0, 0 + + # FP8 support + IS_FP8 = is_fp8([q, k_cache, v_cache]) + if IS_FP8: + arch = get_arch() + if not arch.supports_fp8: + raise RuntimeError( + f"{arch.name} does not support FP8" + ) + rec_dtype = arch.recommended_fp8_dtype(q.dtype) + if ( + q.dtype != rec_dtype + or k_cache.dtype != rec_dtype + or v_cache.dtype != rec_dtype + ): + warnings.warn( + f"Use {rec_dtype} data type on {arch}. Got q: {q.dtype}, k: {k_cache.dtype}, v: {v_cache.dtype}", + UserWarning, + ) + if (q_descale is None) or (k_descale is None) or (v_descale is None): + warnings.warn( + "FP8 tensors detected but descale factors not provided. Using default scale of 1.0", + UserWarning, + ) + # Create default descale tensors if not provided + if q_descale is None: + q_descale = torch.ones( + batch_size, nheads_q, dtype=torch.float32, device=q.device + ) + if k_descale is None: + k_descale = torch.ones( + batch_size, nheads_kc, dtype=torch.float32, device=q.device + ) + if v_descale is None: + v_descale = torch.ones( + batch_size, nheads_vc, dtype=torch.float32, device=q.device + ) + else: + # Enforce exact expected shapes; no reshaping or normalization. + assert ( + q_descale.dim() == 2 + and q_descale.shape[0] == batch_size + and q_descale.shape[1] == nheads_kc + ), f"q_descale expected shape ({batch_size}, {nheads_kc}) got {tuple(q_descale.shape)}" + assert ( + k_descale.dim() == 2 + and k_descale.shape[0] == batch_size + and k_descale.shape[1] == nheads_kc + ), f"k_descale expected shape ({batch_size}, {nheads_kc}) got {tuple(k_descale.shape)}" + assert ( + v_descale.dim() == 2 + and v_descale.shape[0] == batch_size + and v_descale.shape[1] == nheads_kc + ), f"v_descale expected shape ({batch_size}, {nheads_kc}) got {tuple(v_descale.shape)}" + stride_q_descale_z, stride_q_descale_h = q_descale.stride() + stride_k_descale_z, stride_k_descale_h = k_descale.stride() + stride_v_descale_z, stride_v_descale_h = v_descale.stride() + else: + q_descale = None + k_descale = None + v_descale = None + stride_q_descale_z = 0 + stride_q_descale_h = 0 + stride_k_descale_z = 0 + stride_k_descale_h = 0 + stride_v_descale_z = 0 + stride_v_descale_h = 0 + + if DEBUG: + print( + "batch_size, seqlen_q, nheads_q, dim_q", + (batch_size, seqlen_q, nheads_q, dim_q), + ) print("_, seqlen_kc, nheads_kc, dim_kc", (_, seqlen_kc, nheads_kc, dim_kc)) print("dim_padded:", dim_padded) - print("stride_qz, stride_qm, stride_qg, stride_qh, stride_qd", (stride_qz, stride_qm, stride_qg, stride_qh, stride_qd)) - print("stride_kc_z, stride_kc_n, stride_kc_g, stride_kc_h, stride_kc_d", (stride_kc_z, stride_kc_n, stride_kc_g, stride_kc_h, stride_kc_d)) - print("stride_vc_z, stride_vc_n, stride_vc_g, stride_vc_h, stride_vc_d", (stride_vc_z, stride_vc_n, stride_vc_g, stride_vc_h, stride_vc_d)) + print( + "stride_qz, stride_qm, stride_qg, stride_qh, stride_qd", + (stride_qz, stride_qm, stride_qg, stride_qh, stride_qd), + ) + print( + "stride_kc_z, stride_kc_n, stride_kc_g, stride_kc_h, stride_kc_d", + (stride_kc_z, stride_kc_n, stride_kc_g, stride_kc_h, stride_kc_d), + ) + print( + "stride_vc_z, stride_vc_n, stride_vc_g, stride_vc_h, stride_vc_d", + (stride_vc_z, stride_vc_n, stride_vc_g, stride_vc_h, stride_vc_d), + ) if is_new_kv: - print("stride_kn_z, stride_kn_n, stride_kn_g, stride_kn_h, stride_kn_d", (stride_kn_z, stride_kn_n, stride_kn_g, stride_kn_h, stride_kn_d)) - print("stride_vn_z, stride_vn_n, stride_vn_g, stride_vn_h, stride_vn_d", (stride_vn_z, stride_vn_n, stride_vn_g, stride_vn_h, stride_vn_d)) - print("stride_oz, stride_om, stride_og, stride_oh, stride_od", (stride_oz, stride_om, stride_og, stride_oh, stride_od)) - print("stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d", (stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d)) - print("stride_mzhg, stride_m2, stride_ms, stride_mm", (stride_mzhg, stride_m2, stride_ms, stride_mm)) + print( + "stride_kn_z, stride_kn_n, stride_kn_g, stride_kn_h, stride_kn_d", + (stride_kn_z, stride_kn_n, stride_kn_g, stride_kn_h, stride_kn_d), + ) + print( + "stride_vn_z, stride_vn_n, stride_vn_g, stride_vn_h, stride_vn_d", + (stride_vn_z, stride_vn_n, stride_vn_g, stride_vn_h, stride_vn_d), + ) + print( + "stride_oz, stride_om, stride_og, stride_oh, stride_od", + (stride_oz, stride_om, stride_og, stride_oh, stride_od), + ) + print( + "stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d", + (stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d), + ) + print( + "stride_mzhg, stride_m2, stride_ms, stride_mm", + (stride_mzhg, stride_m2, stride_ms, stride_mm), + ) print("stride_lse_zhg, stride_lse_m", (stride_lse_zhg, stride_lse_m)) - # TODO: enable quantization _fwd_kernel_splitK[grid]( Q=q, K=k_cache, V=v_cache, + Q_Descale=q_descale, + K_Descale=k_descale, + V_Descale=v_descale, sm_scale=sm_scale, Out_splitK=out_splitk, Metadata=metadata, - K_new=k_new, - V_new=v_new, + K_new=None, + V_new=None, Cache_seqlens=cache_seqlens, Cache_batch_idx=cache_batch_idx, + Block_table=block_table, Alibi_slopes=alibi_slopes, # q strides stride_qz=stride_qz, @@ -722,32 +1280,43 @@ def attention_decode_forward_triton_impl( stride_vn_g=stride_vn_g, stride_vn_h=stride_vn_h, stride_vn_d=stride_vn_d, + # block table strides + stride_bt_b=stride_bt_b, + stride_bt_s=stride_bt_s, # alibi strides stride_az=stride_az, stride_ah=stride_ah, + # FP8 descale strides + stride_q_descale_z=stride_q_descale_z, + stride_q_descale_h=stride_q_descale_h, + stride_k_descale_z=stride_k_descale_z, + stride_k_descale_h=stride_k_descale_h, + stride_v_descale_z=stride_v_descale_z, + stride_v_descale_h=stride_v_descale_h, Z=batch_size, H_q=heads_per_group_q, H_kv=heads_per_group_k, G_q=n_group_q, N_CTX_Q=seqlen_q, N_CTX_K=seqlen_kc, - N_CTX_NEW=seqlen_kn, + N_CTX_NEW=0, # No new KV, cache already updated BLOCK_N_PER_SPLIT=split_size, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, + BLOCK_SIZE_K=block_size_k if use_block_table else 256, BLOCK_DMODEL=dim_padded, ACTUAL_BLOCK_DMODEL=dim_kc, - BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens, USE_CACHE_SEQLENs=use_cache_seqlens, USE_CACHE_BATCH_IDX=cache_batch_idx is not None, - NEW_KV=is_new_kv, + NEW_KV=False, # Cache already updated IS_GQA=is_gqa, IS_CAUSAL=causal, USE_ALIBI=use_alibi, PADDED_HEAD=is_padded_head, GROUP_SIZE=group_size, - num_warps=num_warps_fwd, - num_stages=num_stages, + USE_SLIDING_WINDOW=use_sliding_window, + WINDOW_SIZE_LEFT=window_size_left, + WINDOW_SIZE_RIGHT=window_size_right, + USE_BLOCK_TABLE=use_block_table, + IS_FP8=IS_FP8, ) if DEBUG: @@ -765,20 +1334,19 @@ def attention_decode_forward_triton_impl( k_block_num = 2 assert dim_padded % k_block_num == 0 k_block_size = dim_padded // k_block_num - grid = (batch_size * n_group_q * heads_per_group_q, seqlen_q, k_block_num) - + reduce_grid = (batch_size * n_group_q * heads_per_group_q, seqlen_q, k_block_num) if DEBUG: print("splitK_pow2:", splitK_pow2) print("k_block_num:", k_block_num) print("k_block_size:", k_block_size) - print("grid:", grid) + print("grid:", reduce_grid) - _splitK_reduce[grid]( - out_splitk, - metadata, - out, - lse, + _splitK_reduce[reduce_grid]( + out_splitk, + metadata, + out, + lse, # Split-K output strides stride_osk_zhg=stride_osk_zhg, stride_osk_s=stride_osk_s, @@ -801,14 +1369,11 @@ def attention_decode_forward_triton_impl( K_BLOCK_SIZE=k_block_size, BLOCK_DMODEL=dim_padded, ACTUAL_BLOCK_DMODEL=dim_kc, - G=n_group_q, + G=n_group_q, H=heads_per_group_q, # TODO: Tune num_warps - split_k=split_k, - splitK_pow2=splitK_pow2, + split_k=split_k, + splitK_pow2=splitK_pow2, MASK_SPLITK=mask_split_k, - IS_CAUSAL=causal, PADDED_HEAD=is_padded_head, - num_warps=num_warps_reduce) - - return lse + ) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py old mode 100644 new mode 100755 index 6f69cd02813..ef8a9d5ff45 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -1,107 +1,341 @@ +import os +import warnings import torch import triton import triton.language as tl -from typing import Literal, Optional, Union -from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, compute_alibi_block, compute_fp8_scaling_factors, get_shapes_from_layout, get_strides_from_layout, is_cdna, is_fp8, is_rdna, create_dropout_mask +from typing import Literal, Optional +from .common import compute_alibi_block, compute_fp8_scaling_factors, apply_rotary +from .utils import ( + DEBUG, + AUTOTUNE, + get_arch, + is_fp8, +) -# NOTE: triton fails to import tl.constexprs so create them here for the file -tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) -tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) -# Convenience function to load with optional boundary checks. -# "First" is the major dim, "second" is the minor dim. -@triton.jit -def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): - if offset_first is not None and offset_second is not None: - mask = (offset_first[:, None] < boundary_first) & \ - (offset_second[None, :] < boundary_second) - tensor = tl.load(ptrs, mask=mask, other=0.0) - elif offset_first is not None: - mask = offset_first[:, None] < boundary_first - tensor = tl.load(ptrs, mask=mask, other=0.0) - elif offset_second is not None: - mask = offset_second[None, :] < boundary_second - tensor = tl.load(ptrs, mask=mask, other=0.0) - else: - tensor = tl.load(ptrs) - return tensor + +FWD_PREFILL_AUTOTUNE_KEYS = [ + "IS_CAUSAL", + "dropout_p", + "MAX_SEQLENS_Q", + "MAX_SEQLENS_K", + "ACTUAL_BLOCK_DMODEL_QK", + "ACTUAL_BLOCK_DMODEL_V", + "IS_VARLEN", + "HQ", + "HK", +] + + +def get_fwd_prefill_configs(autotune: bool): + # Get best config for the architecture. + # NOTE: Tests expect specific BLOCK_N sizes for attention score renormalization: + # - CDNA: BLOCK_N=64 + # - RDNA: BLOCK_N=32 + # See _get_block_size_n_triton() in test_flash_attn_triton_amd.py + if not autotune: + arch = get_arch() + if arch.name == "gfx950": + return [ + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ) + ] + elif arch.name == "gfx942": + if arch.cu_count < 304: + return [ + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + ] + else: + return [ + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ) + ] + elif arch.is_rdna: + return [ + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + ] + else: + return [ + triton.Config( + { + "BLOCK_M": 64, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ) + ] + + # ===================== Autotune Sweep ===================== + configs = [] + BLOCK_M_OPTIONS = [128, 64, 32, 16] + BLOCK_N_OPTIONS = [128, 64, 32, 16] + NUM_WARPS_OPTIONS = [2, 4, 8] + NUM_STAGES_OPTIONS = [1, 2] + WAVES_PER_EU_OPTIONS = [4, 2, 1] + PRE_LOAD_V_OPTIONS = [False] + for bm in BLOCK_M_OPTIONS: + for bn in BLOCK_N_OPTIONS: + for waves in WAVES_PER_EU_OPTIONS: + for nw in NUM_WARPS_OPTIONS: + for ns in NUM_STAGES_OPTIONS: + for preload_v in PRE_LOAD_V_OPTIONS: + configs.append( + triton.Config( + { + "BLOCK_M": bm, + "BLOCK_N": bn, + "waves_per_eu": waves, + "PRE_LOAD_V": preload_v, + }, + num_stages=ns, + num_warps=nw, + ) + ) + + return configs + + +fwd_prefill_autotune_configs = get_fwd_prefill_configs(AUTOTUNE) + @triton.jit -def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, - actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, - block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, - descale_q, descale_k, descale_v, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, - IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, - RETURN_SCORES: tl.constexpr, ACCUMULATOR_TYPE): +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + k_base_ptrs, + v_base_ptrs, + bias_base_ptrs, + stride_kn, + stride_vk, + stride_bn, + stride_sn, + stride_sm, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + philox_seed, + philox_offset_base, + sd_mask, + stride_sz, + stride_sh, + off_z, + off_h_q, + offs_m, + offs_n, + offs_d_qk, + offs_d_v, + block_min, + block_max, + n_extra_tokens, + alibi_slope, + q_descale, + k_descale, + v_descale, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_P_DESCALE: tl.constexpr, + APPLY_MASK: tl.constexpr, # True for masked blocks, False for full blocks + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL_QK: tl.constexpr, + BLOCK_DMODEL_V: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + PADDED_HEAD_QK: tl.constexpr, + PADDED_HEAD_V: tl.constexpr, + ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, + ACTUAL_BLOCK_DMODEL_V: tl.constexpr, + SM_SCALE: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + RETURN_SCORES: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, + ACCUMULATOR_TYPE, +): + """ + Unified attention forward inner loop. + + APPLY_MASK controls whether causal/window masking is applied: + - False: Fast path for full blocks (no masking overhead) + - True: Masked path with causal/window masking support + """ if USE_EXP2: RCP_LN2: tl.constexpr = 1.4426950408889634 - + + # seqlen diff (only used when APPLY_MASK=True) + seqlen_delta_qk = seqlen_k - seqlen_q + # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): - # For padded blocks, we will overrun the tensor size if - # we load all BLOCK_N. For others, the blocks are all within range. - if MASK_STEPS: - k_offs_n = start_n + tl.arange(0, BLOCK_N) + # get ptrs + k_ptrs = k_base_ptrs + start_n * stride_kn + v_ptrs = v_base_ptrs + start_n * stride_vk + + kv_offs_n = start_n + tl.arange(0, BLOCK_N) + + # Load K - different masking for APPLY_MASK vs non-masked + if APPLY_MASK: + # For masked blocks, check seqlen bounds + k_mask = kv_offs_n[None, :] < seqlen_k + v_mask = kv_offs_n[:, None] < seqlen_k + if PADDED_HEAD_QK: + k_mask = k_mask & (offs_d_qk[:, None] < ACTUAL_BLOCK_DMODEL_QK) + if PADDED_HEAD_V: + v_mask = v_mask & (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V) + k = tl.load(k_ptrs, mask=k_mask, other=0.0) + if PRE_LOAD_V: + v = tl.load(v_ptrs, mask=v_mask, other=0.0) else: - k_offs_n = None - k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) - k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k) - if PRE_LOAD_V: - # We can use the same offsets as k, just with dims transposed. - v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) + # For full blocks, only check head dimension padding + if PADDED_HEAD_QK: + k_mask = offs_d_qk[:, None] < ACTUAL_BLOCK_DMODEL_QK + k = tl.load(k_ptrs, mask=k_mask, other=0.0) + else: + k = tl.load(k_ptrs) + if PRE_LOAD_V: + if PADDED_HEAD_V: + v_mask = offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V + v = tl.load(v_ptrs, mask=v_mask, other=0.0) + else: + v = tl.load(v_ptrs) + + # setup qk accumulator qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=ACCUMULATOR_TYPE) - # We start from end of seqlen_k so only the first iteration would need - # to be checked for padding if it is not a multiple of block_n - # TODO: This can be optimized to only be true for the padded block. - if MASK_STEPS: - # If this is the last block / iteration, we want to - # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. - # last step might get wasted but that is okay. check if this masking works For - # that case. - if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): - boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) - size_n = start_n + OFFS_N[None, :] + + # Apply extra token masking for partial blocks (only when APPLY_MASK=True) + if APPLY_MASK: + if (n_extra_tokens != 0) and (start_n + BLOCK_N == block_max): + boundary_m = tl.full([BLOCK_M], seqlen_k, dtype=tl.int32) + size_n = start_n + offs_n[None, :] mask = size_n < boundary_m[:, None] qk = tl.where(mask, qk, float("-inf")) - # compute masks - q_mask = (OFFS_M[:, None] < actual_seqlen_q) - k_mask = ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) - p_mask = q_mask & k_mask - # -- compute qk ---- - if IS_FP8 : - qk += (tl.dot(q, k) * descale_q * descale_k) + if IS_FP8: + qk += tl.dot(q, k) * q_descale * k_descale else: qk += tl.dot(q, k) - qk_scaled = qk * SM_SCALE - - if IS_CAUSAL: - causal_boundary = start_n + offs_n_causal - causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] - qk_scaled = tl.where(causal_mask, qk_scaled, float("-inf")) - if bias_ptrs is not None: - bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None - bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k) - qk_scaled += bias + qk_scaled = qk * SM_SCALE if USE_ALIBI: # compute the global position of each token within the sequence - global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - global_n_positions = start_n + tl.arange(0, BLOCK_N) - alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, - global_n_positions) + q_offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + alibi_block = compute_alibi_block( + alibi_slope, seqlen_q, seqlen_k, q_offs_m, kv_offs_n + ) qk_scaled += alibi_block + + # Apply causal/sliding window masking (only when APPLY_MASK=True) + if APPLY_MASK: + if USE_SLIDING_WINDOW: + if IS_CAUSAL: + # ========== CAUSAL SLIDING WINDOW MASKING ========== + row_idx = offs_m + col_idx = kv_offs_n + row_idx_expanded = row_idx[:, None] + col_idx_expanded = col_idx[None, :] + + causal_offset = seqlen_k - seqlen_q + causal_mask = col_idx_expanded > (row_idx_expanded + causal_offset) + + if WINDOW_SIZE_LEFT < 0: + window_mask = col_idx_expanded > ( + row_idx_expanded + causal_offset + WINDOW_SIZE_RIGHT + ) + else: + left_bound = row_idx_expanded + causal_offset - WINDOW_SIZE_LEFT + right_bound = row_idx_expanded + causal_offset + WINDOW_SIZE_RIGHT + window_mask = (col_idx_expanded < left_bound) | ( + col_idx_expanded > right_bound + ) + + mask = causal_mask | window_mask + qk_scaled = tl.where(mask, float("-inf"), qk_scaled) + else: + # ========== NON-CAUSAL SLIDING WINDOW MASKING ========== + row_idx = offs_m + col_idx = kv_offs_n + sk = seqlen_k + sq = seqlen_q + row_idx_expanded = row_idx[:, None] + col_idx_expanded = col_idx[None, :] + + if WINDOW_SIZE_LEFT < 0: + mask = col_idx_expanded > ( + row_idx_expanded + sk - sq + WINDOW_SIZE_RIGHT + ) + else: + sk_full = tl.full((1, BLOCK_N), sk, dtype=tl.int32) + right_bound_val = row_idx_expanded + sk - sq + WINDOW_SIZE_RIGHT + right_bound = tl.minimum(right_bound_val, sk_full) + left_bound = row_idx_expanded + sk - sq - WINDOW_SIZE_LEFT + mask = (col_idx_expanded > right_bound) | ( + col_idx_expanded < left_bound + ) + + qk_scaled = tl.where(mask, float("-inf"), qk_scaled) + else: + if IS_CAUSAL: + causal_boundary = start_n + offs_n - seqlen_delta_qk + causal_mask = offs_m[:, None] >= causal_boundary[None, :] + qk_scaled = tl.where(causal_mask, qk_scaled, float("-inf")) + + # compute qk mask for bounds checking + qk_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) + + # compute bias + if bias_base_ptrs is not None: + bias_ptrs = bias_base_ptrs + start_n * stride_bn + bias = tl.load(bias_ptrs, mask=qk_mask, other=0.0) + qk_scaled += bias + # get max scores so far m_ij = tl.maximum(m_i, tl.max(qk_scaled, 1)) # scale and subtract max - q_shifted = qk_scaled - m_ij[:, None] - + # Handle the case where all values are -inf + q_shifted = tl.where( + m_ij[:, None] == float("-inf"), float("-inf"), qk_scaled - m_ij[:, None] + ) + # Compute scaled QK and softmax probabilities if USE_EXP2: p = tl.math.exp2(q_shifted * RCP_LN2) @@ -111,539 +345,1399 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: - if tl_DROPOUT_USE_PYTORCH: - dropout_mask = tl.load(dropout_mask_ptrs, mask=p_mask) - else: - rng_output = tl.rand(philox_seed, philox_ptrs) # TODO: use tl.randint for better performance - dropout_mask = rng_output > dropout_p - if tl_DROPOUT_DUMP: - tl.store(dropout_mask_ptrs, dropout_mask, mask=p_mask) + # Compute pointers for this block + philox_base = philox_offset_base + off_z * stride_sz + off_h_q * stride_sh + philox_ptrs = philox_base + offs_m[:, None] * stride_sm + kv_offs_n[None, :] * stride_sn + + # compute dropout mask + rng_output = tl.rand(philox_seed, philox_ptrs) + dropout_mask = rng_output > dropout_p - # return scores with negative values for dropped vals - sd_mask = tl.where(dropout_mask, p, -p) - tl.store(sd_mask_ptrs, sd_mask, mask=p_mask) + # return scores with negative values for dropped vals (only if RETURN_SCORES is True) + if RETURN_SCORES: + sd_mask_value = tl.where(dropout_mask, p, -p) + sd_mask_base = sd_mask + off_z * stride_sz + off_h_q * stride_sh + sd_mask_ptrs = sd_mask_base + offs_m[:, None] * stride_sm + kv_offs_n[None, :] * stride_sn + + sd_store_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) + + if APPLY_MASK and IS_CAUSAL: + causal_constraint = kv_offs_n[None, :] <= (offs_m[:, None] + seqlen_delta_qk) + sd_store_mask = sd_store_mask & causal_constraint + + if APPLY_MASK and USE_SLIDING_WINDOW: + if WINDOW_SIZE_LEFT < 0: + window_constraint = kv_offs_n[None, :] <= (offs_m[:, None] + seqlen_delta_qk + WINDOW_SIZE_RIGHT) + else: + left_bound = offs_m[:, None] + seqlen_delta_qk - WINDOW_SIZE_LEFT + right_bound = offs_m[:, None] + seqlen_delta_qk + WINDOW_SIZE_RIGHT + window_constraint = (kv_offs_n[None, :] >= left_bound) & (kv_offs_n[None, :] <= right_bound) + sd_store_mask = sd_store_mask & window_constraint + + tl.store(sd_mask_ptrs, sd_mask_value, mask=sd_store_mask) # apply dropout mask in place p = tl.where(dropout_mask, p, 0.0) elif RETURN_SCORES: - # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - tl.store(sd_mask_ptrs, p, mask=p_mask) - + sd_mask_base = sd_mask + off_z * stride_sz + off_h_q * stride_sh + sd_mask_ptrs = sd_mask_base + offs_m[:, None] * stride_sm + kv_offs_n[None, :] * stride_sn + + sd_store_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) + + if APPLY_MASK and IS_CAUSAL: + causal_constraint = kv_offs_n[None, :] <= (offs_m[:, None] + seqlen_delta_qk) + sd_store_mask = sd_store_mask & causal_constraint + + if APPLY_MASK and USE_SLIDING_WINDOW: + if WINDOW_SIZE_LEFT < 0: + window_constraint = kv_offs_n[None, :] <= (offs_m[:, None] + seqlen_delta_qk + WINDOW_SIZE_RIGHT) + else: + left_bound = offs_m[:, None] + seqlen_delta_qk - WINDOW_SIZE_LEFT + right_bound = offs_m[:, None] + seqlen_delta_qk + WINDOW_SIZE_RIGHT + window_constraint = (kv_offs_n[None, :] >= left_bound) & (kv_offs_n[None, :] <= right_bound) + sd_store_mask = sd_store_mask & window_constraint + + tl.store(sd_mask_ptrs, p, mask=sd_store_mask) + # -- update output accumulator -- - # alpha is an adjustment factor for acc and li as we loop and find new maxes - # store the diff in maxes to adjust acc and li as we discover new maxes - m_diff = m_i - m_ij + m_diff = tl.where(m_ij == float("-inf"), float("-inf"), m_i - m_ij) if USE_EXP2: alpha = tl.math.exp2(m_diff * RCP_LN2) else: alpha = tl.math.exp(m_diff) acc = acc * alpha[:, None] + + # Load V if not preloaded if not PRE_LOAD_V: - v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) + if APPLY_MASK: + v_mask = kv_offs_n[:, None] < seqlen_k + if PADDED_HEAD_V: + v_mask = v_mask & (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V) + v = tl.load(v_ptrs, mask=v_mask, other=0.0) + else: + if PADDED_HEAD_V: + v_mask = offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V + v = tl.load(v_ptrs, mask=v_mask, other=0.0) + else: + v = tl.load(v_ptrs) + # -- update m_i and l_i l_i = l_i * alpha + l_ij - # update m_i and l_i m_i = m_ij if IS_FP8: - scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) - acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v) + if FP8_P_DESCALE: + scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) + acc += ( + tl.dot((p * scale_p).to(v.type.element_ty), v) + * descale_p + * v_descale + ) + else: + acc += tl.dot(p.to(v.type.element_ty), v) * v_descale else: acc += tl.dot(p.to(v.type.element_ty), v) - k_ptrs += BLOCK_N * stride_kn - v_ptrs += BLOCK_N * stride_vk - if bias_ptrs is not None: - bias_ptrs += BLOCK_N * stride_bn - if RETURN_SCORES: - sd_mask_ptrs += BLOCK_N * stride_sn - - if ENABLE_DROPOUT: - dropout_mask_ptrs += BLOCK_N * stride_sn - philox_ptrs += BLOCK_N * stride_sn return acc, l_i, m_i -def get_cdna_autotune_configs(): - return [ - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - # Fall-back config. - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] - - -def get_rdna_autotune_configs(): - return [ - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - # Fall-back config. - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] - - -def get_autotune_configs(): - if AUTOTUNE: - if is_rdna(): - return get_rdna_autotune_configs() - elif is_cdna(): - return get_cdna_autotune_configs() +@triton.jit +def compute_window_bounds( + q_start, + q_end, + diag, + seqlen_k, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, + IS_CAUSAL: tl.constexpr, +): + """Calculate the window boundaries for a query block.""" + # Left boundary + if WINDOW_SIZE_LEFT < 0: + left_min = 0 + left_max = 0 + else: + left_min = tl.maximum(0, q_start + diag - WINDOW_SIZE_LEFT) + left_max = tl.maximum(0, q_end + diag - WINDOW_SIZE_LEFT) + + # Right boundary + if IS_CAUSAL: + # Causal cap: col ≤ row + diag + right_min = tl.minimum(seqlen_k - 1, q_start + diag) + right_max = tl.minimum(seqlen_k - 1, q_end + diag) + else: + if WINDOW_SIZE_RIGHT < 0: + right_min = tl.minimum(seqlen_k - 1, q_start + diag + WINDOW_SIZE_RIGHT) + right_max = tl.minimum(seqlen_k - 1, q_end + diag + WINDOW_SIZE_RIGHT) else: - raise ValueError("Unknown Device Type") + # Non-causal doesn't have the diagonal constraint + right_min = tl.minimum(seqlen_k - 1, q_start + diag + WINDOW_SIZE_RIGHT) + right_max = tl.minimum(seqlen_k - 1, q_end + diag + WINDOW_SIZE_RIGHT) + + return left_min, left_max, right_min, right_max + + +@triton.jit +def classify_window_blocks( + left_min, left_max, right_min, right_max, BLOCK_N: tl.constexpr +): + """Classify blocks based on window boundaries.""" + # First and last blocks that have ANY overlap with window + first_block = left_min // BLOCK_N + last_block = right_max // BLOCK_N + + # First block that is FULLY visible for all rows in Q block + full_left_block = left_max // BLOCK_N + (left_max % BLOCK_N != 0) + clipped_left = tl.minimum(full_left_block, last_block + 1) + + # Last block that is FULLY visible for all rows in Q block + last_full_block_candidate = right_min // BLOCK_N + if (last_full_block_candidate + 1) * BLOCK_N - 1 > right_min: + last_full_block_candidate -= 1 + full_right_block = tl.maximum(last_full_block_candidate, clipped_left - 1) + + # Calculate counts + n_front_skip_blocks = first_block + n_front_masked_blocks = tl.maximum(0, clipped_left - first_block) + n_full_blocks = tl.maximum(0, full_right_block - clipped_left + 1) + n_back_masked_blocks = tl.maximum(0, last_block - full_right_block) + + return ( + n_front_skip_blocks, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, + clipped_left, + ) # Return clipped_left for padded block handling + + +@triton.jit +def handle_padded_last_block( + n_extra_tokens, + last_block, + total_k_blocks, + clipped_left, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, +): + """Ensure a padded last K-block is never classified as 'full'. + + We move the padded last block (if visible) into the back-masked bucket. + If it's already back-masked, we do nothing. If it was counted in the + front-masked range, we decrement front-masked; if it was counted as full, + we decrement full. Then we increment back-masked. + """ + padded_last_k = (n_extra_tokens != 0) & (last_block == total_k_blocks - 1) + + if padded_last_k: + # current 'full' range right edge + full_right_block = clipped_left + n_full_blocks - 1 + + # If last_block is already beyond full_right_block, it's already in back-masked → nothing to do + last_already_back_masked = last_block > full_right_block + if not last_already_back_masked: + # If the window starts past last_block, it was counted in front-masked + if clipped_left > last_block: + n_front_masked_blocks = tl.maximum(0, n_front_masked_blocks - 1) + else: + # Otherwise it was counted 'full' → move it out of full + n_full_blocks = tl.maximum(0, n_full_blocks - 1) + # In both cases we need one more back-masked block + n_back_masked_blocks = n_back_masked_blocks + 1 + + return n_front_masked_blocks, n_full_blocks, n_back_masked_blocks + + +@triton.jit +def compute_padding_info(seqlen_k, BLOCK_N: tl.constexpr): + """Calculate padding information for the last K block.""" + # check if we will need to do masking due either BLOCK_N being bigger than seqlen_k or seqlen_k not being a factor of BLOCK_N + # n_extra_tokens = 10 % 4 = 2 + # This means the last K block has 2 valid tokens and 2 padding positions + # K blocks visualization: + # Block 0 Block 1 Block 2 (last) + # K0 K1 K2 K3 K4 K5 K6 K7 K8 K9 ?? ?? + # ↑---------↑ ↑---------↑ ↑---↑ ↑---↑ + # full block full block valid pad + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N else: - return [ - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - ], [ - "IS_CAUSAL", - "dropout_p", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "ACTUAL_BLOCK_DMODEL", - "IS_VARLEN", - "HQ", - "HK", - ] - - -autotune_configs, autotune_keys = get_autotune_configs() + n_extra_tokens = 0 + return n_extra_tokens + + +@triton.jit +def compute_block_masking( + seqlen_k, + seqlen_q, + start_m, + IS_CAUSAL: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + Classify K blocks for attention computation with sliding window support. + + Returns: + - n_front_skip_blocks: Blocks completely before the window + - n_front_masked_blocks: Blocks partially overlapping window front + - n_full_blocks: Blocks completely inside the window + - n_back_masked_blocks: Blocks partially overlapping window back + - n_extra_tokens: Padding tokens in last K block + """ + + # common + q_start = start_m * BLOCK_M + q_end = tl.minimum((start_m + 1) * BLOCK_M - 1, seqlen_q - 1) + diag = seqlen_k - seqlen_q + total_k_blocks = tl.cdiv(seqlen_k, BLOCK_N) + n_extra_tokens = compute_padding_info(seqlen_k, BLOCK_N) + + if USE_SLIDING_WINDOW: + # get window bounds + left_min, left_max, right_min, right_max = compute_window_bounds( + q_start, + q_end, + diag, + seqlen_k, + WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT, + IS_CAUSAL, + ) + + # window vanishes → early exit + if right_max < left_min: + return 0, 0, 0, 0, n_extra_tokens + + # classify blocks + ( + n_front_skip_blocks, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, + clipped_left, + ) = classify_window_blocks(left_min, left_max, right_min, right_max, BLOCK_N) + + # handle padded last block if needed + if n_extra_tokens != 0: + last_block = right_max // BLOCK_N + n_front_masked_blocks, n_full_blocks, n_back_masked_blocks = ( + handle_padded_last_block( + n_extra_tokens, + last_block, + total_k_blocks, + clipped_left, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, + ) + ) + return ( + n_front_skip_blocks, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, + n_extra_tokens, + ) + else: + if IS_CAUSAL: + # ========== CAUSAL MODE: Classify K Blocks ========== + # Calculate causal boundary for this Q block + # [K0 K1 K2 K3] [K4 K5 K6 K7] [K8 K9 ?? ??] + # Q0-Q3: [ 1 0 0 0] [ 0 0 0 0] [ 0 0 -- --] ← Q0 + # [ 1 1 0 0] [ 0 0 0 0] [ 0 0 -- --] ← Q1 + # [ 1 1 1 0] [ 0 0 0 0] [ 0 0 -- --] ← Q2 + # [ 1 1 1 1] [ 1 1 0 0] [ 0 0 -- --] ← Q3 + # ↑ can see up to K5 + # + # Q4-Q7: [ 1 1 1 1] [ 1 1 1 0] [ 0 0 -- --] ← Q4 + # [ 1 1 1 1] [ 1 1 1 1] [ 0 0 -- --] ← Q5 + # [ 1 1 1 1] [ 1 1 1 1] [ 1 0 -- --] ← Q6 + # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -- --] ← Q7 + + # ------------------------------------------------------------ + # 1. figure out, in tokens, the right-most K position + # this Q-block may attend to + # ------------------------------------------------------------ + k_max_token = q_end + diag # last visible K index + + # this Q-block is entirely above the diagonal ⇒ nothing to do + if k_max_token < 0: + return 0, 0, 0, 0, n_extra_tokens + + k_max_token = tl.minimum(k_max_token, seqlen_k - 1) + + # ------------------------------------------------------------ + # 2. translate token indices into K-block indices + # ------------------------------------------------------------ + last_visible_k_block = k_max_token // BLOCK_N + n_visible_k_blocks = tl.minimum(last_visible_k_block + 1, total_k_blocks) + + # ------------------------------------------------------------ + # 3. classify those visible blocks + # – we *never* skip or mask blocks in front, because causal + # attention always starts at K0 + # – the back side can require several masked blocks: + # • intersection of the causal diagonal with K-grid + # (at most ⌈BLOCK_M / BLOCK_N⌉ blocks) + # • plus one for partial K blocks at the causal boundary + # ------------------------------------------------------------ + n_back_masked_blocks = BLOCK_M // BLOCK_N + 1 + n_back_masked_blocks = tl.minimum(n_back_masked_blocks, n_visible_k_blocks) + + n_front_skip_blocks = 0 # causal never skips the left side + n_front_masked_blocks = 0 # ditto + n_full_blocks = n_visible_k_blocks - n_back_masked_blocks + else: + # ========== NON-CAUSAL MODE ========== + # Without causal mask, all positions can attend to all positions + # Only need to handle the padding in the last block + # [K0 K1 K2 K3] [K4 K5 K6 K7] [K8 K9 ?? ??] + # Q0-Q3: [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + # + # Q4-Q7: [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + + n_front_skip_blocks = 0 # never skips the left side + n_front_masked_blocks = 0 # ditto + if n_extra_tokens != 0: + n_back_masked_blocks = 1 # Last block needs padding mask + n_full_blocks = total_k_blocks - 1 + else: + n_back_masked_blocks = 0 # All blocks are aligned + n_full_blocks = total_k_blocks + + return ( + n_front_skip_blocks, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, + n_extra_tokens, + ) + @triton.autotune( - configs=autotune_configs, - key=autotune_keys, + configs=fwd_prefill_autotune_configs, + key=FWD_PREFILL_AUTOTUNE_KEYS, use_cuda_graph=True, ) @triton.jit -def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, - Descale_Q, Descale_K, Descale_V, Descale_O, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_o_z, - SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, - stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, - stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, - dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, HQ: tl.constexpr, - HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, - MAX_SEQLENS_K: tl.constexpr, IS_VARLEN: tl.constexpr, IS_INFERENCE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr): +def attn_fwd( + Q, + K, + V, + bias, + Q_Descale, + K_Descale, + V_Descale, + stride_q_descale_z, + stride_k_descale_z, + stride_v_descale_z, + LSE, + Out, + SD_MASK, + ALIBI_SLOPES, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + stride_bz, + stride_bh, + stride_bm, + stride_bn, + stride_az, + stride_ah, + stride_sz, + stride_sh, + stride_sm, + stride_sn, + stride_lse_z, + stride_lse_h, + stride_lse_m, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, # Add seqused parameters + dropout_p, + philox_seed, + philox_offset_base, + HQ: tl.constexpr, + HK: tl.constexpr, + ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, + ACTUAL_BLOCK_DMODEL_V: tl.constexpr, + MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, + IS_VARLEN: tl.constexpr, + SM_SCALE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL_QK: tl.constexpr, + BLOCK_DMODEL_V: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + USE_BIAS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_SCORES: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_P_DESCALE: tl.constexpr, + USE_SEQUSED: tl.constexpr, + FORCE_MASKING: tl.constexpr, +): # set params ACCUMULATOR_TYPE = tl.float32 # compute offsets - start_m = tl.program_id(0) + off_z = tl.program_id(0) off_h_q = tl.program_id(1) - off_z = tl.program_id(2) + start_m = tl.program_id(2) + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + if GROUP_SIZE != 1: + off_h_k = off_h_q // GROUP_SIZE + else: + off_h_k = off_h_q + # Determine if we need to mask the heads + PADDED_HEAD_QK: tl.constexpr = ACTUAL_BLOCK_DMODEL_QK != BLOCK_DMODEL_QK + PADDED_HEAD_V: tl.constexpr = ACTUAL_BLOCK_DMODEL_V != BLOCK_DMODEL_V + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + offs_d_qk = tl.arange(0, BLOCK_DMODEL_QK) + offs_d_v = tl.arange(0, BLOCK_DMODEL_V) # handle seqlen if IS_VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) - seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start - + + # If seqused is provided, use it to limit the actual sequence length + if USE_SEQUSED: + actual_seqlen_q = ( + tl.load(seqused_q + off_z) + if seqused_q is not None + else cu_seqlens_q_end - cu_seqlens_q_start + ) + seqlen_q = tl.minimum( + actual_seqlen_q, cu_seqlens_q_end - cu_seqlens_q_start + ) + else: + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # we have a one-size-fits-all grid in id(0). Some seqlens might be too small for all start_m so for those we return early. if start_m * BLOCK_M > seqlen_q: return cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) - seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start - elif IS_INFERENCE: - cu_seqlens_q_start = 0 - cu_seqlens_k_start = 0 - seqlen_q = MAX_SEQLENS_Q - seqlen_k = tl.load(Cache_seqlens + off_z) + + # If seqused is provided, use it to limit the actual sequence length for keys + if USE_SEQUSED: + actual_seqlen_k = ( + tl.load(seqused_k + off_z) + if seqused_k is not None + else cu_seqlens_k_end - cu_seqlens_k_start + ) + seqlen_k = tl.minimum( + actual_seqlen_k, cu_seqlens_k_end - cu_seqlens_k_start + ) + else: + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start else: cu_seqlens_q_start = 0 cu_seqlens_k_start = 0 seqlen_q = MAX_SEQLENS_Q seqlen_k = MAX_SEQLENS_K - # Now we compute whether we need to exit early due to causal masking. - # This is because for seqlen_q > seqlen_k, M rows of the attn scores - # are completely masked, resulting in 0s written to the output, and - # inf written to LSE. We don't need to do any GEMMs in this case. - # This block of code determines what N is, and if this WG is operating - # on those M rows. - n_blocks = tl.cdiv(seqlen_k, BLOCK_N) - if (IS_CAUSAL): - # If seqlen_q == seqlen_k, the attn scores are a square matrix. - # If seqlen_q != seqlen_k, attn scores are rectangular which means - # the causal mask boundary is bottom right aligned, and ends at either - # the top edge (seqlen_q < seqlen_k) or left edge. - # This captures the decrease in n_blocks if we have a rectangular attn matrix - n_blocks_seqlen = tl.cdiv((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) - # This is what adjusts the block_max for the current WG, only - # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks - n_blocks = min(n_blocks, n_blocks_seqlen) - # If we have no blocks after adjusting for seqlen deltas, this WG is part of - # the blocks that are all 0. We exit early. - if n_blocks <= 0: - o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om - o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) - o_ptrs_mask = offs_m[:, None] < seqlen_q - # We still need to write 0s to the result - tl.store(o_ptrs, acc, mask=o_ptrs_mask) - # The tensor allocated for L is based on MAX_SEQLENS_Q as that is - # statically known. - l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m - l_ptrs = l_offset + offs_m * stride_lse_m - - l = tl.full([BLOCK_M], value=0.0, dtype=ACCUMULATOR_TYPE) - - # mask_m_offsets = start_m + tl.arange(0, BLOCK_M) - # lse_mask = mask_m_offsets < causal_start_idx - # softmax_lse = tl.where(lse_mask, 0.0, softmax_lse) - l_ptrs_mask = offs_m < MAX_SEQLENS_Q - tl.store(l_ptrs, l, mask=l_ptrs_mask) - # TODO: Should dropout and return encoded softmax be handled here too? - return - - # If MQA / GQA, set the K and V head offsets appropriately. - GROUP_SIZE: tl.constexpr = HQ // HK - if GROUP_SIZE != 1: - off_h_k = off_h_q // GROUP_SIZE + # Load scale factors if IS_FP8. + if IS_FP8: + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (off_h_k) + # For MHA (GROUP_SIZE == 1), q_descale uses off_h_q (same as off_h_k) + if GROUP_SIZE != 1: + q_descale = tl.load( + Q_Descale + off_z * stride_q_descale_z + off_h_k + ) # MQA/GQA: broadcast using k/v head index + else: + q_descale = tl.load( + Q_Descale + off_z * stride_q_descale_z + off_h_q + ) # MHA: use q head index + k_descale = tl.load(K_Descale + off_z * stride_k_descale_z + off_h_k) + v_descale = tl.load(V_Descale + off_z * stride_v_descale_z + off_h_k) else: - off_h_k = off_h_q + q_descale, k_descale, v_descale = 1.0, 1.0, 1.0 - n_extra_tokens = 0 - # print("n_extra_tokens:", n_extra_tokens) - # print("seqlen_k:", seqlen_k) - # print("BLOCK_N:", BLOCK_N) - # return - if seqlen_k < BLOCK_N: - n_extra_tokens = BLOCK_N - seqlen_k - elif seqlen_k % BLOCK_N: - n_extra_tokens = seqlen_k % BLOCK_N - PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + # figure out masking pattern + ( + n_front_skip_blocks, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, + n_extra_tokens, + ) = compute_block_masking( + seqlen_k, + seqlen_q, + start_m, + IS_CAUSAL, + USE_SLIDING_WINDOW, + WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT, + BLOCK_M, + BLOCK_N, + ) + + # ============================================================ + # PROGRAM EARLY EXIT (All K Blocks Skipped) + # ============================================================ + total_visible_blocks = n_front_masked_blocks + n_full_blocks + n_back_masked_blocks + if total_visible_blocks == 0: + """ + No K blocks visible - write zeros and exit. + """ + # Write zeros to output + o_offset = ( + Out + + off_z * stride_oz + + off_h_q * stride_oh + + cu_seqlens_q_start * stride_om + ) + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d_v[None, :] * stride_on + o_mask = offs_m[:, None] < seqlen_q + if PADDED_HEAD_V: + o_mask = o_mask & (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V) + tl.store( + o_ptrs, + tl.zeros([BLOCK_M, BLOCK_DMODEL_V], dtype=Out.type.element_ty), + mask=o_mask, + ) + # Write zeros to LSE + l_ptrs = ( + LSE + + off_z * stride_lse_z + + off_h_q * stride_lse_h + + cu_seqlens_q_start * stride_lse_m + + offs_m * stride_lse_m + ) + tl.store(l_ptrs, tl.zeros([BLOCK_M], dtype=tl.float32), mask=offs_m < seqlen_q) + return + + # ============================================================ + # NORMAL PROCESSING (Some K Blocks Visible) + # ============================================================ + """ + This program has visible K blocks to process. + We'll use two calls to handle different block types efficiently. + """ + + # Initialize for processing # Compute pointers for all the tensors used in this kernel. - q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm - q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk - k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn - k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn - v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk - v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn + q_offset = ( + Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + ) + q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d_qk[None, :] * stride_qk + k_offset = ( + K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + ) + k_ptrs = k_offset + offs_d_qk[:, None] * stride_kk + offs_n[None, :] * stride_kn + v_offset = ( + V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + ) + v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d_v[None, :] * stride_vn if USE_BIAS: # Note: this might get large enough to overflow on some configs bias_offset = off_h_q * stride_bh - bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn + bias_ptrs = ( + bias + + bias_offset + + offs_m[:, None] * stride_bm + + offs_n[None, :] * stride_bn + ) else: bias_ptrs = None if USE_ALIBI: a_offset = off_z * stride_az + off_h_q * stride_ah - alibi_slope = tl.load(alibi_slopes + a_offset) + alibi_slope = tl.load(ALIBI_SLOPES + a_offset) else: alibi_slope = None - if RETURN_SCORES: - sd_mask_offset = sd_mask + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm - sd_mask_ptrs = sd_mask_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn - else: - sd_mask_ptrs = None - - if ENABLE_DROPOUT: - dropout_mask_offset = dropout_mask + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm - dropout_mask_ptrs = dropout_mask_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn - batch_philox_offset = philox_offset_base + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm - philox_ptrs = batch_philox_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn - else: - dropout_mask_ptrs = None - philox_ptrs = 0 # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=ACCUMULATOR_TYPE) l_i = tl.full([BLOCK_M], 1.0, dtype=ACCUMULATOR_TYPE) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=ACCUMULATOR_TYPE) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_V], dtype=ACCUMULATOR_TYPE) + # Q is loaded once at the beginning and shared by all N blocks. q_ptrs_mask = offs_m[:, None] < seqlen_q - if PADDED_HEAD: - q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + if PADDED_HEAD_QK: + q_ptrs_mask = q_ptrs_mask & (offs_d_qk[None, :] < ACTUAL_BLOCK_DMODEL_QK) q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) - # Load scale factors if IS_FP8. - if IS_FP8: - descale_q = tl.load(Descale_Q + off_z * stride_descale_q_z + off_h_q) - descale_k = tl.load(Descale_K + off_z * stride_descale_k_z + off_h_k) - descale_v = tl.load(Descale_V + off_z * stride_descale_v_z + off_h_k) - else: - descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + # ========== Process MASKED K Blocks in the front ========== + # NOTE: we use USE_SLIDING_WINDOW as guard because the compiler will crash other wise. front masking is only for sliding window so that is fine. + if n_front_masked_blocks > 0 and USE_SLIDING_WINDOW: + block_min = n_front_skip_blocks * BLOCK_N + block_max = (n_front_skip_blocks + n_front_masked_blocks) * BLOCK_N - # Here we compute how many full and masked blocks we have. - padded_block_k = n_extra_tokens != 0 - is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) - if IS_CAUSAL: - # There are always at least BLOCK_M // BLOCK_N masked blocks. - # Additionally there might be one more due to dissimilar seqlens. - masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) - else: - # Padding on Q does not need to be masked in the FA loop. - masked_blocks = padded_block_k - # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. - # In this case we might exceed n_blocks so pick the min. - masked_blocks = min(masked_blocks, n_blocks) - n_full_blocks = n_blocks - masked_blocks - block_min = 0 - block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false regardless of its actual - # value because there is no masking. Similarly we do not need padding. + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, + stride_sn, + stride_sm, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + philox_seed, + philox_offset_base, + SD_MASK, + stride_sz, + stride_sh, + off_z, + off_h_q, + offs_m, + offs_n, + offs_d_qk, + offs_d_v, + block_min, # Start of front masked blocks + block_max, # End of front masked blocks + 0, # n_extra_tokens (0 for front blocks, only relevant for last block) + alibi_slope, + q_descale, + k_descale, + v_descale, + IS_FP8, + FP8_MAX, + FP8_P_DESCALE, + APPLY_MASK=True, # Masked blocks + IS_CAUSAL=IS_CAUSAL, + BLOCK_M=BLOCK_M, + BLOCK_DMODEL_QK=BLOCK_DMODEL_QK, + BLOCK_DMODEL_V=BLOCK_DMODEL_V, + BLOCK_N=BLOCK_N, + PRE_LOAD_V=PRE_LOAD_V, + ENABLE_DROPOUT=ENABLE_DROPOUT, + PADDED_HEAD_QK=PADDED_HEAD_QK, + PADDED_HEAD_V=PADDED_HEAD_V, + ACTUAL_BLOCK_DMODEL_QK=ACTUAL_BLOCK_DMODEL_QK, + ACTUAL_BLOCK_DMODEL_V=ACTUAL_BLOCK_DMODEL_V, + SM_SCALE=SM_SCALE, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + RETURN_SCORES=RETURN_SCORES, + USE_SLIDING_WINDOW=USE_SLIDING_WINDOW, + WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT=WINDOW_SIZE_RIGHT, + ACCUMULATOR_TYPE=ACCUMULATOR_TYPE, + ) + + # ========== Process FULL K Blocks (Fast Path) ========== if n_full_blocks > 0: - block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, - start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, - sd_mask_ptrs, dropout_mask_ptrs, - # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, block_max, 0, 0, 0, alibi_slope, - descale_q, descale_k, descale_v, IS_FP8, FP8_MAX, - # IS_CAUSAL, .... - False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, False, ENABLE_DROPOUT, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, ACCUMULATOR_TYPE=ACCUMULATOR_TYPE) - block_min = block_max - block_max = n_blocks * BLOCK_N - - tl.debug_barrier() - # Remaining blocks, if any, are full / not masked. - if (masked_blocks > 0): - if IS_CAUSAL: - offs_n_causal = offs_n + (seqlen_q - seqlen_k) - else: - offs_n_causal = 0 - k_ptrs += n_full_blocks * BLOCK_N * stride_kn - v_ptrs += n_full_blocks * BLOCK_N * stride_vk - if USE_BIAS: - bias_ptrs += n_full_blocks * BLOCK_N * stride_bn - if RETURN_SCORES: - sd_mask_ptrs += n_full_blocks * BLOCK_N * stride_sn - if ENABLE_DROPOUT: - dropout_mask_ptrs += n_full_blocks * BLOCK_N * stride_sn - philox_ptrs += n_full_blocks * BLOCK_N * stride_sn - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, - start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, - sd_mask_ptrs, dropout_mask_ptrs, block_min, block_max, offs_n_causal, masked_blocks, - n_extra_tokens, alibi_slope, descale_q, descale_k, descale_v, IS_FP8, FP8_MAX, - IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, True, ENABLE_DROPOUT, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, ACCUMULATOR_TYPE=ACCUMULATOR_TYPE) - # epilogue - # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. - l_recip = 1 / l_i[:, None] + block_min = (n_front_skip_blocks + n_front_masked_blocks) * BLOCK_N + block_max = ( + n_front_skip_blocks + n_front_masked_blocks + n_full_blocks + ) * BLOCK_N + + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, + stride_sn, + stride_sm, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + philox_seed, + philox_offset_base, + SD_MASK, + stride_sz, + stride_sh, + off_z, + off_h_q, + offs_m, + offs_n, + offs_d_qk, + offs_d_v, + block_min, # Start of range: 0 + block_max, # End of range: n_full_blocks * BLOCK_N + 0, # n_extra_tokens (not used for full blocks) + alibi_slope, + q_descale, + k_descale, + v_descale, + IS_FP8, + FP8_MAX, + FP8_P_DESCALE, + APPLY_MASK=FORCE_MASKING, + IS_CAUSAL=IS_CAUSAL, + BLOCK_M=BLOCK_M, + BLOCK_DMODEL_QK=BLOCK_DMODEL_QK, + BLOCK_DMODEL_V=BLOCK_DMODEL_V, + BLOCK_N=BLOCK_N, + PRE_LOAD_V=PRE_LOAD_V, + ENABLE_DROPOUT=ENABLE_DROPOUT, + PADDED_HEAD_QK=PADDED_HEAD_QK, + PADDED_HEAD_V=PADDED_HEAD_V, + ACTUAL_BLOCK_DMODEL_QK=ACTUAL_BLOCK_DMODEL_QK, + ACTUAL_BLOCK_DMODEL_V=ACTUAL_BLOCK_DMODEL_V, + SM_SCALE=SM_SCALE, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + RETURN_SCORES=RETURN_SCORES, + USE_SLIDING_WINDOW=USE_SLIDING_WINDOW, + WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT=WINDOW_SIZE_RIGHT, + ACCUMULATOR_TYPE=ACCUMULATOR_TYPE, + ) + + # ========== Process MASKED K Blocks in the back ========== + if n_back_masked_blocks > 0: + block_min = ( + n_front_skip_blocks + n_front_masked_blocks + n_full_blocks + ) * BLOCK_N + block_max = ( + n_front_skip_blocks + + n_front_masked_blocks + + n_full_blocks + + n_back_masked_blocks + ) * BLOCK_N + + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, + stride_sn, + stride_sm, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + philox_seed, + philox_offset_base, + SD_MASK, + stride_sz, + stride_sh, + off_z, + off_h_q, + offs_m, + offs_n, + offs_d_qk, + offs_d_v, + block_min, # Start of range: n_full_blocks * BLOCK_N + block_max, # End of range: n_visible_k_blocks * BLOCK_N + n_extra_tokens, # Padding tokens in last block + alibi_slope, + q_descale, + k_descale, + v_descale, + IS_FP8, + FP8_MAX, + FP8_P_DESCALE, + APPLY_MASK=True, # Masked blocks + IS_CAUSAL=IS_CAUSAL, # Use actual causal flag + BLOCK_M=BLOCK_M, + BLOCK_DMODEL_QK=BLOCK_DMODEL_QK, + BLOCK_DMODEL_V=BLOCK_DMODEL_V, + BLOCK_N=BLOCK_N, + PRE_LOAD_V=PRE_LOAD_V, + ENABLE_DROPOUT=ENABLE_DROPOUT, + PADDED_HEAD_QK=PADDED_HEAD_QK, + PADDED_HEAD_V=PADDED_HEAD_V, + ACTUAL_BLOCK_DMODEL_QK=ACTUAL_BLOCK_DMODEL_QK, + ACTUAL_BLOCK_DMODEL_V=ACTUAL_BLOCK_DMODEL_V, + SM_SCALE=SM_SCALE, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + RETURN_SCORES=RETURN_SCORES, + USE_SLIDING_WINDOW=USE_SLIDING_WINDOW, + WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT=WINDOW_SIZE_RIGHT, + ACCUMULATOR_TYPE=ACCUMULATOR_TYPE, + ) + + # ============================================================ + # EPILOGUE + # ============================================================ + # Handle invalid rows: rows with no valid keys to attend to. + # This occurs with sliding window or causal attention (when seqlen_q > seqlen_k). + # For invalid rows: m_i = -inf, l_i = 0, acc = 0. + # We set l_i = 1.0 to avoid division by zero and ensure LSE = -inf. + invalid_mask = m_i == float("-inf") + l_i_safe = tl.where(invalid_mask, 1.0, l_i) + l_recip = 1 / l_i_safe[:, None] acc = acc * l_recip + if ENABLE_DROPOUT: dropout_scale = 1 / (1 - dropout_p) acc = acc * dropout_scale - # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, - # then we have one block with a row of all NaNs which come from computing - # softmax over a row of all -infs (-inf - inf = NaN). We check for that here - # and store 0s where there are NaNs as these rows should've been zeroed out. - end_m_idx = (start_m + 1) * BLOCK_M - start_m_idx = start_m * BLOCK_M - causal_start_idx = seqlen_q - seqlen_k - if IS_CAUSAL: - if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) - mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] - z = 0.0 - acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) - # write back LSE(Log Sum Exponents), the log of the normalization constant - l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m - l_ptrs = l_offset + offs_m * stride_lse_m + # compute log-sum-exp if USE_EXP2: RCP_LN2: tl.constexpr = 1.4426950408889634 LN2: tl.constexpr = 0.6931471824645996 - # compute log-sum-exp in base 2 units - mi_base2 = m_i * RCP_LN2 - softmax_lse = mi_base2 + tl.math.log2(l_i) - # convert back to natural units - softmax_lse *= LN2 + softmax_lse = (m_i * RCP_LN2 + tl.math.log2(l_i)) * LN2 else: softmax_lse = m_i + tl.math.log(l_i) - if IS_CAUSAL: - # zero out nans caused by -infs when doing causal - lse_mask = (start_m_idx + tl.arange(0, BLOCK_M)) < causal_start_idx - softmax_lse = tl.where(lse_mask, 0.0, softmax_lse) + # Ensure invalid rows have LSE = -inf + softmax_lse = tl.where(invalid_mask, float("-inf"), softmax_lse) + + # write back LSE(Log Sum Exponents), the log of the normalization constant + l_offset = ( + LSE + + off_z * stride_lse_z + + off_h_q * stride_lse_h + + cu_seqlens_q_start * stride_lse_m + ) + l_ptrs = l_offset + offs_m * stride_lse_m # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. - # This is only true for the last M block. For others, overflow_size will be -ve + # This is only true for the last Q block. For others, overflow_size will be -ve + end_m_idx = (start_m + 1) * BLOCK_M overflow_size = end_m_idx - seqlen_q if overflow_size > 0: - boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) + boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary - tl.store(l_ptrs, softmax_lse, mask=l_ptrs_mask) # the log of the normalization constant + tl.store(l_ptrs, softmax_lse, mask=l_ptrs_mask) else: - tl.store(l_ptrs, softmax_lse) # the log of the normalization constant + tl.store(l_ptrs, softmax_lse) # write back O - o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om - o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on - o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) + o_offset = ( + Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + ) + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d_v[None, :] * stride_on + o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL_V], 1, dtype=tl.int1) if overflow_size > 0: o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) - if PADDED_HEAD: - o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + if PADDED_HEAD_V: + o_ptrs_mask = o_ptrs_mask & (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V) - if FP8_OUTPUT: - scale_acc, descale_acc = compute_fp8_scaling_factors(acc, FP8_MAX) - tl.store(Descale_O + off_z * stride_descale_o_z + off_h_q, descale_acc) - tl.store(o_ptrs, (acc * scale_acc).to(Out.type.element_ty), mask=o_ptrs_mask) - else: - tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) - - -def attention_prefill_forward_triton_impl( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - sm_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - bias: Optional[torch.Tensor], - layout: Literal["bshd", "bhsd", "thd"], - # varlen - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - max_seqlens_q: int, - max_seqlens_k: int, - # inference - cache_seqlens: Optional[Union[(int, torch.Tensor)]], - cache_batch_idx: Optional[torch.Tensor], - # dropout - dropout_p: float, - philox_seed: Optional[int], - philox_offset: Optional[int], - # misc - return_softmax: bool, - use_exp2: bool, - # fp8 - descale_q: Optional[torch.Tensor], - descale_k: Optional[torch.Tensor], - descale_v: Optional[torch.Tensor], - descale_o: Optional[torch.Tensor], + tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) + + +def attention_forward_prefill_triton_impl( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + softmax_lse: torch.Tensor, + sd_mask: Optional[torch.Tensor], + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + window_size_left: int, + window_size_right: int, + bias: Optional[torch.Tensor], + layout: Literal["bshd", "bhsd", "thd"], + # varlen + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + max_seqlens_q: int, + max_seqlens_k: int, + # dropout + dropout_p: float, + philox_seed: Optional[int], + philox_offset: Optional[int], + # misc + return_scores: bool, + use_exp2: bool, + # fp8 + q_descale: Optional[torch.Tensor], + k_descale: Optional[torch.Tensor], + v_descale: Optional[torch.Tensor], + # seqused for FA v3 + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + # rotary (optional) + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, + rotary_interleaved: bool = False, + seqlens_rotary: Optional[torch.Tensor] = None, ): - IS_FP8 = is_fp8(q) - if IS_FP8: - FP8_MAX: tl.constexpr = torch.finfo(q.dtype).max + # get params, strides and shape + IS_VARLEN = layout == "thd" + + # common assertions + assert ( + 0.0 <= dropout_p <= 1.0 + ), f"dropout_p must be between 0 and 1, got {dropout_p}" + assert ( + q.device == k.device == v.device == o.device + ), f"All tensors must be on the same device. Got: q={q.device}, k={k.device}, v={v.device}, o={o.device}" + assert q.dtype == k.dtype == v.dtype, "q, k, v must have the same dtype" + current_device = torch.cuda.current_device() + assert ( + q.is_cuda and q.device.index == current_device + ), f"Device mismatch: Kernel will launch on cuda:{current_device}, but tensors are on {q.device}" + + # get shapes and strides + if IS_VARLEN: + # shape + total_seqlen_q, nheads_q, head_size_q = q.shape + total_seqlen_k, nheads_k, head_size_k = k.shape + total_seqlen_v, nheads_v, head_size_v = v.shape + + # assert shapes + assert ( + cu_seqlens_q is not None + ), "cu_seqlens_q must be provided for varlen layout" + assert ( + cu_seqlens_k is not None + ), "cu_seqlens_k must be provided for varlen layout" + assert ( + max_seqlens_q is not None and max_seqlens_q > 0 + ), "max_seqlens_q must be provided and positive for varlen layout" + assert ( + max_seqlens_k is not None and max_seqlens_k > 0 + ), "max_seqlens_k must be provided and positive for varlen layout" + + # assert head dimensions + assert ( + head_size_q == head_size_k + ), f"head sizes must match: q={head_size_q}, k={head_size_k}" + assert ( + nheads_k == nheads_v + ), f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" + assert ( + nheads_q % nheads_k == 0 + ), f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" + + # assert output shapes + assert o.shape == ( + total_seqlen_q, + nheads_q, + head_size_v, + ), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_v)}" + + # assert cu_seqlens + assert ( + cu_seqlens_q.dtype == torch.int32 + ), f"cu_seqlens_q must be int32, got {cu_seqlens_q.dtype}" + assert ( + cu_seqlens_k.dtype == torch.int32 + ), f"cu_seqlens_k must be int32, got {cu_seqlens_k.dtype}" + assert cu_seqlens_q[0] == 0, "cu_seqlens_q must start with 0" + assert cu_seqlens_k[0] == 0, "cu_seqlens_k must start with 0" + assert ( + cu_seqlens_q[-1] == total_seqlen_q + ), f"cu_seqlens_q[-1] {cu_seqlens_q[-1]} != total_seqlen_q {total_seqlen_q}" + assert ( + cu_seqlens_k[-1] == total_seqlen_k + ), f"cu_seqlens_k[-1] {cu_seqlens_k[-1]} != total_seqlen_k {total_seqlen_k}" + + # set vars + batch = len(cu_seqlens_q) - 1 + head_size_qk = head_size_q - assert is_fp8(q) and is_fp8(k) and is_fp8(v), f"Non fp8 type found: q.dtype={q.dtype}, k.dtype={k.dtype}, v.dtype={v.dtype}. All tensors must be fp8." + # Assert softmax_lse tensor is large enough + assert ( + softmax_lse.shape[0] >= nheads_q + ), f"softmax_lse.shape[0]={softmax_lse.shape[0]} must be >= nheads_q={nheads_q}" + assert ( + softmax_lse.shape[1] >= total_seqlen_q + ), f"softmax_lse.shape[1]={softmax_lse.shape[1]} must be >= total_seqlen_q={total_seqlen_q}" + assert ( + softmax_lse.dtype == torch.float32 + ), f"softmax_lse must be float32, got {softmax_lse.dtype}" + assert ( + softmax_lse.device == q.device + ), f"softmax_lse must be on same device as q" - if is_fp8(o): - FP8_OUTPUT = True - assert descale_o is not None, f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor for the output." + # strides + stride_qb, stride_qh, stride_qm, stride_qd = ( + 0, + q.stride(1), + q.stride(0), + q.stride(2), + ) + stride_kb, stride_kh, stride_kn, stride_kd = ( + 0, + k.stride(1), + k.stride(0), + k.stride(2), + ) + stride_vb, stride_vh, stride_vn, stride_vd = ( + 0, + v.stride(1), + v.stride(0), + v.stride(2), + ) + stride_ob, stride_oh, stride_om, stride_od = ( + 0, + o.stride(1), + o.stride(0), + o.stride(2), + ) + stride_lse_z, stride_lse_h, stride_lse_m = ( + 0, + softmax_lse.stride(0), + softmax_lse.stride(1), + ) + else: + # shapes + batch_q, seqlen_q, nheads_q, head_size_q = q.shape + batch_k, seqlen_k, nheads_k, head_size_k = k.shape + batch_v, seqlen_v, nheads_v, head_size_v = v.shape + + # assert batch dimensions + assert ( + batch_q == batch_k == batch_v + ), f"batch sizes must match: q={batch_q}, k={batch_k}, v={batch_v}" + + # assert head dimensions + assert ( + head_size_q == head_size_k + ), f"head sizes must match: q={head_size_q}, k={head_size_k}" + assert ( + nheads_k == nheads_v + ), f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" + assert ( + nheads_q % nheads_k == 0 + ), f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" + + # assert sequence lengths + assert ( + seqlen_k == seqlen_v + ), f"k and v sequence lengths must match: k={seqlen_k}, v={seqlen_v}" + + # assert output shapes + assert o.shape == ( + batch_q, + seqlen_q, + nheads_q, + head_size_v, + ), f"o shape {o.shape} != expected {(batch_q, seqlen_q, nheads_q, head_size_v)}" + + # set vars + batch = batch_q + head_size_qk = head_size_q + max_seqlens_q = seqlen_q + max_seqlens_k = seqlen_k + + # Assert softmax_lse tensor is large enough + assert ( + softmax_lse.shape[0] >= batch + ), f"softmax_lse.shape[0]={softmax_lse.shape[0]} must be >= batch={batch}" + assert ( + softmax_lse.shape[1] >= nheads_q + ), f"softmax_lse.shape[1]={softmax_lse.shape[1]} must be >= nheads_q={nheads_q}" + assert ( + softmax_lse.shape[2] >= seqlen_q + ), f"softmax_lse.shape[2]={softmax_lse.shape[2]} must be >= seqlen_q={seqlen_q}" + assert ( + softmax_lse.dtype == torch.float32 + ), f"softmax_lse must be float32, got {softmax_lse.dtype}" + assert ( + softmax_lse.device == q.device + ), f"softmax_lse must be on same device as q" + + # strides + stride_qb, stride_qh, stride_qm, stride_qd = ( + q.stride(0), + q.stride(2), + q.stride(1), + q.stride(3), + ) + stride_kb, stride_kh, stride_kn, stride_kd = ( + k.stride(0), + k.stride(2), + k.stride(1), + k.stride(3), + ) + stride_vb, stride_vh, stride_vn, stride_vd = ( + v.stride(0), + v.stride(2), + v.stride(1), + v.stride(3), + ) + stride_ob, stride_oh, stride_om, stride_od = ( + o.stride(0), + o.stride(2), + o.stride(1), + o.stride(3), + ) + stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() + + # apply rotary embeddings + if rotary_cos is not None and rotary_sin is not None: + if IS_VARLEN: + raise NotImplementedError( + "Rotary embeddings with varlen (thd layout) prefill are not implemented yet." + ) + seqlen_offsets = seqlens_rotary if seqlens_rotary is not None else 0 + local = (window_size_left != -1) or (window_size_right != -1) + q, _ = apply_rotary( + q, + None, + rotary_cos, + rotary_sin, + causal=causal, + local=local, + interleaved=rotary_interleaved, + seqlen_offsets=seqlen_offsets, + ) + + # fp8 setup and assertions + IS_FP8 = is_fp8([q, k, v]) + if IS_FP8: + arch = get_arch() + if not arch.supports_fp8: + raise RuntimeError( + f"{arch.name} does not support FP8" + ) + FP8_MAX = torch.finfo(q.dtype).max + rec_dtype = arch.recommended_fp8_dtype(q.dtype) + if q.dtype != rec_dtype or k.dtype != rec_dtype or v.dtype != rec_dtype: + warnings.warn( + f"Use {rec_dtype} data type on {arch}. Got q: {q.dtype}, k: {k.dtype}, v: {v.dtype}", + UserWarning, + ) + + if (q_descale is None) or (k_descale is None) or (v_descale is None): + warnings.warn( + "FP8 tensors detected but descale factors not provided. Using default scale of 1.0", + UserWarning, + ) + # Create default descale tensors if not provided + if q_descale is None: + q_descale = torch.ones( + batch, nheads_q, dtype=torch.float32, device=q.device + ) + if k_descale is None: + k_descale = torch.ones( + batch, nheads_k, dtype=torch.float32, device=q.device + ) + if v_descale is None: + v_descale = torch.ones( + batch, nheads_k, dtype=torch.float32, device=q.device + ) else: - FP8_OUTPUT = False + # Enforce exact expected shapes; no reshaping or normalization. + assert ( + q_descale.dim() == 2 + and q_descale.shape[0] == batch + and q_descale.shape[1] == nheads_k + ), f"q_descale expected shape ({batch}, {nheads_k}) got {tuple(q_descale.shape)}" + assert ( + k_descale.dim() == 2 + and k_descale.shape[0] == batch + and k_descale.shape[1] == nheads_k + ), f"k_descale expected shape ({batch}, {nheads_k}) got {tuple(k_descale.shape)}" + assert ( + v_descale.dim() == 2 + and v_descale.shape[0] == batch + and v_descale.shape[1] == nheads_k + ), f"v_descale expected shape ({batch}, {nheads_k}) got {tuple(v_descale.shape)}" - # Get strides for the kernel - stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None - stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None - stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None - stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None + # o should be fp32 or fp16/bf16 + assert o.dtype in [ + torch.float16, + torch.bfloat16, + torch.float32, + ], f"Output tensor o must be fp16, bf16, or fp32 when using fp8, got {o.dtype}" + + stride_q_descale_z = q_descale.stride(0) if q_descale is not None else 0 + stride_k_descale_z = k_descale.stride(0) if k_descale is not None else 0 + stride_v_descale_z = v_descale.stride(0) if v_descale is not None else 0 + + if DEBUG: + print(f"FP8 path triggered in fwd_prefill.py") else: FP8_MAX = None - FP8_OUTPUT = False - descale_q = descale_k = descale_v = descale_o = None - stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_o_z = None - - # check flags - is_varlen = layout == "thd" - use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) - is_inference = False if cache_seqlens is None else True - if is_inference: - assert layout == "bshd", f"{layout} layout is not supported with inference. Use bshd layout" - if DEBUG: - print(f"is_inference:", is_inference) + q_descale = k_descale = v_descale = None + stride_q_descale_z = stride_k_descale_z = stride_v_descale_z = None - # NOTE: a large bias tensor leads to overflow during pointer arithmetic - if (bias is not None): - assert (bias.numel() < 2**31) + # check output dtype matches input dtype when not using fp8 + assert ( + o.dtype == q.dtype + ), f"Output dtype {o.dtype} must match input dtype {q.dtype} when not using fp8" - batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shapes_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k) - q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout) + # check features + use_sliding_window = window_size_left != -1 or window_size_right != -1 + use_alibi, (stride_az, stride_ah) = ( + (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) + ) + # NOTE: a large bias tensor leads to overflow during pointer arithmetic + if bias is not None: + assert bias.numel() < 2**31 - # Get closest power of 2 over or equal to 32. - padded_d_model = 1 << (head_size - 1).bit_length() + # Get closest power of 2 over or equal to 32 for both QK and V dimensions + padded_d_model_qk = 1 << (head_size_qk - 1).bit_length() + padded_d_model_v = 1 << (head_size_v - 1).bit_length() # Smallest head_dim supported is 16. If smaller, the tile in the # kernel is padded - there is no padding in memory for any dims. - padded_d_model = max(padded_d_model, 16) - - grid = lambda META: (triton.cdiv(max_seqlens_q, META['BLOCK_M']), nheads_q, batch) - - # sd_mask is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out - # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according - # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing - # only. This return holds no useful output aside from debugging. - use_dropout = (dropout_p > 0.0) - if use_dropout or return_softmax: - sd_mask = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, - dtype=torch.float32) - if DROPOUT_USE_PYTORCH: - dropout_mask = create_dropout_mask(dropout_p, (batch, nheads_q, max_seqlens_q, max_seqlens_k), seed = philox_seed) - else: - dropout_mask = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, - dtype=torch.float32) - scores_strides = (sd_mask.stride(0), sd_mask.stride(1), sd_mask.stride(2), sd_mask.stride(3)) - else: - sd_mask = None - dropout_mask = None - scores_strides = (0, 0, 0, 0) - - # stores LSE the log of the normalization constant / sum of expoential score(unnormalzied probablities) - if is_varlen: - total_seqlen_q, _, _ = q.shape - softmax_lse = torch.zeros((nheads_q, total_seqlen_q), device=q.device, dtype=torch.float32) - stride_lse_h, stride_lse_m = softmax_lse.stride() - stride_lse_z = 0 + padded_d_model_qk = max(padded_d_model_qk, 16) + padded_d_model_v = max(padded_d_model_v, 16) + + # sd_mask assertions and strides + if sd_mask is not None: + assert dropout_p > 0.0 or return_scores, "sd_mask provided but not used" + assert ( + sd_mask is not None + ), "sd_mask must be provided when return_scores=True or dropout_p > 0" + # Assert sd_mask tensor is large enough + assert ( + sd_mask.shape[0] >= batch + ), f"sd_mask.shape[0]={sd_mask.shape[0]} must be >= batch={batch}" + assert ( + sd_mask.shape[1] >= nheads_q + ), f"sd_mask.shape[1]={sd_mask.shape[1]} must be >= nheads_q={nheads_q}" + assert ( + sd_mask.shape[2] >= max_seqlens_q + ), f"sd_mask.shape[2]={sd_mask.shape[2]} must be >= max_seqlens_q={max_seqlens_q}" + assert ( + sd_mask.shape[3] >= max_seqlens_k + ), f"sd_mask.shape[3]={sd_mask.shape[3]} must be >= max_seqlens_k={max_seqlens_k}" + assert sd_mask.device == q.device, f"sd_mask must be on same device as q" + + stride_sz, stride_sh, stride_sm, stride_sn = ( + sd_mask.stride(0), + sd_mask.stride(1), + sd_mask.stride(2), + sd_mask.stride(3), + ) else: - softmax_lse = torch.zeros((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32) - stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() + stride_sz, stride_sh, stride_sm, stride_sn = (0, 0, 0, 0) if bias is not None: - bias_strides = (bias.stride(0), bias.stride(1),bias.stride(2), - bias.stride(3)) + stride_bz, stride_bh, stride_bm, stride_bn = ( + bias.stride(0), + bias.stride(1), + bias.stride(2), + bias.stride(3), + ) else: - bias_strides = (0, 0, 0, 0) - - attn_fwd[grid](q, k, v, bias, cache_seqlens, cache_batch_idx, - descale_q, descale_k, descale_v, descale_o, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_o_z, - sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides, - *bias_strides, stride_az, stride_ah, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, - dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes, - HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, - MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, IS_VARLEN=is_varlen, IS_INFERENCE=is_inference, - BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, - USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p - > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT) - - return softmax_lse, sd_mask if return_softmax else None + stride_bz, stride_bh, stride_bm, stride_bn = (0, 0, 0, 0) + + # Detect if we need to force masking for all blocks (required on some architectures) + arch = get_arch() + force_masking = arch.is_rdna + + # launch kernel + grid = lambda META: (batch, nheads_q, triton.cdiv(max_seqlens_q, META["BLOCK_M"])) + attn_fwd[grid]( + q, + k, + v, + bias, + q_descale, + k_descale, + v_descale, + stride_q_descale_z, + stride_k_descale_z, + stride_v_descale_z, + softmax_lse, + o, + sd_mask, + alibi_slopes, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_bz, + stride_bh, + stride_bm, + stride_bn, + stride_az, + stride_ah, + stride_sz, + stride_sh, + stride_sm, + stride_sn, + stride_lse_z, + stride_lse_h, + stride_lse_m, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, # Pass seqused tensors + dropout_p=dropout_p, + philox_seed=philox_seed, + philox_offset_base=philox_offset, + HQ=nheads_q, + HK=nheads_k, + ACTUAL_BLOCK_DMODEL_QK=head_size_qk, + ACTUAL_BLOCK_DMODEL_V=head_size_v, + MAX_SEQLENS_Q=max_seqlens_q, + MAX_SEQLENS_K=max_seqlens_k, + SM_SCALE=sm_scale, + IS_CAUSAL=causal, + USE_SLIDING_WINDOW=use_sliding_window, + WINDOW_SIZE_LEFT=window_size_left, + WINDOW_SIZE_RIGHT=window_size_right, + IS_VARLEN=IS_VARLEN, + BLOCK_DMODEL_QK=padded_d_model_qk, + BLOCK_DMODEL_V=padded_d_model_v, + USE_BIAS=False if bias is None else True, + USE_ALIBI=use_alibi, + ENABLE_DROPOUT=dropout_p > 0.0, + USE_EXP2=use_exp2, + RETURN_SCORES=return_scores, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_P_DESCALE=False, + USE_SEQUSED=(seqused_q is not None or seqused_k is not None), + FORCE_MASKING=force_masking, + ) diff --git a/flash_attn/flash_attn_triton_amd/fwd_ref.py b/flash_attn/flash_attn_triton_amd/fwd_ref.py deleted file mode 100644 index baefb2410c1..00000000000 --- a/flash_attn/flash_attn_triton_amd/fwd_ref.py +++ /dev/null @@ -1,387 +0,0 @@ -import torch -import math -from typing import Literal, Optional -from .utils import DEBUG, compute_alibi_tensor_ref - -DEBUG_CORE = False - -def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2): - if DEBUG_CORE: - print() - print("attention_forward_core_ref_impl") - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("sm_scale:", sm_scale) - print("causal:", causal) - print("dropout_p:", dropout_p) - print("philox_seed:", philox_seed) - print("philox_offset:", philox_offset) - print("use_exp2:", use_exp2) - - # cast to float32 - q = q.to(torch.float32) - k = k.to(torch.float32) - v = v.to(torch.float32) - - # Compute attention scores - attention_scores = torch.matmul(q, k.transpose(-2, -1)) - if DEBUG_CORE: - print("attention_scores:", attention_scores, attention_scores.shape) - - # Scale scores - attention_scaled_scores = sm_scale * attention_scores - if DEBUG_CORE: - print("attention_scaled_scores:", attention_scaled_scores, attention_scaled_scores.shape) - - # Apply ALiBi if slopes are provided - if alibi_slopes is not None: - L_q, L_k = q.shape[1], k.shape[1] - if DEBUG_CORE: - print("alibi_slopes:", alibi_slopes, alibi_slopes.shape) - alibi_bias = compute_alibi_tensor_ref(alibi_slopes, L_q, L_k) - if DEBUG_CORE: - print("alibi_bias:", alibi_bias, alibi_bias.shape) - alibi_bias = alibi_bias.reshape(-1, L_q, L_k) - if DEBUG_CORE: - print("alibi_bias_flat:", alibi_bias, alibi_bias.shape) - attention_scaled_scores = attention_scaled_scores + alibi_bias - if DEBUG_CORE: - print("attention_scaled_scores after alibi:", attention_scaled_scores, attention_scaled_scores.shape) - - - # Apply causal mask if necessary - if causal: - L_q, L_k = q.shape[1], k.shape[1] - row_idx = torch.arange(L_q, device=q.device).unsqueeze(1) - col_idx = torch.arange(L_k, device=q.device).unsqueeze(0) - col_offset = L_q-L_k - causal_mask = row_idx >= (col_offset + col_idx) - if DEBUG_CORE: - print("causal_mask:", causal_mask) - # set -inf to places the causal mask is false - attention_scaled_scores = attention_scaled_scores.masked_fill( - torch.logical_not(causal_mask.unsqueeze(0)), float('-inf') - ) - if DEBUG_CORE: - print("attention_scaled_scores after causal:", attention_scaled_scores, attention_scaled_scores.shape) - - # Compute max for numerical stability - max_scores = torch.max(attention_scaled_scores, dim=-1, keepdim=True)[0] - if DEBUG_CORE: - print("max_scores:", max_scores, max_scores.shape) - if causal: - # Replace -inf in max_scores with zeros to avoid NaN in subtraction - max_scores = torch.where( - torch.isinf(max_scores), torch.zeros_like(max_scores), max_scores - ) - if DEBUG: - print("max_scores if causal:", max_scores, max_scores.shape) - - # Shift scores - attention_shifted_scaled_scores = attention_scaled_scores - max_scores - if DEBUG_CORE: - print("attention_shifted_scaled_scores:", attention_shifted_scaled_scores, attention_shifted_scaled_scores.shape) - - # Exponentiate - if use_exp2: - RCP_LN = 1 / math.log(2) - exp_scores = torch.exp2(RCP_LN * attention_shifted_scaled_scores) - else: - exp_scores = torch.exp(attention_shifted_scaled_scores) - - if DEBUG_CORE: - print("exp_scores:", exp_scores, exp_scores.shape) - - # Sum of exponentials - sum_exp_scores = torch.sum(exp_scores, dim=-1, keepdim=True) - if DEBUG_CORE: - print("sum_exp_scores:", sum_exp_scores, sum_exp_scores.shape) - if causal: - # if sum of exp scores is 0.0 it means scores where -inf, we cannot compute softmax and softmax_lse. Setting to 1 deals with -inf case cleanly - sum_exp_scores = torch.where( - sum_exp_scores == 0, - torch.ones_like(sum_exp_scores), - sum_exp_scores - ) - if DEBUG_CORE: - print("sum_exp_scores:", sum_exp_scores, sum_exp_scores.shape) - - # Compute softmax probabilities - p = exp_scores / sum_exp_scores - - if DEBUG_CORE: - print("softmax:", p, p.shape) - - # apply dropout if specified - if dropout_p > 0.0: - rand_vals = torch.rand(p.shape, generator=torch.Generator(device=p.device).manual_seed(philox_seed), device=p.device, dtype=p.dtype) - dropout_mask, dropout_scale = rand_vals > dropout_p, (1.0 / (1 - dropout_p)) - if DEBUG_CORE: - print("dropout_scale:", dropout_scale) - print("dropout_mask:", dropout_mask) - # Apply dropout mask and scale - # Set -1 for dropped positions and 1 for kept positions in exp_scores - sd_mask = torch.where(dropout_mask, exp_scores, -exp_scores) - p = torch.where(dropout_mask, p , torch.zeros_like(p)) * dropout_scale - if DEBUG_CORE: - print("softmax after dropout:", p) - print("sd_mask:", sd_mask) - else: - sd_mask = exp_scores - - # Compute log-sum-exp - if use_exp2: - LN2 = math.log(2) - RCP_LN = 1 / math.log(2) - max_scores_base2 = max_scores * RCP_LN - softmax_lse_base2 = max_scores_base2 + torch.log2(sum_exp_scores) - softmax_lse = softmax_lse_base2 * LN2 - softmax_lse.squeeze_(-1) - else: - softmax_lse = max_scores + torch.log(sum_exp_scores) - softmax_lse = softmax_lse.squeeze(-1) - - if DEBUG_CORE: - print("softmax_lse:", softmax_lse, softmax_lse.shape) - - # Compute output - o = torch.matmul(p, v) - if DEBUG_CORE: - print("o:", o, o.shape) - - # cast back to original dtype - o = o.to(torch.float16) - # softmax_lse = softmax_lse.to(torch.float16) # NOTE: if you cast lse to fp16 it cause accuracy issues. keep fp32 - sd_mask = sd_mask.to(torch.float16) - - return o, softmax_lse, sd_mask - -def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2): - """Compute reference output and softmax_lse using PyTorch's built-in function""" - - # Ensure the layout is 'bhsd' - if layout == "bshd": - q = q.transpose(1, 2).contiguous() - k = k.transpose(1, 2).contiguous() - v = v.transpose(1, 2).contiguous() - elif layout != "bhsd": - raise ValueError(f"Unknown layout {layout}") - - # Prepare tensors - batch_size, nheads_q, seq_len_q, head_dim = q.shape - batch_size, nheads_k, seq_len_k, head_dim = k.shape - group_size = nheads_q // nheads_k - if nheads_q % nheads_k != 0: - raise ValueError("nheads_q must be divisible by nheads_k") - - if group_size != 1: - # MQA or GQA case - # Reshape q to [batch_size, nheads_k, group_size, seq_len_q, head_dim] - q = q.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) - # Expand k and v to match group_size - k = k.unsqueeze(2).expand(-1, -1, group_size, -1, -1) - v = v.unsqueeze(2).expand(-1, -1, group_size, -1, -1) - # Flatten the first three dimensions for computation - q = q.reshape(batch_size * nheads_k * group_size, seq_len_q, head_dim) - k = k.reshape(batch_size * nheads_k * group_size, seq_len_k, head_dim) - v = v.reshape(batch_size * nheads_k * group_size, seq_len_k, head_dim) - else: - q = q.reshape(batch_size * nheads_q, seq_len_q, head_dim) - k = k.reshape(batch_size * nheads_k, seq_len_k, head_dim) - v = v.reshape(batch_size * nheads_k, seq_len_k, head_dim) - - # Call the core attention function - o, softmax_lse, sd_mask = attention_forward_core_ref_impl( - q, k, v, sm_scale, causal, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2 - ) - - if group_size != 1: - # Reshape outputs back to original dimensions - o = o.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) - o = o.reshape(batch_size, nheads_q, seq_len_q, head_dim) - softmax_lse = softmax_lse.reshape(batch_size, nheads_k, group_size, seq_len_q) - softmax_lse = softmax_lse.reshape(batch_size, nheads_q, seq_len_q) - sd_mask = sd_mask.reshape(batch_size, nheads_k, group_size, seq_len_q, seq_len_k) - sd_mask = sd_mask.reshape(batch_size, nheads_q, seq_len_q, seq_len_k) - else: - # Standard case - o = o.reshape(batch_size, nheads_q, seq_len_q, head_dim) - softmax_lse = softmax_lse.reshape(batch_size, nheads_q, seq_len_q) - sd_mask = sd_mask.reshape(batch_size, nheads_q, seq_len_q, seq_len_k) - - # Restore original layout if necessary - if layout == "bshd": - o = o.transpose(1, 2) - - return o, softmax_lse, sd_mask - - -def attention_varlen_forward_pytorch_ref_impl( - q, - k, - v, - sm_scale, - causal, - layout, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - use_exp2 -): - # Ensure the layout is 'thd' - if layout != 'thd': - raise ValueError(f"Unsupported layout {layout}. Expected 'thd'.") - - batch_size = cu_seqlens_q.shape[0] - 1 - nheads_q, nheads_k = q.shape[1], k.shape[1] - head_dim = q.shape[2] - - # Pre-allocate outputs - total_L_q = q.shape[0] - total_L_k = k.shape[0] - - o = torch.zeros((total_L_q, nheads_q, head_dim), dtype=q.dtype, device=q.device) - softmax_lse = torch.zeros((total_L_q, nheads_q), dtype=torch.float32, device=q.device) - sd_mask = torch.zeros((batch_size, nheads_q, max_seqlen_q, max_seqlen_k), dtype=torch.float32, device=q.device) - - # Compute group_size for MQA/GQA handling - group_size = nheads_q // nheads_k - if nheads_q % nheads_k != 0: - raise ValueError("nheads_q must be divisible by nheads_k") - - for i in range(batch_size): - # Get the start and end indices for the current sequence - start_q = cu_seqlens_q[i].item() - end_q = cu_seqlens_q[i + 1].item() - start_k = cu_seqlens_k[i].item() - end_k = cu_seqlens_k[i + 1].item() - - seqlen_q = end_q - start_q - seqlen_k = end_k - start_k - - if DEBUG: - print(f"Batch {i} with seqlen_q = {seqlen_q}, seqlen_k = {seqlen_k}, Hq= {nheads_q}, Hk = {nheads_k}") - - # Extract q_i, k_i, v_i - q_i = q[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] - k_i = k[start_k:end_k, :, :] # [L_k_i, nheads_k, head_dim] - v_i = v[start_k:end_k, :, :] # [L_k_i, nheads_k, head_dim] - - # Permute to [nheads, L_q_i, head_dim] - q_i = q_i.permute(1, 0, 2) - k_i = k_i.permute(1, 0, 2) - v_i = v_i.permute(1, 0, 2) - - # Handle MQA/GQA by adjusting shapes based on group_size - if group_size != 1: - # Reshape q_i to [nheads_k, group_size, L_q_i, head_dim] - q_i = q_i.reshape(nheads_k, group_size, seqlen_q, head_dim) - # Expand k_i and v_i to match group_size - k_i = k_i.unsqueeze(1).expand(-1, group_size, -1, -1) - v_i = v_i.unsqueeze(1).expand(-1, group_size, -1, -1) - # Flatten the first two dimensions for computation - q_i = q_i.reshape(nheads_k * group_size, seqlen_q, head_dim) - k_i = k_i.reshape(nheads_k * group_size, seqlen_k, head_dim) - v_i = v_i.reshape(nheads_k * group_size, seqlen_k, head_dim) - else: - # Standard case - q_i = q_i.reshape(nheads_q, seqlen_q, head_dim) - k_i = k_i.reshape(nheads_k, seqlen_k, head_dim) - v_i = v_i.reshape(nheads_k, seqlen_k, head_dim) - - if alibi_slopes is not None: - alibi_slopes_i = alibi_slopes[i] - else: - alibi_slopes_i = None - - # Call the core attention function for this sequence - o_i, softmax_lse_i, sd_mask_i = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, dropout_p, philox_seed, philox_offset, alibi_slopes_i, use_exp2) - - # Reshape outputs back to original dimensions - if group_size != 1: - # Reshape outputs to [nheads_k, group_size, seqlen_q, head_dim] - o_i = o_i.reshape(nheads_k, group_size, seqlen_q, head_dim) - # Combine the first two dimensions back to nheads_q - o_i = o_i.reshape(nheads_q, seqlen_q, head_dim) - # Reshape softmax_lse_i similarly - softmax_lse_i = softmax_lse_i.reshape(nheads_k, group_size, seqlen_q) - softmax_lse_i = softmax_lse_i.reshape(nheads_q, seqlen_q) - else: - # Outputs are already in the correct shape - pass - - # Convert back to 'thd' layout - o_i = o_i.permute(1, 0, 2) # [L_q_i, nheads_q, head_dim] - softmax_lse_i = softmax_lse_i.permute(1, 0) # [L_q_i, nheads_q] - sd_mask_i = sd_mask_i # [nheads_q, L_q_i, L_k_i] - - # Place outputs in pre-allocated tensors - o[start_q:end_q, :, :] = o_i - softmax_lse[start_q:end_q, :] = softmax_lse_i - sd_mask[i, :, :seqlen_q, :seqlen_k] = sd_mask_i - - return o, softmax_lse, sd_mask - - - -def attention_forward_pytorch_ref_impl( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - sm_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - layout: Literal["bshd", "bhsd", "thd"], - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - philox_seed: Optional[int], - philox_offset: Optional[int], - use_exp2: bool -): - # compute reference - if layout == "thd": - o_ref, softmax_lse_ref, sd_mask_ref = attention_varlen_forward_pytorch_ref_impl( - q.clone(), - k.clone(), - v.clone(), - sm_scale, - causal, - layout, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - use_exp2, - ) - else: - o_ref, softmax_lse_ref, sd_mask_ref = attention_vanilla_forward_pytorch_ref_impl( - q.clone(), - k.clone(), - v.clone(), - sm_scale, - causal, - layout, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - use_exp2) - - # copy back to ouput tensor - out.copy_(o_ref.to(out.dtype)) - - return softmax_lse_ref, sd_mask_ref diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py deleted file mode 100644 index 06ab7d24d56..00000000000 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ /dev/null @@ -1,792 +0,0 @@ -import torch -import os -from .fwd_prefill import attention_prefill_forward_triton_impl -from .bwd_prefill import attention_prefill_backward_triton_impl -from .bwd_prefill_split import attention_prefill_backward_triton_split_impl -from .bwd_prefill_fused import _flash_attn_backward as attention_prefill_backward_triton_fused_impl -from .bwd_prefill_onekernel import attention_prefill_backward_triton_split_oneKernel_impl -from .fwd_decode import attention_decode_forward_triton_impl -from .fwd_ref import attention_forward_pytorch_ref_impl -from .bwd_ref import attention_backward_pytorch_ref_impl -from .utils import DEBUG, USE_REF, MetaData, get_shapes_from_layout, is_fp8 -from einops import rearrange, repeat -from flash_attn.layers.rotary import apply_rotary_emb -from typing import Literal, Optional, Union - -def fwd(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: Optional[torch.Tensor], - alibi_slopes: Optional[torch.Tensor], - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - return_softmax: bool, - gen_: Optional[torch.Tensor] = None, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_o: Optional[torch.Tensor] = None - ): - - if DEBUG: - print() - print("flash_attn_triton_amd.py::fwd inputs") - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("out:", out, out.shape if out is not None else None) - print("alibi_slopes:", alibi_slopes) - print("dropout_p:", dropout_p) - print("softmax_scale:", softmax_scale) - print("causal:", causal) - print("window_size_left:", window_size_left) - print("window_size_right:", window_size_right) - print("softcap:", softcap) - print("return_softmax:", return_softmax) - print("descale_q:", descale_q, descale_q.shape if descale_q is not None else None) - print("descale_k:", descale_k, descale_k.shape if descale_k is not None else None) - print("descale_v:", descale_v, descale_v.shape if descale_v is not None else None) - print("descale_o:", descale_o, descale_o.shape if descale_o is not None else None) - - if is_fp8(q): - assert out is not None, "fp8 output tensor should be passed in." - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"For fp8, you need to pass descale factors for q, k and v" - else: - out = torch.zeros_like(q) if out is None else out.zero_() - - # Setup metadata - metadata = MetaData(sm_scale=softmax_scale) - metadata.max_seqlens_q = q.shape[1] - metadata.max_seqlens_k = k.shape[1] - metadata.layout = "bshd" - if return_softmax: - metadata.return_scores = True - - batch, nheads_q, nheads_k, head_size, _, _ = get_shapes_from_layout(q, k, metadata.layout) - - if causal: - metadata.need_causal(True) - - if alibi_slopes is not None: - metadata.need_alibi(alibi_slopes, batch, nheads_q) - - # store rng state - metadata.need_dropout(dropout_p, return_softmax) - rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) # as_tensors uses the underlying data and doesnot cast - - # check arguments - metadata.check_args(q, k, v, out) - - # call implementation - if USE_REF: - if DEBUG: - print("Using reference implementation") - softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( - q, - k, - v, - out, - metadata.sm_scale, - metadata.alibi_slopes, - metadata.causal, - metadata.layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - metadata.use_exp2) - softmax_lse=softmax_lse_ref - sd_mask=sd_mask_ref - else: - if DEBUG: - print("Using Triton implementation") - softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl( - q, - k, - v, - out, - metadata.sm_scale, - metadata.alibi_slopes, - metadata.causal, - None, - metadata.layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.cache_seqlens, - metadata.cache_batch_idx, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - metadata.return_scores, - metadata.use_exp2, - descale_q, - descale_k, - descale_v, - descale_o) - softmax_lse=softmax_lse_triton - sd_mask=sd_mask_triton - - if DEBUG: - print("flash_attn_triton_amd.py::fwd outputs") - print("o:", out, out.shape) - if is_fp8(out): - print("descale_o:", descale_o, descale_o.shape if descale_o is not None else None) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None ) - - return out, softmax_lse, sd_mask, rng_state - -BWD_MODE = os.environ.get('BWD_MODE', 'split').lower() -def bwd( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - alibi_slopes: Optional[torch.Tensor], - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - deterministic: bool, - gen_: Optional[torch.Tensor] = None, - rng_state:Optional[torch.Tensor] = None, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_o: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None, - descale_dq: Optional[torch.Tensor] = None, - descale_dk: Optional[torch.Tensor] = None, - descale_dv: Optional[torch.Tensor] = None, -): - if DEBUG: - print() - print("flash_attn_triton_amd.py::bwd inputs") - print("dout:", dout, dout.shape) - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("out:", out, out.shape) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("dq:", dq, dq.shape if dq is not None else None) - print("dk:", dk, dk.shape if dk is not None else None) - print("dv:", dv, dv.shape if dv is not None else None) - print("alibi_slopes:", alibi_slopes) - print("dropout_p:", dropout_p) - print("out:", out) - print("softmax_scale:", softmax_scale) - print("causal:", causal) - print("window_size_left:", window_size_left) - print("window_size_right:", window_size_right) - print("deterministic:", deterministic) - print("gen_:", gen_) - print("rng_state:", rng_state) - print("descale_q:", descale_q, descale_q.shape if descale_q is not None else None) - print("descale_k:", descale_k, descale_k.shape if descale_k is not None else None) - print("descale_v:", descale_v, descale_v.shape if descale_v is not None else None) - print("descale_o:", descale_o, descale_o.shape if descale_o is not None else None) - print("descale_do:", descale_do, descale_do.shape if descale_do is not None else None) - print("descale_dq:", descale_dq, descale_dq.shape if descale_dq is not None else None) - print("descale_dk:", descale_dk, descale_dk.shape if descale_dk is not None else None) - print("descale_dv:", descale_dv, descale_dv.shape if descale_dv is not None else None) - - dq = torch.zeros_like(q) if dq is None else dq.zero_() - dk = torch.zeros_like(k) if dk is None else dk.zero_() - dv = torch.zeros_like(v) if dv is None else dv.zero_() - - if rng_state is not None: - philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() - else: - philox_seed, philox_offset = None, None - - # call implementation - if USE_REF: - if DEBUG: - print("Using reference implementation") - - delta_ref = attention_backward_pytorch_ref_impl( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - softmax_scale, - alibi_slopes, - causal, - "bshd", - None, - None, - None, - None, - dropout_p, - philox_seed, - philox_offset, - False, - ) - delta = delta_ref - else: - if DEBUG: - print("Using Triton implementation") - if BWD_MODE == "split": - delta_triton = attention_prefill_backward_triton_split_impl( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - softmax_scale, - alibi_slopes, - causal, - "bshd", - None, - None, - None, - None, - dropout_p, - philox_seed, - philox_offset, - False, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - descale_dq, - descale_dk, - descale_dv, - ) - delta = delta_triton - elif BWD_MODE == "fused": - delta_triton = attention_prefill_backward_triton_fused_impl( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - softmax_scale, - alibi_slopes, - causal, - None, - None, - q.shape[1], - k.shape[1], - dropout_p, - philox_seed, - philox_offset, - descale_q, - descale_k, - descale_v, - descale_o, - True, - ) - delta = delta_triton - elif BWD_MODE == "jingning": - delta_triton = attention_prefill_backward_triton_split_oneKernel_impl( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - softmax_scale, - alibi_slopes, - causal, - "bshd", - None, - None, - None, - None, - dropout_p, - philox_seed, - philox_offset, - False - ) - delta = delta_triton - else: - raise ValueError(f"Unknown bwd mode {BWD_MODE}") - - if DEBUG: - print("flash_attn_triton_amd.py::bwd outputs") - print("dv:", dv, dv.shape) - if is_fp8(dv): - print("descale_dv:", descale_dv, descale_dv.shape if descale_dv is not None else None) - print("dk:", dk, dk.shape) - if is_fp8(dk): - print("descale_dk:", descale_dk, descale_dk.shape if descale_dk is not None else None) - print("dq:", dq, dq.shape) - if is_fp8(dq): - print("descale_dq:", descale_dq, descale_dq.shape if descale_dq is not None else None) - return dq, dk, dv, delta - -def varlen_fwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: Optional[torch.Tensor], - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - seqused_k: Optional[torch.Tensor], - leftpad_k: Optional[torch.Tensor], - block_table_: Optional[torch.Tensor], - alibi_slopes: Optional[torch.Tensor], - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - softmax_scale: float, - zero_tensors: bool , - causal: bool , - window_size_left: int, - window_size_right: int, - softcap: float, - return_softmax: bool, - gen_: Optional[torch.Tensor] = None, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_o: Optional[torch.Tensor] = None - ): - - if DEBUG: - print() - print("flash_attn_triton_amd.py::varlen_fwd") - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape) - print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape) - print("alibi_slopes:", alibi_slopes) - print("max_seqlen_q:", max_seqlen_q) - print("max_seqlen_k:", max_seqlen_k) - print("dropout_p:", dropout_p) - print("softmax_scale:", softmax_scale) - print("causal:", causal) - print("window_size_left:", window_size_left) - print("window_size_right:", window_size_right) - print("gen_:", gen_) - print("descale_q:", descale_q, descale_q.shape if descale_q is not None else None) - print("descale_k:", descale_k, descale_k.shape if descale_k is not None else None) - print("descale_v:", descale_v, descale_v.shape if descale_v is not None else None) - - if is_fp8(q): - assert out is not None, "fp8 output tensor should be passed in." - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"For fp8, you need to pass descale factors for q, k and v" - else: - out = torch.zeros_like(q) if out is None else out.zero_() - - # Setup metadata - metadata = MetaData(sm_scale=softmax_scale) - if return_softmax: - metadata.return_scores = True - metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) # set layout to "thd" and other metdata - assert metadata.layout is not None - - # get shapes - batch, nheads_q, nheads_k, head_size , seqlen_q, seqlen_k = get_shapes_from_layout(q, k, metadata.layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) - - if causal: - metadata.need_causal(True) - - if alibi_slopes is not None: - metadata.need_alibi(alibi_slopes, batch, nheads_q) - - # store rng state - metadata.need_dropout(dropout_p, return_softmax) - rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) # as_tensors uses the underlying data and doesnot cast - - # Check arguments - metadata.check_args(q, k, v, out) - - # call implementation - if USE_REF: - if DEBUG: - print("Using reference implementation") - softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( - q, - k, - v, - out, - metadata.sm_scale, - metadata.alibi_slopes, - metadata.causal, - metadata.layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - metadata.use_exp2) - softmax_lse=softmax_lse_ref - sd_mask=sd_mask_ref - else: - if DEBUG: - print("Using Triton implementation") - softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl( - q, - k, - v, - out, - metadata.sm_scale, - metadata.alibi_slopes, - metadata.causal, - None, - metadata.layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.cache_seqlens, - metadata.cache_batch_idx, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - metadata.return_scores, - metadata.use_exp2, - descale_q, - descale_k, - descale_v, - descale_o) - softmax_lse=softmax_lse_triton - sd_mask=sd_mask_triton - - if DEBUG: - print("varlen_fwd outputs") - print("out:", out, out.shape) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None ) - - - return out, softmax_lse, sd_mask, rng_state - -def varlen_bwd( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - alibi_slopes: Optional[torch.Tensor], - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - softmax_scale: float, - zero_tensors: bool, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - deterministic: bool, - gen_ : Optional[torch.Tensor] = None, - rng_state: Optional[torch.Tensor] = None, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_o: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None, - descale_dq: Optional[torch.Tensor] = None, - descale_dk: Optional[torch.Tensor] = None, - descale_dv: Optional[torch.Tensor] = None, -): - if DEBUG: - print() - print("varlen_bwd") - print("dout:", dout, dout.shape) - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("out:", out) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("dq:", dq, dq.shape if dq is not None else None) - print("dk:", dk, dk.shape if dk is not None else None) - print("dv:", dv, dv.shape if dv is not None else None) - print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape) - print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape) - print("alibi_slopes:", alibi_slopes) - print("max_seqlen_q:", max_seqlen_q) - print("max_seqlen_k:", max_seqlen_k) - print("dropout_p:", dropout_p) - print("softmax_scale:", softmax_scale) - print("causal:", causal) - print("window_size_left:", window_size_left) - print("window_size_right:", window_size_right) - print("deterministic:", deterministic) - print("gen_:", gen_) - print("rng_state:", rng_state) - print("descale_q:", descale_q, descale_q.shape if descale_q is not None else None) - print("descale_k:", descale_k, descale_k.shape if descale_k is not None else None) - print("descale_v:", descale_v, descale_v.shape if descale_v is not None else None) - print("descale_do:", descale_do, descale_do.shape if descale_do else None) - - dq = torch.zeros_like(q) if dq is None else dq.zero_() - dk = torch.zeros_like(k) if dk is None else dk.zero_() - dv = torch.zeros_like(v) if dv is None else dv.zero_() - - if rng_state is not None: - philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() - else: - philox_seed, philox_offset = None, None - - # call implementation - if USE_REF: - if DEBUG: - print("Using reference implementation") - delta_ref = attention_backward_pytorch_ref_impl( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - softmax_scale, - alibi_slopes, - causal, - "thd", - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - False, - ) - delta = delta_ref - else: - if DEBUG: - print("Using Triton implementation") - delta_triton = attention_prefill_backward_triton_split_impl( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - softmax_scale, - alibi_slopes, - causal, - "thd", - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - False, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - descale_dq, - descale_dk, - descale_dv, - ) - delta = delta_triton - - if DEBUG: - print("varlen_bwd outputs") - print("delta:", delta, delta.shape) - print("dv:", dv, dv.shape) - print("dk:", dk, dk.shape) - print("dq:", dq, dq.shape) - - return dq, dk, dv, delta - -def fwd_kvcache( - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - k: Optional[torch.Tensor], - v: Optional[torch.Tensor], - cache_seqlens: Optional[Union[(int, torch.Tensor)]], - rotary_cos: Optional[torch.Tensor], - rotary_sin: Optional[torch.Tensor], - cache_batch_idx: Optional[torch.Tensor], - cache_leftpad: Optional[torch.Tensor], - block_table: Optional[torch.Tensor], - alibi_slopes: Optional[torch.Tensor], - out: Optional[torch.Tensor], - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - rotary_interleaved: bool, - num_splits: int - ): - - if DEBUG: - print() - print("flash_attn_triton_amd.py::fwd_kvcache inputs") - print("q:", q, q.shape) - print("k_cache:", k_cache, k_cache.shape) - print("v_cache:", v_cache, v_cache.shape) - print("k:", k, k.shape if k is not None else None) - print("v:", v, v.shape if v is not None else None) - print("cache_seqlens:", cache_seqlens ) - print("rotary_cos:",rotary_cos ) - print("rotary_sin:",rotary_sin) - print("cache_batch_idx:", cache_batch_idx) - print("cache_leftpad:", cache_leftpad) - print("block_table:", block_table) - print("alibi_slopes:", alibi_slopes) - print("out:", out) - print("softmax_scale:", softmax_scale) - print("causal:", causal) - print("window_size_left:", window_size_left) - print("window_size_right:", window_size_right) - print("softcap:", softcap) - print("rotary_interleaved:", rotary_interleaved) - print("num_splits:", num_splits) - - # output - out = torch.zeros_like(q) if out is None else out.zero_() - - # fill metadata - metadata = MetaData(sm_scale=softmax_scale) - metadata.layout = "bshd" - metadata.max_seqlens_q = q.shape[1] - metadata.max_seqlens_k = k_cache.shape[1] - metadata.cache_seqlens = cache_seqlens - metadata.cache_batch_idx = cache_batch_idx - - k_new = k - v_new = v - - if causal: - metadata.need_causal(True) - - if alibi_slopes is not None: - batch, _ , nheads_q, _= q.shape - metadata.need_alibi(alibi_slopes, batch, nheads_q) - - # rotary boolean - apply_rotary = torch.is_tensor(rotary_cos) and torch.is_tensor(rotary_sin) - if apply_rotary: - metadata.need_rotary(rotary_sin, rotary_cos, rotary_interleaved) - - # Rotary Embedding Implementation - if apply_rotary: - if metadata.causal: # NOTE: when support is added. Add `or metadata.local` - q_ro = apply_rotary_emb( - q, - metadata.rotary_cos, - metadata.rotary_sin, - seqlen_offsets=metadata.cache_seqlens, - interleaved=metadata.rotary_interleaved, - ) - else: - q_ro = rearrange( - apply_rotary_emb( - rearrange(q, "b s h d -> b 1 (s h) d"), - metadata.rotary_cos, - metadata.rotary_sin, - seqlen_offsets=metadata.cache_seqlens, - interleaved=metadata.rotary_interleaved, - ), - "b 1 (s h) d -> b s h d", - s=metadata.max_seqlens_q, - ) - k_ro = apply_rotary_emb( - k_new, - metadata.rotary_cos, - metadata.rotary_sin, - seqlen_offsets=metadata.cache_seqlens, - interleaved=metadata.rotary_interleaved, - ) - - q, k_new = q_ro.to(q.dtype), k_ro.to(q.dtype) - - # launch kernel - DECODE_KERNEL= True # os.environ.get('DECODE_KERNEL', '0').lower() in ('1', 'true', 'yes') - if DECODE_KERNEL: - softmax_lse_triton = attention_decode_forward_triton_impl( - q, - k_cache, - v_cache, - k_new, - v_new, - out, - metadata.sm_scale, - metadata.causal, - metadata.alibi_slopes, - metadata.layout, - metadata.cache_seqlens, - metadata.cache_batch_idx, - ) - else: - softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl( - q, - k_cache, - v_cache, - out, - metadata.sm_scale, - metadata.alibi_slopes, - metadata.causal, - None, - metadata.layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.cache_seqlens, - metadata.cache_batch_idx, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - metadata.return_scores, - metadata.use_exp2, - None, - None, - None, - None) - softmax_lse = softmax_lse_triton - - if DEBUG: - print("out:", out, out.shape) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - return out, softmax_lse diff --git a/flash_attn/flash_attn_triton_amd/interface_v2.py b/flash_attn/flash_attn_triton_amd/interface_v2.py new file mode 100644 index 00000000000..e0669779be4 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/interface_v2.py @@ -0,0 +1,824 @@ +import torch +import os +from typing import Literal, Optional, Union +from .fwd_prefill import attention_forward_prefill_triton_impl +from .fwd_decode import attention_forward_decode_triton_impl +from .bwd import attention_backward_triton_impl +from .utils import ( + DEBUG, + USE_EXP2, + BWD_MODE, + PHILOX_SEED, + PHILOX_OFFSET, + SHAPE_EXPECTATIONS, + round_multiple, +) + + +def fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + dropout_p: float, + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + return_softmax: bool, + gen_: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + + # Reject FP8 tensors (FA2 AMD path does not support FP8) + if str(q.dtype).startswith("torch.float8"): + raise NotImplementedError( + "FP8 tensors are not supported in the AMD Triton FA2 interface. Use the FA3 path instead." + ) + + # Unsupported features assertions (keep behavior explicit like v3 shim) + if softcap != 0.0: + raise NotImplementedError( + "softcap is not supported in the AMD Triton FA2 interface (expected 0.0)." + ) + + if DEBUG: + print() + print("flash_attn_triton_amd.py::fwd inputs") + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) + print("out:", out.shape if out is not None else None) + print("alibi_slopes:", alibi_slopes.shape if alibi_slopes is not None else None) + print("dropout_p:", dropout_p) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("softcap:", softcap) + print("return_softmax:", return_softmax) + + if out is None: + out = torch.zeros_like(q) + else: + out.zero_() + + # Layout / shapes + layout: Literal["bshd", "bhsd", "thd"] = "bshd" + max_seqlen_q = q.shape[1] + max_seqlen_k = k.shape[1] + batch, _, nheads_q, _ = q.shape + + # Normalize / validate alibi + if alibi_slopes is not None: + if alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + assert alibi_slopes.is_cuda and alibi_slopes.dim() == 2 + assert alibi_slopes.shape == (batch, nheads_q) + + # Dropout + RNG seed + philox_seed, philox_offset = PHILOX_SEED, PHILOX_OFFSET + rng_state = torch.as_tensor([philox_seed, philox_offset]) + + # argument checks + assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4 + assert q.shape[-1] == k.shape[-1] == v.shape[-1] + assert q.dtype == k.dtype == v.dtype + assert out.shape[:-1] == q.shape[:-1] and out.shape[-1] == v.shape[-1] + nheads_k = k.shape[2] + assert (nheads_q % nheads_k) == 0 + + # Create output tensors based on shape expectations + if SHAPE_EXPECTATIONS == "rounded": + softmax_lse = torch.zeros( + (batch, nheads_q, round_multiple(max_seqlen_q, 128)), + device=q.device, + dtype=torch.float32, + ) + if dropout_p > 0.0 or return_softmax: + sd_mask = torch.zeros( + ( + batch, + nheads_q, + round_multiple(max_seqlen_q, 128), + round_multiple(max_seqlen_k, 128), + ), + device=q.device, + dtype=torch.float32, + ) + else: + sd_mask = None + else: + softmax_lse = torch.zeros( + (batch, nheads_q, max_seqlen_q), + device=q.device, + dtype=torch.float32, + ) + if dropout_p > 0.0 or return_softmax: + sd_mask = torch.zeros( + (batch, nheads_q, max_seqlen_q, max_seqlen_k), + device=q.device, + dtype=torch.float32, + ) + else: + sd_mask = None + + # call implementation + if DEBUG: + print("Using Triton implementation") + attention_forward_prefill_triton_impl( + q, + k, + v, + out, + softmax_lse, + sd_mask, + softmax_scale, + alibi_slopes, + causal, + window_size_left, + window_size_right, + None, + layout, + None, + None, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + return_softmax, + USE_EXP2, + None, + None, + None, + None, + None, + None, + None, + ) + + if DEBUG: + print("flash_attn_triton_amd.py::fwd outputs") + print("out:", out.shape) + print("softmax_lse:", softmax_lse.shape) + print("sd_mask:", sd_mask.shape if sd_mask is not None else None) + print("rng_state:", rng_state) + + # --- Assertions (shape + dtype contracts) --- + # out: (B, Sq, Hq, D) + assert out.shape == q.shape, f"[fwd] out shape {out.shape} != q shape {q.shape}" + # softmax_lse dtype + assert ( + softmax_lse.dtype == torch.float32 + ), f"[fwd] softmax_lse dtype {softmax_lse.dtype} != torch.float32" + # softmax_lse shape depends on SHAPE_EXPECTATIONS + if SHAPE_EXPECTATIONS == "rounded": + expected_lse_shape = (q.shape[0], q.shape[2], round_multiple(q.shape[1], 128)) + else: + expected_lse_shape = (q.shape[0], q.shape[2], q.shape[1]) + assert ( + softmax_lse.shape == expected_lse_shape + ), f"[fwd] softmax_lse shape {softmax_lse.shape} != {expected_lse_shape}" + if return_softmax: + # sd_mask: (B, Hq, Sq, Sk) + assert sd_mask is not None, "[fwd] return_softmax=True but sd_mask is None" + assert sd_mask.dim() == 4, f"[fwd] sd_mask dim {sd_mask.dim()} != 4" + if SHAPE_EXPECTATIONS == "rounded": + expected_sq = round_multiple(q.shape[1], 128) + expected_sk = round_multiple(k.shape[1], 128) + assert ( + sd_mask.shape[0] == q.shape[0] + and sd_mask.shape[1] == q.shape[2] + and sd_mask.shape[2] == expected_sq + and sd_mask.shape[3] == expected_sk + ), f"[fwd] sd_mask shape {sd_mask.shape} != (B={q.shape[0]}, Hq={q.shape[2]}, Sq={expected_sq}, Sk={expected_sk})" + else: + assert ( + sd_mask.shape[0] == q.shape[0] + and sd_mask.shape[1] == q.shape[2] + and sd_mask.shape[2] == q.shape[1] + ), f"[fwd] sd_mask leading dims {sd_mask.shape[:3]} mismatch (B,Hq,Sq) {(q.shape[0], q.shape[2], q.shape[1])}" + else: + assert sd_mask is None, "[fwd] return_softmax=False but sd_mask is not None" + + return out, softmax_lse, sd_mask, rng_state + + +def bwd( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + dropout_p: float, + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + deterministic: bool, + gen_: Optional[torch.Tensor] = None, + rng_state: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if softcap != 0.0: + raise NotImplementedError( + "softcap is not supported in the AMD Triton FA2 interface (expected 0.0)." + ) + + # Check for sliding window - backward doesn't support it yet + is_sliding_window = (window_size_left >= 0) or (window_size_right >= 0) + if is_sliding_window: + raise NotImplementedError( + f"Sliding window attention is not yet supported in the AMD Triton backward pass " + f"(window_size_left={window_size_left}, window_size_right={window_size_right}). " + f"Use window_size=(-1, -1) for full attention." + ) + + if DEBUG: + print() + print("flash_attn_triton_amd.py::bwd inputs") + print("dout:", dout.shape) + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) + print("out:", out.shape) + print("softmax_lse:", softmax_lse.shape) + print("dq:", dq.shape if dq is not None else None) + print("dk:", dk.shape if dk is not None else None) + print("dv:", dv.shape if dv is not None else None) + print("alibi_slopes:", alibi_slopes.shape if alibi_slopes is not None else None) + print("dropout_p:", dropout_p) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("deterministic:", deterministic) + print("rng_state:", rng_state) + + dq = torch.zeros_like(q) if dq is None else dq.zero_() + dk = torch.zeros_like(k) if dk is None else dk.zero_() + dv = torch.zeros_like(v) if dv is None else dv.zero_() + + # get shape + batch, seqlen_q, nheads_q, _ = q.shape + + # Create delta tensor with shape based on expectations + # delta (softmax_d) : (B, Hq, Sq) or (B, Hq, round_multiple(Sq, 128)) + if SHAPE_EXPECTATIONS == "rounded": + delta = torch.zeros( + (batch, nheads_q, round_multiple(seqlen_q, 128)), + device=q.device, + dtype=torch.float32, + ) + else: + delta = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) + + # Upstream change: base seeding logic on provided rng_state instead of dropout probability. + if rng_state is not None: + philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() + else: + philox_seed, philox_offset = None, None + + if alibi_slopes is not None: + if alibi_slopes.dim() == 2: + pass + elif alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + else: + raise ValueError("Alibi can be (nheads,) or (batch_size, nheads).") + + # call implementation + if DEBUG: + print(f"Using Triton implementation in {BWD_MODE} mode") + attention_backward_triton_impl( + do=dout, + q=q, + k=k, + v=v, + o=out, + softmax_lse=softmax_lse, + dq=dq, + dk=dk, + dv=dv, + delta=delta, + sm_scale=softmax_scale, + alibi_slopes=alibi_slopes, + causal=causal, + layout="bshd", + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=seqlen_q, + max_seqlen_k=k.shape[1], + seqused_q=None, + seqused_k=None, + dropout_p=dropout_p, + philox_seed=philox_seed, + philox_offset=philox_offset, + use_exp2=USE_EXP2, + mode=BWD_MODE, + ) + + if DEBUG: + print("flash_attn_triton_amd.py::bwd outputs") + print("dq:", dq.shape) + print("dk:", dk.shape) + print("dv:", dv.shape) + # --- Assertions --- + assert dq.shape == q.shape, f"[bwd] dq shape {dq.shape} != q shape {q.shape}" + assert dk.shape == k.shape, f"[bwd] dk shape {dk.shape} != k shape {k.shape}" + assert dv.shape == v.shape, f"[bwd] dv shape {dv.shape} != v shape {v.shape}" + # delta (softmax_d) : (B, Hq, Sq) + if SHAPE_EXPECTATIONS == "rounded": + expected_delta_shape = (q.shape[0], q.shape[2], round_multiple(q.shape[1], 128)) + else: + expected_delta_shape = (q.shape[0], q.shape[2], q.shape[1]) + assert ( + delta.shape == expected_delta_shape + ), f"[bwd] delta shape {delta.shape} != {expected_delta_shape}" + return dq, dk, dv, delta + + +def varlen_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + seqused_k: Optional[torch.Tensor], + leftpad_k: Optional[torch.Tensor], + block_table_: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + softmax_scale: float, + zero_tensors: bool, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + return_softmax: bool, + gen_: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + + if str(q.dtype).startswith("torch.float8"): + raise NotImplementedError( + "FP8 tensors are not supported in the AMD Triton FA2 interface (varlen_fwd). Use the FA3 path instead." + ) + + if softcap != 0.0: + raise NotImplementedError( + "softcap is not supported in varlen_fwd (expected 0.0)." + ) + if leftpad_k is not None: + raise NotImplementedError( + "leftpad_k is not supported in AMD Triton FA2 varlen_fwd." + ) + if block_table_ is not None: + raise NotImplementedError( + "block_table / paged attention is not supported in AMD Triton FA2 varlen_fwd." + ) + if seqused_k is not None: + raise NotImplementedError( + "seqused_k is not supported in AMD Triton FA2 varlen_fwd." + ) + + if DEBUG: + print() + print("flash_attn_triton_amd.py::varlen_fwd") + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) + print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape) + print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape) + print("alibi_slopes:", alibi_slopes) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("dropout_p:", dropout_p) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("gen_:", gen_) + out = torch.zeros_like(q) if out is None else out.zero_() + + # Layout and basic info for varlen + layout: Literal["bshd", "bhsd", "thd"] = "thd" + batch = len(cu_seqlens_q) - 1 + total_q, nheads_q, _ = q.shape + + # Create softmax_lse tensor - varlen always uses exact shape (Hq, Total_Q) + softmax_lse = torch.zeros((nheads_q, total_q), device=q.device, dtype=torch.float32) + + # Create sd_mask tensor if needed + if return_softmax: + # sd_mask: (B, Hq, Sq, Sk) - shape based on expectations + if SHAPE_EXPECTATIONS == "rounded": + sd_mask = torch.zeros( + ( + batch, + nheads_q, + round_multiple(max_seqlen_q, 128), + round_multiple(max_seqlen_k, 128), + ), + device=q.device, + dtype=q.dtype, + ) + else: + sd_mask = torch.zeros( + (batch, nheads_q, max_seqlen_q, max_seqlen_k), + device=q.device, + dtype=q.dtype, + ) + else: + sd_mask = None + + if alibi_slopes is not None: + if alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + assert alibi_slopes.is_cuda and alibi_slopes.dim() == 2 + assert alibi_slopes.shape == (batch, nheads_q) + + philox_seed, philox_offset = PHILOX_SEED, PHILOX_OFFSET + rng_state = torch.as_tensor([philox_seed, philox_offset]) + + # Inline checks (subset appropriate for varlen) + assert q.dim() == 3 and k.dim() == 3 and v.dim() == 3 + assert q.shape[-1] == k.shape[-1] == v.shape[-1] + assert q.dtype == k.dtype == v.dtype + assert out.shape == q.shape + nheads_k = k.shape[1] + assert (nheads_q % nheads_k) == 0 + + # call implementation + if DEBUG: + print("Using Triton implementation") + attention_forward_prefill_triton_impl( + q, + k, + v, + out, + softmax_lse, + sd_mask, + softmax_scale, + alibi_slopes, + causal, + window_size_left, + window_size_right, + None, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + return_softmax, + USE_EXP2, + None, + None, + None, + ) + + if DEBUG: + print("varlen_fwd outputs") + print("out:", out, out.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None) + # --- Assertions --- + # out: (Total_Q, Hq, D) + assert ( + out.shape == q.shape + ), f"[varlen_fwd] out shape {out.shape} != q shape {q.shape}" + # softmax_lse: (Hq, Total_Q) + expected_lse_shape = (q.shape[1], q.shape[0]) + assert ( + softmax_lse.shape == expected_lse_shape + ), f"[varlen_fwd] softmax_lse shape {softmax_lse.shape} != {expected_lse_shape}" + assert ( + softmax_lse.dtype == torch.float32 + ), f"[varlen_fwd] softmax_lse dtype {softmax_lse.dtype} != torch.float32" + if return_softmax: + # sd_mask expected: (B, Hq, max_seqlen_q, max_seqlen_k) + assert ( + sd_mask is not None + ), "[varlen_fwd] return_softmax=True but sd_mask is None" + assert sd_mask.dim() == 4, f"[varlen_fwd] sd_mask dim {sd_mask.dim()} != 4" + batch = len(cu_seqlens_q) - 1 + assert ( + sd_mask.shape[0] == batch + ), f"[varlen_fwd] sd_mask batch {sd_mask.shape[0]} != {batch}" + assert ( + sd_mask.shape[1] == q.shape[1] + ), f"[varlen_fwd] sd_mask nheads {sd_mask.shape[1]} != {q.shape[1]}" + if SHAPE_EXPECTATIONS == "rounded": + expected_sq = round_multiple(max_seqlen_q, 128) + expected_sk = round_multiple(max_seqlen_k, 128) + assert ( + sd_mask.shape[2] == expected_sq and sd_mask.shape[3] == expected_sk + ), f"[varlen_fwd] sd_mask shape {sd_mask.shape} != (B={batch}, Hq={q.shape[1]}, Sq={expected_sq}, Sk={expected_sk})" + else: + assert ( + sd_mask.shape[2] == max_seqlen_q and sd_mask.shape[3] == max_seqlen_k + ), f"[varlen_fwd] sd_mask shape {sd_mask.shape} != (B={batch}, Hq={q.shape[1]}, Sq={max_seqlen_q}, Sk={max_seqlen_k})" + else: + assert ( + sd_mask is None + ), "[varlen_fwd] return_softmax=False but sd_mask is not None" + return out, softmax_lse, sd_mask, rng_state + + +def varlen_bwd( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + alibi_slopes: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + softmax_scale: float, + zero_tensors: bool, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + deterministic: bool, + gen_: Optional[torch.Tensor] = None, + rng_state: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if str(q.dtype).startswith("torch.float8"): + raise NotImplementedError( + "FP8 tensors are not supported in the AMD Triton FA2 interface (varlen_bwd). Use the FA3 path instead." + ) + if softcap != 0.0: + raise NotImplementedError( + "softcap is not supported in varlen_bwd (expected 0.0)." + ) + + if DEBUG: + print() + print("varlen_bwd") + print("dout:", dout.shape) + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) + print("out:", out) + print("softmax_lse:", softmax_lse.shape) + print("dq:", dq.shape if dq is not None else None) + print("dk:", dk.shape if dk is not None else None) + print("dv:", dv.shape if dv is not None else None) + print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape) + print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape) + print("alibi_slopes:", alibi_slopes) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("dropout_p:", dropout_p) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("deterministic:", deterministic) + print("gen_:", gen_) + print("rng_state:", rng_state) + + dq = torch.zeros_like(q) if dq is None else dq.zero_() + dk = torch.zeros_like(k) if dk is None else dk.zero_() + dv = torch.zeros_like(v) if dv is None else dv.zero_() + + # get shape + batch = len(cu_seqlens_q) - 1 + total_q, nheads_q, _ = q.shape + + # Create delta tensor with shape based on expectations + # delta (softmax_d) : (Hq, Total_Q) or (Hq, Total_Q + 128*batch) + if SHAPE_EXPECTATIONS == "rounded": + delta = torch.zeros( + (nheads_q, total_q + 128 * batch), device=q.device, dtype=torch.float32 + ) + else: + delta = torch.zeros((nheads_q, total_q), device=q.device, dtype=torch.float32) + + # Upstream change: base seeding logic on provided rng_state instead of dropout probability. + if rng_state is not None: + philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() + else: + philox_seed, philox_offset = None, None + + if alibi_slopes is not None: + if alibi_slopes.dim() == 2: + pass + elif alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + else: + raise ValueError("Alibi can be (nheads,) or (batch_size, nheads).") + + # call implementation + if DEBUG: + print(f"Using Triton implementation in {BWD_MODE} mode") + attention_backward_triton_impl( + do=dout, + q=q, + k=k, + v=v, + o=out, + softmax_lse=softmax_lse, + dq=dq, + dk=dk, + dv=dv, + delta=delta, + sm_scale=softmax_scale, + alibi_slopes=alibi_slopes, + causal=causal, + layout="thd", + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqused_q=None, + seqused_k=None, + dropout_p=dropout_p, + philox_seed=philox_seed, + philox_offset=philox_offset, + use_exp2=USE_EXP2, + mode=BWD_MODE, + ) + + if DEBUG: + print("varlen_bwd outputs") + print("delta:", delta, delta.shape) + print("dv:", dv, dv.shape) + print("dk:", dk, dk.shape) + print("dq:", dq, dq.shape) + # --- Assertions --- + assert dq.shape == q.shape, f"[varlen_bwd] dq shape {dq.shape} != q shape {q.shape}" + assert dk.shape == k.shape, f"[varlen_bwd] dk shape {dk.shape} != k shape {k.shape}" + assert dv.shape == v.shape, f"[varlen_bwd] dv shape {dv.shape} != v shape {v.shape}" + if SHAPE_EXPECTATIONS == "rounded": + batch = len(cu_seqlens_q) - 1 + expected_delta_shape = (q.shape[1], q.shape[0] + 128 * batch) + else: + expected_delta_shape = (q.shape[1], q.shape[0]) # (Hq, Total_Q) + assert ( + delta.shape == expected_delta_shape + ), f"[varlen_bwd] delta shape {delta.shape} != {expected_delta_shape}" + return dq, dk, dv, delta + + +def fwd_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + cache_seqlens: Optional[Union[int, torch.Tensor]], + rotary_cos: Optional[torch.Tensor], + rotary_sin: Optional[torch.Tensor], + cache_batch_idx: Optional[torch.Tensor], + cache_leftpad: Optional[torch.Tensor], + block_table: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + out: Optional[torch.Tensor], + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + rotary_interleaved: bool, + num_splits: int, +) -> tuple[torch.Tensor, torch.Tensor]: + + if softcap != 0.0: + raise NotImplementedError( + "softcap is not supported in fwd_kvcache (expected 0.0)." + ) + if num_splits not in (0, 1): + raise NotImplementedError( + "num_splits > 1 not supported in AMD Triton FA2 fwd_kvcache." + ) + + if DEBUG: + print() + print("flash_attn_triton_amd.py::fwd_kvcache inputs") + print("q:", q, q.shape) + print("k_cache:", k_cache, k_cache.shape) + print("v_cache:", v_cache, v_cache.shape) + print("k:", k, k.shape if k is not None else None) + print("v:", v, v.shape if v is not None else None) + print("cache_seqlens:", cache_seqlens) + print("rotary_cos:", rotary_cos) + print("rotary_sin:", rotary_sin) + print("cache_batch_idx:", cache_batch_idx) + print("cache_leftpad:", cache_leftpad) + print("block_table:", block_table) + print("alibi_slopes:", alibi_slopes) + print("out:", out) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("softcap:", softcap) + print("rotary_interleaved:", rotary_interleaved) + print("num_splits:", num_splits) + + # output + out = torch.zeros_like(q) if out is None else out.zero_() + + # Basic layout info for decode path + layout: Literal["bshd"] = "bshd" + max_seqlen_q = q.shape[1] + max_seqlen_k = k_cache.shape[1] + cache_seqlens_tensor = ( + torch.tensor(cache_seqlens, device=q.device) + if isinstance(cache_seqlens, int) + else cache_seqlens + ) + window_left = ( + int(window_size_left.item()) + if isinstance(window_size_left, torch.Tensor) + else window_size_left + ) + window_right = ( + int(window_size_right.item()) + if isinstance(window_size_right, torch.Tensor) + else window_size_right + ) + + k_new = k + v_new = v + + # get shape + batch, seqlen_q, nheads_q, _ = q.shape + + # Create softmax_lse tensor - decode always uses exact shape (B, Hq, Sq) + softmax_lse = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) + + if alibi_slopes is not None: + if alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + assert alibi_slopes.is_cuda and alibi_slopes.dim() == 2 + assert alibi_slopes.shape == (batch, nheads_q) + + # launch kernel + if DEBUG: + print("Using Triton implementation") + attention_forward_decode_triton_impl( + q, + k_cache, + v_cache, + k_new, + v_new, + out, + softmax_lse, + softmax_scale, + causal, + window_left, + window_right, + alibi_slopes, + layout, + cache_seqlens_tensor, + cache_batch_idx, + block_table, + None, + None, + None, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + rotary_interleaved=rotary_interleaved, + ) + + if DEBUG: + print("out:", out, out.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + # --- Assertions --- + assert ( + out.shape == q.shape + ), f"[fwd_kvcache] out shape {out.shape} != q shape {q.shape}" + expected_lse_shape = (q.shape[0], q.shape[2], q.shape[1]) + assert ( + softmax_lse.shape == expected_lse_shape + ), f"[fwd_kvcache] softmax_lse shape {softmax_lse.shape} != {expected_lse_shape}" + assert ( + softmax_lse.dtype == torch.float32 + ), f"[fwd_kvcache] softmax_lse dtype {softmax_lse.dtype} != torch.float32" + return out, softmax_lse diff --git a/flash_attn/flash_attn_triton_amd/interface_v3.py b/flash_attn/flash_attn_triton_amd/interface_v3.py new file mode 100755 index 00000000000..c38c190ac35 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/interface_v3.py @@ -0,0 +1,638 @@ +import os +import warnings +import torch +from typing import Literal, Optional, Union, Tuple +from .fwd_prefill import attention_forward_prefill_triton_impl +from .fwd_decode import attention_forward_decode_triton_impl +from .bwd import attention_backward_triton_impl +from .utils import ( + DEBUG, + USE_EXP2, + BWD_MODE, + PHILOX_SEED, + PHILOX_OFFSET, + is_fp8, +) + + +def fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_new: Optional[torch.Tensor], + v_new: Optional[torch.Tensor], + qv: Optional[torch.Tensor], + out: Optional[torch.Tensor], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + cu_seqlens_k_new: Optional[torch.Tensor], + seqused_q: Optional[torch.Tensor], + seqused_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + page_table: Optional[torch.Tensor], + kv_batch_idx: Optional[torch.Tensor], + leftpad_k: Optional[torch.Tensor], + rotary_cos: Optional[torch.Tensor], + rotary_sin: Optional[torch.Tensor], + seqlens_rotary: Optional[torch.Tensor], + q_descale: Optional[torch.Tensor], + k_descale: Optional[torch.Tensor], + v_descale: Optional[torch.Tensor], + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + attention_chunk: int, + softcap: float, + rotary_interleaved: bool, + scheduler_metadata: None = None, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + sm_margin: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Flash Attention v3 forward pass compatible interface for AMD Triton implementation. + + This function maps v3 parameters to the existing AMD Triton implementation. + """ + + if DEBUG: + print() + print("interface_fa_v3.py::fwd inputs") + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) + print("k_new:", k_new.shape if k_new is not None else None) + print("v_new:", v_new.shape if v_new is not None else None) + print("qv:", qv.shape if qv is not None else None) + print("out:", out.shape if out is not None else None) + print("cu_seqlens_q:", cu_seqlens_q.shape if cu_seqlens_q is not None else None) + print("cu_seqlens_k:", cu_seqlens_k.shape if cu_seqlens_k is not None else None) + print("cu_seqlens_k_new:", cu_seqlens_k_new.shape if cu_seqlens_k_new is not None else None) + print("seqused_q:", seqused_q.shape if seqused_q is not None else None) + print("seqused_k:", seqused_k.shape if seqused_k is not None else None) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("page_table:", page_table.shape if page_table is not None else None) + print("kv_batch_idx:", kv_batch_idx.shape if kv_batch_idx is not None else None) + print("leftpad_k:", leftpad_k.shape if leftpad_k is not None else None) + print("rotary_cos:", rotary_cos.shape if rotary_cos is not None else None) + print("rotary_sin:", rotary_sin.shape if rotary_sin is not None else None) + print("seqlens_rotary:", seqlens_rotary.shape if seqlens_rotary is not None else None) + print("q_descale:", q_descale.shape if q_descale is not None else None) + print("k_descale:", k_descale.shape if k_descale is not None else None) + print("v_descale:", v_descale.shape if v_descale is not None else None) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("attention_chunk:", attention_chunk) + print("softcap:", softcap) + print("rotary_interleaved:", rotary_interleaved) + print("scheduler_metadata:", scheduler_metadata) + print("num_splits:", num_splits) + print("pack_gqa:", pack_gqa) + print("sm_margin:", sm_margin) + + # Handle qv packed input + if qv is not None: + raise NotImplementedError( + "QV packed input is not yet supported in the AMD Triton backend" + ) + + # Handle softcap + if softcap != 0.0: + raise NotImplementedError( + f"Softcap is not yet supported in the AMD Triton backend (got softcap={softcap}, expected 0.0)" + ) + + # Handle attention_chunk + if attention_chunk != 0 and attention_chunk != 1: + raise NotImplementedError( + f"attention_chunk is not yet supported in the AMD Triton backend (got attention_chunk={attention_chunk})" + ) + + # Handle scheduler metadata + if scheduler_metadata is not None: + raise NotImplementedError( + "Scheduler metadata is not yet supported in the AMD Triton backend" + ) + + # Handle pack_gqa + if pack_gqa is not None and pack_gqa is not False: + raise NotImplementedError( + f"pack_gqa is not yet supported in the AMD Triton backend (got pack_gqa={pack_gqa})" + ) + + # Handle num_splits + if num_splits != 1: + raise NotImplementedError( + f"Split attention (num_splits > 1) is not yet supported in the AMD Triton backend (got num_splits={num_splits})" + ) + + # Handle sm_margin + if sm_margin != 0: + raise NotImplementedError( + f"sm_margin is not yet supported in the AMD Triton backend (got sm_margin={sm_margin}, expected 0)" + ) + + # Handle leftpad_k + if leftpad_k is not None: + raise NotImplementedError( + "Left padding (leftpad_k) is not yet supported in the AMD Triton backend" + ) + + # Handle cu_seqlens_k_new + if cu_seqlens_k_new is not None: + raise NotImplementedError( + "cu_seqlens_k_new is not yet supported in the AMD Triton backend" + ) + + # establish layout / varlen & max seq lens + if cu_seqlens_q is not None: + if len(q.shape) != 3: + raise ValueError( + f"cu_seqlens_q provided but q has shape {q.shape}, expected 3D tensor for varlen" + ) + layout: Literal["bshd", "thd"] = "thd" + cu_seqlens_q_local = cu_seqlens_q + assert max_seqlen_q is not None, "max_seqlen_q required for varlen mode" + max_seqlens_q_local = max_seqlen_q + if cu_seqlens_k is not None: + cu_seqlens_k_local = cu_seqlens_k + assert max_seqlen_k is not None, "max_seqlen_k required when cu_seqlens_k provided" + max_seqlens_k_local = max_seqlen_k + else: + cu_seqlens_k_local = None + if len(k.shape) == 4: + max_seqlens_k_local = k.shape[1] + else: + assert max_seqlen_k is not None, "max_seqlen_k required for varlen mode" + max_seqlens_k_local = max_seqlen_k + else: + layout = "bshd" + cu_seqlens_q_local = None + cu_seqlens_k_local = None + max_seqlens_q_local = q.shape[1] if max_seqlen_q is None else max_seqlen_q + max_seqlens_k_local = k.shape[1] if max_seqlen_k is None else max_seqlen_k + + # Now determine if we should use decode or prefill kernel + # Decode kernel should be used for KV cache scenarios where: + # 1. k_new/v_new are provided - incremental KV cache update (primary KV cache indicator) + # 2. kv_batch_idx is provided - KV cache batch indexing (primary KV cache indicator) + # 3. seqused_k without seqused_q - indicates KV cache fill levels (not varlen masking) + # Note: In varlen, both seqused_q and seqused_k are used for sequence masking + # In KV cache, only seqused_k is used to track cache fill levels + # Detect KV cache scenarios: + # - Clear KV cache indicators (k_new, v_new, kv_batch_idx) + # - OR seqused_k without seqused_q (KV cache fill tracking, not varlen masking) + use_decode = ( + k_new is not None # Have new KV to append (KV cache indicator) + or v_new is not None # Have new KV to append (KV cache indicator) + or kv_batch_idx is not None # Have KV cache batch indexing (KV cache indicator) + or ( + seqused_k is not None and seqused_q is None + ) # KV cache fill levels (not varlen) + ) + + # Check for unsupported features with decode kernel + if use_decode: + if layout == "thd": + raise NotImplementedError( + "Varlen is not yet supported with the decode kernel in the AMD Triton backend" + ) + if kv_batch_idx is not None: + raise NotImplementedError( + "kv_batch_idx is not yet supported with the decode kernel in the AMD Triton backend" + ) + + if out is None: + # NOTE: Using types that are lower precision than float32 such as bfloat16 for fp8 causes mismatches on a small set of tests. + out_dtype = torch.float32 if is_fp8([q, k, v]) else q.dtype + if layout == "bshd": + out = torch.zeros( + q.shape[0], + q.shape[1], + q.shape[2], + v.shape[-1], + dtype=out_dtype, + device=q.device, + ) + elif layout == "thd": + out = torch.zeros( + q.shape[0], q.shape[1], v.shape[-1], dtype=out_dtype, device=q.device + ) + else: + raise ValueError( + f"Unsupported layout: {layout}. Only 'bshd' and 'thd' layouts are supported." + ) + else: + out = out.zero_() + + # Handle causal mask + causal_flag = bool(causal) + + # Handle alibi slopes + alibi_slopes = None + + # Handle dropout + dropout_p = 0.0 + return_softmax = False + philox_seed = PHILOX_SEED + philox_offset = PHILOX_OFFSET + + # Call implementation + if DEBUG: + print("Using Triton implementation") + + if use_decode: + if DEBUG: + print( + f"Using Decode Triton implementation (cache_seqlens={seqused_k is not None}, k_new={k_new is not None}, v_new={v_new is not None}, kv_batch_idx={kv_batch_idx is not None})" + ) + + # Create softmax_lse tensor for decode - always exact shape (B, Hq, Sq) + batch, seqlen_q, nheads_q, _ = q.shape + softmax_lse = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) + + # Decode only supports bshd layout + assert layout == "bshd", f"decode requires bshd layout, got {layout}" + attention_forward_decode_triton_impl( + q, + k, + v, + k_new, + v_new, + out, + softmax_lse, + softmax_scale, + causal_flag, + window_size_left, + window_size_right, + alibi_slopes, + layout, + seqused_k, + kv_batch_idx, + page_table, + q_descale, + k_descale, + v_descale, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + rotary_interleaved=rotary_interleaved, + seqlens_rotary=seqlens_rotary, + ) + else: + if DEBUG: + print("Using Prefill Triton implementation") + + # Create softmax_lse tensor - FA3 always uses exact shapes + if layout == "thd": + # varlen: (Hq, Total_Q) + total_q, nheads_q, _ = q.shape + softmax_lse = torch.zeros( + (nheads_q, total_q), device=q.device, dtype=torch.float32 + ) + else: + # bshd: (B, Hq, Sq) + batch, seqlen_q, nheads_q, _ = q.shape + softmax_lse = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) + + # sd_mask is not returned in v3 interface + sd_mask = None + + attention_forward_prefill_triton_impl( + q, + k, + v, + out, + softmax_lse, + sd_mask, + softmax_scale, + alibi_slopes, + causal_flag, + window_size_left, + window_size_right, + None, + layout, + cu_seqlens_q_local, + cu_seqlens_k_local, + max_seqlens_q_local, + max_seqlens_k_local, + dropout_p, + philox_seed, + philox_offset, + return_softmax, + USE_EXP2, + q_descale, + k_descale, + v_descale, + seqused_q, + seqused_k, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + rotary_interleaved=rotary_interleaved, + seqlens_rotary=seqlens_rotary, + ) + + if DEBUG: + print("interface_fa_v3.py::fwd outputs") + print("out:", out.shape) + print("softmax_lse:", softmax_lse.shape) + + # --- Assertions (FA3 always expects exact shapes) --- + # out: same shape as q except last dim is v's head_dim + if layout == "thd": + # varlen: (Total_Q, Hq, Dv) + assert ( + out.shape[0] == q.shape[0] + ), f"[fwd_v3] out.shape[0] {out.shape[0]} != q.shape[0] {q.shape[0]}" + assert ( + out.shape[1] == q.shape[1] + ), f"[fwd_v3] out.shape[1] {out.shape[1]} != q.shape[1] {q.shape[1]}" + assert ( + out.shape[2] == v.shape[-1] + ), f"[fwd_v3] out.shape[2] {out.shape[2]} != v.shape[-1] {v.shape[-1]}" + else: + # bshd: (B, Sq, Hq, Dv) + assert ( + out.shape[0] == q.shape[0] + ), f"[fwd_v3] out.shape[0] {out.shape[0]} != q.shape[0] {q.shape[0]}" + assert ( + out.shape[1] == q.shape[1] + ), f"[fwd_v3] out.shape[1] {out.shape[1]} != q.shape[1] {q.shape[1]}" + assert ( + out.shape[2] == q.shape[2] + ), f"[fwd_v3] out.shape[2] {out.shape[2]} != q.shape[2] {q.shape[2]}" + assert ( + out.shape[3] == v.shape[-1] + ), f"[fwd_v3] out.shape[3] {out.shape[3]} != v.shape[-1] {v.shape[-1]}" + + # softmax_lse dtype + assert ( + softmax_lse.dtype == torch.float32 + ), f"[fwd_v3] softmax_lse dtype {softmax_lse.dtype} != torch.float32" + # softmax_lse shape depends on layout + expected_lse_shape: tuple[int, ...] + if layout == "thd": + # varlen: (Hq, Total_Q) + expected_lse_shape = (q.shape[1], q.shape[0]) + else: + # bshd: (B, Hq, Sq) + expected_lse_shape = (q.shape[0], q.shape[2], q.shape[1]) + assert ( + softmax_lse.shape == expected_lse_shape + ), f"[fwd_v3] softmax_lse shape {softmax_lse.shape} != {expected_lse_shape}" + + # Return format compatible with v3 + # V3 returns (out, softmax_lse, out_accum, softmax_lse_accum) + # out_accum and softmax_lse_accum are None for Triton AMD (no split-k accumulation) + return out, softmax_lse, None, None + + +def bwd( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + seqused_q: Optional[torch.Tensor], + seqused_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + deterministic: bool, + sm_margin: int = 0, +) -> Tuple[torch.Tensor]: + """ + Flash Attention v3 backward pass compatible interface for AMD Triton implementation. + + This function maps v3 parameters to the existing AMD Triton implementation. + """ + + if DEBUG: + print() + print("interface_fa_v3.py::bwd inputs") + print("dout:", dout.shape) + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) + print("out:", out.shape) + print("softmax_lse:", softmax_lse.shape) + print("dq:", dq.shape if dq is not None else None) + print("dk:", dk.shape if dk is not None else None) + print("dv:", dv.shape if dv is not None else None) + print("cu_seqlens_q:", cu_seqlens_q.shape if cu_seqlens_q is not None else None) + print("cu_seqlens_k:", cu_seqlens_k.shape if cu_seqlens_k is not None else None) + print("seqused_q:", seqused_q.shape if seqused_q is not None else None) + print("seqused_k:", seqused_k.shape if seqused_k is not None else None) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("softcap:", softcap) + print("deterministic:", deterministic) + print("sm_margin:", sm_margin) + + # Check for unsupported features in backward pass + + # Handle sliding window - backward doesn't support it yet + is_sliding_window = (window_size_left >= 0) or (window_size_right >= 0) + if is_sliding_window: + raise NotImplementedError( + f"Sliding window attention is not yet supported in the AMD Triton backward pass " + f"(window_size_left={window_size_left}, window_size_right={window_size_right}). " + f"Use window_size=(-1, -1) for full attention." + ) + + # Handle softcap + if softcap != 0.0: + raise NotImplementedError( + f"Softcap is not yet supported in the AMD Triton backend backward pass (got softcap={softcap}, expected 0.0)" + ) + + # Handle sm_margin + if sm_margin != 0: + raise NotImplementedError( + f"sm_margin is not yet supported in the AMD Triton backend backward pass (got sm_margin={sm_margin}, expected 0)" + ) + + # Initialize gradient tensors if not provided + # NOTE: Using types that are lower precision than float32 such as bfloat16 for fp8 causes mismatches on a small set of tests. + grad_dtype = torch.float32 if is_fp8([q, k, v]) else q.dtype + dq = torch.zeros_like(q, dtype=grad_dtype) if dq is None else dq.zero_() + dk = torch.zeros_like(k, dtype=grad_dtype) if dk is None else dk.zero_() + dv = torch.zeros_like(v, dtype=grad_dtype) if dv is None else dv.zero_() + + # Determine layout based on cu_seqlens + layout: Literal["bshd", "bhsd", "thd"] + if cu_seqlens_q is not None and cu_seqlens_k is not None: + # Variable length sequence mode + layout = "thd" + batch = len(cu_seqlens_q) - 1 + total_q, nheads_q, _ = q.shape + # Create delta tensor - varlen: (Hq, Total_Q) + delta = torch.zeros((nheads_q, total_q), device=q.device, dtype=torch.float32) + else: + # Regular batch mode + layout = "bshd" + batch, seqlen_q, nheads_q, _ = q.shape + max_seqlen_q = q.shape[1] if max_seqlen_q is None else max_seqlen_q + max_seqlen_k = k.shape[1] if max_seqlen_k is None else max_seqlen_k + # Create delta tensor - bshd: (B, Hq, Sq) + delta = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) + + # V3 backward doesn't have dropout or alibi slopes + dropout_p = 0.0 + philox_seed, philox_offset = None, None + alibi_slopes = None + + # Call implementation + if DEBUG: + print(f"Using Triton implementation in {BWD_MODE} mode") + attention_backward_triton_impl( + do=dout, + q=q, + k=k, + v=v, + o=out, + softmax_lse=softmax_lse, + dq=dq, + dk=dk, + dv=dv, + delta=delta, + sm_scale=softmax_scale, + alibi_slopes=alibi_slopes, + causal=causal, + layout=layout, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + dropout_p=dropout_p, + philox_seed=philox_seed, + philox_offset=philox_offset, + use_exp2=USE_EXP2, + mode=BWD_MODE, + ) + + if DEBUG: + print("interface_fa_v3.py::bwd outputs") + print("dq:", dq.shape) + print("dk:", dk.shape) + print("dv:", dv.shape) + print("delta:", delta.shape) + + # --- Assertions (FA3 always expects exact shapes) --- + # Gradients should match input shapes + assert dq.shape == q.shape, f"[bwd_v3] dq shape {dq.shape} != q shape {q.shape}" + assert dk.shape == k.shape, f"[bwd_v3] dk shape {dk.shape} != k shape {k.shape}" + assert dv.shape == v.shape, f"[bwd_v3] dv shape {dv.shape} != v shape {v.shape}" + # delta (softmax_d) should match softmax_lse shape + assert ( + delta.dtype == torch.float32 + ), f"[bwd_v3] delta dtype {delta.dtype} != torch.float32" + expected_delta_shape: tuple[int, ...] + if layout == "thd": + # varlen: (Hq, Total_Q) + expected_delta_shape = (q.shape[1], q.shape[0]) + else: + # bshd: (B, Hq, Sq) + expected_delta_shape = (q.shape[0], q.shape[2], q.shape[1]) + assert ( + delta.shape == expected_delta_shape + ), f"[bwd_v3] delta shape {delta.shape} != {expected_delta_shape}" + + # V3 expects (softmax_d, *rest) + # delta is the softmax_d in this case + return (delta,) + + +def fwd_combine( + out_partial: torch.Tensor, + lse_partial: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +) -> "torch.Tensor": + """ + Combine partial outputs from split attention computation. + + This is used when num_splits > 1 to combine the partial results. + + Args: + out_partial: Partial output tensor from split computation + lse_partial: Partial log-sum-exp tensor + out: Optional output tensor to write to + out_dtype: Optional dtype for output + + Returns: + Combined output tensor + """ + raise NotImplementedError( + "fwd_combine is not yet implemented in the AMD Triton backend" + ) + + +def get_scheduler_metadata( + batch_size: int, + max_seqlen_q: int, + max_seqlen_k: int, + num_heads_q: int, + num_heads_kv: int, + headdim: int, + headdim_v: int, + qkv_dtype: torch.dtype, + cache_seqlens: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_size: Optional[int] = None, + max_seqlen_k_new: int = 0, + causal: bool = False, + window_size_left: int = -1, + window_size_right: int = -1, + attention_chunk: int = 0, + has_softcap: bool = False, + num_splits: int = 0, + pack_gqa: Optional[bool] = None, + sm_margin: int = 0, +) -> None: + """ + Get scheduler metadata for optimized kernel selection. + + This function is used to precompute metadata for kernel scheduling in FA3. + The AMD Triton backend currently doesn't use scheduler metadata, so this + raises an error. + + Args: + Various attention parameters used for scheduling decisions + + Returns: + None - scheduler metadata is not used in AMD Triton backend + """ + raise NotImplementedError( + "get_scheduler_metadata is not supported in the AMD Triton backend yet." + ) diff --git a/flash_attn/flash_attn_triton_amd/pyproject.toml b/flash_attn/flash_attn_triton_amd/pyproject.toml new file mode 100644 index 00000000000..3a07ef28ed9 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/pyproject.toml @@ -0,0 +1,48 @@ +# mypy --config-file flash_attn/flash_attn_triton_amd/pyproject.toml +[tool.mypy] +files = [ + # Core Triton AMD backend + "flash_attn/flash_attn_triton_amd", + # Tests (based on test_flash_attn.py - looser rules, but catches import errors) + "tests/test_flash_attn_triton_amd.py", + "hopper/test_flash_attn_triton_amd.py", +] +ignore_missing_imports = true +follow_imports = "skip" +python_version = "3.9" + +# Strict checks +strict_equality = true +warn_unreachable = true +warn_redundant_casts = true +warn_unused_ignores = true +check_untyped_defs = true +warn_return_any = true +warn_unused_configs = true +no_implicit_optional = true +strict_optional = true +disallow_incomplete_defs = false # Triton kernels can't be fully typed +disallow_subclassing_any = false # torch.autograd.Function has type Any + +# Triton kernels use untyped decorators and defs +disallow_untyped_defs = false +disallow_untyped_decorators = false +disallow_untyped_calls = false + +# Follow imports for our module so test imports are validated +[[tool.mypy.overrides]] +module = ["flash_attn.flash_attn_triton_amd", "flash_attn.flash_attn_triton_amd.*"] +follow_imports = "normal" + +# Stricter settings for interface and utility modules only +[[tool.mypy.overrides]] +module = ["flash_attn.flash_attn_triton_amd.interface_v2", "flash_attn.flash_attn_triton_amd.interface_v3", "flash_attn.flash_attn_triton_amd.utils"] +disallow_incomplete_defs = true +disallow_untyped_defs = true + +# Test files - based on test_flash_attn.py, looser rules but catches import/export errors +[[tool.mypy.overrides]] +module = ["test_flash_attn_triton_amd", "hopper.test_flash_attn_triton_amd"] +strict_optional = false +check_untyped_defs = false + diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py deleted file mode 100644 index 58e2ae5fc7f..00000000000 --- a/flash_attn/flash_attn_triton_amd/test.py +++ /dev/null @@ -1,932 +0,0 @@ -import os -import glob -import shutil -import time -import torch -import pytest -import logging -import numpy as np -from pathlib import Path -from flash_attn import ( - flash_attn_func, - flash_attn_fp8_func, - flash_attn_kvpacked_func, - flash_attn_qkvpacked_func, - flash_attn_qkvpacked_fp8_func, - flash_attn_varlen_func, - flash_attn_varlen_fp8_func, - flash_attn_varlen_kvpacked_func, - flash_attn_varlen_qkvpacked_func, - flash_attn_varlen_qkvpacked_fp8_func -) - -from .utils import DEBUG, input_helper, arch_supports_fp8 -from .fwd_ref import attention_forward_pytorch_ref_impl -from .fwd_prefill import attention_prefill_forward_triton_impl -from .bwd_prefill_split import attention_prefill_backward_triton_split_impl -from .bwd_ref import attention_backward_pytorch_ref_impl - -# set print options -# torch.set_printoptions(linewidth=5e5, edgeitems=10, sci_mode=False) -# np.set_printoptions(linewidth=5000, threshold=1e4, suppress=True, precision=4) - -# defailt fp16 tolerance is ATOL, RTOL = 1e-5, 1e-3. See table https://pytorch.org/docs/stable/testing.html -ATOL, RTOL = 1e-2, 1e-2 # old standard. maybe to lose. -# ATOL, RTOL = 1e-3, 1e-3 # catchs fa mismatch issues -# ATOL, RTOL = 1e-4, 1e-3 # to strict. there will be small diffs -# ATOL, RTOL = 1e-5, 1e-3 # # default fp16. there will be small diffs -# ATOL_fp8, RTOL_fp8 = 1e-1, 1e-1 # to strict for larger tensors in fp8 -ATOL_fp8, RTOL_fp8 = 2.5e-1, 2.5e-1 # fp8 -# ATOL_fp8, RTOL_fp8 = 2e-2, 2e-2 # fp8 -EQUAL_NAN = True - -@pytest.mark.parametrize( - "BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", - [ - (1, 1, 1, 1, 1, 1), - (1, 1, 1, 2, 4, 16), - (1, 2, 2, 2, 4, 16), - (1, 4, 1, 2, 4, 16), - (1, 4, 2, 2, 4, 16), - (1, 1, 1, 4, 2, 16), - (1, 1, 1, 4, 4, 16), - (1, 2, 2, 4, 4, 16), - (2, 1, 1, 4, 4, 16), - (2, 2, 2, 4, 4, 16), - (1, 1, 1, 128, 64, 16), - (2, 2, 2, 2, 128, 1), - (2, 3, 3, 2, 128, 16), - (3, 2, 2, 256, 512, 16), - (3, 3, 3, 128, 128, 64), - (2, 4, 4, 1024, 1024, 64), - (4, 6, 6, 108, 256, 224), - (4, 8, 8, 2048, 2048, 128), - (4, 16, 16, 4096, 4096, 64), - (2, 4, 4, 8192, 8192, 32), - # fa configs - (4, 6, 1, 113, 203, 256), - (4, 6, 1, 128, 217, 256), - (4, 6, 2, 113, 211, 128), - (4, 6, 2, 108, 256, 128), - (4, 6, 1, 256, 512, 64), - (4, 6, 1, 512, 256, 64), - (4, 6, 2, 1024, 1024, 32), - (4, 6, 2, 1023, 1024, 32), - (4, 6, 6, 1024, 1023, 32), - (4, 6, 6, 2048, 2048, 32), - ], -) -@pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('dropout_p', [0.0]) -@pytest.mark.parametrize('alibi_slopes', [None]) -@pytest.mark.parametrize('layout', ["bshd", "thd"]) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.parametrize('use_exp2', [True, False]) # works when use_exp2 is false -@pytest.mark.parametrize('DEBUG_INPUT', [False]) # NOTE: debug input can overflow when the tensors are large. Just use to figure out issues -def test_op_prefill_fwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, alibi_slopes, layout, dtype, use_exp2, DEBUG_INPUT): - torch.manual_seed(42) - device = "cuda" - - q, k, v, do, metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, dtype, layout=layout, device=device) - - if DEBUG: - if HQ // HK != 1: - print("MQA/GQA") - else: - print("MHA") - - # update metadata - metadata.use_exp2 = use_exp2 - if causal: - metadata.need_causal(True) - - # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - metadata.need_dropout(dropout_p) - - - # call Triton's forward implementation directly - q_triton = q.clone() - k_triton = k.clone() - v_triton = v.clone() - o_triton = torch.zeros_like(q).contiguous() if DEBUG_INPUT else torch.empty_like(q) - softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl( - q_triton, - k_triton, - v_triton, - o_triton, - metadata.sm_scale, - metadata.alibi_slopes, - metadata.causal, - metadata.bias, - metadata.layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.cache_seqlens, - metadata.cache_batch_idx, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - metadata.return_scores, - metadata.use_exp2, - None, - None, - None, - None) - - # ref forward - q_ref = q.clone() - k_ref = k.clone() - v_ref = v.clone() - o_ref = torch.zeros_like(q).contiguous() if DEBUG_INPUT else torch.empty_like(q) - softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( - q_ref, - k_ref, - v_ref, - o_ref, - metadata.sm_scale, - metadata.alibi_slopes, - causal, - layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - use_exp2 - ) - - if DEBUG: - print() - print("Compare Triton Impl with refernce Pytorch Impl") - - # this can be set to true manually or when using dropout - if metadata.return_scores: - if DEBUG: - print("sd_mask_triton:", sd_mask_triton, sd_mask_triton.shape) - print("sd_mask_ref:", sd_mask_ref, sd_mask_ref.shape) - torch.testing.assert_close(sd_mask_triton.to(sd_mask_ref.dtype), sd_mask_ref, atol=ATOL, rtol=RTOL) - - if DEBUG: - print("softmax_lse_triton:", softmax_lse_triton, softmax_lse_triton.shape) - print("softmax_lse_ref:", softmax_lse_ref, softmax_lse_ref.shape) - torch.testing.assert_close(softmax_lse_triton, softmax_lse_ref, atol=ATOL, rtol=RTOL) - - if DEBUG: - print("output_triton:", o_triton, o_triton.shape) - print("output_ref:", o_ref, o_ref.shape) - torch.testing.assert_close(o_triton, o_ref, atol=ATOL, rtol=RTOL) - -@pytest.mark.parametrize( - "BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", [ - (1, 1, 1, 1, 1, 1), - (1, 1, 1, 4, 4, 4), - (2, 1, 1, 4, 4, 16), - (1, 2, 2, 4, 4, 16), - (1, 4, 1, 2, 4, 16), - (1, 8, 1, 2, 4, 16), - (1, 16, 1, 2, 4, 16), - (1, 32, 1, 2, 4, 16), - (1, 64, 1, 2, 4, 16), - (1, 4, 2, 2, 4, 16), - (2, 2, 2, 4, 4, 16), - (1, 1, 1, 4, 4, 16), - (2, 1, 1, 4, 4 , 16), - (4, 6, 6, 8, 8 , 16), - (1, 1, 1, 4, 4, 32), - (1, 1, 1, 16, 16, 16), - (1, 1, 1, 32, 32, 16), - (1, 1, 1, 64, 64, 16), - (1, 1, 1, 64, 64, 16), - (1, 1, 1, 64, 128, 16), - (1, 1, 1, 64, 64, 32), - (1, 1, 1, 64, 128, 32), - (1, 1, 1, 128, 128, 64), - (1, 1, 1, 128, 256, 45), - (1, 1, 1, 113, 203, 192), - (1, 1, 1, 256, 256, 64), - (1, 1, 1, 256, 512, 16), - (1, 1, 1, 512, 512, 64), - (1, 1, 1, 1024, 1024, 64), - # fa configs - (2, 2, 2, 128, 128, 65), - (2, 2, 2, 128, 128, 224), - (4, 6, 6, 108, 256, 224), - (1, 1, 1, 256, 512, 16), - # old tests that work - (4, 48, 6, 1024, 1024, 64), - (4, 48, 12, 2048, 1024, 64), - (4, 48, 24, 1024, 1024, 64), - (4, 48, 48, 1024, 1024, 64), - (4, 48, 48, 1024, 1024, 73), - (4, 48, 48, 2048, 2048, 64), - (1, 24, 24, 4096, 4096, 64), - (1, 16, 16, 1024, 1024, 64), - (1, 16, 16, 1024, 1024, 128), - # testcase new - # seqlen q == k - (1, 1, 1, 2, 2, 2), # small enough to debug - (1, 1, 1, 128, 128, 32), # only one block - (1, 1, 1, 127, 127, 32), # only one block but with masking - (1, 1, 1, 129, 129, 1), # two blocks with 2nd block small enough to debug - (1, 1, 1, 350, 350, 1), # two blocks with 2nd block small enough to debug - (1, 1, 1, 350, 350, 68), # generic masking on q, k and head - (4, 1, 1, 512, 512, 128), # batch > 1 - (4, 8, 2, 512, 512, 128), # GQA - (4, 8, 2, 512, 512, 68), # non-power-of-2 head_dim - (4, 8, 2, 500, 500, 68), # comprehensive case for seqlen q == k - # seqlen q > k - (1, 1, 1, 64, 32, 8), # seqlen_q > seqlen_k - (1, 1, 1, 192, 128, 32), # seqlen_q > seqlen_k - (4, 8, 2, 1024, 512, 68), # seqlen_q < seqlen_k - (1, 1, 1, 729, 516, 68), # seqlen_q > seqlen_k - (16, 16, 4, 2753, 1528, 68), # a comprehensive seqlen_q > seqlen_k - # seqlen q < k - (1, 1, 1, 32, 64, 8), # seqlen_q > seqlen_k - (1, 1, 1, 128, 192, 32), # seqlen_q < seqlen_k - (4, 8, 2, 512, 1024, 68), # seqlen_q < seqlen_k - (1, 1, 1, 200, 413, 1), # seqlen_q < seqlen_k - (1, 1, 1, 782, 1546, 1), # seqlen_q < seqlen_k - (16, 16, 4, 1528, 2753, 68), # a comprehensive seqlen_q < seqlen_k -]) -@pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('dropout_p', [0.0]) -@pytest.mark.parametrize('alibi_slopes', [None]) -@pytest.mark.parametrize('layout', ["bshd", "thd"]) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.parametrize('use_exp2', [False]) # FIXME: using exp2 causes issue when used with causal -@pytest.mark.parametrize('DEBUG_INPUT', [False]) # debug output causes nans on larger tensors -def test_op_prefill_bwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, alibi_slopes, layout, dtype, use_exp2, DEBUG_INPUT): - torch.manual_seed(20) - device="cuda" - - # gen inputs - q, k, v, do, metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, dtype, layout=layout, device=device) - - # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - metadata.need_dropout(dropout_p) - - # =============================================== Reference ============================================================== - # fwd - q_ref = q.clone() - k_ref = k.clone() - v_ref = v.clone() - output_ref = torch.zeros_like(q).contiguous() if DEBUG_INPUT else torch.empty_like(q) - softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( - q_ref, - k_ref, - v_ref, - output_ref, - metadata.sm_scale, - metadata.alibi_slopes, - causal, - layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - use_exp2 - ) - - # bwd - do_ref = do.clone() - dq_ref = torch.zeros_like(q).contiguous() if DEBUG_INPUT else torch.empty_like(q) - dk_ref = torch.zeros_like(k).contiguous() if DEBUG_INPUT else torch.empty_like(k) - dv_ref = torch.zeros_like(v).contiguous() if DEBUG_INPUT else torch.empty_like(v) - delta_ref = attention_backward_pytorch_ref_impl( - do_ref, - q_ref, - k_ref, - v_ref, - output_ref, - softmax_lse_ref, - dq_ref, - dk_ref, - dv_ref, - metadata.sm_scale, - metadata.alibi_slopes, - causal, - layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - use_exp2 - ) - - # =============================================== Triton ============================================================== - do_triton = do.clone() - q_triton = q.clone() - k_triton = k.clone() - v_triton = v.clone() - o_triton = output_ref.clone().contiguous() - softmax_lse_triton = softmax_lse_ref.clone().contiguous() - dq_triton = torch.zeros_like(q_triton, dtype=q.dtype) # NOTE: the kernel does inplace accumlation on dq so dq has to be zeros - dk_triton = torch.zeros_like(k_triton, dtype=k.dtype) if DEBUG_INPUT else torch.empty_like(k_triton, dtype=k.dtype) - dv_triton = torch.zeros_like(v_triton, dtype=v.dtype) if DEBUG_INPUT else torch.empty_like(v_triton, dtype=v.dtype) - delta_triton = attention_prefill_backward_triton_split_impl( - do_triton, - q_triton, - k_triton, - v_triton, - o_triton, - softmax_lse_triton, - dq_triton, - dk_triton, - dv_triton, - metadata.sm_scale, - alibi_slopes, - causal, - layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - use_exp2, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - # =============================================== Check ============================================================== - if DEBUG: - print() - if DEBUG: - print("delta_triton:", delta_triton, delta_triton.shape) - print("delta_ref:", delta_ref, delta_ref.shape) - torch.testing.assert_close(delta_triton, delta_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN) - - if DEBUG: - print("dv_triton:", dv_triton, dv_triton.shape) - print("dv_ref:", dv_ref, dv_ref.shape) - torch.testing.assert_close(dv_triton, dv_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN) - - if DEBUG: - print("dk_triton:", dk_triton, dk_triton.shape) - print("dk_ref:", dk_ref, dk_ref.shape) - torch.testing.assert_close(dk_triton, dk_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN) - - if DEBUG: - print("dq_triton:", dq_triton, dq_triton.shape) - print("dq_ref:", dq_ref, dq_ref.shape) - torch.testing.assert_close(dq_triton, dq_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN) - -def fp8_assert_close(tensor_a, tensor_b, atol=ATOL_fp8, rtol=RTOL_fp8, max_diff_percentage=0.5): - """Assert tensors are close with tolerance for small percentage of elements""" - # standard comparison - abs_diff = torch.abs(tensor_a - tensor_b) - rel_diff = abs_diff / torch.abs(tensor_b.clamp(min=1e-6)) - - # calculate elements that exceed tolerance - abs_check = abs_diff > atol - rel_check = rel_diff > rtol - failed_check = torch.logical_and(abs_check, rel_check) - - # calculate percentage of failed elements - failed_percentage = failed_check.sum().item() / failed_check.numel() * 100 - - # if percentage is small enough, test passes - if failed_percentage <= max_diff_percentage: - return True - - # Otherwise, provide diagnostic information - max_abs_idx = torch.argmax(abs_diff).item() - max_rel_idx = torch.argmax(rel_diff).item() - - flat_to_idx = lambda flat_idx, shape: np.unravel_index(flat_idx, shape) - - max_abs_pos = flat_to_idx(max_abs_idx, tensor_a.shape) - max_rel_pos = flat_to_idx(max_rel_idx, tensor_a.shape) - - max_abs_diff = abs_diff.flatten()[max_abs_idx].item() - max_rel_diff = rel_diff.flatten()[max_rel_idx].item() - - raise AssertionError( - f"Tensors not close enough! {failed_percentage:.6f}% elements exceed tolerance.\n" - f"Greatest absolute difference: {max_abs_diff} at index {max_abs_pos} (up to {atol} allowed)\n" - f"Greatest relative difference: {max_rel_diff} at index {max_rel_pos} (up to {rtol} allowed)" - ) - -@pytest.mark.parametrize( - "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", - [ - # seqlen q == k - (1, 1, 1, 1, 1, 1), - (1, 1, 1, 2, 2, 2), # small enough to debug - (1, 1, 1, 4, 4, 16), - (1, 2, 2, 4, 4, 16), - (2, 1, 1, 4, 4, 16), - (2, 2, 2, 4, 4, 16), - (1, 1, 1, 128, 128, 32), # only one block - (3, 3, 3, 128, 128, 64), - (1, 1, 1, 127, 127, 32), # only one block but with masking - # (1, 1, 1, 129, 129, 1), # two blocks with 2nd block small enough to debug # fails - (1, 2, 2, 129, 129, 32), # two blocks with 2nd block small enough to debug - (1, 1, 1, 350, 350, 32), # two blocks with 2nd block small enough to debug - (1, 1, 1, 350, 350, 68), # generic masking on q, k and head - (4, 1, 1, 512, 512, 128), # batch > 1 - (4, 2, 2, 512, 512, 128), - (4, 2, 2, 512, 512, 68), - (4, 2, 2, 500, 500, 68), - (2, 4, 4, 1024, 1024, 64), - (4, 8, 8, 2048, 2048, 128), - (2, 8, 8, 4096, 4096, 64), - (2, 4, 4, 8192, 8192, 32), - # seqlen q > k - (1, 1, 1, 4, 2, 16), - (1, 1, 1, 64, 32, 8), - (1, 1, 1, 128, 64, 16), - (1, 1, 1, 192, 128, 32), - (1, 2, 2, 1024, 512, 68), - (1, 4, 4, 729, 516, 68), - (2, 4, 4, 2753, 1528, 68), # a comprehensive seqlen_q > seqlen_k - # seqlen q < k - (1, 1, 1, 2, 4, 16), - (1, 2, 2, 2, 4, 16), - (1, 4, 1, 2, 4, 16), - (1, 4, 2, 2, 4, 16), - (2, 2, 2, 2, 128, 1), - (2, 3, 3, 2, 128, 16), - (1, 1, 1, 32, 64, 8), - (1, 1, 1, 128, 192, 32), - (4, 6, 6, 108, 256, 32), - (3, 2, 2, 256, 512, 16), - (2, 2, 2, 512, 1024, 68), - (1, 1, 1, 200, 413, 32), - (1, 1, 1, 782, 1546, 32), - # gqa/mqa # mismatch issue on varlen - (4, 8, 2, 500, 500, 68), - (4, 8, 2, 512, 512, 68), - (4, 8, 2, 512, 512, 128), - (4, 8, 2, 512, 1024, 68), - (4, 8, 2, 1024, 512, 64), - (4, 16, 4, 1528, 2753, 68), - # fa configs - (2, 4, 1, 113, 203, 64), - (2, 4, 2, 128, 217, 64), - (2, 6, 2, 113, 211, 128), - (2, 6, 2, 108, 256, 128), - (2, 6, 2, 256, 512, 64), - (2, 6, 2, 512, 256, 64), - (2, 6, 2, 1024, 1024, 32), - (2, 6, 2, 1023, 1024, 32), - (2, 6, 6, 1024, 1023, 32), - (2, 6, 6, 2048, 2048, 32), - ], -) -@pytest.mark.parametrize('causal', [False, True]) -@pytest.mark.parametrize('dropout_p', [0.0]) -@pytest.mark.parametrize('layout', ["bshd", "thd"]) -@pytest.mark.parametrize('packing', [None, "qkv"]) -@pytest.mark.parametrize('DEBUG_INPUT', [False]) -@pytest.mark.flaky(reruns=3, reason="Retry failures") -@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device") -def test_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, packing, DEBUG_INPUT): - torch.manual_seed(20) - test_backward = True - device = "cuda" - window_size = (-1, -1) - softcap = 0.0 - alibi_slopes = None - deterministic = False - ref_dtype = torch.float32 - is_varlen = True if layout == "thd" else False - - # skip QKV packing tests for uneven sequence lengths and head sizes - if packing == 'qkv': - if N_CTX_Q != N_CTX_K: - pytest.skip("QKV packing requires N_CTX_Q == N_CTX_K") - if HQ != HK: - pytest.skip("QKV packing requires HQ == HK") - - # test apis - if packing == 'qkv': - # generate inputs - qkv, do, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout, packing=packing, device=device, DEBUG_INPUT=DEBUG_INPUT) - - # ---------------------------------------------------------------- - # --- FP8 --- - # ---------------------------------------------------------------- - qkv_fp8 = qkv.clone() - do_fp8= do.clone() - - if is_varlen: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_varlen_qkvpacked_fp8_func( - qkv_fp8, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_qkvpacked_fp8_func( - qkv_fp8, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # ---------------------------------------------------------------- - # --- Reference --- - # ---------------------------------------------------------------- - # reference forward pass - qkv_ref = qkv.clone() - do_ref= do.clone() - - if is_varlen: - out_ref, lse_ref, S_dmask_ref = flash_attn_varlen_qkvpacked_func( - qkv_ref, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out_ref, lse_ref, S_dmask_ref = flash_attn_qkvpacked_func( - qkv_ref, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # ---------------------------------------------------------------- - # --- Compare --- - # ---------------------------------------------------------------- - # compare forward - if DEBUG: - print() - print(f"Compare fp8 against ref with dtype {ref_dtype}") - - if DEBUG: - print("out_ref:", out_ref, out_ref.shape) - print("out_fp8:", out_fp8, out_fp8.shape) - fp8_assert_close(out_ref, out_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) - - - if DEBUG: - print("lse_ref:", lse_ref, lse_ref.shape) - print("lse_fp8:", lse_fp8, lse_fp8.shape) - fp8_assert_close(lse_ref, lse_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - - - if dropout_p > 0.0: - if DEBUG: - print("S_dmask_ref:", S_dmask_ref, S_dmask_ref.shape) - print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape) - fp8_assert_close(S_dmask_ref, S_dmask_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - - if not test_backward: - return - - # fp8 backward pass - dqkv_fp8, = torch.autograd.grad(out_fp8, (qkv_fp8), do_fp8) - - # ref backward pass - dqkv_ref, = torch.autograd.grad(out_ref, (qkv_ref), do_ref) - - # compare backward gradients - if DEBUG: - print("dqkv_ref:", dqkv_ref, dqkv_ref.shape) - print("dqkv_fp8:", dqkv_fp8, dqkv_fp8.shape) - fp8_assert_close(dqkv_ref, dqkv_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - - elif packing is None: - # generate inputs - q, k, v, do, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout, device=device, DEBUG_INPUT=DEBUG_INPUT) - - # ---------------------------------------------------------------- - # --- FP8 --- - # ---------------------------------------------------------------- - if DEBUG: - print() - print(f"Compute Fp8 Forward") - q_fp8 = q.clone() - k_fp8 = k.clone() - v_fp8 = v.clone() - do_fp8= do.clone() - - if is_varlen: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_varlen_fp8_func( - q_fp8, - k_fp8, - v_fp8, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_fp8_func( - q_fp8, - k_fp8, - v_fp8, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # ---------------------------------------------------------------- - # --- Reference --- - # ---------------------------------------------------------------- - if DEBUG: - print() - print(f"Compute Reference Forward") - # reference forward pass - q_ref = q.clone() - k_ref = k.clone() - v_ref = v.clone() - do_ref = do.clone() - - if is_varlen: - out_ref, lse_ref, S_dmask_ref = flash_attn_varlen_func( - q_ref, - k_ref, - v_ref, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out_ref, lse_ref, S_dmask_ref = flash_attn_func( - q_ref, - k_ref, - v_ref, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # ---------------------------------------------------------------- - # --- Compare --- - # ---------------------------------------------------------------- - # compare forward - if DEBUG: - print() - print(f"Compare fp8 against ref with dtype {ref_dtype}") - - if DEBUG: - print("out_ref:", out_ref, out_ref.shape) - print("out_fp8:", out_fp8, out_fp8.shape) - # torch.testing.assert_close(out_ref, out_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - fp8_assert_close(out_ref, out_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) - - - if DEBUG: - print("lse_ref:", lse_ref, lse_ref.shape) - print("lse_fp8:", lse_fp8, lse_fp8.shape) - # torch.testing.assert_close(lse_ref, lse_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - fp8_assert_close(lse_ref, lse_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - - - if dropout_p > 0.0: - if DEBUG: - print("S_dmask_ref:", S_dmask_ref, S_dmask_ref.shape) - print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape) - # torch.testing.assert_close(S_dmask_ref, S_dmask_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - fp8_assert_close(S_dmask_ref, S_dmask_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - - if not test_backward: - return - - if DEBUG: - print() - print(f"Compute Fp8 Backward") - # fp8 backward pass - dq_fp8, dk_fp8, dv_fp8 = torch.autograd.grad(out_fp8, (q_fp8, k_fp8, v_fp8), do_fp8) - - if DEBUG: - print() - print(f"Compute Reference Backward") - # ref backward pass - dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), do_ref) - - # compare backward gradients - if DEBUG: - print("dv_ref:", dv_ref, dv_ref.shape) - print("dv_fp8:", dv_fp8, dv_fp8.shape) - # torch.testing.assert_close(dv_ref, dv_fp8, atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN) - fp8_assert_close(dv_ref, dv_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) - - if DEBUG: - print("dk_ref:", dk_ref, dk_ref.shape) - print("dk_fp8:", dk_fp8, dk_fp8.shape) - # torch.testing.assert_close(dk_ref, dk_fp8, atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN) - fp8_assert_close(dk_ref, dk_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) - - if DEBUG: - print("dq_ref:", dq_ref, dq_ref.shape) - print("dq_fp8:", dq_fp8, dq_fp8.shape) - # torch.testing.assert_close(dq_ref, dq_fp8, atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN) - fp8_assert_close(dq_ref, dq_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) - -@pytest.mark.parametrize( - "BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", - [ - (2, 4, 4, 512, 512, 128), - ], -) -@pytest.mark.parametrize('causal', [False, True]) -@pytest.mark.parametrize('dropout_p', [0.0, 0.1]) -@pytest.mark.parametrize('layout', ['bshd']) -@pytest.mark.parametrize('packing', [None]) -@pytest.mark.parametrize('test_backward', [False, True]) -@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device") -@pytest.mark.skip("Breaks on CI but works locally") -def test_ir(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, packing, test_backward): # Don't run this test in parallel. It clears the cache so it doesnot work properly if run in parallel. - torch.manual_seed(20) - device = "cuda" - window_size = (-1, -1) - softcap = 0.0 - alibi_slopes = None - deterministic = False - ref_dtype = torch.float32 - is_varlen = True if layout == "thd" else False - - # remove cache - cache_path = Path(os.path.expanduser("~/.triton/cache")) - if cache_path.exists(): - shutil.rmtree(cache_path) - os.makedirs(cache_path) - - # inputs - q, k, v, do, metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout=layout, packing=packing, device=device) - - if packing == None: - # fp8 forward pass - if is_varlen: - out, lse, S_dmask = flash_attn_varlen_fp8_func( - q, - k, - v, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out, lse, S_dmask = flash_attn_fp8_func( - q, - k, - v, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # fp8 backward pass - if test_backward: - dq, dk, dv = torch.autograd.grad(out, (q, k, v), do) - elif packing == "qkv": - # qkv packing path - # pack input tensors (use dim=1 for varlen, else dim=2) - if is_varlen: - qkv = torch.stack([q, k, v], dim=1) - else: - qkv = torch.stack([q, k, v], dim=2) - - # fp8 forward pass for qkv-packed input - if is_varlen: - out, lse, S_dmask = flash_attn_varlen_qkvpacked_fp8_func( - qkv, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out, lse, S_dmask = flash_attn_qkvpacked_fp8_func( - qkv, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # fp8 backward pass for qkv-packed input - if test_backward: - dqkv, = torch.autograd.grad(out, (qkv,), do) - else: - raise ValueError(f"unknown packing type {packing}") - - # search for .ttir files - max_retries = 5 - retry_delay = 0.5 - ttir_files = [] - logging.info(f"Checking for .ttir files in {cache_path}...") - for attempt in range(max_retries): - # search for .ttir files recursively within the cache path - ttir_files = glob.glob(str(cache_path) + "/**/*.ttir", recursive=True) - - if ttir_files: - # Files found, log success and exit the loop - logging.info(f"Found {len(ttir_files)} .ttir files on attempt {attempt + 1}.") - break - else: - # Files not found yet - if attempt < max_retries - 1: - # If not the last attempt, wait and log before retrying - logging.warning( - f"No .ttir files found on attempt {attempt + 1}. " - f"Retrying in {retry_delay}s..." - ) - time.sleep(retry_delay) - else: - pytest.fail( - f"FATAL: No .ttir files found in cache {cache_path} " - f"after {max_retries} attempts." - ) - - # check if there is fp8 - ttir_files_fp8_found_status = {} - fp8_types = ['f8E4M3', 'f8E5M2'] - for ttir_file in ttir_files: - base_name = os.path.basename(ttir_file) - with open(ttir_file, 'r') as f: - content = f.read() - - # check content for fp8 - fp8_found = False - for f8_type in fp8_types: - if f8_type in content: - fp8_found = True - ttir_files_fp8_found_status[base_name] = fp8_found - - for file, fp8_found in ttir_files_fp8_found_status.items(): - assert fp8_found, f"{fp8_types} not found in {file}" diff --git a/flash_attn/flash_attn_triton_amd/train.py b/flash_attn/flash_attn_triton_amd/train.py deleted file mode 100644 index fc5f5d0b1bf..00000000000 --- a/flash_attn/flash_attn_triton_amd/train.py +++ /dev/null @@ -1,403 +0,0 @@ -import torch -import torch.nn as nn -import torch.optim as optim -from torch.utils.data import DataLoader, Dataset, random_split -import numpy as np -import pandas as pd -from tqdm import tqdm -import matplotlib.pyplot as plt -from datasets import load_dataset -from flash_attn import flash_attn_qkvpacked_func, flash_attn_qkvpacked_fp8_func, flash_attn_varlen_qkvpacked_func, flash_attn_varlen_qkvpacked_fp8_func - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -print(f"using device: {device}") - -# ------------------------------- -# Model -# ------------------------------- -class FlashAttention(nn.Module): - def __init__(self, dim, num_heads=8, causal=True, dropout=0.1, qkv_bias=True, use_fp8=False): - super().__init__() - self.use_fp8 = use_fp8 - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - self.causal = causal - self.dropout_p = dropout - - # qkv and output projections - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.proj = nn.Linear(dim, dim) - - def forward(self, x): - b, n, c = x.shape - # project to qkv - qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim).permute(2, 0, 1, 3, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - - # reshape for flash attention function - qkv_packed = torch.stack([q, k, v], dim=2).reshape(b, n, 3, self.num_heads, self.head_dim) - - # use the appropriate flash attention function - if self.use_fp8: - context = flash_attn_qkvpacked_fp8_func( - qkv_packed, - dropout_p=self.dropout_p, - causal=self.causal - ) - else: - context = flash_attn_qkvpacked_func( - qkv_packed, - dropout_p=self.dropout_p, - causal=self.causal - ) - - # convert back to original shape and type - context = context.reshape(b, n, c) - - # output projection - x = self.proj(context) - - return x - -class TransformerBlock(nn.Module): - def __init__(self, dim, num_heads, mlp_ratio=4.0, causal=True, dropout=0.1, use_fp8=False): - super().__init__() - self.norm1 = nn.LayerNorm(dim) - self.attn = FlashAttention(dim, num_heads=num_heads, causal=causal, dropout=dropout, use_fp8=use_fp8) - - self.norm2 = nn.LayerNorm(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = nn.Sequential( - nn.Linear(dim, mlp_hidden_dim), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(mlp_hidden_dim, dim), - nn.Dropout(dropout) - ) - - def forward(self, x): - x = x + self.attn(self.norm1(x)) - x = x + self.mlp(self.norm2(x)) - return x - -class FlashLM(nn.Module): - def __init__( - self, - vocab_size, - dim=256, - depth=6, - num_heads=8, - mlp_ratio=4.0, - causal=True, - dropout=0.1, - max_seq_len=256, - use_fp8=False - ): - super().__init__() - - # embedding layers - self.token_embedding = nn.Embedding(vocab_size, dim) - self.position_embedding = nn.Parameter(torch.zeros(1, max_seq_len, dim)) - self.dropout = nn.Dropout(dropout) - - # transformer blocks - self.blocks = nn.ModuleList([ - TransformerBlock(dim, num_heads, mlp_ratio, causal=causal, dropout=dropout, use_fp8=use_fp8) - for _ in range(depth) - ]) - - # lm head: project back to vocabulary dimension for each token - self.norm = nn.LayerNorm(dim) - self.lm_head = nn.Linear(dim, vocab_size) - - def forward(self, x): - b, n = x.shape - - # token + positional embedding - x = self.token_embedding(x) - x = x + self.position_embedding[:, :n, :] - x = self.dropout(x) - - # transformer blocks - for block in self.blocks: - x = block(x) - - # language modeling head - x = self.norm(x) - logits = self.lm_head(x) # shape: (b, n, vocab_size) - return logits - -# ------------------------------- -# Data -# ------------------------------- -class TextDataset(Dataset): - def __init__(self, sequences, max_len=None): - self.sequences = sequences - self.max_len = max_len - - def __len__(self): - return len(self.sequences) - - def __getitem__(self, idx): - seq = self.sequences[idx] - # input: all tokens except the last, target: all tokens except the first - return (torch.tensor(seq[:-1], dtype=torch.long), - torch.tensor(seq[1:], dtype=torch.long)) - -class VarLenTextDataset(Dataset): - def __init__(self, sequences, max_len=256): - self.sequences = sequences - self.max_len = max_len - - def __len__(self): - return len(self.sequences) - - def __getitem__(self, idx): - seq = self.sequences[idx] - # Ensure the sequence doesn't exceed max_len+1 - seq = seq[:self.max_len+1] - # input: all tokens except the last, target: all tokens except the first - return (torch.tensor(seq[:-1], dtype=torch.long), - torch.tensor(seq[1:], dtype=torch.long)) - -def prepare_dataset(batch_size, is_varlen=False, min_len=10, max_len=256, ratio_shorter=0.7): - # load the WikiText-2 - dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") - - # build vocabulary - corpus = " ".join([line for line in dataset["text"] if line.strip() != ""]) # join non-empty lines into a single corpus string - tokens = corpus.split() - vocab = sorted(set(tokens)) - word2idx = {word: idx for idx, word in enumerate(vocab)} - token_ids = [word2idx[word] for word in tokens] - - num_workers = 2 - if is_varlen: - # VARIABLE LENGTH: create sequences of different lengths - sequences = [] - for i in range(0, len(token_ids) - max_len, max_len // 2): # overlap to get more sequences - # Decide target length for this sequence - if np.random.random() < ratio_shorter: - # Shorter sequence - target_len = np.random.randint(min_len + 1, max_len + 1) - else: - # Full length sequence - target_len = max_len + 1 - - # Extract sequence up to target length or whatever's available - seq_end = min(i + target_len, len(token_ids)) - seq = token_ids[i:seq_end] - - # Only keep sequences that are long enough - if len(seq) > min_len + 1: # +1 because we need both input and target - sequences.append(seq) - - print(f"Created {len(sequences)} variable-length sequences") - - # Get some statistics - lens = [len(seq) for seq in sequences] - print(f"Sequence length stats: min={min(lens)}, max={max(lens)}, mean={np.mean(lens):.1f}") - - # split dataset - num_samples = len(sequences) - num_train = int(0.8 * num_samples) - num_val = num_samples - num_train - - # Use appropriate dataset class based on whether we need variable length - dataset_class = VarLenTextDataset - train_sequences = sequences[:num_train] - val_sequences = sequences[num_train:] - - train_dataset = dataset_class(train_sequences, max_len) - val_dataset = dataset_class(val_sequences, max_len) - - - # collate function - def collate_fn(batch): - """ - Collate function that creates a flat representation for variable length flash attention. - """ - # Separate inputs and targets - inputs, targets = zip(*batch) - - # Get sequence lengths - seq_lens = torch.tensor([len(x) for x in inputs], dtype=torch.int32) - - # Concatenate inputs and targets into single tensors - flat_inputs = torch.cat(inputs) - flat_targets = torch.cat(targets) - - # Create cumulative sequence lengths tensor - cu_seqlens = torch.zeros(len(seq_lens) + 1, dtype=torch.int32) - cu_seqlens[1:] = torch.cumsum(seq_lens, dim=0) - - # Calculate max sequence length for this batch - max_seqlen = seq_lens.max().item() - - return flat_inputs, flat_targets, seq_lens, cu_seqlens, max_seqlen - - # data loaders - train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn) - val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn) - else: - # FIXED LENGTH: create sequences of length max_len+1 - sequences = [] - for i in range(0, len(token_ids) - max_len, max_len): - seq = token_ids[i : i + max_len + 1] - if len(seq) == max_len + 1: - sequences.append(seq) - - # split dataset - num_samples = len(sequences) - num_train = int(0.8 * num_samples) - num_val = num_samples - num_train - train_dataset, val_dataset = random_split(TextDataset(sequences), [num_train, num_val]) - - # data loaders - train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) - val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) - - vocab_size = len(vocab) - print(f"vocab size: {vocab_size}, train samples: {len(train_dataset)}, validation samples: {len(val_dataset)}") - return train_dataloader, val_dataloader, vocab_size - -# ------------------------------- -# Training -# ------------------------------- -def train_lm(model, train_dataloader, val_dataloader, optimizer, criterion, num_epochs): - train_losses = [] - val_losses = [] - for epoch in range(num_epochs): - # Training phase - model.train() - epoch_train_loss = 0.0 - for inputs, targets in tqdm(train_dataloader, desc=f"epoch {epoch+1}/{num_epochs} [train]"): - inputs, targets = inputs.to(device), targets.to(device) - - optimizer.zero_grad() - logits = model(inputs) - loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) - loss.backward() - optimizer.step() - - epoch_train_loss += loss.item() - - epoch_train_loss /= len(train_dataloader) - train_losses.append(epoch_train_loss) - print(f"epoch {epoch+1}/{num_epochs} - train loss: {epoch_train_loss:.4f}") - - # Validation phase - model.eval() - epoch_val_loss = 0.0 - with torch.no_grad(): - for inputs, targets in tqdm(val_dataloader, desc=f"epoch {epoch+1}/{num_epochs} [validation]"): - inputs, targets = inputs.to(device), targets.to(device) - logits = model(inputs) - loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) - epoch_val_loss += loss.item() - epoch_val_loss /= len(val_dataloader) - val_losses.append(epoch_val_loss) - print(f"epoch {epoch+1}/{num_epochs} - validation loss: {epoch_val_loss:.4f}") - - return train_losses, val_losses - -# ------------------------------- -# Main -# ------------------------------- -def main(): - # hyperparameters - batch_size = 16 - num_epochs = 20 - learning_rate = 3e-4 - max_len = 128 # total length including both input and target tokens - is_varlen = False - causal=True - dropout=0.1 - - # prep data - print("Preparing Dataset") - train_dataloader, val_dataloader, vocab_size = prepare_dataset(batch_size, max_len=max_len, is_varlen=is_varlen) - - # create language models - print("Creating Models") - model_normal = FlashLM( - vocab_size=vocab_size, - dim=256, - depth=3, - num_heads=8, - causal=causal, - dropout=dropout, - max_seq_len=max_len, - ).to(device) - - model_fp8 = FlashLM( - vocab_size=vocab_size, - dim=256, - depth=3, - num_heads=8, - causal=causal, - dropout=dropout, - max_seq_len=max_len, - use_fp8=True - ).to(device) - - # Train Normal model - print("Starting training for Normal model...") - optimizer_normal = optim.AdamW(model_normal.parameters(), lr=learning_rate) - criterion = nn.CrossEntropyLoss() - normal_train_losses, normal_val_losses = train_lm( - model_normal, train_dataloader, val_dataloader, optimizer_normal, criterion, num_epochs - ) - torch.save(model_normal.state_dict(), 'flash_lm_normal.pth') - print("Normal model training complete and saved.") - - # Train FP8 model - print("Starting training for FP8 model...") - optimizer_fp8 = optim.AdamW(model_fp8.parameters(), lr=learning_rate) - fp8_train_losses, fp8_val_losses = train_lm( - model_fp8, train_dataloader, val_dataloader, optimizer_fp8, criterion, num_epochs - ) - torch.save(model_fp8.state_dict(), 'flash_lm_fp8.pth') - print("FP8 model training complete and saved.") - - # save losses to csv - epochs = range(1, num_epochs+1) - loss_data = { - "Epoch": epochs, - "Normal_Training_Loss": normal_train_losses, - "Normal_Validation_Loss": normal_val_losses, - "FP8_Training_Loss": fp8_train_losses, - "FP8_Validation_Loss": fp8_val_losses, - } - df_losses = pd.DataFrame(loss_data) - df_losses.to_csv("losses.csv", index=False) - print("Loss data saved to losses.csv") - - # plot Training Loss - plt.figure(figsize=(10, 6)) - plt.plot(epochs, normal_train_losses, label="Normal Training Loss", marker='o') - plt.plot(epochs, fp8_train_losses, label="FP8 Training Loss", marker='x') - plt.xlabel("Epoch") - plt.ylabel("Training Loss") - plt.title("Training Loss Comparison: Normal vs FP8 Flash Attention") - plt.legend() - plt.grid(True) - plt.savefig("training_loss.png") # Saves the training loss plot to disk - plt.show() - - # Plot Validation Loss - plt.figure(figsize=(10, 6)) - plt.plot(epochs, normal_val_losses, label="Normal Validation Loss", marker='o') - plt.plot(epochs, fp8_val_losses, label="FP8 Validation Loss", marker='x') - plt.xlabel("Epoch") - plt.ylabel("Validation Loss") - plt.title("Validation Loss Comparison: Normal vs FP8 Flash Attention") - plt.legend() - plt.grid(True) - plt.savefig("validation_loss.png") # Saves the validation loss plot to disk - plt.show() - - -if __name__ == "__main__": - main() diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 5d3bf02e1f8..358467157c7 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -1,627 +1,185 @@ -import csv -import math -import torch -import os -import random +""" +Utilities for Flash Attention Triton AMD backend. + +This module contains essential runtime utilities: +- GPU architecture detection +- Global configuration flags +- Tensor shape/stride helpers +- FP8 type detection +""" import functools -import triton -import triton.language as tl +import os +from dataclasses import dataclass from typing import Literal, Optional, Union -# ------------------------------- -# Gloabl Variables -# ------------------------------- -AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '0').lower() in ('1', 'true', 'yes') -DEBUG = os.environ.get('FLASH_ATTENTION_TRITON_AMD_DEBUG', '0').lower() in ('1', 'true', 'yes') -USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes') -PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes') -USE_SINGLE_BWD_KERNEL = os.environ.get('USE_SINGLE_BWD_KERNEL', '0').lower() in ('1', 'true', 'yes') -USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" -USE_TRITON_INTERPRET = os.environ.get('TRITON_INTERPRET', '0').lower() in ('1', 'true', 'yes') -DEBUG_TRITON = os.environ.get('DEBUG_TRITON', '0').lower() in ('1', 'true', 'yes') and USE_TRITON_INTERPRET -DEBUG_TRITON_DETAIL = os.environ.get('DEBUG_TRITON_DETAIL', '0').lower() in ('1', 'true', 'yes') and USE_TRITON_INTERPRET -if USE_TRITON_ROCM: # TODO remove this - random.seed(42) -DROPOUT_USE_PYTORCH = False -DROPOUT_DUMP = False +import torch +import triton -# ------------------------------- -# Metadata -# ------------------------------- -class MetaData(): - cu_seqlens_q: Optional[torch.Tensor] = None - cu_seqlens_k: Optional[torch.Tensor] = None - max_seqlens_q: int = 0 - max_seqlens_k: int = 0 - bias: Optional[torch.Tensor] = None - alibi_slopes: Optional[torch.Tensor] = None - causal: bool = False - num_contexts = 0 - varlen: bool = False - layout: Optional[Literal["bshd", "bhsd", "thd"]] = None - cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None - cache_batch_idx = None - packing: Optional[bool] = None - return_scores: bool = False - dropout_p: float = 0.0 - philox_seed: Optional[int] = None - philox_offset : Optional[int]= None # if dropout_p > 0.0 seed the RNG so we get reproducible results for testing. - # NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW. - use_exp2: bool = False - rotary_sin: Optional[torch.Tensor] = None - rotary_cos: Optional[torch.Tensor] = None - rotary_interleaved: bool = False - rotary_conjunction: bool = False - - - def __repr__(self) -> str: - return (f"MetaData(\n" - f" sm_scale={self.sm_scale},\n" - f" cu_seqlens_q={self.cu_seqlens_q},\n" - f" cu_seqlens_k={self.cu_seqlens_k},\n" - f" max_seqlens_q={self.max_seqlens_q},\n" - f" max_seqlens_k={self.max_seqlens_k},\n" - f" bias={self.bias},\n" - f" alibi_slopes={self.alibi_slopes},\n" - f" causal={self.causal},\n" - f" num_contexts={self.num_contexts},\n" - f" varlen={self.varlen},\n" - f" layout={self.layout},\n" - f" cache_seqlens={self.cache_seqlens},\n" - f" cache_batch_idx={self.cache_batch_idx},\n" - f" dropout_p={self.dropout_p},\n" - f" return_scores={self.return_scores}\n" - f")") - - def __init__(self, sm_scale=1.0): - self.sm_scale = sm_scale - - def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k): - self.varlen = True - self.layout = 'thd' - self.cu_seqlens_q = cu_seqlens_q - self.cu_seqlens_k = cu_seqlens_k - self.max_seqlens_q = max_seqlen_q - self.max_seqlens_k = max_seqlen_k - - # Without "varlen", there should still be one sequence. - assert len(cu_seqlens_q) >= 2 - assert len(cu_seqlens_q) == len(cu_seqlens_k) - - def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k): - assert bias.is_cuda - assert bias.dim() == 4 - assert bias.shape[0] == 1 - assert bias.shape[2:] == (seqlen_q, seqlen_k) - self.bias = bias - - def need_alibi(self, alibi_slopes, batch, nheads): - assert alibi_slopes.is_cuda - assert alibi_slopes.dim() == 2 - assert alibi_slopes.shape[0] == batch - assert alibi_slopes.shape[1] == nheads - self.alibi_slopes = alibi_slopes - - def need_causal(self, causal): - self.causal = causal - - def need_rotary(self, sin, cos, rotary_interleaved, rotary_conjunction=False): - self.rotary_sin = sin - self.rotary_cos = cos - self.rotary_interleaved = rotary_interleaved - self.rotary_conjunction = rotary_conjunction - - def need_dropout(self, dropout_p, return_softmax = True): - self.dropout_p = dropout_p - self.return_softmax = return_softmax - self.philox_seed, self.philox_offset = 0x1BF58, 0x1D4B49 - - def check_args(self, q, k, v, o): - assert q.dim() == k.dim() and q.dim() == v.dim() - - batch, nheads_q, nheads_k, head_size, _, _ = get_shapes_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k, self.max_seqlens_q, self.max_seqlens_k) - if self.varlen: - assert q.dim() == 3 - assert self.cu_seqlens_q is not None - assert self.cu_seqlens_k is not None - assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) - # TODO: Remove once bias is supported with varlen - assert self.bias is None - # assert not self.return_scores - else: - assert q.dim() == 4 - assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 - assert self.cu_seqlens_q is None and self.cu_seqlens_k is None - assert k.shape == v.shape - assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] - # TODO: Change assert if we support qkl f8 and v f16 - assert q.dtype == k.dtype and q.dtype == v.dtype - assert o.shape == q.shape - assert (nheads_q % nheads_k) == 0 - assert self.layout is not None - assert self.layout == 'thd' or not self.varlen +__all__ = [ + # Runtime info + "get_arch", + "is_hip", + # Global config + "AUTOTUNE", + "DEBUG", + "USE_TRITON_ROCM", + "BWD_MODE", + "USE_EXP2", + "PHILOX_SEED", + "PHILOX_OFFSET", + "SHAPE_EXPECTATIONS", + # FP8 + "is_fp8", + # Shape/stride helpers + "get_shape_from_layout", + "get_stride_from_layout", + "get_padded_headsize", + # Misc helpers + "round_multiple", +] + # ------------------------------- -# Input Helper +# GPU Architecture # ------------------------------- -def random_seqlens_composition(SEQ_LEN, BATCH): - # generate a random composition of N into Z positive parts. - idx = torch.randperm(SEQ_LEN - 1)[: BATCH - 1] + 1 - idx, _ = torch.sort(idx) - breakpoints = torch.cat([ - torch.tensor([0], dtype=torch.long), - idx, - torch.tensor([SEQ_LEN], dtype=torch.long), - ]) - seqlens = (breakpoints[1:] - breakpoints[:-1]).to(torch.int32) - return seqlens - -def generate_varlen_tensor( - total_seqlen: int, - num_heads: int, - head_size: int, - batch_size: Optional[int] = None, - equal_seqlens: bool = False, - device: str = "cuda", - dtype: torch.dtype = torch.float32, - DEBUG_INPUT: bool = False -): - if DEBUG: - print("total_seqlen", total_seqlen) - print("num_heads", num_heads) - print("head_size", head_size) - - # save fp8 type - is_fp8_dtype = is_dtype_fp8(dtype) - if is_fp8_dtype: - og_fp8_dtype = dtype - dtype = torch.float32 - - # get valid batch_size - if batch_size is None: - valid_batch_sizes = [bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen] - batch_size = random.choice(valid_batch_sizes) - - # get seqlens - if equal_seqlens: - seqlens = torch.full( - (batch_size,), - total_seqlen // batch_size, - dtype=torch.int32, - device=device - ) - seqlens[-1] += total_seqlen % batch_size - else: - seqlens = random_seqlens_composition(total_seqlen, batch_size).to(device=device) - - # create cumulative sequence lengths - cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32, device=device), seqlens.cumsum(dim=0)]).to(torch.int32).to(device=device) - max_seqlen = torch.max(seqlens).to(torch.int32).item() - - # create varlen tensor - if DEBUG_INPUT: - x = torch.zeros(total_seqlen, num_heads, head_size, dtype=dtype, device=device) - for i in range(batch_size): - start = cu_seqlens[i].item() - end = cu_seqlens[i+1].item() - length = end - start - - x[start:end, :, :] = ( - torch.arange(length, dtype=dtype, device=device) - .view(length, 1, 1) - .expand(length, num_heads, head_size) - ) - else: - x = torch.randn((total_seqlen, num_heads, head_size), dtype=dtype, device=device) +ArchFamily = Literal["cdna", "rdna"] - if is_fp8_dtype: - # cast to fp8 - x, descale_x = cast_to_fp8(x, og_fp8_dtype, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) - x.requires_grad_() - return x, cu_seqlens, max_seqlen, descale_x - else: - x.requires_grad_() - return x, cu_seqlens, max_seqlen - -def generate_bshd_tensor(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype, device="cuda", DEBUG_INPUT=False): - # save fp8 type - is_fp8_dtype = is_dtype_fp8(dtype) - if is_fp8_dtype: - og_fp8_dtype = dtype - dtype = torch.float32 - - # gen tensor - tensor_shape = (BATCH, SEQ_LEN, NUM_HEADS, D_HEAD) - if DEBUG_INPUT: - x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, SEQ_LEN, 1, 1).expand(*tensor_shape).contiguous() - else: - x = torch.randn(tensor_shape, dtype=dtype, device=device) - - if is_fp8_dtype: - # cast to fp8 - x, descale_x = cast_to_fp8(x, og_fp8_dtype, "bshd") - x.requires_grad_() - return x, descale_x - else: - x.requires_grad_() - return x - -def generate_bhsd_tensor(BATCH, NUM_HEADS, SEQ_LEN, D_HEAD, dtype, device="cuda", DEBUG_INPUT=False): - # save fp8 type - is_fp8_dtype = is_dtype_fp8(dtype) - if is_fp8_dtype: - og_fp8_dtype = dtype - dtype = torch.float32 - - # gen tensor - tensor_shape = (BATCH, NUM_HEADS, SEQ_LEN, D_HEAD) - if DEBUG_INPUT: - x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, 1, SEQ_LEN, 1).expand(*tensor_shape).contiguous() - else: - x = torch.randn(tensor_shape, dtype=dtype, device=device) - - - if is_fp8_dtype: - # cast to fp8 - x, descale_x = cast_to_fp8(x, og_fp8_dtype, "bhsd") # FIXME: I don't the casting fn supports this atm - x.requires_grad_() - return x, descale_x - else: - x.requires_grad_() - return x - -def input_helper( - BATCH: int, - HQ: int, - HK: int, - N_CTX_Q: int, - N_CTX_K: int, - D_HEAD: int, - CAUSAL: bool, - DROPOUT_P: float, - dtype: torch.dtype, - layout: Literal["bshd", "bhsd", "thd"], - packing: Optional[Literal["kv", "qkv"]] = None, - device: Literal["cpu", "cuda"] = "cuda", - DEBUG_INPUT: bool = False, -): - torch.manual_seed(20) - is_fp8_dtype = is_dtype_fp8(dtype) +CDNA_ARCHS = frozenset({"gfx908", "gfx90a", "gfx940", "gfx941", "gfx942", "gfx950"}) +RDNA_ARCHS = frozenset({"gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"}) +FP8_ARCHS = frozenset({"gfx942", "gfx950"}) - if layout == "thd": - # set params - TOTAL_SEQLENS_Q = BATCH * N_CTX_Q - TOTAL_SEQLENS_K = BATCH * N_CTX_K - equal_seqlens=False - - # gen tensors - # TODO: the gen functions should maybe have different gen modes like random, ones, increasing seqlen - if is_fp8_dtype: - q, cu_seqlens_q, max_seqlen_q, descale_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) - k, cu_seqlens_k, max_seqlen_k, descale_k = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) - v, _, _ , descale_v = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) - do, _, _ , descale_do = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) - else: - q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) - k, cu_seqlens_k, max_seqlen_k = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) - v, _, _ = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) - do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) +_RECOMMENDED_FP8_REPLACEMENTS: dict[str, dict[torch.dtype, torch.dtype]] = { + "gfx942": { + torch.float8_e4m3fn: torch.float8_e4m3fnuz, + torch.float8_e5m2: torch.float8_e5m2fnuz, + }, +} + + +@dataclass(frozen=True) +class GpuArch: + """GPU architecture information.""" + name: str # e.g., "gfx942", "gfx1100" + family: Optional[ArchFamily] = None + + @property + def is_cdna(self) -> bool: + return self.family == "cdna" + + @property + def is_rdna(self) -> bool: + return self.family == "rdna" + + @property + def supports_fp8(self) -> bool: + """Check if this architecture supports FP8.""" + return self.name in FP8_ARCHS + + def recommended_fp8_dtype(self, dtype: torch.dtype) -> torch.dtype: + """Get the recommended FP8 dtype for this architecture. - # setup metadata - if DEBUG_INPUT: - sm_scale = 1 - else: - sm_scale = D_HEAD**-0.5 - metadata = MetaData(sm_scale=sm_scale) - metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) - metadata.need_causal(CAUSAL) - metadata.need_dropout(DROPOUT_P) - elif layout == 'bshd' or layout == "bhsd": - # gen tensors - if layout == "bshd": - if is_fp8_dtype: - q, descale_q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - k, descale_k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - v, descale_v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - do, descale_do = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device) - else: - q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) - elif layout == "bhsd": - if is_fp8_dtype: - q, descale_q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - k, descale_k = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - v, descale_v = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - do, descale_do = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device) - else: - q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - k = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - v = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) - - # setup metadata - if DEBUG_INPUT: - sm_scale = 1 - else: - sm_scale = D_HEAD**-0.5 - metadata = MetaData(sm_scale=sm_scale) - metadata.max_seqlens_q = N_CTX_Q - metadata.max_seqlens_k = N_CTX_K - metadata.layout = layout - metadata.need_causal(CAUSAL) - metadata.need_dropout(DROPOUT_P) - else: - raise ValueError(f"Unknown layout: {layout}") - - # deal with packing - if packing is None: - if is_fp8_dtype: - return (q, descale_q), (k, descale_k), (v, descale_v), (do, descale_do), metadata - else: - return q, k, v, do, metadata - elif packing == "kv": - # pack k and v - if layout in ["bhsd", "thd"]: - kv = torch.stack([k, v], dim=1) - elif layout == "bshd": - kv = torch.stack([k, v], dim=2) - else: - raise ValueError(f"Unknown layout: {layout}") - - if is_fp8_dtype: - raise ValueError("FP8 not supported kv packing yet") - else: - return q, kv, do, metadata - elif packing == "qkv": - # qkv packing - requires same sequence length for q and k - assert N_CTX_Q == N_CTX_K, "For QKV packing, Q and K must have same sequence length" - assert HQ == HK, "For QKV packing, Q and K must have same number of heads" - - # pack q, k, and v - if layout in ["bhsd", "thd"]: - qkv = torch.stack([q, k, v], dim=1) - elif layout == "bshd": - qkv = torch.stack([q, k, v], dim=2) - else: - raise ValueError(f"Unknown layout: {layout}") - - if is_fp8_dtype: - raise ValueError("FP8 not supported qkv packing yet") - else: - return qkv, do, metadata - else: - assert False, f"Unsupported packing mode: {packing}" + Some architectures prefer different FP8 variants (e.g., fnuz vs fn). + Returns the input dtype unchanged if no replacement is recommended. + """ + return _RECOMMENDED_FP8_REPLACEMENTS.get(self.name, {}).get(dtype, dtype) + + @property + def cu_count(self) -> int: + """Get the number of compute units on the current GPU.""" + return int( + torch.cuda.get_device_properties( + torch.cuda.current_device() + ).multi_processor_count + ) + # ------------------------------- -# Alibi +# Global Variables # ------------------------------- -@triton.jit -def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): - # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix - # for casual mask we want something like this where (1 is kept and 0 is masked) - # seqlen_q = 2 and seqlen_k = 5 - # 1 1 1 1 0 - # 1 1 1 1 1 - # seqlen_q = 5 and seqlen_k = 2 - # 0 0 - # 0 0 - # 0 0 - # 1 0 - # 1 1 - # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal - # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False - # 1. offs_m[:,None] = [[0], - # [1], - # 2. offs_m[:,None] + seqlen_k = [[5], - # [6], - # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], - # [4], - # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], - # [4], [ 4, 3, 2, 1, 0]] - # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], - # [ -4, -3, -2, -1, 0]], - relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] - alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) - if transpose: - return alibi_block.T - else: - return alibi_block +USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" +AUTOTUNE = os.environ.get("FLASH_ATTENTION_TRITON_AMD_AUTOTUNE", "0").lower() in ( + "1", + "true", + "yes", +) + +# Unified debug level: +# 0 = off (default) +# 1 = basic debug info (shapes, tensor stats, kernel params) +# 2 = detailed debug (includes Triton interpreter prints in kernels) +# +# Set via: FLASH_ATTENTION_TRITON_AMD_DEBUG=0|1|2 +DEBUG: int = int(os.environ.get("FLASH_ATTENTION_TRITON_AMD_DEBUG", "0")) +if AUTOTUNE or DEBUG > 0: + os.environ["TRITON_PRINT_AUTOTUNING"] = "1" +if DEBUG >= 2: + os.environ["TRITON_INTERPRET"] = "1" +BWD_MODE: Literal["fused", "fused_atomic", "split"] = "fused" +USE_EXP2 = True +PHILOX_SEED = 0x1BF58 +PHILOX_OFFSET = 0x1D4B49 +SHAPE_EXPECTATIONS: Literal["exact", "rounded"] = "exact" + # ------------------------------- # FP8 # ------------------------------- -def is_dtype_fp8(dtype): - if dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz}: - if arch_supports_fp8(): +_FP8_DTYPES = frozenset({ + torch.float8_e4m3fnuz, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e5m2fnuz, +}) + + +def is_fp8( + x: Union[torch.dtype, torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]], +) -> bool: + """Check if dtype/tensor(s) are FP8. + + This is a pure function - it only checks dtypes, not architecture support. + Use `get_arch().supports_fp8` to check if the current GPU supports FP8. + + Args: + x: A dtype, tensor, or list/tuple of tensors to check. + + Returns: + True if FP8, False otherwise. + + Rules for multiple tensors: + - If all tensors are FP8 -> return True. + - If none are FP8 -> return False. + - If a mix of FP8 and non-FP8 -> raise ValueError. + + Empty list/tuple returns False. + """ + # Handle dtype directly + if isinstance(x, torch.dtype): + return x in _FP8_DTYPES + + # Handle single tensor + if isinstance(x, torch.Tensor): + return x.dtype in _FP8_DTYPES + + # Handle list/tuple of tensors + if isinstance(x, (list, tuple)): + if len(x) == 0: + return False + flags = [t.dtype in _FP8_DTYPES for t in x] + if all(flags): return True - else: - raise RuntimeError("This device doesnot support fp8") - else: - return False - -def is_fp8(x): - return is_dtype_fp8(x.dtype) - -@triton.jit -def compute_fp8_scaling_factors(x, fp8_max: tl.constexpr): - # compute fp8 scaling and descaling factor for a block - x_amax = tl.max(tl.abs(x)) # NOTE: abs deals with negative values - x_amax = tl.where(x_amax <= 1e-9, 1e-9, x_amax) - scale_x = fp8_max / x_amax - descale_x = x_amax / fp8_max - return scale_x, descale_x - -@triton.jit -def _cast_varlen_to_fp8_kernel_2d( - X, X_fp8, Descale, - cu_seqlens, H, MAX_SEQLEN, - stride_batch, stride_seq, stride_head, stride_dim, - stride_out_batch, stride_out_seq, stride_out_head, stride_out_dim, - stride_desc_batch, stride_desc_head, - FP8_CLAMP_VAL, - FP8_MAX, - BLOCK_SIZE: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, - IS_VARLEN: tl.constexpr - ): - # Process one (batch, head) pair per kernel - b_id = tl.program_id(0) - h_id = tl.program_id(1) - - # Get sequence bounds for this batch - if IS_VARLEN: - seq_start = tl.load(cu_seqlens + b_id) - seq_end = tl.load(cu_seqlens + b_id + 1) - seqlen = seq_end - seq_start - else: - seq_start = 0 - seqlen = MAX_SEQLEN - - # initialize max value tracker - x_max_val = 0.0 - - # STEP 1: Find max absolute value across the entire sequence - num_of_blocks = tl.cdiv(seqlen, BLOCK_SIZE) - for blk_idx in range(0, num_of_blocks): - # print("blk_idx:", blk_idx) - # offsets - offs_seq = blk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - offs_dim = tl.arange(0, HEAD_DIM) - - # Create mask for valid elements - mask_seq = offs_seq[:, None] < seqlen - if ACTUAL_HEAD_DIM != HEAD_DIM: - mask_dim = offs_dim[None, :] < ACTUAL_HEAD_DIM - mask_seq = mask_seq & mask_dim - - # Load block - adj_x = b_id * stride_batch + h_id * stride_head + seq_start * stride_seq + offs_seq[:, None] * stride_seq + offs_dim[None, :] * stride_dim - x_block = tl.load(X + adj_x, mask=mask_seq, other=0.0) - # print("x_block:", x_block) - - # Find max absolute value in this block - block_max = tl.max(tl.abs(x_block)) - # print("block_max:", block_max) - - # Update overall max - x_max_val = tl.maximum(x_max_val, block_max) - # print("x_max_val:", x_max_val) - - # clamp to avoid division by zero issues - x_max_val = tl.maximum(x_max_val, FP8_CLAMP_VAL) - - # compute scale and descale factors for the entire sequence - scale = FP8_MAX / x_max_val - descale = x_max_val / FP8_MAX - - # store descale factor for this (batch, head) pair - desc_ptr = Descale + b_id * stride_desc_batch + h_id# * stride_desc_head - tl.store(desc_ptr, descale) - - # STEP 2: Apply scaling to the entire sequence and convert to FP8 - for blk_idx in range(0, num_of_blocks): - # offsets - offs_seq = blk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - offs_dim = tl.arange(0, HEAD_DIM) - - # Create mask for valid elements - mask_seq = offs_seq[:, None] < seqlen - if ACTUAL_HEAD_DIM != HEAD_DIM: - mask_dim = offs_dim[None, :] < ACTUAL_HEAD_DIM - mask_seq = mask_seq & mask_dim - - # Load block - Using the fixed addressing - addr = b_id * stride_batch + h_id * stride_head + seq_start * stride_seq + offs_seq[:, None] * stride_seq + offs_dim[None, :] * stride_dim - x_block = tl.load(X + addr, mask=mask_seq, other=0.0) - - # Apply scale and convert to FP8 - x_fp8_block = (x_block * scale).to(X_fp8.type.element_ty) - - # Store results - addr_out = b_id * stride_out_batch + h_id * stride_out_head + seq_start * stride_out_seq + offs_seq[:, None] * stride_out_seq + offs_dim[None, :] * stride_out_dim - tl.store(X_fp8 + addr_out, x_fp8_block, mask=mask_seq) + if not any(flags): + return False + raise ValueError( + "Mixed FP8 and non-FP8 tensors provided; either all or none must be FP8." + ) + + raise TypeError(f"Expected dtype, Tensor, or sequence of Tensors, got {type(x)}") -def cast_to_fp8( - x: torch.Tensor, - fp8_dtype: torch.dtype, - layout: Literal["bshd", "thd"], - clamp_val: float = 1e-9, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None -) -> tuple[torch.Tensor, torch.Tensor]: - if False: - print() - print("cast_to_fp8") - print("x:", x, x.shape) - print("fp8_dtype:", fp8_dtype) - print("cu_seqlens:", cu_seqlens) - print("max_seqlen:", max_seqlen) - print("clamp_val:", clamp_val) - - # check types are valid - assert x.dtype in {torch.float16, torch.float32, torch.float64, torch.bfloat16} and is_dtype_fp8(fp8_dtype), f"Cannot cast {x.dtype} to {fp8_dtype}" - - # extract dimensions - batch, max_seqlen_final, num_heads, head_dim = get_shape_from_layout(x, layout, cu_seqlens, max_seqlen) - is_varlen = layout == "thd" - fp8_max = torch.finfo(fp8_dtype).max - if False: - print("batch:", batch) - print("max_seqlen_final:", max_seqlen_final) - print("num_heads:", num_heads) - print("head_dim:", head_dim) - - # get closest power of 2 for head_dim - padded_head_dim = 1 << (head_dim - 1).bit_length() - padded_head_dim = max(padded_head_dim, 32) - - # kernel params - x_fp8 = torch.zeros_like(x, dtype=fp8_dtype) - descale_factors = torch.zeros((batch, num_heads), device=x.device, dtype=torch.float32) - BLOCK_SIZE = 128 - - # calculate strides - stride_batch, stride_head, stride_seq, stride_dim = get_stride_from_layout(x, layout) - stride_out_batch, stride_out_head, stride_out_seq, stride_out_dim = get_stride_from_layout(x_fp8, layout) - stride_desc_batch, stride_desc_head = descale_factors.stride() - - if False: - print("stride_batch", stride_batch) - print("stride_head", stride_head) - print("stride_seq", stride_seq) - print("stride_dim", stride_dim) - print("stride_out_batch", stride_out_batch) - print("stride_out_head", stride_out_head) - print("stride_out_seq", stride_out_seq) - print("stride_out_dim", stride_out_dim) - print("stride_desc_batch", stride_desc_batch) - print("stride_desc_head", stride_desc_head) - - grid = (batch, num_heads) - _cast_varlen_to_fp8_kernel_2d[grid]( - x, x_fp8, descale_factors, - cu_seqlens, num_heads, max_seqlen_final, - stride_batch, stride_seq, stride_head, stride_dim, - stride_out_batch, stride_out_seq, stride_out_head, stride_out_dim, - stride_desc_batch, stride_desc_head, - clamp_val, fp8_max, - BLOCK_SIZE=BLOCK_SIZE, - HEAD_DIM=padded_head_dim, - ACTUAL_HEAD_DIM=head_dim, - IS_VARLEN=is_varlen - ) - - if False: - print("x_fp8:", x_fp8, x_fp8.shape) - print("descale_factors:", descale_factors, descale_factors.shape) - return x_fp8, descale_factors # ------------------------------- -# Misc +# Shape/Stride Helpers # ------------------------------- def get_shape_from_layout( x: torch.Tensor, @@ -629,147 +187,78 @@ def get_shape_from_layout( cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, ) -> tuple[int, int, int, int]: - if layout == 'bhsd': + """Extract (batch, max_seqlen, num_heads, head_dim) from tensor based on layout.""" + if layout == "bhsd": batch, num_heads, max_seqlen_final, head_dim = x.shape - elif layout == 'bshd': + elif layout == "bshd": batch, max_seqlen_final, num_heads, head_dim = x.shape - elif layout == 'thd': + elif layout == "thd": total_seqlen, num_heads, head_dim = x.shape if cu_seqlens is None: - raise ValueError("cu_seqlens must be provided for varlen (thd) layout") + raise ValueError("cu_seqlens must be provided for varlen (thd) layout") if max_seqlen is None: raise ValueError("max_seqlen must be provided for varlen (thd) layout") - - batch, max_seqlen_final, num_heads, head_dim = len(cu_seqlens) - 1, max_seqlen, num_heads, head_dim + + batch, max_seqlen_final, num_heads, head_dim = ( + len(cu_seqlens) - 1, + max_seqlen, + num_heads, + head_dim, + ) else: - assert False, "Got unsupported layout." + raise ValueError(f"Got unsupported layout: {layout}") return batch, max_seqlen_final, num_heads, head_dim -def get_shapes_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None): - batch_q, seqlen_q, nheads_q, head_size_q = get_shape_from_layout(q, layout, cu_seqlens_q, max_seqlen_q) - batch_k, seqlen_k, nheads_k, head_size_k = get_shape_from_layout(k, layout, cu_seqlens_k, max_seqlen_k) - - # assert - assert batch_q == batch_k - assert head_size_q == head_size_k - - return batch_q, nheads_q, nheads_k, head_size_q, seqlen_q, seqlen_k - -def get_stride_from_layout(x: torch.Tensor, layout:Literal["bshd", "bhsd", "thd"]): - if layout == 'thd': - strides = (0, x.stride(1), x.stride(0), x.stride(2)) - elif layout == 'bhsd': +def get_stride_from_layout( + x: torch.Tensor, layout: Literal["bshd", "bhsd", "thd"] +) -> tuple[int, int, int, int]: + """Get strides in (batch, head, seq, dim) order for the given layout.""" + if layout == "thd": + strides = (0, x.stride(1), x.stride(0), x.stride(2)) + elif layout == "bhsd": strides = (x.stride(0), x.stride(1), x.stride(2), x.stride(3)) - elif layout == 'bshd': + elif layout == "bshd": strides = (x.stride(0), x.stride(2), x.stride(1), x.stride(3)) else: - assert False, 'Got unsupported layout.' + raise ValueError(f"Got unsupported layout: {layout}") return strides -def get_shape_and_strides_from_layout(x: torch.Tensor, layout: Literal["bshd", "bhsd", "thd"], cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None): - return get_shape_from_layout(x, layout, cu_seqlens, max_seqlen), get_stride_from_layout(x, layout) -def get_strides_from_layout(q, k, v, o, layout): - q_strides = get_stride_from_layout(q, layout) - k_strides = get_stride_from_layout(k, layout) - v_strides = get_stride_from_layout(v, layout) - o_strides = get_stride_from_layout(o, layout) - return q_strides, k_strides, v_strides, o_strides - -def get_padded_headsize(size): - # Get closest power of 2 over or equal to 32. - padded_d_model = 1 << (size - 1).bit_length() +def get_padded_headsize(size: int) -> int: + """Get closest power of 2 over or equal to 32.""" # Smallest head_dim supported is 16. If smaller, the tile in the # kernel is padded - there is no padding in memory for any dims. + padded_d_model = 1 << (size - 1).bit_length() padded_d_model = max(padded_d_model, 16) return padded_d_model -def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k): - q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) - k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K) - relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K) - return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) # ------------------------------- -# Dropouts +# Misc helpers # ------------------------------- -def create_dropout_mask(dropout_p, shape, seed): - device = "cuda" - rand_vals = torch.rand(shape, generator=torch.Generator(device=device).manual_seed(seed), device=device, dtype=torch.float32) - return rand_vals > dropout_p - -def create_dropout_mask_varlen(dropout_p, batch, nheads_q, cu_seqlens_q, cu_seqlens_k, philox_seed): - device = "cuda" - qlens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]) - klens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]) - max_qlen = qlens.max() - max_klen = klens.max() - dropout_mask = torch.zeros((batch, nheads_q, max_qlen, max_klen), device=device) - for b in range(batch): - qlen = qlens[b] - klen = klens[b] - rand_vals = torch.rand((nheads_q, qlen, klen), generator=torch.Generator(device=device).manual_seed(philox_seed), device=device, dtype=torch.float32) - submask = rand_vals > dropout_p - dropout_mask[b, :, :qlen, :klen] = submask - - return dropout_mask - -def write_dropout_mask(x, tensor_name = "tensor"): - batch, head, seqlen_m, seqlen_n = x.shape - x = x.tolist() - - with open(f'{tensor_name}.csv', 'w') as f: - writer = csv.writer(f) - for b in range(batch): - for h in range(head): - dropout_mask = x[b][h] - if True: - BLOCK_M = 64 - BLOCK_N = 64 - - # Calculate number of blocks in each dimension - m_blocks = math.ceil(seqlen_m / BLOCK_M) - n_blocks = math.ceil(seqlen_n / BLOCK_N) - - # Process each block - for m_block in range(m_blocks): - # Calculate row range for current block - row_start = m_block * BLOCK_M - row_end = min(row_start + BLOCK_M, seqlen_m) - - for n_block in range(n_blocks): - # Calculate column range for current block - col_start = n_block * BLOCK_N - col_end = min(col_start + BLOCK_N, seqlen_n) - - # Extract and write the current block - for row_idx in range(row_start, row_end): - row_data = dropout_mask[row_idx][col_start:col_end] - writer.writerow(row_data) - else: - writer.writerows(dropout_mask) +def round_multiple(x: int, m: int) -> int: + """Round x up to the nearest multiple of m.""" + return (x + m - 1) // m * m + # ------------------------------- # Runtime info # ------------------------------- @functools.cache -def is_hip(): - return triton.runtime.driver.active.get_current_target().backend == "hip" - -@functools.cache -def get_arch(): - return triton.runtime.driver.active.get_current_target().arch +def is_hip() -> bool: + """Check if running on HIP (AMD) backend.""" + return bool(triton.runtime.driver.active.get_current_target().backend == "hip") -@functools.cache -def is_cdna(): - return is_hip() and get_arch() in ('gfx908', 'gfx90a', 'gfx940', 'gfx941', 'gfx942', 'gfx950') - -@functools.cache -def is_rdna(): - return is_hip() and get_arch() in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201") @functools.cache -def arch_supports_fp8(): - return is_hip() and get_arch() in ('gfx942') +def get_arch() -> GpuArch: + """Get the current GPU architecture.""" + name: str = triton.runtime.driver.active.get_current_target().arch + if name in CDNA_ARCHS: + return GpuArch(name=name, family="cdna") + elif name in RDNA_ARCHS: + return GpuArch(name=name, family="rdna") + else: + return GpuArch(name=name) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py old mode 100644 new mode 100755 index 44d1f027cb0..92a014624e1 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -2,16 +2,27 @@ from typing import Optional, Union, List, Tuple +import os +import sys +from pathlib import Path import torch import torch.nn as nn -# isort: off -# We need to import the CUDA kernels after importing torch -import flash_attn_3._C # Registers operators with PyTorch -# isort: on +USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" +if USE_TRITON_ROCM: + repo_root = Path(__file__).resolve().parent.parent + if str(repo_root) not in sys.path: + sys.path.insert(0, str(repo_root)) + from flash_attn.flash_attn_triton_amd import flash_attn_3 as flash_attn_3_gpu # type: ignore +else: + # isort: off + # We need to import the CUDA kernels after importing torch + import flash_attn_3._C # Registers operators with PyTorch -flash_attn_3_cuda = torch.ops.flash_attn_3 + # isort: on + + flash_attn_3_gpu = torch.ops.flash_attn_3 def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x @@ -90,7 +101,7 @@ def _flash_attn_forward( ] rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] seqlens_rotary = maybe_contiguous(seqlens_rotary) - out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd( + out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_gpu.fwd( q, k, v, @@ -268,7 +279,7 @@ def _flash_attn_backward( ) -> torch.Tensor: # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - softmax_d, *rest = flash_attn_3_cuda.bwd( + softmax_d, *rest = flash_attn_3_gpu.bwd( dout, q, k, @@ -922,7 +933,7 @@ def flash_attn_varlen_func( def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None): - return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype) + return flash_attn_3_gpu.fwd_combine(out_partial, lse_partial, out, out_dtype) def flash_attn_with_kvcache( @@ -1110,7 +1121,7 @@ def get_scheduler_metadata( cache_seqlens = maybe_contiguous(cache_seqlens) if headdim_v is None: headdim_v = headdim - scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata( + scheduler_metadata = flash_attn_3_gpu.get_scheduler_metadata( batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v, qkv_dtype, cache_seqlens, diff --git a/hopper/setup.py b/hopper/setup.py old mode 100644 new mode 100755 index 95729edabe2..36359229766 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -43,6 +43,10 @@ SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" +# ROCm specific settings +USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" +if USE_TRITON_ROCM: + SKIP_CUDA_BUILD = True DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" @@ -421,10 +425,10 @@ def nvcc_threads_args(): cmdclass = {} ext_modules = [] - # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp # files included in the source distribution, in case the user compiles from source. -subprocess.run(["git", "submodule", "update", "--init", "../csrc/cutlass"]) +if not USE_TRITON_ROCM: + subprocess.run(["git", "submodule", "update", "--init", "../csrc/cutlass"]) if not SKIP_CUDA_BUILD: print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) diff --git a/hopper/test_flash_attn_triton_amd.py b/hopper/test_flash_attn_triton_amd.py new file mode 100755 index 00000000000..73e54dce066 --- /dev/null +++ b/hopper/test_flash_attn_triton_amd.py @@ -0,0 +1,1173 @@ +import os +import math +import itertools + +import pytest +import torch +import torch.nn.functional as F +from torch._C import parse_schema + +from einops import rearrange, repeat +try: + from flash_attn.layers.rotary import apply_rotary_emb +except ImportError: + apply_rotary_emb = None + +from padding import pad_input, unpad_input +from test_util import ( + attention_ref, + generate_qkv, + generate_random_padding_mask, +) + +from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine +from flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata + + +DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" +DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "TRUE") == "TRUE" +DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" +DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE" +DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "TRUE") == "TRUE" +DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "TRUE") == "TRUE" +DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "TRUE") == "TRUE" +DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE" +DISABLE_FP8 = os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE" or torch.cuda.get_device_capability("cuda")[0] < 9 +DISABLE_HDIM64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM64", "FALSE") == "TRUE" +DISABLE_HDIM96 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM96", "FALSE") == "TRUE" +DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" +DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" +DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" + +COMPILED_HDIMS = ( + [] + + ([64] if not DISABLE_HDIM64 else []) + + ([96] if not DISABLE_HDIM96 else []) + + ([128] if not DISABLE_HDIM128 else []) + + ([192] if not DISABLE_HDIM192 else []) + + ([256] if not DISABLE_HDIM256 else []) +) + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("V_colmajor", [False, True]) +@pytest.mark.parametrize("V_colmajor", [False]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64, 128, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128, 192]) +@pytest.mark.parametrize("d", COMPILED_HDIMS) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (64, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (4096, 4096), + (4224, 4224), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +def test_flash_attn_output( + seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype +): + if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): + pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") + device = "cuda" + # set seed + torch.random.manual_seed(0) + # batch_size = 40 + # nheads = 16 + batch_size = 9 if seqlen_k <= 2048 else 2 + # batch_size = 1 + nheads = 6 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4) + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + # window_size = (-1, -1) if not local else (16, 0) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None + if V_colmajor: + v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() + # if qv is not None: + # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # lse_ref = torch.logsumexp(qk, dim=-1) + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out = flash_attn_func( + q, + k, + v, + causal=causal, + qv=qv, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + pack_gqa=pack_gqa, + num_splits=num_splits + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + + if ( + not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not V_colmajor + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + ): + g = torch.randn_like(out) + do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) + # import flash_attn_3_cuda + # dq, dk, dv, softmax_d, dq_accum, dk_accum, dv_accum = flash_attn_3_cuda.bwd( + # g, + # q, + # k, + # v, + # out, + # lse, + # None, + # None, + # None, + # d ** (-0.5), + # causal, + # window_size[0], window_size[1], + # softcap, + # deterministic, + # 0, # sm_margin + # ) + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("add_unused_qkv", [False, True]) +# @pytest.mark.parametrize("add_unused_qkv", [True]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128]) +@pytest.mark.parametrize("d", COMPILED_HDIMS) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (1, 3), + (2, 1), + (511, 1), + (3, 513), + (64, 128), + (128, 128), + (256, 256), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (307, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +def test_flash_attn_varlen_output( + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype +): + device = "cuda" + # set seed + torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) + # batch_size = 40 + # nheads = 16 + batch_size = 9 if seqlen_q <= 2048 else 2 + nheads = 6 + # batch_size = 2 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random", zero_lengths=False + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, batch_size, device, mode="random", zero_lengths=True + ) + + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) + + ( + q_unpad, + k_unpad, + v_unpad, + qv_unpad, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + qv, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, + query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) + q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out_unpad = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + causal=causal, + qv=qv_unpad, + q_descale=q_descale, + k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + ) + out = output_pad_fn(out_unpad) + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + + + if ( + not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + ): + g_unpad = torch.randn_like(out_unpad) + do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) + # import flash_attn_3_cuda + # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen( + # g_unpad, + # q_unpad, + # k_unpad, + # v_unpad, + # out_unpad, + # lse, + # None, + # None, + # None, + # cu_seqlens_q, + # cu_seqlens_k, + # None, None, + # max_seqlen_q, + # max_seqlen_k, + # d ** (-0.5), + # causal, + # window_size[0], window_size[1], + # softcap, + # deterministic, + # 0, # sm_margin + # ) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq.masked_fill_(q_zero_masking, 0.0) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + g = output_pad_fn(g_unpad) + + # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() + # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) +# @pytest.mark.parametrize("new_kv", [True]) +@pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) +# @pytest.mark.parametrize("causal,local", [(False, False)]) +@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False]) +@pytest.mark.parametrize("rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False]) +# @pytest.mark.parametrize("rotary_interleaved", [True]) +@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) else [0.0]) +# @pytest.mark.parametrize("rotary_fraction", [0.0]) +@pytest.mark.parametrize("page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else [])) +# @pytest.mark.parametrize("page_size", [None]) +@pytest.mark.parametrize("has_leftpad", [False]) +# @pytest.mark.parametrize("has_leftpad", [False]) +@pytest.mark.parametrize("has_batch_idx", [False]) +# @pytest.mark.parametrize("has_batch_idx", [False]) +@pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +@pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [192]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + # (1, 128 * 1024), + # (16, 128 * 1024), + (128, 128), + (256, 512), # To test appending KV with more than 1 block + (2048, 3577), # Enough tile to test persistent scheduler + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +def test_flash_attn_kvcache( + seqlen_q, + seqlen_k, + d, + varlen_q, + has_batch_idx, + has_leftpad, + page_size, + rotary_fraction, + rotary_interleaved, + has_rotary_seqlens, + seqlen_new_eq_seqlen_q, + causal, + local, + new_kv, + mha_type, + dtype, +): + if page_size is not None and seqlen_k % page_size != 0: + pytest.skip() + if seqlen_q > seqlen_k and new_kv: + pytest.skip() + if not new_kv and rotary_fraction > 0.0: + pytest.skip() + if rotary_fraction == 0.0 and has_rotary_seqlens: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 5 + # batch_size = 1 + batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 + nheads = 6 + # nheads = 1 + # rotary_dim must be a multiple of 16, and must be <= d + rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + assert nheads % nheads_k == 0 + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) and not DISABLE_LOCAL else [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + has_qv = d == 64 and dv >= 256 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if has_qv: + qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv = None + if varlen_q: + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None + else: + query_padding_mask = None + q_unpad = q + qv_unpad = qv + cu_seqlens_q, max_seqlen_q = None, None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + + seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() + cu_seqlens_k_new = None + key_new_padding_mask = None + if new_kv: + k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v = torch.randn(batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if varlen_q: # k & v are also varlen + key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random") + k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask) + v_unpad, *rest = unpad_input(v, key_new_padding_mask) + else: + k_unpad, v_unpad = k, v + else: + k, v, k_unpad, v_unpad = None, None, None, None + if page_size is None: + k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + page_table = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype, dtype_ref + ) + cache_seqlens = torch.randint( + 0 if new_kv else 1, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough + ( + (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) + if new_kv + else (seqlen_k + 1) + ), + (batch_size,), + dtype=torch.int32, + device=device, + ) + if has_leftpad: + cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) + if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) + for i in range(batch_size)]) + else: + cache_leftpad = None + if has_batch_idx: + cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ + :batch_size + ] + else: + cache_batch_idx = None + arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + if not new_kv: + key_padding_mask = arange < cache_seqlens_expanded + else: + k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens + if has_leftpad: + key_padding_mask = torch.logical_and( + key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) + ) + # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 + if rotary_dim > 0: + angle = ( + torch.rand( + seqlen_k if page_size is None else num_blocks * page_size, + rotary_dim // 2, + device=device, + ) + * 2 + * math.pi + ) + cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + if causal or local: + q_ro = apply_rotary_emb( + q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=seqlen_q, + ) + # q_ro = q + k_ro = apply_rotary_emb( + k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved + ) + else: + cos, sin = None, None + q_ro, k_ro = q, k + # k_cache[:, 64:] = -1 + k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() + v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() + if new_kv: + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens + ) + k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") + v_to_update = rearrange(v, "b s ... -> (b s) ...") + if varlen_q: + k_to_update = k_to_update[indices_k] + v_to_update = v_to_update[indices_k] + k_cache_ref[update_mask] = k_to_update + v_cache_ref[update_mask] = v_to_update + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + out_ref, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + attention_chunk=attention_chunk, + key_leftpad=cache_leftpad, + ) + out_pt, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + attention_chunk=attention_chunk, + upcast=False, + reorder_ops=True, + key_leftpad=cache_leftpad, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None + ) + q = q.to(dtype) + q_unpad = q_unpad.to(dtype) if varlen_q else None + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None + v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None + k = k.to(dtype) if k is not None else None + v = v.to(dtype) if v is not None else None + k_unpad = k_unpad.to(dtype) if k_unpad is not None else None + v_unpad = v_unpad.to(dtype) if v_unpad is not None else None + qv = qv.to(dtype) if qv is not None else None + qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None + cos = cos.to(dtype) if cos is not None else None + sin = sin.to(dtype) if sin is not None else None + k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() + v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() + num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] + precompute_metadata_vals = [False] + for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): + if precompute_metadata: + scheduler_metadata = get_scheduler_metadata( + batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, + cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, + max_seqlen_k_new=seqlen_new, page_size=page_size, + causal=causal, window_size=window_size, attention_chunk=attention_chunk, + num_splits=num_splits + ) + else: + scheduler_metadata = None + # Repeat to test metadata reuse + for _ in range(1 if not precompute_metadata else 2): + if page_size is None: + k_cache.copy_(k_cache_saved) + v_cache.copy_(v_cache_saved) + else: + k_cache_paged.copy_(k_cache_saved) + v_cache_paged.copy_(v_cache_saved) + out, lse, *rest = flash_attn_with_kvcache( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + k if not new_kv or not varlen_q else k_unpad, + v if not new_kv or not varlen_q else v_unpad, + qv=qv if not varlen_q else qv_unpad, + rotary_cos=cos, + rotary_sin=sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + rotary_seqlens=rotary_seqlens, + causal=causal, + window_size=window_size, + attention_chunk=attention_chunk, + rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + return_softmax_lse=True + ) + if varlen_q: + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) + if dtype is not torch.float8_e4m3fn: + assert torch.equal(v_cache_select, v_cache_ref) + else: + assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) + else: + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + else: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() + + +def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref): + num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 + k_cache_paged = torch.randn( + num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref + ).to(dtype).to(dtype_ref) + v_cache_paged = torch.randn( + num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref + ).to(dtype).to(dtype_ref) + page_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + k_cache = rearrange( + k_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + v_cache = rearrange( + v_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize('d', [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (64, 8192), + ], +) +def test_flash_attn_cluster(seqlen_q, seqlen_k, d, causal, dtype): + device = "cuda" + torch.random.manual_seed(0) + batch_size = 2 + nheads = 16 + nheads_kv = 4 + # There was a bug where this would cause "unspecified launch failure" due to Cluster + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) + v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) + for _ in range(100): + flash_attn_func(q, k, v, causal=causal) + + +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128]) +# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [80]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 239), + (239, 1), + (3, 799), + (799, 3), + (1024, 128), + (97, 97), + (128, 128), + (200, 200), + (256, 256), + (257, 257), + (384, 384), + (512, 512), + (768, 768), + (1024, 1024), + (2048, 2048), + ], +) +@pytest.mark.skip(reason="Cannot be run in parallel with other tests due to memory usage") +def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype): + device = "cuda" + # set seed + torch.random.manual_seed(0) + # Simulate under memory load + dummy = torch.empty(70 * 1024 ** 3, dtype=torch.uint8, device=device) + batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger + nheads = 4 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + torch.random.manual_seed(42) + out0 = flash_attn_func(q, k, v, causal=causal) + g = torch.randn_like(out0) + dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g) + # Numerical error if we just do any arithmetic on dq + dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item() + + for i in range(1000): + torch.random.manual_seed(42) + out = flash_attn_func(q, k, v, causal=causal) + assert torch.equal(out, out0) + # assert torch.equal(lse, lse0) + + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_equal = torch.allclose(dq, dq0, atol=dq_atol) + if not dq_equal: + print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}") + # breakpoint() + assert torch.equal(dv, dv0) + assert torch.equal(dk, dk0) + assert dq_equal + + +def attention_combine_ref(out_partial, lse_partial): + """ + out_partial: (num_splits, batch_size, seqlen, nheads, d) + lse_partial: (num_splits, batch_size, nheads, seqlen) + """ + lse = torch.logsumexp(lse_partial, dim=0) + scale = torch.exp(lse_partial - lse) + scale = torch.where(torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale) + out = (scale.unsqueeze(-1) * out_partial).sum(0) + return out, lse + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float32]) +# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +@pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) +# @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192]) +# @pytest.mark.parametrize("seqlen", [15]) +@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) +# @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11]) +# @pytest.mark.parametrize("num_splits", [128]) +def test_flash_attn_combine(num_splits, seqlen, d, dtype): + if DISABLE_SPLIT: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(1) + batch_size = 5 + nheads = 16 + # batch_size = 1 + # nheads = 1 + out_partial = torch.randn(num_splits * 2, batch_size, nheads, seqlen, d, device=device, dtype=torch.float32).transpose(2, 3)[:num_splits] # To test non-contiguous tensor + lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor + # To test short-circuiting based on num_splits + lse_partial[num_splits // 2:, :batch_size // 3] = -float("inf") + out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype) + out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) + out_pt = out_ref.to(dtype) + + print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + print(f"LSE mean diff: {(lse - lse_ref).abs().mean().item()}") + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) + multiple = 2 + assert ((out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item()) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) + + # from flash_attn.utils.benchmark import pytorch_profiler + # # pytorch_profiler(torch.sum, lse_partial) + # pytorch_profiler(flash_attn_combine, out_partial, lse_partial) + # pytorch_profiler(torch.sum, out_partial) + +@pytest.mark.skip(reason="AMD Triton backend doesn't use torch ops registration") +def test_flash3_bw_compatibility() -> None: + # Let's try to always stay backward compatible! This will make life easier + # for downstream libaries, users, and exported models. + # 1/ Instead of removing arguments, error out if their value is no longer supported + # 2/ When adding arguments, add them at the end with a default value + assert torch.ops.flash_attn_3.fwd.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::fwd(Tensor q, Tensor k, Tensor v, Tensor(k_new!)? k_new=None, " + "Tensor(v_new!)? v_new=None, Tensor? q_v=None, Tensor(out!)? out=None, " + "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, " + "Tensor? cu_seqlens_k_new=None, Tensor? seqused_q=None, Tensor? seqused_k=None, " + "int? max_seqlen_q=None, int? max_seqlen_k=None, Tensor? page_table=None, " + "Tensor? kv_batch_idx=None, Tensor? leftpad_k=None, Tensor? rotary_cos=None, Tensor? rotary_sin=None, " + "Tensor? seqlens_rotary=None, Tensor? q_descale=None, Tensor? k_descale=None, Tensor? v_descale=None, " + "float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, int window_size_right=-1, " + "int attention_chunk=0, float softcap=0., bool is_rotary_interleaved=False, " + "Tensor? scheduler_metadata=None, int num_splits=0, bool? pack_gqa=None, int sm_margin=0) " + "-> (Tensor(out!), Tensor, Tensor, Tensor)" + )) + assert torch.ops.flash_attn_3.bwd.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, " + "Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, " + "Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, " + "int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, " + "int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) " + "-> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)" + )) + assert torch.ops.flash_attn_3.fwd_combine.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::fwd_combine(Tensor out_partial, Tensor lse_partial, Tensor(out!)? out=None, " + "ScalarType? out_dtype=None) -> (Tensor(out!), Tensor)" + )) + assert torch.ops.flash_attn_3.get_scheduler_metadata.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::get_scheduler_metadata(int batch_size, int max_seqlen_q, int max_seqlen_k, " + "int num_heads, int num_heads_k, int headdim, int headdim_v, ScalarType qkv_dtype, Tensor seqused_k, " + "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, Tensor? cu_seqlens_k_new=None, " + "Tensor? seqused_q=None, Tensor? leftpad_k=None, int? page_size=None, int max_seqlen_k_new=0, " + "bool is_causal=False, int window_size_left=-1, int window_size_right=-1, " + "int attention_chunk=0, bool has_softcap=False, int num_splits=0, bool? pack_gqa=None, " + "int sm_margin=0) -> Tensor" + )) diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py index b5e026803c2..ac1ca579d0f 100755 --- a/tests/test_flash_attn_triton_amd.py +++ b/tests/test_flash_attn_triton_amd.py @@ -16,7 +16,19 @@ from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import _get_block_size_n from flash_attn.layers.rotary import apply_rotary_emb -from flash_attn.flash_attn_triton_amd.utils import USE_TRITON_ROCM, is_rdna +from flash_attn.flash_attn_triton_amd.utils import USE_TRITON_ROCM, is_hip, get_arch + + +def _get_block_size_n_triton(device, head_dim, is_dropout, is_causal): + """Get block size for Triton AMD kernel.""" + arch = get_arch() + if arch.is_rdna: + return 32 + elif arch.is_cdna: + return 64 + # Fall back to CUDA kernel block sizes + return _get_block_size_n(device, head_dim, is_dropout, is_causal) + MAX_HEADDIM_SM8x = 192 @@ -26,6 +38,8 @@ is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0) is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0) +skip_bfloat16 = True if is_sm75 or is_hip() else False + def attn_bias_from_alibi_slopes( slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False, key_leftpad=None @@ -505,7 +519,7 @@ def normalize_flash_attn_S( scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias.to(dtype=scores.dtype) - block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal) + block_size_n = _get_block_size_n_triton(scores.device, head_dim, is_dropout, causal) scores_block = scores.split(block_size_n, dim=-1) lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) lse = torch.logsumexp(lse_block, dim=-1) @@ -565,7 +579,7 @@ def get_dropout_fraction( return dropped.sum() / valid.sum() -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("deterministic", [False]) @@ -714,7 +728,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("deterministic", [True]) @@ -862,9 +876,9 @@ def test_flash_attn_varlen_qkvpacked( assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() -@pytest.mark.parametrize("kvpacked", [False]) +@pytest.mark.parametrize("kvpacked", [True, False]) # @pytest.mark.parametrize("kvpacked", [False]) -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @@ -1139,7 +1153,7 @@ def test_flash_attn_output( @pytest.mark.parametrize("kvpacked", [False]) # @pytest.mark.parametrize('kvpacked', [False]) -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize('mha_type', ["mqa"]) @@ -1459,7 +1473,7 @@ def test_flash_attn_varlen_output( assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) @@ -1489,7 +1503,7 @@ def test_flash_attn_varlen_output( # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): if USE_TRITON_ROCM: - if is_rdna(): + if get_arch().is_rdna: if seqlen_q == 1 and seqlen_k == 239 and d == 256: pytest.skip("This config doesnot work on RDNA Devices.") if ( @@ -1572,7 +1586,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) @@ -1741,7 +1755,7 @@ def test_flash_attn_varlen_causal( assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("deterministic", [True]) @@ -1871,7 +1885,7 @@ def test_flash_attn_splitkv( assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4 -# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +# @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("num_splits", [1, 0]) # @pytest.mark.parametrize("num_splits", [1]) @@ -1891,7 +1905,7 @@ def test_flash_attn_splitkv( # @pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) -@pytest.mark.parametrize("paged_kv_block_size", [None]) +@pytest.mark.parametrize("paged_kv_block_size", [None, 256]) # @pytest.mark.parametrize("paged_kv_block_size", [256, 512]) # @pytest.mark.parametrize("paged_kv_block_size", [None]) @pytest.mark.parametrize("has_leftpad", [False]) @@ -2183,7 +2197,7 @@ def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks -# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +# @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [True]) @@ -2310,7 +2324,7 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): ).abs().max().item() + 1e-3 -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.bfloat16]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [False]) @@ -2400,7 +2414,7 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): assert not v.grad.isnan().any() -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) @@ -2459,7 +2473,7 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc assert torch.equal(dq, dq0) -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True])