Skip to content

Commit b424d07

Browse files
committed
fix hopper ut
1 parent 7142d6b commit b424d07

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tests/attention/test_batch_prefill_kernels.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,17 @@ def test_batch_prefill_with_paged_kv_cache(
144144
logits_soft_cap=logits_soft_cap,
145145
)
146146
if return_lse:
147-
o, _ = wrapper.run(q, kv_data, return_lse=True)
147+
o, lse = wrapper.run(q, kv_data, return_lse=True)
148148
else:
149149
o = wrapper.run(q, kv_data)
150150

151151
# test with pre-allocated output
152152
o_buffer = torch.empty_like(o)
153-
wrapper.run(q, kv_data, out=o_buffer)
153+
if return_lse:
154+
lse_buffer = torch.empty_like(lse)
155+
wrapper.run(q, kv_data, out=o_buffer, lse=lse_buffer, return_lse=True)
156+
else:
157+
wrapper.run(q, kv_data, out=o_buffer)
154158
torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3)
155159
else:
156160
q_indptr_buffer = torch.empty(

0 commit comments

Comments
 (0)