Skip to content

RLlib: RLModule swallows AttributeError #58854

@phisad

Description

@phisad

What happened + What you expected to happen

I wanted to implement a custom TorchRLModule.

class ChessRLModule(VPGTorchRLModule):

    def setup(self):
        # obs_space['observation'] is (8, 8, 111) for chess_v6
        obs_space = self.config.observation_space["player_0"]["observation"]
        act_space = self.config.action_space

        # Calculate input dimension for a simple flattened encoder
        input_dim = int(torch.prod(torch.tensor(obs_space.shape)))
        hidden_dim = self.model_config["hidden_dim"]
        output_dim = act_space.n <-- here the error is thrown (caused by me)

        self.policy_net = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, output_dim),
        )

However, when I run the example code:

config = (
    VPGConfig()
    .framework("torch")
    .rl_module(
        model_config={"hidden_dim": 64},
        rl_module_spec=RLModuleSpec(
            module_class=ChessRLModule,
        ),
    )  # custom config for the learner
    .environment("chess_v6")
    .env_runners(
        num_env_runners=0,
    )
    .training(
        num_epochs=1
    )
)

algo = config.build_algo()
algo.train()
algo.evaluate()
algo.stop()

I received an error stating that the optimizer couldn't find any parameters.

ValueError: optimizer got an empty parameter list

This was because my custom VPGTorchRLModule.setup() function led to an AttributeError at line output_dim = act_space.n .

But this error is ignored in RLModule when not of a particular error message (L469-473):

        try:
            self.setup()
        except AttributeError as e:
            if "'NoneType' object has no attribute " in e.args[0]:
                raise (self._catalog_ctor_error or e)

I would expect that the error runs through in any case.

This issue is particularly annoying because this behavior results in a later -- rather unrelated -- error.

Versions / Dependencies

macOS 13.7.8
Python 3.10
Ray 2.49.2
pettingzoo 1.25.0

Reproduction script

from ray.rllib.core.rl_module import RLModuleSpec
from ray.rllib.env import PettingZooEnv
from ray.rllib.examples.algorithms.classes.vpg import VPGConfig
from ray.rllib.examples.rl_modules.classes.vpg_torch_rlm import VPGTorchRLModule
from ray.tune import register_env
from ray.util.annotations import RayDeprecationWarning
from ray.rllib.core.columns import Columns
import torch
from pettingzoo.classic import chess_v6

register_env("chess_v6", lambda _: PettingZooEnv(chess_v6.env()))


class ChessRLModule(VPGTorchRLModule):

    def setup(self):
        obs_space = self.config.observation_space["player_0"]["observation"]
        act_space = self.config.action_space

        input_dim = int(torch.prod(torch.tensor(obs_space.shape)))
        hidden_dim = self.model_config["hidden_dim"]
        output_dim = act_space.n

        self.policy_net = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, output_dim),
        )

    def _forward(self, batch, **kwargs):
        obs_dict = batch[Columns.OBS]
        board_state = obs_dict["observation"].float()  # ignore action-mask
        action_logits = self.policy_net(board_state)
        return {
            Columns.ACTION_DIST_INPUTS: action_logits
        }

    def _forward_inference(self, batch, **kwargs):
        return self._forward(batch, **kwargs)

    def _forward_exploration(self, batch, **kwargs):
        return self._forward(batch, **kwargs)

    def _forward_train(self, batch, **kwargs):
        return self._forward(batch, **kwargs)


config = (
    VPGConfig()
    .framework("torch")
    .rl_module(
        model_config={"hidden_dim": 64},
        rl_module_spec=RLModuleSpec(
            module_class=ChessRLModule,
        ),
    )  # custom config for the learner
    .environment("chess_v6")
    .env_runners(
        num_env_runners=0,
    )
    .training(
        num_epochs=1
    )
)

algo = config.build_algo()
algo.train()
algo.evaluate()
algo.stop()

Issue Severity

High: It blocks me from completing my task.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething that is supposed to be working; but isn'tcommunity-backlogrllibRLlib related issuesstabilitytriageNeeds triage (eg: priority, bug/not-bug, and owning component)usability

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions