Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Logging refactor 2/n - train #4495

Merged
merged 53 commits into from
Nov 5, 2020
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
2b3c4bc
update logging
tchaton Nov 3, 2020
e2814ad
Merge branch 'master' into feat/train_logging
tchaton Nov 3, 2020
f487a5d
Merge branch 'master' into feat/train_logging
tchaton Nov 3, 2020
ba0427f
solve more bugs
tchaton Nov 3, 2020
c68995a
Merge branch 'feat/train_logging' of https://github.com/PyTorchLightn…
tchaton Nov 3, 2020
8337394
replace Mapping by Dict
tchaton Nov 3, 2020
3862ef7
update on comments
tchaton Nov 3, 2020
23a62ac
resolve pep8
tchaton Nov 3, 2020
ebf6573
Merge branch 'master' into feat/train_logging
tchaton Nov 3, 2020
3921725
Apply suggestions from code review
Borda Nov 3, 2020
09ace23
Merge branch 'master' into feat/train_logging
tchaton Nov 4, 2020
c9308d4
Merge branch 'feat/train_logging' of https://github.com/PyTorchLightn…
tchaton Nov 4, 2020
e459131
Update pytorch_lightning/trainer/connectors/logger_connector/epoch_re…
tchaton Nov 4, 2020
abd0fc0
Merge branch 'master' into feat/train_logging
tchaton Nov 4, 2020
a8371bf
update on comments
tchaton Nov 4, 2020
92994d9
typo
tchaton Nov 4, 2020
3539faa
Merge branch 'master' into feat/train_logging
tchaton Nov 4, 2020
f3b4f1f
update for coverage
tchaton Nov 4, 2020
453abed
update test
tchaton Nov 4, 2020
fb72bff
update
tchaton Nov 4, 2020
93c596d
Merge branch 'feat/train_logging' of https://github.com/PyTorchLightn…
tchaton Nov 4, 2020
0decd22
Update tests/models/test_hooks.py
tchaton Nov 4, 2020
005e91b
Update tests/models/test_hooks.py
tchaton Nov 4, 2020
21da81f
Merge branch 'master' into feat/train_logging
tchaton Nov 4, 2020
8f879db
update on comments
tchaton Nov 4, 2020
1983cc1
Merge branch 'master' into feat/train_logging
tchaton Nov 4, 2020
7b4e9e0
Merge branch 'master' into feat/train_logging
tchaton Nov 5, 2020
0e41cad
remove deepcopy
tchaton Nov 5, 2020
fcf74e5
Merge branch 'feat/train_logging' of https://github.com/PyTorchLightn…
tchaton Nov 5, 2020
25692c8
remove useless look for
tchaton Nov 5, 2020
2859f5c
another small optim
tchaton Nov 5, 2020
f0a13bb
extra optim
tchaton Nov 5, 2020
5535b0a
remove lastest optim, can be source of bug
tchaton Nov 5, 2020
ae0c00f
resolve bug
tchaton Nov 5, 2020
3e6fc63
add docstring
tchaton Nov 5, 2020
43f5c45
optimize coverage
tchaton Nov 5, 2020
aa393c3
Update pytorch_lightning/trainer/connectors/logger_connector/epoch_re…
tchaton Nov 5, 2020
bc62cff
Update pytorch_lightning/trainer/connectors/logger_connector/epoch_re…
tchaton Nov 5, 2020
85317ad
Update tests/trainer/logging_tests/test_distributed_logging.py
tchaton Nov 5, 2020
d492d94
Update pytorch_lightning/trainer/evaluation_loop.py
tchaton Nov 5, 2020
caea74c
Update tests/trainer/logging/test_logger_connector.py
tchaton Nov 5, 2020
5bc3847
Update tests/trainer/logging_tests/test_train_loop_logging_1_0.py
tchaton Nov 5, 2020
66a89a8
Merge branch 'master' into feat/train_logging
tchaton Nov 5, 2020
6c7373a
update on comments
tchaton Nov 5, 2020
60f95a8
Merge branch 'feat/train_logging' of https://github.com/PyTorchLightn…
tchaton Nov 5, 2020
ef2065c
update
tchaton Nov 5, 2020
6a9bcc5
update on comments
tchaton Nov 5, 2020
e22d9f8
update parity speed
tchaton Nov 5, 2020
395df7f
get it down to 0.65
tchaton Nov 5, 2020
ae64091
update
tchaton Nov 5, 2020
59ca975
Merge branch 'master' into feat/train_logging
williamFalcon Nov 5, 2020
4c19f96
0.8 max_dif
tchaton Nov 5, 2020
c86643b
Merge branch 'feat/train_logging' of https://github.com/PyTorchLightn…
tchaton Nov 5, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from collections import defaultdict
from copy import deepcopy
from enum import Enum
from typing import Union, Tuple, Any, Mapping

