@@ -3637,7 +3637,7 @@ struct test_flash_attn_ext : public test_case {
36373637
36383638        ggml_tensor * m = nullptr ;
36393639        if  (mask) {
3640-             m = ggml_new_tensor_4d (ctx, GGML_TYPE_F16, kv, GGML_PAD (nb, GGML_KQ_MASK_PAD), nr23[1 ], 1 );
3640+             m = ggml_new_tensor_4d (ctx, GGML_TYPE_F16, kv, GGML_PAD (nb, GGML_KQ_MASK_PAD), nr23[0 ], nr23[ 1 ] );
36413641            ggml_set_name (m, " m"  );
36423642        }
36433643
@@ -4751,7 +4751,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
47514751                                test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, m_prec, {1 , 1 }, scale, max_bias));
47524752
47534753                                if  (ne0 <= 32  && ne1 <= 32 ) {
4754-                                     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {ne0,   ne1,   1 , 1 }, mask, m_prec, {3 , 1 }, scale, max_bias));
4754+                                     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {ne0,   ne1,   1 , 3 }, mask, m_prec, {3 , 1 }, scale, max_bias));
47554755                                    test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, m_prec, {2 , 3 }, scale, max_bias));
47564756                                }
47574757                            }
0 commit comments