diff --git a/deep_rl/component/replay.py b/deep_rl/component/replay.py index d688a63c..54168f54 100644 --- a/deep_rl/component/replay.py +++ b/deep_rl/component/replay.py @@ -202,12 +202,12 @@ class ReplayWrapper(mp.Process): EXIT = 2 UPDATE_PRIORITIES = 3 - def __init__(self, replay_cls, replay_kwargs, async=True): + def __init__(self, replay_cls, replay_kwargs, async_flag=True): mp.Process.__init__(self) self.replay_kwargs = replay_kwargs self.replay_cls = replay_cls self.cache_len = 2 - if async: + if async_flag: self.pipe, self.worker_pipe = mp.Pipe() self.start() else: diff --git a/examples.py b/examples.py index c6e49a30..bc3f344e 100644 --- a/examples.py +++ b/examples.py @@ -113,7 +113,7 @@ def quantile_regression_dqn_feature(**kwargs): replay_kwargs = dict( memory_size=int(1e4), batch_size=config.batch_size) - config.replay_fn = lambda: ReplayWrapper(UniformReplay, replay_kwargs, async=True) + config.replay_fn = lambda: ReplayWrapper(UniformReplay, replay_kwargs) config.random_action_prob = LinearSchedule(1.0, 0.1, 1e4) config.discount = 0.99 @@ -146,7 +146,7 @@ def quantile_regression_dqn_pixel(**kwargs): batch_size=config.batch_size, history_length=4, ) - config.replay_fn = lambda: ReplayWrapper(UniformReplay, replay_kwargs, async=True) + config.replay_fn = lambda: ReplayWrapper(UniformReplay, replay_kwargs) config.state_normalizer = ImageNormalizer() config.reward_normalizer = SignNormalizer() @@ -177,7 +177,7 @@ def categorical_dqn_feature(**kwargs): replay_kwargs = dict( memory_size=int(1e4), batch_size=config.batch_size) - config.replay_fn = lambda: ReplayWrapper(UniformReplay, replay_kwargs, async=True) + config.replay_fn = lambda: ReplayWrapper(UniformReplay, replay_kwargs) config.discount = 0.99 config.target_network_update_freq = 200 @@ -211,7 +211,7 @@ def categorical_dqn_pixel(**kwargs): batch_size=config.batch_size, history_length=4, ) - config.replay_fn = lambda: ReplayWrapper(UniformReplay, replay_kwargs, async=True) + config.replay_fn = lambda: ReplayWrapper(UniformReplay, replay_kwargs) config.discount = 0.99 config.state_normalizer = ImageNormalizer()