-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement GNN-based PPO graph->node for ray.rllib framework #472
Open
nhuet
wants to merge
4
commits into
airbus:master
Choose a base branch
from
nhuet:gnn2node-rllib-mask
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@nhuet I made a fix for the windows compilation error in |
The custom model follows what has been done for sb3 framework: - we predict values as before with GNN feature extraction + reduction to a fixed number of features + MLP - we predict actions with a single GNN which predict logits for each node without knowing in advance their number. To use that feature, we need to set the attribute graph_node_action to True. The doc of RayRLlib has been enriched to explain what is happening in that case. The main issue here is that we have also to pad rollout buffers entries corresponding to action_dist_inputs (the logits) as the number of available actions is potentially varying from an observation to another.
The length of the action mask may vary, as the number of actions is the number of nodes in the observation graph, whichs may vary from a step to another. Because of that, the custom models used before (`TorchParametricActionsModel`) does not apply properly since it uses an action embedding needing the (max) number of actions in advance. We thus use a simpler version of model (`TorchMaskedActionsModel`), much like what is done for sb3 framework: we predict action logits as done when no action masking applies and only at the end, we apply the mask (by adding log(action_mask) to the logits). The main difference is that no last layer is managed by the custom model with weights changed for non applicable actions, this is managed instead by the GNN. Once again the buffers may list action masks of different sizes so we need to apply padding before concatenation. Lastly, the first dummy samples used to initialize weights are generated by rllib from the observation space of the `AsRLlibMultiAgentEnv`. So we need to match the default size chosen for the nodes when converting the graph space into a dict space (to be recognized by rllib) in the action_mask space enriching the observation space. This is done in `_create_agent_obs_space_for_rllib()`. As more stuff on action masking is done here, we create an arborescence similar to what is done for gnn to regroup code related to it: - the custom models for action masking are now available in: - skdecide/hub/solver/ray_rllib/action_masking/models/tf for tensorflow - skdecide/hub/solver/ray_rllib/action_masking/models/torch for torch - the space conversion utilities (such as keys for true obs and action mask) in skdecide/hub/solver/ray_rllib/action_masking/utils/spaces/space_utils.py
This fix a bug appearing in tests unrelated to gnn bu chained after it.
0653bb6
to
1b376b8
Compare
Done |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
We are dealing here with domain whose observations are graphs and actions, nodes of these graphs.
The custom model follows what has been done for sb3 framework:
To use that feature, we need to set the attribute
graph_node_action
toTrue
. The doc of RayRLlib has been enriched to explain what is happening in that case.The main issue here is that we have also to pad rollout buffers entriescorresponding to
action_dist_inputs
(the logits) as the number of available actions is potentially varying from an observation to another.In case of action masking, a bit more work is necessary:
TorchParametricActionsModel
) does not apply properly since it uses an action embedding needing the (max)number of actions in advance. We thus use a simpler version of model (
TorchMaskedActionsModel
), much likewhat is done for sb3 framework: we predict action logits as done when no action masking applies and only at the end, we apply the mask (by adding log(action_mask) to the logits). The main difference is that no last layer is managed by the
custom model with weights changed for non applicable actions, this is managed instead by the GNN.
AsRLlibMultiAgentEnv
. So we need to match the default size chosen for the nodes when converting the graph space into a dict space (to be recognized by rllib) in the action_mask space enriching the observation space. This is done in_create_agent_obs_space_for_rllib()
.As more stuff on action masking is done here, we create an arborescence similar to what is done for gnn to regroup code related to it: