Skip to content

Commit af2f402

Browse files
committed
Remove head size 16 examples
1 parent eb7ea6b commit af2f402

File tree

4 files changed

+10
-1
lines changed

4 files changed

+10
-1
lines changed

applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ class FMHAFwdEpilogue {
285285
}
286286
}
287287
}
288+
288289
return std::make_tuple(rA, rA_sum, active);
289290
}
290291
}

applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,11 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
293293
if (check_remainder_k && K == blk_k1 - 1) {
294294
FragSRow k_rem_mask;
295295
int k = get<0>(tKgK(0,0,0,K,0)) + get_sub_group().get_local_id()[0];
296+
CUTLASS_PRAGMA_UNROLL
296297
for (int i = 0; i < k_rem_mask.size(); i++, k += intel::sg_size) {
297298
k_rem_mask(i) = (k < shape<0>(K_2D)) ? ElementS(sycl::nan(0u)) : ElementS(-INFINITY);
298299
}
300+
CUTLASS_PRAGMA_UNROLL
299301
for (int i = 0; i < tSrS.size(); i++) {
300302
tSrS(i) = sycl::fmin(tSrS(i), broadcast<1>(k_rem_mask, tSrS, i));
301303
}
@@ -312,6 +314,11 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
312314
for (int i = 0; i < tArP.size(); i++)
313315
tArP(i) = static_cast<typename TiledMMAPV::ValTypeA>(tSrS(i));
314316
#endif
317+
#define XSHOW(x) if (cute::thread0()) { print(#x ": "); print(x); print("\n"); }
318+
XSHOW(tSrS.layout())
319+
XSHOW(tSrS.tv_layout())
320+
XSHOW(tArP.layout())
321+
XSHOW(tArP.tv_layout())
315322

316323
/* GEMM 2: A += P * V, split in v dimension */
317324
CUTLASS_PRAGMA_UNROLL

examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ int main(int argc, const char **argv) {
133133
using SubgroupLayoutQK = Layout<Shape<_1, _8, _1>>;
134134

135135
#elif HEAD_DIM == 192
136+
/* FIXME: does not compile yet due to 1x24 output tile */
136137
using ShapeQK = Shape<_1, _512, _64>;
137138
using ShapePV = Shape<_1, _32, _512>;
138139
using ShapeOut = Shape<_1, _192>;

examples/06_bmg_flash_attention/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ set(CUTLASS_APPLICATIONS_DIR ${CMAKE_SOURCE_DIR}/applications)
3232
set(TEST_NO_PAGED "")
3333
set(TEST_PAGED "--use_paged_kv")
3434

35-
foreach(HEAD_DIM 16 64 96 128 192)
35+
foreach(HEAD_DIM 64 96 128 192)
3636

3737
cutlass_example_add_executable(
3838
06_xe_fmha_fwd_prefill_hdim${HEAD_DIM}

0 commit comments

Comments
 (0)