@@ -1302,6 +1302,59 @@ struct test_repeat : public test_case {
13021302    }
13031303};
13041304
1305+ //  GGML_OP_REPEAT_BACK
1306+ struct  test_repeat_back  : public  test_case  {
1307+     const  ggml_type type;
1308+     const  std::array<int64_t , 4 > ne;
1309+     const  std::array<int , 4 > nr;
1310+     const  bool  v; //  whether src is a noncontiguous view
1311+ 
1312+     std::string vars () override  {
1313+         return  VARS_TO_STR4 (type, ne, nr, v);
1314+     }
1315+ 
1316+     size_t  op_size (ggml_tensor * t) override  {
1317+         return  ggml_nbytes (t) * 2 ;
1318+     }
1319+ 
1320+     test_repeat_back (ggml_type type = GGML_TYPE_F32,
1321+             std::array<int64_t , 4 > ne = {8 , 6 , 4 , 2 },
1322+             std::array<int , 4 > nr = {2 , 2 , 2 , 2 },
1323+             bool  v = false )
1324+         : type(type), ne(ne), nr(nr), v(v) {}
1325+ 
1326+     ggml_tensor * build_graph (ggml_context * ctx) override  {
1327+         ggml_tensor * src = ggml_new_tensor_4d (ctx, type, ne[0 ]*nr[0 ], ne[1 ]*nr[1 ], ne[2 ]*nr[2 ], ne[3 ]*nr[3 ]);
1328+         ggml_set_name (src, " src"  );
1329+ 
1330+         if  (v) {
1331+             GGML_ASSERT (ne[0 ] % 2  == 0 );
1332+             GGML_ASSERT (ne[1 ] % 2  == 0 );
1333+             GGML_ASSERT (ne[2 ] % 2  == 0 );
1334+             GGML_ASSERT (ne[3 ] % 2  == 0 );
1335+             GGML_ASSERT (nr[0 ] % 2  == 0  || nr[0 ] == 1 );
1336+             GGML_ASSERT (nr[1 ] % 2  == 0  || nr[1 ] == 1 );
1337+             GGML_ASSERT (nr[2 ] % 2  == 0  || nr[2 ] == 1 );
1338+             GGML_ASSERT (nr[3 ] % 2  == 0  || nr[3 ] == 1 );
1339+ 
1340+             const  int64_t  ne00 = nr[0 ] == 1  ? src->ne [0 ] : src->ne [0 ] / 2 ;
1341+             const  int64_t  ne01 = nr[1 ] == 1  ? src->ne [1 ] : src->ne [1 ] / 2 ;
1342+             const  int64_t  ne02 = nr[2 ] == 1  ? src->ne [2 ] : src->ne [2 ] / 2 ;
1343+             const  int64_t  ne03 = nr[3 ] == 1  ? src->ne [3 ] : src->ne [3 ] / 2 ;
1344+ 
1345+             src = ggml_view_4d (ctx, src, ne00, ne01, ne02, ne03, src->nb [1 ], src->nb [2 ], src->nb [3 ], 0 );
1346+         }
1347+ 
1348+         ggml_tensor * target = ggml_new_tensor (ctx, type, 4 , ne.data ());
1349+         ggml_set_name (target, " target"  );
1350+ 
1351+         ggml_tensor * out = ggml_repeat_back (ctx, src, target);
1352+         ggml_set_name (out, " out"  );
1353+ 
1354+         return  out;
1355+     }
1356+ };
1357+ 
13051358//  GGML_OP_DUP
13061359struct  test_dup  : public  test_case  {
13071360    const  ggml_type type;
@@ -1849,6 +1902,10 @@ struct test_mul_mat : public test_case {
18491902        return  5e-4 ;
18501903    }
18511904
1905+     int64_t  grad_nmax () override  {
1906+         return  20000 ;
1907+     }
1908+ 
18521909    uint64_t  op_flops (ggml_tensor * t) override  {
18531910        GGML_UNUSED (t);
18541911        return  2  * m * n * k * bs[0 ] * nr[0 ] * bs[1 ] * nr[1 ];
@@ -1878,8 +1935,12 @@ struct test_mul_mat : public test_case {
18781935
18791936            a = ggml_new_tensor_4d (ctx, type_a, ne_a[per[0 ]], ne_a[per[1 ]], ne_a[per[2 ]], ne_a[per[3 ]]);
18801937            b = ggml_new_tensor_4d (ctx, type_b, ne_b[per[0 ]], ne_b[per[1 ]], ne_b[per[2 ]], ne_b[per[3 ]]);
1881-             ggml_set_param (ctx, a);
1882-             ggml_set_param (ctx, b);
1938+             if  (!ggml_is_quantized (type_a)) {
1939+                 if  (bs[1 ] == 1  && nr[1 ] == 1 ) {
1940+                     ggml_set_param (ctx, a);
1941+                 }
1942+                 ggml_set_param (ctx, b);
1943+             }
18831944            ggml_set_name (a, " a"  );
18841945            ggml_set_name (b, " b"  );
18851946
@@ -1890,8 +1951,12 @@ struct test_mul_mat : public test_case {
18901951        } else  {
18911952            a = ggml_new_tensor_4d (ctx, type_a, k, m, bs[0 ],       bs[1 ]);
18921953            b = ggml_new_tensor_4d (ctx, type_b, k, n, bs[0 ]*nr[0 ], bs[1 ]*nr[1 ]);
1893-             ggml_set_param (ctx, a);
1894-             ggml_set_param (ctx, b);
1954+             if  (!ggml_is_quantized (type_a)) {
1955+                 if  (bs[1 ] == 1  && nr[1 ] == 1 ) {
1956+                     ggml_set_param (ctx, a);
1957+                 }
1958+                 ggml_set_param (ctx, b);
1959+             }
18951960            ggml_set_name (a, " a"  );
18961961            ggml_set_name (b, " b"  );
18971962        }
@@ -3798,6 +3863,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
37983863        test_cases.emplace_back (new  test_repeat (GGML_TYPE_I16, {10 , 5 , 4 , ne3}, {1 , 1 , 1 , 2 }));
37993864    }
38003865
3866+     for  (bool  view : {false , true }) {
3867+         test_cases.emplace_back (new  test_repeat_back (GGML_TYPE_F32, {8 , 6 , 4 , 2 }, {1 , 1 , 1 , 1 }, view));
3868+         test_cases.emplace_back (new  test_repeat_back (GGML_TYPE_F32, {8 , 6 , 4 , 2 }, {2 , 1 , 1 , 1 }, view));
3869+         test_cases.emplace_back (new  test_repeat_back (GGML_TYPE_F32, {8 , 6 , 4 , 2 }, {1 , 2 , 1 , 1 }, view));
3870+         test_cases.emplace_back (new  test_repeat_back (GGML_TYPE_F32, {8 , 6 , 4 , 2 }, {1 , 1 , 2 , 1 }, view));
3871+         test_cases.emplace_back (new  test_repeat_back (GGML_TYPE_F32, {8 , 6 , 4 , 2 }, {1 , 1 , 1 , 2 }, view));
3872+         test_cases.emplace_back (new  test_repeat_back (GGML_TYPE_I32, {8 , 6 , 4 , 2 }, {2 , 1 , 1 , 1 }, view));
3873+         test_cases.emplace_back (new  test_repeat_back (GGML_TYPE_I16, {8 , 6 , 4 , 2 }, {1 , 1 , 1 , 2 }, view));
3874+     }
3875+ 
38013876    test_cases.emplace_back (new  test_dup (GGML_TYPE_F32));
38023877    test_cases.emplace_back (new  test_dup (GGML_TYPE_F16));
38033878    test_cases.emplace_back (new  test_dup (GGML_TYPE_I32));
@@ -3919,21 +3994,25 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
39193994    for  (ggml_type type_a : base_types) {
39203995        for  (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
39213996            //  test cases without permutation
3922-             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , { 1 ,  1 }, {1 , 1 }));
3923-             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {10 ,  1 }, {1 , 1 }));
3924-             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {10 ,  1 }, {2 , 1 }));
3925-             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {10 , 10 }, {1 , 1 }));
3926-             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {10 , 10 }, {2 , 1 }));
3927-             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {10 , 10 }, {1 , 2 }));
3928-             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {10 , 10 }, {2 , 2 }));
3929- 
3930-             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , { 1 ,  1 }, {1 , 1 }));
3931-             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 ,  1 }, {1 , 1 }));
3932-             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 ,  1 }, {2 , 1 }));
3933-             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 10 }, {1 , 1 }));
3934-             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 10 }, {2 , 1 }));
3935-             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 10 }, {1 , 2 }));
3936-             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 10 }, {2 , 2 }));
3997+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {1 , 1 }, {1 , 1 }));
3998+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {1 , 1 }, {2 , 1 }));
3999+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {1 , 1 }, {1 , 2 }));
4000+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {3 , 1 }, {1 , 1 }));
4001+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {3 , 1 }, {2 , 1 }));
4002+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {3 , 2 }, {1 , 1 }));
4003+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {3 , 2 }, {2 , 1 }));
4004+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {3 , 2 }, {1 , 2 }));
4005+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {3 , 2 }, {2 , 2 }));
4006+ 
4007+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {1 , 1 }, {1 , 1 }));
4008+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {1 , 1 }, {2 , 1 }));
4009+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {1 , 1 }, {1 , 2 }));
4010+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {3 , 1 }, {1 , 1 }));
4011+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {3 , 1 }, {2 , 1 }));
4012+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {3 , 2 }, {1 , 1 }));
4013+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {3 , 2 }, {2 , 1 }));
4014+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {3 , 2 }, {1 , 2 }));
4015+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {3 , 2 }, {2 , 2 }));
39374016
39384017            //  test cases with permutation
39394018            test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {2 , 3 }, {1 , 1 }, {0 , 2 , 1 , 3 }));
0 commit comments