From 520940c6c10696f8f3339c8a02210475b6c5b62d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 6 Jan 2022 21:49:45 +0800 Subject: [PATCH] allow extract backbone weights on non-GPU machines --- extras.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extras.py b/extras.py index a313070..07b9858 100644 --- a/extras.py +++ b/extras.py @@ -108,9 +108,9 @@ def forward(self, batch, target): def extract_backbone_weights(lightning_ckpt_path, save_path): - ckpt = torch.load(lightning_ckpt_path) + ckpt = torch.load(lightning_ckpt_path, map_location='cpu') state_dict = ckpt["state_dict"] - backbone_weights = {k[len("backbone."):]: v.cpu() for k, v in state_dict.items() if k.startswith("backbone.")} + backbone_weights = {k[len("backbone."):]: v for k, v in state_dict.items() if k.startswith("backbone.")} torch.save(backbone_weights, save_path)