Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing TPU tests #2632

Merged
merged 55 commits into from
Jul 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
bacc208
init
Borda Jul 17, 2020
b93fab4
rename
Borda Jul 17, 2020
15cb65c
tpu_core_idx
Borda Jul 17, 2020
5ac3c61
idx 8
Borda Jul 17, 2020
724b7b4
idxs
Borda Jul 17, 2020
d817f0e
@pl_multi_process_test
Borda Jul 17, 2020
be79c50
assert
Borda Jul 17, 2020
f35437f
assert
Borda Jul 17, 2020
773e52a
deamon
Borda Jul 17, 2020
30cda41
no close
Borda Jul 17, 2020
be98712
imort
Borda Jul 18, 2020
f93c5aa
msg
Borda Jul 19, 2020
1dab0c8
use_single_gpu
Borda Jul 21, 2020
910f1c4
dataset
Borda Jul 21, 2020
410742a
idx
Borda Jul 21, 2020
2ad7b42
fix idx
Borda Jul 22, 2020
7576549
dataset
Borda Jul 22, 2020
accea7f
format
Borda Jul 23, 2020
85759bd
add pickable
Borda Jul 23, 2020
c8e7a70
typo
Borda Jul 25, 2020
19919a3
apex
Borda Jul 25, 2020
4ed34a6
typo
Borda Jul 25, 2020
b11c6bc
wip
Borda Jul 27, 2020
dee7fa5
wip
Borda Jul 27, 2020
e2ece1b
wip
Borda Jul 27, 2020
6c11dae
wip
Borda Jul 27, 2020
c216283
wip
Borda Jul 27, 2020
4f726a7
wip
Borda Jul 27, 2020
91d1f56
wip
Borda Jul 27, 2020
d6f1137
wip
Borda Jul 27, 2020
9465404
docs
Borda Jul 27, 2020
75dc5d2
typo
Borda Jul 27, 2020
5a7bfa0
tests
Borda Jul 27, 2020
22dd8c0
tests
Borda Jul 27, 2020
e4f2088
tests
Borda Jul 27, 2020
d89de9b
tests
Borda Jul 27, 2020
1614724
tests
Borda Jul 27, 2020
def846b
tests
Borda Jul 27, 2020
f6060ec
tests
Borda Jul 27, 2020
f0a6174
tests
Borda Jul 27, 2020
bdaaaef
tests
Borda Jul 27, 2020
458471d
tests
Borda Jul 27, 2020
6e6588a
tests
Borda Jul 27, 2020
6dc0e85
tests
Borda Jul 27, 2020
6196724
tests
Borda Jul 27, 2020
183fcc2
tests
Borda Jul 27, 2020
e7a5295
tests
Borda Jul 27, 2020
1999198
tests
Borda Jul 27, 2020
9579bce
tests
Borda Jul 27, 2020
fde01de
docs
Borda Jul 27, 2020
6abee20
docs
Borda Jul 27, 2020
bf6ac74
Apply suggestions from code review
Borda Jul 27, 2020
6f30bc5
Apply suggestions from code review
Borda Jul 27, 2020
45233a5
docs
Borda Jul 27, 2020
6c9495e
Apply suggestions from code review
Borda Jul 27, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,11 @@ references:
# happened to the job in Kubernetes. If we try MAX_CHECKS times and
# still the job hasn't finished, give up and return the starting
# non-zero status code.
while [ $i -lt $MAX_CHECKS ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else echo "Job not finished yet"; fi; sleep 30; done && \
printf "Waiting for job to finish: " && \
while [ $i -lt $MAX_CHECKS ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else printf "."; fi; sleep $CHECK_SPEEP; done && \
echo "Done waiting. Job status code: $status_code" && \
# Allow time for logs to flush.
sleep 30 && \
sleep 10 && \
echo "JOB_NAME: $job_name" && \
gcloud logging read "resource.type=k8s_container resource.labels.project_id=$GOOGLE_PROJECT_ID resource.labels.location=$GOOGLE_COMPUTE_ZONE resource.labels.cluster_name=$GKE_CLUSTER resource.labels.namespace_name=default resource.labels.pod_name:$job_name" --limit 10000000 --order asc --format 'value(textPayload)' --project=$GOOGLE_PROJECT_ID > /tmp/full_output.txt && \
if grep -q '<?xml version="1.0" ?>' /tmp/full_output.txt ; then csplit /tmp/full_output.txt '/<?xml version="1.0" ?>/'; else mv /tmp/full_output.txt xx00; fi && \
Expand Down Expand Up @@ -101,7 +102,8 @@ jobs:
docker:
- image: circleci/python:3.7
environment:
- MAX_CHECKS: 60
- MAX_CHECKS: 240
- CHECK_SPEEP: 5
steps:
- checkout
- go/install
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ jobs:
# TODO: temporary fix till hanging jobs on macOS for py38 is resolved
- python-version: 3.8
os: macOS-10.15
# TODO: temporary fix till pyYaml can be installed, see: https://github.com/actions/setup-python/issues/114
- python-version: 3.7
os: ubuntu-18.04
requires: 'minimal'

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 25
Expand Down
12 changes: 7 additions & 5 deletions .github/workflows/tpu-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ env:
GKE_CLUSTER: lightning-cluster
GKE_ZONE: us-central1-a
IMAGE: gcr.io/${{ secrets.GKE_PROJECT }}/tpu-testing-image
MAX_CHECKS: 240
CHECK_SPEEP: 5

jobs:
setup-build-publish-deploy:
Expand Down Expand Up @@ -82,17 +84,17 @@ jobs:
job_name=${job_name% created} && \
echo "Waiting on kubernetes job: $job_name in cluster: $GKE_CLUSTER" && \
i=0 && \
# 30 checks spaced 30s apart = 900s total.
max_checks=30 && \
# 60 checks spaced 30s apart = 900s total.
status_code=2 && \
# Check on the job periodically. Set the status code depending on what
# happened to the job in Kubernetes. If we try max_checks times and
# happened to the job in Kubernetes. If we try MAX_CHECKS times and
# still the job hasn't finished, give up and return the starting
# non-zero status code.
while [ $i -lt $max_checks ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else echo "Job not finished yet"; fi; sleep 30; done && \
printf "Waiting for job to finish: " && \
while [ $i -lt $MAX_CHECKS ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else printf "." ; fi; sleep $CHECK_SPEEP; done && \
echo "Done waiting. Job status code: $status_code" && \
# Allow time for logs to flush.
sleep 60 && \
sleep 10 && \
echo "JOB_NAME: $job_name" && \
echo "GKE_CLUSTER: $GKE_CLUSTER" && \
echo "GKE_ZONE: $GKE_ZONE" && \
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `weights_save_path` getting ignored when `logger=False` is passed to Trainer ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681))

- Fixed TPU multi-core and Float16 ([#2632](https://github.com/PyTorchLightning/pytorch-lightning/pull/2632))

## [0.8.5] - 2020-07-09

### Added
Expand Down
2 changes: 0 additions & 2 deletions docs/source/new-project.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms


Quick Start
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@
# We are not importing the rest of the lightning during the build process, as it may not be compiled yet
else:
from pytorch_lightning.core import LightningDataModule, LightningModule, data_loader
from pytorch_lightning.core.step_result import TrainResult, EvalResult
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning import metrics
from pytorch_lightning.core.step_result import TrainResult, EvalResult

__all__ = [
'Trainer',
Expand Down
19 changes: 10 additions & 9 deletions pytorch_lightning/accelerator_backends/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,27 @@ class DDPSpawnBackend(object):

def __init__(self, trainer):
self.trainer = trainer
self.q = None
self.mp_queue = None

def setup(self):
self.trainer.set_random_port()

# pass in a state q
smp = mp.get_context('spawn')
self.q = smp.SimpleQueue()
self.mp_queue = smp.SimpleQueue()

def train(self, model, nprocs):
mp.spawn(self.ddp_train, nprocs=nprocs, args=(self.q, model,))
mp.spawn(self.ddp_train, nprocs=nprocs, args=(self.mp_queue, model,))

def teardown(self, model):
# restore main state with best weights
best_path = self.q.get()
results = self.q.get()
last_path = self.q.get()
best_path = self.mp_queue.get()
results = self.mp_queue.get()
last_path = self.mp_queue.get()

# transfer back the best path to the trainer
self.trainer.checkpoint_callback.best_model_path = best_path
# todo, pass also bets score

# load last weights
if last_path is not None and not self.trainer.testing:
Expand All @@ -59,13 +60,13 @@ def teardown(self, model):
self.trainer.model = model
return results

def ddp_train(self, process_idx, q, model, is_master=False, proc_offset=0):
def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0):
"""
Entry point for ddp

Args:
process_idx:
q:
mp_queue: multiprocessing queue
model:
is_master:
proc_offset:
Expand Down Expand Up @@ -166,7 +167,7 @@ def ddp_train(self, process_idx, q, model, is_master=False, proc_offset=0):
model = self.trainer.get_model()

# persist info in ddp_spawn
self.trainer.transfer_ddp_spawn_state_on_fit_end(model, q, results)
self.trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)

# clean up memory
torch.cuda.empty_cache()
5 changes: 3 additions & 2 deletions pytorch_lightning/accelerator_backends/gpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch

from pytorch_lightning.core import LightningModule
try:
from apex import amp
except ImportError:
Expand Down Expand Up @@ -45,15 +46,15 @@ def setup(self, model):

# TODO: remove with dropping NVIDIA AMP support
native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
if self.trainer.use_amp and not native_amp_available:
if APEX_AVAILABLE and self.trainer.use_amp and not native_amp_available:
model = self._setup_nvidia_apex(model)
return model

def train(self, model):
results = self.trainer.run_pretrain_routine(model)
return results

def _setup_nvidia_apex(self, model):
def _setup_nvidia_apex(self, model: LightningModule):
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
Expand Down
105 changes: 68 additions & 37 deletions pytorch_lightning/accelerator_backends/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@
# limitations under the License.

import os

import torch
import torch.multiprocessing as mp

from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning import _logger as log


try:
import torch_xla
import torch_xla.core.xla_model as xm
Expand All @@ -33,31 +37,52 @@ class TPUBackend(object):
def __init__(self, trainer):
self.trainer = trainer
self.start_method = None
self.mp_queue = None

def setup(self):
rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores')

if not XLA_AVAILABLE:
raise MisconfigurationException('No TPU devices found.')
raise MisconfigurationException('PyTorch XLA not installed.')

# see: https://discuss.pytorch.org/t/segfault-with-multiprocessing-queue/81292/2
self.start_method = 'fork'

# pass in a state q
smp = mp.get_context(self.start_method)
self.mp_queue = smp.SimpleQueue()

def teardown(self, model):
# restore main state with best weights
best_path = self.mp_queue.get()
results = self.mp_queue.get()
last_path = self.mp_queue.get()

# COLAB_GPU is an env var available by default in Colab environments.
self.start_method = 'fork' if self.trainer.on_colab_kaggle else 'spawn'
# transfer back the best path to the trainer
self.trainer.checkpoint_callback.best_model_path = best_path
# todo, pass also bets score

def teardown(self):
# load last weights
if last_path and not self.trainer.testing:
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt)

self.trainer.model = model

# when training completes, load the weights back in main process
self.__load_weights_on_main_process()
return results

def train(self, model):
def train(self, model: LightningModule):
self.trainer.model = model

# train
if self.trainer.tpu_id is not None:
self.tpu_train_in_process(self.trainer.tpu_id, model)
self.tpu_train_in_process(self.trainer.tpu_id, model, self.trainer, self.mp_queue)
else:
xmp.spawn(
self.tpu_train_in_process,
args=(model,),
args=(model, self.trainer, self.mp_queue),
nprocs=self.trainer.tpu_cores,
start_method=self.start_method
)
Expand All @@ -71,63 +96,69 @@ def __load_weights_on_main_process(self):

self.trainer.model = model

def tpu_train_in_process(self, tpu_core_idx, model):
def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, trainer=None, mp_queue=None):
"""
Here we are inside each individual process
"""
if not self.trainer.testing:
self.trainer.setup('fit')
if not trainer:
trainer = self.trainer
if not trainer.testing:
trainer.setup('fit')
model.setup('fit')

# setup TPU training
self.__setup_tpu_training(model)
self.__setup_tpu_training(model, trainer)

# Run the pretrain routine
self.trainer.run_pretrain_routine(model)
results = trainer.run_pretrain_routine(model)

# save weights at the end of training
self.__save_end_of_training_weights(model)
self.__save_end_of_training_weights(model, trainer)

def __save_end_of_training_weights(self, model):
# persist info in spawn
trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)

def __save_end_of_training_weights(self, model: LightningModule, trainer):
# when training ends on these platforms dump weights to get out of the main process
if self.trainer.on_colab_kaggle:
if trainer.on_colab_kaggle:
rank_zero_warn('cleaning up... please do not interrupt')
self.trainer.save_spawn_weights(model)
trainer.save_spawn_weights(model)

def __setup_tpu_training(self, model):
def __setup_tpu_training(self, model: LightningModule, trainer):
# use the default device from the process
tpu_device = xm.xla_device()
# tpu_device = xm.xla_device()

# if given an ordinal device, use this as the device
if self.trainer.tpu_id is not None:
tpu_device = xm.xla_device(self.trainer.tpu_id)

if trainer.tpu_id is not None:
tpu_device = xm.xla_device(trainer.tpu_id)
else:
tpu_device = xm.xla_device()
# track the device and move model to it
self.trainer._device = tpu_device
model.to(self.trainer._device)
trainer._device = tpu_device
model.to(trainer._device)

# get the appropriate tpu ranks
self.trainer.tpu_local_core_rank = xm.get_local_ordinal()
self.trainer.tpu_global_core_rank = xm.get_ordinal()
trainer.tpu_local_core_rank = xm.get_local_ordinal()
trainer.tpu_global_core_rank = xm.get_ordinal()

# avoid duplicating progress bar
if self.trainer.tpu_global_core_rank != 0 and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()
if trainer.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
trainer.progress_bar_callback.disable()

self.trainer.global_rank = self.trainer.tpu_local_core_rank
rank_zero_only.rank = self.trainer.global_rank
trainer.global_rank = trainer.tpu_local_core_rank
rank_zero_only.rank = trainer.global_rank

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
self.trainer.optimizers = optimizers
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies
optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers(model)
trainer.optimizers = optimizers
trainer.lr_schedulers = lr_schedulers
trainer.optimizer_frequencies = optimizer_frequencies

# init 16 bit for TPU
if self.trainer.precision == 16:
if trainer.precision == 16:
os.environ['XLA_USE_BF16'] = str(1)

log.info(f'INIT TPU local core: {self.trainer.tpu_local_core_rank},'
f' global rank: {self.trainer.tpu_global_core_rank}')
log.info(f'INIT TPU local core: {trainer.tpu_local_core_rank},'
f' global rank: {trainer.tpu_global_core_rank}'
f' with XLA_USE_BF16={os.environ.get("XLA_USE_BF16")}')
6 changes: 5 additions & 1 deletion pytorch_lightning/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,5 +305,9 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.core.decorators import data_loader
from pytorch_lightning.core.lightning import LightningModule

__all__ = ['LightningDataModule', 'LightningModule', 'data_loader']
__all__ = [
'LightningDataModule',
'LightningModule',
'data_loader',
]
# __call__ = __all__
2 changes: 0 additions & 2 deletions pytorch_lightning/core/decorators.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from functools import wraps
from typing import Callable

import torch

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn

Expand Down
Loading