From 2becda44248c81efeaf6a2dfd46e3b4ce437f79d Mon Sep 17 00:00:00 2001 From: Sergey Shlyapnikov Date: Fri, 24 Jan 2025 19:12:10 +0400 Subject: [PATCH] [GPU] Update PagedAttention output shape, add dynamic paddings support for mixed kernel mode execution --- .../intel_gpu/src/graph/paged_attention.cpp | 3 +- .../kernel_selector/cl_kernels/pa_sdpa_opt.cl | 6 +- .../test_cases/paged_attention_gpu_test.cpp | 79 +++++++++++++------ 3 files changed, 63 insertions(+), 25 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/paged_attention.cpp index c656cb1f284ae0..a197c0798ec28f 100644 --- a/src/plugins/intel_gpu/src/graph/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/paged_attention.cpp @@ -49,6 +49,7 @@ layout paged_attention_inst::calc_output_layout(const paged_attention_node& /*no template std::vector paged_attention_inst::calc_output_layouts(paged_attention_node const& /*node*/, kernel_impl_params const& impl_param) { auto data_layout = impl_param.get_input_layout(0); + data_layout.data_padding = padding(); const auto& key_cache_ps = impl_param.get_input_layout(3).get_partial_shape(); bool valid_block_size = key_cache_ps[3].is_dynamic() || key_cache_ps[3].get_length() == paged_attention::block_size; @@ -71,7 +72,7 @@ std::vector paged_attention_inst::calc_output_layouts(paged_attention_no total_size += past_lens_mem_lock[i]; } - total_size += static_cast(impl_param.get_input_layout(0).get_shape()[0]); + total_size += static_cast(data_layout.get_shape()[0]); output_layouts.push_back(layout{ov::PartialShape{total_size}, output_dt, format::bfyx}); } else { diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl index 7a300aaee1a16a..5ffa0f083da719 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl @@ -118,7 +118,8 @@ KERNEL(pa_sdpa_opt)( { #if STORE_QUERY_TO_SLM const uint query_idx_local = sgid * SUBGROUP_SIZE + sglid; - const uint query_idx = seq_idx * HEAD_SIZE * HEADS_NUM + + const uint query_idx = INPUT0_OFFSET + + seq_idx * (HEAD_SIZE * HEADS_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM) + head_num_idx * HEAD_SIZE + query_idx_local; @@ -137,7 +138,8 @@ KERNEL(pa_sdpa_opt)( #else INPUT0_TYPE q_val[HEAD_SIZE / SUBGROUP_SIZE]; unroll_for (uint i = 0; i < HEAD_SIZE / SUBGROUP_SIZE; i++) { - const uint query_idx = seq_idx * HEAD_SIZE * HEADS_NUM + + const uint query_idx = INPUT0_OFFSET + + seq_idx * (HEAD_SIZE * HEADS_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM) + head_num_idx * HEAD_SIZE + i * SUBGROUP_SIZE; q_val[i] = BLOCK_READN(INPUT0_TYPE, 1, query, query_idx); diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index cdb927a57ca2bb..7d7aab78d3efe0 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -730,6 +730,37 @@ struct PagedAttentionTest : public ::testing::TestWithParam { rotation_deltas_layout.set_partial_shape(ov::PartialShape{ -1, -1 }); rotation_trig_lut_layout.set_partial_shape(ov::PartialShape{ -1, p.head_size }); + if (p.dynamic_paddings) { + const auto padding_axis = 1; + const auto pad_before = p.head_size; + const auto pad_after = p.head_size * 2; + + query_layout.data_padding._dynamic_dims_mask[padding_axis] = 1; + + auto query_data_layout = query_mem->get_layout(); + auto padded_query_data_layout = query_data_layout; + padded_query_data_layout.data_padding._lower_size[padding_axis] = pad_before; + padded_query_data_layout.data_padding._upper_size[padding_axis] = pad_after; + + auto new_query_memory = get_test_engine().allocate_memory(padded_query_data_layout, false); + + mem_lock query_mem_lock(query_mem, get_test_stream()); + mem_lock new_query_mem_lock(new_query_memory, get_test_stream()); + + auto query_data_shape = query_data_layout.get_shape(); + for (size_t b = 0; b < query_data_shape[0]; b++) { + for (size_t f = 0; f < query_data_shape[1]; f++) { + auto input_offset = + query_data_layout.get_linear_offset(cldnn::tensor(static_cast(b), static_cast(f), 0, 0, 0, 0)); + auto output_offset = + padded_query_data_layout.get_linear_offset(cldnn::tensor(static_cast(b), static_cast(f), 0, 0, 0, 0)); + + new_query_mem_lock[output_offset] = query_mem_lock[input_offset]; + } + } + query_mem = new_query_memory; + } + std::vector pa_inputs = { input_info("query"), input_info("key"), @@ -857,6 +888,7 @@ struct paged_attention_test_params { int num_heads; int head_size; int block_size; + bool dynamic_paddings; bool scores_output; CacheRotationDescriptor rotation_config; }; @@ -873,31 +905,34 @@ const auto DISABLE_SCORES = false; const auto PER_BLOCK_ROTATION = CacheRotationDescriptor{ true, true }; const auto PER_TOKEN_ROTATION = CacheRotationDescriptor{ true, false }; const auto DISABLE_ROTATION = CacheRotationDescriptor{ false, false }; +const auto STATIC_INPUT_PAD = false; +const auto DYNAMIC_INPUT_PAD = true; INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing::ValuesIn(std::vector{ /* with scores output */ - paged_attention_test_params{ {{10, 0}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token - paged_attention_test_params{ {{36, 0}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token - paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token long - paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token + 1st token - paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token + 1st token - paged_attention_test_params{ {{1, 10}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token - /* without scores output */ - paged_attention_test_params{ {{10, 0}}, 2, 64, 16, DISABLE_SCORES, DISABLE_ROTATION }, // 1st token - paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, DISABLE_SCORES, DISABLE_ROTATION }, // 1st token long - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, DISABLE_SCORES, DISABLE_ROTATION }, // 2nd token + 2nd token + paged_attention_test_params{ {{10, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token + paged_attention_test_params{ {{36, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token long + paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token + 1st token + paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token + 1st token + paged_attention_test_params{ {{1, 10}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token + /* without scores output, dynamic input query paddings */ + paged_attention_test_params{ {{10, 0}}, 2, 64, 16, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION }, // 1st token long + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token /* with scores, per_block rotation */ - paged_attention_test_params{ {{10, 0}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token - paged_attention_test_params{ {{36, 0}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token - paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token long - paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token + 1st token - paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token + 1st token - paged_attention_test_params{ {{1, 10}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{10, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token + paged_attention_test_params{ {{36, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token long + paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token + 1st token + paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token + 1st token + paged_attention_test_params{ {{1, 10}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token /* with scores, per_token rotation */ - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, ENABLE_SCORES, PER_TOKEN_ROTATION }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, ENABLE_SCORES, PER_TOKEN_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token }));