forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TensorIndexing.cpp
90 lines (77 loc) · 3.19 KB
/
TensorIndexing.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <ATen/TensorIndexing.h>
#include <c10/util/Exception.h>
namespace at {
namespace indexing {
const EllipsisIndexType Ellipsis = EllipsisIndexType();
std::ostream& operator<<(std::ostream& stream, const Slice& slice) {
stream << slice.start() << ":" << slice.stop() << ":" << slice.step();
return stream;
}
std::ostream& operator<<(std::ostream& stream, const TensorIndex& tensor_index) {
if (tensor_index.is_none()) {
stream << "None";
} else if (tensor_index.is_ellipsis()) {
stream << "...";
} else if (tensor_index.is_integer()) {
stream << tensor_index.integer();
} else if (tensor_index.is_boolean()) {
stream << std::boolalpha << tensor_index.boolean();
} else if (tensor_index.is_slice()) {
stream << tensor_index.slice();
} else if (tensor_index.is_tensor()) {
stream << tensor_index.tensor();
}
return stream;
}
std::ostream& operator<<(std::ostream& stream, const std::vector<TensorIndex>& tensor_indices) {
stream << "(";
for (size_t i = 0; i < tensor_indices.size(); i++) {
stream << tensor_indices[i];
if (i < tensor_indices.size() - 1) stream << ", ";
}
stream << ")";
return stream;
}
// This mirrors `THPVariable_setitem` in torch/csrc/autograd/python_variable_indexing.cpp
// for "the assigned value is a Scalar" case
static inline void set_item(Tensor& self, ArrayRef<TensorIndex> indices, Scalar v) {
Tensor value;
{
at::AutoNonVariableTypeMode guard;
// TODO: This qint special case looks very suspicious...
if (isQIntType(self.scalar_type())) {
value = at::indexing::scalarToTensor(v, device(kCPU).dtype(kFloat), at::Device(kCPU));
} else {
value = at::indexing::scalarToTensor(v, self.options(), self.device());
}
}
return set_item(self, indices, value);
}
} // namespace indexing
Tensor Tensor::index(ArrayRef<at::indexing::TensorIndex> indices) const {
TORCH_CHECK(indices.size() > 0, "Passing an empty index list to Tensor::index() is not valid syntax");
OptionalDeviceGuard device_guard(device_of(*this));
return at::indexing::get_item(*this, indices);
}
Tensor Tensor::index(std::initializer_list<at::indexing::TensorIndex> indices) const {
return index(ArrayRef<at::indexing::TensorIndex>(indices));
}
Tensor & Tensor::index_put_(ArrayRef<at::indexing::TensorIndex> indices, Tensor const & rhs) {
TORCH_CHECK(indices.size() > 0, "Passing an empty index list to Tensor::index_put_() is not valid syntax");
OptionalDeviceGuard device_guard(device_of(*this));
at::indexing::set_item(*this, indices, rhs);
return *this;
}
Tensor & Tensor::index_put_(ArrayRef<at::indexing::TensorIndex> indices, Scalar v) {
TORCH_CHECK(indices.size() > 0, "Passing an empty index list to Tensor::index_put_() is not valid syntax");
OptionalDeviceGuard device_guard(device_of(*this));
at::indexing::set_item(*this, indices, v);
return *this;
}
Tensor & Tensor::index_put_(std::initializer_list<at::indexing::TensorIndex> indices, Tensor const & rhs) {
return index_put_(ArrayRef<at::indexing::TensorIndex>(indices), rhs);
}
Tensor & Tensor::index_put_(std::initializer_list<at::indexing::TensorIndex> indices, Scalar v) {
return index_put_(ArrayRef<at::indexing::TensorIndex>(indices), v);
}
} // namespace at