diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index d3af1c055b4e..de15ff8ca771 100755 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -662,7 +662,7 @@ def forward( # take features from the eot embedding (eot_token is the highest number in each sequence) # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 pooled_output = last_hidden_state[ - torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1) + torch.arange(last_hidden_state.shape[0], device=input_ids.device), input_ids.to(torch.int).argmax(dim=-1) ] if not return_dict: diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index 210a848f28c7..6b83b533d37c 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -1134,7 +1134,7 @@ def forward( # take features from the eot embedding (eot_token is the highest number in each sequence) # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 pooled_output = last_hidden_state[ - torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1) + torch.arange(last_hidden_state.shape[0], device=input_ids.device), input_ids.to(torch.int).argmax(dim=-1) ] if not return_dict: