Skip to content

Commit

Permalink
ddp backend refactor (#3207)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored Aug 26, 2020
1 parent ff3c2f4 commit a8daf91
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class Accelerator(object):
def __init__(self, trainer):
self.trainer = trainer

def setup(self):
def setup(self, model):
pass

def teardown(self):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, trainer):
super().__init__(trainer)
self.task_idx = None

def setup(self):
def setup(self, model):
self._resolve_task_idx()

def _resolve_task_idx(self):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, trainer, mode: str = 'ddp'):
self._has_spawned_children = False
self.mode = mode

def setup(self):
def setup(self, model):
if self.mode == 'ddp':
pass
elif self.mode == 'slurm_ddp':
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, trainer, nprocs):
self.mp_queue = None
self.nprocs = nprocs

def setup(self):
def setup(self, model):
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port()))

# pass in a state q
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, trainer):
self.start_method = None
self.mp_queue = None

def setup(self):
def setup(self, model):
rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores')

if not XLA_AVAILABLE:
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,26 +1039,26 @@ def fit(
# DDP2 (cluster only)
if self.use_ddp2:
self.accelerator_backend = DDP2Backend(self)
self.accelerator_backend.setup()
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train(model)
self.accelerator_backend.teardown()

elif use_slurm_ddp:
self.accelerator_backend = DDPBackend(self, mode='slurm_ddp')
self.accelerator_backend.setup()
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train(model)
self.accelerator_backend.teardown()

elif use_torchelastic_ddp:
self.accelerator_backend = DDPBackend(self, mode='torchelastic_ddp')
self.accelerator_backend.setup()
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train(model)
self.accelerator_backend.teardown()

# regular ddp using .spawn
elif use_ddp_spawn:
self.accelerator_backend = DDPSpawnBackend(self, nprocs=self.num_processes)
self.accelerator_backend.setup()
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train(model)
self.accelerator_backend.teardown()

Expand Down Expand Up @@ -1088,7 +1088,7 @@ def fit(

elif self.use_tpu:
self.accelerator_backend = TPUBackend(self)
self.accelerator_backend.setup()
self.accelerator_backend.setup(model)
self.accelerator_backend.train(model)
self.accelerator_backend.teardown(model)

Expand Down

0 comments on commit a8daf91

Please sign in to comment.