Skip to content

Commit

Permalink
allow extract backbone weights on non-GPU machines
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Jan 6, 2022
1 parent 0b7a031 commit 520940c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ def forward(self, batch, target):


def extract_backbone_weights(lightning_ckpt_path, save_path):
ckpt = torch.load(lightning_ckpt_path)
ckpt = torch.load(lightning_ckpt_path, map_location='cpu')
state_dict = ckpt["state_dict"]
backbone_weights = {k[len("backbone."):]: v.cpu() for k, v in state_dict.items() if k.startswith("backbone.")}
backbone_weights = {k[len("backbone."):]: v for k, v in state_dict.items() if k.startswith("backbone.")}
torch.save(backbone_weights, save_path)


Expand Down

0 comments on commit 520940c

Please sign in to comment.