Skip to content

Commit

Permalink
ReduceSum/Max::make removed
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Feb 2, 2024
1 parent 64943f3 commit 6af22d2
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 47 deletions.
13 changes: 1 addition & 12 deletions src/common/snippets/include/snippets/op/reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class ReduceBase : public ov::op::Op {
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
size_t get_axis() const { return m_axis; }
static void compute_and_set_reduce_subtensors(const std::shared_ptr<ReduceBase>& reduce);

protected:
size_t m_axis = 0;
Expand All @@ -38,12 +39,6 @@ class ReduceSum : public ReduceBase {
ReduceSum(const Output<Node>& x, size_t axis) : ReduceBase(x, axis) {}
ReduceSum() = default;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
/**
* @brief Creates ReduceSum operation, computes and sets subtensors to input/output PortDescriptors
* @param x Reduce input
* @param axis Reduce axis
*/
static std::shared_ptr<ReduceSum> make(const Output<Node>& x, size_t axis);
};

class ReduceMax : public ReduceBase {
Expand All @@ -52,12 +47,6 @@ class ReduceMax : public ReduceBase {
ReduceMax(const Output<Node>& x, size_t axis) : ReduceBase(x, axis) {}
ReduceMax() = default;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
/**
* @brief Creates ReduceMax operation, computes and sets subtensors to input/output PortDescriptors
* @param x Reduce input
* @param axis Reduce axis
*/
static std::shared_ptr<ReduceMax> make(const Output<Node>& x, size_t axis);
};

} // namespace op
Expand Down
40 changes: 13 additions & 27 deletions src/common/snippets/src/op/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,6 @@
namespace ov {
namespace snippets {
namespace op {
namespace {
void compute_and_set_reduce_subtensors(const std::shared_ptr<ReduceBase>& reduce) {
OPENVINO_ASSERT(reduce->get_input_partial_shape(0).rank().is_static(),
"Subtensors can be automatically calculated only for reduce with static rank.");
const auto reduce_rank = reduce->get_input_partial_shape(0).size();
const auto axis = reduce->get_axis();

std::vector<size_t> subtensor(reduce_rank, 1);
for (size_t i = axis; i < reduce_rank; ++i)
subtensor[i] = lowered::PortDescriptor::ServiceDimensions::FULL_DIM;
lowered::PortDescriptorUtils::set_port_descriptor_ptr(reduce->input(0), std::make_shared<lowered::PortDescriptor>(reduce->input(0), subtensor));
lowered::PortDescriptorUtils::set_port_descriptor_ptr(reduce->output(0), std::make_shared<lowered::PortDescriptor>(reduce->output(0), subtensor));
}
} // namespace

ReduceBase::ReduceBase(const Output<Node>& x, size_t axis) : Op({x}), m_axis(axis) {
constructor_validate_and_infer_types();
}
Expand All @@ -40,30 +25,31 @@ void ReduceBase::validate_and_infer_types() {
set_output_type(0, get_input_element_type(0), result_shape);
}

void ReduceBase::compute_and_set_reduce_subtensors(const std::shared_ptr<ReduceBase>& reduce) {
OPENVINO_ASSERT(reduce->get_input_partial_shape(0).rank().is_static(),
"Subtensors can be automatically calculated only for reduce with static rank.");
const auto reduce_rank = reduce->get_input_partial_shape(0).size();
const auto axis = reduce->get_axis();

std::vector<size_t> subtensor(reduce_rank, 1);
for (size_t i = axis; i < reduce_rank; ++i)
subtensor[i] = lowered::PortDescriptor::ServiceDimensions::FULL_DIM;
lowered::PortDescriptorUtils::set_port_descriptor_ptr(reduce->input(0), std::make_shared<lowered::PortDescriptor>(reduce->input(0), subtensor));
lowered::PortDescriptorUtils::set_port_descriptor_ptr(reduce->output(0), std::make_shared<lowered::PortDescriptor>(reduce->output(0), subtensor));
}

std::shared_ptr<Node> ReduceSum::clone_with_new_inputs(const OutputVector& new_args) const {
INTERNAL_OP_SCOPE(ReduceSum);
check_new_args_count(this, new_args);
return std::make_shared<ReduceSum>(new_args.at(0), m_axis);
}

std::shared_ptr<ReduceSum> ReduceSum::make(const Output<Node>& x, size_t axis) {
const auto reduce = std::make_shared<ReduceSum>(x, axis);
compute_and_set_reduce_subtensors(reduce);
return reduce;
}

std::shared_ptr<Node> ReduceMax::clone_with_new_inputs(const OutputVector& new_args) const {
INTERNAL_OP_SCOPE(ReduceMax);
check_new_args_count(this, new_args);
return std::make_shared<ReduceMax>(new_args.at(0), m_axis);
}

std::shared_ptr<ReduceMax> ReduceMax::make(const Output<Node>& x, size_t axis) {
const auto reduce = std::make_shared<ReduceMax>(x, axis);
compute_and_set_reduce_subtensors(reduce);
return reduce;
}

} // namespace op
} // namespace snippets
} // namespace ov
5 changes: 3 additions & 2 deletions src/common/snippets/src/pass/reduce_to_snippets_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ snippets::pass::ReduceToSnippetsReduce::ReduceToSnippetsReduce() {

std::shared_ptr<snippets::op::ReduceBase> snippets_reduce = nullptr;
if (ov::is_type<ov::op::v1::ReduceSum>(reduce))
snippets_reduce = ov::snippets::op::ReduceSum::make(data_input, axis);
snippets_reduce = std::make_shared<ov::snippets::op::ReduceSum>(data_input, axis);
else if (ov::is_type<ov::op::v1::ReduceMax>(reduce))
snippets_reduce = ov::snippets::op::ReduceMax::make(data_input, axis);
snippets_reduce = std::make_shared<ov::snippets::op::ReduceMax>(data_input, axis);
else
OPENVINO_THROW("Reduce ", reduce, " can't be converted to snippets opset.");
ov::snippets::op::ReduceBase::compute_and_set_reduce_subtensors(snippets_reduce);

ov::replace_node(reduce, snippets_reduce);
snippets_reduce->set_friendly_name(reduce->get_friendly_name());
Expand Down
6 changes: 4 additions & 2 deletions src/common/snippets/src/pass/softmax_decomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ SoftmaxDecomposition::SoftmaxDecomposition() {
}

const auto& softmax_input = softmax->input_value(0);
const auto reduce_max = ov::snippets::op::ReduceMax::make(softmax_input, axis);
const auto reduce_max = std::make_shared<ov::snippets::op::ReduceMax>(softmax_input, axis);
ov::snippets::op::ReduceBase::compute_and_set_reduce_subtensors(reduce_max);
const auto subtract = std::make_shared<ov::op::v1::Subtract>(softmax_input, reduce_max);
const auto exp = std::make_shared<ov::op::v0::Exp>(subtract);

const auto reduce_sum = ov::snippets::op::ReduceSum::make(exp, axis);
const auto reduce_sum = std::make_shared<ov::snippets::op::ReduceSum>(exp, axis);
ov::snippets::op::ReduceBase::compute_and_set_reduce_subtensors(reduce_sum);
const auto power = std::make_shared<ov::snippets::op::PowerStatic>(reduce_sum, -1.f);
const auto multiply = std::make_shared<ov::op::v1::Multiply>(exp, power);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,13 @@ std::shared_ptr<ov::Model> MHABufferAllocationTest::GetModel() const {
const auto relu1 = std::make_shared<ov::op::v0::Relu>(matmul0);

// Decomposed Softmax
const auto reduce_max = ov::snippets::op::ReduceMax::make(relu1, 3);
const auto reduce_max = std::make_shared<ov::snippets::op::ReduceMax>(relu1, 3);
ov::snippets::op::ReduceBase::compute_and_set_reduce_subtensors(reduce_max);
const auto subtract = std::make_shared<ov::op::v1::Subtract>(relu1, reduce_max);
const auto exp = std::make_shared<ov::op::v0::Exp>(subtract);

const auto reduce_sum = ov::snippets::op::ReduceSum::make(exp, 3);
const auto reduce_sum = std::make_shared<ov::snippets::op::ReduceSum>(exp, 3);
ov::snippets::op::ReduceBase::compute_and_set_reduce_subtensors(reduce_sum);
const auto power = std::make_shared<ov::snippets::op::PowerStatic>(reduce_sum, -1.f);
const auto multiply = std::make_shared<ov::op::v1::Multiply>(exp, power);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,13 @@ class MHABF16AMXBufferAllocationTest : public BufferAllocationCPUTest {
const auto relu1 = std::make_shared<ov::op::v0::Relu>(brgemm_cpu0);

// Decomposed Softmax
const auto reduce_max = ov::snippets::op::ReduceMax::make(relu1, 3);
const auto reduce_max = std::make_shared<ov::snippets::op::ReduceMax>(relu1, 3);
ov::snippets::op::ReduceBase::compute_and_set_reduce_subtensors(reduce_max);
const auto subtract = std::make_shared<ov::op::v1::Subtract>(relu1, reduce_max);
const auto exp = std::make_shared<ov::op::v0::Exp>(subtract);

const auto reduce_sum = ov::snippets::op::ReduceSum::make(exp, 3);
const auto reduce_sum = std::make_shared<ov::snippets::op::ReduceSum>(exp, 3);
ov::snippets::op::ReduceBase::compute_and_set_reduce_subtensors(reduce_sum);
const auto power = std::make_shared<ov::snippets::op::PowerStatic>(reduce_sum, -1.f);
const auto multiply = std::make_shared<ov::op::v1::Multiply>(exp, power);

Expand Down

0 comments on commit 6af22d2

Please sign in to comment.