Skip to content

Commit 2e3b9ce

Browse files
committed
minor updates
1 parent c5eb9b8 commit 2e3b9ce

File tree

6 files changed

+20
-16
lines changed

6 files changed

+20
-16
lines changed

components/replay.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ class FiniteReplay(object):
4343
'''
4444
def __init__(self, memory_size, keys=None):
4545
if keys is None:
46-
keys = []
47-
self.keys = keys + ['action', 'reward', 'mask']
46+
keys = ['action', 'reward', 'mask']
47+
self.keys = keys
4848
self.memory_size = int(memory_size)
4949
self.clear()
5050

envs/env.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,24 @@
22
import gym_pygame
33
import gym_minatar
44
import gym_exploration
5-
try:
6-
import pybullet
7-
import pybullet_envs
8-
except ImportError:
9-
pass
105
from gym.wrappers.time_limit import TimeLimit
116

127
from envs.wrapper import *
138

149

1510
def make_env(env_name, max_episode_steps, episode_life=True):
16-
env = gym.make(env_name)
17-
env_group_title = get_env_group_title(env)
18-
# print(env_group_title, env_name)
11+
if 'DMC' in env_name:
12+
import dmc2gym
13+
domain, task, _ = env_name.split('-') # reacher-hard-DMC
14+
env = dmc2gym.make(domain_name=domain, task_name=task)
15+
env_group_title = 'dmc'
16+
else:
17+
if 'BulletEnv' in env_name:
18+
import pybullet
19+
import pybullet_envs
20+
env = gym.make(env_name)
21+
env_group_title = get_env_group_title(env)
22+
1923
if env_group_title == 'gym_minatar':
2024
env = make_minatar(env, max_episode_steps, scale=False)
2125
if len(env.observation_space.shape) == 3:
@@ -33,7 +37,7 @@ def make_env(env_name, max_episode_steps, episode_life=True):
3337
if len(env.observation_space.shape) == 3:
3438
env = TransposeImage(env)
3539
env = FrameStack(env, 4)
36-
elif env_group_title in ['classic_control', 'box2d', 'gym_pygame', 'gym_exploration', 'pybullet', 'mujoco', 'robotics']:
40+
elif env_group_title in ['classic_control', 'box2d', 'gym_pygame', 'gym_exploration', 'pybullet', 'mujoco', 'robotics', 'dmc']:
3741
if max_episode_steps > 0: # Set max episode steps
3842
env = TimeLimit(env.unwrapped, max_episode_steps)
3943
return env

run.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def main(argv):
1818
# Job time
1919
'time': '0-10:00:00',
2020
# GPU/CPU type
21-
'--cpus-per-task': 1,
21+
'cpus-per-task': 1,
2222
# Memory
2323
'mem-per-cpu': '2000M',
2424
# Email address

utils/logger.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ class Logger(object):
77
def __init__(self, logs_dir, file_name='log.txt', filemode='w'):
88
logging.basicConfig(
99
format='%(asctime)s - %(levelname)s: %(message)s',
10-
filename=f'{logs_dir}/{file_name}',
10+
filename=f'{logs_dir}{file_name}',
1111
filemode=filemode
1212
)
1313
logger = logging.getLogger()

utils/plotter.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def merge_index(self, config_idx, mode, processed, exp=None):
7272
result_list[i] = result_list[i][:n]
7373
result_list[i].loc[:, self.x_label] = new_x
7474
result_list[i].loc[:, self.y_label] = new_y
75-
else:
75+
elif processed == False:
7676
# Cut off redundant results
7777
n = min(len(result) for result in result_list)
7878
for i in range(len(result_list)):
@@ -381,7 +381,7 @@ def find_key_value(config_dict, key):
381381
return config_dict[k]
382382
elif type(v) == dict:
383383
value = find_key_value(v, key)
384-
if value is not '/':
384+
if value != '/':
385385
return value
386386
return '/'
387387

utils/submitter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ def submit(self):
4949
if len(self.job_list) == 0:
5050
print("Finish submitting all jobs!")
5151
exit(1)
52-
time.sleep(self.check-time-interval)
52+
time.sleep(self.check_time_interval)

0 commit comments

Comments
 (0)