Skip to content

Commit

Permalink
Implement maskable version of Graph2NodePPO (sb3 framework)
Browse files Browse the repository at this point in the history
- an action is a node of the observation graph
- we use action masking to predict only applicable actions
- we add an example with variable number of nodes (as it may be relevant
  for action masks shapes handling)
- we add some docstrings to new algos and previous ones (GraphPPO,
  Graph2NodePPO, MaskableGraphPPO) and theirs policies.
  • Loading branch information
nhuet authored and g-poveda committed Feb 21, 2025
1 parent acca9e6 commit 857c3a7
Show file tree
Hide file tree
Showing 10 changed files with 568 additions and 242 deletions.
3 changes: 1 addition & 2 deletions examples/gnn/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,12 @@ class GraphJspDomain(D):

_gym_env: DisjunctiveGraphJspEnv

def __init__(self, gym_env, deterministic=False):
def __init__(self, gym_env):
self._gym_env = gym_env
if self._gym_env.normalize_observation_space:
self.n_nodes_features = gym_env.n_machines + 1
else:
self.n_nodes_features = 2
self.deterministic = deterministic

def _state_reset(self) -> D.T_state:
return self._np_state2graph_state(self._gym_env.reset()[0])
Expand Down
96 changes: 25 additions & 71 deletions examples/gnn/gnn_graph2node_sb3_jsp.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,11 @@
from typing import Any

import numpy as np
from domains import GraphJspDomain
from graph_jsp_env.disjunctive_graph_jsp_env import DisjunctiveGraphJspEnv
from gymnasium.spaces import Box, Graph, GraphInstance

from skdecide.builders.domain import (
FullyObservable,
Initializable,
Markovian,
Renderable,
Rewards,
Sequential,
SingleAgent,
)
from skdecide.core import Space, TransitionOutcome, Value
from skdecide.domains import Domain
from skdecide.hub.solver.stable_baselines import StableBaseline
from skdecide.hub.solver.stable_baselines.gnn.ppo.ppo import Graph2NodePPO
from skdecide.hub.solver.utils.gnn.torch_utils import extract_module_parameters_values
from skdecide.hub.space.gym import GymSpace, ListSpace
from skdecide.hub.solver.stable_baselines.gnn.ppo_mask.ppo_mask import (
MaskableGraph2NodePPO,
)
from skdecide.utils import rollout

jsp = np.array(
Expand Down Expand Up @@ -49,14 +35,35 @@
)


# Uncomment the block below to use PPO without action masking
# with StableBaseline(
# domain_factory=domain_factory,
# algo_class=Graph2NodePPO,
# baselines_policy="GraphInputPolicy",
# policy_kwargs=dict(debug=True),
# learn_config={
# "total_timesteps": 10_000,
# },
# ) as solver:
# solver.solve()
# rollout(
# domain=domain_factory(),
# solver=solver,
# max_steps=30,
# num_episodes=1,
# render=True,
# )

# PPO graph -> node + action masking
with StableBaseline(
domain_factory=domain_factory,
algo_class=Graph2NodePPO,
algo_class=MaskableGraph2NodePPO,
baselines_policy="GraphInputPolicy",
policy_kwargs=dict(debug=True),
learn_config={
"total_timesteps": 10_000,
},
use_action_masking=True,
) as solver:
solver.solve()
rollout(
Expand All @@ -66,56 +73,3 @@
num_episodes=1,
render=True,
)


# action gnn parameters
initial_parameters = solver._algo.policy.action_net.initial_parameters
final_parameters = extract_module_parameters_values(solver._algo.policy.action_net)
same_parameters: dict[str, bool] = {
name: (initial_parameters[name] == final_parameters[name]).all()
for name in final_parameters
}

if all(same_parameters.values()):
print("Action full GNN parameters have not changed during training!")
else:
unchanging_parameters = [name for name, same in same_parameters.items() if same]
print(
f"Action full GNN parameter unchanged after training: {unchanging_parameters}"
)
changing_parameters = [name for name, same in same_parameters.items() if not same]
print(
f"Action full GNN parameters having changed during training: {changing_parameters}"
)
diff_parameters = {
name: abs(initial_parameters[name] - final_parameters[name]).max()
for name in changing_parameters
}
print(diff_parameters)

# value gnn parameters
initial_parameters = solver._algo.policy.features_extractor.extractor.initial_parameters
final_parameters = extract_module_parameters_values(
solver._algo.policy.features_extractor.extractor
)
same_parameters: dict[str, bool] = {
name: (initial_parameters[name] == final_parameters[name]).all()
for name in final_parameters
}

if all(same_parameters.values()):
print("Value GNN feature extractor parameters have not changed during training!")
else:
unchanging_parameters = [name for name, same in same_parameters.items() if same]
print(
f"Value GNN feature extracto parameter unchanged after training: {unchanging_parameters}"
)
changing_parameters = [name for name, same in same_parameters.items() if not same]
print(
f"Value GNN feature extractor parameters having changed during training: {changing_parameters}"
)
diff_parameters = {
name: abs(initial_parameters[name] - final_parameters[name]).max()
for name in changing_parameters
}
print(diff_parameters)
11 changes: 11 additions & 0 deletions skdecide/hub/solver/stable_baselines/gnn/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@


