From 4edfce84233eedb425c4fc60bc63cef6b2bf50b2 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Mon, 10 Oct 2022 23:12:50 +0500 Subject: [PATCH 1/2] create the arange tensor on device for enabling CUDA-Graph at higher-performace for SD --- src/transformers/models/clip/modeling_clip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index d3af1c055b4e..de15ff8ca771 100755 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -662,7 +662,7 @@ def forward( # take features from the eot embedding (eot_token is the highest number in each sequence) # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 pooled_output = last_hidden_state[ - torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1) + torch.arange(last_hidden_state.shape[0], device=input_ids.device), input_ids.to(torch.int).argmax(dim=-1) ] if not return_dict: From 5960a64327d9d4995beee9e1ed6e9cae776bbcc4 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 11 Oct 2022 20:55:27 -0700 Subject: [PATCH 2/2] sync --- src/transformers/models/groupvit/modeling_groupvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index 210a848f28c7..6b83b533d37c 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -1134,7 +1134,7 @@ def forward( # take features from the eot embedding (eot_token is the highest number in each sequence) # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 pooled_output = last_hidden_state[ - torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1) + torch.arange(last_hidden_state.shape[0], device=input_ids.device), input_ids.to(torch.int).argmax(dim=-1) ] if not return_dict: