Skip to content

Commit 6612e25

Browse files
committed
WIP: List[Tensor] embedding to pure Tensors supported by OV
1 parent a08e369 commit 6612e25

File tree

13 files changed

+606
-85
lines changed

13 files changed

+606
-85
lines changed

Diff for: src/bindings/python/src/openvino/frontend/pytorch/__init__.py

+31-5
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def get_output_type (self, index):
7676
return self.get_type_for_value(output)
7777

7878
def _get_known_type_for_value (self, type):
79-
return
79+
#return
8080
'''
8181
Returns known/unknown types wrapped as OVAny
8282
'''
@@ -85,12 +85,17 @@ def _get_known_type_for_value (self, type):
8585
# TODO: Don't use str, use native types
8686
if str(type) in pt_to_ov_type_map:
8787
print(f'Recognized native type, type.__class__ = {type.__class__}')
88-
return OVAny(pt_to_ov_type_map[type])
88+
return OVAny(pt_to_ov_type_map[str(type)])
8989
elif type.__class__ is torch.TensorType:
9090
print(f'Recognized Tensor type with type.dtype() = {type.dtype()}')
9191
# Tensor type, parse element type
9292
# TODO: replace string by native type
93+
#return OVAny(PartialShape([1,2,3]))
9394
return OVAny(DecoderType.Tensor(self._get_known_type_for_value(type.dtype())))
95+
elif type.__class__ is torch.ListType:
96+
element_type = type.getElementType()
97+
print(f'Recognized torch List type. Type of element is {element_type}')
98+
return OVAny(DecoderType.List(self._get_known_type_for_value(element_type)))
9499
else:
95100
print(f'Not a tensor nor native type: {type}')
96101
# Not yet recognized
@@ -113,15 +118,27 @@ def get_shape_for_value (self, value):
113118
return PartialShape.dynamic()
114119

115120
def get_type_for_value (self, value):
116-
#DecoderType.print(self._get_known_type_for_value(value.type()))
121+
print(f'Decoding value type for value {value}')
122+
full_type = self._get_known_type_for_value(value.type())
123+
DecoderType.print(full_type) # new (full) type interpretation
124+
return full_type
125+
# Old version of this function directly treat Tensor[type] as type
126+
# assuming that regular type for vaue is Tensor, so it just
127+
# decodes its element type.
128+
# In full_type we code a complete type according to PT, it allows
129+
# to distiguish int from scalar Tensor[int] in particular.
130+
# It is necessary to interpreting some operations converting scalar values (not tensors)
131+
# to scalar tensors.
132+
# In this new interpretation we leave old beheviout to FE code if it is still needed
117133
if value.isCompleteTensor():
118134
pt_type = str(value.type().dtype())
135+
print(f'Trying to decode tensor element type: {pt_type}')
119136
if pt_type in pt_to_ov_type_map:
120137
ov_type = pt_to_ov_type_map[pt_type]
121-
#print(f'[ DEBUG ] Decoded ov type: {ov_type}', flush=True)
138+
print(f'[ DEBUG ] Decoded ov type: {ov_type}', flush=True)
122139
return OVAny(ov_type)
123140
else:
124-
#print(f'[ DEBUG ] Unrecognized pt element type for a tensor: {pt_type}. Captured it as custom type.', flush=True)
141+
print(f'[ DEBUG ] Unrecognized pt element type for a tensor: {pt_type}. Captured it as custom type.', flush=True)
125142
# TODO: Replace it by Tensor[dynamic]
126143
return OVAny(OVType.dynamic)
127144
else:
@@ -195,6 +212,9 @@ def as_constant (self):
195212
return None
196213
pt_value = self._raw_output(0)
197214
is_tensor = pt_value.isCompleteTensor()
215+
216+
print(f'Decoding value type for constant value {pt_value}')
217+
DecoderType.print(self._get_known_type_for_value(pt_value.type()))
198218

