Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2565,7 +2565,7 @@ def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor,
return type(data)(self._prepare_input(v) for v in data)
elif isinstance(data, torch.Tensor):
kwargs = {"device": self.args.device}
if self.deepspeed and data.dtype != torch.int64:
if self.deepspeed and data.dtype not in (torch.int32, torch.int64):
Comment thread
izapolsk marked this conversation as resolved.
Outdated
# NLP models inputs are int64 and those get adjusted to the right dtype of the
# embedding. Other models such as wav2vec2's inputs are already float and thus
# may need special handling to match the dtypes of the model
Expand Down