|
| 1 | +// Copyright (C) 2023 Intel Corporation |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | +// |
| 4 | + |
| 5 | +#include "snippets/lowered/pass/reduce_max_decomposition.hpp" |
| 6 | + |
| 7 | +#include "snippets/lowered/linear_ir.hpp" |
| 8 | +#include "snippets/lowered/loop_manager.hpp" |
| 9 | +#include "snippets/lowered/pass/mark_loops.hpp" |
| 10 | +#include "snippets/lowered/pass/iter_handler.hpp" |
| 11 | +#include "snippets/snippets_isa.hpp" |
| 12 | +#include "snippets/itt.hpp" |
| 13 | + |
| 14 | +#include "openvino/pass/pattern/op/wrap_type.hpp" |
| 15 | +#include "openvino/pass/pattern/matcher.hpp" |
| 16 | + |
| 17 | + |
| 18 | +namespace ov { |
| 19 | +namespace snippets { |
| 20 | +namespace lowered { |
| 21 | +namespace pass { |
| 22 | + |
| 23 | +using LoopInfo = LinearIR::LoopManager::LoopInfo; |
| 24 | + |
| 25 | +ReduceMaxDecomposition::ReduceMaxDecomposition(size_t vector_size) : m_vector_size{vector_size} {} |
| 26 | + |
| 27 | +bool ReduceMaxDecomposition::run(LinearIR& linear_ir) { |
| 28 | + OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ReduceMaxDecompositionLowered") |
| 29 | + const auto& loop_manager = linear_ir.get_loop_manager(); |
| 30 | + |
| 31 | + bool modified = false; |
| 32 | + for (auto expr_it = linear_ir.begin(); expr_it != linear_ir.end(); expr_it++) { |
| 33 | + const auto& op = (*expr_it)->get_node(); |
| 34 | + if (!ov::is_type<ov::snippets::op::ReduceMax>(op)) |
| 35 | + continue; |
| 36 | + |
| 37 | + const auto reduce = op; |
| 38 | + const auto reduce_expr = *expr_it; |
| 39 | + const auto& input_shape = reduce_expr->get_input_port_descriptor(0)->get_shape(); |
| 40 | + const auto work_amount = *(input_shape.rbegin()); |
| 41 | + const bool is_dynamic = reduce->is_dynamic(); |
| 42 | + |
| 43 | + // We need an iterator to the inserted element |
| 44 | + auto push_node = [&](const std::shared_ptr<Node>& n) { |
| 45 | + const auto expr = linear_ir.insert(expr_it, n); |
| 46 | + if (is_dynamic) |
| 47 | + expr->get()->updateShapes(); |
| 48 | + return std::make_pair(expr, n); |
| 49 | + }; |
| 50 | + // Float constant values in byte representation |
| 51 | + const auto fill_value = uint32_t(0xff7fffff); |
| 52 | + // Note: VectorBuffer is a special case, since it should go before the initial Load. |
| 53 | + // The buffer must be initialized with fill_value before reduction |
| 54 | + const auto vector_buffer = push_node(std::make_shared<op::VectorBuffer>()); |
| 55 | + const auto initial_fill = push_node(std::make_shared<op::Fill>(vector_buffer.second, 0, fill_value)); |
| 56 | + |
| 57 | + // Reduce loop |
| 58 | + const auto fill = push_node(std::make_shared<op::Fill>(reduce->get_input_source_output(0), m_vector_size, fill_value)); |
| 59 | + const auto max = push_node(std::make_shared<ov::op::v1::Maximum>(fill.second, initial_fill.second)); |
| 60 | + |
| 61 | + const auto reduce_loop_id = loop_manager->mark_loop( |
| 62 | + fill.first, |
| 63 | + expr_it, |
| 64 | + work_amount, |
| 65 | + m_vector_size, |
| 66 | + 0, |
| 67 | + std::vector<ExpressionPort>{(*fill.first)->get_input_port(0), (*max.first)->get_input_port(1)}, |
| 68 | + std::vector<ExpressionPort>{(*max.first)->get_output_port(0)}); |
| 69 | + const auto reduce_loop_info = loop_manager->get_loop_info(reduce_loop_id); |
| 70 | + const auto tail_size = work_amount % m_vector_size; |
| 71 | + if (tail_size != 0) { |
| 72 | + reduce_loop_info->handlers[LoopInfo::LAST_ITER].register_pass<DefaultTailLoopHandler>(tail_size); |
| 73 | + reduce_loop_info->handlers[LoopInfo::LAST_ITER].register_pass<SetFillOffset>(tail_size); |
| 74 | + if (work_amount > m_vector_size) { |
| 75 | + reduce_loop_info->handlers[LoopInfo::MAIN_BODY].register_pass<ReduceWorkAmount>(tail_size); |
| 76 | + reduce_loop_info->handlers[LoopInfo::MAIN_BODY].register_pass<ZeroFinalizationOffsets>(); |
| 77 | + } |
| 78 | + } |
| 79 | + |
| 80 | + const auto horizon = push_node(std::make_shared<op::HorizonMax>(max.second)); |
| 81 | + |
| 82 | + // Transfer original ExpressionPorts |
| 83 | + linear_ir.replace_input((*fill.first)->get_input_port(0), reduce_expr->get_input_port_connector(0)); |
| 84 | + linear_ir.replace_input(reduce_expr->get_output_port_connector(0)->get_consumers(), (*horizon.first)->get_output_port_connector(0)); |
| 85 | + |
| 86 | + // Update Loop info for outer loops |
| 87 | + const std::vector<ExpressionPort> entry_points{(*fill.first)->get_input_port(0)}; |
| 88 | + const std::vector<ExpressionPort> exit_points{(*horizon.first)->get_output_port(0)}; |
| 89 | + for (auto loop_id : reduce_expr->get_loop_ids()) { |
| 90 | + loop_manager->expression_replacement(vector_buffer.first, |
| 91 | + expr_it, |
| 92 | + reduce_expr, |
| 93 | + loop_id, |
| 94 | + entry_points, |
| 95 | + exit_points); |
| 96 | + } |
| 97 | + |
| 98 | + expr_it = linear_ir.erase(expr_it); |
| 99 | + modified = true; |
| 100 | + } |
| 101 | + return modified; |
| 102 | +} |
| 103 | + |
| 104 | +} // namespace pass |
| 105 | +} // namespace lowered |
| 106 | +} // namespace snippets |
| 107 | +} // namespace ov |
0 commit comments