-
Notifications
You must be signed in to change notification settings - Fork 7k
Closed
Labels
bugSomething that is supposed to be working; but isn'tSomething that is supposed to be working; but isn'tcommunity-backlogrllibRLlib related issuesRLlib related issuesstabilitytriageNeeds triage (eg: priority, bug/not-bug, and owning component)Needs triage (eg: priority, bug/not-bug, and owning component)usability
Description
What happened + What you expected to happen
I wanted to implement a custom TorchRLModule.
class ChessRLModule(VPGTorchRLModule):
def setup(self):
# obs_space['observation'] is (8, 8, 111) for chess_v6
obs_space = self.config.observation_space["player_0"]["observation"]
act_space = self.config.action_space
# Calculate input dimension for a simple flattened encoder
input_dim = int(torch.prod(torch.tensor(obs_space.shape)))
hidden_dim = self.model_config["hidden_dim"]
output_dim = act_space.n <-- here the error is thrown (caused by me)
self.policy_net = torch.nn.Sequential(
torch.nn.Linear(input_dim, hidden_dim),
torch.nn.ReLU(),
torch.nn.Linear(hidden_dim, output_dim),
)
However, when I run the example code:
config = (
VPGConfig()
.framework("torch")
.rl_module(
model_config={"hidden_dim": 64},
rl_module_spec=RLModuleSpec(
module_class=ChessRLModule,
),
) # custom config for the learner
.environment("chess_v6")
.env_runners(
num_env_runners=0,
)
.training(
num_epochs=1
)
)
algo = config.build_algo()
algo.train()
algo.evaluate()
algo.stop()
I received an error stating that the optimizer couldn't find any parameters.
ValueError: optimizer got an empty parameter list
This was because my custom VPGTorchRLModule.setup() function led to an AttributeError at line output_dim = act_space.n .
But this error is ignored in RLModule when not of a particular error message (L469-473):
try:
self.setup()
except AttributeError as e:
if "'NoneType' object has no attribute " in e.args[0]:
raise (self._catalog_ctor_error or e)
I would expect that the error runs through in any case.
This issue is particularly annoying because this behavior results in a later -- rather unrelated -- error.
Versions / Dependencies
macOS 13.7.8
Python 3.10
Ray 2.49.2
pettingzoo 1.25.0
Reproduction script
from ray.rllib.core.rl_module import RLModuleSpec
from ray.rllib.env import PettingZooEnv
from ray.rllib.examples.algorithms.classes.vpg import VPGConfig
from ray.rllib.examples.rl_modules.classes.vpg_torch_rlm import VPGTorchRLModule
from ray.tune import register_env
from ray.util.annotations import RayDeprecationWarning
from ray.rllib.core.columns import Columns
import torch
from pettingzoo.classic import chess_v6
register_env("chess_v6", lambda _: PettingZooEnv(chess_v6.env()))
class ChessRLModule(VPGTorchRLModule):
def setup(self):
obs_space = self.config.observation_space["player_0"]["observation"]
act_space = self.config.action_space
input_dim = int(torch.prod(torch.tensor(obs_space.shape)))
hidden_dim = self.model_config["hidden_dim"]
output_dim = act_space.n
self.policy_net = torch.nn.Sequential(
torch.nn.Linear(input_dim, hidden_dim),
torch.nn.ReLU(),
torch.nn.Linear(hidden_dim, output_dim),
)
def _forward(self, batch, **kwargs):
obs_dict = batch[Columns.OBS]
board_state = obs_dict["observation"].float() # ignore action-mask
action_logits = self.policy_net(board_state)
return {
Columns.ACTION_DIST_INPUTS: action_logits
}
def _forward_inference(self, batch, **kwargs):
return self._forward(batch, **kwargs)
def _forward_exploration(self, batch, **kwargs):
return self._forward(batch, **kwargs)
def _forward_train(self, batch, **kwargs):
return self._forward(batch, **kwargs)
config = (
VPGConfig()
.framework("torch")
.rl_module(
model_config={"hidden_dim": 64},
rl_module_spec=RLModuleSpec(
module_class=ChessRLModule,
),
) # custom config for the learner
.environment("chess_v6")
.env_runners(
num_env_runners=0,
)
.training(
num_epochs=1
)
)
algo = config.build_algo()
algo.train()
algo.evaluate()
algo.stop()
Issue Severity
High: It blocks me from completing my task.
Metadata
Metadata
Assignees
Labels
bugSomething that is supposed to be working; but isn'tSomething that is supposed to be working; but isn'tcommunity-backlogrllibRLlib related issuesRLlib related issuesstabilitytriageNeeds triage (eg: priority, bug/not-bug, and owning component)Needs triage (eg: priority, bug/not-bug, and owning component)usability