Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions flash_attn/cute/cute_dsl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,35 @@ def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, ena
return tensor.mark_layout_dynamic(leading_dim=leading_dim)


def to_cute_aux_tensor(t, enable_tvm_ffi=True):
"""Convert torch tensor to cute tensor for TVM FFI, tailored to FlexAttention aux tensors.
This allows the user to specify alignment and leading dimension for aux tensors used in
custom score_mod callables.
"""
assumed_align: int = getattr(t, "__assumed_align__", None)
leading_dim: int = getattr(t, "__leading_dim__", None)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in the future we it would be nice to find these programmatically instead of users facing (potentially)

fully_dynamic: bool = leading_dim is None

return to_cute_tensor(
t,
assumed_align=assumed_align,
leading_dim=leading_dim,
fully_dynamic=fully_dynamic,
enable_tvm_ffi=enable_tvm_ffi,
)


def get_aux_tensor_metadata(aux_tensors):
return tuple(
(
getattr(t, "__assumed_align__", 0),
getattr(t, "__leading_dim__", -1),
hasattr(t, "__leading_dim__"),
)
for t in aux_tensors
)


def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]:
"""Return tuple of bools indicating which dims have stride=0 (broadcast).

Expand Down
7 changes: 3 additions & 4 deletions flash_attn/cute/flash_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,9 @@ def __init__(
self.score_mod = score_mod
self.mask_mod = mask_mod
self.qk_acc_dtype = Float32
if const_expr(has_aux_tensors):
self.vec_size: cutlass.Constexpr = 1
else:
self.vec_size: cutlass.Constexpr = 2
self.vec_size: cutlass.Constexpr = getattr(
score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2
)

@staticmethod
def can_implement(
Expand Down
7 changes: 3 additions & 4 deletions flash_attn/cute/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,9 @@ def __init__(
)
self.score_mod = score_mod
self.mask_mod = mask_mod
if cutlass.const_expr(has_aux_tensors):
self.vec_size: cutlass.Constexpr = 1
else:
self.vec_size: cutlass.Constexpr = 2
self.vec_size: cutlass.Constexpr = getattr(
score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2
)
# Does S1 need to wait for S0 to finish
# self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local)
self.s0_s1_barrier = False
Expand Down
13 changes: 9 additions & 4 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@


from flash_attn.cute import utils
from flash_attn.cute.cute_dsl_utils import to_cute_tensor
from flash_attn.cute.cute_dsl_utils import to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata
from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90
from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100
from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess
Expand Down Expand Up @@ -368,7 +368,11 @@ def _flash_attn_fwd(
block_size=(m_block_size, n_block_size),
q_stage=q_stage,
)

if aux_tensors is not None:
aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SUPER DUPER nit; aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors) if aux_tensors else None

all real estate feels quite precious in this file

else:
aux_tensor_metadata = None

compile_key = (
dtype,
head_dim,
Expand All @@ -379,7 +383,7 @@ def _flash_attn_fwd(
mask_mod_hash,
use_block_sparsity,
block_sparse_broadcast_pattern,
len(aux_tensors) if aux_tensors is not None else 0,
aux_tensor_metadata,
lse is None,
cu_seqlens_q is None,
cu_seqlens_k is None,
Expand Down Expand Up @@ -432,8 +436,9 @@ def _flash_attn_fwd(
sparse_tensors = to_cute_block_sparse_tensors(normalized_block_sparse_tensors)

cute_aux_tensors = None
aux_tensor_metadata = None
if aux_tensors is not None:
cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors]
cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors]

if compute_capability == 9:
assert page_table is None, "paged KV not supported on SM 9.0"
Expand Down
67 changes: 42 additions & 25 deletions flash_attn/cute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,11 @@

import quack.activation

_MIXER_ATTRS = ("__vec_size__",)

def hash_callable(func: Callable, set_cute_hash=True) -> str:
"""Hash a callable based on the source code or bytecode and closure values.

Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__``
attribute, that value is returned immediately. Code-generation backends such
as Inductor can set this attribute to avoid expensive runtime hashing.

