Skip to content

Commit 82435d8

Browse files
marcobellagente93Marco Bellagentedaia99Marco Bellagente
authored
Enable training with Tensorboard tracking (#209)
* pass logging_dir from config Co-authored-by: Andrew <[email protected]> * test tensorboard tracker * single tracker string in config * specify ini_trackers_kwargs for wandb only * added workaround to avoid nested dicts with tensorboard * tracker and logging_dir config values optional * validate tracker before init_trackers * merged 2 lines for flat logging dir * removed accidentally commited tensorboard files * Formatting * use flatten_dict to log with tensorboard * added back comments * fixed typo Co-authored-by: Marco Bellagente <[email protected]> Co-authored-by: Andrew <[email protected]> Co-authored-by: Marco Bellagente <[email protected]>
1 parent 2f23001 commit 82435d8

File tree

2 files changed

+38
-11
lines changed

2 files changed

+38
-11
lines changed

trlx/data/configs.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import Any, Dict, Optional, Set, Tuple
2+
from typing import Any, Dict, Optional, Set
33

44
import yaml
55

@@ -137,8 +137,8 @@ class TrainConfig:
137137
:param batch_size: Batch size for training
138138
:type batch_size: int
139139
140-
:param trackers: Tuple of trackers to use for logging. Default: ("wandb",)
141-
:type trackers: Tuple[str]
140+
:param tracker: Tracker to use for logging. Default: "wandb"
141+
:type tracker: str
142142
143143
:param checkpoint_interval: Save model every checkpoint_interval steps
144144
:type checkpoint_interval: int
@@ -198,7 +198,9 @@ class TrainConfig:
198198
rollout_logging_dir: Optional[str] = None
199199
save_best: bool = True
200200

201-
trackers: Tuple[str] = ("wandb",)
201+
tracker: Optional[str] = "wandb"
202+
logging_dir: Optional[str] = None
203+
202204
seed: int = 1000
203205

204206
@classmethod

trlx/trainer/accelerate_base_trainer.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
significant,
2929
)
3030
from trlx.utils.modeling import (
31+
flatten_dict,
3132
freeze_bottom_causal_layers,
3233
freeze_bottom_seq2seq_layers,
3334
get_delta_model_class,
@@ -44,7 +45,9 @@ class AccelerateRLTrainer(BaseRLTrainer):
4445
def __init__(self, config, **kwargs):
4546
super().__init__(config, **kwargs)
4647
self.max_length = config.train.seq_length
47-
self.accelerator = Accelerator(log_with=config.train.trackers)
48+
self.accelerator = Accelerator(
49+
log_with=config.train.tracker, logging_dir=config.train.logging_dir
50+
)
4851
if int(os.environ.get("WORLD_SIZE", 1)) > 1:
4952
torch.distributed.barrier(device_ids=[int(os.environ.get("LOCAL_RANK", 0))])
5053

@@ -78,19 +81,41 @@ def __init__(self, config, **kwargs):
7881
dist_config = get_distributed_config(self.accelerator)
7982
config_dict["distributed"] = dist_config
8083
init_trackers_kwargs = {}
81-
if "wandb" in config.train.trackers:
84+
85+
if config.train.tracker not in ("wandb", "tensorboard"):
86+
raise ValueError(
87+
f"Only supported trackers are wandb and tensorboard, got {config.train.tracker}"
88+
)
89+
90+
if config.train.tracker == "wandb":
8291
init_trackers_kwargs["wandb"] = {
8392
"name": run_name,
8493
"entity": self.config.train.entity_name,
8594
"group": self.config.train.group_name,
8695
"tags": ["/".join(get_git_tag())],
8796
"mode": "disabled" if os.environ.get("debug", False) else "online",
8897
}
89-
self.accelerator.init_trackers(
90-
project_name=self.config.train.project_name,
91-
config=config_dict,
92-
init_kwargs=init_trackers_kwargs,
93-
)
98+
99+
self.accelerator.init_trackers(
100+
project_name=self.config.train.project_name,
101+
config=config_dict,
102+
init_kwargs=init_trackers_kwargs,
103+
)
104+
else: # only other supported tracker is tensorboard
105+
config_dict_flat = flatten_dict(
106+
config_dict
107+
) # flatten config for tensorboard, split list in hparams into flatten config
108+
config_dict_flat["optimizer/kwargs/beta_1"] = config_dict_flat[
109+
"optimizer/kwargs/betas"
110+
][0]
111+
config_dict_flat["optimizer/kwargs/beta_2"] = config_dict_flat[
112+
"optimizer/kwargs/betas"
113+
][1]
114+
config_dict_flat.pop("optimizer/kwargs/betas", None)
115+
self.accelerator.init_trackers(
116+
project_name=self.config.train.project_name,
117+
config=config_dict_flat,
118+
)
94119

95120
def setup_model(self):
96121
"""

0 commit comments

Comments
 (0)