Skip to content

Commit

Permalink
Added accumulation codes to Eager Dygraph
Browse files Browse the repository at this point in the history
  • Loading branch information
jim19930609 committed Nov 23, 2021
1 parent 15984f4 commit 061be26
Show file tree
Hide file tree
Showing 15 changed files with 576 additions and 9 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/eager/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_subdirectory(accumulation)
add_subdirectory(tests)

cc_library(grad_node_info SRCS grad_node_info.cc DEPS pten pten_api)
cc_library(autograd_meta SRCS autograd_meta.cc DEPS pten pten_api)

cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info)
cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulation)
2 changes: 2 additions & 0 deletions paddle/fluid/eager/accumulation/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
cc_library(gradient_accumulation SRCS gradient_accumulation.cc DEPS blas pten pten_api var_type_traits layer math_function)
cc_library(accumulation_node SRCS accumulation_node.cc DEPS gradient_accumulation pten pten_api grad_node_info)
81 changes: 81 additions & 0 deletions paddle/fluid/eager/accumulation/accumulation_node.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/eager/accumulation/accumulation_node.h"
#include "paddle/fluid/eager/accumulation/gradient_accumulation.h"
#include "paddle/fluid/eager/eager_tensor.h"

#include "paddle/pten/api/all.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/include/core.h"

#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"

#include "glog/logging.h"

static void CopyOrAddTensor(egr::EagerTensor* tensor,
const egr::EagerTensor& t) {
if (!tensor->defined() || !tensor->initialized()) {
// Simply copy tensor->impl
*tensor = t;
} else {
// Accumulation
egr::TensorAdd(t, tensor);
}
}

namespace egr {

void GradNodeAccumulation::RetainGrad(
const std::function<egr::EagerTensor(const egr::EagerTensor&)>& hook) {
retain_grad_hook_ = hook;
}

std::vector<std::vector<egr::EagerTensor>> GradNodeAccumulation::operator()(
const std::vector<std::vector<egr::EagerTensor>>& grads) {
PADDLE_ENFORCE(grads.size() == 1,
paddle::platform::errors::Fatal(
"GradNodeAccumulation should take exactly 1 grad tensor"
"However received: %d slot.",
grads.size()));
PADDLE_ENFORCE(grads[0].size() == 1,
paddle::platform::errors::Fatal(
"GradNodeAccumulation should take exactly 1 grad tensor"
"However received: %d in slot %d .",
grads[0].size(), 0));
// Apply Gradient Hooks
if (GradientHooksRegistered()) {
std::vector<std::vector<egr::EagerTensor>> hooked_grads =
ApplyGradientHooks(grads);
// TODO(jiabin): It's little weird
CopyOrAddTensor(&accumulated_grad, hooked_grads[0][0]);
} else {
CopyOrAddTensor(&accumulated_grad, grads[0][0]);
}

if (retain_grad_hook_ != nullptr) {
retain_grad_hook_(accumulated_grad);
}

// Apply Reduce Hooks
if (ReduceHooksRegistered()) {
ApplyReduceHooks();
}

return {{accumulated_grad}};
}

} // namespace egr
41 changes: 41 additions & 0 deletions paddle/fluid/eager/accumulation/accumulation_node.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/eager/grad_node_info.h"

namespace egr {

class GradNodeAccumulation : public GradNodeBase {
public:
// Constructor: configure fwd input tensors to grad node
GradNodeAccumulation() : GradNodeBase(1, 1) { SetDefaultGradInOutMeta(); }

~GradNodeAccumulation() override = default;

// Functor: perform backward computations
virtual std::vector<std::vector<egr::EagerTensor>> operator()(
const std::vector<std::vector<egr::EagerTensor>>& grads) override;

void RetainGrad(
const std::function<egr::EagerTensor(const egr::EagerTensor&)>& hook);

private:
egr::EagerTensor accumulated_grad;

std::function<egr::EagerTensor(const egr::EagerTensor&)> retain_grad_hook_;
};

} // namespace egr
Loading

0 comments on commit 061be26

Please sign in to comment.