Skip to content

Commit

Permalink
Update trainer.py (PaddlePaddle#7123)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZHUI authored Sep 25, 2023
1 parent 9c3f8a4 commit c1157e5
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,7 +1634,12 @@ def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor
elif isinstance(data, (tuple, list)):
return type(data)(self._prepare_input(v) for v in data)
elif isinstance(data, paddle.Tensor):
return data._to(self.args.current_device, None, False)
# kwargs = dict(device=self.args.current_device)
# update data type for pure fp16
if data.place.is_cuda_pinned_place():
return data.cuda()
return data
# return data.to(**kwargs)
return data

def _prepare_inputs(self, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> Dict[str, Union[paddle.Tensor, Any]]:
Expand Down

0 comments on commit c1157e5

Please sign in to comment.