@@ -193,13 +193,22 @@ def postprocess_ppo_gae(
193193 last_r = 0.0
194194 # Trajectory has been truncated -> last r=VF estimate of last obs.
195195 else :
196- next_state = []
197- for i in range (policy .num_state_tensors ()):
198- next_state .append (sample_batch ["state_out_{}" .format (i )][- 1 ])
199- last_r = policy ._value (sample_batch [SampleBatch .NEXT_OBS ][- 1 ],
200- sample_batch [SampleBatch .ACTIONS ][- 1 ],
201- sample_batch [SampleBatch .REWARDS ][- 1 ],
202- * next_state )
196+ # Input dict is provided to us automatically via the Model's
197+ # requirements. It's a single-timestep (last one in trajectory)
198+ # input_dict.
199+ if policy .config ["_use_trajectory_view_api" ]:
200+ # Create an input dict according to the Model's requirements.
201+ input_dict = policy .model .get_input_dict (sample_batch , index = - 1 )
202+ last_r = policy ._value (** input_dict )
203+ # TODO: (sven) Remove once trajectory view API is all-algo default.
204+ else :
205+ next_state = []
206+ for i in range (policy .num_state_tensors ()):
207+ next_state .append (sample_batch ["state_out_{}" .format (i )][- 1 ])
208+ last_r = policy ._value (sample_batch [SampleBatch .NEXT_OBS ][- 1 ],
209+ sample_batch [SampleBatch .ACTIONS ][- 1 ],
210+ sample_batch [SampleBatch .REWARDS ][- 1 ],
211+ * next_state )
203212
204213 # Adds the policy logits, VF preds, and advantages to the batch,
205214 # using GAE ("generalized advantage estimation") or not.
@@ -208,7 +217,9 @@ def postprocess_ppo_gae(
208217 last_r ,
209218 policy .config ["gamma" ],
210219 policy .config ["lambda" ],
211- use_gae = policy .config ["use_gae" ])
220+ use_gae = policy .config ["use_gae" ],
221+ use_critic = policy .config .get ("use_critic" , True ))
222+
212223 return batch
213224
214225
@@ -292,25 +303,40 @@ def __init__(self, obs_space, action_space, config):
292303 # observation.
293304 if config ["use_gae" ]:
294305
295- @make_tf_callable (self .get_session ())
296- def value (ob , prev_action , prev_reward , * state ):
297- model_out , _ = self .model ({
298- SampleBatch .CUR_OBS : tf .convert_to_tensor ([ob ]),
299- SampleBatch .PREV_ACTIONS : tf .convert_to_tensor (
300- [prev_action ]),
301- SampleBatch .PREV_REWARDS : tf .convert_to_tensor (
302- [prev_reward ]),
303- "is_training" : tf .convert_to_tensor ([False ]),
304- }, [tf .convert_to_tensor ([s ]) for s in state ],
305- tf .convert_to_tensor ([1 ]))
306- # [0] = remove the batch dim.
307- return self .model .value_function ()[0 ]
306+ # Input dict is provided to us automatically via the Model's
307+ # requirements. It's a single-timestep (last one in trajectory)
308+ # input_dict.
309+ if config ["_use_trajectory_view_api" ]:
310+
311+ @make_tf_callable (self .get_session ())
312+ def value (** input_dict ):
313+ model_out , _ = self .model .from_batch (
314+ input_dict , is_training = False )
315+ # [0] = remove the batch dim.
316+ return self .model .value_function ()[0 ]
317+
318+ # TODO: (sven) Remove once trajectory view API is all-algo default.
319+ else :
320+
321+ @make_tf_callable (self .get_session ())
322+ def value (ob , prev_action , prev_reward , * state ):
323+ model_out , _ = self .model ({
324+ SampleBatch .CUR_OBS : tf .convert_to_tensor ([ob ]),
325+ SampleBatch .PREV_ACTIONS : tf .convert_to_tensor (
326+ [prev_action ]),
327+ SampleBatch .PREV_REWARDS : tf .convert_to_tensor (
328+ [prev_reward ]),
329+ "is_training" : tf .convert_to_tensor ([False ]),
330+ }, [tf .convert_to_tensor ([s ]) for s in state ],
331+ tf .convert_to_tensor ([1 ]))
332+ # [0] = remove the batch dim.
333+ return self .model .value_function ()[0 ]
308334
309335 # When not doing GAE, we do not require the value function's output.
310336 else :
311337
312338 @make_tf_callable (self .get_session ())
313- def value (ob , prev_action , prev_reward , * state ):
339+ def value (* args , ** kwargs ):
314340 return tf .constant (0.0 )
315341
316342 self ._value = value
0 commit comments