Skip to content

Commit

Permalink
Add local rollout logging (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomfoster authored Dec 5, 2022
1 parent 5df5219 commit 42a7bb7
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
5 changes: 5 additions & 0 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ class TrainConfig:
:param entity_name: Entity name for wandb
:type entity_name: str
:param rollout_logging_dir: Directory to store generated rollouts for use in Algorithm Distillation. Only used by AcceleratePPOModel.
:type rollout_logging_dir: Optional[str]
"""

total_steps: int
Expand All @@ -122,6 +125,8 @@ class TrainConfig:
entity_name: Optional[str] = None
seed: int = 1000

rollout_logging_dir: Optional[str] = None

@classmethod
def from_dict(cls, config: Dict[str, Any]):
return cls(**config)
Expand Down
24 changes: 24 additions & 0 deletions trlx/model/accelerate_ppo_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Tuple
import uuid, os, json

import torch
from torchtyping import TensorType
Expand All @@ -21,6 +22,12 @@ class AcceleratePPOModel(AccelerateRLModel):
def __init__(self, config):
super().__init__(config)

if config.train.rollout_logging_dir is not None:
self.log_rollouts = True
self.setup_rollout_logging(config)
else:
self.log_rollouts = False

self.store = PPORolloutStorage(self.tokenizer.pad_token_id)

rollout_loader = self.store.create_loader(
Expand Down Expand Up @@ -103,7 +110,24 @@ def loss(self, batch: PPORLBatch):
self.approx_kl = stats["policy/approx_kl"] # Update kl controller stats
return loss, stats

def setup_rollout_logging(self, config):
# Make rollout logging dir for this run and store config
exists = os.path.exists(config.train.rollout_logging_dir)
isdir = os.path.isdir(config.train.rollout_logging_dir)
assert exists and isdir

self.run_id = f"run-{uuid.uuid4()}"
self.rollout_logging_dir = os.path.join(
config.train.rollout_logging_dir, self.run_id
)
os.mkdir(self.rollout_logging_dir)

with open(os.path.join(self.rollout_logging_dir, "config.json"), "w") as f:
f.write(json.dumps(config.to_dict(), indent=2))

def post_epoch_callback(self):
if self.log_rollouts:
self.store.export_history(location=self.rollout_logging_dir)
self.store.clear_history()
self.orch.make_experience(
self.config.method.num_rollouts, self.iter_count
Expand Down
13 changes: 12 additions & 1 deletion trlx/pipeline/ppo_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Iterable
import os, json, time

from typing import Iterable, Optional

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
Expand All @@ -25,6 +27,15 @@ def push(self, exps: Iterable[PPORLElement]):
def clear_history(self):
self.history = []

def export_history(self, location: str):
assert os.path.exists(location)

fpath = os.path.join(location, f"epoch-{str(time.time())}.json")
exp_to_dict = lambda exp: {k: v.cpu().tolist() for k, v in exp.__dict__.items()}
data = [exp_to_dict(exp) for exp in self.history]
with open(fpath, "w") as f:
f.write(json.dumps(data, indent=2))

def __getitem__(self, index: int) -> PPORLElement:
return self.history[index]

Expand Down

0 comments on commit 42a7bb7

Please sign in to comment.