diff --git a/rllib/BUILD b/rllib/BUILD index f8f1cbd3c6f8..e75d6bffdc98 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -270,6 +270,16 @@ py_test( ) # PPO +py_test( + name = "run_regression_tests_cartpole_ppo_jax", + main = "tests/run_regression_tests.py", + tags = ["learning_tests_torch", "learning_tests_cartpole"], + size = "medium", + srcs = ["tests/run_regression_tests.py"], + data = ["tuned_examples/ppo/cartpole-ppo.yaml"], + args = ["--yaml-dir=tuned_examples/ppo", "--framework=jax"] +) + py_test( name = "run_regression_tests_cartpole_ppo_tf", main = "tests/run_regression_tests.py", @@ -1112,6 +1122,13 @@ py_test( srcs = ["models/tests/test_distributions.py"] ) +py_test( + name = "test_jax_models", + tags = ["models"], + size = "small", + srcs = ["models/tests/test_jax_models.py"] +) + py_test( name = "test_models", tags = ["models"], diff --git a/rllib/agents/ppo/__init__.py b/rllib/agents/ppo/__init__.py index 1207737074e0..4d63ecdc6216 100644 --- a/rllib/agents/ppo/__init__.py +++ b/rllib/agents/ppo/__init__.py @@ -1,4 +1,5 @@ from ray.rllib.agents.ppo.ppo import PPOTrainer, DEFAULT_CONFIG +from ray.rllib.agents.ppo.ppo_jax_policy import PPOJAXPolicy from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy from ray.rllib.agents.ppo.appo import APPOTrainer @@ -8,6 +9,7 @@ "APPOTrainer", "DDPPOTrainer", "DEFAULT_CONFIG", + "PPOJAXPolicy", "PPOTFPolicy", "PPOTorchPolicy", "PPOTrainer", diff --git a/rllib/agents/ppo/ppo.py b/rllib/agents/ppo/ppo.py index 9270e6cf5f69..807b824b4a23 100644 --- a/rllib/agents/ppo/ppo.py +++ b/rllib/agents/ppo/ppo.py @@ -140,7 +140,7 @@ def validate_config(config: TrainerConfigDict) -> None: "trajectory). Consider setting batch_mode=complete_episodes.") # Multi-gpu not supported for PyTorch and tf-eager. - if config["framework"] in ["tf2", "tfe", "torch"]: + if config["framework"] != "tf": config["simple_optimizer"] = True # Performance warning, if "simple" optimizer used with (static-graph) tf. elif config["simple_optimizer"]: @@ -168,6 +168,9 @@ def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]: if config["framework"] == "torch": from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy return PPOTorchPolicy + elif config["framework"] == "jax": + from ray.rllib.agents.ppo.ppo_jax_policy import PPOJAXPolicy + return PPOJAXPolicy class UpdateKL: @@ -187,9 +190,10 @@ def update(pi, pi_id): assert "kl" not in fetches, ( "kl should be nested under policy id key", fetches) if pi_id in fetches: - assert "kl" in fetches[pi_id], (fetches, pi_id) + #assert "kl" in fetches[pi_id], (fetches, pi_id) # Make the actual `Policy.update_kl()` call. - pi.update_kl(fetches[pi_id]["kl"]) + if "kl" in fetches[pi_id]:#TODO + pi.update_kl(fetches[pi_id]["kl"]) else: logger.warning("No data for {}, not updating kl".format(pi_id)) diff --git a/rllib/agents/ppo/ppo_jax_policy.py b/rllib/agents/ppo/ppo_jax_policy.py new file mode 100644 index 000000000000..8e97eb6e0af0 --- /dev/null +++ b/rllib/agents/ppo/ppo_jax_policy.py @@ -0,0 +1,207 @@ +""" +JAX policy class used for PPO. +""" +import gym +import logging +from typing import List, Type, Union + +import ray +from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping +from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \ + setup_config +from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches, \ + KLCoeffMixin, kl_and_loss_stats +from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \ + Postprocessing +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper +from ray.rllib.policy.jax_policy import LearningRateSchedule +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.policy_template import build_policy_class +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_policy import EntropyCoeffSchedule +from ray.rllib.utils.framework import try_import_jax +from ray.rllib.utils.jax_ops import explained_variance +from ray.rllib.utils.typing import TensorType, TrainerConfigDict + +jax, flax = try_import_jax() +jnp = None +if jax: + import jax.numpy as jnp + +logger = logging.getLogger(__name__) + + +def ppo_surrogate_loss( + policy: Policy, + model: ModelV2, + dist_class: Type[TorchDistributionWrapper], + train_batch: SampleBatch, + vars=None, +) -> Union[TensorType, List[TensorType]]: + """Constructs the loss for Proximal Policy Objective. + + Args: + policy (Policy): The Policy to calculate the loss for. + model (ModelV2): The Model to calculate the loss for. + dist_class (Type[ActionDistribution]: The action distr. class. + train_batch (SampleBatch): The training data. + + Returns: + Union[TensorType, List[TensorType]]: A single loss tensor or a list + of loss tensors. + """ + def loss_(train_batch, vars=None): + #if vars: + # for k, v in vars.items(): + # setattr(model, k, v) + + logits, value_out, state = model.forward_(train_batch["obs"]) + curr_action_dist = dist_class(logits, None) + + prev_action_dist = dist_class( + train_batch[SampleBatch.ACTION_DIST_INPUTS], None) + + logp_ratio = jnp.exp( + curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) - + train_batch[SampleBatch.ACTION_LOGP]) + action_kl = prev_action_dist.kl(curr_action_dist) + policy._mean_kl = jnp.mean(action_kl) + + curr_entropy = curr_action_dist.entropy() + policy._mean_entropy = jnp.mean(curr_entropy) + + surrogate_loss = jnp.minimum( + train_batch[Postprocessing.ADVANTAGES] * logp_ratio, + train_batch[Postprocessing.ADVANTAGES] * jnp.clip( + logp_ratio, 1 - policy.config["clip_param"], + 1 + policy.config["clip_param"])) + policy._mean_policy_loss = jnp.mean(-surrogate_loss) + + if policy.config["use_gae"]: + prev_value_fn_out = train_batch[SampleBatch.VF_PREDS] + #value_fn_out = model.value_function() + vf_loss1 = jnp.square(value_out - + train_batch[Postprocessing.VALUE_TARGETS]) + vf_clipped = prev_value_fn_out + jnp.clip( + value_out - prev_value_fn_out, -policy.config["vf_clip_param"], + policy.config["vf_clip_param"]) + vf_loss2 = jnp.square(vf_clipped - + train_batch[Postprocessing.VALUE_TARGETS]) + vf_loss = jnp.maximum(vf_loss1, vf_loss2) + policy._mean_vf_loss = jnp.mean(vf_loss) + total_loss = jnp.mean(-surrogate_loss + policy.kl_coeff * action_kl + + policy.config["vf_loss_coeff"] * vf_loss - + policy.entropy_coeff * curr_entropy) + else: + policy._mean_vf_loss = 0.0 + total_loss = jnp.mean(-surrogate_loss + policy.kl_coeff * action_kl - + policy.entropy_coeff * curr_entropy) + #policy._value_out = value_out + #policy._total_loss = total_loss + + #policy._vf_explained_var = explained_variance( + # train_batch[Postprocessing.VALUE_TARGETS], + # value_out) # policy.model.value_function() + + #if vars: + # policy._total_loss = policy._total_loss.primal + # policy._mean_policy_loss = policy._mean_policy_loss.primal + # policy._mean_vf_loss = policy._mean_vf_loss.primal + # policy._vf_explained_var = policy._vf_explained_var.primal + # policy._mean_entropy = policy._mean_entropy.primal + # policy._mean_kl = policy._mean_kl.primal + + return total_loss #, mean_policy_loss, mean_vf_loss, value_out, mean_entropy, mean_kl + + if not hasattr(policy, "jit_loss"): + policy.jit_loss = jax.jit(loss_) + policy.gradient_loss = jax.grad(policy.jit_loss, argnums=1)# 4 + + #policy._total_loss = policy.jit_loss(train_batch["obs"], vars) + + # Store stats in policy for stats_fn. + #policy._total_loss = total_loss + #policy._mean_policy_loss = mean_policy_loss + #policy._mean_vf_loss = mean_vf_loss + #policy._vf_explained_var = explained_variance( + # train_batch[Postprocessing.VALUE_TARGETS], + # policy._value_out) #policy.model.value_function() + #policy._mean_entropy = mean_entropy + #policy._mean_kl = mean_kl + + ret = policy.gradient_loss({k: train_batch[k] for k, v in train_batch.items()}, vars) + + +class ValueNetworkMixin: + """Assigns the `_value()` method to the PPOPolicy. + + This way, Policy can call `_value()` to get the current VF estimate on a + single(!) observation (as done in `postprocess_trajectory_fn`). + Note: When doing this, an actual forward pass is being performed. + This is different from only calling `model.value_function()`, where + the result of the most recent forward pass is being used to return an + already calculated tensor. + """ + + def __init__(self, obs_space, action_space, config): + # When doing GAE, we need the value function estimate on the + # observation. + if config["use_gae"]: + + # Input dict is provided to us automatically via the Model's + # requirements. It's a single-timestep (last one in trajectory) + # input_dict. + assert config["_use_trajectory_view_api"] + + def value(**input_dict): + _, value_out, _ = self.model.from_batch( + input_dict, is_training=False) + # [0] = remove the batch dim. + return value_out[0] + + # When not doing GAE, we do not require the value function's output. + else: + + def value(*args, **kwargs): + return 0.0 + + self._value = value + + +def setup_mixins(policy: Policy, obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: TrainerConfigDict) -> None: + """Call all mixin classes' constructors before PPOPolicy initialization. + + Args: + policy (Policy): The Policy object. + obs_space (gym.spaces.Space): The Policy's observation space. + action_space (gym.spaces.Space): The Policy's action space. + config (TrainerConfigDict): The Policy's config. + """ + ValueNetworkMixin.__init__(policy, obs_space, action_space, config) + KLCoeffMixin.__init__(policy, config) + EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"], + config["entropy_coeff_schedule"]) + LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) + + +# Build a child class of `JAXPolicy`, given the custom functions defined +# above. +PPOJAXPolicy = build_policy_class( + name="PPOJAXPolicy", + framework="jax", + get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, + loss_fn=ppo_surrogate_loss, + #stats_fn=kl_and_loss_stats, + #extra_action_out_fn=vf_preds_fetches, + postprocess_fn=compute_gae_for_sample_batch, + extra_grad_process_fn=apply_grad_clipping, + before_init=setup_config, + before_loss_init=setup_mixins, + mixins=[ + LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin, + ValueNetworkMixin + ], +) diff --git a/rllib/agents/ppo/tests/test_ppo.py b/rllib/agents/ppo/tests/test_ppo.py index de0bc90f64cc..ded2ace48490 100644 --- a/rllib/agents/ppo/tests/test_ppo.py +++ b/rllib/agents/ppo/tests/test_ppo.py @@ -45,6 +45,12 @@ def _check_lr_torch(policy, policy_id): for p in opt.param_groups: assert p["lr"] == policy.cur_lr, "LR scheduling error!" + @staticmethod + def _check_lr_jax(policy, policy_id): + for j, opt in enumerate(policy._optimizers): + assert opt.optimizer_def.hyper_params.learning_rate == \ + policy.cur_lr, "LR scheduling error!" + @staticmethod def _check_lr_tf(policy, policy_id): lr = policy.cur_lr @@ -58,14 +64,16 @@ def _check_lr_tf(policy, policy_id): assert lr == optim_lr, "LR scheduling error!" def on_train_result(self, *, trainer, result: dict, **kwargs): - trainer.workers.foreach_policy(self._check_lr_torch if trainer.config[ - "framework"] == "torch" else self._check_lr_tf) + fw = trainer.config["framework"] + trainer.workers.foreach_policy( + self._check_lr_tf if fw.startswith("tf") else self._check_lr_torch + if fw == "torch" else self._check_lr_jax) class TestPPO(unittest.TestCase): @classmethod def setUpClass(cls): - ray.init() + ray.init(local_mode=True)#TODO @classmethod def tearDownClass(cls): @@ -83,19 +91,26 @@ def test_ppo_compilation_and_lr_schedule(self): config["model"]["lstm_cell_size"] = 10 config["model"]["max_seq_len"] = 20 config["train_batch_size"] = 128 - num_iterations = 2 - - for _ in framework_iterator(config): - for env in ["CartPole-v0", "MsPacmanNoFrameskip-v4"]: + num_iterations = 1#TODO:2 + + for fw in framework_iterator( + config, frameworks=("jax")):#TODO, "tf2", "tf", "torch")): + envs = ["CartPole-v0"] + #if fw != "jax": + # envs.append("MsPacmanNoFrameskip-v4") + for env in envs: print("Env={}".format(env)) - for lstm in [True, False]: + lstms = [False] + #if fw != "jax": + # lstms.append(True) + for lstm in lstms: print("LSTM={}".format(lstm)) config["model"]["use_lstm"] = lstm config["model"]["lstm_use_prev_action"] = lstm config["model"]["lstm_use_prev_reward"] = lstm trainer = ppo.PPOTrainer(config=config, env=env) for i in range(num_iterations): - trainer.train() + print(trainer.train()) check_compute_single_action( trainer, include_prev_action_reward=True, diff --git a/rllib/evaluation/postprocessing.py b/rllib/evaluation/postprocessing.py index 7d1801cf6566..c29a33fdfff2 100644 --- a/rllib/evaluation/postprocessing.py +++ b/rllib/evaluation/postprocessing.py @@ -40,8 +40,11 @@ def compute_advantages(rollout: SampleBatch, processed rewards. """ - assert SampleBatch.VF_PREDS in rollout or not use_critic, \ - "use_critic=True but values not found" + try:#TODO + assert SampleBatch.VF_PREDS in rollout or not use_critic, \ + "use_critic=True but values not found" + except Exception as e: + raise e assert use_critic or not use_gae, \ "Can't use gae without using a value function" diff --git a/rllib/execution/rollout_ops.py b/rllib/execution/rollout_ops.py index baaa2635795c..3237d89f88da 100644 --- a/rllib/execution/rollout_ops.py +++ b/rllib/execution/rollout_ops.py @@ -246,7 +246,10 @@ def __call__(self, samples: SampleBatchType) -> SampleBatchType: for policy_id in samples.policy_batches: batch = samples.policy_batches[policy_id] for field in self.fields: - batch[field] = standardized(batch[field]) + try:#TODO + batch[field] = standardized(batch[field]) + except Exception as e: + raise e if wrapped: samples = samples.policy_batches[DEFAULT_POLICY_ID] diff --git a/rllib/models/jax/fcnet.py b/rllib/models/jax/fcnet.py index 1cec5eb5e8a6..74b59fb1d08c 100644 --- a/rllib/models/jax/fcnet.py +++ b/rllib/models/jax/fcnet.py @@ -1,13 +1,15 @@ import logging import numpy as np -import time from ray.rllib.models.jax.jax_modelv2 import JAXModelV2 -from ray.rllib.models.jax.misc import SlimFC +from ray.rllib.models.jax.modules.fc_stack import FCStack from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_jax jax, flax = try_import_jax() +jnp = None +if jax: + import jax.numpy as jnp logger = logging.getLogger(__name__) @@ -20,83 +22,82 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, super().__init__(obs_space, action_space, num_outputs, model_config, name) - self.key = jax.random.PRNGKey(int(time.time())) - activation = model_config.get("fcnet_activation") hiddens = model_config.get("fcnet_hiddens", []) no_final_linear = model_config.get("no_final_linear") + in_features = int(np.product(obs_space.shape)) self.vf_share_layers = model_config.get("vf_share_layers") self.free_log_std = model_config.get("free_log_std") - # Generate free-floating bias variables for the second half of - # the outputs. if self.free_log_std: - assert num_outputs % 2 == 0, ( - "num_outputs must be divisible by two", num_outputs) - num_outputs = num_outputs // 2 + raise ValueError("`free_log_std` not supported for JAX yet!") - self._hidden_layers = [] - prev_layer_size = int(np.product(obs_space.shape)) self._logits = None - - # Create layers 0 to second-last. - for size in hiddens[:-1]: - self._hidden_layers.append( - SlimFC( - in_size=prev_layer_size, - out_size=size, - activation_fn=activation)) - prev_layer_size = size + self._logits_params = None + self._hidden_layers = None # The last layer is adjusted to be of size num_outputs, but it's a # layer with activation. if no_final_linear and num_outputs: - self._hidden_layers.append( - SlimFC( - in_size=prev_layer_size, - out_size=num_outputs, - activation_fn=activation)) - prev_layer_size = num_outputs + self._hidden_layers = FCStack( + in_features=in_features, + layers=hiddens + [num_outputs], + activation=activation, + prng_key=self.prng_key, + ) + # Finish the layers with the provided sizes (`hiddens`), plus - # iff num_outputs > 0 - a last linear layer of size num_outputs. else: + prev_layer_size = in_features if len(hiddens) > 0: - self._hidden_layers.append( - SlimFC( - in_size=prev_layer_size, - out_size=hiddens[-1], - activation_fn=activation)) + self._hidden_layers = FCStack( + in_features=prev_layer_size, + layers=hiddens, + activation=activation, + prng_key=self.prng_key, + ) prev_layer_size = hiddens[-1] if num_outputs: - self._logits = SlimFC( - in_size=prev_layer_size, - out_size=num_outputs, - activation_fn=None) + self._logits = FCStack( + in_features=prev_layer_size, + layers=[num_outputs], + activation=None, + prng_key=self.prng_key, + ) + self._logits_params = self._logits.init( + self.prng_key, jnp.zeros((1, prev_layer_size))) else: self.num_outputs = ( [int(np.product(obs_space.shape))] + hiddens[-1:])[-1] - # Layer to add the log std vars to the state-dependent means. - if self.free_log_std and self._logits: - raise ValueError("`free_log_std` not supported for JAX yet!") + # Init hidden layers. + self._hidden_layers_params = None + if self._hidden_layers: + in_ = jnp.zeros((1, in_features)) + self._hidden_layers_params = self._hidden_layers.init( + self.prng_key, in_) self._value_branch_separate = None if not self.vf_share_layers: # Build a parallel set of hidden layers for the value net. - prev_vf_layer_size = int(np.product(obs_space.shape)) - vf_layers = [] - for size in hiddens: - vf_layers.append( - SlimFC( - in_size=prev_vf_layer_size, - out_size=size, - activation_fn=activation, - )) - prev_vf_layer_size = size - self._value_branch_separate = vf_layers - - self._value_branch = SlimFC( - in_size=prev_layer_size, out_size=1, activation_fn=None) + self._value_branch_separate = FCStack( + in_features=in_features, + layers=hiddens, + activation=activation, + prng_key=self.prng_key, + ) + in_ = jnp.zeros((1, in_features)) + self._value_branch_separate_params = \ + self._value_branch_separate.init(self.prng_key, in_) + + self._value_branch = FCStack( + in_features=prev_layer_size, + layers=[1], + prng_key=self.prng_key, + ) + in_ = jnp.zeros((1, prev_layer_size)) + self._value_branch_params = self._value_branch.init(self.prng_key, in_) # Holds the current "base" output (before logits layer). self._features = None # Holds the last input, in case value branch is separate. @@ -104,22 +105,48 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, @override(JAXModelV2) def forward(self, input_dict, state, seq_lens): - self._last_flat_in = input_dict["obs_flat"] - x = self._last_flat_in - for layer in self._hidden_layers: - x = layer(x) - self._features = x - logits = self._logits(self._features) if self._logits else \ - self._features - if self.free_log_std: - logits = self._append_free_log_std(logits) - return logits, state - - @override(JAXModelV2) - def value_function(self): - assert self._features is not None, "must call forward() first" - if self._value_branch_separate: - return self._value_branch( - self._value_branch_separate(self._last_flat_in)).squeeze(1) - else: - return self._value_branch(self._features).squeeze(1) + #self.jit_forward = jax.jit(lambda i, s, sl: self.forward_(i, s, sl)) + + if not hasattr(self, "forward_"): + def forward_(flat_in): + self._last_flat_in = self._features = flat_in#input_dict["obs_flat"]#flat_in + if self._hidden_layers: + self._features = self._hidden_layers.apply( + self._hidden_layers_params, self._last_flat_in) + logits = self._logits.apply(self._logits_params, self._features) if \ + self._logits else self._features + + if self._value_branch_separate: + x = self._value_branch_separate.apply( + self._value_branch_separate_params, self._last_flat_in) + value_out = self._value_branch.apply(self._value_branch_params, + x).squeeze(1) + else: + value_out = self._value_branch.apply(self._value_branch_params, + self._features).squeeze(1) + + return logits, value_out, state + + self.forward_ = forward_ + self.jit_forward = jax.jit(forward_) + + return self.jit_forward(input_dict["obs_flat"]) + + #@override(JAXModelV2) + #def value_function(self): + # assert self._features is not None, "must call forward() first" + + # if not hasattr(self, "value_function_"): + # def value_function_(): + # if self._value_branch_separate: + # x = self._value_branch_separate.apply( + # self._value_branch_separate_params, self._last_flat_in) + # return self._value_branch.apply(self._value_branch_params, + # x).squeeze(1) + # else: + # return self._value_branch.apply(self._value_branch_params, + # self._features).squeeze(1) + + # self.value_function_ = jax.jit(value_function_) + + # return self.value_function_() diff --git a/rllib/models/jax/jax_action_dist.py b/rllib/models/jax/jax_action_dist.py index dd8309cb73de..3e9d07653734 100644 --- a/rllib/models/jax/jax_action_dist.py +++ b/rllib/models/jax/jax_action_dist.py @@ -8,6 +8,9 @@ jax, flax = try_import_jax() tfp = try_import_tfp() +jnp = None +if jax: + from jax import numpy as jnp class JAXDistribution(ActionDistribution): @@ -33,13 +36,6 @@ def entropy(self) -> TensorType: def kl(self, other: ActionDistribution) -> TensorType: return self.dist.kl_divergence(other.dist) - @override(ActionDistribution) - def sample(self) -> TensorType: - # Update the state of our PRNG. - _, self.prng_key = jax.random.split(self.prng_key) - self.last_sample = jax.random.categorical(self.prng_key, self.inputs) - return self.last_sample - @override(ActionDistribution) def sampled_action_logp(self) -> TensorType: assert self.last_sample is not None @@ -59,11 +55,33 @@ def __init__(self, inputs, model=None, temperature=1.0): self.dist = tfp.experimental.substrates.jax.distributions.Categorical( logits=self.inputs) + @override(ActionDistribution) + def sample(self) -> TensorType: + # Update the state of our PRNG. + _, self.prng_key = jax.random.split(self.prng_key) + self.last_sample = jax.random.categorical(self.prng_key, self.inputs) + return self.last_sample + @override(ActionDistribution) def deterministic_sample(self): self.last_sample = self.inputs.argmax(axis=1) return self.last_sample + @override(JAXDistribution) + def entropy(self) -> TensorType: + m = jnp.max(self.inputs, axis=-1, keepdims=True) + x = self.inputs - m + sum_exp_x = jnp.sum(jnp.exp(x), axis=-1) + lse_logits = m[..., 0] + jnp.log(sum_exp_x) + is_inf_logits = jnp.isinf(self.inputs).astype(jnp.float32) + is_negative_logits = (self.inputs < 0).astype(jnp.float32) + masked_logits = jnp.where( + (is_inf_logits * is_negative_logits).astype(jnp.bool_), + jnp.array(1.0).astype(self.inputs.dtype), self.inputs) + + return lse_logits - jnp.sum( + jnp.multiply(masked_logits, jnp.exp(x)), axis=-1) / sum_exp_x + @staticmethod @override(ActionDistribution) def required_model_output_shape(action_space, model_config): diff --git a/rllib/models/jax/jax_modelv2.py b/rllib/models/jax/jax_modelv2.py index d527542e9c9a..0f647581877e 100644 --- a/rllib/models/jax/jax_modelv2.py +++ b/rllib/models/jax/jax_modelv2.py @@ -1,14 +1,25 @@ import gym +import time +import tree +from typing import Dict, List, Union from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.utils.annotations import PublicAPI -from ray.rllib.utils.typing import ModelConfigDict +from ray.rllib.utils.annotations import override, PublicAPI +from ray.rllib.utils.framework import try_import_jax +from ray.rllib.utils.typing import ModelConfigDict, TensorType + +jax, flax = try_import_jax() +fd = None +if flax: + from flax.core.frozen_dict import FrozenDict as fd @PublicAPI class JAXModelV2(ModelV2): """JAX version of ModelV2. + + Note that this class by itself is not a valid model unless you implement forward() in a subclass.""" @@ -25,3 +36,31 @@ def __init__(self, obs_space: gym.spaces.Space, model_config, name, framework="jax") + + self.prng_key = jax.random.PRNGKey(int(time.time())) + + @PublicAPI + @override(ModelV2) + def variables(self, as_dict: bool = False + ) -> Union[List[TensorType], Dict[str, TensorType]]: + params = { + k: v + for k, v in self.__dict__.items() + if isinstance(v, fd) and "params" in v + } + if as_dict: + return params + return tree.flatten(params) + + @PublicAPI + @override(ModelV2) + def trainable_variables( + self, as_dict: bool = False + ) -> Union[List[TensorType], Dict[str, TensorType]]: + return self.variables(as_dict=as_dict) + + @PublicAPI + @override(ModelV2) + def value_function(self) -> TensorType: + raise ValueError("JAXModelV2 does not have a separate " + "`value_function()` call!") diff --git a/rllib/models/jax/misc.py b/rllib/models/jax/misc.py index 8a9397a1fdb4..b69ac50cfad1 100644 --- a/rllib/models/jax/misc.py +++ b/rllib/models/jax/misc.py @@ -1,65 +1,53 @@ -import time -from typing import Callable, Optional +from typing import Callable, Optional, Union -from ray.rllib.utils.framework import get_activation_fn, try_import_jax +from ray.rllib.models.utils import get_activation_fn, get_initializer +from ray.rllib.utils.framework import try_import_jax jax, flax = try_import_jax() -nn = np = None +nn = jnp = None if flax: import flax.linen as nn - import jax.numpy as np -class SlimFC: - """Simple JAX version of a fully connected layer.""" +class SlimFC(nn.Module if nn else object): + """Simple JAX version of a fully connected layer. - def __init__(self, - in_size, - out_size, - initializer: Optional[Callable] = None, - activation_fn: Optional[str] = None, - use_bias: bool = True, - prng_key: Optional[jax.random.PRNGKey] = None, - name: Optional[str] = None): - """Initializes a SlimFC instance. + Properties: + in_size (int): The input size of the input data that will be passed + into this layer. + out_size (int): The number of nodes in this FC layer. + initializer (flax.: + activation (str): An activation string specifier, e.g. "relu". + use_bias (bool): Whether to add biases to the dot product or not. + #bias_init (float): + """ - Args: - in_size (int): The input size of the input data that will be passed - into this layer. - out_size (int): The number of nodes in this FC layer. - initializer (flax.: - activation_fn (str): An activation string specifier, e.g. "relu". - use_bias (bool): Whether to add biases to the dot product or not. - #bias_init (float): - prng_key (Optional[jax.random.PRNGKey]): An optional PRNG key to - use for initialization. If None, create a new random one. - name (Optional[str]): An optional name for this layer. - """ + in_size: int + out_size: int + initializer: Optional[Union[Callable, str]] = None + activation: Optional[Union[Callable, str]] = None + use_bias: bool = True + def setup(self): # By default, use Glorot unform initializer. - if initializer is None: - initializer = flax.nn.initializers.xavier_uniform() + if self.initializer is None: + self.initializer = "xavier_uniform" + + # Activation function (if any; default=None (linear)). + self.initializer_fn = get_initializer( + self.initializer, framework="jax") + self.activation_fn = get_activation_fn( + self.activation, framework="jax") - self.prng_key = prng_key or jax.random.PRNGKey(int(time.time())) - _, self.prng_key = jax.random.split(self.prng_key) # Create the flax dense layer. - self._dense = nn.Dense( - out_size, - use_bias=use_bias, - kernel_init=initializer, - name=name, + self.dense = nn.Dense( + self.out_size, + use_bias=self.use_bias, + kernel_init=self.initializer_fn, ) - # Initialize it. - dummy_in = jax.random.normal( - self.prng_key, (in_size, ), dtype=np.float32) - _, self.prng_key = jax.random.split(self.prng_key) - self._params = self._dense.init(self.prng_key, dummy_in) - - # Activation function (if any; default=None (linear)). - self.activation_fn = get_activation_fn(activation_fn, "jax") def __call__(self, x): - out = self._dense.apply(self._params, x) + out = self.dense(x) if self.activation_fn: out = self.activation_fn(out) return out diff --git a/rllib/models/jax/modules/__init__.py b/rllib/models/jax/modules/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/rllib/models/jax/modules/fc_stack.py b/rllib/models/jax/modules/fc_stack.py new file mode 100644 index 000000000000..de858d67bda2 --- /dev/null +++ b/rllib/models/jax/modules/fc_stack.py @@ -0,0 +1,59 @@ +import logging +import time +from typing import Callable, Optional, Union + +from ray.rllib.models.jax.misc import SlimFC +from ray.rllib.models.utils import get_initializer +from ray.rllib.utils.framework import try_import_jax + +jax, flax = try_import_jax() +nn = None +if flax: + nn = flax.linen + +logger = logging.getLogger(__name__) + + +class FCStack(nn.Module if nn else object): + """Generic fully connected FLAX module. + + Properties: + in_features (int): Number of input features (input dim). + layers (List[int]): List of Dense layer sizes. + activation (Optional[Union[Callable, str]]): An optional activation + function or activation function specifier (str), such as + "relu". Use None or "linear" for no activation. + initializer (): + """ + + in_features: int + layers: [] + activation: Optional[Union[Callable, str]] = None + initializer: Optional[Union[Callable, str]] = None + use_bias: bool = True + prng_key: Optional[jax.random.PRNGKey] = jax.random.PRNGKey( + int(time.time())) + + def setup(self): + self.initializer = get_initializer(self.initializer, framework="jax") + + # Create all layers. + hidden_layers = [] + prev_layer_size = self.in_features + for i, size in enumerate(self.layers): + hidden_layers.append( + SlimFC( + in_size=prev_layer_size, + out_size=size, + use_bias=self.use_bias, + initializer=self.initializer, + activation=self.activation, + )) + prev_layer_size = size + self.hidden_layers = hidden_layers + + def __call__(self, inputs): + x = inputs + for layer in self.hidden_layers: + x = layer(x) + return x diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index 70ad50202421..e272dc0d5b0c 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -210,11 +210,11 @@ def __call__( with self.context(): res = self.forward(restored, state or [], seq_lens) if ((not isinstance(res, list) and not isinstance(res, tuple)) - or len(res) != 2): + or len(res) not in [2, 3]): raise ValueError( - "forward() must return a tuple of (output, state) tensors, " - "got {}".format(res)) - outputs, state = res + "forward() must return a tuple of (output, [value-out]?, " + "state) tensors, got {}".format(res)) + outputs, state = res[0], res[-1] try: shape = outputs.shape @@ -229,7 +229,7 @@ def __call__( raise ValueError("State output is not a list: {}".format(state)) self._last_output = outputs - return outputs, state + return res @PublicAPI def from_batch(self, train_batch: SampleBatch, diff --git a/rllib/models/tests/test_jax_models.py b/rllib/models/tests/test_jax_models.py new file mode 100644 index 000000000000..5376ca94c63f --- /dev/null +++ b/rllib/models/tests/test_jax_models.py @@ -0,0 +1,49 @@ +from gym.spaces import Box, Discrete +import unittest + +from ray.rllib.models.jax.fcnet import FullyConnectedNetwork +from ray.rllib.models.jax.misc import SlimFC +from ray.rllib.models.jax.modules.fc_stack import FCStack +from ray.rllib.utils.framework import try_import_jax + +jax, flax = try_import_jax() +jnp = None +if jax: + import jax.numpy as jnp + + +class TestJAXModels(unittest.TestCase): + def test_jax_slimfc(self): + slimfc = SlimFC(5, 2) + prng = jax.random.PRNGKey(0) + params = slimfc.init(prng, jnp.zeros((1, 5))) + assert params + + def test_jax_fcstack(self): + fc_stack = FCStack(5, [2, 2], "relu") + prng = jax.random.PRNGKey(0) + params = fc_stack.init(prng, jnp.zeros((1, 5))) + assert params + + def test_jax_fcnet(self): + """Tests the JAX FCNet class.""" + obs_space = Box(-10.0, 10.0, shape=(4, )) + action_space = Discrete(2) + fc_net = FullyConnectedNetwork( + obs_space, + action_space, + num_outputs=2, + model_config={ + "fcnet_hiddens": [10], + "fcnet_activation": "relu", + }, + name="jax_model") + inputs = jnp.array([obs_space.sample()]) + print(fc_net({"obs": inputs})) + fc_net.variables() + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/policy/__init__.py b/rllib/policy/__init__.py index 67868182a07a..e1a0f50b4898 100644 --- a/rllib/policy/__init__.py +++ b/rllib/policy/__init__.py @@ -1,4 +1,5 @@ from ray.rllib.policy.policy import Policy +from ray.rllib.policy.jax_policy import JAXPolicy from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.policy.policy_template import build_policy_class @@ -7,6 +8,7 @@ __all__ = [ "Policy", + "JAXPolicy", "TFPolicy", "TorchPolicy", "build_policy_class", diff --git a/rllib/policy/jax_policy.py b/rllib/policy/jax_policy.py new file mode 100644 index 000000000000..bcc5ceef7887 --- /dev/null +++ b/rllib/policy/jax_policy.py @@ -0,0 +1,544 @@ +import functools +import gym +import numpy as np +import logging +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.jax.jax_modelv2 import JAXModelV2 +from ray.rllib.models.jax.jax_action_dist import JAXDistribution +from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size +from ray.rllib.utils import force_list +from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils.framework import try_import_jax +from ray.rllib.utils.jax_ops import convert_to_jax_device_array, \ + convert_to_non_jax_type +from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule +from ray.rllib.utils.tracking_dict import UsageTrackingDict +from ray.rllib.utils.typing import ModelGradients, ModelWeights, \ + TensorType, TrainerConfigDict + +jax, flax = try_import_jax() +jnp = None +if jax: + import jax.numpy as jnp + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class JAXPolicy(Policy): + """Template for a JAX policy and loss to use with RLlib. + """ + + @DeveloperAPI + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: TrainerConfigDict, + *, + model: ModelV2, + loss: Callable[[ + Policy, ModelV2, Type[JAXDistribution], SampleBatch + ], Union[TensorType, List[TensorType]]], + action_distribution_class: Type[JAXDistribution], + action_sampler_fn: Optional[Callable[[ + TensorType, List[TensorType] + ], Tuple[TensorType, TensorType]]] = None, + action_distribution_fn: Optional[Callable[[ + Policy, ModelV2, TensorType, TensorType, TensorType + ], Tuple[TensorType, Type[JAXDistribution], List[ + TensorType]]]] = None, + max_seq_len: int = 20, + get_batch_divisibility_req: Optional[Callable[[Policy], + int]] = None, + ): + """Initializes a JAXPolicy instance. + + Args: + observation_space (gym.spaces.Space): Observation space of the + policy. + action_space (gym.spaces.Space): Action space of the policy. + config (TrainerConfigDict): The Policy config dict. + model (ModelV2): PyTorch policy module. Given observations as + input, this module must return a list of outputs where the + first item is action logits, and the rest can be any value. + loss (Callable[[Policy, ModelV2, Type[JAXDistribution], + SampleBatch], Union[TensorType, List[TensorType]]]): Callable + that returns a single scalar loss or a list of loss terms. + action_distribution_class (Type[JAXDistribution]): Class + for a torch action distribution. + action_sampler_fn (Callable[[TensorType, List[TensorType]], + Tuple[TensorType, TensorType]]): A callable returning a + sampled action and its log-likelihood given Policy, ModelV2, + input_dict, explore, timestep, and is_training. + action_distribution_fn (Optional[Callable[[Policy, ModelV2, + Dict[str, TensorType], TensorType, TensorType], + Tuple[TensorType, type, List[TensorType]]]]): A callable + returning distribution inputs (parameters), a dist-class to + generate an action distribution object from, and + internal-state outputs (or an empty list if not applicable). + Note: No Exploration hooks have to be called from within + `action_distribution_fn`. It's should only perform a simple + forward pass through some model. + If None, pass inputs through `self.model()` to get distribution + inputs. + The callable takes as inputs: Policy, ModelV2, input_dict, + explore, timestep, is_training. + max_seq_len (int): Max sequence length for LSTM training. + get_batch_divisibility_req (Optional[Callable[[Policy], int]]]): + Optional callable that returns the divisibility requirement + for sample batches given the Policy. + """ + self.framework = "jax" + super().__init__(observation_space, action_space, config) + self.model = model + # Auto-update model's inference view requirements, if recurrent. + self._update_model_view_requirements_from_init_state() + # Combine view_requirements for Model and Policy. + self.view_requirements.update(self.model.view_requirements) + + self.exploration = self._create_exploration() + self.unwrapped_model = model # used to support DistributedDataParallel + self._loss = loss + #self._gradient_loss = jax.grad(self._loss, argnums=4) + self._optimizers = force_list(self.optimizer()) + + self.dist_class = action_distribution_class + self.action_sampler_fn = action_sampler_fn + self.action_distribution_fn = action_distribution_fn + + # If set, means we are using distributed allreduce during learning. + self.distributed_world_size = None + + self.max_seq_len = max_seq_len + self.batch_divisibility_req = get_batch_divisibility_req(self) if \ + callable(get_batch_divisibility_req) else \ + (get_batch_divisibility_req or 1) + + @override(Policy) + def compute_actions( + self, + obs_batch: Union[List[TensorType], TensorType], + state_batches: Optional[List[TensorType]] = None, + prev_action_batch: Union[List[TensorType], TensorType] = None, + prev_reward_batch: Union[List[TensorType], TensorType] = None, + info_batch: Optional[Dict[str, list]] = None, + episodes: Optional[List["MultiAgentEpisode"]] = None, + explore: Optional[bool] = None, + timestep: Optional[int] = None, + **kwargs) -> \ + Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: + + input_dict = self._lazy_tensor_dict({ + SampleBatch.CUR_OBS: np.asarray(obs_batch), + "is_training": False, + }) + if prev_action_batch is not None: + input_dict[SampleBatch.PREV_ACTIONS] = \ + np.asarray(prev_action_batch) + if prev_reward_batch is not None: + input_dict[SampleBatch.PREV_REWARDS] = \ + np.asarray(prev_reward_batch) + for i, s in enumerate(state_batches or []): + input_dict["state_in_".format(i)] = s + + return self.compute_actions_from_input_dict(input_dict, explore, + timestep, **kwargs) + + @override(Policy) + def compute_actions_from_input_dict( + self, + input_dict: Dict[str, TensorType], + explore: bool = None, + timestep: Optional[int] = None, + **kwargs) -> \ + Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: + + import time#TODO + start = time.time() + + explore = explore if explore is not None else self.config["explore"] + timestep = timestep if timestep is not None else self.global_timestep + + # Pack internal state inputs into (separate) list. + state_batches = [ + input_dict[k] for k in input_dict.keys() if "state_in" in k[:8] + ] + # Calculate RNN sequence lengths. + seq_lens = np.array([1] * len(input_dict["obs"])) \ + if state_batches else None + + #if self.action_sampler_fn: + # action_dist = dist_inputs = None + # state_out = state_batches + # actions, logp, state_out = self.action_sampler_fn( + # self, + # self.model, + # input_dict, + # state_out, + # explore=explore, + # timestep=timestep) + #else: + # Call the exploration before_compute_actions hook. + #self.exploration.before_compute_actions( + # explore=explore, timestep=timestep) + #if self.action_distribution_fn: + # dist_inputs, dist_class, state_out = \ + # self.action_distribution_fn( + # self, + # self.model, + # input_dict[SampleBatch.CUR_OBS], + # explore=explore, + # timestep=timestep, + # is_training=False) + #else: + dist_class = self.dist_class + dist_inputs, value_out, state_out = self.model( + input_dict, state_batches, seq_lens) + + #if not (isinstance(dist_class, functools.partial) + # or issubclass(dist_class, JAXDistribution)): + # raise ValueError( + # "`dist_class` ({}) not a JAXDistribution " + # "subclass! Make sure your `action_distribution_fn` or " + # "`make_model_and_action_dist` return a correct " + # "distribution class.".format(dist_class.__name__)) + action_dist = dist_class(dist_inputs, self.model) + + # Get the exploration action from the forward results. + actions, logp = \ + self.exploration.get_exploration_action( + action_distribution=action_dist, + timestep=timestep, + explore=explore) + + input_dict[SampleBatch.ACTIONS] = actions + + # Add default and custom fetches. + extra_fetches = self.extra_action_out(input_dict, state_batches, + self.model, action_dist) + extra_fetches[SampleBatch.VF_PREDS] = value_out + + # Action-dist inputs. + if dist_inputs is not None: + extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs + + # Action-logp and action-prob. + if logp is not None: + extra_fetches[SampleBatch.ACTION_PROB] = \ + jnp.exp(logp.astype(jnp.float32)) + extra_fetches[SampleBatch.ACTION_LOGP] = logp + + # Update our global timestep by the batch size. + self.global_timestep += len(input_dict[SampleBatch.CUR_OBS]) + + #TODO + print("action pass={}".format(time.time() - start)) + + return convert_to_non_jax_type((actions, state_out, extra_fetches)) + + @override(Policy) + @DeveloperAPI + def compute_log_likelihoods( + self, + actions: Union[List[TensorType], TensorType], + obs_batch: Union[List[TensorType], TensorType], + state_batches: Optional[List[TensorType]] = None, + prev_action_batch: Optional[Union[List[TensorType], + TensorType]] = None, + prev_reward_batch: Optional[Union[List[ + TensorType], TensorType]] = None) -> TensorType: + + if self.action_sampler_fn and self.action_distribution_fn is None: + raise ValueError("Cannot compute log-prob/likelihood w/o an " + "`action_distribution_fn` and a provided " + "`action_sampler_fn`!") + + input_dict = { + SampleBatch.CUR_OBS: obs_batch, + SampleBatch.ACTIONS: actions + } + if prev_action_batch is not None: + input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch + if prev_reward_batch is not None: + input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch + seq_lens = jnp.ones(len(obs_batch), dtype=jnp.int32) + state_batches = [s for s in (state_batches or [])] + + # Exploration hook before each forward pass. + self.exploration.before_compute_actions(explore=False) + + # Action dist class and inputs are generated via custom function. + if self.action_distribution_fn: + dist_inputs, dist_class, _ = self.action_distribution_fn( + policy=self, + model=self.model, + obs_batch=input_dict[SampleBatch.CUR_OBS], + explore=False, + is_training=False) + # Default action-dist inputs calculation. + else: + dist_class = self.dist_class + dist_inputs, _ = self.model(input_dict, state_batches, seq_lens) + + action_dist = dist_class(dist_inputs, self.model) + log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS]) + return log_likelihoods + + @override(Policy) + @DeveloperAPI + def learn_on_batch( + self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]: + # Callback handling. + self.callbacks.on_learn_on_batch( + policy=self, train_batch=postprocessed_batch) + + # Compute gradients (will calculate all losses and `backward()` + # them to get the grads). + grads, fetches = self.compute_gradients(postprocessed_batch) + + # Step the optimizer(s). + for i in range(len(self._optimizers)): + opt = self._optimizers[i] + self._optimizers[i] = opt.apply_gradient(grads[i]) + + # Update model's params. + for k, v in self._optimizers[0].target.items(): + setattr(self.model, k, v) + + if self.model: + fetches["model"] = self.model.metrics() + return fetches + + @override(Policy) + @DeveloperAPI + def compute_gradients(self, + postprocessed_batch: SampleBatch) -> ModelGradients: + # Get batch ready for RNNs, if applicable. + pad_batch_to_sequences_of_same_size( + postprocessed_batch, + max_seq_len=self.max_seq_len, + shuffle=False, + batch_divisibility_req=self.batch_divisibility_req, + ) + + train_batch = self._lazy_tensor_dict(postprocessed_batch) + model_params = self.model.variables(as_dict=True) + + # Calculate the actual policy loss. + all_grads = force_list( + self.gradient_loss({k: train_batch[k] for k in train_batch.keys() if k != "infos"}, model_params))#self, self.model, self.dist_class, train_batch, + #model_params)) + + #remove: assert not any(torch.isnan(l) for l in loss_out) + fetches = self.extra_compute_grad_fetches() + + # Loop through all optimizers. + grad_info = {"allreduce_latency": 0.0} + + grad_info["allreduce_latency"] /= len(self._optimizers) + grad_info.update(self.extra_grad_info(train_batch)) + + return all_grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info}) + + @override(Policy) + @DeveloperAPI + def apply_gradients(self, gradients: ModelGradients) -> None: + # TODO(sven): Not supported for multiple optimizers yet. + assert len(self._optimizers) == 1 + + # Step the optimizer(s). + self._optimizers[0] = self._optimizers[0].apply_gradient(gradients) + + @override(Policy) + @DeveloperAPI + def get_weights(self) -> ModelWeights: + cpu = jax.devices("cpu")[0] + return { + k: jax.device_put(v, cpu) + for k, v in self.model.variables(as_dict=True).items() + } + + @override(Policy) + @DeveloperAPI + def set_weights(self, weights: ModelWeights) -> None: + for k, v in weights.items(): + setattr(self.model, k, v) + + @override(Policy) + @DeveloperAPI + def is_recurrent(self) -> bool: + return len(self.model.get_initial_state()) > 0 + + @override(Policy) + @DeveloperAPI + def num_state_tensors(self) -> int: + return len(self.model.get_initial_state()) + + @override(Policy) + @DeveloperAPI + def get_initial_state(self) -> List[TensorType]: + return [ + s.detach().cpu().numpy() for s in self.model.get_initial_state() + ] + + @override(Policy) + @DeveloperAPI + def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]: + state = super().get_state() + state["_optimizer_variables"] = [] + for i, o in enumerate(self._optimizers): + optim_state_dict = convert_to_non_jax_type(o.state_dict()) + state["_optimizer_variables"].append(optim_state_dict) + return state + + @override(Policy) + @DeveloperAPI + def set_state(self, state: object) -> None: + state = state.copy() # shallow copy + # Set optimizer vars first. + optimizer_vars = state.pop("_optimizer_variables", None) + if optimizer_vars: + assert len(optimizer_vars) == len(self._optimizers) + for i, (o, s) in enumerate(zip(self._optimizers, optimizer_vars)): + self._optimizers[i].optimizer_def.state = s + # Then the Policy's (NN) weights. + super().set_state(state) + + @DeveloperAPI + def extra_grad_process(self, optimizer: "jax.optim.Optimizer", + loss: TensorType): + """Called after each optimizer.zero_grad() + loss.backward() call. + + Called for each self._optimizers/loss-value pair. + Allows for gradient processing before optimizer.step() is called. + E.g. for gradient clipping. + + Args: + optimizer (torch.optim.Optimizer): A torch optimizer object. + loss (TensorType): The loss tensor associated with the optimizer. + + Returns: + Dict[str, TensorType]: An dict with information on the gradient + processing step. + """ + return {} + + @DeveloperAPI + def extra_compute_grad_fetches(self) -> Dict[str, any]: + """Extra values to fetch and return from compute_gradients(). + + Returns: + Dict[str, any]: Extra fetch dict to be added to the fetch dict + of the compute_gradients call. + """ + return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc. + + @DeveloperAPI + def extra_action_out( + self, input_dict: Dict[str, TensorType], + state_batches: List[TensorType], model: JAXModelV2, + action_dist: JAXDistribution) -> Dict[str, TensorType]: + """Returns dict of extra info to include in experience batch. + + Args: + input_dict (Dict[str, TensorType]): Dict of model input tensors. + state_batches (List[TensorType]): List of state tensors. + model (JAXModelV2): Reference to the model object. + action_dist (JAXDistribution): Torch action dist object + to get log-probs (e.g. for already sampled actions). + + Returns: + Dict[str, TensorType]: Extra outputs to return in a + compute_actions() call (3rd return value). + """ + return {} + + @DeveloperAPI + def extra_grad_info(self, + train_batch: SampleBatch) -> Dict[str, TensorType]: + """Return dict of extra grad info. + + Args: + train_batch (SampleBatch): The training batch for which to produce + extra grad info for. + + Returns: + Dict[str, TensorType]: The info dict carrying grad info per str + key. + """ + return {} + + @DeveloperAPI + def optimizer( + self + ) -> Union[List["flax.optim.Optimizer"], "flax.optim.Optimizer"]: + """Custom the local FLAX optimizer(s) to use. + + Returns: + Union[List[flax.optim.Optimizer], flax.optim.Optimizer]: + The local FLAX optimizer(s) to use for this Policy. + """ + if hasattr(self, "config"): + adam = flax.optim.Adam(learning_rate=self.config["lr"]) + else: + adam = flax.optim.Adam() + weights = self.get_weights() + adam_state = adam.init_state(weights) + return flax.optim.Optimizer(adam, adam_state, target=weights) + + @override(Policy) + @DeveloperAPI + def export_model(self, export_dir: str) -> None: + """TODO(sven): implement for JAX. + """ + raise NotImplementedError + + @override(Policy) + @DeveloperAPI + def export_checkpoint(self, export_dir: str) -> None: + """TODO(sven): implement for JAX. + """ + raise NotImplementedError + + @override(Policy) + @DeveloperAPI + def import_model_from_h5(self, import_file: str) -> None: + """Imports weights into JAX model.""" + return self.model.import_from_h5(import_file) + + def _lazy_tensor_dict(self, data): + tensor_dict = UsageTrackingDict(data) + tensor_dict.set_get_interceptor( + functools.partial(convert_to_jax_device_array, device=None)) + return tensor_dict + + +# TODO: (sven) Unify hyperparam annealing procedures across RLlib (tf/torch) +# and for all possible hyperparams, not just lr. +@DeveloperAPI +class LearningRateSchedule: + """Mixin for TorchPolicy that adds a learning rate schedule.""" + + @DeveloperAPI + def __init__(self, lr, lr_schedule): + self.cur_lr = lr + if lr_schedule is None: + self.lr_schedule = ConstantSchedule(lr, framework=None) + else: + self.lr_schedule = PiecewiseSchedule( + lr_schedule, outside_value=lr_schedule[-1][-1], framework=None) + + @override(Policy) + def on_global_var_update(self, global_vars): + super().on_global_var_update(global_vars) + self.cur_lr = self.lr_schedule.value(global_vars["timestep"]) + for i in range(len(self._optimizers)): + opt = self._optimizers[i] + new_hyperparams = opt.optimizer_def.update_hyper_params( + learning_rate=self.cur_lr) + opt.optimizer_def.hyper_params = new_hyperparams diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 577ac3d68c75..380ef93b163d 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -588,6 +588,7 @@ def _get_default_view_requirements(self): SampleBatch.EPS_ID: ViewRequirement(), SampleBatch.UNROLL_ID: ViewRequirement(), SampleBatch.AGENT_INDEX: ViewRequirement(), + SampleBatch.VF_PREDS: ViewRequirement(), "t": ViewRequirement(), } diff --git a/rllib/policy/policy_template.py b/rllib/policy/policy_template.py index eb08ecc651e3..fddf5d38f1b3 100644 --- a/rllib/policy/policy_template.py +++ b/rllib/policy/policy_template.py @@ -6,6 +6,7 @@ from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.policy.jax_policy import JAXPolicy from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import TorchPolicy @@ -74,7 +75,7 @@ def build_policy_class( [Policy, "torch.optim.Optimizer"], None]] = None, mixins: Optional[List[type]] = None, get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None -) -> Type[TorchPolicy]: +) -> Type[Union[JAXPolicy, TorchPolicy]]: """Helper function for creating a new Policy class at runtime. Supports frameworks JAX and PyTorch. @@ -188,7 +189,7 @@ def build_policy_class( """ original_kwargs = locals().copy() - parent_cls = TorchPolicy + parent_cls = TorchPolicy if framework == "torch" else JAXPolicy base = add_mixins(parent_cls, mixins) class policy_cls(base): diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index 19d576d3776a..443c583d3df0 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -190,6 +190,10 @@ def compute_actions_from_input_dict( **kwargs) -> \ Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: + import time + #TODO + start = time.time() + explore = explore if explore is not None else self.config["explore"] timestep = timestep if timestep is not None else self.global_timestep @@ -204,8 +208,11 @@ def compute_actions_from_input_dict( seq_lens = np.array([1] * len(input_dict["obs"])) \ if state_batches else None - return self._compute_action_helper(input_dict, state_batches, + ret = self._compute_action_helper(input_dict, state_batches, seq_lens, explore, timestep) + #TODO + print("action pass={}".format(time.time() - start)) + return ret @with_lock def _compute_action_helper(self, input_dict, state_batches, seq_lens, @@ -222,48 +229,48 @@ def _compute_action_helper(self, input_dict, state_batches, seq_lens, if self.model: self.model.eval() - if self.action_sampler_fn: - action_dist = dist_inputs = None - actions, logp, state_out = self.action_sampler_fn( - self, - self.model, - input_dict, - state_batches, - explore=explore, - timestep=timestep) - else: - # Call the exploration before_compute_actions hook. - self.exploration.before_compute_actions( - explore=explore, timestep=timestep) - if self.action_distribution_fn: - dist_inputs, dist_class, state_out = \ - self.action_distribution_fn( - self, - self.model, - input_dict[SampleBatch.CUR_OBS], - explore=explore, - timestep=timestep, - is_training=False) - else: - dist_class = self.dist_class - dist_inputs, state_out = self.model(input_dict, state_batches, - seq_lens) - - if not (isinstance(dist_class, functools.partial) - or issubclass(dist_class, TorchDistributionWrapper)): - raise ValueError( - "`dist_class` ({}) not a TorchDistributionWrapper " - "subclass! Make sure your `action_distribution_fn` or " - "`make_model_and_action_dist` return a correct " - "distribution class.".format(dist_class.__name__)) - action_dist = dist_class(dist_inputs, self.model) + #if self.action_sampler_fn: + # action_dist = dist_inputs = None + # actions, logp, state_out = self.action_sampler_fn( + # self, + # self.model, + # input_dict, + # state_batches, + # explore=explore, + # timestep=timestep) + #else: + # Call the exploration before_compute_actions hook. + #self.exploration.before_compute_actions( + # explore=explore, timestep=timestep) + #if self.action_distribution_fn: + # dist_inputs, dist_class, state_out = \ + # self.action_distribution_fn( + # self, + # self.model, + # input_dict[SampleBatch.CUR_OBS], + # explore=explore, + # timestep=timestep, + # is_training=False) + #else: + dist_class = self.dist_class + dist_inputs, state_out = self.model(input_dict, state_batches, + seq_lens) - # Get the exploration action from the forward results. - actions, logp = \ - self.exploration.get_exploration_action( - action_distribution=action_dist, - timestep=timestep, - explore=explore) + #if not (isinstance(dist_class, functools.partial) + # or issubclass(dist_class, TorchDistributionWrapper)): + # raise ValueError( + # "`dist_class` ({}) not a TorchDistributionWrapper " + # "subclass! Make sure your `action_distribution_fn` or " + # "`make_model_and_action_dist` return a correct " + # "distribution class.".format(dist_class.__name__)) + action_dist = dist_class(dist_inputs, self.model) + + # Get the exploration action from the forward results. + actions, logp = \ + self.exploration.get_exploration_action( + action_distribution=action_dist, + timestep=timestep, + explore=explore) input_dict[SampleBatch.ACTIONS] = actions diff --git a/rllib/utils/exploration/stochastic_sampling.py b/rllib/utils/exploration/stochastic_sampling.py index c40fae254b5c..7875c8c2f62f 100644 --- a/rllib/utils/exploration/stochastic_sampling.py +++ b/rllib/utils/exploration/stochastic_sampling.py @@ -8,9 +8,10 @@ from ray.rllib.utils.annotations import override from ray.rllib.utils.exploration.exploration import Exploration from ray.rllib.utils.exploration.random import Random -from ray.rllib.utils.framework import get_variable, try_import_tf, \ - try_import_torch, TensorType +from ray.rllib.utils.framework import get_variable, try_import_jax, \ + try_import_tf, try_import_torch, TensorType +jax, _ = try_import_jax() tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -36,7 +37,7 @@ def __init__(self, Args: action_space (gym.spaces.Space): The gym action space used by the environment. - framework (str): One of None, "tf", "torch". + framework (str): One of None, "jax", "tf", "torch". model (ModelV2): The ModelV2 used by the owning Policy. random_timesteps (int): The number of timesteps for which to act completely randomly. Only after this number of timesteps, @@ -65,9 +66,9 @@ def get_exploration_action(self, action_distribution: ActionDistribution, timestep: Union[int, TensorType], explore: bool = True): - if self.framework == "torch": - return self._get_torch_exploration_action(action_distribution, - timestep, explore) + if self.framework in ["torch", "jax"]: + return self._get_exploration_action(action_distribution, timestep, + explore) else: return self._get_tf_exploration_action_op(action_distribution, timestep, explore) @@ -114,9 +115,9 @@ def logp_false_fn(): with tf1.control_dependencies([assign_op]): return action, logp - def _get_torch_exploration_action(self, action_dist: ActionDistribution, - timestep: Union[TensorType, int], - explore: Union[TensorType, bool]): + def _get_exploration_action(self, action_dist: ActionDistribution, + timestep: Union[TensorType, int], + explore: Union[TensorType, bool]): # Set last timestep or (if not given) increase by one. self.last_timestep = timestep if timestep is not None else \ self.last_timestep + 1 @@ -136,6 +137,7 @@ def _get_torch_exploration_action(self, action_dist: ActionDistribution, # No exploration -> Return deterministic actions. else: action = action_dist.deterministic_sample() - logp = torch.zeros_like(action_dist.sampled_action_logp()) + fw = torch if self.framework == "torch" else jax.numpy + logp = fw.zeros_like(action_dist.sampled_action_logp()) return action, logp diff --git a/rllib/utils/jax_ops.py b/rllib/utils/jax_ops.py new file mode 100644 index 000000000000..a51d6b57492a --- /dev/null +++ b/rllib/utils/jax_ops.py @@ -0,0 +1,92 @@ +import numpy as np +import tree + +from ray.rllib.utils.framework import try_import_jax + +jax, _ = try_import_jax() +jnp = None +if jax: + import jax.numpy as jnp + + +def convert_to_jax_device_array(x, device=None): + """Converts any struct to jax.numpy.DeviceArray. + + x (any): Any (possibly nested) struct, the values in which will be + converted and returned as a new struct with all leaves converted + to torch tensors. + + Returns: + Any: A new struct with the same structure as `stats`, but with all + values converted to jax.numpy.DeviceArray types. + """ + + def mapping(item): + # Already JAX DeviceArray -> make sure it's on right device. + if isinstance(item, jnp.DeviceArray): + return item if device is None else item.to(device) + # Numpy arrays. + if isinstance(item, np.ndarray): + # np.object_ type (e.g. info dicts in train batch): leave as-is. + if item.dtype == np.object_: + return item + # Already numpy: Wrap as torch tensor. + else: + tensor = jnp.array(item) + # Everything else: Convert to numpy, then wrap as torch tensor. + else: + tensor = jnp.asarray(item) + # Floatify all float64 tensors. + if tensor.dtype == jnp.double: + tensor = tensor.astype(jnp.float32) + return tensor if device is None else tensor.to(device) + + return tree.map_structure(mapping, x) + + +def convert_to_non_jax_type(stats): + """Converts values in `stats` to non-JAX numpy or python types. + + Args: + stats (any): Any (possibly nested) struct, the values in which will be + converted and returned as a new struct with all JAX DeviceArrays + being converted to numpy types. + + Returns: + Any: A new struct with the same structure as `stats`, but with all + values converted to non-JAX Tensor types. + """ + + # The mapping function used to numpyize JAX DeviceArrays. + def mapping(item): + if isinstance(item, jnp.DeviceArray): + return np.array(item) + else: + return item + + return tree.map_structure(mapping, stats) + + +def explained_variance(y, pred): + y_var = jnp.var(y, axis=[0]) + diff_var = jnp.var(y - pred, axis=[0]) + min_ = 1.0 + return jnp.maximum(min_, 1 - (diff_var / y_var)) + + +def sequence_mask(lengths, maxlen=None, dtype=None, time_major=False): + """Offers same behavior as tf.sequence_mask for JAX numpy. + + Thanks to Dimitris Papatheodorou + (https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/ + 39036). + """ + if maxlen is None: + maxlen = int(lengths.max()) + + mask = ~(jnp.ones((len(lengths), maxlen)).cumsum(axis=1).t() > lengths) + if not time_major: + mask = mask.t() + mask.type(dtype or jnp.bool_) + + return mask diff --git a/rllib/utils/sgd.py b/rllib/utils/sgd.py index b5b72d44d37c..678c0afbc5a4 100644 --- a/rllib/utils/sgd.py +++ b/rllib/utils/sgd.py @@ -27,7 +27,10 @@ def averaged(kv, axis=None): out = {} for k, v in kv.items(): if v[0] is not None and not isinstance(v[0], dict): - out[k] = np.mean(v, axis=axis) + try:#TODO + out[k] = np.mean(v, axis=axis) + except Exception as e: + raise e else: out[k] = v[0] return out