Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates: Support the latest Atari environments and state entropy maximization-based exploration #298

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions a2c_ppo_acktr/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ def get_args():
action='store_true',
default=False,
help='use a linear schedule on the learning rate')
parser.add_argument(
'--use-sem',
action='store_true',
default=False,
help='use state entropy maximization to improve exploration')
args = parser.parse_args()

args.cuda = not args.no_cuda and torch.cuda.is_available()
Expand Down
5 changes: 3 additions & 2 deletions a2c_ppo_acktr/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _thunk():
env = gym.make(env_id)

is_atari = hasattr(gym.envs, 'atari') and isinstance(
env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
env.unwrapped, gym.envs.atari.AtariEnv)
if is_atari:
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
Expand Down Expand Up @@ -177,7 +177,8 @@ def reset(self):
return obs

def step_async(self, actions):
if isinstance(actions, torch.LongTensor):
if actions.dtype is torch.int64:
#if isinstance(actions, torch.LongTensor):
# Squeeze the dimension for discrete actions
actions = actions.squeeze(1)
actions = actions.cpu().numpy()
Expand Down
68 changes: 68 additions & 0 deletions a2c_ppo_acktr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,71 @@ def cleanup_log_dir(log_dir):
files = glob.glob(os.path.join(log_dir, '*.monitor.csv'))
for f in files:
os.remove(f)


# State entropy maximization with random encoders for efficient exploration (RE3)
class CNNEmbeddingNetwork(nn.Module):
def __init__(self, kwargs):
super(CNNEmbeddingNetwork, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(kwargs['in_channels'], 32, (8, 8), stride=(4, 4)), nn.ReLU(),
nn.Conv2d(32, 64, (4, 4), stride=(2, 2)), nn.ReLU(),
nn.Conv2d(64, 32, (3, 3), stride=(1, 1)), nn.ReLU(), nn.Flatten(),
nn.Linear(32 * 7 * 7, kwargs['embedding_size']))

def forward(self, ob):
x = self.main(ob)

return x

class MLPEmbeddingNetwork(nn.Module):
def __init__(self, kwargs):
super(MLPEmbeddingNetwork, self).__init__()
self.main = nn.Sequential(
nn.Linear(kwargs['input_dim'], 64), nn.ReLU(),
nn.Linear(64, 64), nn.ReLU(),
nn.Linear(64, kwargs['embedding_size'])
)

def forward(self, ob):
x = self.main(ob)

return x

class SEM:
def __init__(self,
ob_space,
action_space,
device,
num_updates
):
self.device = device
self.num_updates = num_updates
if action_space.__class__.__name__ == "Discrete":
self.embedding_network = CNNEmbeddingNetwork(
kwargs={'in_channels': ob_space.shape[0], 'embedding_size': 128})
elif action_space.__class__.__name__ == 'Box':
self.embedding_network = MLPEmbeddingNetwork(
kwargs={'input_dim': ob_space.shape[0], 'embedding_size': 128})
else:
raise NotImplementedError('Please check the supported environments!')

self.embedding_network.to(self.device)

# fixed and random encoder
for p in self.embedding_network.parameters():
p.requires_grad = False

def compute_intrinsic_rewards(self, obs_buffer, update_step, k=5):
size = obs_buffer.size()
obs = obs_buffer[:size[0] - 1]
intrinsic_rewards = torch.zeros(size=(size[0] - 1, size[1], 1))

for process in range(size[1]):
encoded_obs = self.embedding_network(obs[:, process].to(self.device))
for step in range(size[0] - 1):
dist = torch.norm(encoded_obs[step] - encoded_obs, p=2, dim=1)
H_step = torch.log(dist.sort().values[k + 1] + 1.)
intrinsic_rewards[step, process, 0] = H_step

return intrinsic_rewards * (1. - update_step / self.num_updates)
16 changes: 15 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@ def main():
start = time.time()
num_updates = int(
args.num_env_steps) // args.num_steps // args.num_processes

# use state entropy maximization to improve exploration
if args.use_sem:
sem = utils.SEM(
ob_space=envs.observation_space,
action_space=envs.action_space,
device=device,
num_updates=num_updates)

for j in range(num_updates):

if args.use_linear_lr_decay:
Expand All @@ -117,7 +126,7 @@ def main():
rollouts.obs[step], rollouts.recurrent_hidden_states[step],
rollouts.masks[step])

# Obser reward and next obs
# Observe reward and next obs
obs, reward, done, infos = envs.step(action)

for info in infos:
Expand All @@ -138,6 +147,11 @@ def main():
rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
rollouts.masks[-1]).detach()

# compute intrinsic rewards
if args.use_sem:
intrinsic_rewards = sem.compute_intrinsic_rewards(rollouts.obs, update_step=j)
rollouts.rewards += intrinsic_rewards.to(device)

if args.gail:
if j >= 10:
envs.venv.eval()
Expand Down