Skip to content

Commit

Permalink
🚀 [RofuncRL] encoder can load pre-trained sub module with name in cfg
Browse files Browse the repository at this point in the history
  • Loading branch information
Skylark0924 committed Aug 27, 2023
1 parent b8125c2 commit 3ee7408
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions rofunc/learning/RofuncRL/state_encoders/base_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch.nn as nn
from omegaconf import DictConfig

import rofunc as rf
from rofunc.config.utils import omegaconf_to_dict
from rofunc.learning.RofuncRL.models.base_models import BaseMLP

Expand All @@ -41,6 +42,8 @@ def __init__(self, cfg: DictConfig, cfg_name: str = 'state_encoder'):
self.use_pretrained = self.cfg_dict[cfg_name]['use_pretrained']
self.freeze = self.cfg_dict[cfg_name]['freeze']
self.model_ckpt = self.cfg_dict[cfg_name]['model_ckpt']
self.model_module_name = self.cfg_dict[cfg_name]['model_module_name'] if 'model_module_name' in self.cfg_dict[
cfg_name] else None

def set_up(self):
if self.freeze:
Expand All @@ -52,6 +55,7 @@ def freeze_network(self):
for net in self.freeze_net_list:
for param in net.parameters():
param.requires_grad = False
rf.logger.beauty_print(f"Freeze state encoder", type="info")

def pre_trained_mode(self):
if self.use_pretrained is True and self.model_ckpt is None:
Expand All @@ -70,6 +74,8 @@ def save_ckpt(self, path: str):

def load_ckpt(self, path: str):
modules = torch.load(path)
if self.model_module_name is not None:
modules = modules[self.model_module_name]
if type(modules) is dict:
for name, data in modules.items():
module = self.checkpoint_modules.get(name, None)
Expand All @@ -80,6 +86,7 @@ def load_ckpt(self, path: str):
module.eval()
else:
raise NotImplementedError
rf.logger.beauty_print(f"Loaded pretrained state encoder model from {self.model_ckpt}", type="info")


class MLPEncoder(BaseMLP):
Expand Down

0 comments on commit 3ee7408

Please sign in to comment.