@@ -1151,6 +1151,60 @@ struct test_glu : public test_case {
11511151 }
11521152};
11531153
1154+ struct test_glu_split : public test_case {
1155+ const ggml_glu_op op;
1156+ const ggml_type type;
1157+ const std::array<int64_t , 4 > ne_a;
1158+ int v; // view (1 : non-contiguous a)
1159+
1160+ std::string vars () override {
1161+ return VARS_TO_STR3 (type, ne_a, v);
1162+ }
1163+
1164+ test_glu_split (ggml_glu_op op,
1165+ ggml_type type = GGML_TYPE_F32,
1166+ std::array<int64_t , 4 > ne_a = {128 , 2 , 2 , 2 },
1167+ int v = 0 )
1168+ : op(op), type(type), ne_a(ne_a), v(v) {}
1169+
1170+ ggml_tensor * build_graph (ggml_context * ctx) override {
1171+ ggml_tensor * a;
1172+ ggml_tensor * b;
1173+ if (v & 1 ) {
1174+ auto ne = ne_a; ne[0 ] *= 3 ;
1175+ a = ggml_new_tensor (ctx, type, 4 , ne.data ());
1176+ ggml_set_name (a, " a" );
1177+
1178+ a = ggml_view_4d (ctx, a, ne_a[0 ], ne_a[1 ], ne_a[2 ], ne_a[3 ], a->nb [1 ], a->nb [2 ], a->nb [3 ], 0 );
1179+ ggml_set_name (a, " view_of_a" );
1180+
1181+ b = ggml_new_tensor (ctx, type, 4 , ne.data ());
1182+ ggml_set_name (b, " b" );
1183+
1184+ b = ggml_view_4d (ctx, b, ne_a[0 ], ne_a[1 ], ne_a[2 ], ne_a[3 ], b->nb [1 ], b->nb [2 ], b->nb [3 ], 0 );
1185+ ggml_set_name (a, " view_of_b" );
1186+ } else {
1187+ a = ggml_new_tensor (ctx, type, 4 , ne_a.data ());
1188+ ggml_set_name (a, " a" );
1189+
1190+ b = ggml_new_tensor (ctx, type, 4 , ne_a.data ());
1191+ ggml_set_name (b, " b" );
1192+ }
1193+
1194+ ggml_tensor * out = ggml_glu_split (ctx, a, b, op);
1195+ ggml_set_name (out, " out" );
1196+
1197+ return out;
1198+ }
1199+
1200+ void initialize_tensors (ggml_context * ctx) override {
1201+ for (ggml_tensor * t = ggml_get_first_tensor (ctx); t != NULL ; t = ggml_get_next_tensor (ctx, t)) {
1202+ // test extended range of values to check for NaNs in GELU
1203+ init_tensor_uniform (t, -150 .f , 150 .f );
1204+ }
1205+ }
1206+ };
1207+
11541208// GGML_OP_GET_ROWS
11551209struct test_get_rows : public test_case {
11561210 const ggml_type type;
@@ -3986,6 +4040,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
39864040 test_cases.emplace_back (new test_glu ((ggml_glu_op) op, type, { 128 , 2 , 2 , 2 }, v, swapped));
39874041 test_cases.emplace_back (new test_glu ((ggml_glu_op) op, type, { 5 , 7 , 11 , 13 }, v, swapped));
39884042 }
4043+
4044+ test_cases.emplace_back (new test_glu_split ((ggml_glu_op) op, type, { 128 , 2 , 2 , 2 }, v));
4045+ test_cases.emplace_back (new test_glu_split ((ggml_glu_op) op, type, { 5 , 7 , 11 , 13 }, v));
39894046 }
39904047 }
39914048 }
0 commit comments