Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion flash_attn/cute/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,11 @@ def __init__(
self.mask_vec_size: cutlass.Constexpr = getattr(mask_mod, "__vec_size__", 1)
# Does S1 need to wait for S0 to finish
# self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local)
is_sm103 = self.arch >= Arch.sm_103 and self.arch <= Arch.sm_103f
# NOTE: is_family_of also matches any future sm_10x with x > 3 — intentional.
# The flag gates ex2 emulation; sm_103 (B300) has fast hardware ex2 and later
# Blackwell variants are assumed to inherit this, so forward-inclusion is correct
# despite the literal `is_sm103` name.
is_sm103 = self.arch.is_family_of(Arch.sm_103f)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_family_of will test (self.major == arch.major and self.minor >= arch.minor), so it is correct for now. If there is another sm10.x hardware with x > 3 in the future then it would be wrong.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ocss884 @lingolin128 Good catch — added an inline comment in 5833281 to make the forward-inclusive semantics explicit. The flag gates ex2 emulation, and since sm_103 (B300) brought fast hardware ex2, any future sm_10x Blackwellvariant should inherit the same behavior. So is_family_of's minor >= 3 behavior is actually what we want here, even though the literal name is_sm103 is a bit misleading.

self.is_sm103 = is_sm103
# enable_ex2_emu is derived: True if tuning config has freq > 0, else fallback to default logic
_default_enable_ex2_emu = (self.head_dim_padded <= 128 or (self.head_dim_padded == 192 and self.use_2cta_instrs and not self.is_causal and not self.is_local)) and not is_sm103
Expand Down
2 changes: 1 addition & 1 deletion flash_attn/cute/flash_fwd_sm90.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
"Paged KV does not support irregular head dim"
)
self.cluster_shape_mn = (1, 1)
assert self.arch >= Arch.sm_90 and self.arch <= Arch.sm_90a, "Only SM 9.x is supported"
assert self.arch.is_family_of(Arch.sm_90a), "Only SM 9.x is supported"

def _get_smem_layout_atom(self):
sQ_layout_atom = warpgroup.make_smem_layout_atom(
Expand Down