Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle checkpoint dirpath suffix in NeptuneLogger #18863

Merged
merged 33 commits into from
Nov 25, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
51f0b92
handle checkpoint dirpath suffix
AleksanderWWW Oct 25, 2023
1adfb10
handle without relying on system separator
AleksanderWWW Oct 25, 2023
923b6fe
fix for windows
AleksanderWWW Oct 25, 2023
c3971de
fix
AleksanderWWW Oct 25, 2023
c1cfcfa
Apply suggestions from code review
AleksanderWWW Oct 25, 2023
5af40d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2023
08cbda8
Update src/lightning/pytorch/loggers/neptune.py
AleksanderWWW Oct 26, 2023
66bd7ab
Merge branch 'master' into aw/fix-path-suffix
AleksanderWWW Oct 26, 2023
3ab88c0
add doctest + function rename
AleksanderWWW Oct 27, 2023
fdd2728
single quotes in doctest
AleksanderWWW Oct 27, 2023
a513dcb
Merge branch 'master' into aw/fix-path-suffix
AleksanderWWW Oct 27, 2023
1ec4687
Apply suggestions from code review
AleksanderWWW Oct 27, 2023
84b4b32
use os.path.normpath
AleksanderWWW Oct 29, 2023
6d7f9f7
Merge branch 'master' into aw/fix-path-suffix
AleksanderWWW Oct 29, 2023
1db1922
fix unit tests
AleksanderWWW Oct 30, 2023
d4b696f
fix for windows
AleksanderWWW Oct 30, 2023
804d59d
Merge branch 'master' into aw/fix-path-suffix
AleksanderWWW Oct 30, 2023
1a20eee
Update src/lightning/pytorch/loggers/neptune.py
AleksanderWWW Oct 30, 2023
72ed887
Merge branch 'master' into aw/fix-path-suffix
AleksanderWWW Oct 30, 2023
c4f0eb2
Merge branch 'master' into aw/fix-path-suffix
AleksanderWWW Oct 30, 2023
0ef1ff5
Merge branch 'master' into aw/fix-path-suffix
AleksanderWWW Nov 2, 2023
6e9d09c
Merge branch 'master' into aw/fix-path-suffix
AleksanderWWW Nov 2, 2023
9f8c816
Merge branch 'master' into aw/fix-path-suffix
AleksanderWWW Nov 6, 2023
05a27ec
Merge branch 'master' into aw/fix-path-suffix
Borda Nov 18, 2023
8de17fe
use pathlib
AleksanderWWW Nov 22, 2023
cba5a3c
Merge branch 'master' into aw/fix-path-suffix
AleksanderWWW Nov 22, 2023
3a127eb
typing
AleksanderWWW Nov 22, 2023
404acf9
typing
AleksanderWWW Nov 22, 2023
bdb978a
norm the path at return
AleksanderWWW Nov 22, 2023
e0f246d
fix for windows
AleksanderWWW Nov 22, 2023
9f2c968
add test case
awaelchli Nov 25, 2023
e711a13
Merge branch 'master' into aw/fix-path-suffix
awaelchli Nov 25, 2023
1f5130e
chlog
awaelchli Nov 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/lightning/pytorch/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@
_INTEGRATION_VERSION_KEY = "source_code/integrations/pytorch-lightning"


def _get_expected_model_path(dir_path: str) -> str:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
expected_model_path = dir_path

while expected_model_path and expected_model_path[-1] in ("/", "\\"):
expected_model_path = expected_model_path[:-1]

return f"{expected_model_path}{os.path.sep}"


class NeptuneLogger(Logger):
r"""Log using `Neptune <https://neptune.ai>`_.

Expand Down Expand Up @@ -551,8 +560,8 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
def _get_full_model_name(model_path: str, checkpoint_callback: Checkpoint) -> str:
"""Returns model name which is string `model_path` appended to `checkpoint_callback.dirpath`."""
if hasattr(checkpoint_callback, "dirpath"):
expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}"
if not model_path.startswith(expected_model_path):
expected_model_path = _get_expected_model_path(checkpoint_callback.dirpath)
if not model_path.startswith(expected_model_path[:-1]):
raise ValueError(f"{model_path} was expected to start with {expected_model_path}.")
# Remove extension from filepath
filepath, _ = os.path.splitext(model_path[len(expected_model_path) :])
Expand Down
19 changes: 19 additions & 0 deletions tests/tests_pytorch/loggers/test_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import os
import pickle
import sys
from collections import namedtuple
from unittest import mock
from unittest.mock import MagicMock, call
Expand All @@ -23,6 +24,11 @@
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loggers import NeptuneLogger
from lightning.pytorch.loggers.neptune import _get_expected_model_path

IS_WINDOWS = sys.platform == "win32"
skip_if_on_windows = pytest.mark.skipif(IS_WINDOWS, reason="Those tests are specific to non-windows systems")
skip_if_not_windows = pytest.mark.skipif(not IS_WINDOWS, reason="Those tests are specific to windows os")
AleksanderWWW marked this conversation as resolved.
Show resolved Hide resolved


def _fetchable_paths(value):
Expand Down Expand Up @@ -303,3 +309,16 @@ def test_get_full_model_names_from_exp_structure():
}
expected_keys = {"lvl1_1/lvl2/lvl3_1", "lvl1_1/lvl2/lvl3_2", "lvl1_2"}
assert NeptuneLogger._get_full_model_names_from_exp_structure(input_dict, "foo/bar") == expected_keys


@skip_if_on_windows
def test_get_expected_model_path():
assert _get_expected_model_path("my_model/checkpoints") == "my_model/checkpoints/"
assert _get_expected_model_path("my_model/checkpoints/") == "my_model/checkpoints/"
assert _get_expected_model_path("my_model/checkpoints//") == "my_model/checkpoints/"


@skip_if_not_windows
def test_get_expected_model_path_windows():
assert _get_expected_model_path("my_model\\checkpoints\\") == "my_model\\checkpoints\\"
AleksanderWWW marked this conversation as resolved.
Show resolved Hide resolved
assert _get_expected_model_path("my_model\\checkpoints") == "my_model\\checkpoints\\"