diff --git a/src/common/snippets/include/snippets/pass/softmax_decomposition.hpp b/src/common/snippets/include/snippets/pass/softmax_decomposition.hpp new file mode 100644 index 00000000000000..51d80520d4991f --- /dev/null +++ b/src/common/snippets/include/snippets/pass/softmax_decomposition.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pattern/matcher.hpp" + +namespace ov { +namespace snippets { +namespace pass { + +/** + * @interface SoftmaxDecomposition + * @brief Decomposes Softmax to a range of low-level operations + * @ingroup snippets + */ +class SoftmaxDecomposition: public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("SoftmaxDecomposition", "0"); + SoftmaxDecomposition(); +}; + +} // namespace pass +} // namespace snippets +} // namespace ov diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index d2ff6ac5d6c9a6..6b66b10fafc578 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -12,6 +12,7 @@ #include "snippets/pass/convert_constants.hpp" #include "snippets/pass/convert_power_to_powerstatic.hpp" #include "snippets/pass/transpose_decomposition.hpp" +#include "snippets/pass/softmax_decomposition.hpp" #include "snippets/pass/matmul_to_brgemm.hpp" #include "snippets/pass/fuse_transpose_brgemm.hpp" #include "snippets/pass/set_softmax_ports.hpp" @@ -401,7 +402,11 @@ void Subgraph::data_flow_transformations(const BlockedShapeVector& blocked_input manager.register_pass(); manager.register_pass(); manager.register_pass(); - manager.register_pass(); + if (getenv("DISABLE_DATA_FLOW_DECOMPOSITION")) { + manager.register_pass(); + } else { + manager.register_pass(); + } } manager.register_pass(); manager.register_pass(); diff --git a/src/common/snippets/src/pass/softmax_decomposition.cpp b/src/common/snippets/src/pass/softmax_decomposition.cpp new file mode 100644 index 00000000000000..aeeff79128adbe --- /dev/null +++ b/src/common/snippets/src/pass/softmax_decomposition.cpp @@ -0,0 +1,73 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/pass/softmax_decomposition.hpp" + +#include "openvino/op/softmax.hpp" +#include "openvino/pass/pattern/op/or.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "snippets/itt.hpp" +#include "snippets/lowered/port_descriptor.hpp" +#include "snippets/op/reduce.hpp" +#include "snippets/snippets_isa.hpp" + +namespace ov { +namespace snippets { +namespace pass { +using namespace lowered; + +SoftmaxDecomposition::SoftmaxDecomposition() { + MATCHER_SCOPE(SoftmaxDecomposition); + auto softmax_v1_m = ov::pass::pattern::wrap_type(); + auto softmax_v8_m = ov::pass::pattern::wrap_type(); + auto softmax_m = std::make_shared(ov::OutputVector{softmax_v1_m, softmax_v8_m}); + + ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::SoftmaxDecomposition") + auto softmax = m.get_match_root(); + + const auto& pshape = softmax->get_input_partial_shape(0); + OPENVINO_ASSERT(!pshape.rank().is_dynamic(), "SetSoftmaxPorts doesn't support dynamic ranks"); + const auto rank = pshape.size(); + + size_t axis; + if (const auto softmax_v8 = ov::as_type_ptr(softmax)) { + OPENVINO_SUPPRESS_DEPRECATED_START + axis = ov::normalize_axis(softmax->get_friendly_name(), softmax_v8->get_axis(), rank); + OPENVINO_SUPPRESS_DEPRECATED_END + } else if (const auto softmax_v1 = ov::as_type_ptr(softmax)) { + axis = softmax_v1->get_axis(); + } else { + OPENVINO_THROW("Unexpected node matched"); + } + + const auto& softmax_input = softmax->input_value(0); + const auto reduce_max = std::make_shared(softmax_input, axis); + const auto subtract = std::make_shared(softmax_input, reduce_max); + const auto exp = std::make_shared(subtract); + + const auto reduce_sum = std::make_shared(exp, axis); + const auto power = std::make_shared(reduce_sum, -1.f); + const auto multiply = std::make_shared(exp, power); + + OPENVINO_ASSERT(axis < rank, "Softmax has incorrect axis"); + std::vector subtensor(rank, 1); + for (size_t i = axis; i < rank; ++i) + subtensor[i] = PortDescriptor::ServiceDimensions::FULL_DIM; + + PortDescriptorUtils::set_port_descriptor_ptr(reduce_max->input(0), std::make_shared(reduce_max->input(0), subtensor)); + PortDescriptorUtils::set_port_descriptor_ptr(reduce_sum->input(0), std::make_shared(reduce_sum->input(0), subtensor)); + PortDescriptorUtils::set_port_descriptor_ptr(reduce_max->output(0), std::make_shared(reduce_max->output(0), subtensor)); + PortDescriptorUtils::set_port_descriptor_ptr(reduce_sum->output(0), std::make_shared(reduce_sum->output(0), subtensor)); + + return ov::replace_node_update_name(softmax, multiply); + }; + + auto m = std::make_shared(softmax_m, matcher_name); + register_matcher(m, callback); +} + +} // namespace pass +} // namespace snippets +} // namespace ov