Skip to content

Commit

Permalink
[ Make FLAGS_einsum_opt as default ] Einsum memory optimization (#43397)
Browse files Browse the repository at this point in the history
* change logic for optimize

* modifty

* optimize the backward speed of EinsumOp

* add cache optimizer for einsum op

* EinsumOp: fix new dygraph mode error

* fix bug

* change Cache->InnerCache

* fix code

* fix

* add nan inf utils for einsum op

* add as_extra

* memory optimizer for einsum

* update code
  • Loading branch information
2742195759 authored Jun 14, 2022
1 parent 2106f66 commit 83abec6
Show file tree
Hide file tree
Showing 13 changed files with 67 additions and 23 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/eager/nan_inf_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfTensorAndVector& tensors) {
CheckTensorHasNanOrInf(api_name, std::get<0>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<1>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<2>(tensors));
}

} // namespace egr
3 changes: 2 additions & 1 deletion paddle/fluid/eager/nan_inf_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ using TupleOfFourTensors = std::tuple<Tensor, Tensor, Tensor, Tensor>;
using TupleOfFiveTensors = std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>;
using TupleOfSixTensors =
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor>;
using TupleOfTensorAndVector = std::tuple<Tensor, std::vector<Tensor>>;
using TupleOfTensorAndVector =
std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>>;

void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor);

Expand Down
19 changes: 15 additions & 4 deletions paddle/fluid/operators/einsum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ class EinsumOpMaker : public framework::OpProtoAndCheckerMaker {
.AsExtra()
.AsIntermediate();

AddOutput("XShape", "(Tensor), The cache of the x_shape of: A and B.")
.AsDuplicable()
.AsExtra()
.AsIntermediate();
AddAttr<std::string>("equation",
"(string) A einsum equation. such as `ij,jk->ik`"
"There must have `->` and the number of operands in "
Expand All @@ -59,8 +63,8 @@ class EinsumGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
auto x_name = "Operands";
auto x_grad_name = framework::GradVarName(x_name);
ctx->SetOutputsDim(x_grad_name, ctx->GetInputsDim(x_name));
ctx->ShareAllLoD(x_name, x_grad_name);
ctx->SetOutputsDim(x_grad_name, ctx->GetInputsDim("Operands"));
ctx->ShareAllLoD("Operands", x_grad_name);
}

protected:
Expand All @@ -79,8 +83,15 @@ class EinsumGradMaker : public framework::SingleGradOpMaker<T> {

void Apply(GradOpPtr<T> retv) const override {
retv->SetType("einsum_grad");
retv->SetInput("Operands", this->Input("Operands"));
retv->SetInput("InnerCache", this->Output("InnerCache"));
if (this->HasOutput("InnerCache")) {
retv->SetInput("InnerCache", this->Output("InnerCache"));
}
if (this->HasOutput("XShape")) {
// add if for compatibility.
retv->SetInput("Operands", this->Output("XShape")); // for memory save.
} else {
retv->SetInput("Operands", this->Input("Operands"));
}
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetAttrMap(this->Attrs());
retv->SetOutput(framework::GradVarName("Operands"),
Expand Down
9 changes: 8 additions & 1 deletion paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,8 @@ void EighInferMeta(const MetaTensor& x,
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache) {
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape) {
// collect the following informations to prepare einsum.
LabelMap labelshape(0);
LabelMap labeltype(LabelType::Reduction);
Expand Down Expand Up @@ -439,6 +440,12 @@ void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape);
out->set_dims(make_ddim(output_dims));
out->set_dtype(inputs[0]->dtype());
for (size_t i = 0; i < xshape.size(); ++i) {
if (xshape[i] != nullptr) {
xshape[i]->set_dims(inputs[i]->dims());
xshape[i]->set_dtype(inputs[i]->dtype());
}
}
}

void ExpandInferMeta(const MetaTensor& x,
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ void EighInferMeta(const MetaTensor& x,
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache);
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape);

void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape,
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/einsum_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void EinsumKernelRaw(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out,
std::vector<DenseTensor*> cache);
std::vector<DenseTensor*> inner_cache,
std::vector<DenseTensor*> xshape);

} // namespace phi
1 change: 0 additions & 1 deletion paddle/phi/kernels/impl/einsum_grad_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ void EinsumGradKernel(const Context& dev_ctx,
cache[0].ShareBufferWith(*(inner_cache[0]));
cache[1].ShareBufferWith(*(inner_cache[1]));
}

EinsumKernelImpl<T, Context>(dev_ctx,
all_labels,
operands_for_A,
Expand Down
12 changes: 9 additions & 3 deletions paddle/phi/kernels/impl/einsum_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ DenseTensor PerformContraction(
}
// reduction
DenseTensor trans_t;
if (FLAGS_einsum_opt && use_cache && cache[operand_idx] != nullptr &&
if (use_cache && cache[operand_idx] != nullptr &&
cache[operand_idx]->IsInitialized()) {
trans_t.ShareBufferWith(*(cache[operand_idx]));
VLOG(5) << "Cache Used!";
Expand All @@ -468,7 +468,7 @@ DenseTensor PerformContraction(
dev_ctx, t, perm, all_labels, ellipsis, label2type);
trans_t = PerformTranspose<T, Context>(
dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type);
if (FLAGS_einsum_opt && cache[operand_idx] != nullptr)
if (cache[operand_idx] != nullptr)
cache[operand_idx]->ShareBufferWith(trans_t);
}
auto mul_dims = GetShapeByType<int>(all_labels,
Expand Down Expand Up @@ -599,6 +599,11 @@ void EinsumKernelImpl(const Context& dev_ctx,
out);
// Reshape Procedure
} else if (inputs.size() == 1) {
if (cache[0] != nullptr) { // For compatibility, may be cache is nullptr if
// loading the program from v2.3.0
(*cache[0]) = *(inputs[0]); // ShareBuffer for backward, because backward
// we can only see cached tensor.
}
auto reduce_A = PerformReduction<T, Context>(dev_ctx,
*inputs[0],
label2perms[0],
Expand Down Expand Up @@ -627,7 +632,8 @@ void EinsumKernelRaw(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out,
std::vector<DenseTensor*> cache) {
std::vector<DenseTensor*> cache,
std::vector<DenseTensor*> xshape) {
std::vector<char> tmp;
// for the sake of compatibility, we may load and run v2.3 EinsumOp. Output
// may have nullptr and the cache.size() is not equal to inputs.size(). refer
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/compat/einsum_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace phi {

KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache"});
"einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache", "XShape"});
}

KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
Expand Down
7 changes: 4 additions & 3 deletions python/paddle/fluid/tests/unittests/test_einsum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def setUp(self):
'Out':
out,
"InnerCache": [('cache_' + str(i), np.array([1.0]))
for i in range(len(self.operands))]
for i in range(len(self.operands))],
"XShape": [('xshape_' + str(i), np.array([1.0]))
for i in range(len(self.operands))],
}

def init_input(self):
Expand All @@ -48,14 +50,13 @@ def init_input(self):
self.inputs.append(np.random.random(s).astype(t))

def set_mandatory(self):
self.disable = False
self.shapes = [(10, 10, 20), (20, 6)]
self.types = [np.float64, np.float64]
self.equation = "mij,jk->ki"

def test_check_output(self):
if not self.disable:
self.check_output(no_check_set=["InnerCache"])
self.check_output(no_check_set=["InnerCache", "XShape"])

def test_grad(self):
if not self.disable:
Expand Down
11 changes: 8 additions & 3 deletions python/paddle/tensor/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,9 +807,9 @@ def gen_einsum_op(equation, *operands):

if _in_legacy_dygraph():
# dygraph
return _C_ops.einsum(operands, len(operands), 'equation', equation)[0]
return _C_ops.einsum(operands, len(operands), len(operands), 'equation',
equation)[0]

# static graph
for inp in operands:
check_variable_and_dtype(inp, 'dtype', ['float32', 'float64'], 'einsum')
check_type(equation, 'equation', str, 'einsum')
Expand All @@ -821,11 +821,16 @@ def gen_einsum_op(equation, *operands):
helper.create_variable_for_type_inference(dtype=operands[0].dtype)
for i in range(len(operands))
]
xshape = [
helper.create_variable_for_type_inference(dtype=operands[0].dtype)
for i in range(len(operands))
]
helper.append_op(type='einsum',
inputs={'Operands': operands},
outputs={
'Out': out,
"InnerCache": caches
"InnerCache": caches,
"XShape": xshape
},
attrs=attrs)
return out
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/utils/code_gen/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@

- api : einsum
args : (Tensor[] x, str equation)
output : Tensor, Tensor[]{x.size()}
output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()}
infer_meta :
func : EinsumInferMeta
param : [x, equation]
Expand Down
17 changes: 14 additions & 3 deletions python/paddle/utils/code_gen/backward.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
#- backward_api : einsum_grad

#forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache)
#args : (Tensor[] x, Tensor[] inner_cache, Tensor out_grad, str equation)
#output : Tensor[](x_grad){x.size()}
#infer_meta :
#func : UnchangedMultiInferMeta
#param : [x]
#kernel :
#func : einsum_grad

- backward_api : abs_double_grad
forward : abs_grad (Tensor x, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor x, Tensor grad_x_grad)
Expand Down Expand Up @@ -616,12 +627,12 @@
skip_transform : out_w, out_w_grad

- backward_api : einsum_grad
forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache)
args : (Tensor[] x, Tensor[] inner_cache, Tensor out_grad, str equation)
forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache), Tensor[](x_shape)
args : (Tensor[] x_shape, Tensor[] inner_cache, Tensor out_grad, str equation)
output : Tensor[](x_grad){x.size()}
infer_meta :
func : UnchangedMultiInferMeta
param : [x]
param : [x_shape]
kernel :
func : einsum_grad

Expand Down

0 comments on commit 83abec6

Please sign in to comment.