Skip to content

Commit 549347d

Browse files
authored
[SYCLTLA] Fix FlashAttention FWD performance on PVC (#2415)
I missed a `else` when launch kernel so that it launchs kernel twice on PVC..
1 parent 12b6ab4 commit 549347d

File tree

3 files changed

+20
-22
lines changed

3 files changed

+20
-22
lines changed

src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ class FMHAPrefill {
191191
}
192192
193193
// Find the length of the longest non masked sequence within that subgroup
194-
int calculate_longest_non_masked_length(
194+
CUTLASS_DEVICE int calculate_longest_non_masked_length(
195195
const int& seq_len_kv,
196196
const int& seq_len_qo,
197197
const int& last_seq_coord,
@@ -222,7 +222,7 @@ class FMHAPrefill {
222222
}
223223
224224
template <class Tensor>
225-
void handle_corner_cases(
225+
CUTLASS_DEVICE void handle_corner_cases(
226226
Tensor& tSr,
227227
const int& thread_idx,
228228
const int& SubgroupSize,

src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1461,18 +1461,17 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> flash_attention_backward_sycltla(
14611461
.get_info<
14621462
sycl::ext::oneapi::experimental::info::device::architecture>();
14631463
constexpr auto supported_architectures =
1464-
std::array<sycl::ext::oneapi::experimental::architecture, 4>{
1464+
std::array<sycl::ext::oneapi::experimental::architecture, 3>{
14651465
sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc,
14661466
sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc_vg,
1467-
sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21,
1468-
sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g31};
1467+
sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21};
14691468
if (std::find(
14701469
supported_architectures.begin(),
14711470
supported_architectures.end(),
14721471
device_architecture) == supported_architectures.end()) {
14731472
TORCH_CHECK(
14741473
false,
1475-
"XPU device architecture does not support flash attention backward. Supported architectures are: intel_gpu_pvc, intel_gpu_pvc_vg, intel_gpu_bmg_g21, intel_gpu_bmg_g31.");
1474+
"XPU device architecture does not support flash attention backward. Supported architectures are: intel_gpu_pvc, intel_gpu_pvc_vg, intel_gpu_bmg_g21.");
14761475
}
14771476

14781477
auto grad_query = at::empty_like(query);

src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -333,19 +333,19 @@ void run_mha_fwd_(
333333
TileShapeOutPut,
334334
SubgroupLayout,
335335
PipelineStages);
336+
} else {
337+
constexpr int PipelineStages = 2;
338+
using TileShapeQK = Shape<_256, _32, _64>;
339+
using TileShapePV = Shape<_256, _32, _32>;
340+
using TileShapeOutPut = Shape<_256, _128, _32>;
341+
using SubgroupLayout = Layout<Shape<_16, _1, _1>, Stride<_1, _1, _1>>;
342+
run_mha_fwd_specialized(
343+
TileShapeQK,
344+
TileShapePV,
345+
TileShapeOutPut,
346+
SubgroupLayout,
347+
PipelineStages);
336348
}
337-
338-
constexpr int PipelineStages = 2;
339-
using TileShapeQK = Shape<_256, _32, _64>;
340-
using TileShapePV = Shape<_256, _32, _32>;
341-
using TileShapeOutPut = Shape<_256, _128, _32>;
342-
using SubgroupLayout = Layout<Shape<_16, _1, _1>, Stride<_1, _1, _1>>;
343-
run_mha_fwd_specialized(
344-
TileShapeQK,
345-
TileShapePV,
346-
TileShapeOutPut,
347-
SubgroupLayout,
348-
PipelineStages);
349349
} else if (headdim == 192) {
350350
constexpr int PipelineStages = 2;
351351
using TileShapeQK = Shape<_256, _64, _64>;
@@ -537,18 +537,17 @@ flash_attention_forward_sycltla(
537537
.get_info<
538538
sycl::ext::oneapi::experimental::info::device::architecture>();
539539
constexpr auto supported_architectures =
540-
std::array<sycl::ext::oneapi::experimental::architecture, 4>{
540+
std::array<sycl::ext::oneapi::experimental::architecture, 3>{
541541
sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc,
542542
sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc_vg,
543-
sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21,
544-
sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g31};
543+
sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21};
545544
if (std::find(
546545
supported_architectures.begin(),
547546
supported_architectures.end(),
548547
device_architecture) == supported_architectures.end()) {
549548
TORCH_CHECK(
550549
false,
551-
"XPU device architecture does not support flash attention. Supported architectures are: intel_gpu_pvc, intel_gpu_pvc_vg, intel_gpu_bmg_g21, intel_gpu_bmg_g31.");
550+
"XPU device architecture does not support flash attention. Supported architectures are: intel_gpu_pvc, intel_gpu_pvc_vg, intel_gpu_bmg_g21.");
552551
}
553552

554553
auto problem_shape = ProblemShapeRegular(

0 commit comments

Comments
 (0)