@@ -13,79 +13,96 @@ namespace ov {
13
13
namespace snippets {
14
14
namespace op {
15
15
16
+ Buffer::Buffer (const OutputVector& arguments, const ov::Shape& shape, size_t id, ov::element::Type element_type)
17
+ : Op(arguments), m_shape(shape), m_id(id), m_element_type(std::move(element_type)), m_offset(0 ) {
18
+ constructor_validate_and_infer_types ();
19
+ }
20
+
21
+ bool Buffer::visit_attributes (AttributeVisitor& visitor) {
22
+ INTERNAL_OP_SCOPE (Buffer_visit_attributes);
23
+ visitor.on_attribute (" allocation_shape" , m_shape);
24
+ visitor.on_attribute (" offset" , m_offset);
25
+ visitor.on_attribute (" id" , m_id);
26
+ visitor.on_attribute (" element_type" , m_element_type);
27
+ return true ;
28
+ }
16
29
17
- Buffer::Buffer (const ov::Shape& shape, ov::element::Type element_type, size_t id)
18
- : Op(), m_type(Type::NewMemory), m_shape(shape), m_offset(0 ), m_id(id), m_element_type(std::move(element_type)) {
30
+ size_t Buffer::get_byte_size () const {
31
+ const auto shape = get_allocation_shape ();
32
+ return ov::shape_size (shape) * get_element_type ().size ();
33
+ }
34
+
35
+ IntermediateMemoryBuffer::IntermediateMemoryBuffer (const ov::Output<ov::Node>& arg, const ov::Shape& shape, size_t id)
36
+ : Buffer({arg}, shape, id) {
19
37
constructor_validate_and_infer_types ();
20
38
}
21
39
22
- Buffer::Buffer (const ov::Output<ov::Node>& arg, const ov::Shape& shape , size_t id)
23
- : Op ({arg}), m_type(Type::IntermediateMemory), m_shape(shape ), m_offset( 0 ), m_id( id) {
40
+ IntermediateMemoryBuffer::IntermediateMemoryBuffer (const ov::Output<ov::Node>& arg, int32_t allocation_rank , size_t id)
41
+ : Buffer ({arg}, compute_shape_from_allocation_rank(arg, allocation_rank ), id) {
24
42
constructor_validate_and_infer_types ();
25
43
}
26
44
27
- Buffer::Buffer (const ov::Output<ov::Node>& arg, int32_t allocation_rank, size_t id)
28
- : Op({arg}), m_type(Type::IntermediateMemory), m_offset(0 ), m_id(id) {
45
+ ov::Shape IntermediateMemoryBuffer::compute_shape_from_allocation_rank (const ov::Output<ov::Node>& arg, int32_t allocation_rank) {
29
46
const auto & pshape = arg.get_partial_shape ();
30
47
OPENVINO_ASSERT (pshape.is_static (), " Buffer supports only static input shape" );
31
48
const auto shape = pshape.get_shape ();
32
49
const auto normalize_rank = utils::normalize_rank (static_cast <int32_t >(allocation_rank), shape.size ());
33
50
const auto offset = static_cast <int32_t >(shape.size ()) - normalize_rank;
34
- m_shape = {shape.begin () + offset, shape.end ()};
35
- constructor_validate_and_infer_types ();
36
- }
37
-
38
- bool Buffer::visit_attributes (AttributeVisitor& visitor) {
39
- INTERNAL_OP_SCOPE (Buffer_visit_attributes);
40
- visitor.on_attribute (" allocation_shape" , m_shape);
41
- visitor.on_attribute (" offset" , m_offset);
42
- visitor.on_attribute (" id" , m_id);
43
- visitor.on_attribute (" element_type" , m_element_type);
44
- return true ;
51
+ return ov::Shape{shape.begin () + offset, shape.end ()};
45
52
}
46
53
47
- void Buffer ::validate_and_infer_types () {
54
+ void IntermediateMemoryBuffer ::validate_and_infer_types () {
48
55
INTERNAL_OP_SCOPE (Buffer_validate_and_infer_types);
49
56
ov::PartialShape output_shape;
50
- if (m_type == Type::NewMemory) {
51
- OPENVINO_ASSERT (get_input_size () == 0 , " Buffer with new allocated memory must to not have arguments!" );
52
- output_shape = m_shape;
53
- } else if (m_type == Type::IntermediateMemory) {
54
- m_element_type = get_input_element_type (0 );
55
- output_shape = get_input_partial_shape (0 );
56
- } else {
57
- OPENVINO_THROW (" Buffer supports only the following types: NewMemory and IntermediateMemory" );
58
- }
57
+ m_element_type = get_input_element_type (0 );
58
+ output_shape = get_input_partial_shape (0 );
59
59
set_output_type (0 , m_element_type, output_shape);
60
60
}
61
61
62
- std::shared_ptr<Node> Buffer ::clone_with_new_inputs (const OutputVector& new_args) const {
62
+ std::shared_ptr<Node> IntermediateMemoryBuffer ::clone_with_new_inputs (const OutputVector& new_args) const {
63
63
INTERNAL_OP_SCOPE (Buffer_clone_with_new_inputs);
64
64
check_new_args_count (this , new_args);
65
- std::shared_ptr<op::Buffer> new_buffer = nullptr ;
66
- if (m_type == Type::NewMemory) {
67
- new_buffer = std::make_shared<Buffer>(m_shape, m_element_type, m_id);
68
- } else if (m_type == Type::IntermediateMemory) {
69
- new_buffer = std::make_shared<Buffer>(new_args.at (0 ), m_shape, m_id);
70
- } else {
71
- OPENVINO_THROW (" Buffer supports only the following types: NewMemory and IntermediateMemory" );
72
- }
73
- new_buffer->m_offset = m_offset;
65
+ auto new_buffer = std::make_shared<IntermediateMemoryBuffer>(new_args.at (0 ), m_shape, m_id);
66
+ new_buffer->set_offset (m_offset);
74
67
return new_buffer;
75
68
}
76
69
77
- size_t Buffer::get_byte_size () const {
78
- const auto shape = get_allocation_shape ();
79
- return ov::shape_size (shape) * get_element_type (). size ();
70
+ NewMemoryBuffer::NewMemoryBuffer ( const ov::Shape& shape, size_t id, ov::element::Type element_type)
71
+ : Buffer({}, shape, id) {
72
+ constructor_validate_and_infer_types ();
80
73
}
81
74
82
- void Buffer::set_element_type (ov::element::Type element_type) {
83
- OPENVINO_ASSERT (is_new_memory (), " Only Buffer with NewMemory can change his output precision!" );
75
+ void NewMemoryBuffer::validate_and_infer_types () {
76
+ INTERNAL_OP_SCOPE (Buffer_validate_and_infer_types);
77
+ OPENVINO_ASSERT (get_input_size () == 0 , " Buffer with new allocated memory mustn't have arguments!" );
78
+ set_output_type (0 , m_element_type, m_shape);
79
+ }
80
+
81
+ std::shared_ptr<Node> NewMemoryBuffer::clone_with_new_inputs (const OutputVector& new_args) const {
82
+ INTERNAL_OP_SCOPE (Buffer_clone_with_new_inputs);
83
+ check_new_args_count (this , new_args);
84
+ auto new_buffer = std::make_shared<NewMemoryBuffer>(m_shape, m_id, m_element_type);
85
+ new_buffer->set_offset (m_offset);
86
+ return new_buffer;
87
+ }
88
+
89
+ void NewMemoryBuffer::set_element_type (ov::element::Type element_type) {
84
90
m_element_type = std::move (element_type);
85
91
// Apply the change
86
92
validate_and_infer_types ();
87
93
}
88
94
95
+ NewMemoryBuffer::ShapeInfer::ShapeInfer (const std::shared_ptr<ov::Node>& n) {
96
+ const auto & buffer = ov::as_type_ptr<NewMemoryBuffer>(n);
97
+ OPENVINO_ASSERT (buffer, " Got invalid node in NewMemoryBuffer::ShapeInfer" );
98
+ m_shape = buffer->get_shape ();
99
+ }
100
+
101
+ IShapeInferSnippets::Result NewMemoryBuffer::ShapeInfer::infer (const std::vector<VectorDimsRef>& input_shapes) {
102
+ OPENVINO_ASSERT (input_shapes.size () == 1 , " Got unexpected number of input shapes" );
103
+ return {{m_shape}, ShapeInferStatus::success};
104
+ }
105
+
89
106
} // namespace op
90
107
} // namespace snippets
91
108
} // namespace ov
0 commit comments