Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: ddp backend refactor (3) #3208

Merged
merged 2 commits into from
Aug 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pytorch_lightning/accelerators/cpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ def setup(self, model):
self.trainer.optimizers = optimizers
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies
self.trainer.model = model

def train(self, model):
def train(self):
model = self.trainer.model
results = self.trainer.run_pretrain_routine(model)
return results

Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/accelerators/ddp2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def __init__(self, trainer):
def setup(self, model):
self._resolve_task_idx()

self.trainer.model = model

def _resolve_task_idx(self):
if self.trainer.is_slurm_managing_tasks:
self.task_idx = int(os.environ['SLURM_LOCALID'])
Expand All @@ -57,7 +59,8 @@ def _resolve_task_idx(self):
m = 'ddp2 only works in SLURM or via torchelastic with the WORLD_SIZE, LOCAL_RANK, GROUP_RANK flags'
raise MisconfigurationException(m)

def train(self, model):
def train(self):
model = self.trainer.model
self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model)

def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0):
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,16 @@ def setup(self, model):
elif self.mode == 'torchelastic_ddp':
self.__torchelastic_setup()

self.trainer.model = model

def __slurm_setup(self):
self.task_idx = int(os.environ['SLURM_LOCALID'])

def __torchelastic_setup(self):
self.task_idx = int(os.environ['LOCAL_RANK'])

def train(self, model):
def train(self):
model = self.trainer.model
self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model)

def spawn_ddp_children(self, model):
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ def setup(self, model):
smp = mp.get_context('spawn')
self.mp_queue = smp.SimpleQueue()

def train(self, model):
self.trainer.model = model

def train(self):
model = self.trainer.model

# train in children process
mp.spawn(self.ddp_train, nprocs=self.nprocs, args=(self.mp_queue, model,))

Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/accelerators/gpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ def setup(self, model):

if self.trainer.amp_backend == AMPType.APEX:
model = self._setup_nvidia_apex(model)
return model

def train(self, model):
self.trainer.model = model

def train(self):
model = self.trainer.model
results = self.trainer.run_pretrain_routine(model)
return results

Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/accelerators/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def setup(self, model):
smp = mp.get_context(self.start_method)
self.mp_queue = smp.SimpleQueue()

self.trainer.model = model

def teardown(self, model):
# restore main state with best weights
best_path = self.mp_queue.get()
Expand All @@ -75,8 +77,8 @@ def teardown(self, model):
self.__load_weights_on_main_process()
return results

def train(self, model: LightningModule):
self.trainer.model = model
def train(self):
model = self.trainer.model

# train
if self.trainer.tpu_id is not None:
Expand Down
16 changes: 8 additions & 8 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,26 +1040,26 @@ def fit(
if self.use_ddp2:
self.accelerator_backend = DDP2Backend(self)
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()

elif use_slurm_ddp:
self.accelerator_backend = DDPBackend(self, mode='slurm_ddp')
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()

elif use_torchelastic_ddp:
self.accelerator_backend = DDPBackend(self, mode='torchelastic_ddp')
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()

# regular ddp using .spawn
elif use_ddp_spawn:
self.accelerator_backend = DDPSpawnBackend(self, nprocs=self.num_processes)
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()

# ddp
Expand All @@ -1082,20 +1082,20 @@ def fit(

elif self.use_single_gpu:
self.accelerator_backend = GPUBackend(self)
model = self.accelerator_backend.setup(model)
results = self.accelerator_backend.train(model)
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()

elif self.use_tpu:
self.accelerator_backend = TPUBackend(self)
self.accelerator_backend.setup(model)
self.accelerator_backend.train(model)
self.accelerator_backend.train()
self.accelerator_backend.teardown(model)

else:
self.accelerator_backend = CPUBackend(self)
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()

# hook
Expand Down