From 42a7bb7db0cdf53bb5beca76c9792a1f06b1e727 Mon Sep 17 00:00:00 2001 From: thomfoster <36595849+thomfoster@users.noreply.github.com> Date: Mon, 5 Dec 2022 14:43:40 +0000 Subject: [PATCH] Add local rollout logging (#124) --- trlx/data/configs.py | 5 +++++ trlx/model/accelerate_ppo_model.py | 24 ++++++++++++++++++++++++ trlx/pipeline/ppo_pipeline.py | 13 ++++++++++++- 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/trlx/data/configs.py b/trlx/data/configs.py index badbec496..a8fff0569 100644 --- a/trlx/data/configs.py +++ b/trlx/data/configs.py @@ -98,6 +98,9 @@ class TrainConfig: :param entity_name: Entity name for wandb :type entity_name: str + + :param rollout_logging_dir: Directory to store generated rollouts for use in Algorithm Distillation. Only used by AcceleratePPOModel. + :type rollout_logging_dir: Optional[str] """ total_steps: int @@ -122,6 +125,8 @@ class TrainConfig: entity_name: Optional[str] = None seed: int = 1000 + rollout_logging_dir: Optional[str] = None + @classmethod def from_dict(cls, config: Dict[str, Any]): return cls(**config) diff --git a/trlx/model/accelerate_ppo_model.py b/trlx/model/accelerate_ppo_model.py index 40dbdd295..b3aaa120d 100644 --- a/trlx/model/accelerate_ppo_model.py +++ b/trlx/model/accelerate_ppo_model.py @@ -1,4 +1,5 @@ from typing import Tuple +import uuid, os, json import torch from torchtyping import TensorType @@ -21,6 +22,12 @@ class AcceleratePPOModel(AccelerateRLModel): def __init__(self, config): super().__init__(config) + if config.train.rollout_logging_dir is not None: + self.log_rollouts = True + self.setup_rollout_logging(config) + else: + self.log_rollouts = False + self.store = PPORolloutStorage(self.tokenizer.pad_token_id) rollout_loader = self.store.create_loader( @@ -103,7 +110,24 @@ def loss(self, batch: PPORLBatch): self.approx_kl = stats["policy/approx_kl"] # Update kl controller stats return loss, stats + def setup_rollout_logging(self, config): + # Make rollout logging dir for this run and store config + exists = os.path.exists(config.train.rollout_logging_dir) + isdir = os.path.isdir(config.train.rollout_logging_dir) + assert exists and isdir + + self.run_id = f"run-{uuid.uuid4()}" + self.rollout_logging_dir = os.path.join( + config.train.rollout_logging_dir, self.run_id + ) + os.mkdir(self.rollout_logging_dir) + + with open(os.path.join(self.rollout_logging_dir, "config.json"), "w") as f: + f.write(json.dumps(config.to_dict(), indent=2)) + def post_epoch_callback(self): + if self.log_rollouts: + self.store.export_history(location=self.rollout_logging_dir) self.store.clear_history() self.orch.make_experience( self.config.method.num_rollouts, self.iter_count diff --git a/trlx/pipeline/ppo_pipeline.py b/trlx/pipeline/ppo_pipeline.py index 91cf7ef84..44661ff7b 100644 --- a/trlx/pipeline/ppo_pipeline.py +++ b/trlx/pipeline/ppo_pipeline.py @@ -1,4 +1,6 @@ -from typing import Iterable +import os, json, time + +from typing import Iterable, Optional from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader @@ -25,6 +27,15 @@ def push(self, exps: Iterable[PPORLElement]): def clear_history(self): self.history = [] + def export_history(self, location: str): + assert os.path.exists(location) + + fpath = os.path.join(location, f"epoch-{str(time.time())}.json") + exp_to_dict = lambda exp: {k: v.cpu().tolist() for k, v in exp.__dict__.items()} + data = [exp_to_dict(exp) for exp in self.history] + with open(fpath, "w") as f: + f.write(json.dumps(data, indent=2)) + def __getitem__(self, index: int) -> PPORLElement: return self.history[index]