Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: use device in all Torch models #5026

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fiftyone/utils/clip/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def _predict_all(self, imgs):
frame_size = (width, height)

if self._using_gpu:
imgs = imgs.cuda()
imgs = imgs.to(self.device)

text_features = self._get_text_features()
image_features = self._model.encode_image(imgs)
Expand Down
10 changes: 6 additions & 4 deletions fiftyone/utils/open_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _get_text_features(self):
# Tokenize text
text = self._tokenizer(prompts)
if self._using_gpu:
text = text.cuda()
text = text.to(self.device)
self._text_features = self._model.encode_text(text)

return self._text_features
Expand All @@ -118,7 +118,7 @@ def _embed_prompts(self, prompts):
# Tokenize text
text = self._tokenizer(formatted_prompts)
if self._using_gpu:
text = text.cuda()
text = text.to(self.device)
return self._model.encode_text(text)

def _get_class_logits(self, text_features, image_features):
Expand All @@ -143,9 +143,11 @@ def _predict_all(self, imgs):
frame_size = (width, height)

if self._using_gpu:
imgs = imgs.cuda()
imgs = imgs.to(self.device)
Copy link
Contributor

@coderabbitai coderabbitai bot Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Consider adding error handling for device transfers.

While the device handling change is correct, the subsequent autocast context is hardcoded to "cuda". This might cause issues when running on CPU.

Apply this diff to make it device-aware:

         if self._using_gpu:
             imgs = imgs.to(self.device)
 
-        with torch.no_grad(), torch.amp.autocast("cuda"):
+        with torch.no_grad(), torch.amp.autocast(device_type=self.device.type if self._using_gpu else "cpu"):
             image_features = self._model.encode_image(imgs)
             text_features = self._get_text_features()

Also applies to: 147-152

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jacobsela coderabbit raises an interesting point here. Does torch.amp.autocast("cuda") need to be updated?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is valid and will cause problems if not handled. It's in my todo for this week to more thoroughly review the code before moving further with this PR because this message makes me think that there are probably more places I haven't noticed that make hardware assumptions.


with torch.no_grad(), torch.amp.autocast("cuda"):
with torch.no_grad(), torch.amp.autocast(
device_type=self.device.type if self._using_gpu else "cpu"
):
image_features = self._model.encode_image(imgs)
text_features = self._get_text_features()

Expand Down
2 changes: 1 addition & 1 deletion fiftyone/utils/super_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _load_model(self, config):
)

if self._using_gpu:
model = model.cuda()
model = model.to(self.device)

return model

Expand Down
Loading