Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug of device and dtype of WgtScaleBatchNorm.std #116

Merged
merged 2 commits into from
Oct 24, 2022

Conversation

fangwei123456
Copy link
Contributor

@fangwei123456 fangwei123456 commented Oct 15, 2022

Issue Number: 115

Objective of pull request:

Pull request checklist

Your PR fulfills the following requirements:

Pull request type

Please check your PR type:

  • Bugfix
  • Feature
  • Code style update (formatting, renaming)
  • Refactoring (no functional changes, no api changes)
  • Build related changes
  • Documentation changes
  • Other (please describe):

What is the current behavior?

Run the following codes:

from lava.lib.dl import slayer
import torch

net = slayer.neuron.cuba.Neuron(
    threshold=1.,
    current_decay=1.,
    voltage_decay=0.,
    scale=1 << 6,
    norm=slayer.neuron.norm.WgtScaleBatchNorm
)
device = 'cuda:0'
net.to(device)
with torch.no_grad():
    x = torch.rand([4, 4, 4], device=device)
    net(x)

We will get the error:

Traceback (most recent call last):
  File "/home/wfang/spikingjelly_dev/spikingjelly/test4.py", line 15, in <module>
    net(x)
  File "/home/wfang/anaconda3/envs/lava-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wfang/anaconda3/envs/lava-env/lib/python3.10/site-packages/lava/lib/dl/slayer/neuron/cuba.py", line 439, in forward
    _, voltage = self.dynamics(input)
  File "/home/wfang/anaconda3/envs/lava-env/lib/python3.10/site-packages/lava/lib/dl/slayer/neuron/cuba.py", line 365, in dynamics
    current = self.norm(current)
  File "/home/wfang/anaconda3/envs/lava-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wfang/anaconda3/envs/lava-env/lib/python3.10/site-packages/lava/lib/dl/slayer/neuron/norm.py", line 209, in forward
    std = self.std(var)
  File "/home/wfang/anaconda3/envs/lava-env/lib/python3.10/site-packages/lava/lib/dl/slayer/neuron/norm.py", line 170, in std
    return torch.ones(1) << torch.ceil(torch.log2(std)).clamp(
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

What is the new behavior?

We can run the codes without any error.

Does this introduce a breaking change?

  • Yes
  • No

Supplemental information

The error RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! can be solved by change File "/home/wfang/anaconda3/envs/lava-env/lib/python3.10/site-packages/lava/lib/dl/slayer/neuron/norm.py", line 170, in std from

return torch.ones(1) << torch.ceil(torch.log2(std)).clamp(

to

return torch.ones(1, device=std.device) << torch.ceil(torch.log2(std)).clamp(

But it will raise a new error:

RuntimeError: "lshift_cuda" not implemented for 'Float'

We can solve this error by cast both torch.ones(1, device=std.device) and torch.ceil(torch.log2(std)).clamp( ... to torch.int.

However, considering that the return value std is used for a float computation ... / std.view(1, -1), I think using float directly is better than using << with torch.int.

Copy link
Contributor

@bamsumit bamsumit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @fangwei123456 for identifying the issue and fixing the problem.

@bamsumit bamsumit linked an issue Oct 17, 2022 that may be closed by this pull request
13 tasks
@bamsumit bamsumit merged commit b0e2866 into lava-nc:main Oct 24, 2022
@tim-shea tim-shea added this to the Release v0.3.1 milestone Oct 25, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bug of device and dtype of WgtScaleBatchNorm.std
4 participants