Skip to content

Commit

Permalink
Softmax loops in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Nov 17, 2023
1 parent 0aaab75 commit 14c45e1
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 43 deletions.
2 changes: 0 additions & 2 deletions src/common/snippets/include/snippets/lowered/linear_ir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ class LinearIR {

exprIt erase(exprIt pos);
exprIt erase(constExprIt pos);
exprIt erase(exprIt begin, exprIt end);
exprIt erase(constExprIt begin, constExprIt end);

constExprIt find(const ExpressionPtr& target) const;
template<typename iterator>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ class ZeroFinalizationOffsets : public pass::SubgraphPass {
bool run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;
};

class SoftmaxTailLoopHandler : public pass::SubgraphPass {
class InsertFill : public pass::SubgraphPass {
public:
SoftmaxTailLoopHandler(size_t tail_size);
OPENVINO_RTTI("SoftmaxTailLoopHandler", "Pass")
InsertFill(size_t tail_size);
OPENVINO_RTTI("InsertFill", "Pass")
bool run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;

private:
Expand Down
12 changes: 0 additions & 12 deletions src/common/snippets/src/lowered/linear_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,18 +293,6 @@ LinearIR::exprIt LinearIR::erase(LinearIR::constExprIt pos) {
return m_expressions.erase(pos);
}

LinearIR::exprIt LinearIR::erase(LinearIR::exprIt begin, LinearIR::exprIt end) {
for (auto b = begin; b != end; b++)
unregister_expression(*b);
return m_expressions.erase(begin, end);
}

LinearIR::exprIt LinearIR::erase(LinearIR::constExprIt begin, LinearIR::constExprIt end) {
for (auto b = begin; b != end; b++)
unregister_expression(*b);
return m_expressions.erase(begin, end);
}

void LinearIR::move(LinearIR::constExprIt from, LinearIR::constExprIt to) {
// Instead of `insert()` + `erase()`, we use `splice()` for the same list
m_expressions.splice(to, m_expressions, from);
Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/src/lowered/pass/iter_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ bool ZeroFinalizationOffsets::run(const LinearIR& linear_ir, LinearIR::constExpr
return true;
}

SoftmaxTailLoopHandler::SoftmaxTailLoopHandler(size_t tail_size) : m_tail_size(tail_size) {}
InsertFill::InsertFill(size_t tail_size) : m_tail_size(tail_size) {}

bool SoftmaxTailLoopHandler::run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) {
bool InsertFill::run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) {
const auto& config = linear_ir.get_config();
if (!config.m_need_fill_tail_register)
return false;
Expand Down
44 changes: 20 additions & 24 deletions src/common/snippets/src/lowered/pass/softmax_decomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ namespace snippets {
namespace lowered {
namespace pass {

using LoopInfo = LinearIR::LoopManager::LoopInfo;

SoftmaxDecomposition::SoftmaxDecomposition(size_t vector_size) : m_vector_size{vector_size} {}

bool SoftmaxDecomposition::run(LinearIR& linear_ir) {
Expand Down Expand Up @@ -69,14 +71,13 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir) {
(*max.first)->get_input_port(1)},
std::vector<ExpressionPort>{(*max.first)->get_output_port(0)});
const auto& reduce_max_loop_info = loop_manager->get_loop_info(reduce_max_loop_id);
if (inner_work_amount < m_vector_size) {
reduce_max_loop_info->handlers[LinearIR::LoopManager::LoopInfo::MAIN_BODY].register_pass<DefaultTailLoopHandler>(inner_work_amount);
reduce_max_loop_info->handlers[LinearIR::LoopManager::LoopInfo::MAIN_BODY].register_pass<SoftmaxTailLoopHandler>(inner_work_amount);
} else {
const auto tail_size = inner_work_amount % m_vector_size;
if (tail_size != 0) {
reduce_max_loop_info->handlers[LinearIR::LoopManager::LoopInfo::LAST_ITER].register_pass<DefaultTailLoopHandler>(tail_size);
reduce_max_loop_info->handlers[LinearIR::LoopManager::LoopInfo::LAST_ITER].register_pass<SoftmaxTailLoopHandler>(tail_size);
const auto tail_size = inner_work_amount % m_vector_size;
if (tail_size != 0) {
reduce_max_loop_info->handlers[LoopInfo::LAST_ITER].register_pass<DefaultTailLoopHandler>(tail_size);
reduce_max_loop_info->handlers[LoopInfo::LAST_ITER].register_pass<InsertFill>(tail_size);
if (inner_work_amount > m_vector_size) {
reduce_max_loop_info->handlers[LoopInfo::MAIN_BODY].register_pass<ReduceWorkAmount>(tail_size);
reduce_max_loop_info->handlers[LoopInfo::MAIN_BODY].register_pass<ZeroFinalizationOffsets>();
}
}
const auto broadcast_horizon_max = push_node(std::make_shared<op::BroadcastMove>(horizon_max.second, broadcasted_dim));
Expand All @@ -99,14 +100,12 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir) {
std::vector<ExpressionPort>{(*exp.first)->get_output_port(0),
(*sum.first)->get_output_port(0)});
const auto& reduce_sum_loop_info = loop_manager->get_loop_info(reduce_sum_loop_id);
if (inner_work_amount < m_vector_size) {
reduce_sum_loop_info->handlers[LinearIR::LoopManager::LoopInfo::MAIN_BODY].register_pass<DefaultTailLoopHandler>(inner_work_amount);
reduce_sum_loop_info->handlers[LinearIR::LoopManager::LoopInfo::MAIN_BODY].register_pass<SoftmaxTailLoopHandler>(inner_work_amount);
} else {
const auto tail_size = inner_work_amount % m_vector_size;
if (tail_size != 0) {
reduce_sum_loop_info->handlers[LinearIR::LoopManager::LoopInfo::LAST_ITER].register_pass<DefaultTailLoopHandler>(tail_size);
reduce_sum_loop_info->handlers[LinearIR::LoopManager::LoopInfo::LAST_ITER].register_pass<SoftmaxTailLoopHandler>(tail_size);
if (tail_size != 0) {
reduce_max_loop_info->handlers[LoopInfo::LAST_ITER].register_pass<InsertFill>(tail_size);
reduce_sum_loop_info->handlers[LoopInfo::LAST_ITER].register_pass<DefaultTailLoopHandler>(tail_size);
if (inner_work_amount > m_vector_size) {
reduce_sum_loop_info->handlers[LoopInfo::MAIN_BODY].register_pass<ReduceWorkAmount>(tail_size);
reduce_sum_loop_info->handlers[LoopInfo::MAIN_BODY].register_pass<ZeroFinalizationOffsets>();
}
}

Expand All @@ -128,14 +127,11 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir) {
(*mul.first)->get_input_port(1)},
std::vector<ExpressionPort>{(*mul.first)->get_output_port(0)});
const auto& mul_loop_info = loop_manager->get_loop_info(mul_loop_id);
if (inner_work_amount < m_vector_size) {
mul_loop_info->handlers[LinearIR::LoopManager::LoopInfo::MAIN_BODY].register_pass<DefaultTailLoopHandler>(inner_work_amount);
mul_loop_info->handlers[LinearIR::LoopManager::LoopInfo::MAIN_BODY].register_pass<SoftmaxTailLoopHandler>(inner_work_amount);
} else {
const auto tail_size = inner_work_amount % m_vector_size;
if (tail_size != 0) {
mul_loop_info->handlers[LinearIR::LoopManager::LoopInfo::LAST_ITER].register_pass<DefaultTailLoopHandler>(tail_size);
mul_loop_info->handlers[LinearIR::LoopManager::LoopInfo::LAST_ITER].register_pass<SoftmaxTailLoopHandler>(tail_size);
if (tail_size != 0) {
mul_loop_info->handlers[LoopInfo::LAST_ITER].register_pass<DefaultTailLoopHandler>(tail_size);
if (inner_work_amount > m_vector_size) {
mul_loop_info->handlers[LoopInfo::MAIN_BODY].register_pass<ReduceWorkAmount>(tail_size);
mul_loop_info->handlers[LoopInfo::MAIN_BODY].register_pass<ZeroFinalizationOffsets>();
}
}

Expand Down

0 comments on commit 14c45e1

Please sign in to comment.