from typing import Union, Tuple, Any, Dict
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import Result


Expand Down Expand Up @@ -96,14 +96,16 @@ def __init__(self, fx_name):
def get_reduced_metrics(self):
return self._internals_reduced

def add_dataloader_idx(self):
return len(self._internals) > 1
@property
def add_dataloader_idx(self) -> bool:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return self.num_dataloaders > 1

@property
def num_dataloaders(self):
return len(self._internals)
def num_dataloaders(self) -> int:
_inter = self._internals_reduced if self.has_reduced else self._internals
return len(_inter)

def get_latest_from_dict(self, dl_idx):
def get_latest_from_dict(self, dl_idx: str) -> Result:
num_opt_idx = len(self._internals[dl_idx]) - 1
assert num_opt_idx >= 0
num_opt_idx = str(num_opt_idx)
Expand All @@ -125,7 +127,7 @@ def check_dataloader_idx(self, result: Result) -> bool:
except Exception:
return add_dataloader_idx

def get_lastest_from_func_name(self, func_name, *args, latest=True, **kwargs):
def get_lastest_from_func_name(self, func_name: str, *args, latest: bool = True, **kwargs) -> Dict:
results = {}
if latest:
for dl_idx in range(self.num_dataloaders):
Expand Down Expand Up @@ -157,7 +159,7 @@ def run_epoch_func(self, results, opt_metric, func_name, *args, **kwargs) -> Non
else:
raise Exception("The provided opt_metric should be a Result Object. Something is wrong")

def get_epoch_from_func_name(self, func_name, *args, **kwargs) -> Mapping:
def get_epoch_from_func_name(self, func_name, *args, **kwargs) -> Dict:
results = {}
for dl_idx in range(self.num_dataloaders):
dl_idx = str(dl_idx)
Expand All @@ -169,13 +171,13 @@ def get_epoch_from_func_name(self, func_name, *args, **kwargs) -> Mapping:
self.run_epoch_func(results, opt_metrics, func_name, *args, **kwargs)
return results

def get_epoch_pbar_metrics(self, *args, **kwargs) -> Mapping:
def get_epoch_pbar_metrics(self, *args, **kwargs) -> Dict:
return self.get_epoch_from_func_name("get_epoch_pbar_metrics")

def get_epoch_log_metrics(self, *args, **kwargs) -> Mapping:
def get_epoch_log_metrics(self, *args, **kwargs) -> Dict:
return self.get_epoch_from_func_name("get_epoch_log_metrics")

def get_forked_metrics(self, *args, **kwargs) -> Mapping:
def get_forked_metrics(self, *args, **kwargs) -> Dict:
return self.get_epoch_from_func_name("get_forked_metrics")

@staticmethod
Expand Down Expand Up @@ -271,7 +273,7 @@ def auto_reduce_results_on_epoch_end(self) -> None:
self._internals_reduced[dl_idx][str(opt_idx)] = opt_outputs

# free memory
del self._internals[dl_idx]
del self._internals[dl_idx][opt_idx]
else:
# no need to reduce as called only once
if len(epoch_metrics) == 1:
Expand Down Expand Up @@ -301,21 +303,16 @@ def __repr__(self):
class EpochResultStore:
"""
This class is defined for internal usage.

It holds all metrics logged using the self.log function using `HookResultStore` object.

The internal datastructure is as follow:

self._internals = {"fx_name_0": HookResultStore(), ..., "fx_name_n": HookResultStore()}

Pseudo Code Example:
```
model._current_fx_name = 'something'
model._results = Result()
model.log('a', ...)
epoch_result_store.cache_result()
```

"""
def __init__(self, trainer, stage):
self.trainer = trainer
Expand Down Expand Up @@ -365,7 +362,7 @@ def current_model_info(self):
model_ref = self.trainer.get_model()
# extract hook information
fx_name = model_ref._current_hook_fx_name
if fx_name == '':
if fx_name is None:
fx_name = model_ref._current_fx_name
dataloader_idx = model_ref._current_dataloader_idx
return fx_name, dataloader_idx
Expand Down Expand Up @@ -456,18 +453,22 @@ def update_logger_connector(self, fx_name: str = None) -> None:
logger_connector.callback_metrics.update(callback_metrics)
logger_connector.callback_metrics.pop("epoch", None)

