From 27e3ddad1e27b2d27c0a4a61e510f897f41a39bf Mon Sep 17 00:00:00 2001 From: Ruofan Kong Date: Wed, 2 Jun 2021 00:30:05 -0700 Subject: [PATCH] No Case: Fix the RLLib hybrid shape bug for SAC 1.3.0 --- rllib/agents/sac/sac_tf_model.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/rllib/agents/sac/sac_tf_model.py b/rllib/agents/sac/sac_tf_model.py index b457f1e947e0..b3b6f9f9f2f9 100644 --- a/rllib/agents/sac/sac_tf_model.py +++ b/rllib/agents/sac/sac_tf_model.py @@ -241,7 +241,23 @@ def _get_q_value(self, model_out, actions, net): if self.concat_obs_and_actions: input_dict = {"obs": tf.concat([model_out, actions], axis=-1)} else: - input_dict = {"obs": force_list(model_out) + [actions]} + # For the discrete case, action is always None. + shapes = [] + for space in net.obs_space: + if isinstance(space, Discrete): + shapes.append(space.n) + elif isinstance(space, Box): + shapes.append(space.shape) + else: + raise RuntimeError("The space type is not supported.") + + input_dict = { + "obs": tf.split( + tf.concat(force_list(model_out) + [actions], axis=-1), + num_or_size_splits=shapes, + axis=-1, + ) + } # Discrete case -> return q-vals for all actions. else: input_dict = {"obs": model_out}