-
Notifications
You must be signed in to change notification settings - Fork 7.2k
[rllib] Fix Multidiscrete support #4869
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,7 @@ | |
| from ray.rllib.policy.sample_batch import SampleBatch | ||
| from ray.rllib.policy.tf_policy import TFPolicy, \ | ||
| LearningRateSchedule | ||
| from ray.rllib.models.action_dist import MultiCategorical | ||
| from ray.rllib.models.action_dist import Categorical | ||
| from ray.rllib.models.catalog import ModelCatalog | ||
| from ray.rllib.utils.annotations import override | ||
| from ray.rllib.utils.explained_variance import explained_variance | ||
|
|
@@ -191,9 +191,7 @@ def __init__(self, | |
| unpacked_outputs = tf.split( | ||
| self.model.outputs, output_hidden_shape, axis=1) | ||
|
|
||
| dist_inputs = unpacked_outputs if is_multidiscrete else \ | ||
| self.model.outputs | ||
| action_dist = dist_class(dist_inputs) | ||
| action_dist = dist_class(self.model.outputs) | ||
|
|
||
| values = self.model.value_function() | ||
| self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, | ||
|
|
@@ -258,32 +256,13 @@ def make_time_major(tensor, drop_last=False): | |
| rewards=make_time_major(rewards, drop_last=True), | ||
| values=make_time_major(values, drop_last=True), | ||
| bootstrap_value=make_time_major(values)[-1], | ||
| dist_class=dist_class, | ||
| dist_class=Categorical if is_multidiscrete else dist_class, | ||
| valid_mask=make_time_major(mask, drop_last=True), | ||
| vf_loss_coeff=self.config["vf_loss_coeff"], | ||
| entropy_coeff=self.config["entropy_coeff"], | ||
| clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], | ||
| clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"]) | ||
|
|
||
| # KL divergence between worker and learner logits for debugging | ||
| model_dist = MultiCategorical(unpacked_outputs) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed the KL stats since I doubt they were very useful
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think they should stay in. They can be useful for debugging the policy (e.g. large fluctuations are a bad sign). See slide 21 here. |
||
| behaviour_dist = MultiCategorical(unpacked_behaviour_logits) | ||
|
|
||
| kls = model_dist.kl(behaviour_dist) | ||
| if len(kls) > 1: | ||
| self.KL_stats = {} | ||
|
|
||
| for i, kl in enumerate(kls): | ||
| self.KL_stats.update({ | ||
| "mean_KL_{}".format(i): tf.reduce_mean(kl), | ||
| "max_KL_{}".format(i): tf.reduce_max(kl), | ||
| }) | ||
| else: | ||
| self.KL_stats = { | ||
| "mean_KL": tf.reduce_mean(kls[0]), | ||
| "max_KL": tf.reduce_max(kls[0]), | ||
| } | ||
|
|
||
| # Initialize TFPolicy | ||
| loss_in = [ | ||
| (SampleBatch.ACTIONS, actions), | ||
|
|
@@ -318,7 +297,7 @@ def make_time_major(tensor, drop_last=False): | |
| self.sess.run(tf.global_variables_initializer()) | ||
|
|
||
| self.stats_fetches = { | ||
| LEARNER_STATS_KEY: dict({ | ||
| LEARNER_STATS_KEY: { | ||
| "cur_lr": tf.cast(self.cur_lr, tf.float64), | ||
| "policy_loss": self.loss.pi_loss, | ||
| "entropy": self.loss.entropy, | ||
|
|
@@ -328,7 +307,7 @@ def make_time_major(tensor, drop_last=False): | |
| "vf_explained_var": explained_variance( | ||
| tf.reshape(self.loss.vtrace_returns.vs, [-1]), | ||
| tf.reshape(make_time_major(values, drop_last=True), [-1])), | ||
| }, **self.KL_stats), | ||
| }, | ||
| } | ||
|
|
||
| @override(TFPolicy) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,7 @@ | |
| import traceback | ||
|
|
||
| import gym | ||
| from gym.spaces import Box, Discrete, Tuple, Dict | ||
| from gym.spaces import Box, Discrete, Tuple, Dict, MultiDiscrete | ||
| from gym.envs.registration import EnvSpec | ||
| import numpy as np | ||
| import sys | ||
|
|
@@ -17,6 +17,7 @@ | |
| ACTION_SPACES_TO_TEST = { | ||
| "discrete": Discrete(5), | ||
| "vector": Box(-1.0, 1.0, (5, ), dtype=np.float32), | ||
| "multidiscrete": MultiDiscrete([1, 2, 3, 4]), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should there be an explicit test for the case that failed earlier? It's possible that this test already covers it though.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, this covers the failing case. The test fails before the other changes in the PR. |
||
| "tuple": Tuple( | ||
| [Discrete(2), | ||
| Discrete(3), | ||
|
|
@@ -61,7 +62,7 @@ def step(self, action): | |
| return StubEnv | ||
|
|
||
|
|
||
| def check_support(alg, config, stats, check_bounds=False): | ||
| def check_support(alg, config, stats, check_bounds=False, name=None): | ||
| for a_name, action_space in ACTION_SPACES_TO_TEST.items(): | ||
| for o_name, obs_space in OBSERVATION_SPACES_TO_TEST.items(): | ||
| print("=== Testing", alg, action_space, obs_space, "===") | ||
|
|
@@ -87,7 +88,7 @@ def check_support(alg, config, stats, check_bounds=False): | |
| pass | ||
| print(stat) | ||
| print() | ||
| stats[alg, a_name, o_name] = stat | ||
| stats[name or alg, a_name, o_name] = stat | ||
|
|
||
|
|
||
| def check_support_multiagent(alg, config): | ||
|
|
@@ -114,6 +115,11 @@ def testAll(self): | |
| stats = {} | ||
| check_support("IMPALA", {"num_gpus": 0}, stats) | ||
| check_support("APPO", {"num_gpus": 0, "vtrace": False}, stats) | ||
| check_support( | ||
| "APPO", { | ||
| "num_gpus": 0, | ||
| "vtrace": True | ||
| }, stats, name="APPO-vt") | ||
| check_support( | ||
| "DDPG", { | ||
| "exploration_ou_noise_scale": 100.0, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this needed anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's used below in a spot.