Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support PettingZoo Parallel API and action mask #305

Merged

Conversation

nkzawa
Copy link
Contributor

@nkzawa nkzawa commented Sep 15, 2024

As the title, add support for PettingZoo by providing a wrapper class. Also, add support for action masks.

Creating PettingZoo env

from sample_factory.envs.pettingzoo_envs import PettingZooParallelEnv

def make_pettingzoo_env(full_env_name, cfg=None, env_config=None, render_mode: Optional[str] = None):
    return PettingZooParallelEnv(some_env.parallel_env(render_mode=render_mode))

Currently, it supports only Parallel API since it requires a different flow of executions for supporting AEC API.

Action mask

It works when you add the action_mask key to the dict observation.

import gymnasium as gym

class CustomEnv(gym.Env):
    def __init__(self, full_env_name, cfg, render_mode: Optional[str] = None):
        self.observation_space = gym.spaces.Dict({
            "obs": gym.spaces.Box(low=0, high=1, shape=(3, 3, 2), dtype=np.int8),
            "action_mask": gym.spaces.Box(low=0, high=1, shape=(9,), dtype=np.int8),
        })
        self.action_spaces = gym.spaces.Discrete(9)

    def step(self, action):
        ...
        return {"obs": obs, "action_mask": action_mask}, reward, terminated, truncated, info

It seems this is the most common interface for providing action masks on PettingZoo, so I think it makes sense to follow that. It's also common to have the value in info as info["action_mask"] but it's not supported for now since sample-factory requires knowing the shape to allocate buffer as far as I understand.

I added an example that trains Tic-Tac-Toe but I'm not sure about the configuration so appreciate any suggestions.


btw, I followed CONTRIBUTING.md but make check-codestyle fails with the error:

.../python3.11/site-packages/sympy/polys/numberfields/resolvent_lookup.py: "pyflakes[F]" failed during execution due to RecursionError('maximum recursion depth exceeded')

Also, make test fails when mujoco is involved.

OSError: dlopen(/System/Library/OpenGL.framework/OpenGL, 0x0006): tried: '/System/Library/OpenGL.framework/OpenGL' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/System/Library/OpenGL.framework/OpenGL' (no such file), '/System/Library/OpenGL.framework/OpenGL' (no such file, not in dyld cache)

Python version is 3.11.9 on Intel Mac

@alex-petrenko
Copy link
Owner

alex-petrenko commented Sep 16, 2024

hi @nkzawa !
thank you for this contribution

would it be possible to make sure pre-commit passes, as well as Ubuntu tests?
https://www.samplefactory.dev/12-community/contribution/ there are some details here

A short doc page would be very nice to have too :) I will do a proper review soon!

EDIT: just noticed your comment about pyflakes
I will take a look

Copy link
Owner

@alex-petrenko alex-petrenko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great contribution, thank you!
There are a couple of small comments to address.
If you have the time to add a small documentation page, that'd be very helpful! Something similar to other contributed env integrations, like https://www.samplefactory.dev/09-environment-integrations/nethack/ (does not have to be as elaborate as this, installation instruction and training example would already be great, a wandb run/report or some videos - doubly great!)

epsilons = torch.full_like(probs, 1e-6)
probs = torch.where(all_zero, epsilons, probs) # ensure sum of probabilities is non-zero

samples = torch.multinomial(probs, 1, True)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just checking if we don't have to re-normalize the probabilities here so they add up to 1.
Does torch.multinomial do this internally?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, it seems there is no need to add up to 1, according to the doc:

The rows of input do not need to sum to one (in which case we use > the values as weights), ...

https://pytorch.org/docs/stable/generated/torch.multinomial.html

But I'm not so sure honestly (I'm a newbie on RL). So please feel free to fix if you see something wrong with the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, seems it requires to normalize the value with softmax as far as I understand, so implemented that.

sample_factory/algo/utils/action_distributions.py Outdated Show resolved Hide resolved
sample_factory/algo/utils/action_distributions.py Outdated Show resolved Hide resolved
sample_factory/algo/utils/action_distributions.py Outdated Show resolved Hide resolved
@@ -321,4 +336,13 @@ def create_actor_critic(cfg: Config, obs_space: ObsSpace, action_space: ActionSp
from sample_factory.algo.utils.context import global_model_factory

make_actor_critic_func = global_model_factory().make_actor_critic_func
return make_actor_critic_func(cfg, obs_space, action_space)
return make_actor_critic_func(cfg, obs_space_without_action_mask(obs_space), action_space)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be a bit cleaner to add special treatment for action_mask inside make_actor_critic_func but I'm fine with this solution too 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assumed that it meant to do inside default_make_actor_critic_func. Fixed it so anyway 🙏

def register_pettingzoo_fixture(self):
register_custom_components()
yield # this is where the actual test happens
reset_global_context()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this (similar to other env tests)
this resets global encoder factory, otherwise other tests can fail if this runs first.

(I'd be the first to admit this is not a perfect solution but hey I wrote this years ago)

@nkzawa
Copy link
Contributor Author

nkzawa commented Sep 17, 2024

Will add docs as well 👍

EDIT:
Added docs, though no video nor report. Feel free to fix/improve 🙏

@nkzawa
Copy link
Contributor Author

nkzawa commented Oct 3, 2024

Based on this paper and the Maskable PPO implementation in Stable-Baselines3, it appears that action masks should also be applied when calculating log probabilities as well. This approach helps the model learn to avoid selecting invalid actions, as far as I understand. I'll modify the code accordingly.

EDIT: done

def __init__(self, full_env_name, cfg, render_mode=None):
...
self.observation_space = gym.spaces.Dict({
"obs": gym.spaces.Box(low=0, high=1, shape=(3, 3, 2), dtype=np.int8),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit: I wonder if low=0 high=1 here is intentional, would this mean binary observations?

I understand 0/1 in action_mask since this is a binary mask

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intentional since it's retrieved from tic tac toe of PettingZoo but I think we can change if it's confusing.

@@ -0,0 +1,38 @@
# Action Masking
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This documentation is wonderful, thank you!

@@ -0,0 +1,46 @@
# PettingZoo
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this. Thank you!

# https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/nn/util.py#L243
def masked_softmax(logits, mask):
# To limit numerical errors from large vector elements outside the mask, we zero these out.
result = functional.softmax(logits * mask, dim=-1)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you help me understand this please?

I think logits in general can be negative, or positive but close to 0, in which case multiplying them by zero does not achieve the desired effect.

I'd say we should probably use something like this instead?

def masked_softmax(logits, mask):
    # Mask out the invalid logits by adding a large negative number (-1e9)
    logits = logits + (mask == 0) * -1e9
    result = functional.softmax(logits, dim=-1)
    result = result * mask
    result = result / (result.sum(dim=-1, keepdim=True) + 1e-13)
    return result

the choice of 1e-9 is arbitrary here, but it could be something like -max(abs(logits)) * 1e6 to make this universal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's got from AllenNLP including the comment so don't fully understand but as far as I investigated seems your version is safer in some cases even tho usually the results are identical in both versions 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

# vector + mask.log() is an easy way to zero out masked elements in logspace, but it
# results in nans when the whole vector is masked. We need a very small value instead of a
# zero in the mask for these cases.
logits = logits + (mask + 1e-13).log()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes more sense to me, this is essentially adding log(1e-13) to non-valid elements which is about -30. I'm not sure if this is universally correct, but most likely should work. Why can't we just explicitly add a large negative constant though, like -1e9 or -max(abs(logits)) * 1e6 like in the previous example?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you're correct. This version causes a problem in extreme cases as far as I tested.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@alex-petrenko alex-petrenko merged commit abbc459 into alex-petrenko:master Oct 23, 2024
4 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants