Skip to content
This repository was archived by the owner on Jan 26, 2021. It is now read-only.

PyTorch binding with mnist examples #148

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,5 @@ _Pvt_Extensions

# Python
*.pyc
*.egg
binding/python/multiverso_python.egg-info/
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@ configure_file(

add_custom_target(uninstall
COMMAND ${CMAKE_COMMAND} -P ${CMAKE_CURRENT_BINARY_DIR}/cmake_uninstall.cmake)
add_definitions(-DENABLE_DCASGD)
129 changes: 129 additions & 0 deletions binding/python/examples/torch/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

import numpy as np
import multiverso as mv
from multiverso.torch_ext import torchmodel

mv.init(sync=True, updater='sgd')

# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0, metavar='M',
help='SGD momentum (default: 0)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)


kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)

model = torchmodel.MVTorchModel(Net())

if args.cuda:
model.cuda()

optimizer = optim.SGD(model.parameters(), lr=args.lr * mv.workers_num(), momentum=args.momentum)

def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if batch_idx % mv.workers_num() == mv.worker_id():
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()

model.cpu()
model.mv_sync()
model.cuda()

if (batch_idx/mv.workers_num()) % args.log_interval == 0:
print('Worker: {}\tTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
mv.worker_id(), epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data[0]))

def test(epoch):
model.eval()
test_loss = 0
correct = 0
for data, target in test_loader:
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
output = model(data)
test_loss += F.nll_loss(output, target).data[0]
pred = output.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(target.data).cpu().sum()

test_loss = test_loss
test_loss /= len(test_loader) # loss function already averages over batch size
print('\nWorker: {}\tTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
mv.worker_id(), test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))


for epoch in range(1, args.epochs + 1):
train(epoch)
test(epoch)

mv.shutdown()
3 changes: 2 additions & 1 deletion binding/python/multiverso/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
mv_lib = Loader.get_lib()


def init(sync=False):
def init(sync=False, updater='sgd'):
'''Initialize mutliverso.

This should be called only once before training at the beginning of the
Expand All @@ -29,6 +29,7 @@ def init(sync=False):
args = [b""] # the first argument will be ignored. So we put a placeholder here
if sync:
args.append(b"-sync=true")
args.append(b"-updater=sgd")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be optional?

n = len(args)
args_type = ctypes.c_char_p * n
mv_lib.MV_Init(ctypes.pointer(ctypes.c_int(n)), args_type(*[ctypes.c_char_p(arg) for arg in args]))
Expand Down
Empty file.
46 changes: 46 additions & 0 deletions binding/python/multiverso/torch_ext/torchmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/usr/bin/env python
# coding:utf8

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

import numpy as np
import multiverso as mv


class MVTorchModel(object):
def __init__(self, tmobj):
assert(isinstance(tmobj, nn.Module))
self._tmobj = tmobj
self._mv_params=[]
for param in self._tmobj.parameters():
self._mv_params.append(mv.ArrayTableHandler(param.data.numpy().size, param.data.numpy().reshape((-1,))))
mv.barrier()
self._last_mv_params=[]
for mv_param in self._mv_params:
self._last_mv_params.append(mv_param.get())
for param, last_mv_param in zip(self._tmobj.parameters(), self._last_mv_params):
param=Variable(torch.from_numpy(last_mv_param.reshape(param.data.numpy().shape)))

def mv_sync(self):
for mv_param, last_mv_param, param in zip(self._mv_params, self._last_mv_params, self._tmobj.parameters()):
mv_param.add(last_mv_param - param.data.numpy().reshape((-1,)))

for mv_param, last_mv_param, param in zip(self._mv_params, self._last_mv_params, self._tmobj.parameters()):
last_mv_param = mv_param.get()
param=Variable(torch.from_numpy(last_mv_param.reshape(param.data.numpy().shape)))

def __call__(self, *args, **kwargs):
return self._tmobj(*args, **kwargs)

def __getattribute__(self, attr):
if attr in ['_tmobj', '_mv_params', '_last_mv_params']:
return object.__getattribute__(self, attr)
elif attr in ['mv_sync', '__call__']:
return getattr(MVTorchModel, attr).__get__(self)
else:
return getattr(self._tmobj, attr)
2 changes: 1 addition & 1 deletion binding/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def readme():
url='https://github.com/Microsoft/multiverso',
author='Microsoft',
license='MIT',
packages=['multiverso', 'multiverso.theano_ext', 'multiverso.theano_ext.lasagne_ext'],
packages=['multiverso', 'multiverso.torch_ext', 'multiverso.theano_ext', 'multiverso.theano_ext.lasagne_ext'],
# TODO: The lasagne on pypi is too old. multiverso need some functions in
# lasagne-0.2 which is not released yet. Please replace the dev version
# with the stable release later.
Expand Down