set_cute_hash: whether or not to set func.__cute_hash__ if not present
"""
if hasattr(func, "__cute_hash__"):
return func.__cute_hash__

# Unwrap decorated functions (e.g., cute.jit wrappers).
if hasattr(func, "__wrapped__"):
base_func = func.__wrapped__
if hasattr(base_func, "__cute_hash__"):
return base_func.__cute_hash__
func = base_func

def _compute_base_hash(func: Callable) -> str:
"""Compute hash from source code or bytecode and closure values."""
try:
data = inspect.getsource(func).encode()
except (OSError, TypeError):
Expand All @@ -48,16 +33,48 @@ def hash_callable(func: Callable, set_cute_hash=True) -> str:
hasher = hashlib.sha256(data)

if hasattr(func, "__closure__") and func.__closure__ is not None:
for idx, cell in enumerate(func.__closure__):
cell_value = cell.cell_contents
hasher.update(repr(cell_value).encode())
for cell in func.__closure__:
hasher.update(repr(cell.cell_contents).encode())

return hasher.hexdigest()


def hash_callable(
func: Callable, mixer_attrs: Tuple[str] = _MIXER_ATTRS, set_cute_hash: bool = True
) -> str:
"""Hash a callable based on the source code or bytecode and closure values.
Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__``
attribute, that value is returned immediately as the base hash, then
metadata dunders are mixed in to produce the final dict-key hash.
set_cute_hash: whether or not to set func.__cute_hash__
"""
# Resolve base hash
if hasattr(func, "__cute_hash__"):
base_hash = func.__cute_hash__
else:
# Unwrap decorated functions (e.g., cute.jit wrappers).
base_func = getattr(func, "__wrapped__", func)

if hasattr(base_func, "__cute_hash__"):
base_hash = base_func.__cute_hash__
else:
base_hash = _compute_base_hash(base_func)

if set_cute_hash:
base_func.__cute_hash__ = base_hash

# Mix in mutable metadata dunders
mixer_values = tuple(getattr(func, attr, None) for attr in mixer_attrs)

if all(v is None for v in mixer_values):
return base_hash

hash = hasher.hexdigest()
hasher = hashlib.sha256(base_hash.encode())

if set_cute_hash:
func.__cute_hash__ = hash
for attr, val in zip(_MIXER_ATTRS, mixer_values):
hasher.update(f"{attr}={val!r}".encode())

return hash
return hasher.hexdigest()


def create_softcap_scoremod(softcap_val):
Expand Down
84 changes: 84 additions & 0 deletions tests/cute/score_mod_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,47 @@ def score_mod_identity(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_t
return tSrS_ssa


@cute.jit
def score_mod_identity_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
return tSrS_ssa


@cute.jit
def score_mod_causal(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
mask = operator.ge(q_idx, kv_idx)
return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf")))


@cute.jit
def score_mod_causal_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
mask = cute.make_rmem_tensor(kv_idx.shape, dtype=cutlass.Boolean)
kv_idx0 = kv_idx[0]
q_idx0 = q_idx[0]
for i in cutlass.range_constexpr(cute.size(mask.shape)):
mask[i] = q_idx0 >= kv_idx0 + i
mask_ssa = mask.load()
return cute.where(mask_ssa, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf")))


@cute.jit
def score_mod_rel_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
diff = q_idx - kv_idx
abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype)
return tSrS_ssa + abs_diff.to(cutlass.Float32)


@cute.jit
def score_mod_rel_bias_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
q_idx0 = q_idx[0]
kv_idx0 = kv_idx[0]
diff0 = q_idx0 - kv_idx0
abs_diff = cute.make_rmem_tensor(kv_idx.shape, dtype=diff0.dtype)
for i in cutlass.range_constexpr(cute.size(kv_idx.shape)):
diffi = diff0 - i
abs_diff[i] = mlir_math.absi(diffi)
return tSrS_ssa + abs_diff.load().to(cutlass.Float32)


@cute.jit
def score_mod_rel_bias_x2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
diff = q_idx - kv_idx
Expand All @@ -36,10 +64,25 @@ def score_mod_rel_bias_x2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, au
return tSrS_ssa + scaled.to(cutlass.Float32)


@cute.jit
def score_mod_rel_bias_x2_vectorized(
tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors
):
q_idx0 = q_idx[0]
kv_idx0 = kv_idx[0]
diff0 = q_idx0 - kv_idx0
abs_diff_x2 = cute.make_rmem_tensor(kv_idx.shape, dtype=diff0.dtype)
for i in cutlass.range_constexpr(cute.size(kv_idx.shape)):
diffi = diff0 - i
abs_diff_x2[i] = mlir_math.absi(diffi) * 2
return tSrS_ssa + abs_diff_x2.load().to(cutlass.Float32)


@cute.jit
def score_mod_times_two(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
return tSrS_ssa * cute.full_like(tSrS_ssa, 2)

score_mod_times_two_vectorized = score_mod_times_two

@cute.jit
def score_mod_alibi(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
Expand All @@ -53,6 +96,21 @@ def score_mod_alibi(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tens
abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype).to(cutlass.Float32)
return score - slope * abs_diff

@cute.jit
def score_mod_alibi_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
score = tSrS_ssa.to(cutlass.Float32)
slope_exp = (h_idx + cute.full_like(h_idx, 1)) * cute.full_like(h_idx, -8)
slope = cute.math.exp2(
slope_exp.to(cutlass.Float32)
* cute.full_like(score, 0.125 * 0.6931471805599453 * 1.4426950408889634)
)
diff0 = q_idx[0] - kv_idx[0]
abs_diff = cute.make_rmem_tensor(kv_idx.shape, diff0.dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we write a note somewhere that vec_width for fwd score-mod is always encoded in kv_idx shape?

for i in cutlass.range_constexpr(cute.size(abs_diff.shape)):
diffi = diff0 - i
abs_diff[i] = mlir_math.absi(diffi)
return score - slope * abs_diff.load().to(cutlass.Float32)


@cute.jit
def score_mod_sliding_window(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
Expand Down Expand Up @@ -88,6 +146,16 @@ def score_mod_batch_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux
bias_val = (bias_frag.load()).to(cutlass.Float32)
return tSrS_ssa + bias_val

@cute.jit
def score_mod_batch_bias_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
batch_bias = aux_tensors[0]
dtype = batch_bias.element_type
b_idx0 = b_idx[0]
bias_frag = cute.make_rmem_tensor(1, dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and to triple check is this is actually not vectorized right?

bias_frag[0] = batch_bias[b_idx0]
bias_val = (bias_frag.load()).to(cutlass.Float32)
return tSrS_ssa + bias_val
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or maybe it is and this + is doing broadcasting? if so should we also have some doc on this pattern for aux_tensor vectorization?



@cute.jit
def score_mod_dual_buffer(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
Expand All @@ -109,6 +177,22 @@ def score_mod_dual_buffer(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, au

return tSrS_ssa + head_val + pos_val

@cute.jit
def score_mod_dual_buffer_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
head_bias = aux_tensors[0]
pos_bias = aux_tensors[1]
dtype = head_bias.element_type

head_val_frag = cute.make_fragment(1, dtype)
head_val_frag[0] = head_bias[h_idx[0]]
head_val = (head_val_frag.load()).to(cutlass.Float32)

pos_val_frag = cute.make_fragment(1, dtype)
pos_val_frag[0] = pos_bias[q_idx[0]]
pos_val = (pos_val_frag.load()).to(cutlass.Float32)

return tSrS_ssa + head_val + pos_val


# =============================================================================
# Score_mod functions that use global indices
Expand Down
Loading