-
Notifications
You must be signed in to change notification settings - Fork 112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support PettingZoo Parallel API and action mask #305
Support PettingZoo Parallel API and action mask #305
Conversation
hi @nkzawa ! would it be possible to make sure pre-commit passes, as well as Ubuntu tests? A short doc page would be very nice to have too :) I will do a proper review soon! EDIT: just noticed your comment about pyflakes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great contribution, thank you!
There are a couple of small comments to address.
If you have the time to add a small documentation page, that'd be very helpful! Something similar to other contributed env integrations, like https://www.samplefactory.dev/09-environment-integrations/nethack/ (does not have to be as elaborate as this, installation instruction and training example would already be great, a wandb run/report or some videos - doubly great!)
epsilons = torch.full_like(probs, 1e-6) | ||
probs = torch.where(all_zero, epsilons, probs) # ensure sum of probabilities is non-zero | ||
|
||
samples = torch.multinomial(probs, 1, True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just checking if we don't have to re-normalize the probabilities here so they add up to 1.
Does torch.multinomial do this internally?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically, it seems there is no need to add up to 1, according to the doc:
The rows of input do not need to sum to one (in which case we use > the values as weights), ...
https://pytorch.org/docs/stable/generated/torch.multinomial.html
But I'm not so sure honestly (I'm a newbie on RL). So please feel free to fix if you see something wrong with the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, seems it requires to normalize the value with softmax as far as I understand, so implemented that.
sample_factory/model/actor_critic.py
Outdated
@@ -321,4 +336,13 @@ def create_actor_critic(cfg: Config, obs_space: ObsSpace, action_space: ActionSp | |||
from sample_factory.algo.utils.context import global_model_factory | |||
|
|||
make_actor_critic_func = global_model_factory().make_actor_critic_func | |||
return make_actor_critic_func(cfg, obs_space, action_space) | |||
return make_actor_critic_func(cfg, obs_space_without_action_mask(obs_space), action_space) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be a bit cleaner to add special treatment for action_mask inside make_actor_critic_func
but I'm fine with this solution too 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assumed that it meant to do inside default_make_actor_critic_func
. Fixed it so anyway 🙏
def register_pettingzoo_fixture(self): | ||
register_custom_components() | ||
yield # this is where the actual test happens | ||
reset_global_context() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this (similar to other env tests)
this resets global encoder factory, otherwise other tests can fail if this runs first.
(I'd be the first to admit this is not a perfect solution but hey I wrote this years ago)
Will add docs as well 👍 EDIT: |
Based on this paper and the Maskable PPO implementation in Stable-Baselines3, it appears that action masks should also be applied when calculating log probabilities as well. This approach helps the model learn to avoid selecting invalid actions, as far as I understand. I'll modify the code accordingly. EDIT: done |
def __init__(self, full_env_name, cfg, render_mode=None): | ||
... | ||
self.observation_space = gym.spaces.Dict({ | ||
"obs": gym.spaces.Box(low=0, high=1, shape=(3, 3, 2), dtype=np.int8), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small nit: I wonder if low=0 high=1 here is intentional, would this mean binary observations?
I understand 0/1 in action_mask since this is a binary mask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is intentional since it's retrieved from tic tac toe of PettingZoo but I think we can change if it's confusing.
@@ -0,0 +1,38 @@ | |||
# Action Masking |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This documentation is wonderful, thank you!
@@ -0,0 +1,46 @@ | |||
# PettingZoo |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love this. Thank you!
# https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/nn/util.py#L243 | ||
def masked_softmax(logits, mask): | ||
# To limit numerical errors from large vector elements outside the mask, we zero these out. | ||
result = functional.softmax(logits * mask, dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you help me understand this please?
I think logits in general can be negative, or positive but close to 0, in which case multiplying them by zero does not achieve the desired effect.
I'd say we should probably use something like this instead?
def masked_softmax(logits, mask):
# Mask out the invalid logits by adding a large negative number (-1e9)
logits = logits + (mask == 0) * -1e9
result = functional.softmax(logits, dim=-1)
result = result * mask
result = result / (result.sum(dim=-1, keepdim=True) + 1e-13)
return result
the choice of 1e-9 is arbitrary here, but it could be something like -max(abs(logits)) * 1e6
to make this universal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's got from AllenNLP including the comment so don't fully understand but as far as I investigated seems your version is safer in some cases even tho usually the results are identical in both versions 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
# vector + mask.log() is an easy way to zero out masked elements in logspace, but it | ||
# results in nans when the whole vector is masked. We need a very small value instead of a | ||
# zero in the mask for these cases. | ||
logits = logits + (mask + 1e-13).log() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes more sense to me, this is essentially adding log(1e-13) to non-valid elements which is about -30. I'm not sure if this is universally correct, but most likely should work. Why can't we just explicitly add a large negative constant though, like -1e9 or -max(abs(logits)) * 1e6
like in the previous example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems you're correct. This version causes a problem in extreme cases as far as I tested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
As the title, add support for PettingZoo by providing a wrapper class. Also, add support for action masks.
Creating PettingZoo env
Currently, it supports only Parallel API since it requires a different flow of executions for supporting AEC API.
Action mask
It works when you add the
action_mask
key to thedict
observation.It seems this is the most common interface for providing action masks on PettingZoo, so I think it makes sense to follow that. It's also common to have the value in
info
asinfo["action_mask"]
but it's not supported for now since sample-factory requires knowing the shape to allocate buffer as far as I understand.I added an example that trains Tic-Tac-Toe but I'm not sure about the configuration so appreciate any suggestions.
btw, I followed
CONTRIBUTING.md
butmake check-codestyle
fails with the error:Also,
make test
fails whenmujoco
is involved.Python version is 3.11.9 on Intel Mac