Skip to content

Commit

Permalink
optimize some functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles-hit committed Jan 8, 2023
1 parent c2ff1a7 commit caec995
Showing 1 changed file with 33 additions and 18 deletions.
51 changes: 33 additions & 18 deletions paddle/fluid/prim/utils/static/composite_grad_desc_maker.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,17 @@ class GradCompositeOpMakerBase {
std::vector<paddle::optional<paddle::experimental::Tensor>> outputs_opt;
std::vector<framework::VarDesc*> outputs_descs =
this->MultiForwardOutput(name);
outputs_opt.resize(outputs_descs.size());
for (size_t i = 0; i < outputs_descs.size(); ++i) {
if (outputs_descs[i]) {
outputs_opt[i] = paddle::make_optional<paddle::experimental::Tensor>(
paddle::experimental::Tensor(
std::make_shared<DescTensor>(outputs_descs[i])));
outputs_opt.reserve(outputs_descs.size());
for (const auto& output_desc : outputs_descs) {
if (output_desc) {
outputs_opt.emplace_back(
paddle::make_optional<paddle::experimental::Tensor>(
paddle::experimental::Tensor(
std::make_shared<DescTensor>(output_desc))));
} else {
outputs_opt.emplace_back(
paddle::make_optional<paddle::experimental::Tensor>(
paddle::experimental::Tensor()));
}
}
return outputs_opt;
Expand All @@ -214,12 +219,17 @@ class GradCompositeOpMakerBase {
std::vector<paddle::optional<paddle::experimental::Tensor>> inputs_opt;
std::vector<framework::VarDesc*> inputs_descs =
this->MultiForwardInput(name);
inputs_opt.resize(inputs_descs.size());
for (size_t i = 0; i < inputs_descs.size(); ++i) {
if (inputs_descs[i]) {
inputs_opt[i] = paddle::make_optional<paddle::experimental::Tensor>(
paddle::experimental::Tensor(
std::make_shared<DescTensor>(inputs_descs[i])));
inputs_opt.reserve(inputs_descs.size());
for (const auto& input_desc : inputs_descs) {
if (input_desc) {
inputs_opt.emplace_back(
paddle::make_optional<paddle::experimental::Tensor>(
paddle::experimental::Tensor(
std::make_shared<DescTensor>(input_desc))));
} else {
inputs_opt.emplace_back(
paddle::make_optional<paddle::experimental::Tensor>(
paddle::experimental::Tensor()));
}
}
return inputs_opt;
Expand All @@ -230,12 +240,17 @@ class GradCompositeOpMakerBase {
std::vector<paddle::optional<paddle::experimental::Tensor>> outputs_grads;
std::vector<framework::VarDesc*> outputs_grads_descs =
this->MultiOutputGrad(name);
outputs_grads.resize(outputs_grads_descs.size());
for (size_t i = 0; i < outputs_grads_descs.size(); ++i) {
if (outputs_grads_descs[i]) {
outputs_grads[i] = paddle::make_optional<paddle::experimental::Tensor>(
paddle::experimental::Tensor(
std::make_shared<DescTensor>(outputs_grads_descs[i])));
outputs_grads.reserve(outputs_grads_descs.size());
for (const auto& output_grad_desc : outputs_grads_descs) {
if (output_grad_desc) {
outputs_grads.emplace_back(
paddle::make_optional<paddle::experimental::Tensor>(
paddle::experimental::Tensor(
std::make_shared<DescTensor>(output_grad_desc))));
} else {
outputs_grads.emplace_back(
paddle::make_optional<paddle::experimental::Tensor>(
paddle::experimental::Tensor()));
}
}
return outputs_grads;
Expand Down

0 comments on commit caec995

Please sign in to comment.