def run_batch_from_func_name(self, func_name) -> Mapping:
def run_batch_from_func_name(self, func_name) -> Dict:
results = {}
for fx_name, hook_result in self._internals.items():
func = getattr(hook_result, func_name)
results.update(func(latest=True, include_forked_originals=False))
return results

def get_latest_batch_log_metrics(self) -> Mapping:
return self.run_batch_from_func_name("get_batch_log_metrics")
def get_latest_batch_log_metrics(self) -> Dict:
batch_log_metrics: Dict = self.run_batch_from_func_name("get_batch_log_metrics")
tchaton marked this conversation as resolved.
Show resolved Hide resolved
batch_log_metrics.update(self.legacy_batch_log_metrics)
return batch_log_metrics

def get_latest_batch_pbar_metrics(self) -> Mapping:
return self.run_batch_from_func_name("get_batch_pbar_metrics")
def get_latest_batch_pbar_metrics(self) -> Dict:
batch_pbar_metrics: Dict = self.run_batch_from_func_name("get_batch_pbar_metrics")
batch_pbar_metrics.update(self.legacy_batch_pbar_metrics)
return batch_pbar_metrics

@property
def has_reduced(self) -> bool:
Expand Down Expand Up @@ -495,7 +496,7 @@ def has_batch_loop_finished(self, has_batch_loop_finished):
self._has_batch_loop_finished = has_batch_loop_finished
self.update_logger_connector()

def run_epoch_by_func_name(self, func_name) -> Mapping:
def run_epoch_by_func_name(self, func_name) -> Dict:
if not self.has_reduced:
self.auto_reduce_results_on_epoch_end()
results = {}
Expand All @@ -504,16 +505,16 @@ def run_epoch_by_func_name(self, func_name) -> Mapping:
results.update(func())
return results

def get_epoch_pbar_metrics(self) -> Mapping:
def get_epoch_pbar_metrics(self) -> Dict:
return self.run_epoch_by_func_name("get_epoch_pbar_metrics")

def get_epoch_log_metrics(self) -> Mapping:
def get_epoch_log_metrics(self) -> Dict:
return self.run_epoch_by_func_name("get_epoch_log_metrics")

def get_forked_metrics(self) -> Mapping:
def get_forked_metrics(self) -> Dict:
return self.run_epoch_by_func_name("get_forked_metrics")

def get_reduced_metrics(self) -> Mapping:
def get_reduced_metrics(self) -> Dict:
return self.run_epoch_by_func_name("get_reduced_metrics")

def reset(self):
Expand All @@ -523,6 +524,56 @@ def reset(self):
self._opt_idx: Union[int, None] = None
self._batch_size: Union[int, None] = None
self._has_batch_loop_finished = False
self.legacy_batch_log_metrics = {}
self.legacy_batch_pbar_metrics = {}

def __call__(self,
fx_name: Union[str, int, None] = None,
dl_idx: Union[str, int, None] = None,
opt_idx: Union[str, int, None] = None,
batch_idx: Union[str, int, None] = None,
split_idx: Union[str, int, None] = None,
reduced=False):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""
This function is used to easily acces saved logged data.
"""

hook_result = self[str(fx_name)]

dl_idx = str(dl_idx) if dl_idx is not None else None
opt_idx = str(opt_idx) if opt_idx is not None else None
batch_idx = str(batch_idx) if batch_idx is not None else None
split_idx = int(split_idx) if split_idx is not None else None

internal_type = hook_result._internal_type
if internal_type is None:
return Result()

if reduced:
result = hook_result._internals_reduced
else:
result = hook_result._internals

if internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP:
if not reduced:
if dl_idx is not None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
tchaton marked this conversation as resolved.
Show resolved Hide resolved
result = result[dl_idx]
if opt_idx is not None:
result = result[opt_idx]
if batch_idx is not None:
result = result[batch_idx]
if split_idx is not None:
result = result[split_idx]
else:
if dl_idx is not None:
result = result[dl_idx]
if opt_idx is not None:
result = result[opt_idx]
else:
if dl_idx is not None:
result = result[dl_idx]

return result

def __repr__(self):
return f"{self.__class__.__name__}(stage={self._stage}, internals={self._internals})"
Loading