Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AlphaZero torch model doesn't support cuda, only cpu #14970

Open
2 tasks done
lukaszkn opened this issue Mar 27, 2021 · 0 comments
Open
2 tasks done

AlphaZero torch model doesn't support cuda, only cpu #14970

lukaszkn opened this issue Mar 27, 2021 · 0 comments
Labels
bug Something that is supposed to be working; but isn't P2 Important issue, but not time-critical

Comments

@lukaszkn
Copy link

What is the problem?

I'm trying to run AlphaZero implementation sample from ray/rllib/contrib/alpha_zero/examples/train_cartpole.py but exception is raised:
RuntimeError: Tensor for argument # 2 'mat1' is on CPU, but expected it to be on GPU

There is no exception with cpu version of torch installed. Does AlphaZero support CPU only?

Ray version and other system information (Python version, TensorFlow version, OS):
ray 1.2.0
torch 1.8.1 gpu
python 3.6

Reproduction (REQUIRED)

Take sample from ray/rllib/contrib/alpha_zero/examples/train_cartpole.py and run with torch 1.8.1+gpu installed and then exception is raised:

Traceback (most recent call last):
File "C:\Python36\lib\site-packages\ray\rllib\agents\trainer.py", line 526, in train
raise e
File "C:\Python36\lib\site-packages\ray\rllib\agents\trainer.py", line 515, in train
result = Trainable.train(self)
File "C:\Python36\lib\site-packages\ray\tune\trainable.py", line 226, in train
result = self.step()
File "C:\Python36\lib\site-packages\ray\rllib\agents\trainer_template.py", line 148, in step
res = next(self.train_exec_impl)
File "C:\Python36\lib\site-packages\ray\util\iter.py", line 756, in next
return next(self.built_iterator)
File "C:\Python36\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
for item in it:
File "C:\Python36\lib\site-packages\ray\util\iter.py", line 843, in apply_filter
for item in it:
File "C:\Python36\lib\site-packages\ray\util\iter.py", line 843, in apply_filter
for item in it:
File "C:\Python36\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
for item in it:
File "C:\Python36\lib\site-packages\ray\util\iter.py", line 876, in apply_flatten
for item in it:
File "C:\Python36\lib\site-packages\ray\util\iter.py", line 828, in add_wait_hooks
item = next(it)
File "C:\Python36\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
for item in it:
File "C:\Python36\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
for item in it:
File "C:\Python36\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
for item in it:
File "C:\Python36\lib\site-packages\ray\util\iter.py", line 471, in base_iterator
yield ray.get(futures, timeout=timeout)
File "C:\Python36\lib\site-packages\ray_private\client_mode_hook.py", line 47, in wrapper
return func(*args, **kwargs)
File "C:\Python36\lib\site-packages\ray\worker.py", line 1456, in get
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): ray::RolloutWorker.par_iter_next() (pid=12020, ip=192.168.0.107)
File "python\ray_raylet.pyx", line 480, in ray._raylet.execute_task
File "python\ray_raylet.pyx", line 432, in ray._raylet.execute_task.function_executor
File "C:\Python36\lib\site-packages\ray\function_manager.py", line 556, in actor_method_executor
return method(__ray_actor, *args, **kwargs)
File "C:\Python36\lib\site-packages\ray\util\iter.py", line 1152, in par_iter_next
return next(self.local_it)
File "C:\Python36\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 327, in gen_rollouts
yield self.sample()
File "C:\Python36\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 662, in sample
batches = [self.input_reader.next()]
File "C:\Python36\lib\site-packages\ray\rllib\evaluation\sampler.py", line 95, in next
batches = [self.get_data()]
File "C:\Python36\lib\site-packages\ray\rllib\evaluation\sampler.py", line 224, in get_data
item = next(self.rollout_provider)
File "C:\Python36\lib\site-packages\ray\rllib\evaluation\sampler.py", line 656, in _env_runner
tf_sess=tf_sess,
File "C:\Python36\lib\site-packages\ray\rllib\evaluation\sampler.py", line 1344, in _do_policy_eval_w_trajectory_view_api
episodes=[active_episodes[t.env_id] for t in eval_data])
File "C:\Python36\lib\site-packages\ray\rllib\contrib\alpha_zero\core\alpha_zero_policy.py", line 91, in compute_actions_from_input_dict
tree_node)
File "C:\Python36\lib\site-packages\ray\rllib\contrib\alpha_zero\core\mcts.py", line 129, in compute_action
leaf.obs)
File "C:\Python36\lib\site-packages\ray\rllib\contrib\alpha_zero\models\custom_torch_models.py", line 53, in compute_priors_and_value
model_out = self.forward(input_dict, None, [1])
File "C:\Python36\lib\site-packages\ray\rllib\contrib\alpha_zero\models\custom_torch_models.py", line 37, in forward
x = self.shared_layers(x)
File "C:\Python36\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Python36\lib\site-packages\torch\nn\modules\container.py", line 119, in forward
input = module(input)
File "C:\Python36\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Python36\lib\site-packages\torch\nn\modules\linear.py", line 94, in forward
return F.linear(input, self.weight, self.bias)
File "C:\Python36\lib\site-packages\torch\nn\functional.py", line 1753, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: Tensor for argument # 2 'mat1' is on CPU, but expected it to be on GPU (while checking arguments for addmm)
python-BaseException

  • I have verified my script runs in a clean environment and reproduces the issue.
  • I have verified the issue also occurs with the latest wheels.
@lukaszkn lukaszkn added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Mar 27, 2021
@richardliaw richardliaw added this to the RLlib Bugs milestone Apr 21, 2021
@richardliaw richardliaw added P2 Important issue, but not time-critical and removed triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Apr 21, 2021
@anyscalesam anyscalesam removed this from the RLlib Bugs milestone Jun 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't P2 Important issue, but not time-critical
Projects
None yet
Development

No branches or pull requests

3 participants