Skip to content

Commit

Permalink
Add debug mode and logger
Browse files Browse the repository at this point in the history
Add debug flag to run in debug mode (useful for gdb) and use logger
instead of print.
  • Loading branch information
scientist1642 committed Mar 30, 2017
1 parent e7b2794 commit 02d6f65
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 17 deletions.
31 changes: 21 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse
import os
import sys
import math

import torch
import torch.optim as optim
Expand All @@ -13,8 +14,11 @@
from model import ActorCritic
from train import train
from test import test
from utils import logger
import my_optim

logger = logger.getLogger('main')

# Based on
# https://github.com/pytorch/examples/tree/master/mnist_hogwild
# Training settings
Expand All @@ -37,13 +41,16 @@
help='environment to train on (default: PongDeterministic-v3)')
parser.add_argument('--no-shared', default=False, metavar='O',
help='use an optimizer without shared momentum.')
parser.add_argument('--max-iters', type=int, default=math.inf,
help='maximum iterations per process.')

parser.add_argument('--debug', action='store_true', default=False,
help='run in a way its easier to debug')

if __name__ == '__main__':
args = parser.parse_args()

torch.manual_seed(args.seed)

env = create_atari_env(args.env_name)
shared_model = ActorCritic(
env.observation_space.shape[0], env.action_space)
Expand All @@ -55,15 +62,19 @@
optimizer = my_optim.SharedAdam(shared_model.parameters(), lr=args.lr)
optimizer.share_memory()

processes = []

p = mp.Process(target=test, args=(args.num_processes, args, shared_model))
p.start()
processes.append(p)

if not args.debug:
processes = []

for rank in range(0, args.num_processes):
p = mp.Process(target=train, args=(rank, args, shared_model, optimizer))
p = mp.Process(target=test, args=(args.num_processes, args, shared_model))
p.start()
processes.append(p)
for p in processes:
p.join()
for rank in range(0, args.num_processes):
p = mp.Process(target=train, args=(rank, args, shared_model, optimizer))
p.start()
processes.append(p)
for p in processes:
p.join()
else: ## debug is enabled
# run only one process in a main, easier to debug
train(0, args, shared_model, optimizer)
4 changes: 3 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ def __init__(self, num_inputs, action_space):
self.lstm = nn.LSTMCell(32 * 3 * 3, 256)

num_outputs = action_space.n

self.critic_linear = nn.Linear(256, 1)
self.actor_linear = nn.Linear(256, num_outputs)
#self.critic_linear = nn.Linear(288, 1)
#self.actor_linear = nn.Linear(288, num_outputs)

self.apply(weights_init)
self.actor_linear.weight.data = normalized_columns_initializer(
Expand All @@ -66,7 +69,6 @@ def forward(self, inputs):
x = F.elu(self.conv2(x))
x = F.elu(self.conv3(x))
x = F.elu(self.conv4(x))

x = x.view(-1, 32 * 3 * 3)
hx, cx = self.lstm(x, (hx, cx))
x = hx
Expand Down
4 changes: 3 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from torchvision import datasets, transforms
import time
from collections import deque
from utils import logger

logger = logger.getLogger('test')

def test(rank, args, shared_model):
torch.manual_seed(args.seed + rank)
Expand Down Expand Up @@ -59,7 +61,7 @@ def test(rank, args, shared_model):
done = True

if done:
print("Time {}, episode reward {}, episode length {}".format(
logger.info("Time {}, episode reward {}, episode length {}".format(
time.strftime("%Hh %Mm %Ss",
time.gmtime(time.time() - start_time)),
reward_sum, episode_length))
Expand Down
30 changes: 25 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import math
import os
import sys
import resource
import gc

import torch
import torch.nn.functional as F
Expand All @@ -9,15 +11,16 @@
from model import ActorCritic
from torch.autograd import Variable
from torchvision import datasets, transforms
from utils import logger

logger = logger.getLogger('main')

def ensure_shared_grads(model, shared_model):
for param, shared_param in zip(model.parameters(), shared_model.parameters()):
if shared_param.grad is not None:
return
shared_param._grad = param.grad


def train(rank, args, shared_model, optimizer=None):
torch.manual_seed(args.seed + rank)

Expand All @@ -36,8 +39,29 @@ def train(rank, args, shared_model, optimizer=None):
done = True

episode_length = 0

iteration = 0

while True:

values = []
log_probs = []
rewards = []
entropies = []

if iteration == args.max_iters:
logger.info('Max iteration {} reached..'.format(args.max_iters))
break

if iteration % 200 == 0 and rank == 0:
mem_used = int(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
mem_used_mb = mem_used / 1024
logger.info('Memory usage of one proc: {} (mb)'.format(mem_used_mb))


iteration += 1
episode_length += 1

# Sync with the shared model
model.load_state_dict(shared_model.state_dict())
if done:
Expand All @@ -47,10 +71,6 @@ def train(rank, args, shared_model, optimizer=None):
cx = Variable(cx.data)
hx = Variable(hx.data)

values = []
log_probs = []
rewards = []
entropies = []

for step in range(args.num_steps):
value, logit, (hx, cx) = model(
Expand Down
49 changes: 49 additions & 0 deletions utils/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
import os
import logging
import logging.config


LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO')
LOGGING = {
'version': 1,
'disable_existing_loggers': True,
'formatters': {
'verbose': {
'format': "[%(asctime)s] %(levelname)s " \
"[%(threadName)s:%(lineno)s] %(message)s",
'datefmt': "%Y-%m-%d %H:%M:%S"
},
'simple': {
'format': '%(levelname)s %(message)s'
},
},
'handlers': {
'console': {
'level': LOG_LEVEL,
'class': 'logging.StreamHandler',
'formatter': 'verbose'
},
'file': {
'level': LOG_LEVEL,
'class': 'logging.handlers.RotatingFileHandler',
'formatter': 'verbose',
'filename': 'rl.log',
'maxBytes': 10*10**6,
'backupCount': 3
}
},
'loggers': {
'': {
'handlers': ['console', 'file'],
'level': LOG_LEVEL,
},
}
}


logging.config.dictConfig(LOGGING)

def getLogger(name):

return logging.getLogger(name)

0 comments on commit 02d6f65

Please sign in to comment.