-
Notifications
You must be signed in to change notification settings - Fork 7.2k
[sgd] Add non-distributed PyTorch runner #4933
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
richardliaw
merged 7 commits into
ray-project:master
from
pschafhalter:sgd-pytorch-osx-compat
Jun 12, 2019
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
9496c68
Add non-distributed PyTorch runner
pschafhalter 0157932
use dist.is_available() instead of checking OS
pschafhalter 0903c77
Nicer exception
pschafhalter 5def336
Fix bug in choosing port
pschafhalter e250fd6
Refactor some code
pschafhalter 6054269
Address comments
pschafhalter bf2bfc8
Address comments
pschafhalter File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
131 changes: 131 additions & 0 deletions
131
python/ray/experimental/sgd/pytorch/distributed_pytorch_runner.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,131 @@ | ||
| from __future__ import absolute_import | ||
| from __future__ import division | ||
| from __future__ import print_function | ||
|
|
||
| import logging | ||
| import os | ||
| import torch.distributed as dist | ||
| import torch.utils.data | ||
|
|
||
| from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class DistributedPyTorchRunner(PyTorchRunner): | ||
| """Manages a distributed PyTorch model replica.""" | ||
|
|
||
| def __init__(self, | ||
| model_creator, | ||
| data_creator, | ||
| optimizer_creator, | ||
| config=None, | ||
| batch_size=16, | ||
| backend="gloo"): | ||
| """Initializes the runner. | ||
|
|
||
| Args: | ||
| model_creator (dict -> torch.nn.Module): see pytorch_trainer.py. | ||
| data_creator (dict -> Dataset, Dataset): see pytorch_trainer.py. | ||
| optimizer_creator (torch.nn.Module, dict -> loss, optimizer): | ||
| see pytorch_trainer.py. | ||
| config (dict): see pytorch_trainer.py. | ||
| batch_size (int): batch size used by one replica for an update. | ||
| backend (string): see pytorch_trainer.py. | ||
| """ | ||
| super(DistributedPyTorchRunner, self).__init__( | ||
| model_creator, data_creator, optimizer_creator, config, batch_size) | ||
| self.backend = backend | ||
|
|
||
| def setup(self, url, world_rank, world_size): | ||
| """Connects to the distributed PyTorch backend and initializes the model. | ||
|
|
||
| Args: | ||
| url (str): the URL used to connect to distributed PyTorch. | ||
| world_rank (int): the index of the runner. | ||
| world_size (int): the total number of runners. | ||
| """ | ||
| self._setup_distributed_pytorch(url, world_rank, world_size) | ||
| self._setup_training() | ||
|
|
||
| def _setup_distributed_pytorch(self, url, world_rank, world_size): | ||
| os.environ["CUDA_LAUNCH_BLOCKING"] = "1" | ||
| with self._timers["setup_proc"]: | ||
| self.world_rank = world_rank | ||
| logger.debug( | ||
| "Connecting to {} world_rank: {} world_size: {}".format( | ||
| url, world_rank, world_size)) | ||
| logger.debug("using {}".format(self.backend)) | ||
| dist.init_process_group( | ||
| backend=self.backend, | ||
| init_method=url, | ||
| rank=world_rank, | ||
| world_size=world_size) | ||
|
|
||
| def _setup_training(self): | ||
| logger.debug("Creating model") | ||
| self.model = self.model_creator(self.config) | ||
| if torch.cuda.is_available(): | ||
| self.model = torch.nn.parallel.DistributedDataParallel( | ||
| self.model.cuda()) | ||
| else: | ||
| self.model = torch.nn.parallel.DistributedDataParallelCPU( | ||
| self.model) | ||
|
|
||
| logger.debug("Creating optimizer") | ||
| self.criterion, self.optimizer = self.optimizer_creator( | ||
| self.model, self.config) | ||
| if torch.cuda.is_available(): | ||
| self.criterion = self.criterion.cuda() | ||
|
|
||
| logger.debug("Creating dataset") | ||
| self.training_set, self.validation_set = self.data_creator(self.config) | ||
|
|
||
| # TODO: make num_workers configurable | ||
| self.train_sampler = torch.utils.data.distributed.DistributedSampler( | ||
| self.training_set) | ||
| self.train_loader = torch.utils.data.DataLoader( | ||
| self.training_set, | ||
| batch_size=self.batch_size, | ||
| shuffle=(self.train_sampler is None), | ||
| num_workers=2, | ||
| pin_memory=False, | ||
| sampler=self.train_sampler) | ||
|
|
||
| self.validation_sampler = ( | ||
| torch.utils.data.distributed.DistributedSampler( | ||
| self.validation_set)) | ||
| self.validation_loader = torch.utils.data.DataLoader( | ||
| self.validation_set, | ||
| batch_size=self.batch_size, | ||
| shuffle=(self.validation_sampler is None), | ||
| num_workers=2, | ||
| pin_memory=False, | ||
| sampler=self.validation_sampler) | ||
|
|
||
| def step(self): | ||
| """Runs a training epoch and updates the model parameters.""" | ||
| logger.debug("Starting step") | ||
| self.train_sampler.set_epoch(self.epoch) | ||
| return super(DistributedPyTorchRunner, self).step() | ||
|
|
||
| def get_state(self): | ||
| """Returns the state of the runner.""" | ||
| return { | ||
| "epoch": self.epoch, | ||
| "model": self.model.module.state_dict(), | ||
| "optimizer": self.optimizer.state_dict(), | ||
| "stats": self.stats() | ||
| } | ||
|
|
||
| def set_state(self, state): | ||
| """Sets the state of the model.""" | ||
| # TODO: restore timer stats | ||
| self.model.module.load_state_dict(state["model"]) | ||
| self.optimizer.load_state_dict(state["optimizer"]) | ||
| self.epoch = state["stats"]["epoch"] | ||
|
|
||
| def shutdown(self): | ||
| """Attempts to shut down the worker.""" | ||
| super(DistributedPyTorchRunner, self).shutdown() | ||
| dist.destroy_process_group() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.