@@ -26,15 +26,15 @@ def test_eagle_get_hidden_states(self):
2626 paddle .seed (2023 )
2727 bs = np .random .randint (1 , 8 + 1 , dtype = np .int32 )
2828 input_token_num = np .random .randint (2 * 1024 , 4 * 1024 + 1 , dtype = np .int32 )
29- dim_embed = np .random . randint ( 1 , 4 * 1024 + 1 , dtype = np .int32 )
29+ dim_embed = np .array ( 1024 , dtype = np .int32 )
3030 actual_draft_token_num = np .random .randint (2 , 6 , dtype = np .int32 )
3131
3232 seq_lens_this_time = np .random .randint (0 , 2 , bs , dtype = np .int32 )
3333 seq_lens_encoder = np .random .randint (0 , input_token_num // bs + 1 , bs , dtype = np .int32 )
3434 accept_nums = np .random .randint (0 , actual_draft_token_num + 1 , bs , dtype = np .int32 )
3535 base_model_seq_lens_this_time = np .random .randint (0 , input_token_num // bs + 1 , bs , dtype = np .int32 )
3636 base_model_seq_lens_encoder = np .random .randint (0 , 2 , bs , dtype = np .int32 )
37- # don't care
37+
3838 seq_lens_decoder = np .random .randint (0 , input_token_num // bs + 1 , bs , dtype = np .int32 )
3939 stop_flags = np .random .randint (0 , 2 , bs , dtype = np .int32 )
4040
@@ -43,13 +43,12 @@ def test_eagle_get_hidden_states(self):
4343 accept_nums_tensor = paddle .to_tensor (accept_nums , dtype = paddle .int32 )
4444 base_model_seq_lens_this_time_tensor = paddle .to_tensor (base_model_seq_lens_this_time , dtype = paddle .int32 )
4545 base_model_seq_lens_encoder_tensor = paddle .to_tensor (base_model_seq_lens_encoder , dtype = paddle .int32 )
46- # don't care
46+
4747 seq_lens_decoder_tensor = paddle .to_tensor (seq_lens_decoder , dtype = paddle .int32 )
4848 stop_flags_tensor = paddle .to_tensor (stop_flags , dtype = paddle .int32 )
4949
50- # fp32 test
5150 input = np .random .randint (0 , 10 , (input_token_num , dim_embed ), dtype = np .int32 )
52- input_tensor = paddle .to_tensor (input , dtype = paddle .float32 )
51+ input_tensor = paddle .to_tensor (input , dtype = paddle .float16 )
5352 gpu_out = eagle_get_hidden_states (
5453 input_tensor ,
5554 seq_lens_this_time_tensor ,
@@ -61,7 +60,8 @@ def test_eagle_get_hidden_states(self):
6160 base_model_seq_lens_encoder_tensor ,
6261 actual_draft_token_num ,
6362 )
64- print (gpu_out .numpy ())
63+ out_ref = np .array ([6 , 4 , 3 , 3 ], dtype = np .int32 )
64+ np .testing .assert_allclose (gpu_out .numpy ()[0 ][0 :4 ], out_ref )
6565
6666
6767if __name__ == "__main__" :
0 commit comments