Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions flash_attn/cute/cute_dsl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,24 @@ def cute_compile_patched(*args, **kwargs):
return output


def assume_strides_aligned(t):
"""Assume all strides except the last are divisible by 128 bits.

Python int strides (e.g., stride=0 from GQA expand) are kept as-is
since they're static and don't need alignment assumptions.
"""
divby = 128 // t.element_type.width
strides = tuple(s if isinstance(s, int) else cute.assume(s, divby=divby) for s in t.stride[:-1])
return (*strides, t.stride[-1])


def assume_tensor_aligned(t):
"""Rebuild a tensor with 128-bit aligned stride assumptions. Passes through None."""
if t is None:
return None
return cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=assume_strides_aligned(t)))


def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True):
"""Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1."""
tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi)
Expand Down
8 changes: 4 additions & 4 deletions flash_attn/cute/flash_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import cutlass.utils as utils_basic

from flash_attn.cute import ampere_helpers as sm80_utils
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
from flash_attn.cute import utils
from flash_attn.cute.mask import AttentionMask
from flash_attn.cute.seqlen_info import SeqlenInfoQK
Expand Down Expand Up @@ -383,10 +384,9 @@ def __call__(
# Get the data type and check if it is fp16 or bf16
self._check_type(*(t.element_type if t is not None else None
for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)))
# Assume all strides are divisible by 128 bits except the last stride
# Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints)
new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) if not isinstance(s, int) or s != 0 else s for s in t.stride[:-1]), t.stride[-1])
mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) if t is not None else None for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)]
mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [
assume_tensor_aligned(t) for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)
]
self.varlen_q = (mCuSeqlensQ is not None)
self._setup_attributes()
SharedStorage = self._get_shared_storage_cls()
Expand Down
11 changes: 2 additions & 9 deletions flash_attn/cute/flash_bwd_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from cutlass.utils import LayoutEnum

from flash_attn.cute import utils
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
from flash_attn.cute import copy_utils
from flash_attn.cute import ampere_helpers as sm80_utils
from flash_attn.cute import hopper_helpers as sm90_utils
Expand Down Expand Up @@ -211,15 +212,7 @@ def __call__(
if const_expr(mdQaccum.element_type not in [cutlass.Float32]):
raise TypeError("dQaccum tensor must be Float32")

# Assume all strides are divisible by 128 bits except the last stride
new_stride = lambda t: (
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
t.stride[-1],
)
mdQaccum, mdQ = [
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
for t in (mdQaccum, mdQ)
]
mdQaccum, mdQ = [assume_tensor_aligned(t) for t in (mdQaccum, mdQ)]

self.tiled_mma = self._get_tiled_mma()
self._setup_attributes()
Expand Down
13 changes: 2 additions & 11 deletions flash_attn/cute/flash_bwd_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from cutlass import Float32

from flash_attn.cute import utils
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
from flash_attn.cute import copy_utils
from flash_attn.cute.seqlen_info import SeqlenInfoQK
from flash_attn.cute.tile_scheduler import (
Expand Down Expand Up @@ -135,17 +136,7 @@ def __call__(
if cutlass.const_expr(mLSElog2.element_type not in [Float32]):
raise TypeError("LSElog2 tensor must be Float32")

# Assume all strides are divisible by 128 bits except the last stride
new_stride = lambda t: (
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
t.stride[-1],
)
mO, mdO, mdQaccum = [
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
if t is not None
else None
for t in (mO, mdO, mdQaccum)
]
mO, mdO, mdQaccum = [assume_tensor_aligned(t) for t in (mO, mdO, mdQaccum)]

self._setup_attributes()

Expand Down
25 changes: 2 additions & 23 deletions flash_attn/cute/flash_bwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from cutlass.pipeline import PipelineAsync, PipelineConsumer

from flash_attn.cute import utils
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
from flash_attn.cute import copy_utils
from flash_attn.cute import pipeline
from flash_attn.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa
Expand Down Expand Up @@ -411,29 +412,7 @@ def __call__(
assert self.dk_dtype.width == 32, "Must accumulate dK in float precision for GQA"
assert self.dv_dtype.width == 32, "Must accumulate dV in float precision for GQA"

# Assume all strides are divisible by 128 bits except the last stride
# Skip assume for Python ints (e.g., stride=0 from GQA expand)
new_stride = lambda t: (
*(
s if isinstance(s, int) else cute.assume(s, divby=128 // t.element_type.width)
for s in t.stride[:-1]
),
t.stride[-1],
)
(
mdQaccum,
mdK,
mdV,
) = [
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
if t is not None
else None
for t in (
mdQaccum,
mdK,
mdV,
)
]
mdQaccum, mdK, mdV = [assume_tensor_aligned(t) for t in (mdQaccum, mdK, mdV)]

# (b, s, n, h) --> (s, h, n, b) or (t, n, h) -> (t, h, n)
QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]
Expand Down
17 changes: 2 additions & 15 deletions flash_attn/cute/flash_bwd_sm90.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from cutlass.utils import LayoutEnum

from flash_attn.cute import hopper_helpers as sm90_utils
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
from flash_attn.cute import utils
from flash_attn.cute import copy_utils
from flash_attn.cute.hopper_helpers import gemm_zero_init, gemm_w_idx
Expand Down Expand Up @@ -350,22 +351,8 @@ def __call__(
)
)

# Assume all strides are divisible by 128 bits except the last stride
# Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints)
new_stride = lambda t: (
*(
cute.assume(s, divby=128 // t.element_type.width)
if not isinstance(s, int) or s != 0
else s
for s in t.stride[:-1]
),
t.stride[-1],
)
mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
if t is not None
else None
for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)
assume_tensor_aligned(t) for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)
]

layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b)
Expand Down
34 changes: 3 additions & 31 deletions flash_attn/cute/flash_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from quack import copy_utils as quack_copy_utils

from flash_attn.cute import ampere_helpers as sm80_utils
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
from flash_attn.cute import hopper_helpers as sm90_utils
from flash_attn.cute import utils
from flash_attn.cute import copy_utils
Expand Down Expand Up @@ -660,21 +661,7 @@ def __call__(
self.use_tma_O = self.arch >= 90
self._setup_attributes()
SharedStorage = self._get_shared_storage_cls()
# Assume all strides are divisible by 128 bits except the last stride
# Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints)
new_stride = lambda t: (
*(
cute.assume(s, divby=128 // t.element_type.width)
if not isinstance(s, int) or s != 0
else s
for s in t.stride[:-1]
),
t.stride[-1],
)
mQ, mK, mV, mO = [
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
for t in (mQ, mK, mV, mO)
]
mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)]
mQ, mK, mV, mO = [
cute.make_tensor(t.iterator, cute.select(t.layout, mode=[1, 3, 2, 0]))
for t in (mQ, mK, mV, mO)
Expand Down Expand Up @@ -1303,22 +1290,7 @@ def __call__(
)
)

# Assume all strides are divisible by 128 bits except the last stride
# Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints)
new_stride = lambda t: (
*(
cute.assume(s, divby=128 // t.element_type.width)
if not isinstance(s, int) or s != 0
else s
for s in t.stride[:-1]
),
t.stride[-1],
)

mQ, mK, mV, mO = [
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
for t in (mQ, mK, mV, mO)
]
mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)]
QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]
mQ, mO = [utils.select(t, QO_layout_transpose) for t in (mQ, mO)]
KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1]
Expand Down
11 changes: 2 additions & 9 deletions flash_attn/cute/flash_fwd_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from cutlass import Float32, Int32, const_expr

from flash_attn.cute import utils
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
from flash_attn.cute.seqlen_info import SeqlenInfo
from cutlass.cute import FastDivmodDivisor

Expand Down Expand Up @@ -232,15 +233,7 @@ def __call__(
"LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)"
)

# Assume all strides are divisible by 128 bits except the last stride
new_stride = lambda t: (
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
t.stride[-1],
)
mO_partial, mO = [
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
for t in (mO_partial, mO)
]
mO_partial, mO = [assume_tensor_aligned(t) for t in (mO_partial, mO)]
# (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b)
# or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h)
O_partial_layout_transpose = (
Expand Down
12 changes: 2 additions & 10 deletions flash_attn/cute/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from flash_attn.cute.paged_kv import PagedKVManager
import flash_attn.cute.utils as utils
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
from flash_attn.cute import copy_utils
import flash_attn.cute.pipeline as pipeline
from flash_attn.cute.mask import AttentionMask
Expand Down Expand Up @@ -297,16 +298,7 @@ def __call__(
self.k_dtype = mK.element_type
self.v_dtype = mV.element_type
self.o_dtype = mO.element_type
# Assume all strides are divisible by 128 bits except the last stride
# Skip assume for Python ints (e.g., stride=0 from GQA expand)
new_stride = lambda t: (
*(s if isinstance(s, int) else cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
t.stride[-1],
)
mQ, mK, mV, mO = [
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
for t in (mQ, mK, mV, mO)
]
mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)]
Q_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]
mQ = cute.make_tensor(mQ.iterator, cute.select(mQ.layout, mode=Q_layout_transpose))
# (s_k, d, h_k, b_k) or (total_k, d, h_k) if there's cu_seqlens_k or (page_size, d, h_k, num_pages) if there's page_table
Expand Down
36 changes: 27 additions & 9 deletions flash_attn/cute/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import cutlass.cute as cute
from cutlass import Boolean, Int32, const_expr
from cutlass.cutlass_dsl import if_generate
from cutlass.pipeline import PipelineAsync, PipelineState, Agent, CooperativeGroup
from cutlass.pipeline import PipelineState, Agent, CooperativeGroup
from cutlass.pipeline import PipelineUserType, PipelineOp
from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg
from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg
Expand Down Expand Up @@ -150,19 +150,24 @@ def producer_acquire(
state: PipelineState,
try_acquire_token: Optional[Boolean] = None,
extra_tx_count: int = 0,
*,
loc=None,
ip=None,
):
"""
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
"""
if_generate(
try_acquire_token is None or try_acquire_token == 0,
lambda: self.sync_object_empty.wait(state.index, state.phase),
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
loc=loc,
ip=ip,
)
if const_expr(extra_tx_count == 0):
self.sync_object_full.arrive(state.index, self.producer_mask)
self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip)
else:
tx_count = self.sync_object_full.tx_count + extra_tx_count
self.sync_object_full.arrive_and_expect_tx(state.index, tx_count)
self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip)


@dataclass(frozen=True)
Expand Down Expand Up @@ -207,10 +212,10 @@ def create(
producer = (producer_type, producer_group)
consumer = (consumer_type, consumer_group)

sync_object_full = PipelineAsync._make_sync_object(
sync_object_full = PipelineTmaUmmaOg._make_sync_object(
Comment thread
drisspg marked this conversation as resolved.
barrier_storage.align(min_align=8), num_stages, producer, tx_count
)
sync_object_empty = PipelineAsync._make_sync_object(
sync_object_empty = PipelineTmaUmmaOg._make_sync_object(
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
)

Expand Down Expand Up @@ -251,22 +256,35 @@ def producer_acquire(
state: PipelineState,
try_acquire_token: Optional[Boolean] = None,
extra_tx_count: int = 0,
*,
loc=None,
ip=None,
):
"""
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
"""
if_generate(
try_acquire_token is None or try_acquire_token == 0,
lambda: self.sync_object_empty.wait(state.index, state.phase),
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
loc=loc,
ip=ip,
)
if const_expr(extra_tx_count == 0):
if_generate(
self.is_leader_cta,
lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
lambda: self.sync_object_full.arrive(
state.index, self.producer_mask, loc=loc, ip=ip
),
loc=loc,
ip=ip,
)
else:
tx_count = self.sync_object_full.tx_count + extra_tx_count
if_generate(
self.is_leader_cta,
lambda: self.sync_object_full.arrive_and_expect_tx(state.index, tx_count),
lambda: self.sync_object_full.arrive_and_expect_tx(
state.index, tx_count, loc=loc, ip=ip
),
loc=loc,
ip=ip,
)
4 changes: 2 additions & 2 deletions flash_attn/cute/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ classifiers = [
]

dependencies = [
"nvidia-cutlass-dsl>=4.3.5,<4.4.0",
"nvidia-cutlass-dsl>=4.4.0.dev1",
"torch",
"einops",
"typing_extensions",
"apache-tvm-ffi>=0.1.5,<0.2",
"torch-c-dlpack-ext",
"quack-kernels==0.2.4",
"quack-kernels>=0.2.7",
]

[project.optional-dependencies]
Expand Down