Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rllib/agents/ddpg/ddpg_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def setup_late_mixins(policy, obs_space, action_space, config):
optimizer_fn=make_ddpg_optimizers,
validate_spaces=validate_spaces,
before_init=before_init_fn,
after_init=setup_late_mixins,
before_loss_init=setup_late_mixins,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the more accurate kwarg to use (torch did not have a loss init step before, so this is new). The old after_init still works the exact same and thus this does not cause an API-break.

action_distribution_fn=get_distribution_inputs_and_class,
make_model_and_action_dist=build_ddpg_models_and_action_dist,
apply_gradients_fn=apply_gradients_fn,
Expand Down
8 changes: 4 additions & 4 deletions rllib/agents/dqn/dqn_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,9 @@ def setup_early_mixins(policy: Policy, obs_space, action_space,
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])


def after_init(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
def before_loss_init(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
ComputeTDErrorMixin.__init__(policy)
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
# Move target net to device (this is done automatically for the
Expand Down Expand Up @@ -397,7 +397,7 @@ def extra_action_out_fn(policy: Policy, input_dict, state_batches, model,
extra_learn_fetches_fn=lambda policy: {"td_error": policy.q_loss.td_error},
extra_action_out_fn=extra_action_out_fn,
before_init=setup_early_mixins,
after_init=after_init,
before_loss_init=before_loss_init,
mixins=[
TargetNetworkMixin,
ComputeTDErrorMixin,
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/marwil/marwil_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,5 @@ def setup_mixins(policy, obs_space, action_space, config):
get_default_config=lambda: ray.rllib.agents.marwil.marwil.DEFAULT_CONFIG,
stats_fn=stats,
postprocess_fn=postprocess_advantages,
after_init=setup_mixins,
before_loss_init=setup_mixins,
mixins=[ValueNetworkMixin])
2 changes: 1 addition & 1 deletion rllib/agents/ppo/appo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
extra_grad_process_fn=apply_grad_clipping,
optimizer_fn=choose_optimizer,
before_init=setup_early_mixins,
after_init=setup_late_mixins,
before_loss_init=setup_late_mixins,
make_model=make_appo_model,
mixins=[
LearningRateSchedule, KLCoeffMixin, TargetNetworkMixin,
Expand Down
7 changes: 6 additions & 1 deletion rllib/agents/ppo/ppo_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ def ppo_surrogate_loss(

# RNN case: Mask away 0-padded chunks at end of time axis.
if state:
max_seq_len = tf.reduce_max(train_batch["seq_lens"])
# Derive max_seq_len from the data itself, not from the seq_lens
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prep for attention nets, where dynamic max'ing over the given sequences is not allowed.

# tensor. This is in case e.g. seq_lens=[2, 3], but the data is still
# 0-padded up to T=5 (as it's the case for attention nets).
B = tf.shape(train_batch["seq_lens"])[0]
max_seq_len = tf.shape(logits)[0] // B

mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len)
mask = tf.reshape(mask, [-1])

Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/ppo/ppo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
postprocess_fn=postprocess_ppo_gae,
extra_grad_process_fn=apply_grad_clipping,
before_init=setup_config,
after_init=setup_mixins,
before_loss_init=setup_mixins,
mixins=[
LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
ValueNetworkMixin
Expand Down
1 change: 1 addition & 0 deletions rllib/agents/ppo/tests/test_appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def test_appo_compilation(self):
for _ in framework_iterator(config):
print("w/o v-trace")
_config = config.copy()
_config["vtrace"] = False
trainer = ppo.APPOTrainer(config=_config, env="CartPole-v0")
for i in range(num_iterations):
print(trainer.train())
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/sac/sac_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
extra_grad_process_fn=apply_grad_clipping,
optimizer_fn=optimizer_fn,
validate_spaces=validate_spaces,
after_init=setup_late_mixins,
before_loss_init=setup_late_mixins,
make_model_and_action_dist=build_sac_model_and_action_dist,
mixins=[TargetNetworkMixin, ComputeTDErrorMixin],
action_distribution_fn=action_distribution_fn,
Expand Down
2 changes: 1 addition & 1 deletion rllib/contrib/maddpg/maddpg_policy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import ray
from ray.rllib.agents.dqn.dqn_tf_policy import minimize_and_clip, _adjust_nstep
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.policy.policy import Policy
Expand Down
10 changes: 7 additions & 3 deletions rllib/evaluation/collectors/simple_list_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ def build(self, view_requirements: Dict[str, ViewRequirement]) -> \
batch = SampleBatch(batch_data)

if SampleBatch.UNROLL_ID not in batch.data:
# TODO: (sven) Once we have the additional
# model.preprocess_train_batch in place (attention net PR), we
# should not even need UNROLL_ID anymore:
# Add "if SampleBatch.UNROLL_ID in view_requirements:" here.
batch.data[SampleBatch.UNROLL_ID] = np.repeat(
_AgentCollector._next_unroll_id, batch.count)
_AgentCollector._next_unroll_id += 1
Expand Down Expand Up @@ -238,7 +242,7 @@ def add_postprocessed_batch_for_training(
"""
for view_col, data in batch.items():
# Skip columns that are not used for training.
if view_col in view_requirements and \
if view_col not in view_requirements or \
not view_requirements[view_col].used_for_training:
continue
self.buffers[view_col].extend(data)
Expand Down Expand Up @@ -465,8 +469,7 @@ def postprocess_episode(self,
pre_batch = collector.build(policy.view_requirements)
pre_batches[agent_id] = (policy, pre_batch)

# Apply postprocessor.
post_batches = {}
# Apply reward clipping before calling postprocessing functions.
if self.clip_rewards is True:
for _, (_, pre_batch) in pre_batches.items():
pre_batch["rewards"] = np.sign(pre_batch["rewards"])
Expand All @@ -477,6 +480,7 @@ def postprocess_episode(self,
a_min=-self.clip_rewards,
a_max=self.clip_rewards)

post_batches = {}
for agent_id, (_, pre_batch) in pre_batches.items():
# Entire episode is said to be done.
# Error if no DONE at end of this agent's trajectory.
Expand Down
5 changes: 3 additions & 2 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def __init__(
directory if specified.
log_dir (str): Directory where logs can be placed.
log_level (str): Set the root log level on creation.
callbacks (DefaultCallbacks): Custom training callbacks.
callbacks (Type[DefaultCallbacks]): Custom sub-class of
DefaultCallbacks for training/policy/rollout-worker callbacks.
input_creator (Callable[[IOContext], InputReader]): Function that
returns an InputReader object for loading previous generated
experiences.
Expand Down Expand Up @@ -340,7 +341,7 @@ def gen_rollouts():
self.callbacks: "DefaultCallbacks" = callbacks()
else:
from ray.rllib.agents.callbacks import DefaultCallbacks
self.callbacks: "DefaultCallbacks" = DefaultCallbacks()
self.callbacks: DefaultCallbacks = DefaultCallbacks()
self.worker_index: int = worker_index
self.num_workers: int = num_workers
model_config: ModelConfigDict = model_config or {}
Expand Down
15 changes: 11 additions & 4 deletions rllib/evaluation/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,9 @@ def _process_observations_w_trajectory_view_api(
agent_id)
episode._set_last_observation(agent_id, filtered_obs)
episode._set_last_raw_obs(agent_id, raw_obs)
episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))
# Infos from the environment.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding "infos" to the collector's, if required.

agent_infos = infos[env_id].get(agent_id, {})
episode._set_last_info(agent_id, agent_infos)

# Record transition info if applicable.
if last_observation is None:
Expand All @@ -1058,15 +1060,20 @@ def _process_observations_w_trajectory_view_api(
"new_obs": filtered_obs,
}
# Add extra-action-fetches to collectors.
values_dict.update(**episode.last_pi_info_for(agent_id))
pol = policies[policy_id]
for key, value in episode.last_pi_info_for(agent_id).items():
values_dict[key] = value
# Env infos for this agent.
if "infos" in pol.view_requirements:
values_dict["infos"] = agent_infos
_sample_collector.add_action_reward_next_obs(
episode.episode_id, agent_id, env_id, policy_id,
agent_done, values_dict)

if not agent_done:
item = PolicyEvalData(
env_id, agent_id, filtered_obs, infos[env_id].get(
agent_id, {}), None if last_observation is None else
env_id, agent_id, filtered_obs, agent_infos, None
if last_observation is None else
episode.rnn_state_for(agent_id), None
if last_observation is None else
episode.last_action_for(agent_id),
Expand Down
14 changes: 6 additions & 8 deletions rllib/evaluation/tests/test_trajectory_view_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.examples.policy.episode_env_aware_policy import \
EpisodeEnvAwarePolicy
EpisodeEnvAwareLSTMPolicy
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
Expand Down Expand Up @@ -121,7 +121,6 @@ def test_traj_view_simple_performance(self):
obs_space = Box(-1.0, 1.0, shape=(700, ))

from ray.rllib.examples.env.random_env import RandomMultiAgentEnv

from ray.tune import register_env
register_env("ma_env", lambda c: RandomMultiAgentEnv({
"num_agents": 2,
Expand All @@ -147,7 +146,6 @@ def policy_fn(agent_id):
"policy_mapping_fn": policy_fn,
}
num_iterations = 2
# Only works in torch so far.
for _ in framework_iterator(config, frameworks="torch"):
print("w/ traj. view API")
config["_use_trajectory_view_api"] = True
Expand Down Expand Up @@ -253,7 +251,7 @@ def test_traj_view_lstm_functionality(self):
rollout_fragment_length = 200
assert rollout_fragment_length % max_seq_len == 0
policies = {
"pol0": (EpisodeEnvAwarePolicy, obs_space, action_space, {}),
"pol0": (EpisodeEnvAwareLSTMPolicy, obs_space, action_space, {}),
}

def policy_fn(agent_id):
Expand Down Expand Up @@ -316,8 +314,8 @@ def analyze_rnn_batch(batch, max_seq_len):
state_in_1 = batch["state_in_1"][idx]

# Check postprocessing outputs.
if "postprocessed_column" in batch:
postprocessed_col_t = batch["postprocessed_column"][idx]
if "2xobs" in batch:
postprocessed_col_t = batch["2xobs"][idx]
assert (obs_t == postprocessed_col_t / 2.0).all()

# Check state-in/out and next-obs values.
Expand Down Expand Up @@ -386,8 +384,8 @@ def analyze_rnn_batch(batch, max_seq_len):
r_t = batch["rewards"][k]

# Check postprocessing outputs.
if "postprocessed_column" in batch:
postprocessed_col_t = batch["postprocessed_column"][k]
if "2xobs" in batch:
postprocessed_col_t = batch["2xobs"][k]
assert (obs_t == postprocessed_col_t / 2.0).all()

# Check state-in/out and next-obs values.
Expand Down
66 changes: 64 additions & 2 deletions rllib/examples/policy/episode_env_aware_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ray.rllib.utils.annotations import override


class EpisodeEnvAwarePolicy(RandomPolicy):
class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
"""A Policy that always knows the current EpisodeID and EnvID and
returns these in its actions."""

Expand Down Expand Up @@ -78,5 +78,67 @@ def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
sample_batch["postprocessed_column"] = sample_batch["obs"] * 2.0
sample_batch["2xobs"] = sample_batch["obs"] * 2.0
return sample_batch


class EpisodeEnvAwareAttentionPolicy(RandomPolicy):
"""A Policy that always knows the current EpisodeID and EnvID and
returns these in its actions."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.state_space = Box(-1.0, 1.0, (1, ))
self.config["model"] = {"max_seq_len": 50}

class _fake_model:
pass

self.model = _fake_model()
self.model.inference_view_requirements = {
SampleBatch.AGENT_INDEX: ViewRequirement(),
SampleBatch.EPS_ID: ViewRequirement(),
"env_id": ViewRequirement(),
"t": ViewRequirement(),
SampleBatch.OBS: ViewRequirement(),
"state_in_0": ViewRequirement(
"state_out_0",
# Provide state outs -50 to -1 as "state-in".
data_rel_pos="-50:-1",
# Repeat the incoming state every n time steps (usually max seq
# len).
batch_repeat_value=self.config["model"]["max_seq_len"],
space=self.state_space)
}

self.view_requirements = dict(super()._get_default_view_requirements(),
**self.model.inference_view_requirements)

@override(Policy)
def is_recurrent(self):
return True

@override(Policy)
def compute_actions_from_input_dict(self,
input_dict,
explore=None,
timestep=None,
**kwargs):
ts = input_dict["t"]
print(ts)
# Always return [episodeID, envID] as actions.
actions = np.array([[
input_dict[SampleBatch.AGENT_INDEX][i],
input_dict[SampleBatch.EPS_ID][i], input_dict["env_id"][i]
] for i, _ in enumerate(input_dict["obs"])])
states = [np.array([[ts[i]] for i in range(len(input_dict["obs"]))])]
self.global_timestep += 1
return actions, states, {}

@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
sample_batch["3xobs"] = sample_batch["obs"] * 3.0
return sample_batch
1 change: 1 addition & 0 deletions rllib/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ def _get_default_view_requirements(self):
SampleBatch.INFOS: ViewRequirement(),
SampleBatch.EPS_ID: ViewRequirement(),
SampleBatch.AGENT_INDEX: ViewRequirement(),
SampleBatch.UNROLL_ID: ViewRequirement(),
"t": ViewRequirement(),
}

Expand Down
2 changes: 1 addition & 1 deletion rllib/policy/sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(self, *args, **kwargs):
if self.seq_lens is not None and len(self.seq_lens) > 0:
self.count = sum(self.seq_lens)
else:
self.count = len(self.data[k])
self.count = len(next(iter(self.data.values())))

# Keeps track of new columns added after initial ones.
self.new_columns = []
Expand Down
4 changes: 4 additions & 0 deletions rllib/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,8 @@ def compute_gradients(self,
)

train_batch = self._lazy_tensor_dict(postprocessed_batch)

# Calculate the actual policy loss.
loss_out = force_list(
self._loss(self, self.model, self.dist_class, train_batch))

Expand All @@ -369,13 +371,15 @@ def compute_gradients(self,

assert len(loss_out) == len(self._optimizers)

# assert not any(torch.isnan(l) for l in loss_out)
fetches = self.extra_compute_grad_fetches()

# Loop through all optimizers.
grad_info = {"allreduce_latency": 0.0}

all_grads = []
for i, opt in enumerate(self._optimizers):
# Erase gradients in all vars of this optimizer.
opt.zero_grad()
# Recompute gradients of loss over all variables.
loss_out[i].backward(retain_graph=(i < len(self._optimizers) - 1))
Expand Down
2 changes: 1 addition & 1 deletion rllib/utils/exploration/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def get_exploration_loss(self, policy_loss: List[TensorType],
Policy's own loss function and maybe the Model's custom loss.
train_batch (SampleBatch): The training data to calculate the
loss(es) for. This train data has already gone through
this Exploration's `preprocess_train_batch()` method.
this Exploration's `postprocess_trajectory()` method.

Returns:
List[TensorType]: The updated list of loss terms.
Expand Down
2 changes: 1 addition & 1 deletion rllib/utils/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def minibatches(samples, sgd_minibatch_size):
# Replace with `if samples.seq_lens` check.
if "state_in_0" in samples.data or "state_out_0" in samples.data:
if log_once("not_shuffling_rnn_data_in_simple_mode"):
logger.warning("Not shuffling RNN data for SGD in simple mode")
logger.warning("Not time-shuffling RNN data for SGD.")
else:
samples.shuffle()

Expand Down
Loading