-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix device mismatch error when using masks #1061
Conversation
nerfstudio/data/pixel_samplers.py
Outdated
@@ -40,7 +40,7 @@ def collate_image_dataset_batch(batch: Dict, num_rays_per_batch: int, keep_full_ | |||
if "mask" in batch: | |||
nonzero_indices = torch.nonzero(batch["mask"][..., 0], as_tuple=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it by any chance work if you instead do
nonzero_indices = torch.nonzero(batch["mask"][..., 0].to(device), as_tuple=False)
This puts it on the gpu earlier which might be a little faster
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that seems to work too, but FYI 'device' in this case seems to be cpu (sorry, had a typo in the original PR description)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* Fix device mismatch error when using masks * Update pixel_samplers.py
* Fix device mismatch error when using masks * Update pixel_samplers.py
When using masks, I get the following error:
printing the keys and devices in my batch, I see:
Not sure if this fix is the most elegant solution