Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions python/ray/experimental/sgd/pytorch/distributed_pytorch_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging
import os
import torch.distributed as dist
import torch.utils.data

from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner

logger = logging.getLogger(__name__)


class DistributedPyTorchRunner(PyTorchRunner):
"""Manages a distributed PyTorch model replica."""

def __init__(self,
model_creator,
data_creator,
optimizer_creator,
config=None,
batch_size=16,
backend="gloo"):
"""Initializes the runner.

Args:
model_creator (dict -> torch.nn.Module): see pytorch_trainer.py.
data_creator (dict -> Dataset, Dataset): see pytorch_trainer.py.
optimizer_creator (torch.nn.Module, dict -> loss, optimizer):
see pytorch_trainer.py.
config (dict): see pytorch_trainer.py.
batch_size (int): batch size used by one replica for an update.
backend (string): see pytorch_trainer.py.
"""
super(DistributedPyTorchRunner, self).__init__(
model_creator, data_creator, optimizer_creator, config, batch_size)
self.backend = backend

def setup(self, url, world_rank, world_size):
"""Connects to the distributed PyTorch backend and initializes the model.

Args:
url (str): the URL used to connect to distributed PyTorch.
world_rank (int): the index of the runner.
world_size (int): the total number of runners.
"""
self._setup_distributed_pytorch(url, world_rank, world_size)
self._setup_training()

def _setup_distributed_pytorch(self, url, world_rank, world_size):
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
with self._timers["setup_proc"]:
self.world_rank = world_rank
logger.debug(
"Connecting to {} world_rank: {} world_size: {}".format(
url, world_rank, world_size))
logger.debug("using {}".format(self.backend))
dist.init_process_group(
backend=self.backend,
init_method=url,
rank=world_rank,
world_size=world_size)

def _setup_training(self):
logger.debug("Creating model")
self.model = self.model_creator(self.config)
if torch.cuda.is_available():
self.model = torch.nn.parallel.DistributedDataParallel(
self.model.cuda())
else:
self.model = torch.nn.parallel.DistributedDataParallelCPU(
self.model)

logger.debug("Creating optimizer")
self.criterion, self.optimizer = self.optimizer_creator(
self.model, self.config)
if torch.cuda.is_available():
self.criterion = self.criterion.cuda()

logger.debug("Creating dataset")
self.training_set, self.validation_set = self.data_creator(self.config)

# TODO: make num_workers configurable
self.train_sampler = torch.utils.data.distributed.DistributedSampler(
self.training_set)
self.train_loader = torch.utils.data.DataLoader(
self.training_set,
batch_size=self.batch_size,
shuffle=(self.train_sampler is None),
num_workers=2,
pin_memory=False,
sampler=self.train_sampler)

self.validation_sampler = (
torch.utils.data.distributed.DistributedSampler(
self.validation_set))
self.validation_loader = torch.utils.data.DataLoader(
self.validation_set,
batch_size=self.batch_size,
shuffle=(self.validation_sampler is None),
num_workers=2,
pin_memory=False,
sampler=self.validation_sampler)

def step(self):
"""Runs a training epoch and updates the model parameters."""
logger.debug("Starting step")
self.train_sampler.set_epoch(self.epoch)
return super(DistributedPyTorchRunner, self).step()

def get_state(self):
"""Returns the state of the runner."""
return {
"epoch": self.epoch,
"model": self.model.module.state_dict(),
"optimizer": self.optimizer.state_dict(),
"stats": self.stats()
}

def set_state(self, state):
"""Sets the state of the model."""
# TODO: restore timer stats
self.model.module.load_state_dict(state["model"])
self.optimizer.load_state_dict(state["optimizer"])
self.epoch = state["stats"]["epoch"]

def shutdown(self):
"""Attempts to shut down the worker."""
super(DistributedPyTorchRunner, self).shutdown()
dist.destroy_process_group()
105 changes: 34 additions & 71 deletions python/ray/experimental/sgd/pytorch/pytorch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from __future__ import print_function

import logging
import os
import torch
import torch.distributed as dist
import torch.utils.data

import ray
Expand All @@ -15,36 +13,30 @@


class PyTorchRunner(object):
"""Manages a distributed PyTorch model replica"""
"""Manages a PyTorch model for training."""

def __init__(self,
model_creator,
data_creator,
optimizer_creator,
config=None,
batch_size=16,
backend="gloo"):
batch_size=16):
"""Initializes the runner.

Args:
model_creator (dict -> torch.nn.Module): creates the model using
the config.
data_creator (dict -> Dataset, Dataset): creates the training and
validation data sets using the config.
model_creator (dict -> torch.nn.Module): see pytorch_trainer.py.
data_creator (dict -> Dataset, Dataset): see pytorch_trainer.py.
optimizer_creator (torch.nn.Module, dict -> loss, optimizer):
creates the loss and optimizer using the model and the config.
config (dict): configuration passed to 'model_creator',
'data_creator', and 'optimizer_creator'.
batch_size (int): batch size used in an update.
backend (string): backend used by distributed PyTorch.
see pytorch_trainer.py.
config (dict): see pytorch_trainer.py.
batch_size (int): see pytorch_trainer.py.
"""

self.model_creator = model_creator
self.data_creator = data_creator
self.optimizer_creator = optimizer_creator
self.config = {} if config is None else config
self.batch_size = batch_size
self.backend = backend
self.verbose = True

self.epoch = 0
Expand All @@ -56,82 +48,45 @@ def __init__(self,
]
}

def setup(self, url, world_rank, world_size):
"""Connects to the distributed PyTorch backend and initializes the model.

Args:
url (str): the URL used to connect to distributed PyTorch.
world_rank (int): the index of the runner.
world_size (int): the total number of runners.
"""
self._setup_distributed_pytorch(url, world_rank, world_size)
self._setup_training()

def _setup_distributed_pytorch(self, url, world_rank, world_size):
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
with self._timers["setup_proc"]:
self.world_rank = world_rank
logger.debug(
"Connecting to {} world_rank: {} world_size: {}".format(
url, world_rank, world_size))
logger.debug("using {}".format(self.backend))
dist.init_process_group(
backend=self.backend,
init_method=url,
rank=world_rank,
world_size=world_size)

def _setup_training(self):
def setup(self):
"""Initializes the model."""
logger.debug("Creating model")
self.model = self.model_creator(self.config)
if torch.cuda.is_available():
self.model = torch.nn.parallel.DistributedDataParallel(
self.model.cuda())
else:
self.model = torch.nn.parallel.DistributedDataParallelCPU(
self.model)
self.model = self.model.cuda()

logger.debug("Creating optimizer")
self.criterion, self.optimizer = self.optimizer_creator(
self.model, self.config)

if torch.cuda.is_available():
self.criterion = self.criterion.cuda()

logger.debug("Creating dataset")
self.training_set, self.validation_set = self.data_creator(self.config)

# TODO: make num_workers configurable
self.train_sampler = torch.utils.data.distributed.DistributedSampler(
self.training_set)
self.train_loader = torch.utils.data.DataLoader(
self.training_set,
batch_size=self.batch_size,
shuffle=(self.train_sampler is None),
shuffle=True,
num_workers=2,
pin_memory=False,
sampler=self.train_sampler)
pin_memory=False)

self.validation_sampler = (
torch.utils.data.distributed.DistributedSampler(
self.validation_set))
self.validation_loader = torch.utils.data.DataLoader(
self.validation_set,
batch_size=self.batch_size,
shuffle=(self.validation_sampler is None),
shuffle=True,
num_workers=2,
pin_memory=False,
sampler=self.validation_sampler)
pin_memory=False)

def get_node_ip(self):
"""Returns the IP address of the current node"""
"""Returns the IP address of the current node."""
return ray.services.get_node_ip_address()

def step(self):
"""Runs a training epoch and updates the model parameters"""
logger.debug("Starting step")
self.train_sampler.set_epoch(self.epoch)
def find_free_port(self):
"""Finds a free port on the current node."""
return utils.find_free_port()

def step(self):
"""Runs a training epoch and updates the model parameters."""
logger.debug("Begin Training Epoch {}".format(self.epoch + 1))
with self._timers["training"]:
train_stats = utils.train(self.train_loader, self.model,
Expand All @@ -144,7 +99,7 @@ def step(self):
return train_stats

def validate(self):
"""Evaluates the model on the validation data set"""
"""Evaluates the model on the validation data set."""
with self._timers["validation"]:
validation_stats = utils.validate(self.validation_loader,
self.model, self.criterion)
Expand All @@ -153,7 +108,7 @@ def validate(self):
return validation_stats

def stats(self):
"""Returns a dictionary of statistics collected"""
"""Returns a dictionary of statistics collected."""
stats = {"epoch": self.epoch}
for k, t in self._timers.items():
stats[k + "_time_mean"] = t.mean
Expand All @@ -162,7 +117,7 @@ def stats(self):
return stats

def get_state(self):
"""Returns the state of the runner"""
"""Returns the state of the runner."""
return {
"epoch": self.epoch,
"model": self.model.state_dict(),
Expand All @@ -171,12 +126,20 @@ def get_state(self):
}

def set_state(self, state):
"""Sets the state of the model"""
"""Sets the state of the model."""
# TODO: restore timer stats
self.model.load_state_dict(state["model"])
self.optimizer.load_state_dict(state["optimizer"])
self.epoch = state["stats"]["epoch"]

def shutdown(self):
"""Attempts to shut down the worker"""
dist.destroy_process_group()
"""Attempts to shut down the worker."""
del self.validation_loader
del self.validation_set
del self.train_loader
del self.training_set
del self.criterion
del self.optimizer
del self.model
if torch.cuda.is_available():
torch.cuda.empty_cache()
Loading