Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix device mismatch error when using masks #1061

Merged
merged 3 commits into from
Dec 1, 2022

Conversation

hturki
Copy link
Contributor

@hturki hturki commented Nov 30, 2022

When using masks, I get the following error:

  File "/home/hturki/miniconda3/envs/my-end/lib/python3.9/site-packages/nerfstudio/data/pixel_samplers.py", line 51, in <dictcomp>
    collated_batch = {key: value[c, y, x] for key, value in batch.items() if key != "image_idx" and value is not None}
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

printing the keys and devices in my batch, I see:

image_idx cuda:0
image cpu
mask cuda:0

Not sure if this fix is the most elegant solution

@@ -40,7 +40,7 @@ def collate_image_dataset_batch(batch: Dict, num_rays_per_batch: int, keep_full_
if "mask" in batch:
nonzero_indices = torch.nonzero(batch["mask"][..., 0], as_tuple=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it by any chance work if you instead do
nonzero_indices = torch.nonzero(batch["mask"][..., 0].to(device), as_tuple=False)
This puts it on the gpu earlier which might be a little faster

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that seems to work too, but FYI 'device' in this case seems to be cpu (sorry, had a typo in the original PR description)

Copy link
Contributor

@tancik tancik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tancik tancik merged commit 8c3f093 into nerfstudio-project:main Dec 1, 2022
@hturki hturki deleted the patch-1 branch December 1, 2022 14:04
tancik pushed a commit to dozeri83/nerfstudio that referenced this pull request Jan 20, 2023
* Fix device mismatch error when using masks

* Update pixel_samplers.py
chris838 pushed a commit to chris838/nerfstudio that referenced this pull request Apr 22, 2023
* Fix device mismatch error when using masks

* Update pixel_samplers.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants