Skip to content

Commit 99c81c6

Browse files
authored
[RLlib] Attention Net prep PR #3. (ray-project#12450)
1 parent 401d342 commit 99c81c6

32 files changed

+355
-248
lines changed

rllib/agents/ppo/appo_tf_policy.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
from ray.rllib.agents.impala import vtrace_tf as vtrace
1414
from ray.rllib.agents.impala.vtrace_tf_policy import _make_time_major, \
1515
clip_gradients, choose_optimizer
16+
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae
1617
from ray.rllib.evaluation.episode import MultiAgentEpisode
1718
from ray.rllib.evaluation.postprocessing import Postprocessing
1819
from ray.rllib.models.tf.tf_action_dist import Categorical
1920
from ray.rllib.policy.policy import Policy
2021
from ray.rllib.policy.sample_batch import SampleBatch
21-
from ray.rllib.evaluation.postprocessing import compute_advantages
2222
from ray.rllib.policy.tf_policy_template import build_tf_policy
2323
from ray.rllib.policy.tf_policy import LearningRateSchedule, TFPolicy
2424
from ray.rllib.agents.ppo.ppo_tf_policy import KLCoeffMixin, ValueNetworkMixin
@@ -338,31 +338,14 @@ def postprocess_trajectory(
338338
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
339339
"""
340340
if not policy.config["vtrace"]:
341-
completed = sample_batch["dones"][-1]
342-
if completed:
343-
last_r = 0.0
344-
else:
345-
next_state = []
346-
for i in range(policy.num_state_tensors()):
347-
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
348-
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
349-
sample_batch[SampleBatch.ACTIONS][-1],
350-
sample_batch[SampleBatch.REWARDS][-1],
351-
*next_state)
352-
batch = compute_advantages(
353-
sample_batch,
354-
last_r,
355-
policy.config["gamma"],
356-
policy.config["lambda"],
357-
use_gae=policy.config["use_gae"],
358-
use_critic=policy.config["use_critic"])
359-
else:
360-
batch = sample_batch
341+
sample_batch = postprocess_ppo_gae(policy, sample_batch,
342+
other_agent_batches, episode)
343+
361344
# TODO: (sven) remove this del once we have trajectory view API fully in
362345
# place.
363-
del batch.data["new_obs"] # not used, so save some bandwidth
346+
del sample_batch.data["new_obs"] # not used, so save some bandwidth
364347

365-
return batch
348+
return sample_batch
366349

367350

368351
def add_values(policy):

rllib/agents/ppo/ppo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
# If true, use the Generalized Advantage Estimator (GAE)
3939
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
4040
"use_gae": True,
41-
# The GAE(lambda) parameter.
41+
# The GAE (lambda) parameter.
4242
"lambda": 1.0,
4343
# Initial coefficient for KL divergence.
4444
"kl_coeff": 0.2,

rllib/agents/ppo/ppo_tf_policy.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -193,13 +193,22 @@ def postprocess_ppo_gae(
193193
last_r = 0.0
194194
# Trajectory has been truncated -> last r=VF estimate of last obs.
195195
else:
196-
next_state = []
197-
for i in range(policy.num_state_tensors()):
198-
next_state.append(sample_batch["state_out_{}".format(i)][-1])
199-
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
200-
sample_batch[SampleBatch.ACTIONS][-1],
201-
sample_batch[SampleBatch.REWARDS][-1],
202-
*next_state)
196+
# Input dict is provided to us automatically via the Model's
197+
# requirements. It's a single-timestep (last one in trajectory)
198+
# input_dict.
199+
if policy.config["_use_trajectory_view_api"]:
200+
# Create an input dict according to the Model's requirements.
201+
input_dict = policy.model.get_input_dict(sample_batch, index=-1)
202+
last_r = policy._value(**input_dict)
203+
# TODO: (sven) Remove once trajectory view API is all-algo default.
204+
else:
205+
next_state = []
206+
for i in range(policy.num_state_tensors()):
207+
next_state.append(sample_batch["state_out_{}".format(i)][-1])
208+
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
209+
sample_batch[SampleBatch.ACTIONS][-1],
210+
sample_batch[SampleBatch.REWARDS][-1],
211+
*next_state)
203212

204213
# Adds the policy logits, VF preds, and advantages to the batch,
205214
# using GAE ("generalized advantage estimation") or not.
@@ -208,7 +217,9 @@ def postprocess_ppo_gae(
208217
last_r,
209218
policy.config["gamma"],
210219
policy.config["lambda"],
211-
use_gae=policy.config["use_gae"])
220+
use_gae=policy.config["use_gae"],
221+
use_critic=policy.config.get("use_critic", True))
222+
212223
return batch
213224

214225

@@ -292,25 +303,40 @@ def __init__(self, obs_space, action_space, config):
292303
# observation.
293304
if config["use_gae"]:
294305

295-
@make_tf_callable(self.get_session())
296-
def value(ob, prev_action, prev_reward, *state):
297-
model_out, _ = self.model({
298-
SampleBatch.CUR_OBS: tf.convert_to_tensor([ob]),
299-
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
300-
[prev_action]),
301-
SampleBatch.PREV_REWARDS: tf.convert_to_tensor(
302-
[prev_reward]),
303-
"is_training": tf.convert_to_tensor([False]),
304-
}, [tf.convert_to_tensor([s]) for s in state],
305-
tf.convert_to_tensor([1]))
306-
# [0] = remove the batch dim.
307-
return self.model.value_function()[0]
306+
# Input dict is provided to us automatically via the Model's
307+
# requirements. It's a single-timestep (last one in trajectory)
308+
# input_dict.
309+
if config["_use_trajectory_view_api"]:
310+
311+
@make_tf_callable(self.get_session())
312+
def value(**input_dict):
313+
model_out, _ = self.model.from_batch(
314+
input_dict, is_training=False)
315+
# [0] = remove the batch dim.
316+
return self.model.value_function()[0]
317+
318+
# TODO: (sven) Remove once trajectory view API is all-algo default.
319+
else:
320+
321+
@make_tf_callable(self.get_session())
322+
def value(ob, prev_action, prev_reward, *state):
323+
model_out, _ = self.model({
324+
SampleBatch.CUR_OBS: tf.convert_to_tensor([ob]),
325+
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
326+
[prev_action]),
327+
SampleBatch.PREV_REWARDS: tf.convert_to_tensor(
328+
[prev_reward]),
329+
"is_training": tf.convert_to_tensor([False]),
330+
}, [tf.convert_to_tensor([s]) for s in state],
331+
tf.convert_to_tensor([1]))
332+
# [0] = remove the batch dim.
333+
return self.model.value_function()[0]
308334

309335
# When not doing GAE, we do not require the value function's output.
310336
else:
311337

312338
@make_tf_callable(self.get_session())
313-
def value(ob, prev_action, prev_reward, *state):
339+
def value(*args, **kwargs):
314340
return tf.constant(0.0)
315341

316342
self._value = value

rllib/agents/ppo/ppo_torch_policy.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -210,22 +210,36 @@ def __init__(self, obs_space, action_space, config):
210210
# When doing GAE, we need the value function estimate on the
211211
# observation.
212212
if config["use_gae"]:
213-
214-
def value(ob, prev_action, prev_reward, *state):
215-
model_out, _ = self.model({
216-
SampleBatch.CUR_OBS: convert_to_torch_tensor(
217-
np.asarray([ob]), self.device),
218-
SampleBatch.PREV_ACTIONS: convert_to_torch_tensor(
219-
np.asarray([prev_action]), self.device),
220-
SampleBatch.PREV_REWARDS: convert_to_torch_tensor(
221-
np.asarray([prev_reward]), self.device),
222-
"is_training": False,
223-
}, [
224-
convert_to_torch_tensor(np.asarray([s]), self.device)
225-
for s in state
226-
], convert_to_torch_tensor(np.asarray([1]), self.device))
227-
# [0] = remove the batch dim.
228-
return self.model.value_function()[0]
213+
# Input dict is provided to us automatically via the Model's
214+
# requirements. It's a single-timestep (last one in trajectory)
215+
# input_dict.
216+
if config["_use_trajectory_view_api"]:
217+
218+
def value(**input_dict):
219+
model_out, _ = self.model.from_batch(
220+
convert_to_torch_tensor(input_dict, self.device),
221+
is_training=False)
222+
# [0] = remove the batch dim.
223+
return self.model.value_function()[0]
224+
225+
# TODO: (sven) Remove once trajectory view API is all-algo default.
226+
else:
227+
228+
def value(ob, prev_action, prev_reward, *state):
229+
model_out, _ = self.model({
230+
SampleBatch.CUR_OBS: convert_to_torch_tensor(
231+
np.asarray([ob]), self.device),
232+
SampleBatch.PREV_ACTIONS: convert_to_torch_tensor(
233+
np.asarray([prev_action]), self.device),
234+
SampleBatch.PREV_REWARDS: convert_to_torch_tensor(
235+
np.asarray([prev_reward]), self.device),
236+
"is_training": False,
237+
}, [
238+
convert_to_torch_tensor(np.asarray([s]), self.device)
239+
for s in state
240+
], convert_to_torch_tensor(np.asarray([1]), self.device))
241+
# [0] = remove the batch dim.
242+
return self.model.value_function()[0]
229243

230244
# When not doing GAE, we do not require the value function's output.
231245
else:

rllib/agents/qmix/model.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
from gym.spaces import Box
2-
31
from ray.rllib.models.modelv2 import ModelV2
42
from ray.rllib.models.preprocessors import get_preprocessor
53
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
6-
from ray.rllib.policy.view_requirement import ViewRequirement
74
from ray.rllib.utils.annotations import override
85
from ray.rllib.utils.framework import try_import_torch
96

@@ -25,17 +22,13 @@ def __init__(self, obs_space, action_space, num_outputs, model_config,
2522
self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs)
2623
self.n_agents = model_config["n_agents"]
2724

28-
self.inference_view_requirements.update({
29-
"state_in_0": ViewRequirement(
30-
"state_out_0",
31-
data_rel_pos=-1,
32-
space=Box(-1.0, 1.0, (self.n_agents, self.rnn_hidden_dim)))
33-
})
34-
3525
@override(ModelV2)
3626
def get_initial_state(self):
3727
# Place hidden states on same device as model.
38-
return [self.fc1.weight.new(1, self.rnn_hidden_dim).zero_().squeeze(0)]
28+
return [
29+
self.fc1.weight.new(self.n_agents,
30+
self.rnn_hidden_dim).zero_().squeeze(0)
31+
]
3932

4033
@override(ModelV2)
4134
def forward(self, input_dict, hidden_state, seq_lens):

rllib/agents/qmix/qmix_policy.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,6 @@ def __init__(self, obs_space, action_space, config):
215215
name="target_model",
216216
default_model=RNNModel).to(self.device)
217217

218-
# Combine view_requirements for Model and Policy.
219-
self.view_requirements.update(self.model.inference_view_requirements)
220-
221218
self.exploration = self._create_exploration()
222219

223220
# Setup the mixer network.

rllib/contrib/maddpg/maddpg_policy.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def postprocess_trajectory(self,
2828
other_agent_batches=None,
2929
episode=None):
3030
# FIXME: Get done from info is required since agentwise done is not
31-
# supported now.
31+
# supported now.
3232
sample_batch.data[SampleBatch.DONES] = self.get_done_from_info(
3333
sample_batch.data[SampleBatch.INFOS])
3434

@@ -251,6 +251,9 @@ def _make_loss_inputs(placeholders):
251251
loss_inputs=loss_inputs,
252252
dist_inputs=actor_feature)
253253

254+
del self.view_requirements["prev_actions"]
255+
del self.view_requirements["prev_rewards"]
256+
254257
self.sess.run(tf1.global_variables_initializer())
255258

256259
# Hard initial update

rllib/evaluation/collectors/sample_collector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def try_build_truncated_episode_multi_agent_batch(self) -> \
191191
postprocessor.
192192
This is usually called to collect samples for policy training.
193193
If not enough data has been collected yet (`rollout_fragment_length`),
194-
returns None.
194+
returns an empty list.
195195
196196
Returns:
197197
List[Union[MultiAgentBatch, SampleBatch]]: Returns a (possibly

0 commit comments

Comments
 (0)