Skip to content

Commit 6827b0d

Browse files
authored
Fix einsum accuracy diff when with contraction dim (#74303)
* fix einsum accuracy diff when with contraction dim * refine
1 parent c2a5776 commit 6827b0d

File tree

2 files changed

+64
-5
lines changed

2 files changed

+64
-5
lines changed

paddle/phi/kernels/impl/einsum_grad_kernel_impl.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,35 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx,
4343
std::vector<int64_t> repeat_times;
4444
std::vector<int64_t> resize_dims;
4545
std::vector<int64_t> recover_shape;
46-
for (int c : op_label) {
46+
std::vector<int64_t> t_shape = common::vectorize<int64_t>(t.dims());
47+
for (int i = 0; i < op_label.size(); i++) {
48+
int c = op_label[i];
4749
if (label2type[c] == LabelType::Reduction) {
4850
repeat_times.push_back(label2shape[c]);
4951
resize_dims.push_back(1);
5052
recover_shape.push_back(label2shape[c]);
53+
t_shape.insert(t_shape.begin() + i, 1);
5154
} else {
5255
resize_dims.push_back(label2shape[c]);
5356
repeat_times.push_back(1);
5457
recover_shape.push_back(label2shape[c]);
5558
}
5659
}
60+
PADDLE_ENFORCE_EQ(op_label.size(),
61+
t_shape.size(),
62+
common::errors::InvalidArgument(
63+
"Input shape size doesn't match label nums, input "
64+
"shape size: `%d`, but got label nums: `%d`",
65+
t_shape.size(),
66+
op_label.size()));
67+
for (int i = 0; i < op_label.size(); i++) {
68+
int c = op_label[i];
69+
if (label2type[c] == LabelType::Contraction &&
70+
t_shape[i] != label2shape[c]) {
71+
repeat_times[i] = label2shape[c];
72+
resize_dims[i] = 1;
73+
}
74+
}
5775
t.Resize(common::make_ddim(resize_dims));
5876
DenseTensor after_tile;
5977
if (std::all_of(repeat_times.begin(), repeat_times.end(), [](int64_t x) {

paddle/phi/kernels/impl/einsum_kernel_impl.h

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -546,10 +546,51 @@ DenseTensor PerformContraction(
546546
trans_t = PerformTranspose<T, Context>(
547547
dev_ctx, reduct_t, perm, reordered_all_labels, label2type);
548548
if (cache[operand_idx] != nullptr) {
549-
cache[operand_idx]->ShareBufferWith(trans_t);
550-
cache[operand_idx]->Resize(trans_t.dims());
551-
VLOG(5) << "Set dims of cache[" << operand_idx
552-
<< "]: " << trans_t.dims();
549+
std::vector<int64_t> broadcast_shapes_restore(
550+
broadcast_shapes[operand_idx].size());
551+
552+
auto contraction_dim1 =
553+
[&](const std::vector<int64_t>& broadcast_shapes,
554+
const std::vector<int64_t>& original_shapes) -> bool {
555+
bool found = false;
556+
for (size_t i = 0; i < broadcast_shapes.size(); ++i) {
557+
if (broadcast_shapes[i] != original_shapes[i] &&
558+
label2type[input_strs[operand_idx][i]] ==
559+
LabelType::Contraction) {
560+
broadcast_shapes_restore[i] = original_shapes[i];
561+
found = true;
562+
} else {
563+
broadcast_shapes_restore[i] = broadcast_shapes[i];
564+
}
565+
}
566+
return found;
567+
};
568+
if (!contraction_dim1(broadcast_shapes[operand_idx],
569+
common::vectorize<int64_t>(t.dims()))) {
570+
cache[operand_idx]->ShareBufferWith(trans_t);
571+
cache[operand_idx]->Resize(trans_t.dims());
572+
VLOG(5) << "Set dims of cache[" << operand_idx
573+
<< "]: " << trans_t.dims();
574+
} else {
575+
auto reduct_t_for_cache =
576+
PerformDiagonalAndReduction<T, Context>(dev_ctx,
577+
t,
578+
input_strs[operand_idx],
579+
perm,
580+
all_labels,
581+
broadcast_shapes_restore,
582+
label2type);
583+
DenseTensor trans_t_for_cache;
584+
trans_t_for_cache = PerformTranspose<T, Context>(dev_ctx,
585+
reduct_t_for_cache,
586+
perm,
587+
reordered_all_labels,
588+
label2type);
589+
cache[operand_idx]->ShareBufferWith(trans_t_for_cache);
590+
cache[operand_idx]->Resize(trans_t_for_cache.dims());
591+
VLOG(5) << "Set dims of cache[" << operand_idx
592+
<< "]: " << trans_t_for_cache.dims();
593+
}
553594
}
554595
}
555596
auto mul_dims = GetShapeByType<int64_t>(

0 commit comments

Comments
 (0)