199219
if is_tensor and str(pt_value.type().dtype()) in pt_to_py_type_map:
200220
return self.as_constant_tensor(pt_value)
@@ -208,6 +228,9 @@ def as_constant (self):
208228
if str(pt_value.type()) in ['torch.int32', 'int']:
209229
#print(f'Found int value= {pt_value}, type = {type(pt_value.toIValue())}, ivalue = {pt_value.toIValue()}')
210230
return op.Constant(OVType.i32, Shape([]), [pt_value.toIValue()]).outputs()
231+
if str(pt_value.type()) in ['torch.FloatType', 'float']:
232+
#print(f'Found float value= {pt_value}, type = {type(pt_value.toIValue())}, ivalue = {pt_value.toIValue()}')
233+
return op.Constant(OVType.f32, Shape([]), [pt_value.toIValue()]).outputs()
211234
if str(pt_value.type()) in ['torch.bool', 'bool']:
212235
#print('Scalar bool detected')
213236
return op.Constant(OVType.boolean, Shape([]), [pt_value.toIValue()]).outputs()
@@ -218,6 +241,9 @@ def as_constant (self):
218241
return None
219242

220243
def as_constant_tensor (self, pt_value):
244+
# Constant interpretation doesn't respect new-full type of PT
245+
# It recognizes only tensors, and give lists as 1D tensors, and scalars as Tensor scalars
246+
# So only tensor-type constants are supported
221247
ovshape = PartialShape(pt_value.type().sizes())
222248
ovtype = pt_to_ov_type_map[str(pt_value.type().dtype())]
223249
np_value = pt_value.toIValue().cpu().detach().numpy().flatten().tolist() # TODO: find a better/shorter way

Diff for: src/bindings/python/src/pyopenvino/utils/utils.cpp

+85
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include "pyopenvino/utils/utils.hpp"
66

7+
#include "openvino/frontend/pytorch/decoder.hpp"
8+
79
#include <pybind11/stl.h>
810

911
#include <map>
@@ -136,3 +138,86 @@ std::string convert_path_to_string(const py::object& path) {
136138
}
137139
}; // namespace utils
138140
}; // namespace Common
141+
142+
ov::Any py_object_to_any(const py::object& py_obj) {
143+
144+
// TODO: Investigate if there is a better alternative for converting any registered pybind11 type
145+
// Just listing all known types here looks a double work as we have already registed a lot of OV types
146+
// in other pybind11 definitions.
147+
// Another option is to not unpack pybind object until ov::Any is casted.
148+
149+
150+
// Python types
151+
if (py::isinstance<py::str>(py_obj)) {
152+
return py_obj.cast<std::string>();
153+
} else if (py::isinstance<py::bool_>(py_obj)) {
154+
return py_obj.cast<bool>();
155+
} else if (py::isinstance<py::float_>(py_obj)) {
156+
return py_obj.cast<double>();
157+
} else if (py::isinstance<py::int_>(py_obj)) {
158+
return py_obj.cast<int64_t>();
159+
} else if (py::isinstance<py::list>(py_obj)) {
160+
auto _list = py_obj.cast<py::list>();
161+
enum class PY_TYPE : int { UNKNOWN = 0, STR, INT, FLOAT, BOOL };
162+
PY_TYPE detected_type = PY_TYPE::UNKNOWN;
163+
for (const auto& it : _list) {
164+
auto check_type = [&](PY_TYPE type) {
165+
if (detected_type == PY_TYPE::UNKNOWN || detected_type == type) {
166+
detected_type = type;
167+
return;
168+
}
169+
OPENVINO_ASSERT("Incorrect attribute. Mixed types in the list are not allowed.");
170+
};
171+
if (py::isinstance<py::str>(it)) {
172+
check_type(PY_TYPE::STR);
173+
} else if (py::isinstance<py::int_>(it)) {
174+
check_type(PY_TYPE::INT);
175+
} else if (py::isinstance<py::float_>(it)) {
176+
check_type(PY_TYPE::FLOAT);
177+
} else if (py::isinstance<py::bool_>(it)) {
178+
check_type(PY_TYPE::BOOL);
179+
}
180+
}
181+
182+
switch (detected_type) {
183+
case PY_TYPE::STR:
184+
return _list.cast<std::vector<std::string>>();
185+
case PY_TYPE::FLOAT:
186+
return _list.cast<std::vector<double>>();
187+
case PY_TYPE::INT:
188+
return _list.cast<std::vector<int64_t>>();
189+
case PY_TYPE::BOOL:
190+
return _list.cast<std::vector<bool>>();
191+
default:
192+
OPENVINO_ASSERT(false, "Unsupported attribute type.");
193+
}
194+
// OV types
195+
} else if (py::isinstance<ov::Any>(py_obj)) {
196+
return py::cast<ov::Any>(py_obj);
197+
} else if (py::isinstance<ov::element::Type>(py_obj)) {
198+
return py::cast<ov::element::Type>(py_obj);
199+
} else if (py::isinstance<ov::hint::Priority>(py_obj)) {
200+
return py::cast<ov::hint::Priority>(py_obj);
201+
} else if (py::isinstance<ov::hint::PerformanceMode>(py_obj)) {
202+
return py::cast<ov::hint::PerformanceMode>(py_obj);
203+
} else if (py::isinstance<ov::log::Level>(py_obj)) {
204+
return py::cast<ov::log::Level>(py_obj);
205+
} else if (py::isinstance<ov::device::Type>(py_obj)) {
206+
return py::cast<ov::device::Type>(py_obj);
207+
} else if (py::isinstance<ov::streams::Num>(py_obj)) {
208+
return py::cast<ov::streams::Num>(py_obj);
209+
} else if (py::isinstance<ov::Affinity>(py_obj)) {
210+
return py::cast<ov::Affinity>(py_obj);
211+
// Custom PT FE Types
212+
} else if (py::isinstance<ov::frontend::pytorch::Type::Tensor>(py_obj)) {
213+
std::cout << "[ ANY PYBIND ] Detected Tensor\n";
214+
return py::cast<ov::frontend::pytorch::Type::Tensor>(py_obj);
215+
} else if (py::isinstance<ov::frontend::pytorch::Type::List>(py_obj)) {
216+
std::cout << "[ ANY PYBIND ] Detected List\n";
217+
return py::cast<ov::frontend::pytorch::Type::List>(py_obj);
218+
// If there is no match fallback to py::object
219+
} else if (py::isinstance<py::object>(py_obj)) {
220+
return py_obj;
221+
}
222+
OPENVINO_ASSERT(false, "Unsupported attribute type.");
223+
}

Diff for: src/bindings/python/src/pyopenvino/utils/utils.hpp

+1-68
Original file line numberDiff line numberDiff line change
@@ -21,71 +21,4 @@ namespace utils {
2121
}; // namespace utils
2222
}; // namespace Common
2323

24-
inline ov::Any py_object_to_any(const py::object& py_obj) {
25-
// Python types
26-
if (py::isinstance<py::str>(py_obj)) {
27-
return py_obj.cast<std::string>();
28-
} else if (py::isinstance<py::bool_>(py_obj)) {
29-
return py_obj.cast<bool>();
30-
} else if (py::isinstance<py::float_>(py_obj)) {
31-
return py_obj.cast<double>();
32-
} else if (py::isinstance<py::int_>(py_obj)) {
33-
return py_obj.cast<int64_t>();
34-
} else if (py::isinstance<py::list>(py_obj)) {
35-
auto _list = py_obj.cast<py::list>();
36-
enum class PY_TYPE : int { UNKNOWN = 0, STR, INT, FLOAT, BOOL };
37-
PY_TYPE detected_type = PY_TYPE::UNKNOWN;
38-
for (const auto& it : _list) {
39-
auto check_type = [&](PY_TYPE type) {
40-
if (detected_type == PY_TYPE::UNKNOWN || detected_type == type) {
41-
detected_type = type;
42-
return;
43-
}
44-
OPENVINO_ASSERT("Incorrect attribute. Mixed types in the list are not allowed.");
45-
};
46-
if (py::isinstance<py::str>(it)) {
47-
check_type(PY_TYPE::STR);
48-
} else if (py::isinstance<py::int_>(it)) {
49-
check_type(PY_TYPE::INT);
50-
} else if (py::isinstance<py::float_>(it)) {
51-
check_type(PY_TYPE::FLOAT);
52-
} else if (py::isinstance<py::bool_>(it)) {
53-
check_type(PY_TYPE::BOOL);
54-
}
55-
}
56-
57-
switch (detected_type) {
58-
case PY_TYPE::STR:
59-
return _list.cast<std::vector<std::string>>();
60-
case PY_TYPE::FLOAT:
61-
return _list.cast<std::vector<double>>();
62-
case PY_TYPE::INT:
63-
return _list.cast<std::vector<int64_t>>();
64-
case PY_TYPE::BOOL:
65-
return _list.cast<std::vector<bool>>();
66-
default:
67-
OPENVINO_ASSERT(false, "Unsupported attribute type.");
68-
}
69-
// OV types
70-
} else if (py::isinstance<ov::Any>(py_obj)) {
71-
return py::cast<ov::Any>(py_obj);
72-
} else if (py::isinstance<ov::element::Type>(py_obj)) {
73-
return py::cast<ov::element::Type>(py_obj);
74-
} else if (py::isinstance<ov::hint::Priority>(py_obj)) {
75-
return py::cast<ov::hint::Priority>(py_obj);
76-
} else if (py::isinstance<ov::hint::PerformanceMode>(py_obj)) {
77-
return py::cast<ov::hint::PerformanceMode>(py_obj);
78-
} else if (py::isinstance<ov::log::Level>(py_obj)) {
79-
return py::cast<ov::log::Level>(py_obj);
80-
} else if (py::isinstance<ov::device::Type>(py_obj)) {
81-
return py::cast<ov::device::Type>(py_obj);
82-
} else if (py::isinstance<ov::streams::Num>(py_obj)) {
83-
return py::cast<ov::streams::Num>(py_obj);
84-
} else if (py::isinstance<ov::Affinity>(py_obj)) {
85-
return py::cast<ov::Affinity>(py_obj);
86-
// If there is no match fallback to py::object
87-
} else if (py::isinstance<py::object>(py_obj)) {
88-
return py_obj;
89-
}
90-
OPENVINO_ASSERT(false, "Unsupported attribute type.");
91-
}
24+
ov::Any py_object_to_any(const py::object& py_obj);

Diff for: src/core/include/openvino/op/parameter.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class OPENVINO_API Parameter : public op::Op {
6666
protected:
6767
PartialShape m_partial_shape;
6868
element::Type m_element_type;
69+
Any m_element_custom_type;
6970
bool m_is_relevant_to_shapes{false};
7071
};
7172
} // namespace v0

Diff for: src/core/include/openvino/op/util/framework_node.hpp

+5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ namespace ov {
1414
namespace op {
1515
namespace util {
1616

17+
// TODO: Consider removing this
1718
class OPENVINO_API FrameworkNodeAttrs {
1819
public:
1920
using attrs_t = std::unordered_map<std::string, std::string>;
@@ -62,6 +63,10 @@ class OPENVINO_API FrameworkNodeAttrs {
6263
return m_type_name == other.m_type_name && m_opset_name == other.m_opset_name && m_attrs == other.m_attrs;
6364
}
6465

66+
attrs_t::const_iterator find(const std::string& key) const {
67+
return m_attrs.find(key);
68+
}
69+
6570
private:
6671
std::string m_type_name;
6772
std::string m_opset_name;

Diff for: src/core/src/op/parameter.cpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,13 @@ op::Parameter::Parameter(const element::Type& element_type, const ov::Any& eleme
2929
OPENVINO_ASSERT(element_type == element::custom, "Parameter ctor with 3 arguments accept element_type = element::custom only");
3030
// If element_type is custom, it doesn't mean that it is really custom, it may be just a way to hide normal type under Any
3131
// In some circumstances it is simpler to wrap a regular type in Any and then pass through multi-layer API that works with Any only
32+
std::cout << "Parameter of custom type: attempt to detect simple type\n";
3233
if(element_custom_type.is<element::Type>()) {
34+
std::cout << "Parameter of custom type is simple type: " << element_custom_type.as<element::Type>() << "\n";
3335
m_element_type = element_custom_type.as<element::Type>();
36+
} else {
37+
m_element_type = element_type; // custom
38+
m_element_custom_type = element_custom_type;
3439
}
3540
constructor_validate_and_infer_types();
3641
}
@@ -45,7 +50,11 @@ bool op::Parameter::visit_attributes(AttributeVisitor& visitor) {
4550
void op::Parameter::validate_and_infer_types() {
4651
NGRAPH_OP_SCOPE(v0_Parameter_validate_and_infer_types);
4752
Op::validate_and_infer_types();
48-
set_output_type(0, m_element_type, m_partial_shape);
53+
if(m_element_type == element::custom) {
54+
set_custom_output_type(0, m_element_custom_type, m_partial_shape);
55+
} else {
56+
set_output_type(0, m_element_type, m_partial_shape);
57+
}
4958
}
5059

5160
shared_ptr<Node> op::Parameter::clone_with_new_inputs(const OutputVector& new_args) const {

Diff for: src/core/src/op/strided_slice.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,6 @@ shared_ptr<Node> calculate_default_strides(const Output<Node>& begin, const Outp
6060
strides_length = end_pshape[0].get_length();
6161
} else // dynamic case
6262
{
63-
NGRAPH_CHECK(begin_pshape.rank().is_static() && begin_pshape.rank().get_length() == 1,
64-
"Begin input must be 1D");
6563
return std::make_shared<op::v1::Broadcast>(op::Constant::create(element::i64, {}, {1}),
6664
std::make_shared<op::ShapeOf>(begin));
6765
}

Diff for: src/core/src/pass/serialize.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -632,9 +632,13 @@ std::string get_precision_name(const ngraph::element::Type& elem_type) {
632632
return "BIN";
633633
case ::ngraph::element::Type_t::boolean:
634634
return "BOOL";
635+
case ::ngraph::element::Type_t::custom:
636+
return "CUSTOM";
635637
default:
636638
std::stringstream msg;
639+
std::cerr << "[ ERROR ] Unsupported precision\n";
637640
msg << "Unsupported precision: " << elem_type;
641+
std::cerr << "[ ERROR ] End of unsupported precision " << elem_type << "\n";
638642
throw ngraph_error(msg.str());
639643
}
640644
}

Diff for: src/core/src/type/element_type.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ inline TypeInfo get_type_info(ov::element::Type_t type) {
7373
return {32, false, false, false, "uint32_t", "u32"};
7474
case ov::element::Type_t::u64:
7575
return {64, false, false, false, "uint64_t", "u64"};
76+
// This doesn't make sence for custom, but requird for various services
77+
case ov::element::Type_t::custom:
78+
return {0, false, false, false, "custom", "custom"};
7679
default:
7780
OPENVINO_UNREACHABLE("ov::element::Type_t not supported: ", type);
7881
}
@@ -96,7 +99,8 @@ std::vector<const ov::element::Type*> ov::element::Type::get_known_types() {
9699
&ov::element::u8,
97100
&ov::element::u16,
98101
&ov::element::u32,
99-
&ov::element::u64};
102+
&ov::element::u64,
103+
&ov::element::custom};
100104
return rc;
101105
}
102106

@@ -361,6 +365,8 @@ inline size_t compiler_byte_size(ov::element::Type_t et) {
361365
return 0;
362366
case ov::element::Type_t::dynamic:
363367
return 0;
368+
case ov::element::Type_t::custom:
369+
return 0;
364370
}
365371

366372
throw ov::Exception("compiler_byte_size: Unsupported value of ov::element::Type_t: " +
@@ -373,6 +379,7 @@ NGRAPH_API EnumNames<element::Type_t>& EnumNames<element::Type_t>::get() {
373379
static auto enum_names = EnumNames<element::Type_t>("element::Type_t",
374380
{{"undefined", element::Type_t::undefined},
375381
{"dynamic", element::Type_t::dynamic},
382+
{"custom", element::Type_t::custom},
376383
{"boolean", element::Type_t::boolean},
377384
{"bf16", element::Type_t::bf16},
378385
{"f16", element::Type_t::f16},

Diff for: src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ struct NamedTuple;
4141
struct Union;
4242

4343
inline void print (const Any& x) {
44+
std::cout << "XDecoder.print: {" << x.type_info().name() << "}: ";
4445
if(x.is<element::Type>()) {
4546
std::cout << x.as<element::Type>();
4647
} else if(x.is<Tensor>()) {
@@ -54,6 +55,7 @@ inline void print (const Any& x) {
5455
} else {
5556
std::cout << "UNKNWON_ANY_TYPE";
5657
}
58+
std::cout << std::flush;
5759
}
5860

5961
}

0 commit comments

Comments
 (0)