Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Soft Actor-Critic #120

Merged
merged 31 commits into from
Dec 13, 2018
Merged

Soft Actor-Critic #120

merged 31 commits into from
Dec 13, 2018

Conversation

araffin
Copy link
Collaborator

@araffin araffin commented Dec 8, 2018

This PR adds Soft Actor-Critic algorithms and fixes some bugs.
Fixes:

  • DDPG target network not being saved
  • DQN prioritized replay buffer parameter not being used

Notes:

Differences with original implementation:

  • no regularization on policy parameters
  • usage of entropy coefficient (equivalent to inverse of reward scale), this prevent from having high values in Q-Values losses
  • default network architecture is [64, 64] and not [256, 256]

docs/modules/sac.rst Outdated Show resolved Hide resolved
Copy link
Owner

@hill-a hill-a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor detail, otherwise LGTM

stable_baselines/sac/policies.py Outdated Show resolved Hide resolved
hill-a
hill-a previously approved these changes Dec 10, 2018
Copy link
Owner

@hill-a hill-a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge from master, at it should be good

hill-a
hill-a previously approved these changes Dec 12, 2018
The python dependencies needed to be installed beforehand
because of the __version__ that was imported
@hill-a hill-a merged commit c4d41d3 into master Dec 13, 2018
@araffin araffin deleted the sac branch December 13, 2018 10:54
@xionghuichen
Copy link

Hello, can soft actor-critic reach a similar performance to the original implementation? I test it in Halfcheetah and the final performance is about 11, 000 (14, 000 in published results).

@araffin
Copy link
Collaborator Author

araffin commented Oct 20, 2019

Hello,
What hyperparameters did you use? how many training steps? how many seeds? What version of HalfCheetah?
Please look at the rl zoo, where we could get good results on most envs.

@xionghuichen
Copy link

@araffin Thanks for the reply. I use the same hyperparameters and the time-steps to the original paper. I will check the Halfcheetah-v2 hyperparameters on the page (https://github.com/araffin/rl-baselines-zoo/blob/master/hyperparams/sac.yml) and send you a report as soon as possible.

@araffin
Copy link
Collaborator Author

araffin commented Oct 20, 2019

be sure to evaluate the agent with a test env and with 'deterministic=True' (in the predict)

@xionghuichen
Copy link

xionghuichen commented Oct 23, 2019

@araffin hello, the results of HalfCheetah-v2 is as follows:
myplot

Additional remarks:

  1. The experiments run with default parameters directly. In particular, "original-sac" runs with https://github.com/haarnoja/sac/blob/master/examples/mujoco_all_sac.py, and "rl-baselines-zoo-sac" runs with https://github.com/araffin/rl-baselines-zoo/blob/master/train.py;
  2. The results are based on a single seed experiment since the performance variance is small in HalfCheetah-v2;
  3. When evaluation, policy runs with deterministic action.

To my knowledge, there are some differences between our implementation and haarnoja/sac:

params haarnoja/sac rl-baselines-zoo
sac_ent_coef 0.2 auto
regularization_coef 0.001 None

However, just fix these differences can not reach haarnoja/sac performance.

Ps: evaluation code of stable-baselines
``

           if self.num_timesteps % 4000 == 0:
                eval_ob = self.eval_env.reset()
                eval_epi_rewards = 0
                eval_epis = 0
                eval_performance = []
                while True:
                    eval_action = self.policy_tf.step(eval_ob[None], deterministic=True).flatten()
                    eval_rescaled_action = eval_action * np.abs(self.action_space.low)
                    eval_new_obs, eval_reward, eval_done, eval_info = self.eval_env.step(eval_rescaled_action)
                    eval_epi_rewards += eval_reward
                    eval_ob = eval_new_obs
                    if eval_done:
                        eval_ob = self.eval_env.reset()
                        eval_performance.append(eval_epi_rewards)
                        eval_epi_rewards = 0
                        eval_epis += 1
                        if eval_epis > 10:
                            break

                logger.record_tabular("eval/performance", np.mean(eval_performance))

Edit:

  • 10.29.2019: modify sac_ent_coef in rl-baselines-zoo (0.1->auto).

@araffin
Copy link
Collaborator Author

araffin commented Oct 23, 2019

Default hyperparameters won't give you the best results. You should change the network architecture and batch size to match the paper hyperparameters.
For the entropy coeff, i don't get where you found that value, it is auto by default.

@araffin
Copy link
Collaborator Author

araffin commented Oct 23, 2019

For the evaluation, call predict directly

@araffin
Copy link
Collaborator Author

araffin commented Oct 23, 2019

If you cannot match the result, please open an issue with the complete steps to reproduce your experiments. Note that this is HalfCheetah-v1 that is used in the SAC repo.

@xionghuichen
Copy link

@araffin Thank you for the suggestions.

I have also tried the hyper-parameters matched the original paper before, but it still not work. I will open an issue a few days later (a little busy these days :( ).

By the way, although SAC repo test in HalfCheetah-v1, the performance in HalfCheetah-v2 is similar (orange line is tested in v2). entropy-coeff is regarded as alpha^-1. SAC in haarnoja/sac repo set reward scale to a hyper-parameter. alpha=5 in HalfCheetah, so I said sac_ent_coef=0.2 in haarnoja/sac.

@xionghuichen
Copy link

xionghuichen commented Nov 7, 2019

@araffin Hello, I have good news! Recently, I checked code in two repo and found the critical difference leading to the performance gap.

In haarnoja/sac, the environment runs without TimeLmit wrapper, while we run with it. (ref: https://github.com/haarnoja/sac/blob/8258e33633c7e37833cc39315891e77adfbe14b2/sac/envs/gym_env.py#L75)

image
The read and blue lines are 2 seeds without TimeLimit wrapper, while the pick and grey lines are 2 seeds run with TimeLimit. It's surprising that the Markovian is so important in HalfCheetah!

@araffin
Copy link
Collaborator Author

araffin commented Nov 7, 2019

Good news, thanks for the update =)
In fact, SAC works quite well on all envs from the rl zoo, so i was a bit surprised ;)

Yes the markovian assumption is important here for value estimation (i was planning to write a post about time limits in RL too)

@xionghuichen
Copy link

Cool, look forward to your post! 

by the way, do you have any plan to add new algorithms to stable_baselines? I like the code structure and the algorithm implementation of stable_baselines. Maybe I can do some contribution to it.

@araffin
Copy link
Collaborator Author

araffin commented Nov 7, 2019

like the code structure and the algorithm implementation of stable_baselines.

thanks =)

Well, if you want to add an algorithm, open an issue and we will discuss it.
But currently, our focus will be on the tf2 migration and other improvements (cf roadmap and milestones).
And contributions are welcomed ;)

@xionghuichen
Copy link

OK. I will follow your roadmap and find something interested! But how could I track the progress of tf2 migration?

@araffin
Copy link
Collaborator Author

araffin commented Nov 16, 2019

there is a wip PR for that as well as an open issue

@xionghuichen
Copy link

👌 I got it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants