Skip to content

Commit

Permalink
Fix line lengths.
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Yang committed Jul 2, 2019
1 parent 5c7ab38 commit 9912129
Showing 1 changed file with 30 additions and 16 deletions.
46 changes: 30 additions & 16 deletions src/mnist-mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ def forward(self, x):

def train(gpu, args):
rank = args.nr * args.gpus + gpu
dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank)
dist.init_process_group(
backend='nccl',
init_method='env://',
world_size=args.world_size,
rank=rank)

model = ConvNet()
torch.cuda.set_device(gpu)
Expand All @@ -66,19 +70,24 @@ def train(gpu, args):
model, optimizer = amp.initialize(model, optimizer, opt_level='O2')
model = DDP(model)
# Data loading code
train_dataset = torchvision.datasets.MNIST(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,
num_replicas=args.world_size,
rank=rank)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
sampler=train_sampler)
train_dataset = torchvision.datasets.MNIST(
root='./data',
train=True,
transform=transforms.ToTensor(),
download=True
)
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=args.world_size,
rank=rank)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
sampler=train_sampler
)

start = datetime.now()
total_step = len(train_loader)
Expand All @@ -96,8 +105,13 @@ def train(gpu, args):
scaled_loss.backward()
optimizer.step()
if (i + 1) % 100 == 0 and gpu == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, args.epochs, i + 1, total_step,
loss.item()))
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
epoch + 1,
args.epochs,
i + 1,
total_step,
loss.item())
)
if gpu == 0:
print("Training complete in: " + str(datetime.now() - start))

Expand Down

0 comments on commit 9912129

Please sign in to comment.