Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Cherry-pick PR 37420】fix inplace bug when the first grad_var(loss_grad) is inplace var #37488

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions paddle/fluid/imperative/basic_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ void BasicEngine::Init(
platform::errors::AlreadyExists(
"Accumulators are not empty before preparing it for "
"backward network execution."));
PADDLE_ENFORCE_EQ(accumulators_with_grad_node_.empty(), true,
platform::errors::AlreadyExists(
"Accumulators with grad_node as the key are not empty "
"before preparing it for backward network execution."));

for (size_t i = 0; i < tensors.size(); ++i) {
auto var = tensors[i];
Expand All @@ -73,7 +77,6 @@ void BasicEngine::Init(
VLOG(5) << "Clear the auto-grad graph from grad var " << var->Name()
<< " because of retain_graph=False when calling backward";
var->GradVarBase()->SetGraphIsFreed(true);
var->GradVarBase()->ClearGradNode();
}

if (init_node == nullptr || var->OverridedStopGradient()) {
Expand Down Expand Up @@ -108,14 +111,18 @@ void BasicEngine::Init(
}

VariableWrapper* init_grad_var = var->GradVarBase()->SharedVar().get();
auto& accumulator = accumulators_[init_grad_var];
auto& accumulator =
accumulators_with_grad_node_[init_grad_var->GetGradNode()]
[init_grad_var];
if (!accumulator) {
if (FLAGS_sort_sum_gradient) {
accumulator.reset(new SortedGradientAccumulator(init_grad_var));
} else {
accumulator.reset(new EagerGradientAccumulator(init_grad_var));
}
}
accumulator->IncreaseRefCnt();
accumulator->IncreaseCurCnt();

init_nodes_.push_back(init_node);
}
Expand Down Expand Up @@ -253,10 +260,6 @@ void BasicEngine::PrepareDeps() {
node_deps_.empty(), true,
platform::errors::AlreadyExists("Op deps are not empty before preparing "
"it for backward network execution."));
PADDLE_ENFORCE_EQ(accumulators_with_grad_node_.empty(), true,
platform::errors::AlreadyExists(
"Accumulators with grad_node as the key are not empty "
"before preparing it for backward network execution."));

std::queue<GradOpNode*> q;
std::unordered_set<GradOpNode*> visited;
Expand Down
25 changes: 25 additions & 0 deletions python/paddle/fluid/tests/unittests/test_inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,5 +409,30 @@ def inplace_api_processing(self, var):
return var.subtract_(self.input_var_2)


class TestLossIsInplaceVar(unittest.TestCase):
def test_loss_is_inplace_var(self):
with paddle.fluid.dygraph.guard():
var_a = paddle.ones((2, 2))
var_a.stop_gradient = False

var_b = var_a * 2
loss = var_b.tanh_()

loss.backward()
inplace_grad_var_a = var_a.grad.numpy()

with paddle.fluid.dygraph.guard():
var_a = paddle.ones((2, 2))
var_a.stop_gradient = False

var_b = var_a * 2
loss = var_b.tanh()

loss.backward()
grad_var_a = var_a.grad.numpy()

self.assertTrue(np.array_equal(inplace_grad_var_a, grad_var_a))


if __name__ == '__main__':
unittest.main()