Skip to content

Commit f4b313e

Browse files
SieversLeonericl
authored andcommitted
[rllib] Moved clip_action into policy_graph; Clip actions in compute_single_action (#4459)
* Moved clip_action into policy_graph; Clip actions in compute_single_action * Update policy_graph.py * Changed formatting * Updated codebase for convencience
1 parent 5133b10 commit f4b313e

File tree

4 files changed

+46
-34
lines changed

4 files changed

+46
-34
lines changed

python/ray/rllib/agents/agent.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,9 +448,19 @@ def compute_action(self,
448448
preprocessed, update=False)
449449
if state:
450450
return self.get_policy(policy_id).compute_single_action(
451-
filtered_obs, state, prev_action, prev_reward, info)
451+
filtered_obs,
452+
state,
453+
prev_action,
454+
prev_reward,
455+
info,
456+
clip_actions=self.config["clip_actions"])
452457
return self.get_policy(policy_id).compute_single_action(
453-
filtered_obs, state, prev_action, prev_reward, info)[0]
458+
filtered_obs,
459+
state,
460+
prev_action,
461+
prev_reward,
462+
info,
463+
clip_actions=self.config["clip_actions"])[0]
454464

455465
@property
456466
def iteration(self):

python/ray/rllib/evaluation/policy_graph.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from __future__ import division
33
from __future__ import print_function
44

5+
import numpy as np
6+
import gym
7+
58
from ray.rllib.utils.annotations import DeveloperAPI
69

710

@@ -81,6 +84,7 @@ def compute_single_action(self,
8184
prev_reward=None,
8285
info=None,
8386
episode=None,
87+
clip_actions=False,
8488
**kwargs):
8589
"""Unbatched version of compute_actions.
8690
@@ -93,6 +97,7 @@ def compute_single_action(self,
9397
episode (MultiAgentEpisode): this provides access to all of the
9498
internal episode state, which may be useful for model-based or
9599
multi-agent algorithms.
100+
clip_actions (bool): should the action be clipped
96101
kwargs: forward compatibility placeholder
97102
98103
Returns:
@@ -119,6 +124,8 @@ def compute_single_action(self,
119124
prev_reward_batch=prev_reward_batch,
120125
info_batch=info_batch,
121126
episodes=episodes)
127+
if clip_actions:
128+
action = clip_action(action, self.action_space)
122129
return action, [s[0] for s in state_out], \
123130
{k: v[0] for k, v in info.items()}
124131

@@ -263,3 +270,28 @@ def export_checkpoint(self, export_dir):
263270
export_dir (str): Local writable directory.
264271
"""
265272
raise NotImplementedError
273+
274+
275+
def clip_action(action, space):
276+
"""Called to clip actions to the specified range of this policy.
277+
278+
Arguments:
279+
action: Single action.
280+
space: Action space the actions should be present in.
281+
282+
Returns:
283+
Clipped batch of actions.
284+
"""
285+
286+
if isinstance(space, gym.spaces.Box):
287+
return np.clip(action, space.low, space.high)
288+
elif isinstance(space, gym.spaces.Tuple):
289+
if type(action) not in (tuple, list):
290+
raise ValueError("Expected tuple space for actions {}: {}".format(
291+
action, space))
292+
out = []
293+
for a, s in zip(action, space.spaces):
294+
out.append(clip_action(a, s))
295+
return out
296+
else:
297+
return action

python/ray/rllib/evaluation/sampler.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from __future__ import division
33
from __future__ import print_function
44

5-
import gym
65
from collections import defaultdict, namedtuple
76
import logging
87
import numpy as np
@@ -21,6 +20,7 @@
2120
from ray.rllib.utils.annotations import override
2221
from ray.rllib.utils.debug import log_once, summarize
2322
from ray.rllib.utils.tf_run_builder import TFRunBuilder
23+
from ray.rllib.evaluation.policy_graph import clip_action
2424

2525
logger = logging.getLogger(__name__)
2626

@@ -224,31 +224,6 @@ def get_extra_batches(self):
224224
return extra
225225

226226

227-
def clip_action(action, space):
228-
"""Called to clip actions to the specified range of this policy.
229-
230-
Arguments:
231-
action: Single action.
232-
space: Action space the actions should be present in.
233-
234-
Returns:
235-
Clipped batch of actions.
236-
"""
237-
238-
if isinstance(space, gym.spaces.Box):
239-
return np.clip(action, space.low, space.high)
240-
elif isinstance(space, gym.spaces.Tuple):
241-
if type(action) not in (tuple, list):
242-
raise ValueError("Expected tuple space for actions {}: {}".format(
243-
action, space))
244-
out = []
245-
for a, s in zip(action, space.spaces):
246-
out.append(clip_action(a, s))
247-
return out
248-
else:
249-
return action
250-
251-
252227
def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
253228
unroll_length, horizon, preprocessors, obs_filters,
254229
clip_rewards, clip_actions, pack, callbacks, tf_sess,

python/ray/rllib/rollout.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import ray
1414
from ray.rllib.agents.registry import get_agent_class
1515
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
16-
from ray.rllib.evaluation.sampler import clip_action
1716
from ray.tune.util import merge_dicts
1817

1918
EXAMPLE_USAGE = """
@@ -155,11 +154,7 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
155154
else:
156155
action = agent.compute_action(state)
157156

158-
if agent.config["clip_actions"]:
159-
clipped_action = clip_action(action, env.action_space)
160-
next_state, reward, done, _ = env.step(clipped_action)
161-
else:
162-
next_state, reward, done, _ = env.step(action)
157+
next_state, reward, done, _ = env.step(action)
163158

164159
if multiagent:
165160
done = done["__all__"]

0 commit comments

Comments
 (0)