Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions python/ray/rllib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,9 +437,19 @@ def compute_action(self,
preprocessed, update=False)
if state:
return self.get_policy(policy_id).compute_single_action(
filtered_obs, state, prev_action, prev_reward, info)
filtered_obs,
state,
prev_action,
prev_reward,
info,
clip_actions=self.config["clip_actions"])
return self.get_policy(policy_id).compute_single_action(
filtered_obs, state, prev_action, prev_reward, info)[0]
filtered_obs,
state,
prev_action,
prev_reward,
info,
clip_actions=self.config["clip_actions"])[0]

@property
def iteration(self):
Expand Down Expand Up @@ -657,12 +667,12 @@ def session_creator():
input_creator = (lambda ioctx: ioctx.default_sampler_input())
elif isinstance(config["input"], dict):
input_creator = (lambda ioctx: ShuffledInput(
MixedInput(config["input"], ioctx),
config["shuffle_buffer_size"]))
MixedInput(config["input"], ioctx), config[
"shuffle_buffer_size"]))
else:
input_creator = (lambda ioctx: ShuffledInput(
JsonReader(config["input"], ioctx),
config["shuffle_buffer_size"]))
JsonReader(config["input"], ioctx), config[
"shuffle_buffer_size"]))

if isinstance(config["output"], FunctionType):
output_creator = config["output"]
Expand Down
32 changes: 32 additions & 0 deletions python/ray/rllib/evaluation/policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from __future__ import division
from __future__ import print_function

import numpy as np
import gym

from ray.rllib.utils.annotations import DeveloperAPI


Expand Down Expand Up @@ -81,6 +84,7 @@ def compute_single_action(self,
prev_reward=None,
info=None,
episode=None,
clip_actions=False,
**kwargs):
"""Unbatched version of compute_actions.

Expand All @@ -93,6 +97,7 @@ def compute_single_action(self,
episode (MultiAgentEpisode): this provides access to all of the
internal episode state, which may be useful for model-based or
multi-agent algorithms.
clip_actions (bool): should the action be clipped
kwargs: forward compatibility placeholder

Returns:
Expand All @@ -119,6 +124,8 @@ def compute_single_action(self,
prev_reward_batch=prev_reward_batch,
info_batch=info_batch,
episodes=episodes)
if clip_actions:
action = clip_action(action, self.action_space)
return action, [s[0] for s in state_out], \
{k: v[0] for k, v in info.items()}

Expand Down Expand Up @@ -263,3 +270,28 @@ def export_checkpoint(self, export_dir):
export_dir (str): Local writable directory.
"""
raise NotImplementedError


def clip_action(action, space):
"""Called to clip actions to the specified range of this policy.

Arguments:
action: Single action.
space: Action space the actions should be present in.

Returns:
Clipped batch of actions.
"""

if isinstance(space, gym.spaces.Box):
return np.clip(action, space.low, space.high)
elif isinstance(space, gym.spaces.Tuple):
if type(action) not in (tuple, list):
raise ValueError("Expected tuple space for actions {}: {}".format(
action, space))
out = []
for a, s in zip(action, space.spaces):
out.append(clip_action(a, s))
return out
else:
return action
27 changes: 1 addition & 26 deletions python/ray/rllib/evaluation/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from __future__ import division
from __future__ import print_function

import gym
from collections import defaultdict, namedtuple
import logging
import numpy as np
Expand All @@ -21,6 +20,7 @@
from ray.rllib.utils.annotations import override
from ray.rllib.utils.debug import log_once, summarize
from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.evaluation.policy_graph import clip_action

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -224,31 +224,6 @@ def get_extra_batches(self):
return extra


def clip_action(action, space):
"""Called to clip actions to the specified range of this policy.

Arguments:
action: Single action.
space: Action space the actions should be present in.

Returns:
Clipped batch of actions.
"""

if isinstance(space, gym.spaces.Box):
return np.clip(action, space.low, space.high)
elif isinstance(space, gym.spaces.Tuple):
if type(action) not in (tuple, list):
raise ValueError("Expected tuple space for actions {}: {}".format(
action, space))
out = []
for a, s in zip(action, space.spaces):
out.append(clip_action(a, s))
return out
else:
return action


def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
unroll_length, horizon, preprocessors, obs_filters,
clip_rewards, clip_actions, pack, callbacks, tf_sess,
Expand Down
7 changes: 1 addition & 6 deletions python/ray/rllib/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import ray
from ray.rllib.agents.registry import get_agent_class
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.evaluation.sampler import clip_action
from ray.tune.util import merge_dicts

EXAMPLE_USAGE = """
Expand Down Expand Up @@ -155,11 +154,7 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
else:
action = agent.compute_action(state)

if agent.config["clip_actions"]:
clipped_action = clip_action(action, env.action_space)
next_state, reward, done, _ = env.step(clipped_action)
else:
next_state, reward, done, _ = env.step(action)
next_state, reward, done, _ = env.step(action)

if multiagent:
done = done["__all__"]
Expand Down