-
Notifications
You must be signed in to change notification settings - Fork 6.1k
/
Copy pathrollout_worker.py
1175 lines (1042 loc) · 49.7 KB
/
rollout_worker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import random
import numpy as np
import gym
import logging
import pickle
import platform
import os
from typing import Callable, Any, List, Dict, Tuple, Union, Optional, \
TYPE_CHECKING, Type, TypeVar
import ray
from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
from ray.rllib.models import ModelCatalog
from ray.rllib.models.preprocessors import NoPreprocessor, Preprocessor
from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
OffPolicyEstimate
from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.filter import get_filter, Filter
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.sgd import do_minibatch_sgd
from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.utils.typing import AgentID, EnvConfigDict, EnvType, \
ModelConfigDict, ModelGradients, ModelWeights, \
MultiAgentPolicyConfigDict, PartialTrainerConfigDict, PolicyID, \
SampleBatchType, TrainerConfigDict
from ray.util.debug import log_once, disable_log_once_globally, \
enable_periodic_logging
from ray.util.iter import ParallelIteratorWorker
if TYPE_CHECKING:
from ray.rllib.evaluation.observation_function import ObservationFunction
# Generic type var for foreach_* methods.
T = TypeVar("T")
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
logger = logging.getLogger(__name__)
# Handle to the current rollout worker, which will be set to the most recently
# created RolloutWorker in this process. This can be helpful to access in
# custom env or policy classes for debugging or advanced use cases.
_global_worker: "RolloutWorker" = None
@DeveloperAPI
def get_global_worker() -> "RolloutWorker":
"""Returns a handle to the active rollout worker in this process."""
global _global_worker
return _global_worker
@DeveloperAPI
class RolloutWorker(ParallelIteratorWorker):
"""Common experience collection class.
This class wraps a policy instance and an environment class to
collect experiences from the environment. You can create many replicas of
this class as Ray actors to scale RL training.
This class supports vectorized and multi-agent policy evaluation (e.g.,
VectorEnv, MultiAgentEnv, etc.)
Examples:
>>> # Create a rollout worker and using it to collect experiences.
>>> worker = RolloutWorker(
... env_creator=lambda _: gym.make("CartPole-v0"),
... policy_spec=PGTFPolicy)
>>> print(worker.sample())
SampleBatch({
"obs": [[...]], "actions": [[...]], "rewards": [[...]],
"dones": [[...]], "new_obs": [[...]]})
>>> # Creating a multi-agent rollout worker
>>> worker = RolloutWorker(
... env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
... policy_spec={
... # Use an ensemble of two policies for car agents
... "car_policy1":
... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}),
... "car_policy2":
... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.95}),
... # Use a single shared policy for all traffic lights
... "traffic_light_policy":
... (PGTFPolicy, Box(...), Discrete(...), {}),
... },
... policy_mapping_fn=lambda agent_id:
... random.choice(["car_policy1", "car_policy2"])
... if agent_id.startswith("car_") else "traffic_light_policy")
>>> print(worker.sample())
MultiAgentBatch({
"car_policy1": SampleBatch(...),
"car_policy2": SampleBatch(...),
"traffic_light_policy": SampleBatch(...)})
"""
@DeveloperAPI
@classmethod
def as_remote(cls,
num_cpus: int = None,
num_gpus: int = None,
memory: int = None,
object_store_memory: int = None,
resources: dict = None) -> type:
return ray.remote(
num_cpus=num_cpus,
num_gpus=num_gpus,
memory=memory,
object_store_memory=object_store_memory,
resources=resources)(cls)
@DeveloperAPI
def __init__(
self,
*,
env_creator: Callable[[EnvContext], EnvType],
validate_env: Optional[Callable[[EnvType, EnvContext],
None]] = None,
policy_spec: Union[type, Dict[
str, Tuple[Optional[type], gym.Space, gym.Space,
PartialTrainerConfigDict]]] = None,
policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
policies_to_train: Optional[List[PolicyID]] = None,
tf_session_creator: Optional[Callable[[], "tf1.Session"]] = None,
rollout_fragment_length: int = 100,
batch_mode: str = "truncate_episodes",
episode_horizon: int = None,
preprocessor_pref: str = "deepmind",
sample_async: bool = False,
compress_observations: bool = False,
num_envs: int = 1,
observation_fn: "ObservationFunction" = None,
observation_filter: str = "NoFilter",
clip_rewards: bool = None,
clip_actions: bool = True,
env_config: EnvConfigDict = None,
model_config: ModelConfigDict = None,
policy_config: TrainerConfigDict = None,
worker_index: int = 0,
num_workers: int = 0,
monitor_path: str = None,
log_dir: str = None,
log_level: str = None,
callbacks: Type["DefaultCallbacks"] = None,
input_creator: Callable[[
IOContext
], InputReader] = lambda ioctx: ioctx.default_sampler_input(),
input_evaluation: List[str] = frozenset([]),
output_creator: Callable[
[IOContext], OutputWriter] = lambda ioctx: NoopOutput(),
remote_worker_envs: bool = False,
remote_env_batch_wait_ms: int = 0,
soft_horizon: bool = False,
no_done_at_end: bool = False,
seed: int = None,
extra_python_environs: dict = None,
fake_sampler: bool = False,
spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
gym.spaces.Space]]] = None,
policy: Union[type, Dict[
str, Tuple[Optional[type], gym.Space, gym.Space,
PartialTrainerConfigDict]]] = None,
):
"""Initialize a rollout worker.
Args:
env_creator (Callable[[EnvContext], EnvType]): Function that
returns a gym.Env given an EnvContext wrapped configuration.
validate_env (Optional[Callable[[EnvType, EnvContext], None]]):
Optional callable to validate the generated environment (only
on worker=0).
policy_spec (Union[type, Dict[str, Tuple[Type[Policy], gym.Space,
gym.Space, PartialTrainerConfigDict]]]): Either a Policy class
or a dict of policy id strings to
(Policy class, obs_space, action_space, config)-tuples. If a
dict is specified, then we are in multi-agent mode and a
policy_mapping_fn can also be set (if not, will map all agents
to DEFAULT_POLICY_ID).
policy_mapping_fn (Optional[Callable[[AgentID], PolicyID]]): A
callable that maps agent ids to policy ids in multi-agent mode.
This function will be called each time a new agent appears in
an episode, to bind that agent to a policy for the duration of
the episode. If not provided, will map all agents to
DEFAULT_POLICY_ID.
policies_to_train (Optional[List[PolicyID]]): Optional list of
policies to train, or None for all policies.
tf_session_creator (Optional[Callable[[], tf1.Session]]): A
function that returns a TF session. This is optional and only
useful with TFPolicy.
rollout_fragment_length (int): The target number of env transitions
to include in each sample batch returned from this worker.
batch_mode (str): One of the following batch modes:
"truncate_episodes": Each call to sample() will return a batch
of at most `rollout_fragment_length * num_envs` in size.
The batch will be exactly
`rollout_fragment_length * num_envs` in size if
postprocessing does not change batch sizes. Episodes may be
truncated in order to meet this size requirement.
"complete_episodes": Each call to sample() will return a batch
of at least `rollout_fragment_length * num_envs` in size.
Episodes will not be truncated, but multiple episodes may
be packed within one batch to meet the batch size. Note
that when `num_envs > 1`, episode steps will be buffered
until the episode completes, and hence batches may contain
significant amounts of off-policy data.
episode_horizon (int): Whether to stop episodes at this horizon.
preprocessor_pref (str): Whether to prefer RLlib preprocessors
("rllib") or deepmind ("deepmind") when applicable.
sample_async (bool): Whether to compute samples asynchronously in
the background, which improves throughput but can cause samples
to be slightly off-policy.
compress_observations (bool): If true, compress the observations.
They can be decompressed with rllib/utils/compression.
num_envs (int): If more than one, will create multiple envs
and vectorize the computation of actions. This has no effect if
if the env already implements VectorEnv.
observation_fn (ObservationFunction): Optional multi-agent
observation function.
observation_filter (str): Name of observation filter to use.
clip_rewards (bool): Whether to clip rewards to [-1, 1] prior to
experience postprocessing. Setting to None means clip for Atari
only.
clip_actions (bool): Whether to clip action values to the range
specified by the policy action space.
env_config (EnvConfigDict): Config to pass to the env creator.
model_config (ModelConfigDict): Config to use when creating the
policy model.
policy_config (TrainerConfigDict): Config to pass to the policy.
In the multi-agent case, this config will be merged with the
per-policy configs specified by `policy_spec`.
worker_index (int): For remote workers, this should be set to a
non-zero and unique value. This index is passed to created envs
through EnvContext so that envs can be configured per worker.
num_workers (int): For remote workers, how many workers altogether
have been created?
monitor_path (str): Write out episode stats and videos to this
directory if specified.
log_dir (str): Directory where logs can be placed.
log_level (str): Set the root log level on creation.
callbacks (DefaultCallbacks): Custom training callbacks.
input_creator (Callable[[IOContext], InputReader]): Function that
returns an InputReader object for loading previous generated
experiences.
input_evaluation (List[str]): How to evaluate the policy
performance. This only makes sense to set when the input is
reading offline data. The possible values include:
- "is": the step-wise importance sampling estimator.
- "wis": the weighted step-wise is estimator.
- "simulation": run the environment in the background, but
use this data for evaluation only and never for learning.
output_creator (Callable[[IOContext], OutputWriter]): Function that
returns an OutputWriter object for saving generated
experiences.
remote_worker_envs (bool): If using num_envs > 1, whether to create
those new envs in remote processes instead of in the current
process. This adds overheads, but can make sense if your envs
remote_env_batch_wait_ms (float): Timeout that remote workers
are waiting when polling environments. 0 (continue when at
least one env is ready) is a reasonable default, but optimal
value could be obtained by measuring your environment
step / reset and model inference perf.
soft_horizon (bool): Calculate rewards but don't reset the
environment when the horizon is hit.
no_done_at_end (bool): Ignore the done=True at the end of the
episode and instead record done=False.
seed (int): Set the seed of both np and tf to this value to
to ensure each remote worker has unique exploration behavior.
extra_python_environs (dict): Extra python environments need to
be set.
fake_sampler (bool): Use a fake (inf speed) sampler for testing.
spaces (Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
gym.spaces.Space]]]): An optional space dict mapping policy IDs
to (obs_space, action_space)-tuples. This is used in case no
Env is created on this RolloutWorker.
policy: Obsoleted arg. Use `policy_spec` instead.
"""
# Deprecated arg.
if policy is not None:
deprecation_warning("policy", "policy_spec", error=False)
policy_spec = policy
assert policy_spec is not None, "Must provide `policy_spec` when " \
"creating RolloutWorker!"
self._original_kwargs: dict = locals().copy()
del self._original_kwargs["self"]
global _global_worker
_global_worker = self
# set extra environs first
if extra_python_environs:
for key, value in extra_python_environs.items():
os.environ[key] = str(value)
def gen_rollouts():
while True:
yield self.sample()
ParallelIteratorWorker.__init__(self, gen_rollouts, False)
policy_config: TrainerConfigDict = policy_config or {}
if (tf1 and policy_config.get("framework") in ["tf2", "tfe"]
# This eager check is necessary for certain all-framework tests
# that use tf's eager_mode() context generator.
and not tf1.executing_eagerly()):
tf1.enable_eager_execution()
if log_level:
logging.getLogger("ray.rllib").setLevel(log_level)
if worker_index > 1:
disable_log_once_globally() # only need 1 worker to log
elif log_level == "DEBUG":
enable_periodic_logging()
env_context = EnvContext(env_config or {}, worker_index)
self.env_context = env_context
self.policy_config: TrainerConfigDict = policy_config
if callbacks:
self.callbacks: "DefaultCallbacks" = callbacks()
else:
from ray.rllib.agents.callbacks import DefaultCallbacks
self.callbacks: "DefaultCallbacks" = DefaultCallbacks()
self.worker_index: int = worker_index
self.num_workers: int = num_workers
model_config: ModelConfigDict = model_config or {}
policy_mapping_fn = (policy_mapping_fn
or (lambda agent_id: DEFAULT_POLICY_ID))
if not callable(policy_mapping_fn):
raise ValueError("Policy mapping function not callable?")
self.env_creator: Callable[[EnvContext], EnvType] = env_creator
self.rollout_fragment_length: int = rollout_fragment_length * num_envs
self.batch_mode: str = batch_mode
self.compress_observations: bool = compress_observations
self.preprocessing_enabled: bool = True
self.last_batch: SampleBatchType = None
self.global_vars: dict = None
self.fake_sampler: bool = fake_sampler
# No Env will be used in this particular worker (not needed).
if worker_index == 0 and num_workers > 0 and \
policy_config["create_env_on_driver"] is False:
self.env = None
# Create an env for this worker.
else:
self.env = _validate_env(env_creator(env_context))
if validate_env is not None:
validate_env(self.env, self.env_context)
if isinstance(self.env, (BaseEnv, MultiAgentEnv)):
def wrap(env):
return env # we can't auto-wrap these env types
elif is_atari(self.env) and \
not model_config.get("custom_preprocessor") and \
preprocessor_pref == "deepmind":
# Deepmind wrappers already handle all preprocessing.
self.preprocessing_enabled = False
# If clip_rewards not explicitly set to False, switch it
# on here (clip between -1.0 and 1.0).
if clip_rewards is None:
clip_rewards = True
def wrap(env):
env = wrap_deepmind(
env,
dim=model_config.get("dim"),
framestack=model_config.get("framestack"))
if monitor_path:
from gym import wrappers
env = wrappers.Monitor(env, monitor_path, resume=True)
return env
else:
def wrap(env):
if monitor_path:
from gym import wrappers
env = wrappers.Monitor(env, monitor_path, resume=True)
return env
self.env: EnvType = wrap(self.env)
def make_env(vector_index):
return wrap(
env_creator(
env_context.copy_with_overrides(
worker_index=worker_index,
vector_index=vector_index,
remote=remote_worker_envs)))
self.make_env_fn = make_env
self.tf_sess = None
policy_dict = _validate_and_canonicalize(
policy_spec, self.env, spaces=spaces)
self.policies_to_train: List[PolicyID] = policies_to_train or list(
policy_dict.keys())
self.policy_map: Dict[PolicyID, Policy] = None
self.preprocessors: Dict[PolicyID, Preprocessor] = None
# set numpy and python seed
if seed is not None:
np.random.seed(seed)
random.seed(seed)
if not hasattr(self.env, "seed"):
logger.info("Env doesn't support env.seed(): {}".format(
self.env))
else:
self.env.seed(seed)
try:
assert torch is not None
torch.manual_seed(seed)
except AssertionError:
logger.info("Could not seed torch")
if _has_tensorflow_graph(policy_dict) and not (
tf1 and tf1.executing_eagerly()):
if not tf1:
raise ImportError("Could not import tensorflow")
with tf1.Graph().as_default():
if tf_session_creator:
self.tf_sess = tf_session_creator()
else:
self.tf_sess = tf1.Session(
config=tf1.ConfigProto(
gpu_options=tf1.GPUOptions(allow_growth=True)))
with self.tf_sess.as_default():
# set graph-level seed
if seed is not None:
tf1.set_random_seed(seed)
self.policy_map, self.preprocessors = \
self._build_policy_map(policy_dict, policy_config)
else:
self.policy_map, self.preprocessors = self._build_policy_map(
policy_dict, policy_config)
if (ray.is_initialized()
and ray.worker._mode() != ray.worker.LOCAL_MODE):
# Check available number of GPUs
if not ray.get_gpu_ids():
logger.debug("Creating policy evaluation worker {}".format(
worker_index) +
" on CPU (please ignore any CUDA init errors)")
elif (policy_config["framework"] in ["tf2", "tf", "tfe"] and
not tf.config.experimental.list_physical_devices("GPU")) or \
(policy_config["framework"] == "torch" and
not torch.cuda.is_available()):
raise RuntimeError(
"GPUs were assigned to this worker by Ray, but "
"your DL framework ({}) reports GPU acceleration is "
"disabled. This could be due to a bad CUDA- or {} "
"installation.".format(policy_config["framework"],
policy_config["framework"]))
self.multiagent: bool = set(
self.policy_map.keys()) != {DEFAULT_POLICY_ID}
if self.multiagent and self.env is not None:
if not ((isinstance(self.env, MultiAgentEnv)
or isinstance(self.env, ExternalMultiAgentEnv))
or isinstance(self.env, BaseEnv)):
raise ValueError(
"Have multiple policies {}, but the env ".format(
self.policy_map) +
"{} is not a subclass of BaseEnv, MultiAgentEnv or "
"ExternalMultiAgentEnv?".format(self.env))
self.filters: Dict[PolicyID, Filter] = {
policy_id: get_filter(observation_filter,
policy.observation_space.shape)
for (policy_id, policy) in self.policy_map.items()
}
if self.worker_index == 0:
logger.info("Built filter map: {}".format(self.filters))
self.num_envs: int = num_envs
if self.env is None:
self.async_env = None
elif "custom_vector_env" in policy_config:
custom_vec_wrapper = policy_config["custom_vector_env"]
self.async_env = custom_vec_wrapper(self.env)
else:
# Always use vector env for consistency even if num_envs = 1.
self.async_env: BaseEnv = BaseEnv.to_base_env(
self.env,
make_env=make_env,
num_envs=num_envs,
remote_envs=remote_worker_envs,
remote_env_batch_wait_ms=remote_env_batch_wait_ms)
# `truncate_episodes`: Allow a batch to contain more than one episode
# (fragments) and always make the batch `rollout_fragment_length`
# long.
if self.batch_mode == "truncate_episodes":
pack = True
# `complete_episodes`: Never cut episodes and sampler will return
# exactly one (complete) episode per poll.
elif self.batch_mode == "complete_episodes":
rollout_fragment_length = float("inf")
pack = False
else:
raise ValueError("Unsupported batch mode: {}".format(
self.batch_mode))
self.io_context: IOContext = IOContext(log_dir, policy_config,
worker_index, self)
self.reward_estimators: List[OffPolicyEstimator] = []
for method in input_evaluation:
if method == "simulation":
logger.warning(
"Requested 'simulation' input evaluation method: "
"will discard all sampler outputs and keep only metrics.")
sample_async = True
elif method == "is":
ise = ImportanceSamplingEstimator.create(self.io_context)
self.reward_estimators.append(ise)
elif method == "wis":
wise = WeightedImportanceSamplingEstimator.create(
self.io_context)
self.reward_estimators.append(wise)
else:
raise ValueError(
"Unknown evaluation method: {}".format(method))
if self.env is None:
self.sampler = None
elif sample_async:
self.sampler = AsyncSampler(
worker=self,
env=self.async_env,
policies=self.policy_map,
policy_mapping_fn=policy_mapping_fn,
preprocessors=self.preprocessors,
obs_filters=self.filters,
clip_rewards=clip_rewards,
rollout_fragment_length=rollout_fragment_length,
callbacks=self.callbacks,
horizon=episode_horizon,
multiple_episodes_in_batch=pack,
tf_sess=self.tf_sess,
clip_actions=clip_actions,
blackhole_outputs="simulation" in input_evaluation,
soft_horizon=soft_horizon,
no_done_at_end=no_done_at_end,
observation_fn=observation_fn,
_use_trajectory_view_api=policy_config.get(
"_use_trajectory_view_api", False))
# Start the Sampler thread.
self.sampler.start()
else:
self.sampler = SyncSampler(
worker=self,
env=self.async_env,
policies=self.policy_map,
policy_mapping_fn=policy_mapping_fn,
preprocessors=self.preprocessors,
obs_filters=self.filters,
clip_rewards=clip_rewards,
rollout_fragment_length=rollout_fragment_length,
callbacks=self.callbacks,
horizon=episode_horizon,
multiple_episodes_in_batch=pack,
tf_sess=self.tf_sess,
clip_actions=clip_actions,
soft_horizon=soft_horizon,
no_done_at_end=no_done_at_end,
observation_fn=observation_fn,
_use_trajectory_view_api=policy_config.get(
"_use_trajectory_view_api", False))
self.input_reader: InputReader = input_creator(self.io_context)
self.output_writer: OutputWriter = output_creator(self.io_context)
logger.debug(
"Created rollout worker with env {} ({}), policies {}".format(
self.async_env, self.env, self.policy_map))
@DeveloperAPI
def sample(self) -> SampleBatchType:
"""Returns a batch of experience sampled from this worker.
This method must be implemented by subclasses.
Returns:
SampleBatchType: A columnar batch of experiences (e.g., tensors).
Examples:
>>> print(worker.sample())
SampleBatch({"obs": [1, 2, 3], "action": [0, 1, 0], ...})
"""
if self.fake_sampler and self.last_batch is not None:
return self.last_batch
if log_once("sample_start"):
logger.info("Generating sample batch of size {}".format(
self.rollout_fragment_length))
batches = [self.input_reader.next()]
steps_so_far = batches[0].count
# In truncate_episodes mode, never pull more than 1 batch per env.
# This avoids over-running the target batch size.
if self.batch_mode == "truncate_episodes":
max_batches = self.num_envs
else:
max_batches = float("inf")
while (steps_so_far < self.rollout_fragment_length
and len(batches) < max_batches):
batch = self.input_reader.next()
steps_so_far += batch.count
batches.append(batch)
batch = batches[0].concat_samples(batches) if len(batches) > 1 else \
batches[0]
self.callbacks.on_sample_end(worker=self, samples=batch)
# Always do writes prior to compression for consistency and to allow
# for better compression inside the writer.
self.output_writer.write(batch)
# Do off-policy estimation if needed
if self.reward_estimators:
for sub_batch in batch.split_by_episode():
for estimator in self.reward_estimators:
estimator.process(sub_batch)
if log_once("sample_end"):
logger.info("Completed sample batch:\n\n{}\n".format(
summarize(batch)))
if self.compress_observations == "bulk":
batch.compress(bulk=True)
elif self.compress_observations:
batch.compress()
if self.fake_sampler:
self.last_batch = batch
return batch
@DeveloperAPI
@ray.method(num_returns=2)
def sample_with_count(self) -> Tuple[SampleBatchType, int]:
"""Same as sample() but returns the count as a separate future."""
batch = self.sample()
return batch, batch.count
@DeveloperAPI
def get_weights(self,
policies: List[PolicyID] = None) -> (ModelWeights, dict):
"""Returns the model weights of this worker.
Returns:
object: weights that can be set on another worker.
info: dictionary of extra metadata.
Examples:
>>> weights = worker.get_weights()
"""
if policies is None:
policies = self.policy_map.keys()
return {
pid: policy.get_weights()
for pid, policy in self.policy_map.items() if pid in policies
}
@DeveloperAPI
def set_weights(self, weights: ModelWeights,
global_vars: dict = None) -> None:
"""Sets the model weights of this worker.
Examples:
>>> weights = worker.get_weights()
>>> worker.set_weights(weights)
"""
for pid, w in weights.items():
self.policy_map[pid].set_weights(w)
if global_vars:
self.set_global_vars(global_vars)
@DeveloperAPI
def compute_gradients(
self, samples: SampleBatchType) -> Tuple[ModelGradients, dict]:
"""Returns a gradient computed w.r.t the specified samples.
Returns:
(grads, info): A list of gradients that can be applied on a
compatible worker. In the multi-agent case, returns a dict
of gradients keyed by policy ids. An info dictionary of
extra metadata is also returned.
Examples:
>>> batch = worker.sample()
>>> grads, info = worker.compute_gradients(samples)
"""
if log_once("compute_gradients"):
logger.info("Compute gradients on:\n\n{}\n".format(
summarize(samples)))
if isinstance(samples, MultiAgentBatch):
grad_out, info_out = {}, {}
if self.tf_sess is not None:
builder = TFRunBuilder(self.tf_sess, "compute_gradients")
for pid, batch in samples.policy_batches.items():
if pid not in self.policies_to_train:
continue
grad_out[pid], info_out[pid] = (
self.policy_map[pid]._build_compute_gradients(
builder, batch))
grad_out = {k: builder.get(v) for k, v in grad_out.items()}
info_out = {k: builder.get(v) for k, v in info_out.items()}
else:
for pid, batch in samples.policy_batches.items():
if pid not in self.policies_to_train:
continue
grad_out[pid], info_out[pid] = (
self.policy_map[pid].compute_gradients(batch))
else:
grad_out, info_out = (
self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples))
info_out["batch_count"] = samples.count
if log_once("grad_out"):
logger.info("Compute grad info:\n\n{}\n".format(
summarize(info_out)))
return grad_out, info_out
@DeveloperAPI
def apply_gradients(self, grads: ModelGradients) -> Dict[PolicyID, Any]:
"""Applies the given gradients to this worker's weights.
Examples:
>>> samples = worker.sample()
>>> grads, info = worker.compute_gradients(samples)
>>> worker.apply_gradients(grads)
"""
if log_once("apply_gradients"):
logger.info("Apply gradients:\n\n{}\n".format(summarize(grads)))
if isinstance(grads, dict):
if self.tf_sess is not None:
builder = TFRunBuilder(self.tf_sess, "apply_gradients")
outputs = {
pid: self.policy_map[pid]._build_apply_gradients(
builder, grad)
for pid, grad in grads.items()
}
return {k: builder.get(v) for k, v in outputs.items()}
else:
return {
pid: self.policy_map[pid].apply_gradients(g)
for pid, g in grads.items()
}
else:
return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
@DeveloperAPI
def learn_on_batch(self, samples: SampleBatchType) -> dict:
"""Update policies based on the given batch.
This is the equivalent to apply_gradients(compute_gradients(samples)),
but can be optimized to avoid pulling gradients into CPU memory.
Returns:
info: dictionary of extra metadata from compute_gradients().
Examples:
>>> batch = worker.sample()
>>> worker.learn_on_batch(samples)
"""
if log_once("learn_on_batch"):
logger.info(
"Training on concatenated sample batches:\n\n{}\n".format(
summarize(samples)))
if isinstance(samples, MultiAgentBatch):
info_out = {}
to_fetch = {}
if self.tf_sess is not None:
builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
else:
builder = None
for pid, batch in samples.policy_batches.items():
if pid not in self.policies_to_train:
continue
policy = self.policy_map[pid]
if builder and hasattr(policy, "_build_learn_on_batch"):
to_fetch[pid] = policy._build_learn_on_batch(
builder, batch)
else:
info_out[pid] = policy.learn_on_batch(batch)
info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
else:
info_out = {
DEFAULT_POLICY_ID: self.policy_map[DEFAULT_POLICY_ID]
.learn_on_batch(samples)
}
if log_once("learn_out"):
logger.debug("Training out:\n\n{}\n".format(summarize(info_out)))
return info_out
def sample_and_learn(self, expected_batch_size: int, num_sgd_iter: int,
sgd_minibatch_size: str,
standardize_fields: List[str]) -> Tuple[dict, int]:
"""Sample and batch and learn on it.
This is typically used in combination with distributed allreduce.
Args:
expected_batch_size (int): Expected number of samples to learn on.
num_sgd_iter (int): Number of SGD iterations.
sgd_minibatch_size (int): SGD minibatch size.
standardize_fields (list): List of sample fields to normalize.
Returns:
info: dictionary of extra metadata from learn_on_batch().
count: number of samples learned on.
"""
batch = self.sample()
assert batch.count == expected_batch_size, \
("Batch size possibly out of sync between workers, expected:",
expected_batch_size, "got:", batch.count)
logger.info("Executing distributed minibatch SGD "
"with epoch size {}, minibatch size {}".format(
batch.count, sgd_minibatch_size))
info = do_minibatch_sgd(batch, self.policy_map, self, num_sgd_iter,
sgd_minibatch_size, standardize_fields)
return info, batch.count
@DeveloperAPI
def get_metrics(self) -> List[Union[RolloutMetrics, OffPolicyEstimate]]:
"""Returns a list of new RolloutMetric objects from evaluation."""
# Get metrics from sampler (if any).
if self.sampler is not None:
out = self.sampler.get_metrics()
else:
out = []
# Get metrics from our reward-estimators (if any).
for m in self.reward_estimators:
out.extend(m.get_metrics())
return out
@DeveloperAPI
def foreach_env(self, func: Callable[[BaseEnv], T]) -> List[T]:
"""Apply the given function to each underlying env instance."""
if self.async_env is None:
return []
envs = self.async_env.get_unwrapped()
if not envs:
return [func(self.async_env)]
else:
return [func(e) for e in envs]
@DeveloperAPI
def get_policy(
self, policy_id: Optional[PolicyID] = DEFAULT_POLICY_ID) -> Policy:
"""Return policy for the specified id, or None.
Args:
policy_id (str): id of policy to return.
"""
return self.policy_map.get(policy_id)
@DeveloperAPI
def for_policy(self,
func: Callable[[Policy], T],
policy_id: Optional[PolicyID] = DEFAULT_POLICY_ID,
**kwargs) -> T:
"""Apply the given function to the specified policy."""
return func(self.policy_map[policy_id], **kwargs)
@DeveloperAPI
def foreach_policy(self, func: Callable[[Policy, PolicyID], T],
**kwargs) -> List[T]:
"""Apply the given function to each (policy, policy_id) tuple."""
return [
func(policy, pid, **kwargs)
for pid, policy in self.policy_map.items()
]
@DeveloperAPI
def foreach_trainable_policy(self, func: Callable[[Policy, PolicyID], T],
**kwargs) -> List[T]:
"""
Applies the given function to each (policy, policy_id) tuple, which
can be found in `self.policies_to_train`.
Args:
func (callable): A function - taking a Policy and its ID - that is
called on all Policies within `self.policies_to_train`.
Returns:
List[any]: The list of n return values of all
`func([policy], [ID])`-calls.
"""
return [
func(policy, pid, **kwargs)
for pid, policy in self.policy_map.items()
if pid in self.policies_to_train
]
@DeveloperAPI
def sync_filters(self, new_filters: dict) -> None:
"""Changes self's filter to given and rebases any accumulated delta.
Args:
new_filters (dict): Filters with new state to update local copy.
"""
assert all(k in new_filters for k in self.filters)
for k in self.filters:
self.filters[k].sync(new_filters[k])
@DeveloperAPI
def get_filters(self, flush_after: bool = False) -> dict:
"""Returns a snapshot of filters.
Args:
flush_after (bool): Clears the filter buffer state.
Returns:
return_filters (dict): Dict for serializable filters
"""
return_filters = {}
for k, f in self.filters.items():
return_filters[k] = f.as_serializable()
if flush_after:
f.clear_buffer()
return return_filters
@DeveloperAPI
def save(self) -> str:
filters = self.get_filters(flush_after=True)
state = {
pid: self.policy_map[pid].get_state()
for pid in self.policy_map
}
return pickle.dumps({"filters": filters, "state": state})
@DeveloperAPI
def restore(self, objs: str) -> None:
objs = pickle.loads(objs)
self.sync_filters(objs["filters"])
for pid, state in objs["state"].items():
self.policy_map[pid].set_state(state)
@DeveloperAPI
def set_global_vars(self, global_vars: dict) -> None:
self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars))
self.global_vars = global_vars
@DeveloperAPI
def get_global_vars(self) -> dict:
return self.global_vars
@DeveloperAPI
def export_policy_model(self,
export_dir: str,
policy_id: PolicyID = DEFAULT_POLICY_ID):
self.policy_map[policy_id].export_model(export_dir)
@DeveloperAPI
def import_policy_model_from_h5(self,
import_file: str,
policy_id: PolicyID = DEFAULT_POLICY_ID):
self.policy_map[policy_id].import_model_from_h5(import_file)
@DeveloperAPI
def export_policy_checkpoint(self,
export_dir: str,
filename_prefix: str = "model",
policy_id: PolicyID = DEFAULT_POLICY_ID):
self.policy_map[policy_id].export_checkpoint(export_dir,
filename_prefix)
@DeveloperAPI