diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 079696c244fe..d34449fa57db 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2562,8 +2562,8 @@ def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, return type(data)(self._prepare_input(v) for v in data) elif isinstance(data, torch.Tensor): kwargs = {"device": self.args.device} - if self.deepspeed and data.dtype != torch.int64: - # NLP models inputs are int64 and those get adjusted to the right dtype of the + if self.deepspeed and (torch.is_floating_point(data) or torch.is_complex(data)): + # NLP models inputs are int/uint and those get adjusted to the right dtype of the # embedding. Other models such as wav2vec2's inputs are already float and thus # may need special handling to match the dtypes of the model kwargs.update({"dtype": self.args.hf_deepspeed_config.dtype()})