diff --git a/applications/neural_search/recall/simcse/inference.py b/applications/neural_search/recall/simcse/inference.py index 0e11c6ad65e4..097c348c736f 100644 --- a/applications/neural_search/recall/simcse/inference.py +++ b/applications/neural_search/recall/simcse/inference.py @@ -66,8 +66,10 @@ def convert_example(example, tokenizer, max_seq_length=512, do_evalute=False): max_seq_length=max_seq_length) batchify_fn = lambda samples, fn=Tuple( - Pad(axis=0, pad_val=tokenizer.pad_token_id), # text_input - Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # text_segment + Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64" + ), # text_input + Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64" + ), # text_segment ): [data for data in fn(samples)] pretrained_model = AutoModel.from_pretrained(model_name_or_path)