diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 04b55b7b6a33..ba94baaa7d0c 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -206,6 +206,10 @@ def _collate_batch(examples, tokenizer): return result +def tolist(x: Union[List[Any], torch.Tensor]): + return x.tolist() if isinstance(x, torch.Tensor) else x + + @dataclass class DataCollatorForLanguageModeling: """ @@ -320,13 +324,13 @@ def __call__( mask_labels = [] for e in examples: ref_tokens = [] - for id in e["input_ids"].tolist(): + for id in tolist(e["input_ids"]): token = self.tokenizer._convert_id_to_token(id) ref_tokens.append(token) # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢] if "chinese_ref" in e: - ref_pos = e["chinese_ref"].tolist() + ref_pos = tolist(e["chinese_ref"]) len_seq = e["input_ids"].size(0) for i in range(len_seq): if i in ref_pos: