Skip to content

Commit

Permalink
Ivan's comments applied: 1st part
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Feb 1, 2024
1 parent 7048b7a commit 0cad73b
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ namespace pass {
* @attention Only Reduce by last dimension is supported
* @ingroup snippets
*/
class ReduceDecomposition : public Pass {
class ReduceDecomposition : public RangedPass {
public:
OPENVINO_RTTI("ReduceDecomposition", "Pass")
OPENVINO_RTTI("ReduceDecomposition", "RangedPass")
explicit ReduceDecomposition(size_t vector_size);
bool run(LinearIR& linear_ir) override;
bool run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;

private:
size_t m_vector_size;
size_t m_vector_size = 0;
};

} // namespace pass
Expand Down
2 changes: 1 addition & 1 deletion src/common/snippets/include/snippets/op/reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class ReduceBase : public ov::op::Op {
size_t get_axis() const { return m_axis; }

protected:
size_t m_axis;
size_t m_axis = 0;
};

class ReduceSum : public ReduceBase {
Expand Down
6 changes: 3 additions & 3 deletions src/common/snippets/src/lowered/pass/reduce_decomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ std::shared_ptr<ov::Node> get_horizon_node(const ov::Output<ov::Node>& input, co
using LoopInfo = LinearIR::LoopManager::LoopInfo;
using HandlerType = LoopInfo::SpecificIterationHandlers::HandlerType;

ReduceDecomposition::ReduceDecomposition(size_t vector_size) : m_vector_size{vector_size} {}
ReduceDecomposition::ReduceDecomposition(size_t vector_size) : RangedPass(), m_vector_size{vector_size} {}

bool ReduceDecomposition::run(LinearIR& linear_ir) {
bool ReduceDecomposition::run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ReduceMaxDecompositionLowered")
const auto& loop_manager = linear_ir.get_loop_manager();
bool modified = false;
for (auto expr_it = linear_ir.begin(); expr_it != linear_ir.end(); expr_it++) {
for (auto expr_it = begin; expr_it != end; expr_it++) {
const auto& reduce_expr = *expr_it;
const auto& reduce = ov::as_type_ptr<ov::snippets::op::ReduceBase>(reduce_expr->get_node());
if (!reduce)
Expand Down
8 changes: 2 additions & 6 deletions src/common/snippets/src/lowered/target_machine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,14 @@
using namespace ov::snippets;
std::function<std::shared_ptr<Emitter>(const lowered::ExpressionPtr&)> TargetMachine::get(const ov::DiscreteTypeInfo& type) const {
auto jitter = jitters.find(type);
if (jitter == jitters.end()) {
OPENVINO_THROW(std::string("Target code emitter is not available for ") + type.name + " operation.");
}
OPENVINO_ASSERT(jitter != jitters.end(), "Target code emitter is not available for ", type.name, " operation.");
return jitter->second.first;
}

std::function<std::set<ov::element::TypeVector>(const std::shared_ptr<ov::Node>&)>
TargetMachine::get_supported_precisions(const ov::DiscreteTypeInfo& type) const {
auto jitter = jitters.find(type);
if (jitter == jitters.end()) {
OPENVINO_THROW(std::string("Target code emitter is not available for ") + type.name + " operation.");
}
OPENVINO_ASSERT(jitter != jitters.end(), "Supported precisions set is not available for ", type.name, " operation.");
return jitter->second.second;
}

Expand Down
4 changes: 0 additions & 4 deletions src/common/snippets/src/pass/propagate_precision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ bool ov::snippets::pass::PropagatePrecision::run_on_model(const std::shared_ptr<
bool was_updated = false;
for (const auto& op : f->get_ordered_ops()) {
auto type_info = op->get_type_info();
OPENVINO_ASSERT(
target_machine->has(type_info),
"operation '" + std::string(type_info.version_id) + "::" + std::string(type_info.name) + "' was not found in target machine");

auto exec = target_machine->get_supported_precisions(type_info);
const auto& supported_precisions = exec(op);

Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/src/shape_inference/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ const IShapeInferSnippetsFactory::TRegistry IShapeInferSnippetsFactory::registry
SHAPE_INFER_PREDEFINED(op::Nop, EmptyShapeInfer),
SHAPE_INFER_OP_SPECIFIC_EXTERNAL(opset1::Select, SelectShapeInfer),
SHAPE_INFER_OP_SPECIFIC_EXTERNAL(op::Brgemm, BrgemmShapeInfer),
SHAPE_INFER_OP_SPECIFIC_EXTERNAL(op::ReduceMax, ReduceShapeInfer),
SHAPE_INFER_OP_SPECIFIC_EXTERNAL(op::ReduceSum, ReduceShapeInfer),
// Note that Result has no output PortConnectors, so the shape must be empty
SHAPE_INFER_PREDEFINED(ov::op::v0::Result, EmptyShapeInfer),
//
Expand All @@ -84,8 +86,6 @@ std::shared_ptr<IShapeInferSnippets> make_shape_inference(const std::shared_ptr<
ov::is_type<ov::op::util::BinaryElementwiseComparison>(op) ||
ov::is_type<ov::op::util::BinaryElementwiseLogical>(op)) {
return std::make_shared<NumpyBroadcastShapeInfer>();
} else if (ov::is_type<ov::snippets::op::ReduceBase>(op)) {
return std::make_shared<ReduceShapeInfer>(op);
} else {
OPENVINO_THROW("Operation type " + std::string(op->get_type_info().name) + " is not supported in Snippets shape inference pipeline");
}
Expand Down

0 comments on commit 0cad73b

Please sign in to comment.