@@ -98,7 +98,7 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
9898 sampler_tester tester (probs, probs_expected);
9999
100100 DUMP (&tester.cur_p );
101- tester.apply (llama_sampler_init_top_p (p, 1 ));
101+ tester.apply (llama_sampler_init_top_p (p, 0 ));
102102 tester.apply (llama_sampler_init_dist (0 ));
103103 DUMP (&tester.cur_p );
104104
@@ -130,7 +130,7 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
130130 sampler_tester tester (probs, probs_expected);
131131
132132 DUMP (&tester.cur_p );
133- tester.apply (llama_sampler_init_typical (p, 1 ));
133+ tester.apply (llama_sampler_init_typical (p, 0 ));
134134 DUMP (&tester.cur_p );
135135
136136 tester.check ();
@@ -342,8 +342,8 @@ int main(void) {
342342 printf (" XTC should not:\n " );
343343 test_xtc ({0 .4f , 0 .3f , 0 .2f , 0 .1f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 0 .99f , 0 .39f );
344344
345- test_typical ({0 .97f , 0 .01f , 0 .01f , 0 .01f }, {0 .97f }, 0 .5f );
346- test_typical ({0 .4f , 0 .2f , 0 .2f , 0 .2f }, {0 .2f , 0 .2f , 0 .2f }, 0 .5f );
345+ test_typical ({0 .97f , 0 .01f , 0 .01f , 0 .01f }, {0 .97f }, 0 .5f );
346+ test_typical ({0 .4f , 0 .2f , 0 .2f , 0 .2f }, {0 .2f , 0 .2f , 0 .2f }, 0 .5f );
347347
348348 test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 }, {0 .25f , 0 .25f , 0 .25f , 0 .25f , 0 }, 50 .0f , 0 .0f , 0 .0f );
349349 test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 }, {0 .5f , 0 .5f , 0 , 0 , 0 }, 50 .0f , 0 .0f , 0 .0f );
0 commit comments