Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test_grad_checkpoint.py fails if PyTorch is compiled with CUDA support. #6086

Closed
ysiraichi opened this issue Dec 9, 2023 · 1 comment · Fixed by #6178
Closed

test_grad_checkpoint.py fails if PyTorch is compiled with CUDA support. #6086

ysiraichi opened this issue Dec 9, 2023 · 1 comment · Fixed by #6178
Labels

Comments

@ysiraichi
Copy link
Collaborator

🐛 Bug

Enabling CUDA support for PyTorch on CI is breaking test_checkpoint.py. I managed to reproduce this issue by modifying the test:

diff --git a/test/test_grad_checkpoint.py b/test/test_grad_checkpoint.py
index 9a5fd19aa..e7d29357b 100644
--- a/test/test_grad_checkpoint.py
+++ b/test/test_grad_checkpoint.py
@@ -4,6 +4,7 @@ import torch_xla.debug.metrics as met
 import torch_xla
 import torch_xla.utils.checkpoint as checkpoint
 
+torch.cuda.init()
 
 def run():
   device = xm.xla_device()

Running test_grad_checkpoint.py test produces the following error:

$ python test/test_grad_checkpoint.py
Traceback (most recent call last):
  File "test/test_grad_checkpoint.py", line 37, in <module>
    run()
  File "test/test_grad_checkpoint.py", line 27, in run
    x = checkpoint.checkpoint(layer, x)
  File "xla/torch_xla/utils/checkpoint.py", line 212, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "xla/torch_xla/utils/checkpoint.py", line 49, in forward
    ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
  File "torch/utils/checkpoint.py", line 177, in get_device_states
    device_module = _get_device_module(_infer_device_type(*args))
  File "torch/utils/checkpoint.py", line 97, in _get_device_module
    device_module = getattr(torch, device)
  File "torch/__init__.py", line 1927, in __getattr__
    raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
AttributeError: module 'torch' has no attribute 'xla'

Environment

Additional Context

Blocking: #6070

@ysiraichi
Copy link
Collaborator Author

This issue should be fixed in the main PyTorch repo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants