Skip to content

Commit

Permalink
Softmax decomposition: WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Nov 29, 2023
1 parent e0bfeea commit ef4178e
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/matcher.hpp"

namespace ov {
namespace snippets {
namespace pass {

/**
* @interface SoftmaxDecomposition
* @brief Decomposes Softmax to a range of low-level operations
* @ingroup snippets
*/
class SoftmaxDecomposition: public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("SoftmaxDecomposition", "0");
SoftmaxDecomposition();
};

} // namespace pass
} // namespace snippets
} // namespace ov
7 changes: 6 additions & 1 deletion src/common/snippets/src/op/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "snippets/pass/convert_constants.hpp"
#include "snippets/pass/convert_power_to_powerstatic.hpp"
#include "snippets/pass/transpose_decomposition.hpp"
#include "snippets/pass/softmax_decomposition.hpp"
#include "snippets/pass/matmul_to_brgemm.hpp"
#include "snippets/pass/fuse_transpose_brgemm.hpp"
#include "snippets/pass/set_softmax_ports.hpp"
Expand Down Expand Up @@ -401,7 +402,11 @@ void Subgraph::data_flow_transformations(const BlockedShapeVector& blocked_input
manager.register_pass<snippets::pass::MatMulToBrgemm>();
manager.register_pass<snippets::pass::FuseTransposeBrgemm>();
manager.register_pass<snippets::pass::TransposeDecomposition>();
manager.register_pass<snippets::pass::SetSoftmaxPorts>();
if (getenv("DISABLE_DATA_FLOW_DECOMPOSITION")) {
manager.register_pass<snippets::pass::SetSoftmaxPorts>();
} else {
manager.register_pass<snippets::pass::SoftmaxDecomposition>();
}
}
manager.register_pass<snippets::pass::BroadcastToMoveBroadcast>();
manager.register_pass<snippets::pass::ConvertConstantsToScalars>();
Expand Down
73 changes: 73 additions & 0 deletions src/common/snippets/src/pass/softmax_decomposition.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "snippets/pass/softmax_decomposition.hpp"

#include "openvino/op/softmax.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "snippets/itt.hpp"
#include "snippets/lowered/port_descriptor.hpp"
#include "snippets/op/reduce.hpp"
#include "snippets/snippets_isa.hpp"

namespace ov {
namespace snippets {
namespace pass {
using namespace lowered;

SoftmaxDecomposition::SoftmaxDecomposition() {
MATCHER_SCOPE(SoftmaxDecomposition);
auto softmax_v1_m = ov::pass::pattern::wrap_type<ov::op::v1::Softmax>();
auto softmax_v8_m = ov::pass::pattern::wrap_type<ov::op::v8::Softmax>();
auto softmax_m = std::make_shared<ov::pass::pattern::op::Or>(ov::OutputVector{softmax_v1_m, softmax_v8_m});

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::SoftmaxDecomposition")
auto softmax = m.get_match_root();

const auto& pshape = softmax->get_input_partial_shape(0);
OPENVINO_ASSERT(!pshape.rank().is_dynamic(), "SetSoftmaxPorts doesn't support dynamic ranks");
const auto rank = pshape.size();

size_t axis;
if (const auto softmax_v8 = ov::as_type_ptr<ov::op::v8::Softmax>(softmax)) {
OPENVINO_SUPPRESS_DEPRECATED_START
axis = ov::normalize_axis(softmax->get_friendly_name(), softmax_v8->get_axis(), rank);
OPENVINO_SUPPRESS_DEPRECATED_END
} else if (const auto softmax_v1 = ov::as_type_ptr<ov::op::v1::Softmax>(softmax)) {
axis = softmax_v1->get_axis();
} else {
OPENVINO_THROW("Unexpected node matched");
}

const auto& softmax_input = softmax->input_value(0);
const auto reduce_max = std::make_shared<ov::snippets::op::ReduceMax>(softmax_input, axis);
const auto subtract = std::make_shared<ov::op::v1::Subtract>(softmax_input, reduce_max);
const auto exp = std::make_shared<ov::op::v0::Exp>(subtract);

const auto reduce_sum = std::make_shared<ov::snippets::op::ReduceSum>(exp, axis);
const auto power = std::make_shared<ov::snippets::op::PowerStatic>(reduce_sum, -1.f);
const auto multiply = std::make_shared<ov::op::v1::Multiply>(exp, power);

OPENVINO_ASSERT(axis < rank, "Softmax has incorrect axis");
std::vector<size_t> subtensor(rank, 1);
for (size_t i = axis; i < rank; ++i)
subtensor[i] = PortDescriptor::ServiceDimensions::FULL_DIM;

PortDescriptorUtils::set_port_descriptor_ptr(reduce_max->input(0), std::make_shared<PortDescriptor>(reduce_max->input(0), subtensor));
PortDescriptorUtils::set_port_descriptor_ptr(reduce_sum->input(0), std::make_shared<PortDescriptor>(reduce_sum->input(0), subtensor));
PortDescriptorUtils::set_port_descriptor_ptr(reduce_max->output(0), std::make_shared<PortDescriptor>(reduce_max->output(0), subtensor));
PortDescriptorUtils::set_port_descriptor_ptr(reduce_sum->output(0), std::make_shared<PortDescriptor>(reduce_sum->output(0), subtensor));

return ov::replace_node_update_name(softmax, multiply);
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(softmax_m, matcher_name);
register_matcher(m, callback);
}

} // namespace pass
} // namespace snippets
} // namespace ov

0 comments on commit ef4178e

Please sign in to comment.