From 116464ae43585f58c4534d4b76d85563e8abd15c Mon Sep 17 00:00:00 2001 From: Nadav Timor Date: Mon, 27 Jan 2025 19:05:09 -0500 Subject: [PATCH 1/2] tok ids to `torch.int64` (reference: https://huggingface.co/docs/transformers.js/en/api/tokenizers) --- src/transformers/generation/candidate_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 086a87750cfb..1e7a85b5a970 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -893,7 +893,7 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to if tokens_to_remove > 0: self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove] assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) - assistant_input_ids = assistant_input_ids.to(torch.int) + assistant_input_ids = assistant_input_ids.to(torch.int64) return assistant_input_ids, len(assistant_new_ids[0]) From 6ffd3fa93dbe3a9c98cc491a799e911cf570a27e Mon Sep 17 00:00:00 2001 From: Nadav Timor Date: Mon, 27 Jan 2025 19:16:19 -0500 Subject: [PATCH 2/2] `LongTensor` --- src/transformers/generation/candidate_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 1e7a85b5a970..062bdd056b70 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -893,7 +893,7 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to if tokens_to_remove > 0: self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove] assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) - assistant_input_ids = assistant_input_ids.to(torch.int64) + assistant_input_ids = assistant_input_ids.to(torch.LongTensor) return assistant_input_ids, len(assistant_new_ids[0])