Skip to content

Commit

Permalink
ref: ddp backend refactor (3) (#3208)
Browse files Browse the repository at this point in the history
* ddp backend refactor

* ddp backend refactor
  • Loading branch information
williamFalcon authored Aug 27, 2020
1 parent a8daf91 commit 6bae404
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 16 deletions.
4 changes: 3 additions & 1 deletion pytorch_lightning/accelerators/cpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ def setup(self, model):
self.trainer.optimizers = optimizers
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies
self.trainer.model = model

def train(self, model):
def train(self):
model = self.trainer.model
results = self.trainer.run_pretrain_routine(model)
return results

Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/accelerators/ddp2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def __init__(self, trainer):
def setup(self, model):
self._resolve_task_idx()

self.trainer.model = model

def _resolve_task_idx(self):
if self.trainer.is_slurm_managing_tasks:
self.task_idx = int(os.environ['SLURM_LOCALID'])
Expand All @@ -57,7 +59,8 @@ def _resolve_task_idx(self):
m = 'ddp2 only works in SLURM or via torchelastic with the WORLD_SIZE, LOCAL_RANK, GROUP_RANK flags'
raise MisconfigurationException(m)

def train(self, model):
def train(self):
model = self.trainer.model
self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model)

def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0):
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,16 @@ def setup(self, model):
elif self.mode == 'torchelastic_ddp':
self.__torchelastic_setup()

self.trainer.model = model

def __slurm_setup(self):
self.task_idx = int(os.environ['SLURM_LOCALID'])

def __torchelastic_setup(self):
self.task_idx = int(os.environ['LOCAL_RANK'])

def train(self, model):
def train(self):
model = self.trainer.model
self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model)

def spawn_ddp_children(self, model):
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ def setup(self, model):
smp = mp.get_context('spawn')
self.mp_queue = smp.SimpleQueue()

def train(self, model):
self.trainer.model = model

def train(self):
model = self.trainer.model

# train in children process
mp.spawn(self.ddp_train, nprocs=self.nprocs, args=(self.mp_queue, model,))

Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/accelerators/gpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ def setup(self, model):

if self.trainer.amp_backend == AMPType.APEX:
model = self._setup_nvidia_apex(model)
return model

def train(self, model):
self.trainer.model = model

def train(self):
model = self.trainer.model
results = self.trainer.run_pretrain_routine(model)
return results

Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/accelerators/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def setup(self, model):
smp = mp.get_context(self.start_method)
self.mp_queue = smp.SimpleQueue()

self.trainer.model = model

def teardown(self, model):
# restore main state with best weights
best_path = self.mp_queue.get()
Expand All @@ -75,8 +77,8 @@ def teardown(self, model):
self.__load_weights_on_main_process()
return results

def train(self, model: LightningModule):
self.trainer.model = model
def train(self):
model = self.trainer.model

# train
if self.trainer.tpu_id is not None:
Expand Down
16 changes: 8 additions & 8 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,26 +1040,26 @@ def fit(
if self.use_ddp2:
self.accelerator_backend = DDP2Backend(self)
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()

elif use_slurm_ddp:
self.accelerator_backend = DDPBackend(self, mode='slurm_ddp')
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()

elif use_torchelastic_ddp:
self.accelerator_backend = DDPBackend(self, mode='torchelastic_ddp')
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()

# regular ddp using .spawn
elif use_ddp_spawn:
self.accelerator_backend = DDPSpawnBackend(self, nprocs=self.num_processes)
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()

# ddp
Expand All @@ -1082,20 +1082,20 @@ def fit(

elif self.use_single_gpu:
self.accelerator_backend = GPUBackend(self)
model = self.accelerator_backend.setup(model)
results = self.accelerator_backend.train(model)
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()

elif self.use_tpu:
self.accelerator_backend = TPUBackend(self)
self.accelerator_backend.setup(model)
self.accelerator_backend.train(model)
self.accelerator_backend.train()
self.accelerator_backend.teardown(model)

else:
self.accelerator_backend = CPUBackend(self)
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()

# hook
Expand Down

0 comments on commit 6bae404

Please sign in to comment.