Skip to content

Commit d3a3978

Browse files
authored
use weights_only to True (#136)
1 parent f13ac8c commit d3a3978

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

shimmer/utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ def migrate_model(ckpt_path: str | PathLike, **torch_load_kwargs):
7171
torch_load_kwargs: additional args given to torch.load.
7272
"""
7373
ckpt_path = Path(ckpt_path)
74-
ckpt = torch.load(ckpt_path, **torch_load_kwargs)
74+
default_torch_kwargs: dict[str, Any] = {"weights_only": True}
75+
default_torch_kwargs.update(torch_load_kwargs)
76+
ckpt = torch.load(ckpt_path, **default_torch_kwargs)
7577
new_ckpt, done_migrations = migrate_from_folder(ckpt, MIGRATION_DIR)
7678
done_migration_log = ", ".join(map(lambda x: x.name, done_migrations))
7779
print(f"Migrating: {done_migration_log}")

tests/test_ckpt_migrations.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
def test_ckpt_migration_2_domains():
1616
old_ckpt_path = here / "data" / "old_gw_2_domains.ckpt"
17-
old_ckpt = torch.load(old_ckpt_path)
17+
old_ckpt = torch.load(old_ckpt_path, weights_only=True)
1818
new_ckpt, done_migrations = migrate_from_folder(old_ckpt, MIGRATION_DIR)
1919

2020
old_keys = set(old_ckpt["state_dict"].keys())
@@ -74,7 +74,7 @@ def test_ckpt_migration_2_domains():
7474

7575
def test_ckpt_migration_gw():
7676
old_ckpt_path = here / "data" / "old_gw.ckpt"
77-
old_ckpt = torch.load(old_ckpt_path)
77+
old_ckpt = torch.load(old_ckpt_path, weights_only=True)
7878
new_ckpt, done_migrations = migrate_from_folder(old_ckpt, MIGRATION_DIR)
7979

8080
old_keys = set(old_ckpt["state_dict"].keys())

0 commit comments

Comments
 (0)