Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: 'Distribution' object has no attribute 'dist_type' #82

Open
Ciccios96 opened this issue Feb 6, 2021 · 2 comments
Open

Comments

@Ciccios96
Copy link

Hello, when i try to run the torchvision.utils.save_image() function inside the save_and_sample() function i get the error in the title. Can i get some help? This error is present in the Cifar10 train script.

@Gass2109
Copy link

The same issue with the version 0.8.2 of torchvision (it works properly with the version 0.4.2).

@universome
Copy link

universome commented May 6, 2021

The problem arises because grid variable is not a normal torch.Tensor, but rather a tweaked Distribution variable. To fix the issue, I had to tweak lines:

in train_fns.py:
from

torchvision.utils.save_image(fixed_Gz.float().cpu(), image_filename,

to

torchvision.utils.save_image(torch.from_numpy(fixed_Gz.float().cpu().numpy()), image_filename,

in utils.py:
add line:

out_ims = torch.from_numpy(out_ims.numpy())

after line:

out_ims = torch.stack(ims, 1).view(-1, ims[0].shape[1], ims[0].shape[2],
                                        ims[0].shape[3]).data.float().cpu()

What we basically do here is converting the tweaked Distribution tensor into normal torch.Tensor via torch.from_numpy(t.numpy()).


Also, JFYI, here are the scores I got for cifar-10 when launching the launch_cifar_ema.sh script:

  • best IS = 8.375 at itr=95000
  • best FID = 6.335 at itr=95000
    The paper reported IS=9.22 and FID=14.73 in Appendix C.2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants