diff --git a/paddle/phi/kernels/impl/einsum_grad_kernel_impl.h b/paddle/phi/kernels/impl/einsum_grad_kernel_impl.h index 089dea7ae4e8da..0354e28761ab9a 100644 --- a/paddle/phi/kernels/impl/einsum_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/einsum_grad_kernel_impl.h @@ -43,17 +43,35 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx, std::vector repeat_times; std::vector resize_dims; std::vector recover_shape; - for (int c : op_label) { + std::vector t_shape = common::vectorize(t.dims()); + for (int i = 0; i < op_label.size(); i++) { + int c = op_label[i]; if (label2type[c] == LabelType::Reduction) { repeat_times.push_back(label2shape[c]); resize_dims.push_back(1); recover_shape.push_back(label2shape[c]); + t_shape.insert(t_shape.begin() + i, 1); } else { resize_dims.push_back(label2shape[c]); repeat_times.push_back(1); recover_shape.push_back(label2shape[c]); } } + PADDLE_ENFORCE_EQ(op_label.size(), + t_shape.size(), + common::errors::InvalidArgument( + "Input shape size doesn't match label nums, input " + "shape size: `%d`, but got label nums: `%d`", + t_shape.size(), + op_label.size())); + for (int i = 0; i < op_label.size(); i++) { + int c = op_label[i]; + if (label2type[c] == LabelType::Contraction && + t_shape[i] != label2shape[c]) { + repeat_times[i] = label2shape[c]; + resize_dims[i] = 1; + } + } t.Resize(common::make_ddim(resize_dims)); DenseTensor after_tile; if (std::all_of(repeat_times.begin(), repeat_times.end(), [](int64_t x) { diff --git a/paddle/phi/kernels/impl/einsum_kernel_impl.h b/paddle/phi/kernels/impl/einsum_kernel_impl.h index 4ea9be82f1f571..d561b92fb86ccf 100644 --- a/paddle/phi/kernels/impl/einsum_kernel_impl.h +++ b/paddle/phi/kernels/impl/einsum_kernel_impl.h @@ -546,10 +546,51 @@ DenseTensor PerformContraction( trans_t = PerformTranspose( dev_ctx, reduct_t, perm, reordered_all_labels, label2type); if (cache[operand_idx] != nullptr) { - cache[operand_idx]->ShareBufferWith(trans_t); - cache[operand_idx]->Resize(trans_t.dims()); - VLOG(5) << "Set dims of cache[" << operand_idx - << "]: " << trans_t.dims(); + std::vector broadcast_shapes_restore( + broadcast_shapes[operand_idx].size()); + + auto contraction_dim1 = + [&](const std::vector& broadcast_shapes, + const std::vector& original_shapes) -> bool { + bool found = false; + for (size_t i = 0; i < broadcast_shapes.size(); ++i) { + if (broadcast_shapes[i] != original_shapes[i] && + label2type[input_strs[operand_idx][i]] == + LabelType::Contraction) { + broadcast_shapes_restore[i] = original_shapes[i]; + found = true; + } else { + broadcast_shapes_restore[i] = broadcast_shapes[i]; + } + } + return found; + }; + if (!contraction_dim1(broadcast_shapes[operand_idx], + common::vectorize(t.dims()))) { + cache[operand_idx]->ShareBufferWith(trans_t); + cache[operand_idx]->Resize(trans_t.dims()); + VLOG(5) << "Set dims of cache[" << operand_idx + << "]: " << trans_t.dims(); + } else { + auto reduct_t_for_cache = + PerformDiagonalAndReduction(dev_ctx, + t, + input_strs[operand_idx], + perm, + all_labels, + broadcast_shapes_restore, + label2type); + DenseTensor trans_t_for_cache; + trans_t_for_cache = PerformTranspose(dev_ctx, + reduct_t_for_cache, + perm, + reordered_all_labels, + label2type); + cache[operand_idx]->ShareBufferWith(trans_t_for_cache); + cache[operand_idx]->Resize(trans_t_for_cache.dims()); + VLOG(5) << "Set dims of cache[" << operand_idx + << "]: " << trans_t_for_cache.dims(); + } } } auto mul_dims = GetShapeByType(