Skip to content

Commit

Permalink
[GPU] allow scalar eltwise primitive fusion with gemm (#28764)
Browse files Browse the repository at this point in the history
### Details:
- allows fusion of scalar eltwise layers with gemm to prevent unfusion
for some LLMs

### Tickets:
 - 161678
  • Loading branch information
e-ddykim authored Feb 4, 2025
1 parent 77d0779 commit 898f6e1
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2672,14 +2672,16 @@ bool primitive_inst::is_valid_fusion() const {
const auto& outer_dep = _deps[outer_dep_idx];

const auto& outer_dep_pshape = outer_dep.first->_impl_params->get_output_layout().get_partial_shape();
size_t outer_dep_pshape_count = outer_dep_pshape.is_static() ? ov::shape_size(outer_dep_pshape.to_shape()) : 0;
auto merged_shape = out_pshape;
bool can_broadcast = true;
if (fd.is_type<eltwise>())
can_broadcast = ov::PartialShape::broadcast_merge_into(merged_shape, outer_dep_pshape, fd.typed_desc<eltwise>()->broadcast_spec);

// Check if broadcast happens more than single axis.
// Current gemm_tiled_opt kernel FUSED_OP_LOAD macro cannot support broadcast on dynamic dimension.
if (_node->is_type<gemm>() && can_broadcast == true && merged_shape.rank().get_length() >= outer_dep_pshape.rank().get_length()) {
if (_node->is_type<gemm>() && can_broadcast == true && merged_shape.rank().get_length() >= outer_dep_pshape.rank().get_length() &&
outer_dep_pshape_count != 1) {
uint8_t broadcast_more_than_single_axis = 0;
auto updated_outer_dep_pshape = ov::PartialShape(outer_dep_pshape);

Expand Down Expand Up @@ -2715,7 +2717,7 @@ bool primitive_inst::is_valid_fusion() const {
cldnn::format::dimension(data_layout.format),
false);

if (gemm_dims[0] != data_dims[0])
if (gemm_dims[0] != data_dims[0] && outer_dep_pshape_count != 1)
return false;
} else if (_node->is_type<fully_connected>() && _node->get_preferred_impl_type() == impl_types::onednn) {
const auto& fc_layout = _impl_params->get_output_layout();
Expand Down
34 changes: 34 additions & 0 deletions src/plugins/intel_gpu/tests/unit/fusions/gemm_fusion_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,40 @@ TEST_P(gemm_2in_add, eltwise_postop_scalar) {
execute(p, false, true);
}

TEST_P(gemm_2in_add, eltwise_postop_scalar_dynamic) {
auto p = GetParam();

if (engine.get_device_info().supports_immad) {
ov::intel_gpu::ImplementationDesc gemmv_impl = { cldnn::format::type::any, "", impl_types::onednn };
cfg_fused.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ { "gemm_prim", gemmv_impl } }));
cfg_fused.set_property(ov::intel_gpu::use_only_static_kernels_for_dynamic_shape(true));
}

auto add_data_layout = get_output_layout(p);
auto add_data_size = add_data_layout.get_partial_shape();
for (size_t i = 0; i < add_data_size.size(); i++)
add_data_size[i] = 1;
add_data_layout.set_partial_shape(add_data_size);

auto in_layout0 = get_input_layout(p, 0);
auto in_layout1 = get_input_layout(p, 1);

in_layout0.set_partial_shape(ov::PartialShape::dynamic(p.in_shapes[0].size()));
in_layout1.set_partial_shape(ov::PartialShape::dynamic(p.in_shapes[1].size()));

create_topologies(
input_layout("input0", in_layout0),
input_layout("input1", in_layout1),
data("add_data", get_mem(add_data_layout, 0.5f)),
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32, false, false, 1.f, 0.f, in_layout0.get_rank(), in_layout1.get_rank()),
eltwise("add_prim", { input_info("gemm_prim"), input_info("add_data") }, p.eltwise_m, p.default_type),
reorder("reorder_bfyx", input_info("add_prim"), p.default_format, data_types::f32)
);

tolerance = default_tolerance(p.default_type);
execute(p, true, true);
}

INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_add, ::testing::ValuesIn(std::vector<gemm_test_params>{
// gemm_test_params{ CASE_GEMM_2IN_FP16_3, 3, 4, "", broadcast_kinds::none, eltwise_mode::sum }, // TODO: check why failed in eltwise_postop_dynamic
gemm_test_params{ CASE_GEMM_2IN_FP16_4, 3, 4, "", broadcast_kinds::none, eltwise_mode::sum },
Expand Down

0 comments on commit 898f6e1

Please sign in to comment.