@@ -2305,14 +2305,17 @@ struct ggml_tensor * ggml_repeat(
23052305struct  ggml_tensor  *  ggml_repeat_back (
23062306        struct  ggml_context  *  ctx ,
23072307        struct  ggml_tensor   *  a ,
2308-         struct  ggml_tensor   *  b ) {
2308+         struct  ggml_tensor   *  b ,
2309+         bool                   adjacent ) {
23092310    GGML_ASSERT (ggml_can_repeat (b , a ));
23102311
23112312    struct  ggml_tensor  *  result  =  ggml_new_tensor (ctx , a -> type , GGML_MAX_DIMS , b -> ne );
23122313
23132314    result -> op      =  GGML_OP_REPEAT_BACK ;
23142315    result -> src [0 ] =  a ;
23152316
2317+     result -> op_params [0 ] =  adjacent  ? 1  : 0 ;
2318+ 
23162319    return  result ;
23172320}
23182321
@@ -5299,7 +5302,7 @@ static void ggml_compute_backward(
52995302            if  (src1_needs_grads ) {
53005303                struct  ggml_tensor  *  tmp  =  grad ;
53015304                if  (!ggml_are_same_shape (src0 , src1 )) {
5302-                     tmp  =  ggml_repeat_back (ctx , tmp , src1 );
5305+                     tmp  =  ggml_repeat_back (ctx , tmp , src1 , false );
53035306                }
53045307                ggml_add_or_set (ctx , cgraph , isrc1 , tmp );
53055308            }
@@ -5339,12 +5342,12 @@ static void ggml_compute_backward(
53395342        } break ;
53405343        case  GGML_OP_MUL : {
53415344            if  (src0_needs_grads ) {
5342-                 ggml_add_or_set (ctx , cgraph , isrc0 , ggml_mul (ctx , src1 ,  grad ));
5345+                 ggml_add_or_set (ctx , cgraph , isrc0 , ggml_mul (ctx , grad ,  src1 ));
53435346            }
53445347            if  (src1_needs_grads ) {
53455348                struct  ggml_tensor  *  tmp  =  ggml_mul (ctx , src0 , grad );
53465349                if  (!ggml_are_same_shape (src0 , src1 )) {
5347-                     tmp  =  ggml_repeat_back (ctx , tmp , src1 );
5350+                     tmp  =  ggml_repeat_back (ctx , tmp , src1 , false );
53485351                }
53495352                ggml_add_or_set (ctx , cgraph , isrc1 , tmp );
53505353            }
@@ -5399,7 +5402,7 @@ static void ggml_compute_backward(
53995402        } break ;
54005403        case  GGML_OP_REPEAT : {
54015404            if  (src0_needs_grads ) {
5402-                 ggml_add_or_set (ctx , cgraph , isrc0 , ggml_repeat_back (ctx , grad , src0 ));
5405+                 ggml_add_or_set (ctx , cgraph , isrc0 , ggml_repeat_back (ctx , grad , src0 , false ));
54035406            }
54045407        } break ;
54055408        case  GGML_OP_REPEAT_BACK : {
@@ -5431,21 +5434,18 @@ static void ggml_compute_backward(
54315434            // src1.shape   [n,p,qq,rr] 
54325435
54335436            if  (src0_needs_grads ) {
5434-                 struct  ggml_tensor  *  s1_tg  = 
5437+                 GGML_ASSERT (grad -> ne [2 ] ==  src1 -> ne [2 ]);
5438+                 GGML_ASSERT (grad -> ne [3 ] ==  src1 -> ne [3 ]);
5439+                 struct  ggml_tensor  *  tmp  = 
54355440                    ggml_out_prod (ctx , // [n,m,qq,rr] 
54365441                        src1 ,          // [n,p,qq,rr] 
54375442                        grad );         // [m,p,qq,rr] 
5438-                 const  int64_t  qq  =  s1_tg -> ne [2 ];
5439-                 const  int64_t  rr  =  s1_tg -> ne [3 ];
5440-                 const  int64_t  q1  =  src0 -> ne [2 ];
5441-                 const  int64_t  r1  =  src0 -> ne [3 ];
5442-                 const  bool  ne2_broadcasted  =  qq  >  q1 ;
5443-                 const  bool  ne3_broadcasted  =  rr  >  r1 ;
5444-                 if  (ne2_broadcasted  ||  ne3_broadcasted ) {
5445-                     // sum broadcast repetitions of s1_tg into shape of src0 
5446-                     s1_tg  =  ggml_repeat_back (ctx , s1_tg , src0 );
5443+                 if  (!ggml_are_same_shape (tmp , src0 )) {
5444+                     GGML_ASSERT (tmp -> ne [0 ] ==  src0 -> ne [0 ]);
5445+                     GGML_ASSERT (tmp -> ne [1 ] ==  src0 -> ne [1 ]);
5446+                     tmp  =  ggml_repeat_back (ctx , tmp , src0 , true);
54475447                }
5448-                 ggml_add_or_set (ctx , cgraph , isrc0 , s1_tg   /*= [n,m,q1,r1]*/ );
5448+                 ggml_add_or_set (ctx , cgraph , isrc0 , tmp );
54495449            }
54505450            if  (src1_needs_grads ) {
54515451                ggml_add_or_set (ctx , cgraph , isrc1 ,
@@ -5514,7 +5514,9 @@ static void ggml_compute_backward(
55145514            if  (src0_needs_grads ) {
55155515                GGML_ASSERT (!cgraph -> grads [isrc0 ] ||  ggml_is_contiguous (cgraph -> grads [isrc0 ]));
55165516                GGML_ASSERT (ggml_is_contiguous (grad ));
5517-                 ggml_add_or_set (ctx , cgraph , isrc0 , grad );
5517+                 GGML_ASSERT (ggml_nelements (tensor ) ==  ggml_nelements (src0 ));
5518+                 ggml_add_or_set (ctx , cgraph , isrc0 ,
5519+                     ggml_are_same_shape (tensor , src0 ) ? grad  : ggml_reshape (ctx , grad , src0 ));
55185520            }
55195521        } break ;
55205522        case  GGML_OP_RESHAPE : {
0 commit comments