diff --git a/trlx/model/__init__.py b/trlx/model/__init__.py index 241f7fbd6..48e0148a3 100644 --- a/trlx/model/__init__.py +++ b/trlx/model/__init__.py @@ -96,41 +96,14 @@ def learn( pass @abstractmethod - def get_components(self) -> Dict[str, Any]: - """ - Get pytorch components (mainly for saving/loading) - """ + def save(self, directory=None): + """Creates a checkpoint of training states""" pass - def save(self, fp: str, title: str = "OUT"): - """ - Try to save all components to specified path under a folder with given title - """ - path = os.path.join(fp, title) - safe_mkdir(path) - - components = self.get_components() - for name in components: - try: - torch.save(components[name], os.path.join(path, name) + ".pt") - except: - print(f"Failed to save component: {name}, continuing.") - - def load(self, fp: str, title: str = "OUT"): - """ - Try to load all components from specified path under a folder with given title - """ - - path = os.path.join(fp, title) - - components = self.get_components() - for name in components: - try: - components[name] = torch.load( - os.path.join(path, name) + ".pt", map_location="cpu" - ) - except: - print(f"Failed to load component: {name}, continuing.") + @abstractmethod + def load(self, directory=None): + """Loads a checkpoint created from `save`""" + pass def intervals(self, steps: int) -> Dict[str, bool]: """ diff --git a/trlx/model/accelerate_base_model.py b/trlx/model/accelerate_base_model.py index 1da086fc0..96f25e8f9 100644 --- a/trlx/model/accelerate_base_model.py +++ b/trlx/model/accelerate_base_model.py @@ -128,18 +128,14 @@ def generate(self, input_ids, attention_mask=None, **kwargs): input_ids=input_ids, attention_mask=attention_mask, **kwargs ) - def get_components(self) -> Dict[str, Any]: - components = ( - {"model": self.model, "opt": self.opt, "scheduler": self.scheduler} - if self.train_mode - else {"model": self.model} - ) - return components - def save(self, directory=None): """Creates checkpoint of optimizer, scheduler and a model""" self.accelerator.save_state(directory or self.config.train.checkpoint_dir) + def load(self, directory=None): + """Load checkpoint of optimizer, scheduler and a model""" + self.accelerator.load_state(directory or self.config.train.checkpoint_dir) + def add_eval_pipeline(self, eval_pipeline): """Adds pipeline from with validation prompts""" self.eval_pipeline = eval_pipeline