Skip to content

Commit 42a7bb7

Browse files
authored
Add local rollout logging (#124)
1 parent 5df5219 commit 42a7bb7

File tree

3 files changed

+41
-1
lines changed

3 files changed

+41
-1
lines changed

trlx/data/configs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ class TrainConfig:
9898
9999
:param entity_name: Entity name for wandb
100100
:type entity_name: str
101+
102+
:param rollout_logging_dir: Directory to store generated rollouts for use in Algorithm Distillation. Only used by AcceleratePPOModel.
103+
:type rollout_logging_dir: Optional[str]
101104
"""
102105

103106
total_steps: int
@@ -122,6 +125,8 @@ class TrainConfig:
122125
entity_name: Optional[str] = None
123126
seed: int = 1000
124127

128+
rollout_logging_dir: Optional[str] = None
129+
125130
@classmethod
126131
def from_dict(cls, config: Dict[str, Any]):
127132
return cls(**config)

trlx/model/accelerate_ppo_model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Tuple
2+
import uuid, os, json
23

34
import torch
45
from torchtyping import TensorType
@@ -21,6 +22,12 @@ class AcceleratePPOModel(AccelerateRLModel):
2122
def __init__(self, config):
2223
super().__init__(config)
2324

25+
if config.train.rollout_logging_dir is not None:
26+
self.log_rollouts = True
27+
self.setup_rollout_logging(config)
28+
else:
29+
self.log_rollouts = False
30+
2431
self.store = PPORolloutStorage(self.tokenizer.pad_token_id)
2532

2633
rollout_loader = self.store.create_loader(
@@ -103,7 +110,24 @@ def loss(self, batch: PPORLBatch):
103110
self.approx_kl = stats["policy/approx_kl"] # Update kl controller stats
104111
return loss, stats
105112

113+
def setup_rollout_logging(self, config):
114+
# Make rollout logging dir for this run and store config
115+
exists = os.path.exists(config.train.rollout_logging_dir)
116+
isdir = os.path.isdir(config.train.rollout_logging_dir)
117+
assert exists and isdir
118+
119+
self.run_id = f"run-{uuid.uuid4()}"
120+
self.rollout_logging_dir = os.path.join(
121+
config.train.rollout_logging_dir, self.run_id
122+
)
123+
os.mkdir(self.rollout_logging_dir)
124+
125+
with open(os.path.join(self.rollout_logging_dir, "config.json"), "w") as f:
126+
f.write(json.dumps(config.to_dict(), indent=2))
127+
106128
def post_epoch_callback(self):
129+
if self.log_rollouts:
130+
self.store.export_history(location=self.rollout_logging_dir)
107131
self.store.clear_history()
108132
self.orch.make_experience(
109133
self.config.method.num_rollouts, self.iter_count

trlx/pipeline/ppo_pipeline.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Iterable
1+
import os, json, time
2+
3+
from typing import Iterable, Optional
24

35
from torch.nn.utils.rnn import pad_sequence
46
from torch.utils.data import DataLoader
@@ -25,6 +27,15 @@ def push(self, exps: Iterable[PPORLElement]):
2527
def clear_history(self):
2628
self.history = []
2729

30+
def export_history(self, location: str):
31+
assert os.path.exists(location)
32+
33+
fpath = os.path.join(location, f"epoch-{str(time.time())}.json")
34+
exp_to_dict = lambda exp: {k: v.cpu().tolist() for k, v in exp.__dict__.items()}
35+
data = [exp_to_dict(exp) for exp in self.history]
36+
with open(fpath, "w") as f:
37+
f.write(json.dumps(data, indent=2))
38+
2839
def __getitem__(self, index: int) -> PPORLElement:
2940
return self.history[index]
3041

0 commit comments

Comments
 (0)