forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
register_ops_common_utils.cpp
103 lines (94 loc) · 3.22 KB
/
register_ops_common_utils.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#include <ATen/core/dynamic_type.h>
#include <ATen/core/type_factory.h>
#include <torch/csrc/jit/mobile/register_ops_common_utils.h>
namespace torch {
namespace jit {
int64_t normalizeIndex(int64_t idx, int64_t list_size) {
if (idx < 0) {
// Handle negative indexing
idx = list_size + idx;
}
return idx;
}
IValue tensorToListRecursive(
char* data,
int64_t cur_dim,
int64_t num_tensor_dims,
at::TypePtr ty,
at::ScalarType scalar_ty,
at::IntArrayRef sizes,
at::IntArrayRef strides,
size_t element_size) {
// If ty is a ListType, get the element type.
if (auto list_type = ty->cast<at::ListType>()) {
ty = list_type->getElementType();
} else {
// If the output type is a scalar, read and push one scalar of
// the right type onto the stack.
if (ty == at::IntType::get()) {
int64_t scalar = *(int64_t*)data;
return IValue(scalar);
} else if (ty == at::FloatType::get()) {
TORCH_INTERNAL_ASSERT(
scalar_ty == at::ScalarType::Float ||
scalar_ty == at::ScalarType::Double,
"Unexpected scalar type for Tensor");
double scalar =
scalar_ty == at::ScalarType::Float ? *(float*)data : *(double*)data;
return IValue(scalar);
} else if (ty == at::ComplexType::get()) {
TORCH_INTERNAL_ASSERT(
scalar_ty == at::ScalarType::ComplexFloat ||
scalar_ty == at::ScalarType::ComplexDouble,
"Unexpected scalar type for Tensor");
c10::complex<double> scalar = scalar_ty == at::ScalarType::ComplexFloat
? *(c10::complex<float>*)data
: *(c10::complex<double>*)data;
return IValue(scalar);
} else if (ty == at::BoolType::get()) {
bool scalar = *(bool*)data;
return IValue(scalar);
} else {
TORCH_CHECK(
false,
ty->repr_str(),
" is not one of the supported types for tolist: int, float, bool");
}
}
// Make the result list consisting of elements of type ty. Since this
// invocation is processing dimension cur_dim, there will be sizes[cur_dim]
// output elements.
auto result = c10::impl::GenericList(ty);
result.reserve(sizes[cur_dim]);
// Since ty was a list type, tensorToListRecursive needs to be called
// recursively on each slice of the tensor in the current dimension.
for (int64_t i = 0, e = sizes[cur_dim]; i < e; ++i) {
auto inner_result = tensorToListRecursive(
data,
cur_dim + 1,
num_tensor_dims,
ty,
scalar_ty,
sizes,
strides,
element_size);
if (inner_result.isList()) {
result.emplace_back(inner_result.toList());
} else if (inner_result.isComplexDouble()) {
result.emplace_back(inner_result.toComplexDouble());
} else if (inner_result.isDouble()) {
result.emplace_back(inner_result.toDouble());
} else if (inner_result.isInt()) {
result.emplace_back(inner_result.toInt());
} else if (inner_result.isBool()) {
result.emplace_back(inner_result.toBool());
} else {
TORCH_INTERNAL_ASSERT(
false && "Unknown return type for tensorToListRecursive");
}
data += strides[cur_dim] * element_size;
}
return result;
}
} // namespace jit
} // namespace torch