diff --git a/rllib/agents/ppo/ppo_tf_policy.py b/rllib/agents/ppo/ppo_tf_policy.py index 29266dfccab9..ec56863a9007 100644 --- a/rllib/agents/ppo/ppo_tf_policy.py +++ b/rllib/agents/ppo/ppo_tf_policy.py @@ -193,22 +193,25 @@ def postprocess_ppo_gae( last_r = 0.0 # Trajectory has been truncated -> last r=VF estimate of last obs. else: - # Input dict is provided to us automatically via the Model's - # requirements. It's a single-timestep (last one in trajectory) - # input_dict. - if policy.config["_use_trajectory_view_api"]: - # Create an input dict according to the Model's requirements. - input_dict = policy.model.get_input_dict(sample_batch, index=-1) - last_r = policy._value(**input_dict) - # TODO: (sven) Remove once trajectory view API is all-algo default. + state_in_view_req = policy.model.inference_view_requirements.get( + "state_in_0") + # Attention net. + if state_in_view_req and state_in_view_req.shift_from is not None: + next_state = [] + for i in range(policy.num_state_tensors()): + view_req = policy.model.inference_view_requirements.get( + "state_in_{}".format(i)) + next_state.append(sample_batch["state_out_{}".format(i)][ + view_req.shift_from:view_req.shift_to + 1]) + # Everything else. else: next_state = [] for i in range(policy.num_state_tensors()): next_state.append(sample_batch["state_out_{}".format(i)][-1]) - last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1], - sample_batch[SampleBatch.ACTIONS][-1], - sample_batch[SampleBatch.REWARDS][-1], - *next_state) + last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1], + sample_batch[SampleBatch.ACTIONS][-1], + sample_batch[SampleBatch.REWARDS][-1], + *next_state) # Adds the policy logits, VF preds, and advantages to the batch, # using GAE ("generalized advantage estimation") or not. @@ -303,34 +306,19 @@ def __init__(self, obs_space, action_space, config): # observation. if config["use_gae"]: - # Input dict is provided to us automatically via the Model's - # requirements. It's a single-timestep (last one in trajectory) - # input_dict. - if config["_use_trajectory_view_api"]: - - @make_tf_callable(self.get_session()) - def value(**input_dict): - model_out, _ = self.model.from_batch( - input_dict, is_training=False) - # [0] = remove the batch dim. - return self.model.value_function()[0] - - # TODO: (sven) Remove once trajectory view API is all-algo default. - else: - - @make_tf_callable(self.get_session()) - def value(ob, prev_action, prev_reward, *state): - model_out, _ = self.model({ - SampleBatch.CUR_OBS: tf.convert_to_tensor([ob]), - SampleBatch.PREV_ACTIONS: tf.convert_to_tensor( - [prev_action]), - SampleBatch.PREV_REWARDS: tf.convert_to_tensor( - [prev_reward]), - "is_training": tf.convert_to_tensor([False]), - }, [tf.convert_to_tensor([s]) for s in state], - tf.convert_to_tensor([1])) - # [0] = remove the batch dim. - return self.model.value_function()[0] + @make_tf_callable(self.get_session()) + def value(ob, prev_action, prev_reward, *state): + model_out, _ = self.model({ + SampleBatch.CUR_OBS: tf.convert_to_tensor([ob]), + SampleBatch.PREV_ACTIONS: tf.convert_to_tensor( + [prev_action]), + SampleBatch.PREV_REWARDS: tf.convert_to_tensor( + [prev_reward]), + "is_training": tf.convert_to_tensor([False]), + }, [tf.convert_to_tensor([s]) for s in state], + tf.convert_to_tensor([1])) + # [0] = remove the batch dim. + return self.model.value_function()[0] # When not doing GAE, we do not require the value function's output. else: diff --git a/rllib/agents/ppo/ppo_torch_policy.py b/rllib/agents/ppo/ppo_torch_policy.py index a268e748720d..58637fa0a64b 100644 --- a/rllib/agents/ppo/ppo_torch_policy.py +++ b/rllib/agents/ppo/ppo_torch_policy.py @@ -210,36 +210,22 @@ def __init__(self, obs_space, action_space, config): # When doing GAE, we need the value function estimate on the # observation. if config["use_gae"]: - # Input dict is provided to us automatically via the Model's - # requirements. It's a single-timestep (last one in trajectory) - # input_dict. - if config["_use_trajectory_view_api"]: - - def value(**input_dict): - model_out, _ = self.model.from_batch( - convert_to_torch_tensor(input_dict, self.device), - is_training=False) - # [0] = remove the batch dim. - return self.model.value_function()[0] - - # TODO: (sven) Remove once trajectory view API is all-algo default. - else: - - def value(ob, prev_action, prev_reward, *state): - model_out, _ = self.model({ - SampleBatch.CUR_OBS: convert_to_torch_tensor( - np.asarray([ob]), self.device), - SampleBatch.PREV_ACTIONS: convert_to_torch_tensor( - np.asarray([prev_action]), self.device), - SampleBatch.PREV_REWARDS: convert_to_torch_tensor( - np.asarray([prev_reward]), self.device), - "is_training": False, - }, [ - convert_to_torch_tensor(np.asarray([s]), self.device) - for s in state - ], convert_to_torch_tensor(np.asarray([1]), self.device)) - # [0] = remove the batch dim. - return self.model.value_function()[0] + + def value(ob, prev_action, prev_reward, *state): + model_out, _ = self.model({ + SampleBatch.CUR_OBS: convert_to_torch_tensor( + np.asarray([ob]), self.device), + SampleBatch.PREV_ACTIONS: convert_to_torch_tensor( + np.asarray([prev_action]), self.device), + SampleBatch.PREV_REWARDS: convert_to_torch_tensor( + np.asarray([prev_reward]), self.device), + "is_training": False, + }, [ + convert_to_torch_tensor(np.asarray([s]), self.device) + for s in state + ], convert_to_torch_tensor(np.asarray([1]), self.device)) + # [0] = remove the batch dim. + return self.model.value_function()[0] # When not doing GAE, we do not require the value function's output. else: diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index efcadf32f96d..f71cc17a7855 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -1,6 +1,7 @@ import collections from gym.spaces import Space import logging +import math import numpy as np from typing import Any, List, Dict, Tuple, TYPE_CHECKING, Union @@ -34,6 +35,9 @@ def to_float_np_array(v: List[Any]) -> np.ndarray: return arr +_INIT_COLS = [SampleBatch.OBS] + + class _AgentCollector: """Collects samples for one agent in one trajectory (episode). @@ -45,9 +49,18 @@ class _AgentCollector: _next_unroll_id = 0 # disambiguates unrolls within a single episode - def __init__(self, shift_before: int = 0): - self.shift_before = max(shift_before, 1) + def __init__(self, view_reqs): + # Determine the size of the buffer we need for data before the actual + # episode starts. This is used for 0-buffering of e.g. prev-actions, + # or internal state inputs. + self.shift_before = -min( + [(int(vr.shift.split(":")[0]) + if isinstance(vr.shift, str) else vr.shift) + + (-1 if vr.data_col in _INIT_COLS or k in _INIT_COLS else 0) + for k, vr in view_reqs.items()]) + # The actual data buffers (lists holding each timestep's data). self.buffers: Dict[str, List] = {} + # The episode ID for the agent for which we collect data. self.episode_id = None # The simple timestep count for this agent. Gets increased by one # each time a (non-initial!) observation is added. @@ -137,31 +150,86 @@ def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: # -> skip. if data_col not in self.buffers: continue + # OBS are already shifted by -1 (the initial obs starts one ts # before all other data columns). - shift = view_req.shift - \ - (1 if data_col == SampleBatch.OBS else 0) + obs_shift = -1 if data_col == SampleBatch.OBS else 0 + + # Keep an np-array cache so we don't have to regenerate the + # np-array for different view_cols using to the same data_col. if data_col not in np_data: np_data[data_col] = to_float_np_array(self.buffers[data_col]) - # Shift is exactly 0: Send trajectory as is. - if shift == 0: - data = np_data[data_col][self.shift_before:] - # Shift is positive: We still need to 0-pad at the end here. - elif shift > 0: - data = to_float_np_array( - self.buffers[data_col][self.shift_before + shift:] + [ - np.zeros( - shape=view_req.space.shape, - dtype=view_req.space.dtype) for _ in range(shift) + + # Range of indices on time-axis, e.g. "-50:-1". Together with + # the `batch_repeat_value`, this determines the data produced. + # Example: + # batch_repeat_value=10, shift_from=-3, shift_to=-1 + # buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + # resulting data=[[-3, -2, -1], [7, 8, 9]] + # Range of 3 consecutive items repeats every 10 timesteps. + if view_req.shift_from is not None: + if view_req.batch_repeat_value > 1: + count = int( + math.ceil((len(np_data[data_col]) - self.shift_before) + / view_req.batch_repeat_value)) + data = np.asarray([ + np_data[data_col][self.shift_before + + (i * view_req.batch_repeat_value) + + view_req.shift_from + + obs_shift:self.shift_before + + (i * view_req.batch_repeat_value) + + view_req.shift_to + 1 + obs_shift] + for i in range(count) ]) - # Shift is negative: Shift into the already existing and 0-padded - # "before" area of our buffers. + else: + data = np_data[data_col][self.shift_before + + view_req.shift_from + + obs_shift:self.shift_before + + view_req.shift_to + 1 + obs_shift] + # Set of (probably non-consecutive) indices. + # Example: + # shift=[-3, 0] + # buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + # resulting data=[[-3, 0], [-2, 1], [-1, 2], [0, 3], [1, 4], ...] + elif isinstance(view_req.shift, np.ndarray): + data = np_data[data_col][self.shift_before + obs_shift + + view_req.shift] + # Single shift int value. Use the trajectory as-is, and if + # `shift` != 0: shifted by that value. else: - data = np_data[data_col][self.shift_before + shift:shift] + shift = view_req.shift + obs_shift + + # Batch repeat (only provide a value every n timesteps). + if view_req.batch_repeat_value > 1: + count = int( + math.ceil((len(np_data[data_col]) - self.shift_before) + / view_req.batch_repeat_value)) + data = np.asarray([np_data[data_col][self.shift_before + ( + i * view_req.batch_repeat_value) + shift] for i + in range(count)]) + # Shift is exactly 0: Use trajectory as is. + elif shift == 0: + data = np_data[data_col][self.shift_before:] + # Shift is positive: We still need to 0-pad at the end. + elif shift > 0: + data = to_float_np_array( + self.buffers[data_col][self.shift_before + shift:] + [ + np.zeros( + shape=view_req.space.shape, + dtype=view_req.space.dtype) + for _ in range(shift) + ]) + # Shift is negative: Shift into the already existing and + # 0-padded "before" area of our buffers. + else: + data = np_data[data_col][self.shift_before + shift:shift] + if len(data) > 0: batch_data[view_col] = data - batch = SampleBatch(batch_data) + # Due to possible batch-repeats > 1, columns in the resulting batch + # may not all have the same batch size. + batch = SampleBatch(batch_data, _dont_check_lens=True) # Add EPS_ID and UNROLL_ID to batch. batch.data[SampleBatch.EPS_ID] = np.repeat(self.episode_id, @@ -230,15 +298,22 @@ class _PolicyCollector: appended to this policy's buffers. """ - def __init__(self): - """Initializes a _PolicyCollector instance.""" + def __init__(self, policy): + """Initializes a _PolicyCollector instance. + + Args: + policy (Policy): The policy object. + """ self.buffers: Dict[str, List] = collections.defaultdict(list) + self.policy = policy # The total timestep count for all agents that use this policy. # NOTE: This is not an env-step count (across n agents). AgentA and # agentB, both using this policy, acting in the same episode and both # doing n steps would increase the count by 2*n. self.agent_steps = 0 + # Seq-lens list of already added agent batches. + self.seq_lens = [] if policy.is_recurrent() else None def add_postprocessed_batch_for_training( self, batch: SampleBatch, @@ -257,11 +332,18 @@ def add_postprocessed_batch_for_training( # 1) If col is not in view_requirements, we must have a direct # child of the base Policy that doesn't do auto-view req creation. # 2) Col is in view-reqs and needed for training. - if view_col not in view_requirements or \ - view_requirements[view_col].used_for_training: + view_req = view_requirements.get(view_col) + if view_req is None or view_req.used_for_training: self.buffers[view_col].extend(data) # Add the agent's trajectory length to our count. self.agent_steps += batch.count + # Adjust the seq-lens array depending on the incoming agent sequences. + if self.seq_lens is not None: + max_seq_len = self.policy.config["model"]["max_seq_len"] + count = batch.count + while count > 0: + self.seq_lens.append(min(count, max_seq_len)) + count -= max_seq_len def build(self): """Builds a SampleBatch for this policy from the collected data. @@ -273,20 +355,22 @@ def build(self): this policy. """ # Create batch from our buffers. - batch = SampleBatch(self.buffers) - assert SampleBatch.UNROLL_ID in batch.data + batch = SampleBatch( + self.buffers, _seq_lens=self.seq_lens, _dont_check_lens=True) # Clear buffers for future samples. self.buffers.clear() - # Reset agent steps to 0. + # Reset agent steps to 0 and seq-lens to empty list. self.agent_steps = 0 + if self.seq_lens is not None: + self.seq_lens = [] return batch class _PolicyCollectorGroup: def __init__(self, policy_map): self.policy_collectors = { - pid: _PolicyCollector() - for pid in policy_map.keys() + pid: _PolicyCollector(policy) + for pid, policy in policy_map.items() } # Total env-steps (1 env-step=up to N agents stepped). self.env_steps = 0 @@ -396,11 +480,14 @@ def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, self.agent_key_to_policy_id[agent_key] = policy_id else: assert self.agent_key_to_policy_id[agent_key] == policy_id + policy = self.policy_map[policy_id] + view_reqs = policy.model.inference_view_requirements if \ + getattr(policy, "model", None) else policy.view_requirements # Add initial obs to Trajectory. assert agent_key not in self.agent_collectors # TODO: determine exact shift-before based on the view-req shifts. - self.agent_collectors[agent_key] = _AgentCollector() + self.agent_collectors[agent_key] = _AgentCollector(view_reqs) self.agent_collectors[agent_key].add_init_obs( episode_id=episode.episode_id, agent_index=episode._agent_index(agent_id), @@ -466,11 +553,19 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \ for view_col, view_req in view_reqs.items(): # Create the batch of data from the different buffers. data_col = view_req.data_col or view_col - time_indices = \ - view_req.shift - ( - 1 if data_col in [SampleBatch.OBS, "t", "env_id", - SampleBatch.AGENT_INDEX] else 0) + delta = -1 if data_col in [ + SampleBatch.OBS, "t", "env_id", SampleBatch.EPS_ID, + SampleBatch.AGENT_INDEX + ] else 0 + # Range of shifts, e.g. "-100:0". Note: This includes index 0! + if view_req.shift_from is not None: + time_indices = (view_req.shift_from + delta, + view_req.shift_to + delta) + # Single shift (e.g. -1) or list of shifts, e.g. [-4, -1, 0]. + else: + time_indices = view_req.shift + delta data_list = [] + # Loop through agents and add-up their data (batch). for k in keys: if data_col == SampleBatch.EPS_ID: data_list.append(self.agent_collectors[k].episode_id) @@ -482,7 +577,15 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \ self.agent_collectors[k]._build_buffers({ data_col: fill_value }) - data_list.append(buffers[k][data_col][time_indices]) + if isinstance(time_indices, tuple): + if time_indices[1] == -1: + data_list.append( + buffers[k][data_col][time_indices[0]:]) + else: + data_list.append(buffers[k][data_col][time_indices[ + 0]:time_indices[1] + 1]) + else: + data_list.append(buffers[k][data_col][time_indices]) input_dict[view_col] = np.array(data_list) self._reset_inference_calls(policy_id) diff --git a/rllib/evaluation/postprocessing.py b/rllib/evaluation/postprocessing.py index a19411433a74..0cb25d5c7927 100644 --- a/rllib/evaluation/postprocessing.py +++ b/rllib/evaluation/postprocessing.py @@ -50,8 +50,6 @@ def compute_advantages(rollout: SampleBatch, processed rewards. """ - rollout_size = len(rollout[SampleBatch.ACTIONS]) - assert SampleBatch.VF_PREDS in rollout or not use_critic, \ "use_critic=True but values not found" assert use_critic or not use_gae, \ @@ -90,6 +88,4 @@ def compute_advantages(rollout: SampleBatch, rollout[Postprocessing.ADVANTAGES] = rollout[ Postprocessing.ADVANTAGES].astype(np.float32) - assert all(val.shape[0] == rollout_size for key, val in rollout.items()), \ - "Rollout stacked incorrectly!" return rollout diff --git a/rllib/evaluation/tests/test_trajectory_view_api.py b/rllib/evaluation/tests/test_trajectory_view_api.py index a50978bfdce4..1dc29af82481 100644 --- a/rllib/evaluation/tests/test_trajectory_view_api.py +++ b/rllib/evaluation/tests/test_trajectory_view_api.py @@ -7,19 +7,41 @@ import ray from ray import tune +from ray.rllib.agents.callbacks import DefaultCallbacks import ray.rllib.agents.dqn as dqn import ray.rllib.agents.ppo as ppo from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv from ray.rllib.examples.env.multi_agent import MultiAgentCartPole from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.examples.policy.episode_env_aware_policy import \ - EpisodeEnvAwareLSTMPolicy + EpisodeEnvAwareLSTMPolicy, EpisodeEnvAwareAttentionPolicy +from ray.rllib.models.tf.attention_net import GTrXLNet from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement +from ray.rllib.utils.annotations import override from ray.rllib.utils.test_utils import framework_iterator, check +class MyCallbacks(DefaultCallbacks): + @override(DefaultCallbacks) + def on_learn_on_batch(self, *, policy, train_batch, **kwargs): + assert train_batch.count == 201 + assert sum(train_batch.seq_lens) == 201 + for k, v in train_batch.data.items(): + if k == "state_in_0": + assert len(v) == len(train_batch.seq_lens) + else: + assert len(v) == 201 + current = None + for o in train_batch[SampleBatch.OBS]: + if current: + assert o == current + 1 + current = o + if o == 15: + current = None + + class TestTrajectoryViewAPI(unittest.TestCase): @classmethod def setUpClass(cls) -> None: @@ -116,6 +138,45 @@ def test_traj_view_lstm_prev_actions_and_rewards(self): assert view_req_policy[key].shift == 1 trainer.stop() + def test_traj_view_attention_net(self): + config = ppo.DEFAULT_CONFIG.copy() + # Setup attention net. + config["model"] = config["model"].copy() + config["model"]["max_seq_len"] = 50 + config["model"]["custom_model"] = GTrXLNet + config["model"]["custom_model_config"] = { + "num_transformer_units": 1, + "attn_dim": 64, + "num_heads": 2, + "memory_inference": 50, + "memory_training": 50, + "head_dim": 32, + "ff_hidden_dim": 32, + } + # Test with odd batch numbers. + config["train_batch_size"] = 1031 + config["sgd_minibatch_size"] = 201 + config["num_sgd_iter"] = 5 + config["num_workers"] = 0 + config["callbacks"] = MyCallbacks + config["env_config"] = { + "config": { + "start_at_t": 1 + } + } # first obs is [1.0] + + for _ in framework_iterator(config, frameworks="tf2"): + trainer = ppo.PPOTrainer( + config, + env="ray.rllib.examples.env.debug_counter_env.DebugCounterEnv", + ) + rw = trainer.workers.local_worker() + sample = rw.sample() + assert sample.count == config["rollout_fragment_length"] + results = trainer.train() + assert results["train_batch_size"] == config["train_batch_size"] + trainer.stop() + def test_traj_view_simple_performance(self): """Test whether PPOTrainer runs faster w/ `_use_trajectory_view_api`. """ @@ -298,6 +359,40 @@ def policy_fn(agent_id): pol_batch_wo = result.policy_batches["pol0"] check(pol_batch_w.data, pol_batch_wo.data) + def test_traj_view_attention_functionality(self): + action_space = Box(-float("inf"), float("inf"), shape=(3, )) + obs_space = Box(float("-inf"), float("inf"), (4, )) + max_seq_len = 50 + rollout_fragment_length = 201 + policies = { + "pol0": (EpisodeEnvAwareAttentionPolicy, obs_space, action_space, + {}), + } + + def policy_fn(agent_id): + return "pol0" + + config = { + "multiagent": { + "policies": policies, + "policy_mapping_fn": policy_fn, + }, + "model": { + "max_seq_len": max_seq_len, + }, + }, + + rollout_worker_w_api = RolloutWorker( + env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}), + policy_config=dict(config, **{"_use_trajectory_view_api": True}), + rollout_fragment_length=rollout_fragment_length, + policy_spec=policies, + policy_mapping_fn=policy_fn, + num_envs=1, + ) + batch = rollout_worker_w_api.sample() + print(batch) + def test_counting_by_agent_steps(self): """Test whether a PPOTrainer can be built with all frameworks.""" config = copy.deepcopy(ppo.DEFAULT_CONFIG) diff --git a/rllib/examples/attention_net.py b/rllib/examples/attention_net.py index 49884d9f308a..de3f06c29beb 100644 --- a/rllib/examples/attention_net.py +++ b/rllib/examples/attention_net.py @@ -39,6 +39,7 @@ config = { "env": args.env, + # This env_config is only used for the RepeatAfterMeEnv env. "env_config": { "repeat_delay": 2, }, @@ -48,7 +49,7 @@ "num_workers": 0, "num_envs_per_worker": 20, "entropy_coeff": 0.001, - "num_sgd_iter": 5, + "num_sgd_iter": 10, "vf_loss_coeff": 1e-5, "model": { "custom_model": GTrXLNet, @@ -56,9 +57,10 @@ "custom_model_config": { "num_transformer_units": 1, "attn_dim": 64, - "num_heads": 2, - "memory_tau": 50, + "memory_inference": 100, + "memory_training": 50, "head_dim": 32, + "num_heads": 2, "ff_hidden_dim": 32, }, }, @@ -71,7 +73,7 @@ "episode_reward_mean": args.stop_reward, } - results = tune.run(args.run, config=config, stop=stop, verbose=1) + results = tune.run(args.run, config=config, stop=stop, verbose=2) if args.as_test: check_learning_achieved(results, args.stop_reward) diff --git a/rllib/examples/custom_metrics_and_callbacks.py b/rllib/examples/custom_metrics_and_callbacks.py index d7a2c849d66a..745a94029a2e 100644 --- a/rllib/examples/custom_metrics_and_callbacks.py +++ b/rllib/examples/custom_metrics_and_callbacks.py @@ -28,6 +28,7 @@ def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, episode: MultiAgentEpisode, env_index: int, **kwargs): print("episode {} (env-idx={}) started.".format( episode.episode_id, env_index)) + episode.user_data["pole_angles"] = [] episode.hist_data["pole_angles"] = [] diff --git a/rllib/examples/env/debug_counter_env.py b/rllib/examples/env/debug_counter_env.py index c14d49951094..b2004bce0e22 100644 --- a/rllib/examples/env/debug_counter_env.py +++ b/rllib/examples/env/debug_counter_env.py @@ -12,18 +12,24 @@ class DebugCounterEnv(gym.Env): Reward is always: current ts % 3. """ - def __init__(self): + def __init__(self, config): self.action_space = gym.spaces.Discrete(2) - self.observation_space = gym.spaces.Box(0, 100, (1, )) - self.i = 0 + self.observation_space = \ + gym.spaces.Box(0, 100, (1, ), dtype=np.float32) + self.start_at_t = int(config.get("start_at_t", 0)) + self.i = self.start_at_t def reset(self): - self.i = 0 - return [self.i] + self.i = self.start_at_t + return self._get_obs() def step(self, action): self.i += 1 - return [self.i], self.i % 3, self.i >= 15, {} + return self._get_obs(), float(self.i % 3), \ + self.i >= 15 + self.start_at_t, {} + + def _get_obs(self): + return np.array([self.i], dtype=np.float32) class MultiAgentDebugCounterEnv(MultiAgentEnv): diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index 38478857c778..11de172a1515 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -61,7 +61,7 @@ def __init__(self, obs_space: gym.spaces.Space, self.time_major = self.model_config.get("_time_major") # Basic view requirement for all models: Use the observation as input. self.inference_view_requirements = { - SampleBatch.OBS: ViewRequirement(shift=0, space=self.obs_space), + SampleBatch.OBS: ViewRequirement(space=self.obs_space), } # TODO: (sven): Get rid of `get_initial_state` once Trajectory @@ -238,16 +238,20 @@ def from_batch(self, train_batch: SampleBatch, right input dict, state, and seq len arguments. """ - train_batch["is_training"] = is_training + input_dict = train_batch.copy() + input_dict["is_training"] = is_training states = [] i = 0 - while "state_in_{}".format(i) in train_batch: - states.append(train_batch["state_in_{}".format(i)]) + while "state_in_{}".format(i) in input_dict: + states.append(input_dict["state_in_{}".format(i)]) i += 1 - ret = self.__call__(train_batch, states, train_batch.get("seq_lens")) - del train_batch["is_training"] + ret = self.__call__(input_dict, states, input_dict.get("seq_lens")) return ret + # TODO: (sven) Experimental method. + def preprocess_train_batch(self, train_batch): + return train_batch + def import_from_h5(self, h5_file: str) -> None: """Imports weights from an h5 file. @@ -314,29 +318,6 @@ def is_time_major(self) -> bool: """ return self.time_major is True - # TODO: (sven) Experimental method. - def get_input_dict(self, sample_batch, - index: int = -1) -> Dict[str, TensorType]: - if index < 0: - index = sample_batch.count - 1 - - input_dict = {} - for view_col, view_req in self.inference_view_requirements.items(): - # Create batches of size 1 (single-agent input-dict). - - # Index range. - if isinstance(index, tuple): - data = sample_batch[view_col][index[0]:index[1] + 1] - input_dict[view_col] = np.array([data]) - # Single index. - else: - input_dict[view_col] = sample_batch[view_col][index:index + 1] - - # Add valid `seq_lens`, just in case RNNs need it. - input_dict["seq_lens"] = np.array([1], dtype=np.int32) - - return input_dict - class NullContextManager: """No-op context manager""" diff --git a/rllib/models/tests/test_attention_nets.py b/rllib/models/tests/test_attention_nets.py index ac6ec134ddba..5c3b146a8cf7 100644 --- a/rllib/models/tests/test_attention_nets.py +++ b/rllib/models/tests/test_attention_nets.py @@ -93,36 +93,21 @@ def train_torch_layer(self, model, inputs, outputs, num_epochs=250): # that the model is learning from the training data. self.assertLess(final_loss / init_loss, 0.5) - def train_tf_model(self, - model, - inputs, - outputs, - num_epochs=250, - minibatch_size=32): - """Convenience method that trains a Tensorflow model for num_epochs - epochs and tests whether loss decreased, as expected. - - Args: - model (tf.Model): Torch model to be trained. - inputs (np.array): Training data - outputs (np.array): Training labels - num_epochs (int): Number of training epochs - batch_size (int): Number of samples in each minibatch - """ - - # Configure a model for mean-squared error loss. - model.compile(optimizer="SGD", loss="mse", metrics=["mae"]) - - hist = model.fit( - inputs, - outputs, - verbose=0, - epochs=num_epochs, - batch_size=minibatch_size).history - init_loss = hist["loss"][0] - final_loss = hist["loss"][-1] - - self.assertLess(final_loss / init_loss, 0.5) + def train_tf_model(self, model, inputs, labels, num_epochs=250): + optim = tf.keras.optimizers.Adam(lr=0.0001) + init_loss = final_loss = None + for _ in range(num_epochs): + with tf.GradientTape() as tape: + outputs = model(inputs) + final_loss = tf.reduce_mean(tf.square(outputs[0] - labels[0])) + if init_loss is None: + init_loss = final_loss + # Optimizer step. + grads = tape.gradient(final_loss, model.trainable_variables) + optim.apply_gradients( + [(g, v) for g, v in zip(grads, model.trainable_variables)]) + + self.assertLess(final_loss, init_loss) def test_multi_head_attention(self): """Tests the MultiHeadAttention mechanism of Vaswani et al.""" @@ -171,35 +156,53 @@ def test_attention_net(self): relative_position_embedding(20, 15).eval(session=sess), relative_position_embedding_torch(20, 15).numpy()) - # B is batch size + # Batch size. B = 32 - # D_in is attention dim, L is memory_tau - L, D_in, D_out = 2, 16, 2 + # Max seq-len. + max_seq_len = 10 + # Memory size (inference). + memory_size = max_seq_len * 2 + # Memory size (training). + memory_training = max_seq_len + # Number of transformer units. + num_transformer_units = 2 + # Input dim. + observation_dim = 8 + # Head dim. + head_dim = 12 + # Attention dim. + attention_dim = 16 + # Action dim. + action_dim = 2 for fw, sess in framework_iterator(session=True): + # Create random Tensors to hold inputs and labels. + x = np.random.random((B, max_seq_len, + observation_dim)).astype(np.float32) + y = np.random.random((B, max_seq_len, action_dim)) + # Create a single attention layer with 2 heads if fw == "torch": - # Create random Tensors to hold inputs and outputs - x = torch.randn(B, L, D_in) - y = torch.randn(B, L, D_out) - - value_labels = torch.randn(B, L, D_in) - memory_labels = torch.randn(B, L, D_out) + value_labels = torch.randn(B, max_seq_len) + memory_labels = torch.randn(B, max_seq_len, attention_dim) attention_net = TorchGTrXLNet( observation_space=gym.spaces.Box( - low=float("-inf"), high=float("inf"), shape=(D_in, )), - action_space=gym.spaces.Discrete(D_out), - num_outputs=D_out, + low=float("-inf"), + high=float("inf"), + shape=(observation_dim, )), + action_space=gym.spaces.Discrete(action_dim), + num_outputs=action_dim, model_config={"max_seq_len": 2}, name="TestTorchAttentionNet", - num_transformer_units=2, - attn_dim=D_in, + num_transformer_units=num_transformer_units, + attn_dim=attention_dim, num_heads=2, - memory_tau=L, - head_dim=D_out, - ff_hidden_dim=16, + memory_inference=memory_size, + memory_training=memory_training, + head_dim=head_dim, + ff_hidden_dim=24, init_gate_bias=2.0) init_state = attention_net.get_initial_state() @@ -207,7 +210,7 @@ def test_attention_net(self): # Get initial state and add a batch dimension. init_state = [np.expand_dims(s, 0) for s in init_state] seq_lens_init = torch.full( - size=(B, ), fill_value=L, dtype=torch.int32) + size=(B, ), fill_value=max_seq_len, dtype=torch.int32) # Torch implementation expects a formatted input_dict instead # of a numpy array as input. @@ -220,40 +223,42 @@ def test_attention_net(self): seq_lens=seq_lens_init) # Framework is tensorflow or tensorflow-eager. else: - x = np.random.random((B, L, D_in)) - y = np.random.random((B, L, D_out)) - - value_labels = np.random.random((B, L, 1)) - memory_labels = np.random.random((B, L, D_in)) - - # We need to create (N-1) MLP labels for N transformer units - mlp_labels = np.random.random((B, L, D_in)) + value_labels = np.random.random((B, max_seq_len)) + memory_labels = [ + np.random.random((B, memory_size, attention_dim)) + for _ in range(num_transformer_units) + ] attention_net = GTrXLNet( observation_space=gym.spaces.Box( - low=float("-inf"), high=float("inf"), shape=(D_in, )), - action_space=gym.spaces.Discrete(D_out), - num_outputs=D_out, - model_config={"max_seq_len": 2}, + low=float("-inf"), + high=float("inf"), + shape=(observation_dim, )), + action_space=gym.spaces.Discrete(action_dim), + num_outputs=action_dim, + model_config={"max_seq_len": max_seq_len}, name="TestTFAttentionNet", - num_transformer_units=2, - attn_dim=D_in, + num_transformer_units=num_transformer_units, + attn_dim=attention_dim, num_heads=2, - memory_tau=L, - head_dim=D_out, - ff_hidden_dim=16, + memory_inference=memory_size, + memory_training=memory_training, + head_dim=head_dim, + ff_hidden_dim=24, init_gate_bias=2.0) model = attention_net.trxl_model - # Get initial state and add a batch dimension. - init_state = attention_net.get_initial_state() + # Get initial state (for training!) and add a batch dimension. + init_state = [ + np.zeros((memory_training, attention_dim), np.float32) + for _ in range(num_transformer_units) + ] init_state = [np.tile(s, (B, 1, 1)) for s in init_state] self.train_tf_model( - model, [x] + init_state, - [y, value_labels, memory_labels, mlp_labels], - num_epochs=200, - minibatch_size=B) + model, [x] + init_state + [np.array([True])], + [y, value_labels] + memory_labels, + num_epochs=20) if __name__ == "__main__": diff --git a/rllib/models/tf/attention_net.py b/rllib/models/tf/attention_net.py index 2ddbaf33b934..3e6b1c11196f 100644 --- a/rllib/models/tf/attention_net.py +++ b/rllib/models/tf/attention_net.py @@ -8,14 +8,18 @@ Z. Dai, Z. Yang, et al. - Carnegie Mellon U - 2019. https://www.aclweb.org/anthology/P19-1285.pdf """ +from gym.spaces import Box import numpy as np import gym -from typing import Optional, Any +from typing import Any, Dict, Optional from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.layers import GRUGate, RelativeMultiHeadAttention, \ SkipConnection from ray.rllib.models.tf.recurrent_net import RecurrentNetwork +from ray.rllib.policy.rnn_sequencing import chop_into_sequences +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.typing import ModelConfigDict, TensorType, List @@ -60,7 +64,7 @@ def __init__(self, observation_space: gym.spaces.Space, model_config: ModelConfigDict, name: str, num_transformer_units: int, attn_dim: int, num_heads: int, head_dim: int, ff_hidden_dim: int): - """Initializes a TfXLNet object. + """Initializes a TrXLNet object. Args: num_transformer_units (int): The number of Transformer repeats to @@ -88,8 +92,6 @@ def __init__(self, observation_space: gym.spaces.Space, self.max_seq_len = model_config["max_seq_len"] self.obs_dim = observation_space.shape[0] - pos_embedding = relative_position_embedding(self.max_seq_len, attn_dim) - inputs = tf.keras.layers.Input( shape=(self.max_seq_len, self.obs_dim), name="inputs") E_out = tf.keras.layers.Dense(attn_dim)(inputs) @@ -100,7 +102,6 @@ def __init__(self, observation_space: gym.spaces.Space, out_dim=attn_dim, num_heads=num_heads, head_dim=head_dim, - rel_pos_encoder=pos_embedding, input_layernorm=False, output_activation=None), fan_in_layer=None)(E_out) @@ -174,11 +175,12 @@ def __init__(self, num_transformer_units: int, attn_dim: int, num_heads: int, - memory_tau: int, + memory_inference: int, + memory_training: int, head_dim: int, ff_hidden_dim: int, init_gate_bias: float = 2.0): - """Initializes a GTrXLNet. + """Initializes a GTrXLNet instance. Args: num_transformer_units (int): The number of Transformer repeats to @@ -187,9 +189,15 @@ def __init__(self, unit. num_heads (int): The number of attention heads to use in parallel. Denoted as `H` in [3]. - memory_tau (int): The number of timesteps to store in each - transformer block's memory M (concat'd over time and fed into - next transformer block as input). + memory_inference (int): The number of timesteps to concat (time + axis) and feed into the next transformer unit as inference + input. The first transformer unit will receive this number of + past observations (plus the current one), instead. + memory_training (int): The number of timesteps to concat (time + axis) and feed into the next transformer unit as training + input (plus the actual input sequence of len=max_seq_len). + The first transformer unit will receive this number of + past observations (plus the input sequence), instead. head_dim (int): The dimension of a single(!) head. Denoted as `d` in [3]. ff_hidden_dim (int): The dimension of the hidden layer within @@ -208,21 +216,18 @@ def __init__(self, self.num_transformer_units = num_transformer_units self.attn_dim = attn_dim self.num_heads = num_heads - self.memory_tau = memory_tau + self.memory_inference = memory_inference + self.memory_training = memory_training self.head_dim = head_dim self.max_seq_len = model_config["max_seq_len"] self.obs_dim = observation_space.shape[0] - # Constant (non-trainable) sinusoid rel pos encoding matrix. - Phi = relative_position_embedding(self.max_seq_len + self.memory_tau, - self.attn_dim) - - # Raw observation input. + # Raw observation input (plus (None) time axis). input_layer = tf.keras.layers.Input( - shape=(self.max_seq_len, self.obs_dim), name="inputs") + shape=(None, self.obs_dim), name="inputs") memory_ins = [ tf.keras.layers.Input( - shape=(self.memory_tau, self.attn_dim), + shape=(None, self.attn_dim), dtype=tf.float32, name="memory_in_{}".format(i)) for i in range(self.num_transformer_units) @@ -242,7 +247,6 @@ def __init__(self, out_dim=self.attn_dim, num_heads=num_heads, head_dim=head_dim, - rel_pos_encoder=Phi, input_layernorm=True, output_activation=tf.nn.relu), fan_in_layer=GRUGate(init_gate_bias), @@ -280,69 +284,86 @@ def __init__(self, self.register_variables(self.trxl_model.variables) self.trxl_model.summary() - @override(RecurrentNetwork) - def forward_rnn(self, inputs: TensorType, state: List[TensorType], - seq_lens: TensorType) -> (TensorType, List[TensorType]): - # To make Attention work with current RLlib's ModelV2 API: - # We assume `state` is the history of L recent observations (all - # concatenated into one tensor) and append the current inputs to the - # end and only keep the most recent (up to `max_seq_len`). This allows - # us to deal with timestep-wise inference and full sequence training - # within the same logic. - observations = state[0] - memory = state[1:] + # Setup inference view (`memory-inference` x past observations + + # current one (0)) + # 1 to `num_transformer_units`: Memory data (one per transformer unit). + for i in range(self.num_transformer_units): + space = Box(-1.0, 1.0, shape=(self.attn_dim, )) + self.inference_view_requirements["state_in_{}".format(i)] = \ + ViewRequirement( + "state_out_{}".format(i), + shift="-{}:-1".format(self.memory_inference), + # Repeat the incoming state every max-seq-len times. + batch_repeat_value=self.max_seq_len, + space=space) + self.inference_view_requirements["state_out_{}".format(i)] = \ + ViewRequirement( + space=space, + used_for_training=False) - observations = tf.concat( - (observations, inputs), axis=1)[:, -self.max_seq_len:] - all_out = self.trxl_model([observations] + memory) - logits, self._value_out = all_out[0], all_out[1] - memory_outs = all_out[2:] - # If memory_tau > max_seq_len -> overlap w/ previous `memory` input. - if self.memory_tau > self.max_seq_len: - memory_outs = [ - tf.concat( - [memory[i][:, -(self.memory_tau - self.max_seq_len):], m], - axis=1) for i, m in enumerate(memory_outs) - ] - else: - memory_outs = [m[:, -self.memory_tau:] for m in memory_outs] + @override(ModelV2) + def forward(self, input_dict, state: List[TensorType], + seq_lens: TensorType) -> (TensorType, List[TensorType]): + assert seq_lens is not None - T = tf.shape(inputs)[1] # Length of input segment (time). - logits = logits[:, -T:] - self._value_out = self._value_out[:, -T:] + # Add the time dim to observations. + B = tf.shape(seq_lens)[0] + observations = input_dict[SampleBatch.OBS] - return logits, [observations] + memory_outs + shape = tf.shape(observations) + T = shape[0] // B + observations = tf.reshape(observations, + tf.concat([[-1, T], shape[1:]], axis=0)) + + all_out = self.trxl_model([observations] + state) + + logits = all_out[0] + self._value_out = all_out[1] + memory_outs = all_out[2:] + + return tf.reshape(logits, [-1, self.num_outputs]), [ + tf.reshape(m, [-1, self.attn_dim]) for m in memory_outs + ] # TODO: (sven) Deprecate this once trajectory view API has fully matured. @override(RecurrentNetwork) def get_initial_state(self) -> List[np.ndarray]: - # State is the T last observations concat'd together into one Tensor. - # Plus all Transformer blocks' E(l) outputs concat'd together (up to - # tau timesteps). - return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)] + \ - [np.zeros((self.memory_tau, self.attn_dim), np.float32) - for _ in range(self.num_transformer_units)] + return [] @override(ModelV2) def value_function(self) -> TensorType: return tf.reshape(self._value_out, [-1]) - -def relative_position_embedding(seq_length: int, out_dim: int) -> TensorType: - """Creates a [seq_length x seq_length] matrix for rel. pos encoding. - - Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding - matrix. - - Args: - seq_length (int): The max. sequence length (time axis). - out_dim (int): The number of nodes to go into the first Tranformer - layer with. - - Returns: - tf.Tensor: The encoding matrix Phi. - """ - inverse_freq = 1 / (10000**(tf.range(0, out_dim, 2.0) / out_dim)) - pos_offsets = tf.range(seq_length - 1., -1., -1.) - inputs = pos_offsets[:, None] * inverse_freq[None, :] - return tf.concat((tf.sin(inputs), tf.cos(inputs)), axis=-1) + @override(RecurrentNetwork) + def preprocess_train_batch(self, train_batch): + # Should be the same as for RecurrentNets, but with dynamic-max=False. + assert "state_in_0" in train_batch + state_keys = [] + feature_keys_ = [] + for k, v in train_batch.items(): + if k.startswith("state_in_"): + state_keys.append(k) + elif not k.startswith( + "state_out_" + ) and k != "infos" and k != "seq_lens" and isinstance( + v, np.ndarray): + feature_keys_.append(k) + + feature_sequences, initial_states, seq_lens = \ + chop_into_sequences( + episode_ids=None, + unroll_ids=None, + agent_indices=None, + feature_columns=[train_batch[k] for k in feature_keys_], + state_columns=[train_batch[k] for k in state_keys], + max_seq_len=self.model_config["max_seq_len"], + dynamic_max=False, + seq_lens=train_batch.seq_lens, + states_already_reduced_to_init=True, + shuffle=False) + for i, k in enumerate(feature_keys_): + train_batch[k] = feature_sequences[i] + for i, k in enumerate(state_keys): + train_batch[k] = initial_states[i] + train_batch["seq_lens"] = np.array(seq_lens) + return train_batch diff --git a/rllib/models/tf/layers/relative_multi_head_attention.py b/rllib/models/tf/layers/relative_multi_head_attention.py index f7d70ab60eba..5d0e796bd0f1 100644 --- a/rllib/models/tf/layers/relative_multi_head_attention.py +++ b/rllib/models/tf/layers/relative_multi_head_attention.py @@ -1,4 +1,4 @@ -from typing import Optional, Any +from typing import Optional from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.typing import TensorType @@ -16,9 +16,8 @@ def __init__(self, out_dim: int, num_heads: int, head_dim: int, - rel_pos_encoder: Any, input_layernorm: bool = False, - output_activation: Optional[Any] = None, + output_activation: Optional["tf.nn.activation"] = None, **kwargs): """Initializes a RelativeMultiHeadAttention keras Layer object. @@ -28,7 +27,6 @@ def __init__(self, Denoted `H` in [2]. head_dim (int): The dimension of a single(!) attention head Denoted `D` in [2]. - rel_pos_encoder (: input_layernorm (bool): Whether to prepend a LayerNorm before everything else. Should be True for building a GTrXL. output_activation (Optional[tf.nn.activation]): Optional tf.nn @@ -50,9 +48,14 @@ def __init__(self, self._uvar = self.add_weight(shape=(num_heads, head_dim)) self._vvar = self.add_weight(shape=(num_heads, head_dim)) + # Constant (non-trainable) sinusoid rel pos encoding matrix, which + # depends on this incoming time dimension. + # For inference, we prepend the memory to the current timestep's + # input: Tau + 1. For training, we prepend the memory to the input + # sequence: Tau + T. + self._pos_embedding = PositionalEmbedding(out_dim) self._pos_proj = tf.keras.layers.Dense( num_heads * head_dim, use_bias=False) - self._rel_pos_encoder = rel_pos_encoder self._input_layernorm = None if input_layernorm: @@ -66,9 +69,8 @@ def call(self, inputs: TensorType, # Add previous memory chunk (as const, w/o gradient) to input. # Tau (number of (prev) time slices in each memory chunk). - Tau = memory.shape.as_list()[1] if memory is not None else 0 - if memory is not None: - inputs = tf.concat((tf.stop_gradient(memory), inputs), axis=1) + Tau = tf.shape(memory)[1] + inputs = tf.concat([tf.stop_gradient(memory), inputs], axis=1) # Apply the Layer-Norm. if self._input_layernorm is not None: @@ -77,15 +79,17 @@ def call(self, inputs: TensorType, qkv = self._qkv_layer(inputs) queries, keys, values = tf.split(qkv, 3, -1) - # Cut out Tau memory timesteps from query. + # Cut out memory timesteps from query. queries = queries[:, -T:] + # Splitting up queries into per-head dims (d). queries = tf.reshape(queries, [-1, T, H, d]) - keys = tf.reshape(keys, [-1, T + Tau, H, d]) - values = tf.reshape(values, [-1, T + Tau, H, d]) + keys = tf.reshape(keys, [-1, Tau + T, H, d]) + values = tf.reshape(values, [-1, Tau + T, H, d]) - R = self._pos_proj(self._rel_pos_encoder) - R = tf.reshape(R, [T + Tau, H, d]) + R = self._pos_embedding(Tau + T) + R = self._pos_proj(R) + R = tf.reshape(R, [Tau + T, H, d]) # b=batch # i and j=time indices (i=max-timesteps (inputs); j=Tau memory space) @@ -96,9 +100,9 @@ def call(self, inputs: TensorType, score = score + self.rel_shift(pos_score) score = score / d**0.5 - # causal mask of the same length as the sequence + # Causal mask of the same length as the sequence. mask = tf.sequence_mask( - tf.range(Tau + 1, T + Tau + 1), dtype=score.dtype) + tf.range(Tau + 1, Tau + T + 1), dtype=score.dtype) mask = mask[None, :, :, None] masked_score = score * mask + 1e30 * (mask - 1.) @@ -121,3 +125,14 @@ def rel_shift(x: TensorType) -> TensorType: x = tf.reshape(x, x_size) return x + + +class PositionalEmbedding(tf.keras.layers.Layer): + def __init__(self, out_dim, **kwargs): + super().__init__(**kwargs) + self.inverse_freq = 1 / (10000**(tf.range(0, out_dim, 2.0) / out_dim)) + + def call(self, seq_length): + pos_offsets = tf.cast(tf.range(seq_length - 1, -1, -1), tf.float32) + inputs = pos_offsets[:, None] * self.inverse_freq[None, :] + return tf.concat((tf.sin(inputs), tf.cos(inputs)), axis=-1) diff --git a/rllib/models/tf/layers/skip_connection.py b/rllib/models/tf/layers/skip_connection.py index efb89f2e3387..a44ae2bc108d 100644 --- a/rllib/models/tf/layers/skip_connection.py +++ b/rllib/models/tf/layers/skip_connection.py @@ -16,7 +16,6 @@ class SkipConnection(tf.keras.layers.Layer if tf else object): def __init__(self, layer: Any, fan_in_layer: Optional[Any] = None, - add_memory: bool = False, **kwargs): """Initializes a SkipConnection keras layer object. diff --git a/rllib/models/tf/recurrent_net.py b/rllib/models/tf/recurrent_net.py index f939c7ae36a6..2a7a18269494 100644 --- a/rllib/models/tf/recurrent_net.py +++ b/rllib/models/tf/recurrent_net.py @@ -5,6 +5,7 @@ from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_modelv2 import TFModelV2 +from ray.rllib.models.utils import rnn_preprocess_train_batch from ray.rllib.policy.rnn_sequencing import add_time_dimension from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement @@ -108,6 +109,11 @@ def get_initial_state(self): """ raise NotImplementedError("You must implement this for a RNN model") + @override(ModelV2) + def preprocess_train_batch(self, train_batch): + return rnn_preprocess_train_batch( + train_batch, self.model_config["max_seq_len"]) + class LSTMWrapper(RecurrentNetwork): """An LSTM wrapper serving as an interface for ModelV2s that set use_lstm. diff --git a/rllib/models/torch/modules/skip_connection.py b/rllib/models/torch/modules/skip_connection.py index 126274b1da85..8d79b7826082 100644 --- a/rllib/models/torch/modules/skip_connection.py +++ b/rllib/models/torch/modules/skip_connection.py @@ -15,7 +15,6 @@ class SkipConnection(nn.Module): def __init__(self, layer: nn.Module, fan_in_layer: Optional[nn.Module] = None, - add_memory: bool = False, **kwargs): """Initializes a SkipConnection nn Module object. diff --git a/rllib/models/torch/recurrent_net.py b/rllib/models/torch/recurrent_net.py index d558bf3dbf74..c85b121a5281 100644 --- a/rllib/models/torch/recurrent_net.py +++ b/rllib/models/torch/recurrent_net.py @@ -6,6 +6,7 @@ from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.misc import SlimFC from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.models.utils import rnn_preprocess_train_batch from ray.rllib.policy.rnn_sequencing import add_time_dimension from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement @@ -106,6 +107,11 @@ def forward_rnn(self, inputs, state, seq_lens): """ raise NotImplementedError("You must implement this for an RNN model") + @override(ModelV2) + def preprocess_train_batch(self, train_batch): + return rnn_preprocess_train_batch( + train_batch, self.model_config["max_seq_len"]) + class LSTMWrapper(RecurrentNetwork, nn.Module): """An LSTM wrapper serving as an interface for ModelV2s that set use_lstm. diff --git a/rllib/models/utils.py b/rllib/models/utils.py index 2c9f076f0ebe..9a677352361f 100644 --- a/rllib/models/utils.py +++ b/rllib/models/utils.py @@ -1,3 +1,7 @@ +import numpy as np + +from ray.rllib.policy.rnn_sequencing import chop_into_sequences +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.framework import try_import_tf, try_import_torch @@ -65,3 +69,38 @@ def get_initializer(name, framework="tf"): raise ValueError("Unknown activation ({}) for framework={}!".format( name, framework)) + + +def rnn_preprocess_train_batch(train_batch, max_seq_len): + assert "state_in_0" in train_batch + state_keys = [] + feature_keys_ = [] + for k, v in train_batch.items(): + if k.startswith("state_in_"): + state_keys.append(k) + elif not k.startswith("state_out_") and k != "infos" and \ + isinstance(v, np.ndarray): + feature_keys_.append(k) + + states_already_reduced_to_init = \ + len(train_batch["state_in_0"]) < len(train_batch["obs"]) + + feature_sequences, initial_states, seq_lens = \ + chop_into_sequences( + feature_columns=[train_batch[k] for k in feature_keys_], + state_columns=[train_batch[k] for k in state_keys], + max_seq_len=max_seq_len, + episode_ids=train_batch.get(SampleBatch.EPS_ID), + unroll_ids=train_batch.get(SampleBatch.UNROLL_ID), + agent_indices=train_batch.get(SampleBatch.AGENT_INDEX), + dynamic_max=True, + shuffle=False, + seq_lens=getattr(train_batch, "seq_lens", train_batch.get("seq_lens")), + states_already_reduced_to_init=states_already_reduced_to_init, + ) + for i, k in enumerate(feature_keys_): + train_batch[k] = feature_sequences[i] + for i, k in enumerate(state_keys): + train_batch[k] = initial_states[i] + train_batch["seq_lens"] = seq_lens + return train_batch diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index 92734302f924..f7afcf83708b 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -183,11 +183,12 @@ def __init__( else: if self.config["_use_trajectory_view_api"]: self._state_inputs = [ - tf1.placeholder( - shape=(None, ) + vr.space.shape, dtype=vr.space.dtype) - for k, vr in + get_placeholder( + space=vr.space, + time_axis=not isinstance(vr.shift, int), + ) for k, vr in self.model.inference_view_requirements.items() - if k[:9] == "state_in_" + if k.startswith("state_in_") ] else: self._state_inputs = [ @@ -422,9 +423,14 @@ def _get_input_dict_and_dummy_batch(self, view_requirements, input_dict[view_col] = existing_inputs[view_col] # All others. else: + time_axis = not isinstance(view_req.shift, int) if view_req.used_for_training: + # Create a +time-axis placeholder if the shift is not an + # int (range or list of ints). input_dict[view_col] = get_placeholder( - space=view_req.space, name=view_col) + space=view_req.space, + name=view_col, + time_axis=time_axis) dummy_batch = self._get_dummy_batch_from_view_requirements( batch_size=32) @@ -489,10 +495,10 @@ def fake_array(tensor): dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) for k, v in self.extra_compute_action_fetches().items(): dummy_batch[k] = fake_array(v) + dummy_batch = SampleBatch(dummy_batch) - sb = SampleBatch(dummy_batch) - batch_for_postproc = UsageTrackingDict(sb) - batch_for_postproc.count = sb.count + batch_for_postproc = UsageTrackingDict(dummy_batch) + batch_for_postproc.count = dummy_batch.count logger.info("Testing `postprocess_trajectory` w/ dummy batch.") self.exploration.postprocess_trajectory(self, batch_for_postproc, self._sess) @@ -509,6 +515,11 @@ def fake_array(tensor): -1.0, 1.0, shape=batch_for_postproc[key].shape[1:], dtype=batch_for_postproc[key].dtype)) + # Model forward pass for the loss (needed after postprocess to + # overwrite any tensor state from that call) + ## TODO: replace with `compute_actions_from_input_dict` + #self.model(self._input_dict, self._state_inputs, self._seq_lens) + if not self.config["_use_trajectory_view_api"]: train_batch = UsageTrackingDict( dict({ @@ -518,6 +529,7 @@ def fake_array(tensor): train_batch.update({ SampleBatch.PREV_ACTIONS: self._prev_action_input, SampleBatch.PREV_REWARDS: self._prev_reward_input, + SampleBatch.CUR_OBS: self._obs_input, }) for k, v in postprocessed_batch.items(): @@ -577,7 +589,8 @@ def fake_array(tensor): for key in batch_for_postproc.accessed_keys: if key not in train_batch.accessed_keys and \ key not in self.model.inference_view_requirements: - self.view_requirements[key].used_for_training = False + if key in self.view_requirements: + self.view_requirements[key].used_for_training = False if key in self._loss_input_dict: del self._loss_input_dict[key] # Remove those not needed at all (leave those that are needed diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 40a2c7986326..7bd1c6783ba9 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -313,13 +313,18 @@ def learn_on_batch(self, postprocessed_batch): # Callback handling. self.callbacks.on_learn_on_batch( policy=self, train_batch=postprocessed_batch) - + # Get batch ready for multi-agent, if applicable. + if self.batch_divisibility_req > 1: + pad_batch_to_sequences_of_same_size( + postprocessed_batch, + shuffle=False, + max_seq_len=self._max_seq_len, + batch_divisibility_req=self.batch_divisibility_req) # Get batch ready for RNNs, if applicable. - pad_batch_to_sequences_of_same_size( - postprocessed_batch, - shuffle=False, - max_seq_len=self._max_seq_len, - batch_divisibility_req=self.batch_divisibility_req) + if getattr(self, "model", None): + postprocessed_batch = self.model.preprocess_train_batch(postprocessed_batch) + self._is_training = True + postprocessed_batch["is_training"] = True return self._learn_on_batch_eager(postprocessed_batch) @convert_eager_inputs @@ -332,12 +337,18 @@ def _learn_on_batch_eager(self, samples): @override(Policy) def compute_gradients(self, samples): + # Get batch ready for multi-agent, if applicable. + if self.batch_divisibility_req > 1: + pad_batch_to_sequences_of_same_size( + samples, + shuffle=False, + max_seq_len=self._max_seq_len, + batch_divisibility_req=self.batch_divisibility_req) # Get batch ready for RNNs, if applicable. - pad_batch_to_sequences_of_same_size( - samples, - shuffle=False, - max_seq_len=self._max_seq_len, - batch_divisibility_req=self.batch_divisibility_req) + if getattr(self, "model", None): + samples = self.model.preprocess_train_batch(samples) + self._is_training = True + samples["is_training"] = True return self._compute_gradients_eager(samples) @convert_eager_inputs @@ -369,7 +380,7 @@ def compute_actions(self, # TODO: remove python side effect to cull sources of bugs. self._is_training = False - self._state_in = state_batches + self._state_in = state_batches or [] if not tf1.executing_eagerly(): tf1.enable_eager_execution() @@ -591,8 +602,6 @@ def _apply_gradients(self, grads_and_vars): def _compute_gradients(self, samples): """Computes and returns grads as eager tensors.""" - self._is_training = True - with tf.GradientTape(persistent=gradients_fn is not None) as tape: loss = loss_fn(self, self.model, self.dist_class, samples) diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index a1e92ac37e17..ec1d0596431a 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -583,8 +583,8 @@ def _get_default_view_requirements(self): SampleBatch.DONES: ViewRequirement(), SampleBatch.INFOS: ViewRequirement(), SampleBatch.EPS_ID: ViewRequirement(), - SampleBatch.UNROLL_ID: ViewRequirement(), SampleBatch.AGENT_INDEX: ViewRequirement(), + SampleBatch.UNROLL_ID: ViewRequirement(), "t": ViewRequirement(), } @@ -631,8 +631,8 @@ def _initialize_loss_from_dummy_batch( postprocessed_batch = self.postprocess_trajectory(batch_for_postproc) if state_outs: B = 4 # For RNNs, have B=4, T=[depends on sample_batch_size] - # TODO: (sven) This hack will not work for attention net traj. - # view setup. + #TODO: (sven) This hack will not work for attention net traj. + # view setup. i = 0 while "state_in_{}".format(i) in postprocessed_batch: postprocessed_batch["state_in_{}".format(i)] = \ @@ -642,11 +642,11 @@ def _initialize_loss_from_dummy_batch( postprocessed_batch["state_out_{}".format(i)][:B] i += 1 seq_len = sample_batch_size // B - postprocessed_batch["seq_lens"] = \ + postprocessed_batch.seq_lens = \ np.array([seq_len for _ in range(B)], dtype=np.int32) # Remove the UsageTrackingDict wrap to prep for wrapping the # train batch with a to-tensor UsageTrackingDict. - train_batch = {k: v for k, v in postprocessed_batch.items()} + train_batch = self.model.preprocess_train_batch(postprocessed_batch) train_batch = self._lazy_tensor_dict(train_batch) train_batch.count = self._dummy_batch.count # Call the loss function, if it exists. @@ -712,13 +712,33 @@ def _get_dummy_batch_from_view_requirements( ret[view_col] = \ np.zeros((batch_size, ) + shape[1:], np.float32) else: - if isinstance(view_req.space, gym.spaces.Space): - ret[view_col] = np.zeros_like( - [view_req.space.sample() for _ in range(batch_size)]) + # Range of indices on time-axis, e.g. "-50:-1". + if view_req.shift_from is not None: + ret[view_col] = np.zeros_like([[ + view_req.space.sample() + for _ in range(view_req.shift_to - + view_req.shift_from + 1) + ] for _ in range(batch_size)]) + # Set of (probably non-consecutive) indices. + elif isinstance(view_req.shift, (list, tuple)): + ret[view_col] = np.zeros_like([[ + view_req.space.sample() + for t in range(len(view_req.shift)) + ] for _ in range(batch_size)]) + # Single shift int value. else: - ret[view_col] = [view_req.space for _ in range(batch_size)] - - return SampleBatch(ret) + if isinstance(view_req.space, gym.spaces.Space): + ret[view_col] = np.zeros_like([ + view_req.space.sample() for _ in range(batch_size) + ]) + else: + ret[view_col] = [ + view_req.space for _ in range(batch_size) + ] + + # Due to different view requirements for the different columns, + # columns in the resulting batch may not all have the same batch size. + return SampleBatch(ret, _dont_check_lens=True) def _update_model_inference_view_requirements_from_init_state(self): """Uses Model's (or this Policy's) init state to add needed ViewReqs. @@ -737,8 +757,13 @@ def _update_model_inference_view_requirements_from_init_state(self): view_reqs = model.inference_view_requirements if model else \ self.view_requirements view_reqs["state_in_{}".format(i)] = ViewRequirement( - "state_out_{}".format(i), shift=-1, space=space) - view_reqs["state_out_{}".format(i)] = ViewRequirement(space=space) + "state_out_{}".format(i), + shift=-1, + batch_repeat_value=self.config.get("model", {}).get( + "max_seq_len", 1), + space=space) + #TODO: check, whether we can set: used_for_training=False here. + view_reqs["state_out_{}".format(i)] = ViewRequirement(space=space, used_for_training=False) def clip_action(action, action_space): diff --git a/rllib/policy/rnn_sequencing.py b/rllib/policy/rnn_sequencing.py index 486bbf0db457..118f8deb9652 100644 --- a/rllib/policy/rnn_sequencing.py +++ b/rllib/policy/rnn_sequencing.py @@ -88,12 +88,13 @@ def pad_batch_to_sequences_of_same_size( feature_sequences, initial_states, seq_lens = \ chop_into_sequences( - batch[SampleBatch.EPS_ID], - batch[SampleBatch.UNROLL_ID], - batch[SampleBatch.AGENT_INDEX], - [batch[k] for k in feature_keys_], - [batch[k] for k in state_keys], - max_seq_len, + feature_columns=[batch[k] for k in feature_keys_], + state_columns=[batch[k] for k in state_keys], + episode_ids=batch[SampleBatch.EPS_ID], + unroll_ids=batch[SampleBatch.UNROLL_ID], + agent_indices=batch[SampleBatch.AGENT_INDEX], + seq_lens=batch.seq_lens, + max_seq_len=max_seq_len, dynamic_max=dynamic_max, shuffle=shuffle) for i, k in enumerate(feature_keys_): @@ -157,19 +158,20 @@ def add_time_dimension(padded_inputs: TensorType, return torch.reshape(padded_inputs, new_shape) -# NOTE: This function will be deprecated once chunks already come padded and -# correctly chopped from the _SampleCollector object (in time-major fashion -# or not). It is already no longer user iff `_use_trajectory_view_api` = True. @DeveloperAPI -def chop_into_sequences(episode_ids, - unroll_ids, - agent_indices, - feature_columns, - state_columns, - max_seq_len, - dynamic_max=True, - shuffle=False, - _extra_padding=0): +def chop_into_sequences( + *, + feature_columns, + state_columns, + max_seq_len, + episode_ids=None, + unroll_ids=None, + agent_indices=None, + dynamic_max=True, + shuffle=False, + seq_lens=None, + states_already_reduced_to_init=False, + _extra_padding=0): """Truncate and pad experiences into fixed-length sequences. Args: @@ -212,23 +214,24 @@ def chop_into_sequences(episode_ids, [2, 3, 1] """ - prev_id = None - seq_lens = [] - seq_len = 0 - unique_ids = np.add( - np.add(episode_ids, agent_indices), - np.array(unroll_ids, dtype=np.int64) << 32) - for uid in unique_ids: - if (prev_id is not None and uid != prev_id) or \ - seq_len >= max_seq_len: + if seq_lens is None: + prev_id = None + seq_lens = [] + seq_len = 0 + unique_ids = np.add( + np.add(episode_ids, agent_indices), + np.array(unroll_ids, dtype=np.int64) << 32) + for uid in unique_ids: + if (prev_id is not None and uid != prev_id) or \ + seq_len >= max_seq_len: + seq_lens.append(seq_len) + seq_len = 0 + seq_len += 1 + prev_id = uid + if seq_len: seq_lens.append(seq_len) - seq_len = 0 - seq_len += 1 - prev_id = uid - if seq_len: - seq_lens.append(seq_len) - assert sum(seq_lens) == len(unique_ids) - seq_lens = np.array(seq_lens, dtype=np.int32) + seq_lens = np.array(seq_lens, dtype=np.int32) + assert sum(seq_lens) == len(feature_columns[0]) # Dynamically shrink max len as needed to optimize memory usage if dynamic_max: @@ -252,18 +255,23 @@ def chop_into_sequences(episode_ids, f_pad[seq_base + seq_offset] = f[i] i += 1 seq_base += max_seq_len - assert i == len(unique_ids), f + assert i == len(f), f feature_sequences.append(f_pad) - initial_states = [] - for s in state_columns: - s = np.array(s) - s_init = [] - i = 0 - for len_ in seq_lens: - s_init.append(s[i]) - i += len_ - initial_states.append(np.array(s_init)) + if states_already_reduced_to_init: + initial_states = state_columns + else: + initial_states = [] + for s in state_columns: + # Skip unnecessary copy. + if not isinstance(s, np.ndarray): + s = np.array(s) + s_init = [] + i = 0 + for len_ in seq_lens: + s_init.append(s[i]) + i += len_ + initial_states.append(np.array(s_init)) if shuffle: permutation = np.random.permutation(len(seq_lens)) diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index a2934fdb981f..a1b4c43bcb1b 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -61,6 +61,7 @@ def __init__(self, *args, **kwargs): # Possible seq_lens (TxB or BxT) setup. self.time_major = kwargs.pop("_time_major", None) self.seq_lens = kwargs.pop("_seq_lens", None) + self.dont_check_lens = kwargs.pop("_dont_check_lens", False) self.max_seq_len = None if self.seq_lens is not None and len(self.seq_lens) > 0: self.max_seq_len = max(self.seq_lens) @@ -76,8 +77,10 @@ def __init__(self, *args, **kwargs): self.data[k] = np.array(v) if not lengths: raise ValueError("Empty sample batch") - assert len(set(lengths)) == 1, \ - "Data columns must be same length, but lens are {}".format(lengths) + if not self.dont_check_lens: + assert len(set(lengths)) == 1, \ + "Data columns must be same length, but lens are " \ + "{}".format(lengths) if self.seq_lens is not None and len(self.seq_lens) > 0: self.count = sum(self.seq_lens) else: @@ -117,7 +120,8 @@ def concat_samples(samples: List["SampleBatch"]) -> \ return SampleBatch( out, _seq_lens=np.array(seq_lens, dtype=np.int32), - _time_major=concat_samples[0].time_major) + _time_major=concat_samples[0].time_major, + _dont_check_lens=True) @PublicAPI def concat(self, other: "SampleBatch") -> "SampleBatch": @@ -248,12 +252,35 @@ def slice(self, start: int, end: int) -> "SampleBatch": SampleBatch: A new SampleBatch, which has a slice of this batch's data. """ - if self.time_major is not None: + if self.seq_lens is not None and len(self.seq_lens) > 0: + data = {k: v[start:end] for k, v in self.data.items()} + # Fix state_in_x data. + count = 0 + state_start = None + seq_lens = None + for i, seq_len in enumerate(self.seq_lens): + count += seq_len + if count >= end: + state_idx = 0 + state_key = "state_in_{}".format(state_idx) + while state_key in self.data: + data[state_key] = self.data[state_key][state_start:i + + 1] + state_idx += 1 + state_key = "state_in_{}".format(state_idx) + seq_lens = list(self.seq_lens[state_start:i]) + [ + seq_len - (count - end) + ] + assert sum(seq_lens) == (end - start) + break + elif state_start is None and count > start: + state_start = i + return SampleBatch( - {k: v[:, start:end] - for k, v in self.data.items()}, - _seq_lens=self.seq_lens[start:end], - _time_major=self.time_major) + data, + _seq_lens=np.array(seq_lens, dtype=np.int32), + _time_major=self.time_major, + _dont_check_lens=True) else: return SampleBatch( {k: v[start:end] diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index abcfae503c6d..f40ffecbd39e 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -9,7 +9,6 @@ import ray.experimental.tf_utils from ray.util.debug import log_once from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY -from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils.annotations import override, DeveloperAPI @@ -174,11 +173,6 @@ def __init__(self, raise ValueError( "Number of state input and output tensors must match, got: " "{} vs {}".format(self._state_inputs, self._state_outputs)) - if len(self.get_initial_state()) != len(self._state_inputs): - raise ValueError( - "Length of initial state must match number of state inputs, " - "got: {} vs {}".format(self.get_initial_state(), - self._state_inputs)) if self._state_inputs and self._seq_lens is None: raise ValueError( "seq_lens tensor must be given if state inputs are defined") @@ -790,11 +784,11 @@ def _get_grad_and_stats_fetches(self): **fetches[LEARNER_STATS_KEY]) return fetches - def _get_loss_inputs_dict(self, batch, shuffle): + def _get_loss_inputs_dict(self, train_batch, shuffle): """Return a feed dict from a batch. Args: - batch (SampleBatch): batch of data to derive inputs from + train_batch (SampleBatch): batch of data to derive inputs from. shuffle (bool): whether to shuffle batch sequences. Shuffle may be done in-place. This only makes sense if you're further applying minibatch SGD after getting the outputs. @@ -803,30 +797,25 @@ def _get_loss_inputs_dict(self, batch, shuffle): feed dict of data """ - # Get batch ready for RNNs, if applicable. - pad_batch_to_sequences_of_same_size( - batch, - shuffle=shuffle, - max_seq_len=self._max_seq_len, - batch_divisibility_req=self._batch_divisibility_req, - feature_keys=[ - k for k in self._loss_input_dict.keys() if k != "seq_lens" - ], - ) - batch["is_training"] = True + # Get batch ready for RNNs/Attention Nets, etc. + train_batch = self.model.preprocess_train_batch(train_batch) + + # Mark the batch as "is_training" so the Model can use this + # information. + train_batch["is_training"] = True # Build the feed dict from the batch. feed_dict = {} for key, placeholder in self._loss_input_dict.items(): - feed_dict[placeholder] = batch[key] + feed_dict[placeholder] = train_batch[key] state_keys = [ "state_in_{}".format(i) for i in range(len(self._state_inputs)) ] for key in state_keys: - feed_dict[self._loss_input_dict[key]] = batch[key] + feed_dict[self._loss_input_dict[key]] = train_batch[key] if state_keys: - feed_dict[self._seq_lens] = batch["seq_lens"] + feed_dict[self._seq_lens] = train_batch["seq_lens"] return feed_dict diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index f294b510dba0..1b246573d8f0 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -345,14 +345,18 @@ def learn_on_batch( @DeveloperAPI def compute_gradients(self, postprocessed_batch: SampleBatch) -> ModelGradients: - # Get batch ready for RNNs, if applicable. - pad_batch_to_sequences_of_same_size( - postprocessed_batch, - max_seq_len=self.max_seq_len, - shuffle=False, - batch_divisibility_req=self.batch_divisibility_req, - ) - + # Get batch ready for multi-agent, if applicable. + if self.batch_divisibility_req > 1: + pad_batch_to_sequences_of_same_size( + postprocessed_batch, + max_seq_len=self.max_seq_len, + shuffle=False, + batch_divisibility_req=self.batch_divisibility_req, + ) + # Allow model to preprocess the batch before passing it through the + # loss function. + postprocessed_batch = self.model.preprocess_train_batch( + postprocessed_batch) train_batch = self._lazy_tensor_dict(postprocessed_batch) # Calculate the actual policy loss. diff --git a/rllib/policy/view_requirement.py b/rllib/policy/view_requirement.py index f9c7750d45eb..25a5e908a20f 100644 --- a/rllib/policy/view_requirement.py +++ b/rllib/policy/view_requirement.py @@ -1,4 +1,5 @@ import gym +import numpy as np from typing import List, Optional, Union from ray.rllib.utils.framework import try_import_torch @@ -29,8 +30,9 @@ class ViewRequirement: def __init__(self, data_col: Optional[str] = None, space: gym.Space = None, - shift: Union[int, List[int]] = 0, + shift: Union[int, str, List[int]] = 0, index: Optional[int] = None, + batch_repeat_value: int = 1, used_for_training: bool = True): """Initializes a ViewRequirement object. @@ -64,7 +66,19 @@ def __init__(self, self.space = space if space is not None else gym.spaces.Box( float("-inf"), float("inf"), shape=()) + self.shift = shift + if isinstance(self.shift, (list, tuple)): + self.shift = np.array(self.shift) + + # Special case: Providing a (probably larger) range of indices, e.g. + # "-100:0" (past 100 timesteps plus current one). + self.shift_from = self.shift_to = None + if isinstance(self.shift, str): + f, t = self.shift.split(":") + self.shift_from = int(f) + self.shift_to = int(t) + self.index = index + self.batch_repeat_value = batch_repeat_value - self.shift = shift self.used_for_training = used_for_training diff --git a/rllib/utils/sgd.py b/rllib/utils/sgd.py index d5576e0fa57d..b5b72d44d37c 100644 --- a/rllib/utils/sgd.py +++ b/rllib/utils/sgd.py @@ -72,18 +72,23 @@ def minibatches(samples, sgd_minibatch_size): i = 0 slices = [] - if samples.seq_lens: - seq_no = 0 - while i < samples.count: - seq_no_end = seq_no - actual_count = 0 - while actual_count < sgd_minibatch_size and len( - samples.seq_lens) > seq_no_end: - actual_count += samples.seq_lens[seq_no_end] - seq_no_end += 1 - slices.append((seq_no, seq_no_end)) - i += actual_count - seq_no = seq_no_end + if samples.seq_lens is not None and len(samples.seq_lens) > 0: + start_pos = 0 + minibatch_size = 0 + idx = 0 + while idx < len(samples.seq_lens): + seq_len = samples.seq_lens[idx] + minibatch_size += seq_len + # Complete minibatch -> Append to slices. + if minibatch_size >= sgd_minibatch_size: + slices.append((start_pos, start_pos + sgd_minibatch_size)) + start_pos += sgd_minibatch_size + if minibatch_size > sgd_minibatch_size: + overhead = minibatch_size - sgd_minibatch_size + start_pos -= (seq_len - overhead) + idx -= 1 + minibatch_size = 0 + idx += 1 else: while i < samples.count: slices.append((i, i + sgd_minibatch_size))