@@ -204,7 +204,9 @@ inline static void InferLabelShape(
204204 const std::vector<std::string>& op_labels,
205205 const std::vector<DDim>& inputs,
206206 LabelMap* labelshape,
207- std::vector<std::vector<int64_t >>* broadcast_shapes) {
207+ std::vector<std::vector<int64_t >>* broadcast_shapes,
208+ LabelMap* labeltype) {
209+ LabelMap labelshape_copy = *labelshape;
208210 VLOG (5 ) << " Start InferLabelShape" ;
209211 for (size_t i = 0 ; i < op_labels.size (); ++i) {
210212 auto & op_str = op_labels[i];
@@ -233,6 +235,20 @@ inline static void InferLabelShape(
233235 }
234236 for (size_t i = 0 ; i < op_labels.size (); ++i) {
235237 for (auto & c : op_labels[i]) {
238+ // Note: When broadcasting is involved, ensure the gradient is calculated
239+ // with respect to the broadcasted shape. For example, in
240+ // einsum("ij,ij->j", x(2,2), y(1,2)), y is broadcast to (2,2). The
241+ // gradient calculation for x must use this broadcasted shape of y.
242+ if (labelshape_copy.exist (c) && labelshape_copy[c] > (*labelshape)[c]) {
243+ // Strict check for the situation.
244+ PADDLE_ENFORCE_EQ (
245+ (*labelshape)[c] == 1 && ((*labeltype)[c] == LabelType::AO ||
246+ (*labeltype)[c] == LabelType::BO),
247+ true ,
248+ common::errors::InvalidArgument (
249+ " Broadcast dims must be 1 for label: `%c`" , c));
250+ (*labelshape)[c] = labelshape_copy[c];
251+ }
236252 (*broadcast_shapes)[i].push_back ((*labelshape)[c]);
237253 }
238254 }
@@ -282,7 +298,7 @@ inline static void ParseEinsumEquation(
282298 // split_string("->") -> [], we push back a "".
283299 if (op_labels.empty ()) op_labels.emplace_back (" " );
284300 GlobalInfo (op_labels, *right, labeltype, all_labels);
285- InferLabelShape (op_labels, inputs, labelshape, broadcast_shapes);
301+ InferLabelShape (op_labels, inputs, labelshape, broadcast_shapes, labeltype );
286302 VLOG (5 ) << " Einsum Infershape: right:" << *right;
287303 VLOG (5 ) << " Einsum Infershape: left :"
288304 << paddle::string::join_strings (op_labels, ' \n ' );
@@ -603,6 +619,7 @@ DenseTensor TransposeToOutput(const Context& dev_ctx,
603619template <typename T, typename Context>
604620void EinsumKernelImpl (const Context& dev_ctx,
605621 const std::vector<char >& forward_all_labels,
622+ const LabelMap& forward_label_shape,
606623 const std::vector<const DenseTensor*>& inputs,
607624 const std::string& equation,
608625 DenseTensor* out,
@@ -629,6 +646,7 @@ void EinsumKernelImpl(const Context& dev_ctx,
629646 std::string right;
630647 if (!is_forward) {
631648 all_labels = forward_all_labels;
649+ labelshape = forward_label_shape;
632650 }
633651 ParseEinsumEquation (equation,
634652 input_dims,
@@ -680,15 +698,22 @@ void EinsumKernel(const Context& dev_ctx,
680698 }
681699 }
682700 std::vector<char > tmp;
701+ LabelMap labelshape_holder;
683702 // for the sake of compatibility, we may load and run v2.3 EinsumOp. Output
684703 // may have nullptr and the cache.size() is not equal to inputs.size(). refer
685704 // to BuildPhiKernelContext for details.
686705 int diff = inputs.size () - cache.size ();
687706 for (int i = 0 ; i < diff; ++i) {
688707 cache.push_back (nullptr );
689708 }
690- EinsumKernelImpl<T, Context>(
691- dev_ctx, tmp, inputs, equation, out, cache, /* forward=*/ true );
709+ EinsumKernelImpl<T, Context>(dev_ctx,
710+ tmp,
711+ labelshape_holder,
712+ inputs,
713+ equation,
714+ out,
715+ cache,
716+ /* forward=*/ true );
692717}
693718
694719template <typename T, typename Context>
@@ -697,13 +722,20 @@ void EinsumInferKernel(const Context& dev_ctx,
697722 const std::string& equation,
698723 DenseTensor* out) {
699724 std::vector<char > place_holder;
725+ LabelMap labelshape_holder;
700726 std::vector<DenseTensor*> cache_tensor (
701727 inputs.size ()); // set empty; TA, TB, TdC
702728 for (size_t i = 0 ; i < inputs.size (); ++i) {
703729 cache_tensor[i] = nullptr ;
704730 }
705- EinsumKernelImpl<T, Context>(
706- dev_ctx, place_holder, inputs, equation, out, cache_tensor, true );
731+ EinsumKernelImpl<T, Context>(dev_ctx,
732+ place_holder,
733+ labelshape_holder,
734+ inputs,
735+ equation,
736+ out,
737+ cache_tensor,
738+ true );
707739}
708740
709741} // namespace phi
0 commit comments