@@ -2347,11 +2347,12 @@ struct test_soft_max : public test_case {
23472347 const ggml_type type;
23482348 const std::array<int64_t , 4 > ne;
23492349 const bool mask;
2350+ const ggml_type m_prec;
23502351 const float scale;
23512352 const float max_bias;
23522353
23532354 std::string vars () override {
2354- return VARS_TO_STR5 (type, ne, mask, scale, max_bias);
2355+ return VARS_TO_STR6 (type, ne, mask, m_prec , scale, max_bias);
23552356 }
23562357
23572358 // the 1024 test with bias occasionally fails:
@@ -2363,9 +2364,10 @@ struct test_soft_max : public test_case {
23632364 test_soft_max (ggml_type type = GGML_TYPE_F32,
23642365 std::array<int64_t , 4 > ne = {10 , 5 , 4 , 3 },
23652366 bool mask = false ,
2367+ ggml_type m_prec = GGML_TYPE_F32,
23662368 float scale = 1 .0f ,
23672369 float max_bias = 0 .0f )
2368- : type(type), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {}
2370+ : type(type), ne(ne), mask(mask), m_prec(m_prec), scale(scale), max_bias(max_bias) {}
23692371
23702372 ggml_tensor * build_graph (ggml_context * ctx) override {
23712373 ggml_tensor * a = ggml_new_tensor (ctx, type, 4 , ne.data ());
@@ -2374,7 +2376,7 @@ struct test_soft_max : public test_case {
23742376
23752377 ggml_tensor * mask = nullptr ;
23762378 if (this ->mask ) {
2377- mask = ggml_new_tensor_2d (ctx, GGML_TYPE_F32 , ne[0 ], ne[1 ]);
2379+ mask = ggml_new_tensor_2d (ctx, m_prec , ne[0 ], ne[1 ]);
23782380 ggml_set_name (mask, " mask" );
23792381 }
23802382
@@ -4150,17 +4152,28 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
41504152 for (float scale : {1 .0f , 0 .1f }) {
41514153 for (int64_t ne0 : {16 , 1024 }) {
41524154 for (int64_t ne1 : {16 , 1024 }) {
4153- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 1 }, mask, scale, max_bias));
4154- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, scale, max_bias));
4155+ if (mask) {
4156+ for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
4157+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 1 }, mask, m_prec, scale, max_bias));
4158+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, m_prec, scale, max_bias));
4159+ }
4160+ } else {
4161+ /* The precision of mask here doesn't matter as boolean mask is false */
4162+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 1 }, mask, GGML_TYPE_F32, scale, max_bias));
4163+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, GGML_TYPE_F32, scale, max_bias));
4164+ }
41554165 }
41564166 }
41574167 }
41584168 }
41594169 }
4160- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, true , 0 .1f , 0 .0f ));
4161- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, false , 0 .1f , 0 .0f ));
4162- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , 0 .1f , 0 .0f ));
4163- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , 0 .1f , 8 .0f ));
4170+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, true , GGML_TYPE_F32, 0 .1f , 0 .0f ));
4171+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, true , GGML_TYPE_F16, 0 .1f , 0 .0f ));
4172+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, false , GGML_TYPE_F32, 0 .1f , 0 .0f ));
4173+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F32, 0 .1f , 0 .0f ));
4174+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F16, 0 .1f , 0 .0f ));
4175+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F32, 0 .1f , 8 .0f ));
4176+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F16, 0 .1f , 8 .0f ));
41644177
41654178 for (float max_bias : {0 .0f , 8 .0f }) {
41664179 for (float scale : {1 .0f , 0 .1f }) {
@@ -4296,13 +4309,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
42964309 test_cases.emplace_back (new test_cpy (GGML_TYPE_F32, GGML_TYPE_F32, {8192 , 512 , 2 , 1 }, {0 , 2 , 1 , 3 }));
42974310 test_cases.emplace_back (new test_cpy (GGML_TYPE_F32, GGML_TYPE_F32, {3072 , 512 , 2 , 1 }, {0 , 2 , 1 , 3 }));
42984311
4299- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {4096 , 4096 , 5 , 1 }, false , 1 .0f , 0 .0f ));
4300- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 4096 , 5 , 1 }, false , 1 .0f , 0 .0f ));
4301- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {1024 , 1024 , 10 , 1 }, false , 1 .0f , 0 .0f ));
4302- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 1024 , 10 , 1 }, false , 1 .0f , 0 .0f ));
4303- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {256 , 256 , 20 , 1 }, false , 1 .0f , 0 .0f ));
4304- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {64 , 64 , 20 , 1 }, false , 1 .0f , 0 .0f ));
4305- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 64 , 20 , 1 }, false , 1 .0f , 0 .0f ));
4312+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {4096 , 4096 , 5 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4313+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 4096 , 5 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4314+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {1024 , 1024 , 10 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4315+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 1024 , 10 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4316+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {256 , 256 , 20 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4317+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {64 , 64 , 20 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4318+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 64 , 20 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
43064319
43074320 test_cases.emplace_back (new test_argmax (GGML_TYPE_F32, {32 , 10 , 1 , 1 }));
43084321 test_cases.emplace_back (new test_argmax (GGML_TYPE_F32, {1024 , 10 , 1 , 1 }));
0 commit comments