@@ -1302,6 +1302,59 @@ struct test_repeat : public test_case {
13021302 }
13031303};
13041304
1305+ // GGML_OP_REPEAT_BACK
1306+ struct test_repeat_back : public test_case {
1307+ const ggml_type type;
1308+ const std::array<int64_t , 4 > ne;
1309+ const std::array<int , 4 > nr;
1310+ const bool v; // whether src is a noncontiguous view
1311+
1312+ std::string vars () override {
1313+ return VARS_TO_STR4 (type, ne, nr, v);
1314+ }
1315+
1316+ size_t op_size (ggml_tensor * t) override {
1317+ return ggml_nbytes (t) * 2 ;
1318+ }
1319+
1320+ test_repeat_back (ggml_type type = GGML_TYPE_F32,
1321+ std::array<int64_t , 4 > ne = {8 , 6 , 4 , 2 },
1322+ std::array<int , 4 > nr = {2 , 2 , 2 , 2 },
1323+ bool v = false )
1324+ : type(type), ne(ne), nr(nr), v(v) {}
1325+
1326+ ggml_tensor * build_graph (ggml_context * ctx) override {
1327+ ggml_tensor * src = ggml_new_tensor_4d (ctx, type, ne[0 ]*nr[0 ], ne[1 ]*nr[1 ], ne[2 ]*nr[2 ], ne[3 ]*nr[3 ]);
1328+ ggml_set_name (src, " src" );
1329+
1330+ if (v) {
1331+ GGML_ASSERT (ne[0 ] % 2 == 0 );
1332+ GGML_ASSERT (ne[1 ] % 2 == 0 );
1333+ GGML_ASSERT (ne[2 ] % 2 == 0 );
1334+ GGML_ASSERT (ne[3 ] % 2 == 0 );
1335+ GGML_ASSERT (nr[0 ] % 2 == 0 || nr[0 ] == 1 );
1336+ GGML_ASSERT (nr[1 ] % 2 == 0 || nr[1 ] == 1 );
1337+ GGML_ASSERT (nr[2 ] % 2 == 0 || nr[2 ] == 1 );
1338+ GGML_ASSERT (nr[3 ] % 2 == 0 || nr[3 ] == 1 );
1339+
1340+ const int64_t ne00 = nr[0 ] == 1 ? src->ne [0 ] : src->ne [0 ] / 2 ;
1341+ const int64_t ne01 = nr[1 ] == 1 ? src->ne [1 ] : src->ne [1 ] / 2 ;
1342+ const int64_t ne02 = nr[2 ] == 1 ? src->ne [2 ] : src->ne [2 ] / 2 ;
1343+ const int64_t ne03 = nr[3 ] == 1 ? src->ne [3 ] : src->ne [3 ] / 2 ;
1344+
1345+ src = ggml_view_4d (ctx, src, ne00, ne01, ne02, ne03, src->nb [1 ], src->nb [2 ], src->nb [3 ], 0 );
1346+ }
1347+
1348+ ggml_tensor * target = ggml_new_tensor (ctx, type, 4 , ne.data ());
1349+ ggml_set_name (target, " target" );
1350+
1351+ ggml_tensor * out = ggml_repeat_back (ctx, src, target);
1352+ ggml_set_name (out, " out" );
1353+
1354+ return out;
1355+ }
1356+ };
1357+
13051358// GGML_OP_DUP
13061359struct test_dup : public test_case {
13071360 const ggml_type type;
@@ -1849,6 +1902,10 @@ struct test_mul_mat : public test_case {
18491902 return 5e-4 ;
18501903 }
18511904
1905+ int64_t grad_nmax () override {
1906+ return 20000 ;
1907+ }
1908+
18521909 uint64_t op_flops (ggml_tensor * t) override {
18531910 GGML_UNUSED (t);
18541911 return 2 * m * n * k * bs[0 ] * nr[0 ] * bs[1 ] * nr[1 ];
@@ -1878,8 +1935,12 @@ struct test_mul_mat : public test_case {
18781935
18791936 a = ggml_new_tensor_4d (ctx, type_a, ne_a[per[0 ]], ne_a[per[1 ]], ne_a[per[2 ]], ne_a[per[3 ]]);
18801937 b = ggml_new_tensor_4d (ctx, type_b, ne_b[per[0 ]], ne_b[per[1 ]], ne_b[per[2 ]], ne_b[per[3 ]]);
1881- ggml_set_param (ctx, a);
1882- ggml_set_param (ctx, b);
1938+ if (!ggml_is_quantized (type_a)) {
1939+ if (bs[1 ] == 1 && nr[1 ] == 1 ) {
1940+ ggml_set_param (ctx, a);
1941+ }
1942+ ggml_set_param (ctx, b);
1943+ }
18831944 ggml_set_name (a, " a" );
18841945 ggml_set_name (b, " b" );
18851946
@@ -1890,8 +1951,12 @@ struct test_mul_mat : public test_case {
18901951 } else {
18911952 a = ggml_new_tensor_4d (ctx, type_a, k, m, bs[0 ], bs[1 ]);
18921953 b = ggml_new_tensor_4d (ctx, type_b, k, n, bs[0 ]*nr[0 ], bs[1 ]*nr[1 ]);
1893- ggml_set_param (ctx, a);
1894- ggml_set_param (ctx, b);
1954+ if (!ggml_is_quantized (type_a)) {
1955+ if (bs[1 ] == 1 && nr[1 ] == 1 ) {
1956+ ggml_set_param (ctx, a);
1957+ }
1958+ ggml_set_param (ctx, b);
1959+ }
18951960 ggml_set_name (a, " a" );
18961961 ggml_set_name (b, " b" );
18971962 }
@@ -3798,6 +3863,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
37983863 test_cases.emplace_back (new test_repeat (GGML_TYPE_I16, {10 , 5 , 4 , ne3}, {1 , 1 , 1 , 2 }));
37993864 }
38003865
3866+ for (bool view : {false , true }) {
3867+ test_cases.emplace_back (new test_repeat_back (GGML_TYPE_F32, {8 , 6 , 4 , 2 }, {1 , 1 , 1 , 1 }, view));
3868+ test_cases.emplace_back (new test_repeat_back (GGML_TYPE_F32, {8 , 6 , 4 , 2 }, {2 , 1 , 1 , 1 }, view));
3869+ test_cases.emplace_back (new test_repeat_back (GGML_TYPE_F32, {8 , 6 , 4 , 2 }, {1 , 2 , 1 , 1 }, view));
3870+ test_cases.emplace_back (new test_repeat_back (GGML_TYPE_F32, {8 , 6 , 4 , 2 }, {1 , 1 , 2 , 1 }, view));
3871+ test_cases.emplace_back (new test_repeat_back (GGML_TYPE_F32, {8 , 6 , 4 , 2 }, {1 , 1 , 1 , 2 }, view));
3872+ test_cases.emplace_back (new test_repeat_back (GGML_TYPE_I32, {8 , 6 , 4 , 2 }, {2 , 1 , 1 , 1 }, view));
3873+ test_cases.emplace_back (new test_repeat_back (GGML_TYPE_I16, {8 , 6 , 4 , 2 }, {1 , 1 , 1 , 2 }, view));
3874+ }
3875+
38013876 test_cases.emplace_back (new test_dup (GGML_TYPE_F32));
38023877 test_cases.emplace_back (new test_dup (GGML_TYPE_F16));
38033878 test_cases.emplace_back (new test_dup (GGML_TYPE_I32));
@@ -3919,21 +3994,25 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
39193994 for (ggml_type type_a : base_types) {
39203995 for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
39213996 // test cases without permutation
3922- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , { 1 , 1 }, {1 , 1 }));
3923- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 1 }, {1 , 1 }));
3924- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 1 }, {2 , 1 }));
3925- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {1 , 1 }));
3926- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {2 , 1 }));
3927- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {1 , 2 }));
3928- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {2 , 2 }));
3929-
3930- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , { 1 , 1 }, {1 , 1 }));
3931- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 1 }, {1 , 1 }));
3932- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 1 }, {2 , 1 }));
3933- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 10 }, {1 , 1 }));
3934- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 10 }, {2 , 1 }));
3935- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 10 }, {1 , 2 }));
3936- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 10 }, {2 , 2 }));
3997+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {1 , 1 }, {1 , 1 }));
3998+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {1 , 1 }, {2 , 1 }));
3999+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {1 , 1 }, {1 , 2 }));
4000+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {3 , 1 }, {1 , 1 }));
4001+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {3 , 1 }, {2 , 1 }));
4002+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {3 , 2 }, {1 , 1 }));
4003+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {3 , 2 }, {2 , 1 }));
4004+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {3 , 2 }, {1 , 2 }));
4005+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {3 , 2 }, {2 , 2 }));
4006+
4007+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {1 , 1 }, {1 , 1 }));
4008+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {1 , 1 }, {2 , 1 }));
4009+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {1 , 1 }, {1 , 2 }));
4010+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {3 , 1 }, {1 , 1 }));
4011+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {3 , 1 }, {2 , 1 }));
4012+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {3 , 2 }, {1 , 1 }));
4013+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {3 , 2 }, {2 , 1 }));
4014+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {3 , 2 }, {1 , 2 }));
4015+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {3 , 2 }, {2 , 2 }));
39374016
39384017 // test cases with permutation
39394018 test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {2 , 3 }, {1 , 1 }, {0 , 2 , 1 , 3 }));
0 commit comments