diff --git a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp index e94714c84fdebf..96e0476fbd8044 100644 --- a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp +++ b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp @@ -847,6 +847,26 @@ void prepare_buffer_fusing::run(program& p) { if (user_info.first) { node.get_users().front()->set_output_layout(user_info.second); } + + // In case that the rank of weight node of gemm is less than 4 and, + // it transforms to extend to 4 dims by adding 1 to begin(). + // Therefore, the padding of crop_layout should be shifted properly. + const size_t TDIM = 4; + auto user = node.get_users().front(); + if (user->is_type() && pred_layout.is_static() && user->get_dependency(1).id().compare(node.id()) == 0) { + auto input_rank = user->get_kernel_impl_params()->typed_desc()->weight_rank; + if (input_rank < TDIM) { + std::vector l_pad = {0, 0, 0, 0}; + std::vector u_pad = {0, 0, 0, 0}; + + //shift right + size_t shift_right = TDIM - input_rank; + std::copy_n(crop_layout.data_padding._lower_size.begin(), l_pad.size() - shift_right, l_pad.begin() + shift_right); + std::copy_n(crop_layout.data_padding._upper_size.begin(), u_pad.size() - shift_right, u_pad.begin() + shift_right); + + crop_layout.data_padding = padding(l_pad, u_pad); + } + } } node.set_output_layout(crop_layout); node.can_be_optimized(true); diff --git a/src/plugins/intel_gpu/tests/unit/passes/prepare_buffer_fusing_test.cpp b/src/plugins/intel_gpu/tests/unit/passes/prepare_buffer_fusing_test.cpp index 5adc1e691b82a7..456fab4ae0286a 100644 --- a/src/plugins/intel_gpu/tests/unit/passes/prepare_buffer_fusing_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/passes/prepare_buffer_fusing_test.cpp @@ -11,6 +11,8 @@ #include "intel_gpu/graph/program.hpp" #include "data_inst.h" +#include "concatenation_inst.h" +#include "gemm_inst.h" #include "crop_inst.h" #include "convolution_inst.h" #include "gather_inst.h" @@ -707,6 +709,54 @@ TEST(prepare_buffer_fusing, in_place_crop_static) { ASSERT_EQ(output_ptr_2[i], out2[i]); } +TEST(prepare_buffer_fusing, in_place_crop_static_padding_and_gemm) { + auto& engine = get_test_engine(); + + auto gemm_input_mem = engine.allocate_memory({ {1, 4, 4, 2}, data_types::f32, format::bfyx }); + auto concat_input_mem = engine.allocate_memory({ {1, 4, 2}, data_types::f32, format::bfyx }); + + set_values(gemm_input_mem, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f }); + set_values(concat_input_mem, { -0.5f, 2.0f, 0.5f, 1.0f, 0.5f, -2.0f, -0.5f, -1.0f }); + + std::vector expected = { 0.5, 4, 0.5, 10, 0.5, 16, 0.5, 22, + 0.5, 4, 0.5, 10, 0.5, 16, 0.5, 22, + 0.5, 4, 0.5, 10, 0.5, 16, 0.5, 22, + 0.5, 4, 0.5, 10, 0.5, 16, 0.5, 22}; + cldnn::tensor refSize = {1, 2, 1, 2}; + + topology topology( + input_layout("gemm_input", gemm_input_mem->get_layout()), + input_layout("concat_input", concat_input_mem->get_layout()), + concatenation("concat", { input_info("concat_input"), input_info("concat_input") }, 2), + crop("crop", input_info("concat"), refSize, tensor(0, 0, 0, 0)), + gemm("gemm", { input_info("gemm_input"), input_info("crop") }, data_types::f32, false, false, 1.0, 0.0, 4, 3), + reorder("output", input_info("gemm"), format::bfyx, data_types::f32) + ); + + { + auto config = get_test_default_config(engine); + config.set_property(ov::intel_gpu::optimize_data(true)); + network network(engine, topology, config); + + network.set_input_data("gemm_input", gemm_input_mem); + network.set_input_data("concat_input", concat_input_mem); + + auto outputs = network.execute(); + + auto crop_prim = network.get_primitive("crop"); + ASSERT_EQ(crop_prim->can_be_optimized(), true); + + auto output = outputs.at("output").get_memory(); + cldnn::mem_lock output_ptr(output, get_test_stream()); + for (size_t i = 0; i < expected.size(); i++) { + ASSERT_EQ(output_ptr[i], expected[i]); + } + } +} + TEST(prepare_buffer_fusing, in_place_crop_dynamic) { auto& engine = get_test_engine();