Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and clean-up of common code #89

Merged
merged 8 commits into from
Jul 8, 2020
Merged

Conversation

AdamGleave
Copy link
Collaborator

@AdamGleave AdamGleave commented Jul 3, 2020

Description

Overall the code looked good -- I was only tempted to make a few fairly minor changes amongst the couple of thousand lines of code.

Changes from reviewing common code as per #17. Specific changes:

  • Add module docstrings
  • base_class.py:
    • Refactor some environment creation logic into maybe_make_env
    • Remove raise NotImplementedError() on abstract methods. This should not be needed as Python will refuse to instantiate classes with abstract methods.
    • Removed check for "_setup_model" existing: hasattr(base_class.BaseAlgorithm, '_setup_model') is True so I think this check cannot ever fire.
  • distributions.py:
    • Make Distributions an ABC, use @abstractmethod instead of NotImplementedError
    • Have subclasses methods defined in consistent order

Issues I noticed but wasn't sure how to resolve (tick indicates resolved subsequently):

  • In BaseAlgorithm.load, we accept data missing observation and action space, provided it contains the "env" key. But if the user provides an env argument, then we check against the (possibly non-existent) observation/action spaces. Can we just always require that data contains observation & action spaces?
  • In BaseAlgorithm._setup_learn I think we may have a similar issue to that reported in Support learn() with total timesteps less than episode length hill-a/stable-baselines#619 and (partially) fixed in RLModel.learn: Reuse logging statistics after each call hill-a/stable-baselines#649 In brief, repeated calls to learn() each call _setup_learn() which resets the ep_info_buffer and ep_success_buffer. Perhaps we should only reset these if reset_num_timesteps is True?
  • In Distribution, proba_distribution and proba_distribution_net are not included, but every subclass defines these methods. Should we add them to the ABC?
  • In BasePolicy.predict, the code for state and mask are commented out. Not sure what's going on here -- do we just not support RNN policies yet? If so, should warn of this in the docstring, and probably open an enhancement issue to resolve this.
  • In BasePolicy._get_data, two lines are commented out for squash_output and features_extractor -- I'm not sure if this is desired or not?
  • In ActorCriticPolicy._build, self.action_net may not get constructed if we don't recognize the distribution. Is this intentional (e.g. responsibility of subclass to build it) or should we add an else: raise NotImplementedError()? I note we do raise NotImplementedError in ActorCriticPolicy._get_action_dist_from_latent.
  • The Critic in SAC and TD3 subclassing BasePolicy feels a little odd, e.g. they do not implement _predict. However, I'm not sure it's actually worth refactoring.

Motivation and Context

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist:

  • I've read the CONTRIBUTION guide (required)
  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have checked the codestyle using make lint
  • I have ensured make pytest and make type both pass.

@AdamGleave AdamGleave mentioned this pull request Jul 3, 2020
9 tasks
@AdamGleave AdamGleave requested a review from araffin July 3, 2020 04:45
@AdamGleave AdamGleave marked this pull request as ready for review July 3, 2020 04:45
@@ -23,12 +25,30 @@
from stable_baselines3.common.noise import ActionNoise


def maybe_make_env(env: Union[GymEnv, str, None], monitor_wrapper: bool, verbose: int) -> Optional[GymEnv]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if it should be there, as a static method or in utils.py...
because this is only used by the BaseAlgorithm class.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If only used by BaseAlgorithm then probably belongs somewhere in this file.

I'd be OK with module-level function or static method. I've got in the habit of avoiding @staticmethod after being brainwashed by Google but I don't think their argument is that strong.

@araffin
Copy link
Member

araffin commented Jul 6, 2020

Can we just always require that data contains observation & action spaces?

Sounds reasonable ;) (should be already the case in fact)

Perhaps we should only reset these if reset_num_timesteps is True?

yep. I tried not to reproduced the errors from SB2 but I may have missed that one.

In Distribution, proba_distribution and proba_distribution_net are not included, but every subclass defines these methods. Should we add them to the ABC?

Good point. I forgot to add them after refactoring the distributions.

Not sure what's going on here -- do we just not support RNN policies yet? If so, should warn of this in the docstring, and probably open an enhancement issue to resolve this.

No, we do not support RNN yet, this is in the roadmap for v1.1+ though. I agree with opening an issue to keep track of it ;)

two lines are commented out for squash_output and features_extractor -- I'm not sure if this is desired or not?

yes, this is desired (see comment above). I had issue in the past because of inheritance.

self.action_net may not get constructed if we don't recognize the distribution. Is this intentional (e.g. responsibility of subclass to build it) or should we add an else: raise NotImplementedError()?

An error will be thrown earlier in fact: self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs).

he Critic in SAC and TD3 subclassing BasePolicy feels a little odd, e.g. they do not implement _predict. However, I'm not sure it's actually worth refactoring.

I agree with that feeling... Maybe it needs and additional BaseModel class (that implements all the basic operations) and BasePolicy would derive from it?

Copy link
Member

@araffin araffin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM =)

@AdamGleave AdamGleave merged commit c39ed39 into master Jul 8, 2020
@AdamGleave AdamGleave deleted the base-class-review branch July 8, 2020 19:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants