Skip to content

Commit 3d90e56

Browse files
committed
Update on "Add gradcheck for forward AD by default and functional API"
RFC: pytorch/rfcs#11 This PR adds the option to check forward grad using gradcheck. The current logic is: - Forward grad is always checked - If the forward evaluation fails because an op is not implemented, the test is silently passing The goal is to make sure that all formulas that are added are properly tested without having to add a new test for each op. The final logic after the next PR that adds the remaining formulas is going to be: - Forward grad is always checked - Failure with not implemented op is an actual failure - Users should set `check_forward=False` if they explicitly don't want to test forward grads (which should not be the case internally). [ghstack-poisoned]
2 parents b9b7442 + dc852d0 commit 3d90e56

File tree

6 files changed

+92
-17
lines changed

6 files changed

+92
-17
lines changed

torch/csrc/autograd/VariableTypeUtils.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,22 @@ template<typename... Args> inline variable_list flatten_tensor_args(Args&&... ar
137137
inline Tensor as_view(const Tensor & base, const Tensor & tensor, bool is_bw_differentiable,
138138
bool is_fw_differentiable, c10::optional<std::function<Tensor(const Tensor&)>> view_func=c10::nullopt,
139139
CreationMeta creation_meta=CreationMeta::DEFAULT, bool allow_tensor_metadata_change=true) {
140+
if (!isForwardADEnabled()) {
141+
// Fast codepath for backward only code
142+
if (is_bw_differentiable) {
143+
if (base.is_view()) {
144+
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(base));
145+
auto base_bw_info = diff_view_meta->get_backward_view();
146+
return make_variable_differentiable_view(std::move(tensor), base_bw_info.chain(base, tensor, view_func),
147+
c10::nullopt, creation_meta, allow_tensor_metadata_change);
148+
} else {
149+
return make_variable_differentiable_view(std::move(tensor), ViewInfo(base, view_func),
150+
c10::nullopt, creation_meta, allow_tensor_metadata_change);
151+
}
152+
} else {
153+
return make_variable_non_differentiable_view(base, std::move(tensor), allow_tensor_metadata_change);
154+
}
155+
}
140156
// Create both the forward and backward info that are needed
141157
c10::optional<ViewInfo> new_bw_info = c10::nullopt;
142158
c10::optional<ViewInfo> new_fw_info = c10::nullopt;
@@ -167,7 +183,8 @@ inline Tensor as_view(const Tensor & base, const Tensor & tensor, bool is_bw_dif
167183
}
168184

169185
if (is_fw_differentiable || is_bw_differentiable) {
170-
return make_variable_differentiable_view(std::move(tensor), new_bw_info, new_fw_info, creation_meta, allow_tensor_metadata_change);
186+
return make_variable_differentiable_view(std::move(tensor), std::move(new_bw_info), std::move(new_fw_info),
187+
creation_meta, allow_tensor_metadata_change);
171188
} else {
172189
return make_variable_non_differentiable_view(base, std::move(tensor), allow_tensor_metadata_change);
173190
}
@@ -195,7 +212,7 @@ inline std::vector<Tensor> as_view(const Tensor & base, std::vector<Tensor>& ten
195212
TORCH_CHECK(creation_meta == CreationMeta::DEFAULT,
196213
"Non-backward differentiable views must have creation_meta=CreationMeta::DEFAULT");
197214
}
198-
if (is_fw_differentiable) {
215+
if (isForwardADEnabled() && is_fw_differentiable) {
199216
// Check if base is a forward differentiabble view
200217
auto is_view = torch::autograd::impl::get_autograd_meta(base) && torch::autograd::impl::get_autograd_meta(base)->is_view_;
201218
if (is_view && static_cast<DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(base))->has_fw_view()) {

torch/csrc/autograd/forward_grad.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ namespace {
1212

1313
const static at::Tensor singleton_undefined_tensor;
1414

15+
// Temporary flag to disable forward mode
16+
// TODO(alband) remove these when perf issues are solved
17+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
18+
static bool is_forward_grad_enabled = false;
1519
}
1620

1721
uint64_t ForwardADLevel::get_next_idx() {
@@ -55,4 +59,18 @@ const at::Tensor& ForwardGrad::value(uint64_t level) const {
5559
return it == content_.end() ? singleton_undefined_tensor : (*it).second;
5660
}
5761

62+
const at::Tensor& ForwardGrad::undef_grad() {
63+
return singleton_undefined_tensor;
64+
}
65+
66+
// Temporary functions to disable forward AD
67+
// TODO(alband) remove these when perf issues are solved
68+
bool isForwardADEnabled() {
69+
return is_forward_grad_enabled;
70+
}
71+
72+
void setForwardADEnabled(bool value) {
73+
is_forward_grad_enabled = value;
74+
}
75+
5876
}} // namespace torch::autograd

torch/csrc/autograd/forward_grad.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,18 @@ struct TORCH_API ForwardGrad : std::enable_shared_from_this<ForwardGrad> {
9191
return content_.empty();
9292
}
9393

94+
static const at::Tensor& undef_grad();
95+
9496

9597
private:
9698
std::unordered_map<uint64_t, at::Tensor> content_;
9799
mutable std::mutex mutex_;
98100

99101
};
100102

103+
// Temporary functions to disable forward AD
104+
// TODO(alband) remove these when perf issues are solved
105+
bool TORCH_API isForwardADEnabled();
106+
void TORCH_API setForwardADEnabled(bool value);
107+
101108
}} // namespace torch::autograd

torch/csrc/autograd/init.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,26 @@ static PyObject * autocast_decrement_nesting(PyObject* _unused, PyObject *arg) {
235235
END_HANDLE_TH_ERRORS
236236
}
237237

238+
static PyObject * set_forward_AD_enabled(PyObject* _unused, PyObject *arg) {
239+
HANDLE_TH_ERRORS
240+
if (!PyBool_Check(arg)) {
241+
throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name);
242+
}
243+
setForwardADEnabled(arg == Py_True);
244+
Py_RETURN_NONE;
245+
END_HANDLE_TH_ERRORS
246+
}
247+
248+
static PyObject * is_forward_AD_enabled(PyObject* _unused, PyObject *arg) {
249+
HANDLE_TH_ERRORS
250+
if (isForwardADEnabled()) {
251+
Py_RETURN_TRUE;
252+
} else {
253+
Py_RETURN_FALSE;
254+
}
255+
END_HANDLE_TH_ERRORS
256+
}
257+
238258
static PyObject * set_grad_enabled(PyObject* _unused, PyObject *arg) {
239259
HANDLE_TH_ERRORS
240260
if (!PyBool_Check(arg)) {
@@ -327,6 +347,8 @@ static PyObject * python_unpack_dual(PyObject* _unused, PyObject* args, PyObject
327347
static PyMethodDef methods[] = { // NOLINT
328348
{"_set_grad_enabled", set_grad_enabled, METH_O, nullptr},
329349
{"is_grad_enabled", is_grad_enabled, METH_NOARGS, nullptr},
350+
{"_set_forward_AD_enabled", set_forward_AD_enabled, METH_O, nullptr},
351+
{"_is_forward_AD_enabled", is_forward_AD_enabled, METH_NOARGS, nullptr},
330352
{"set_autocast_enabled", set_autocast_enabled, METH_O, nullptr},
331353
{"is_autocast_enabled", is_autocast_enabled, METH_NOARGS, nullptr},
332354
{"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr},

torch/csrc/autograd/variable.cpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@ DifferentiableViewMeta::DifferentiableViewMeta(at::TensorImpl* self_impl,
3131
c10::optional<ViewInfo> backward_info,
3232
c10::optional<ViewInfo> forward_info,
3333
CreationMeta creation_meta)
34-
: AutogradMeta(self_impl), creation_meta(creation_meta),
34+
: AutogradMeta(self_impl),
3535
backward_info_(std::move(backward_info)),
36-
forward_info_(std::move(forward_info)) {
36+
forward_info_(std::move(forward_info)),
37+
creation_meta(creation_meta) {
3738
is_view_ = true;
3839
if (backward_info_.has_value()) {
3940
self_impl->set_version_counter(impl::version_counter(backward_info_.value().base_));
@@ -594,6 +595,10 @@ namespace {
594595
// This function is will ensure that the fw_grad_ is properly a view of the base for inplace ops on
595596
// Tensors that do not have forward grad originally.
596597
void AutogradMeta::set_fw_grad(Variable& new_grad, const Variable& self, uint64_t level, bool is_inplace_op) {
598+
if (!fw_grad_) {
599+
// Lazy initialization
600+
fw_grad_ = std::make_shared<ForwardGrad>();
601+
}
597602
if (fw_grad_->contains(level)) {
598603
// Setting the forward grad again is only allowed if it is a no-op.
599604
// We do allow this case to simplify writing codegen for inplace ops.
@@ -652,33 +657,38 @@ void AutogradMeta::set_fw_grad(Variable& new_grad, const Variable& self, uint64_
652657
}
653658

654659
const Variable& AutogradMeta::fw_grad(uint64_t level, const Variable& self) const {
655-
const auto& val = fw_grad_->value(level);
656-
if (!val.defined() && is_view_) {
660+
bool has_no_direct_fw_grad = !(fw_grad_ && fw_grad_->value(level).defined());
661+
if (has_no_direct_fw_grad && is_view_) {
657662
// For view that don't have a forward grad, check if their base has one that
658663
// has been defined by an inplace operation.
659664
// See [Forward Grad View] for more details.
660-
const auto this_view_meta = static_cast<const DifferentiableViewMeta*>(this);
665+
auto this_view_meta = static_cast<torch::autograd::DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(self));
661666
if (this_view_meta->has_fw_view()) {
662667
auto view_info = this_view_meta->get_forward_view();
663668
const auto& base = view_info.base_;
664669

665670
const auto& base_val = base.fw_grad(level);
666671
if (base_val.defined()) {
672+
// Lazy initialization
673+
this_view_meta->fw_grad_ = std::make_shared<ForwardGrad>();
674+
667675
Variable new_val;
668676
if (view_info.has_view_fn()) {
669677
new_val = view_info.view_fn()(base_val);
670678
} else {
671679
new_val = base_val.as_strided(self.sizes(), self.strides(), self.storage_offset());
672680
}
673681

674-
fw_grad_->set_value(new_val, level);
675-
return fw_grad_->value(level);
676-
} else {
677-
return val;
682+
this_view_meta->fw_grad_->set_value(new_val, level);
683+
return this_view_meta->fw_grad_->value(level);
678684
}
679685
}
680686
}
681-
return val;
687+
if (fw_grad_) {
688+
return fw_grad_->value(level);
689+
} else {
690+
return ForwardGrad::undef_grad();
691+
}
682692
}
683693

684694
}} // namespace torch::autograd

torch/csrc/autograd/variable.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,12 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
191191
std::string name_;
192192

193193
Variable grad_;
194-
std::shared_ptr<ForwardGrad> fw_grad_;
195194
std::shared_ptr<Node> grad_fn_;
196195
std::weak_ptr<Node> grad_accumulator_;
197196

197+
// This field is lazily initialized
198+
std::shared_ptr<ForwardGrad> fw_grad_;
199+
198200
std::vector<std::shared_ptr<FunctionPreHook>> hooks_;
199201
std::shared_ptr<hooks_list> cpp_hooks_list;
200202

@@ -250,7 +252,6 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
250252
retains_grad_ = false;
251253
is_view_ = false;
252254
output_nr_ = gradient_edge.input_nr;
253-
fw_grad_ = std::make_shared<ForwardGrad>();
254255

255256
// set_requires_grad also checks error conditions.
256257
if (requires_grad) {
@@ -295,9 +296,9 @@ struct TORCH_API ViewInfo {
295296
ViewInfo chain(const Variable & base, const Variable & tensor,
296297
c10::optional<std::function<Variable(const Variable&)>> view_func=c10::nullopt);
297298

298-
ViewInfo(Variable base, c10::optional<std::function<Variable(const Variable&)>> view_fn) {
299-
base_ = std::move(base);
300-
view_fn_ = std::move(view_fn);
299+
ViewInfo(Variable base, c10::optional<std::function<Variable(const Variable&)>> view_fn) :
300+
base_(std::move(base)),
301+
view_fn_(std::move(view_fn)) {
301302
TORCH_CHECK(base_.defined(), "base is undefined");
302303
}
303304
};

0 commit comments

Comments
 (0)