diff --git a/python/examples/gluon/01-attention-forward.py b/python/examples/gluon/01-attention-forward.py index 8a0e40c95866..43ecc4195278 100644 --- a/python/examples/gluon/01-attention-forward.py +++ b/python/examples/gluon/01-attention-forward.py @@ -383,7 +383,8 @@ def get_loop_bounds(self, STAGE: gl.constexpr): @gluon.jit def _borrow_s_as_p(config, s_tmem): - p_tmem = s_tmem._reinterpret(config.dtype, [config.SPLIT_M, 2 * config.BLOCK_N], config.p_tmem_layout) + cols: gl.constexpr = s_tmem.dtype.primitive_bitwidth // config.dtype.primitive_bitwidth + p_tmem = s_tmem._reinterpret(config.dtype, [config.SPLIT_M, cols * config.BLOCK_N], config.p_tmem_layout) return p_tmem.slice(0, config.BLOCK_N) @@ -1162,7 +1163,7 @@ def attention_forward(q, k, v, causal, sm_scale, o=None, M=None, *, use_tmem_red @pytest.mark.parametrize("N_CTX", [1024, 2048, 4096, 8192]) @pytest.mark.parametrize("HEAD_DIM", [64, 128]) @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e5m2]) @pytest.mark.parametrize("use_tmem_red", [False, True] if is_blackwell_ultra() else [False]) @pytest.mark.parametrize("cta_layout", [(), ((1, 0), ), ((1, 0), (2, 0))], ids=["1cta", "2ctas", "4ctas"]) @pytest.mark.skipif(not is_blackwell(), reason="Gluon attention is only supported on Blackwell GPUs") @@ -1178,15 +1179,19 @@ def alloc_fn(size: int, alignment: int, stream): pytest.skip("TMEM reduction is only supported on Blackwell Ultra GPUs") torch.manual_seed(42) - q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_()) - k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_()) - v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_()) + q = (torch.empty((Z, H, N_CTX, HEAD_DIM), device=device).normal_(mean=0.0, std=0.5).to(dtype).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, HEAD_DIM), device=device).normal_(mean=0.0, std=0.5).to(dtype).requires_grad_()) + v = (torch.empty((Z, H, N_CTX, HEAD_DIM), device=device).normal_(mean=0.0, std=0.5).to(dtype).requires_grad_()) sm_scale = 0.5 - ref_out = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=sm_scale, is_causal=causal) - tri_out, _ = attention_forward(q, k, v, causal, sm_scale, use_tmem_red=use_tmem_red, cta_layout=cta_layout) - torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) + if dtype == torch.float8_e5m2: + ref_out = torch.nn.functional.scaled_dot_product_attention(q.float(), k.float(), v.float(), scale=sm_scale, + is_causal=causal) + torch.testing.assert_close(ref_out.to(dtype).float(), tri_out.float(), atol=0.25, rtol=0.25) + else: + ref_out = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=sm_scale, is_causal=causal) + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) # ===-----------------------------------------------------------------------===#