Skip to content

Commit d650414

Browse files
authored
[Accuracy diff No.84-85] Fix accuracy diff for paddle.einsum API (#74257)
* fix einsum_grad when contraction with broadcast * add strict check for the situation.
1 parent b910b04 commit d650414

File tree

3 files changed

+85
-6
lines changed

3 files changed

+85
-6
lines changed

paddle/phi/kernels/impl/einsum_grad_kernel_impl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ void EinsumGradKernel(const Context& dev_ctx,
215215
}
216216
EinsumKernelImpl<T, Context>(dev_ctx,
217217
all_labels,
218+
labelshape,
218219
operands_for_A,
219220
equation_for_A,
220221
&dA,
@@ -223,6 +224,7 @@ void EinsumGradKernel(const Context& dev_ctx,
223224

224225
EinsumKernelImpl<T, Context>(dev_ctx,
225226
all_labels,
227+
labelshape,
226228
operands_for_B,
227229
equation_for_B,
228230
&dB,

paddle/phi/kernels/impl/einsum_kernel_impl.h

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,9 @@ inline static void InferLabelShape(
204204
const std::vector<std::string>& op_labels,
205205
const std::vector<DDim>& inputs,
206206
LabelMap* labelshape,
207-
std::vector<std::vector<int64_t>>* broadcast_shapes) {
207+
std::vector<std::vector<int64_t>>* broadcast_shapes,
208+
LabelMap* labeltype) {
209+
LabelMap labelshape_copy = *labelshape;
208210
VLOG(5) << "Start InferLabelShape";
209211
for (size_t i = 0; i < op_labels.size(); ++i) {
210212
auto& op_str = op_labels[i];
@@ -233,6 +235,20 @@ inline static void InferLabelShape(
233235
}
234236
for (size_t i = 0; i < op_labels.size(); ++i) {
235237
for (auto& c : op_labels[i]) {
238+
// Note: When broadcasting is involved, ensure the gradient is calculated
239+
// with respect to the broadcasted shape. For example, in
240+
// einsum("ij,ij->j", x(2,2), y(1,2)), y is broadcast to (2,2). The
241+
// gradient calculation for x must use this broadcasted shape of y.
242+
if (labelshape_copy.exist(c) && labelshape_copy[c] > (*labelshape)[c]) {
243+
// Strict check for the situation.
244+
PADDLE_ENFORCE_EQ(
245+
(*labelshape)[c] == 1 && ((*labeltype)[c] == LabelType::AO ||
246+
(*labeltype)[c] == LabelType::BO),
247+
true,
248+
common::errors::InvalidArgument(
249+
"Broadcast dims must be 1 for label: `%c`", c));
250+
(*labelshape)[c] = labelshape_copy[c];
251+
}
236252
(*broadcast_shapes)[i].push_back((*labelshape)[c]);
237253
}
238254
}
@@ -282,7 +298,7 @@ inline static void ParseEinsumEquation(
282298
// split_string("->") -> [], we push back a "".
283299
if (op_labels.empty()) op_labels.emplace_back("");
284300
GlobalInfo(op_labels, *right, labeltype, all_labels);
285-
InferLabelShape(op_labels, inputs, labelshape, broadcast_shapes);
301+
InferLabelShape(op_labels, inputs, labelshape, broadcast_shapes, labeltype);
286302
VLOG(5) << "Einsum Infershape: right:" << *right;
287303
VLOG(5) << "Einsum Infershape: left :"
288304
<< paddle::string::join_strings(op_labels, '\n');
@@ -603,6 +619,7 @@ DenseTensor TransposeToOutput(const Context& dev_ctx,
603619
template <typename T, typename Context>
604620
void EinsumKernelImpl(const Context& dev_ctx,
605621
const std::vector<char>& forward_all_labels,
622+
const LabelMap& forward_label_shape,
606623
const std::vector<const DenseTensor*>& inputs,
607624
const std::string& equation,
608625
DenseTensor* out,
@@ -629,6 +646,7 @@ void EinsumKernelImpl(const Context& dev_ctx,
629646
std::string right;
630647
if (!is_forward) {
631648
all_labels = forward_all_labels;
649+
labelshape = forward_label_shape;
632650
}
633651
ParseEinsumEquation(equation,
634652
input_dims,
@@ -680,15 +698,22 @@ void EinsumKernel(const Context& dev_ctx,
680698
}
681699
}
682700
std::vector<char> tmp;
701+
LabelMap labelshape_holder;
683702
// for the sake of compatibility, we may load and run v2.3 EinsumOp. Output
684703
// may have nullptr and the cache.size() is not equal to inputs.size(). refer
685704
// to BuildPhiKernelContext for details.
686705
int diff = inputs.size() - cache.size();
687706
for (int i = 0; i < diff; ++i) {
688707
cache.push_back(nullptr);
689708
}
690-
EinsumKernelImpl<T, Context>(
691-
dev_ctx, tmp, inputs, equation, out, cache, /*forward=*/true);
709+
EinsumKernelImpl<T, Context>(dev_ctx,
710+
tmp,
711+
labelshape_holder,
712+
inputs,
713+
equation,
714+
out,
715+
cache,
716+
/*forward=*/true);
692717
}
693718

694719
template <typename T, typename Context>
@@ -697,13 +722,20 @@ void EinsumInferKernel(const Context& dev_ctx,
697722
const std::string& equation,
698723
DenseTensor* out) {
699724
std::vector<char> place_holder;
725+
LabelMap labelshape_holder;
700726
std::vector<DenseTensor*> cache_tensor(
701727
inputs.size()); // set empty; TA, TB, TdC
702728
for (size_t i = 0; i < inputs.size(); ++i) {
703729
cache_tensor[i] = nullptr;
704730
}
705-
EinsumKernelImpl<T, Context>(
706-
dev_ctx, place_holder, inputs, equation, out, cache_tensor, true);
731+
EinsumKernelImpl<T, Context>(dev_ctx,
732+
place_holder,
733+
labelshape_holder,
734+
inputs,
735+
equation,
736+
out,
737+
cache_tensor,
738+
true);
707739
}
708740

709741
} // namespace phi

test/legacy_test/test_einsum.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,5 +532,50 @@ def test_static_graph(self):
532532
self.check_output_equal(a, e)
533533

534534

535+
class TestContractionBroadcastGrad(unittest.TestCase):
536+
def setUp(self):
537+
self.place = (
538+
paddle.CUDAPlace(0)
539+
if paddle.is_compiled_with_cuda()
540+
else paddle.CPUPlace()
541+
)
542+
543+
def test_case1(self):
544+
with paddle.base.dygraph.guard(self.place):
545+
# paddle.einsum("i, i", Tensor([2],"float32"), Tensor([1],"float32"), )
546+
x_np = np.array([0.1, 0.2]).astype(np.float32)
547+
y_np = np.array([0.5]).astype(np.float32)
548+
except_res = np.einsum("i, i", x_np, y_np)
549+
except_grad_x = np.array([0.5, 0.5]).astype(np.float32)
550+
except_grad_y = np.array([0.3]).astype(np.float32)
551+
x = paddle.to_tensor(x_np, stop_gradient=False)
552+
y = paddle.to_tensor(y_np, stop_gradient=False)
553+
res = paddle.einsum("i, i", x, y)
554+
np.testing.assert_allclose(res.numpy(), except_res)
555+
res.sum().backward()
556+
x.grad.get_tensor() # To check if accessing unallocated memory
557+
np.testing.assert_allclose(x.grad.numpy(), except_grad_x)
558+
np.testing.assert_allclose(y.grad.numpy(), except_grad_y)
559+
560+
def test_case2(self):
561+
with paddle.base.dygraph.guard(self.place):
562+
# paddle.einsum("ij,ij->j", Tensor([2, 2],"float32"), Tensor([1, 2],"float32"), )
563+
x_np = np.array([[0.1, 0.2], [0.3, 0.4]]).astype(np.float32)
564+
y_np = np.array([[0.5, 0.6]]).astype(np.float32)
565+
except_res = np.einsum("ij,ij->j", x_np, y_np)
566+
except_grad_x = np.array([[0.5, 0.6], [0.5, 0.6]]).astype(
567+
np.float32
568+
)
569+
except_grad_y = np.array([[0.4, 0.6]]).astype(np.float32)
570+
x = paddle.to_tensor(x_np, stop_gradient=False)
571+
y = paddle.to_tensor(y_np, stop_gradient=False)
572+
res = paddle.einsum("ij,ij->j", x, y)
573+
np.testing.assert_allclose(res.numpy(), except_res)
574+
res.sum().backward()
575+
x.grad.get_tensor() # To check if accessing unallocated memory
576+
np.testing.assert_allclose(x.grad.numpy(), except_grad_x)
577+
np.testing.assert_allclose(y.grad.numpy(), except_grad_y)
578+
579+
535580
if __name__ == "__main__":
536581
unittest.main()

0 commit comments

Comments
 (0)