Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError for Image size: Height 480 , Width 854 in RAFT #8848

Open
Neoyning opened this issue Jan 11, 2025 · 1 comment · May be fixed by #8851
Open

ValueError for Image size: Height 480 , Width 854 in RAFT #8848

Neoyning opened this issue Jan 11, 2025 · 1 comment · May be fixed by #8851

Comments

@Neoyning
Copy link

🐛 Describe the bug

...
device = "cuda" if torch.cuda.is_available() else "cpu"
raft_model = raft_small(pretrained=True, progress=False).to(device)
raft_model = raft_model.eval()
transform = transforms.ToTensor()
with torch.no_grad():
list_of_flows = raft_model(old_batch.to(device), new_batch.to(device))
...

Versions

Hi there,

I am testing the orchvision.models.optical_flow module raft_small, the code is running ok for image size (480, 752), (800,848)..
However, when I test it on Image size: Height 480 , Width 854. The code throw

ValueError: The feature encoder should downsample H and W by 8

I debug the code on https://github.com/pytorch/vision/blob/d3beb52a00e16c71e821e192bcc592d614a490c0/torchvision/models/optical_flow/raft.py#L494

fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0))
fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0)
if fmap1.shape[-2:] != (h // 8, w // 8):
     raise ValueError("The feature encoder should downsample H and W by 8")

Image size: Height 480 , Width 854
where fmap1.shape[-2:] is torch.Size([60, 107]), h // 8 = 60, but w // 8 = 106 which triggered the ValueError.

I think this issue is related to output dimension of self.feature_encoder. Looking for help, thx~

@NicolasHug
Copy link
Member

Thanks for the report @Neoyning

RAFT expects height and width to be divisible by 8, and the problem is the 854 isn't divisible by 8.

We have a check for that:

if not (h % 8 == 0) and (w % 8 == 0):
raise ValueError(f"input image H and W should be divisible by 8, instead got {h} (h) and {w} (w)")

But I realize now that it's incorrect (bug was introduced in https://github.com/pytorch/vision/pull/5587/files#diff-695b16043168ca4717ef0ac674e2c78c3055e752616b0366a0b2608a40375db8R469).

I'll submit a PR to fix the error message, which will now clearly indicate the source of the problem.

@NicolasHug NicolasHug linked a pull request Jan 13, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants