From 26103901c92b4d9d5217f1b0ab836b3cac49f4e7 Mon Sep 17 00:00:00 2001 From: "amarsingh.thakur" Date: Wed, 15 Jan 2025 16:12:18 +0530 Subject: [PATCH] Fixes issue for tensors on same device when running inference in multithreaded environment --- ip_adapter/attention_processor.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/ip_adapter/attention_processor.py b/ip_adapter/attention_processor.py index bfbb50d9..7abc4e57 100644 --- a/ip_adapter/attention_processor.py +++ b/ip_adapter/attention_processor.py @@ -427,6 +427,22 @@ def forward( mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1]) else: mask = torch.ones_like(ip_hidden_states) + + mask = mask.to(ip_hidden_states.device) + if mask.shape[1] < ip_hidden_states.shape[1]: + # Pad mask if it's shorter + pad_size = ip_hidden_states.shape[1] - mask.shape[1] + mask = F.pad(mask, (0, 0, 0, pad_size), mode="constant", value=1.0) + else: + # Truncate mask if it's longer + mask = mask[:, :ip_hidden_states.shape[1]] + + # Ensure mask has the same number of dimensions as ip_hidden_states + if mask.ndim < ip_hidden_states.ndim: + mask = mask.unsqueeze(-1) + elif mask.ndim > ip_hidden_states.ndim: + mask = mask.squeeze(-1) + ip_hidden_states = ip_hidden_states * mask hidden_states = hidden_states + self.scale * ip_hidden_states