forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinput_metadata.cpp
145 lines (124 loc) · 4.57 KB
/
input_metadata.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#include <torch/csrc/autograd/input_metadata.h>
// TODO: we may be able to move some imports from input_metadata.h to here, but
// it seems that function.h transitively depends on some of them.
namespace torch {
namespace autograd {
namespace {
MetadataShape compute_variant_shape(const at::Tensor& input) {
if (input.is_nested() && !input.unsafeGetTensorImpl()->is_python_dispatch()) {
auto nested_size = input._nested_tensor_size();
return MetadataShape{std::in_place_type<at::Tensor>, nested_size};
}
return MetadataShape{std::in_place_type<SymIntSmallVec>, input.sym_sizes()};
}
bool is_python_dispatch(const at::Tensor& tensor) {
return tensor.unsafeGetTensorImpl()->is_python_dispatch();
}
bool is_cpp_nested_tensor(const at::Tensor& tensor) {
return tensor.is_nested() && !is_python_dispatch(tensor);
}
} // namespace
InputMetadata::InputMetadata(
const at::TensorOptions& options,
MetadataShape input_shape,
bool is_tensor_subclass,
bool is_nested)
: options_{options},
shape_{std::move(input_shape)},
is_tensor_subclass_{is_tensor_subclass},
is_nested_{is_nested},
was_default_constructed_{false} {
auto device_ = options.device();
stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_);
}
InputMetadata::InputMetadata(const at::Tensor& t)
: InputMetadata(
t.options(),
compute_variant_shape(t),
is_python_dispatch(t),
t.is_nested()) {}
at::Tensor InputMetadata::zeros_like() const {
TORCH_CHECK(
!is_nested_, "Zeros is not currently supported for nested tensors.")
return at::zeros_symint(shape_as_dim_vector(), options_);
}
bool InputMetadata::is_same_shape(const at::Tensor& grad) const {
if (!is_nestedness_same(grad)) {
return false;
}
if (is_cpp_nested_tensor()) {
return grad._nested_tensor_size().is_same_size(shape_as_tensor());
}
return grad.sym_sizes().equals(shape_as_dim_vector());
}
bool InputMetadata::is_expandable_to_shape(const at::Tensor& grad) const {
if (!maybe_expandable_to(grad)) {
return false;
}
return at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
}
at::Tensor InputMetadata::reduce_grad(at::Tensor& grad) const {
// reduce_grad should only be called if is_expandable_to_shape returns true.
TORCH_INTERNAL_ASSERT(maybe_expandable_to(grad));
return at::sum_to(std::move(grad), shape_as_dim_vector());
}
std::stringstream InputMetadata::incompatible_shape_error_message(
const size_t index,
const at::Tensor& grad) const {
std::stringstream ss{};
ss << "invalid gradient at index " << index << " - got ";
if (::torch::autograd::is_cpp_nested_tensor(grad)) {
ss << grad._nested_tensor_size();
} else {
ss << grad.sym_sizes();
}
ss << " but expected shape compatible with ";
if (is_cpp_nested_tensor()) {
ss << shape_as_tensor();
} else {
ss << shape_as_dim_vector();
}
return ss;
}
bool InputMetadata::is_cpp_nested_tensor() const {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool ret = std::holds_alternative<at::Tensor>(shape_);
TORCH_INTERNAL_ASSERT(ret == (is_nested_ && !is_tensor_subclass_))
return ret;
}
c10::SymIntArrayRef InputMetadata::shape_as_dim_vector() const {
const auto& dim_shape = std::get<SymIntSmallVec>(shape_);
return c10::SymIntArrayRef(dim_shape.data(), dim_shape.size());
}
// Danger: not thread safe, caller must protect with lock
SymIntSmallVec& InputMetadata::mutable_shape_as_dim_vector() {
return std::get<SymIntSmallVec>(shape_);
}
bool InputMetadata::is_nestedness_same(const at::Tensor& grad) const {
return (
grad.is_nested() == is_nested_ &&
::torch::autograd::is_cpp_nested_tensor(grad) == is_cpp_nested_tensor());
}
at::Tensor InputMetadata::shape_as_tensor() const {
return std::get<at::Tensor>(shape_);
}
bool InputMetadata::maybe_expandable_to(const at::Tensor& grad) const {
// This is the initial step to determine whether or not the tensor represented
// by input_metadata is expandable to grad based on is-nestedness information
// alone. If this function returns true, then is_expandable_to_shape will be
// called. We support the following 3 types of expansion:
bool grad_is_nested = grad.is_nested();
if (!is_nested_ && !grad_is_nested) {
// Normal case (no NestedTensors are involved)
// (1) plain Tensor -> plain Tensor
return true;
} else {
// (2) python NT -> python NT
// (3) plain Tensor -> python NT
return (
grad_is_nested && is_python_dispatch(grad) &&
(!is_nested_ || is_tensor_subclass_));
}
}
} // namespace autograd
} // namespace torch