Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility with PyTorch 0.4 #10

Open
AziziShekoofeh opened this issue Jan 30, 2019 · 1 comment
Open

Compatibility with PyTorch 0.4 #10

AziziShekoofeh opened this issue Jan 30, 2019 · 1 comment

Comments

@AziziShekoofeh
Copy link

Hi, Thanks for the nice package and code. I had a few issues in runnig the code on PyTorch 0.4, especially in reading the model. I saw a few similar open issues which people sugessted to change the verison to PyTorch 0.2. Since Most of the recent packages are based on PyTorch 0.4+, and I wasn't intrested in use the conda solution or downgrading, I spent time to find a way to run the code. This issue is just for sharing the location that you may need to change for the recent version of the PyTorch:

1- Loading the pre-trained model issue:

OrderedDict "checkpoint['state_dict']['FeatureExtraction.model.1.num_batches_tracked']" does not exist I'd appreciate it if you could check for the error.

To solve this you need to find the comman names between this pretrained checkpoint and TwoStageCNNGeometric model namespaces.

    if model_aff_tps_path != '':
        checkpoint = torch.load(model_aff_tps_path, map_location=lambda storage, loc: storage)
        checkpoint['state_dict'] = OrderedDict(
            [(k.replace('vgg', 'model'), v) for k, v in checkpoint['state_dict'].items()])

        for name, param in model.FeatureExtraction.state_dict().items():
            if 'FeatureExtraction.' + name in checkpoint['state_dict']:
                model.FeatureExtraction.state_dict()[name].copy_(checkpoint['state_dict']['FeatureExtraction.' + name])
        for name, param in model.FeatureRegression.state_dict().items():
            if 'FeatureRegression.' + name in checkpoint['state_dict']:
                model.FeatureRegression.state_dict()[name].copy_(checkpoint['state_dict']['FeatureRegression.' + name])
        for name, param in model.FeatureRegression2.state_dict().items():
            if 'FeatureRegression2.' + name in checkpoint['state_dict']:
                model.FeatureRegression2.state_dict()[name].copy_(checkpoint['state_dict']['FeatureRegression2.' + name])

The other optimum way is to add a more pythonic statement when you are generating the checkpoint and make a OrderedDic, something like:

 [(k.replace('vgg', 'model'), v) for k, v in checkpoint['state_dict'].items() if v in model.FeatureExtraction.state_dict()])

but I it doesn't work like this definitly and I couldn't find an optimum way anyway, so I ended up to add the explicit if statements.

2- The second issue is happening later on, in preprocess_image(), normalize_image() in ./image/normalization.py, line 38

if isinstance(image,torch.autograd.variable.Variable):
....

The fact is in the classes Tensor and Variable got merged in newer version of PyTorch, so there is no need to check if image is a Variale type and so on.

So, you can easily replace this wholeline by "else:"

Hope this would be helpful for others too.

@yimengli46
Copy link

Thanks, your suggestions work perfectly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants