-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor and clean-up of common code #89
Conversation
@@ -23,12 +25,30 @@ | |||
from stable_baselines3.common.noise import ActionNoise | |||
|
|||
|
|||
def maybe_make_env(env: Union[GymEnv, str, None], monitor_wrapper: bool, verbose: int) -> Optional[GymEnv]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if it should be there, as a static method or in utils.py
...
because this is only used by the BaseAlgorithm
class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If only used by BaseAlgorithm
then probably belongs somewhere in this file.
I'd be OK with module-level function or static method. I've got in the habit of avoiding @staticmethod
after being brainwashed by Google but I don't think their argument is that strong.
Sounds reasonable ;) (should be already the case in fact)
yep. I tried not to reproduced the errors from SB2 but I may have missed that one.
Good point. I forgot to add them after refactoring the distributions.
No, we do not support RNN yet, this is in the roadmap for v1.1+ though. I agree with opening an issue to keep track of it ;)
yes, this is desired (see comment above). I had issue in the past because of inheritance.
An error will be thrown earlier in fact:
I agree with that feeling... Maybe it needs and additional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM =)
Description
Overall the code looked good -- I was only tempted to make a few fairly minor changes amongst the couple of thousand lines of code.
Changes from reviewing common code as per #17. Specific changes:
base_class.py
:maybe_make_env
raise NotImplementedError()
on abstract methods. This should not be needed as Python will refuse to instantiate classes with abstract methods.hasattr(base_class.BaseAlgorithm, '_setup_model')
isTrue
so I think this check cannot ever fire.distributions.py
:Distributions
an ABC, use@abstractmethod
instead ofNotImplementedError
Issues I noticed but wasn't sure how to resolve (tick indicates resolved subsequently):
BaseAlgorithm.load
, we acceptdata
missing observation and action space, provided it contains the"env"
key. But if the user provides anenv
argument, then we check against the (possibly non-existent) observation/action spaces. Can we just always require that data contains observation & action spaces?BaseAlgorithm._setup_learn
I think we may have a similar issue to that reported in Support learn() with total timesteps less than episode length hill-a/stable-baselines#619 and (partially) fixed in RLModel.learn: Reuse logging statistics after each call hill-a/stable-baselines#649 In brief, repeated calls tolearn()
each call_setup_learn()
which resets theep_info_buffer
andep_success_buffer
. Perhaps we should only reset these ifreset_num_timesteps
is True?Distribution
,proba_distribution
andproba_distribution_net
are not included, but every subclass defines these methods. Should we add them to the ABC?BasePolicy.predict
, the code forstate
andmask
are commented out. Not sure what's going on here -- do we just not support RNN policies yet? If so, should warn of this in the docstring, and probably open an enhancement issue to resolve this.BasePolicy._get_data
, two lines are commented out forsquash_output
andfeatures_extractor
-- I'm not sure if this is desired or not?ActorCriticPolicy._build
,self.action_net
may not get constructed if we don't recognize the distribution. Is this intentional (e.g. responsibility of subclass to build it) or should we add anelse: raise NotImplementedError()
? I note we do raiseNotImplementedError
inActorCriticPolicy._get_action_dist_from_latent
.Critic
in SAC and TD3 subclassingBasePolicy
feels a little odd, e.g. they do not implement_predict
. However, I'm not sure it's actually worth refactoring.Motivation and Context
Types of changes
Checklist:
make lint
make pytest
andmake type
both pass.