Skip to content

Commit 8f03b66

Browse files
committed
fix
1 parent 0a5a93b commit 8f03b66

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

tests/operators/test_eagle_get_hidden_states.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ def test_eagle_get_hidden_states(self):
2626
paddle.seed(2023)
2727
bs = np.random.randint(1, 8 + 1, dtype=np.int32)
2828
input_token_num = np.random.randint(2 * 1024, 4 * 1024 + 1, dtype=np.int32)
29-
dim_embed = np.random.randint(1, 4 * 1024 + 1, dtype=np.int32)
29+
dim_embed = np.array(1024, dtype=np.int32)
3030
actual_draft_token_num = np.random.randint(2, 6, dtype=np.int32)
3131

3232
seq_lens_this_time = np.random.randint(0, 2, bs, dtype=np.int32)
3333
seq_lens_encoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
3434
accept_nums = np.random.randint(0, actual_draft_token_num + 1, bs, dtype=np.int32)
3535
base_model_seq_lens_this_time = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
3636
base_model_seq_lens_encoder = np.random.randint(0, 2, bs, dtype=np.int32)
37-
# don't care
37+
3838
seq_lens_decoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
3939
stop_flags = np.random.randint(0, 2, bs, dtype=np.int32)
4040

@@ -43,13 +43,12 @@ def test_eagle_get_hidden_states(self):
4343
accept_nums_tensor = paddle.to_tensor(accept_nums, dtype=paddle.int32)
4444
base_model_seq_lens_this_time_tensor = paddle.to_tensor(base_model_seq_lens_this_time, dtype=paddle.int32)
4545
base_model_seq_lens_encoder_tensor = paddle.to_tensor(base_model_seq_lens_encoder, dtype=paddle.int32)
46-
# don't care
46+
4747
seq_lens_decoder_tensor = paddle.to_tensor(seq_lens_decoder, dtype=paddle.int32)
4848
stop_flags_tensor = paddle.to_tensor(stop_flags, dtype=paddle.int32)
4949

50-
# fp32 test
5150
input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int32)
52-
input_tensor = paddle.to_tensor(input, dtype=paddle.float32)
51+
input_tensor = paddle.to_tensor(input, dtype=paddle.float16)
5352
gpu_out = eagle_get_hidden_states(
5453
input_tensor,
5554
seq_lens_this_time_tensor,
@@ -61,7 +60,8 @@ def test_eagle_get_hidden_states(self):
6160
base_model_seq_lens_encoder_tensor,
6261
actual_draft_token_num,
6362
)
64-
print(gpu_out.numpy())
63+
out_ref = np.array([6, 4, 3, 3], dtype=np.int32)
64+
np.testing.assert_allclose(gpu_out.numpy()[0][0:4], out_ref)
6565

6666

6767
if __name__ == "__main__":

0 commit comments

Comments
 (0)