Use is_family_of for sm_90 and sm_103 arch checks#2589
Conversation
Follow-up to Dao-AILab#2572 — apply the same is_family_of pattern to the two remaining range-style arch checks for consistency: - flash_fwd_sm90.py:69 (SM 9.x assert) - flash_fwd_sm100.py:195 (is_sm103 flag) Same semantic narrowing as Dao-AILab#2572: bare-base SMs (sm_90, sm_103) are excluded. These kernels rely on wgmma / UMMA / 2CTA paths that require the a/f PTX variant anyway, so bare-base targets could not compile.
|
Cc, @ocss884 @lingolin128 helping review |
| # Does S1 need to wait for S0 to finish | ||
| # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) | ||
| is_sm103 = self.arch >= Arch.sm_103 and self.arch <= Arch.sm_103f | ||
| is_sm103 = self.arch.is_family_of(Arch.sm_103f) |
There was a problem hiding this comment.
is_family_of will test (self.major == arch.major and self.minor >= arch.minor), so it is correct for now. If there is another sm10.x hardware with x > 3 in the future then it would be wrong.
There was a problem hiding this comment.
@ocss884 @lingolin128 Good catch — added an inline comment in 5833281 to make the forward-inclusive semantics explicit. The flag gates ex2 emulation, and since sm_103 (B300) brought fast hardware ex2, any future sm_10x Blackwellvariant should inherit the same behavior. So is_family_of's minor >= 3 behavior is actually what we want here, even though the literal name is_sm103 is a bit misleading.
|
LGTM, @ocss884's consideration is completely correct in terms of code. |
is_family_of(sm_103f) also matches any future sm_10x with x > 3, not just sm_103a/f. This was raised in PR review (@ocss884) — adding an inline comment clarifying that this forward-inclusive behavior is intentional: the flag gates ex2 emulation, sm_103 (B300) has fast hardware ex2, and later Blackwell variants in the same family are assumed to inherit it. No code-behavior change.
* Use is_family_of for sm_90 and sm_103 arch checks Follow-up to Dao-AILab#2572 — apply the same is_family_of pattern to the two remaining range-style arch checks for consistency: - flash_fwd_sm90.py:69 (SM 9.x assert) - flash_fwd_sm100.py:195 (is_sm103 flag) Same semantic narrowing as Dao-AILab#2572: bare-base SMs (sm_90, sm_103) are excluded. These kernels rely on wgmma / UMMA / 2CTA paths that require the a/f PTX variant anyway, so bare-base targets could not compile. * Clarify is_sm103 forward-inclusive semantics is_family_of(sm_103f) also matches any future sm_10x with x > 3, not just sm_103a/f. This was raised in PR review (@ocss884) — adding an inline comment clarifying that this forward-inclusive behavior is intentional: the flag gates ex2 emulation, sm_103 (B300) has fast hardware ex2, and later Blackwell variants in the same family are assumed to inherit it. No code-behavior change.
Follow-up to #2572 — apply the same is_family_of pattern to the two remaining range-style arch checks for consistency:
Same semantic narrowing as #2572: bare-base SMs (
sm_90,sm_103) are excluded — they were accepted by the old range check but cannot actually run these kernels, since wgmma / UMMA / 2CTA paths require thea/fPTX variant. The new assertion fails earlier with a clearer error instead of letting compilation get to ptxas before failing.