class GraphA2C(GraphOnPolicyAlgorithm, A2C):
"""Advantage Actor Critic (A2C) with graph observations.
It is meant to be applied to a gymnasium environment whose observation space is
- either a `gymnasium.spaces.Graph` => you should use policy="GraphInputPolicy",
- or a `gymnasium.spaces.Dict` with some subspaces being `gymnasium.spaces.Graph`
=> you should use policy="MultiInputPolicy"
The policies will use a GNN to extract features from the observation before being plug to an MLP for prediction.
"""

policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
"GraphInputPolicy": GNNActorCriticPolicy,
"MultiInputPolicy": MultiInputGNNActorCriticPolicy,
Expand Down
66 changes: 53 additions & 13 deletions skdecide/hub/solver/stable_baselines/gnn/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import VecNormalize
from torch.nn.utils.rnn import pad_sequence

from .preprocessing import get_obs_shape
from .utils import copy_graph_instance, graph_instance_to_thg_data
Expand Down Expand Up @@ -80,7 +81,7 @@ def reset(self) -> None:

def add(
self,
obs: spaces.GraphInstance,
obs: list[spaces.GraphInstance],
action: np.ndarray,
reward: np.ndarray,
episode_start: np.ndarray,
Expand Down Expand Up @@ -111,6 +112,11 @@ def _add_obs(self, obs: list[spaces.GraphInstance]) -> None:
def _swap_and_flatten_obs(self) -> None:
self.observations = _swap_and_flatten_nested_list(self.observations)

def _swap_and_flatten_action_masks(self) -> None:
"""Method to override in buffers meant to be used with action masks."""
# by default, no action masks
...

def get(
self, batch_size: Optional[int] = None
) -> Generator[RolloutBufferSamples, None, None]:
Expand All @@ -119,6 +125,7 @@ def get(
# Prepare the data
if not self.generator_ready:
self._swap_and_flatten_obs()
self._swap_and_flatten_action_masks()
for tensor in self.tensor_names:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True
Expand Down Expand Up @@ -224,34 +231,61 @@ def _get_observations_samples(


class _BaseMaskableRolloutBuffer:

tensor_names = [
"actions",
"values",
"log_probs",
"advantages",
"returns",
"action_masks",
]

def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None:
"""
:param action_masks: Masks applied to constrain the choice of possible actions.
"""

self._add_action_masks(action_masks=action_masks)
super().add(*args, **kwargs)

def _add_action_masks(self, action_masks: Optional[np.ndarray] = None):
if action_masks is not None:
self.action_masks[self.pos] = action_masks.reshape(
(self.n_envs, self.mask_dims)
)

super().add(*args, **kwargs)
def _swap_and_flatten_action_masks(self) -> None:
self.action_masks = self.swap_and_flatten(self.action_masks)

def _get_samples(
self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
) -> MaskableRolloutBufferSamples:
samples_wo_action_masks = super()._get_samples(batch_inds=batch_inds, env=env)
action_masks = self._get_action_masks_samples(batch_inds=batch_inds)
return MaskableRolloutBufferSamples(
*samples_wo_action_masks,
action_masks=self.action_masks[batch_inds].reshape(-1, self.mask_dims),
action_masks=action_masks,
)

def _get_action_masks_samples(self, batch_inds: np.ndarray) -> np.ndarray:
return self.to_torch(self.action_masks[batch_inds].reshape(-1, self.mask_dims))


class _BaseMaskableGraph2NodeRolloutBuffer(_BaseMaskableRolloutBuffer):

action_masks: list[np.ndarray]

def reset(self):
super().reset()
self.action_masks = list()

def _add_action_masks(self, action_masks: Optional[list[np.ndarray]] = None):
if action_masks is not None:
self.action_masks.append(action_masks.reshape(self.n_envs, -1))
else:
self.action_masks.append([])

def _swap_and_flatten_action_masks(self) -> None:
if self.n_envs > 1:
raise NotImplementedError()
else:
self.action_masks = [a.flatten() for a in self.action_masks]

def _get_action_masks_samples(self, batch_inds: np.ndarray) -> th.Tensor:
return pad_sequence(
[self.to_torch(self.action_masks[idx]) for idx in batch_inds],
batch_first=True,
)


Expand All @@ -267,6 +301,12 @@ class MaskableDictGraphRolloutBuffer(
...


class MaskableGraph2NodeRolloutBuffer(
_BaseMaskableGraph2NodeRolloutBuffer, GraphRolloutBuffer, MaskableRolloutBuffer
):
...


class GraphReplayBuffer(ReplayBuffer, GraphBaseBuffer):
observations: list[spaces.GraphInstance]
next_observations: list[spaces.GraphInstance]
Expand Down
Loading

0 comments on commit 857c3a7

Please sign in to comment.