Skip to content

Commit 2c7622b

Browse files
committed
Buffer operation separation
1 parent e1fc633 commit 2c7622b

File tree

14 files changed

+137
-93
lines changed

14 files changed

+137
-93
lines changed

src/common/snippets/include/snippets/op/buffer.hpp

+46-22
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#pragma once
66

77
#include "openvino/op/op.hpp"
8+
#include "snippets/shape_inference/shape_inference.hpp"
89

910
namespace ov {
1011
namespace snippets {
@@ -13,13 +14,10 @@ namespace op {
1314
/**
1415
* @interface Buffer
1516
* @brief This is a base class for memory storage.
16-
* If Buffer has a parent, the operation is for intermediate data storage - IntermediateMemory type.
17-
* Otherwise, the operation is for allocation of new empty memory with shape `m_shape` - NewMemory type
1817
* Notes:
1918
* - All buffers with the same ID in a graph have the same memory pointer. So if we have a few buffers,
2019
* each the corresponding MemoryAccess op for Buffer should have offset for common memory pointer of this Buffer
2120
* - Buffer should be a single consumer for operation output port
22-
* @param m_type - type of Buffer: IntermediateMemory/NewMemory
2321
* @param m_shape - output allocation shape for Buffer with type NewMemory
2422
* @param m_offset - offset in common Buffer scratchpad
2523
* @param m_id - Buffer ID in common Buffer system
@@ -29,39 +27,65 @@ class Buffer : public ov::op::Op {
2927
public:
3028
OPENVINO_OP("Buffer", "SnippetsOpset");
3129
Buffer() = default;
32-
Buffer(const ov::Shape& shape, ov::element::Type element_type = ov::element::u8, size_t id = 0);
33-
Buffer(const ov::Output<ov::Node>& arg, const ov::Shape& shape, size_t id = 0);
34-
Buffer(const ov::Output<ov::Node>& arg, int32_t allocation_rank = -1, size_t id = 0);
30+
Buffer(const OutputVector& arguments, const ov::Shape& shape, size_t id, ov::element::Type element_type = ov::element::u8);
3531

3632
bool visit_attributes(AttributeVisitor& visitor) override;
37-
void validate_and_infer_types() override;
38-
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
39-
40-
enum Type {
41-
NewMemory,
42-
IntermediateMemory
43-
};
4433

4534
size_t get_id() const { return m_id; }
46-
Type get_type() const { return m_type; }
4735
int64_t get_offset() const { return m_offset; }
4836
void set_id(size_t id) { m_id = id; }
4937
const ov::Shape& get_allocation_shape() const { return m_shape; }
5038
void set_allocation_shape(const ov::Shape& allocation_shape) { m_shape = allocation_shape; }
5139
void set_offset(int64_t offset) { m_offset = offset; }
5240
size_t get_byte_size() const;
5341

54-
void set_element_type(ov::element::Type element_type);
55-
56-
bool is_intermediate_memory() const { return m_type == Type::IntermediateMemory; }
57-
bool is_new_memory() const { return m_type == Type::NewMemory; }
58-
59-
private:
60-
Type m_type = Type::IntermediateMemory;
42+
protected:
6143
ov::Shape m_shape = {};
62-
int64_t m_offset = 0;
6344
size_t m_id = 0; // Default ID - 0. All Buffers are from the same set
6445
ov::element::Type m_element_type = ov::element::u8; // u8 - default 1 byte
46+
int64_t m_offset = 0;
47+
};
48+
49+
/**
50+
* @interface IntermediateMemoryBuffer
51+
* @brief Represents an intermediate memory storage operation. It always has a parent.
52+
* @ingroup snippets
53+
*
54+
*/
55+
class IntermediateMemoryBuffer : public Buffer {
56+
public:
57+
OPENVINO_OP("IntermediateMemoryBuffer", "SnippetsOpset", Buffer);
58+
IntermediateMemoryBuffer(const ov::Output<ov::Node>& arg, const ov::Shape& shape, size_t id = 0);
59+
IntermediateMemoryBuffer(const ov::Output<ov::Node>& arg, int32_t allocation_rank = -1, size_t id = 0);
60+
61+
void validate_and_infer_types() override;
62+
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
63+
64+
private:
65+
ov::Shape compute_shape_from_allocation_rank(const ov::Output<ov::Node>& arg, int32_t allocation_rank);
66+
};
67+
68+
/**
69+
* @interface NewMemoryBuffer
70+
* @brief Represents a new empty memory for allocation with specified shape. It has no parent operations.
71+
* @ingroup snippets
72+
*
73+
*/
74+
class NewMemoryBuffer : public Buffer {
75+
public:
76+
OPENVINO_OP("NewMemoryBuffer", "SnippetsOpset", Buffer);
77+
NewMemoryBuffer(const ov::Shape& shape, size_t id = 0, ov::element::Type element_type = ov::element::u8);
78+
79+
void validate_and_infer_types() override;
80+
void set_element_type(ov::element::Type element_type);
81+
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
82+
83+
class ShapeInfer : public IShapeInferSnippets {
84+
ov::Shape m_shape;
85+
public:
86+
explicit ShapeInfer(const std::shared_ptr<ov::Node>& n);
87+
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
88+
};
6589
};
6690

6791
} // namespace op

src/common/snippets/src/generator.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ Generator::opRegType Generator::get_op_reg_type(const std::shared_ptr<Node>& op)
7676
std::dynamic_pointer_cast<op::LoopBegin>(op) ||
7777
std::dynamic_pointer_cast<op::LoopEnd>(op) ||
7878
std::dynamic_pointer_cast<op::Brgemm>(op) ||
79-
std::dynamic_pointer_cast<op::Buffer>(op) ||
79+
std::dynamic_pointer_cast<op::IntermediateMemoryBuffer>(op) ||
80+
std::dynamic_pointer_cast<op::NewMemoryBuffer>(op) ||
8081
std::dynamic_pointer_cast<op::RankNormalization>(op))
8182
return gpr2gpr;
8283
else if (std::dynamic_pointer_cast<snippets::op::Load>(op) ||

src/common/snippets/src/lowered/pass/allocate_buffers.cpp

+14-15
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,18 @@ void AllocateBuffers::propagate_offset(const LinearIR& linear_ir, const Expressi
2020
buffer->set_offset(static_cast<int64_t>(offset));
2121

2222
// Propagate to up: in Store. Buffer can have only one Store
23-
{
24-
if (buffer->is_intermediate_memory()) {
25-
OPENVINO_ASSERT(buffer_expr->get_input_port_connectors().size() == 1, "Buffer with intermediate memory must have one parent");
26-
const auto& parent_output = buffer_expr->get_input_port_connector(0)->get_source();
27-
const auto& parent_expr = parent_output.get_expr();
28-
const auto port = parent_output.get_index();
29-
const auto& parent_node = parent_expr->get_node();
30-
auto memory_access = ov::as_type_ptr<ov::snippets::op::MemoryAccess>(parent_node);
31-
if (memory_access && memory_access->is_memory_access_output_port(port)) {
32-
memory_access->set_output_offset(offset, port);
33-
} else {
34-
OPENVINO_THROW(
35-
"Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation");
36-
}
23+
if (ov::is_type<op::IntermediateMemoryBuffer>(buffer)) {
24+
OPENVINO_ASSERT(buffer_expr->get_input_port_connectors().size() == 1, "Buffer with intermediate memory must have one parent");
25+
const auto& parent_output = buffer_expr->get_input_port_connector(0)->get_source();
26+
const auto& parent_expr = parent_output.get_expr();
27+
const auto port = parent_output.get_index();
28+
const auto& parent_node = parent_expr->get_node();
29+
auto memory_access = ov::as_type_ptr<ov::snippets::op::MemoryAccess>(parent_node);
30+
if (memory_access && memory_access->is_memory_access_output_port(port)) {
31+
memory_access->set_output_offset(offset, port);
32+
} else {
33+
OPENVINO_THROW(
34+
"Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation");
3735
}
3836
}
3937
// Propagate to down: in Load. Buffer can have several Load
@@ -89,7 +87,7 @@ bool AllocateBuffers::run(LinearIR& linear_ir) {
8987
continue;
9088
}
9189

92-
if (buffer->is_intermediate_memory()) {
90+
if (ov::is_type<op::IntermediateMemoryBuffer>(buffer)) {
9391
const auto& parent_expr = expr->get_input_port_connector(0)->get_source().get_expr();
9492
const auto& parent_node = parent_expr->get_node();
9593
// Full MemoryAccess ops need new memory. Previous logic is to check for parent isn't Loop
@@ -142,6 +140,7 @@ bool AllocateBuffers::run(LinearIR& linear_ir) {
142140
allocated_buffers.insert(expr);
143141
prev_data_size = current_data_size;
144142
} else {
143+
OPENVINO_ASSERT(ov::is_type<op::NewMemoryBuffer>(buffer), "NewMemoryBuffer is expected");
145144
if (!new_memory_buffer_allocated) {
146145
allocate(buffer, *expr_it, buffer_size);
147146
new_memory_buffer_allocated = true;

src/common/snippets/src/lowered/pass/assign_registers.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
6464
} else if (const auto& buffer = ov::as_type_ptr<op::Buffer>(op)) {
6565
const auto buffer_id = buffer->get_id();
6666
// All buffers have one common data pointer
67-
if (buffer->is_intermediate_memory()) {
67+
if (ov::is_type<op::IntermediateMemoryBuffer>(buffer)) {
6868
manually_assigned_gprs[expr->get_input_port_connector(0)] =
6969
static_cast<Reg>(num_results + num_parameters + buffer_id);
7070
}

src/common/snippets/src/lowered/pass/insert_buffers.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::constExprIt&
155155
parent_loops,
156156
parent_expr_output,
157157
m_buffer_allocation_rank);
158-
const auto buffer = std::make_shared<op::Buffer>(parent->output(parent_port), allocation_shape);
158+
const auto buffer = std::make_shared<op::IntermediateMemoryBuffer>(parent->output(parent_port), allocation_shape);
159159
PortDescriptorUtils::set_port_descriptor_ptr(buffer->output(0), parent_expr_output.get_descriptor_ptr()->clone());
160160
// Output connector is automatically filled from PortDescriptor
161161
const auto buffer_expr = linear_ir.create_expression(buffer, {input_connector});
@@ -248,7 +248,7 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::constExprIt&
248248
current_loops,
249249
*exit_port,
250250
m_buffer_allocation_rank);
251-
auto buffer = std::make_shared<op::Buffer>(node->output(port_idx), allocation_shape);
251+
auto buffer = std::make_shared<op::IntermediateMemoryBuffer>(node->output(port_idx), allocation_shape);
252252
PortDescriptorUtils::set_port_descriptor_ptr(buffer->output(0), exit_port->get_descriptor_ptr()->clone());
253253
// We cannot insert Node output connector on Buffer output because not all consumers of Node needs Buffer
254254
// Example:

src/common/snippets/src/lowered/pass/insert_load_store.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ bool InsertLoadStore::run(LinearIR& linear_ir) {
122122
modified |= insert_load(linear_ir, expr_it);
123123
} else if (ov::is_type<ov::op::v0::Result>(node)) {
124124
modified |= insert_store(linear_ir, expr_it);
125-
} else if (auto buffer = ov::as_type_ptr<op::Buffer>(node)) {
125+
} else if (ov::is_type<op::Buffer>(node)) {
126126
modified |= insert_load(linear_ir, expr_it);
127-
if (buffer->is_intermediate_memory())
127+
if (ov::is_type<op::IntermediateMemoryBuffer>(node))
128128
modified |= insert_store(linear_ir, expr_it);
129129
}
130130
}

src/common/snippets/src/op/buffer.cpp

+59-42
Original file line numberDiff line numberDiff line change
@@ -13,79 +13,96 @@ namespace ov {
1313
namespace snippets {
1414
namespace op {
1515

16+
Buffer::Buffer(const OutputVector& arguments, const ov::Shape& shape, size_t id, ov::element::Type element_type)
17+
: Op(arguments), m_shape(shape), m_id(id), m_element_type(std::move(element_type)), m_offset(0) {
18+
constructor_validate_and_infer_types();
19+
}
20+
21+
bool Buffer::visit_attributes(AttributeVisitor& visitor) {
22+
INTERNAL_OP_SCOPE(Buffer_visit_attributes);
23+
visitor.on_attribute("allocation_shape", m_shape);
24+
visitor.on_attribute("offset", m_offset);
25+
visitor.on_attribute("id", m_id);
26+
visitor.on_attribute("element_type", m_element_type);
27+
return true;
28+
}
1629

17-
Buffer::Buffer(const ov::Shape& shape, ov::element::Type element_type, size_t id)
18-
: Op(), m_type(Type::NewMemory), m_shape(shape), m_offset(0), m_id(id), m_element_type(std::move(element_type)) {
30+
size_t Buffer::get_byte_size() const {
31+
const auto shape = get_allocation_shape();
32+
return ov::shape_size(shape) * get_element_type().size();
33+
}
34+
35+
IntermediateMemoryBuffer::IntermediateMemoryBuffer(const ov::Output<ov::Node>& arg, const ov::Shape& shape, size_t id)
36+
: Buffer({arg}, shape, id) {
1937
constructor_validate_and_infer_types();
2038
}
2139

22-
Buffer::Buffer(const ov::Output<ov::Node>& arg, const ov::Shape& shape, size_t id)
23-
: Op({arg}), m_type(Type::IntermediateMemory), m_shape(shape), m_offset(0), m_id(id) {
40+
IntermediateMemoryBuffer::IntermediateMemoryBuffer(const ov::Output<ov::Node>& arg, int32_t allocation_rank, size_t id)
41+
: Buffer({arg}, compute_shape_from_allocation_rank(arg, allocation_rank), id) {
2442
constructor_validate_and_infer_types();
2543
}
2644

27-
Buffer::Buffer(const ov::Output<ov::Node>& arg, int32_t allocation_rank, size_t id)
28-
: Op({arg}), m_type(Type::IntermediateMemory), m_offset(0), m_id(id) {
45+
ov::Shape IntermediateMemoryBuffer::compute_shape_from_allocation_rank(const ov::Output<ov::Node>& arg, int32_t allocation_rank) {
2946
const auto& pshape = arg.get_partial_shape();
3047
OPENVINO_ASSERT(pshape.is_static(), "Buffer supports only static input shape");
3148
const auto shape = pshape.get_shape();
3249
const auto normalize_rank = utils::normalize_rank(static_cast<int32_t>(allocation_rank), shape.size());
3350
const auto offset = static_cast<int32_t>(shape.size()) - normalize_rank;
34-
m_shape = {shape.begin() + offset, shape.end()};
35-
constructor_validate_and_infer_types();
36-
}
37-
38-
bool Buffer::visit_attributes(AttributeVisitor& visitor) {
39-
INTERNAL_OP_SCOPE(Buffer_visit_attributes);
40-
visitor.on_attribute("allocation_shape", m_shape);
41-
visitor.on_attribute("offset", m_offset);
42-
visitor.on_attribute("id", m_id);
43-
visitor.on_attribute("element_type", m_element_type);
44-
return true;
51+
return ov::Shape{shape.begin() + offset, shape.end()};
4552
}
4653

47-
void Buffer::validate_and_infer_types() {
54+
void IntermediateMemoryBuffer::validate_and_infer_types() {
4855
INTERNAL_OP_SCOPE(Buffer_validate_and_infer_types);
4956
ov::PartialShape output_shape;
50-
if (m_type == Type::NewMemory) {
51-
OPENVINO_ASSERT(get_input_size() == 0, "Buffer with new allocated memory must to not have arguments!");
52-
output_shape = m_shape;
53-
} else if (m_type == Type::IntermediateMemory) {
54-
m_element_type = get_input_element_type(0);
55-
output_shape = get_input_partial_shape(0);
56-
} else {
57-
OPENVINO_THROW("Buffer supports only the following types: NewMemory and IntermediateMemory");
58-
}
57+
m_element_type = get_input_element_type(0);
58+
output_shape = get_input_partial_shape(0);
5959
set_output_type(0, m_element_type, output_shape);
6060
}
6161

62-
std::shared_ptr<Node> Buffer::clone_with_new_inputs(const OutputVector& new_args) const {
62+
std::shared_ptr<Node> IntermediateMemoryBuffer::clone_with_new_inputs(const OutputVector& new_args) const {
6363
INTERNAL_OP_SCOPE(Buffer_clone_with_new_inputs);
6464
check_new_args_count(this, new_args);
65-
std::shared_ptr<op::Buffer> new_buffer = nullptr;
66-
if (m_type == Type::NewMemory) {
67-
new_buffer = std::make_shared<Buffer>(m_shape, m_element_type, m_id);
68-
} else if (m_type == Type::IntermediateMemory) {
69-
new_buffer = std::make_shared<Buffer>(new_args.at(0), m_shape, m_id);
70-
} else {
71-
OPENVINO_THROW("Buffer supports only the following types: NewMemory and IntermediateMemory");
72-
}
73-
new_buffer->m_offset = m_offset;
65+
auto new_buffer = std::make_shared<IntermediateMemoryBuffer>(new_args.at(0), m_shape, m_id);
66+
new_buffer->set_offset(m_offset);
7467
return new_buffer;
7568
}
7669

77-
size_t Buffer::get_byte_size() const {
78-
const auto shape = get_allocation_shape();
79-
return ov::shape_size(shape) * get_element_type().size();
70+
NewMemoryBuffer::NewMemoryBuffer(const ov::Shape& shape, size_t id, ov::element::Type element_type)
71+
: Buffer({}, shape, id) {
72+
constructor_validate_and_infer_types();
8073
}
8174

82-
void Buffer::set_element_type(ov::element::Type element_type) {
83-
OPENVINO_ASSERT(is_new_memory(), "Only Buffer with NewMemory can change his output precision!");
75+
void NewMemoryBuffer::validate_and_infer_types() {
76+
INTERNAL_OP_SCOPE(Buffer_validate_and_infer_types);
77+
OPENVINO_ASSERT(get_input_size() == 0, "Buffer with new allocated memory mustn't have arguments!");
78+
set_output_type(0, m_element_type, m_shape);
79+
}
80+
81+
std::shared_ptr<Node> NewMemoryBuffer::clone_with_new_inputs(const OutputVector& new_args) const {
82+
INTERNAL_OP_SCOPE(Buffer_clone_with_new_inputs);
83+
check_new_args_count(this, new_args);
84+
auto new_buffer = std::make_shared<NewMemoryBuffer>(m_shape, m_id, m_element_type);
85+
new_buffer->set_offset(m_offset);
86+
return new_buffer;
87+
}
88+
89+
void NewMemoryBuffer::set_element_type(ov::element::Type element_type) {
8490
m_element_type = std::move(element_type);
8591
// Apply the change
8692
validate_and_infer_types();
8793
}
8894

95+
NewMemoryBuffer::ShapeInfer::ShapeInfer(const std::shared_ptr<ov::Node>& n) {
96+
const auto& buffer = ov::as_type_ptr<NewMemoryBuffer>(n);
97+
OPENVINO_ASSERT(buffer, "Got invalid node in NewMemoryBuffer::ShapeInfer");
98+
m_shape = buffer->get_shape();
99+
}
100+
101+
IShapeInferSnippets::Result NewMemoryBuffer::ShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
102+
OPENVINO_ASSERT(input_shapes.size() == 1, "Got unexpected number of input shapes");
103+
return {{m_shape}, ShapeInferStatus::success};
104+
}
105+
89106
} // namespace op
90107
} // namespace snippets
91108
} // namespace ov

src/common/snippets/src/shape_inference/shape_inference.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ const IShapeInferSnippetsFactory::TRegistry IShapeInferSnippetsFactory::registry
3939
SHAPE_INFER_PREDEFINED(op::ConvertSaturation, PassThroughShapeInfer),
4040
SHAPE_INFER_PREDEFINED(op::Load, PassThroughShapeInfer),
4141
SHAPE_INFER_PREDEFINED(op::Store, PassThroughShapeInfer),
42-
SHAPE_INFER_PREDEFINED(op::Buffer, PassThroughShapeInfer),
42+
SHAPE_INFER_PREDEFINED(op::IntermediateMemoryBuffer, PassThroughShapeInfer),
4343
SHAPE_INFER_PREDEFINED(op::Fill, PassThroughShapeInfer),
4444
SHAPE_INFER_PREDEFINED(ov::op::v0::Parameter, PassThroughShapeInfer),
4545
// Note: We should remove Softmax shape infers after the decomposition activity,
@@ -66,6 +66,7 @@ const IShapeInferSnippetsFactory::TRegistry IShapeInferSnippetsFactory::registry
6666
SHAPE_INFER_OP_SPECIFIC(op::RankNormalization),
6767
SHAPE_INFER_OP_SPECIFIC(op::BroadcastLoad),
6868
SHAPE_INFER_OP_SPECIFIC(op::BroadcastMove),
69+
SHAPE_INFER_OP_SPECIFIC(op::NewMemoryBuffer),
6970
};
7071
#undef SHAPE_INFER_OP_SPECIFIC_EXTERNAL
7172
#undef SHAPE_INFER_OP_SPECIFIC

src/common/snippets/tests/src/lowering_utils.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ DummyTargetMachine::DummyTargetMachine(const std::vector<ov::Node::type_info_t>&
4242
jitters[ov::snippets::op::LoopBegin::get_type_info_static()] = dummy_functor;
4343
jitters[ov::snippets::op::LoopEnd::get_type_info_static()] = dummy_functor;
4444
jitters[ov::snippets::op::Brgemm::get_type_info_static()] = dummy_functor;
45-
jitters[ov::snippets::op::Buffer::get_type_info_static()] = dummy_functor;
45+
jitters[ov::snippets::op::IntermediateMemoryBuffer::get_type_info_static()] = dummy_functor;
46+
jitters[ov::snippets::op::NewMemoryBuffer::get_type_info_static()] = dummy_functor;
4647
jitters[ov::snippets::op::VectorBuffer::get_type_info_static()] = dummy_functor;
4748
jitters[ov::snippets::op::Fill::get_type_info_static()] = dummy_functor;
4849

0 commit comments

Comments
 (0)