@@ -546,10 +546,51 @@ DenseTensor PerformContraction(
546546 trans_t = PerformTranspose<T, Context>(
547547 dev_ctx, reduct_t , perm, reordered_all_labels, label2type);
548548 if (cache[operand_idx] != nullptr ) {
549- cache[operand_idx]->ShareBufferWith (trans_t );
550- cache[operand_idx]->Resize (trans_t .dims ());
551- VLOG (5 ) << " Set dims of cache[" << operand_idx
552- << " ]: " << trans_t .dims ();
549+ std::vector<int64_t > broadcast_shapes_restore (
550+ broadcast_shapes[operand_idx].size ());
551+
552+ auto contraction_dim1 =
553+ [&](const std::vector<int64_t >& broadcast_shapes,
554+ const std::vector<int64_t >& original_shapes) -> bool {
555+ bool found = false ;
556+ for (size_t i = 0 ; i < broadcast_shapes.size (); ++i) {
557+ if (broadcast_shapes[i] != original_shapes[i] &&
558+ label2type[input_strs[operand_idx][i]] ==
559+ LabelType::Contraction) {
560+ broadcast_shapes_restore[i] = original_shapes[i];
561+ found = true ;
562+ } else {
563+ broadcast_shapes_restore[i] = broadcast_shapes[i];
564+ }
565+ }
566+ return found;
567+ };
568+ if (!contraction_dim1 (broadcast_shapes[operand_idx],
569+ common::vectorize<int64_t >(t.dims ()))) {
570+ cache[operand_idx]->ShareBufferWith (trans_t );
571+ cache[operand_idx]->Resize (trans_t .dims ());
572+ VLOG (5 ) << " Set dims of cache[" << operand_idx
573+ << " ]: " << trans_t .dims ();
574+ } else {
575+ auto reduct_t_for_cache =
576+ PerformDiagonalAndReduction<T, Context>(dev_ctx,
577+ t,
578+ input_strs[operand_idx],
579+ perm,
580+ all_labels,
581+ broadcast_shapes_restore,
582+ label2type);
583+ DenseTensor trans_t_for_cache;
584+ trans_t_for_cache = PerformTranspose<T, Context>(dev_ctx,
585+ reduct_t_for_cache,
586+ perm,
587+ reordered_all_labels,
588+ label2type);
589+ cache[operand_idx]->ShareBufferWith (trans_t_for_cache);
590+ cache[operand_idx]->Resize (trans_t_for_cache.dims ());
591+ VLOG (5 ) << " Set dims of cache[" << operand_idx
592+ << " ]: " << trans_t_for_cache.dims ();
593+ }
553594 }
554595 }
555596 auto mul_dims = GetShapeByType<int64_t >(
0 commit comments