Skip to content
10 changes: 6 additions & 4 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Set, Tuple
from typing import Any, Dict, Optional, Set

import yaml

Expand Down Expand Up @@ -137,8 +137,8 @@ class TrainConfig:
:param batch_size: Batch size for training
:type batch_size: int

:param trackers: Tuple of trackers to use for logging. Default: ("wandb",)
:type trackers: Tuple[str]
:param tracker: Tracker to use for logging. Default: "wandb"
:type tracker: str

:param checkpoint_interval: Save model every checkpoint_interval steps
:type checkpoint_interval: int
Expand Down Expand Up @@ -198,7 +198,9 @@ class TrainConfig:
rollout_logging_dir: Optional[str] = None
save_best: bool = True

trackers: Tuple[str] = ("wandb",)
tracker: Optional[str] = "wandb"
logging_dir: Optional[str] = None

seed: int = 1000

@classmethod
Expand Down
39 changes: 32 additions & 7 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
significant,
)
from trlx.utils.modeling import (
flatten_dict,
freeze_bottom_causal_layers,
freeze_bottom_seq2seq_layers,
get_delta_model_class,
Expand All @@ -44,7 +45,9 @@ class AccelerateRLTrainer(BaseRLTrainer):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.max_length = config.train.seq_length
self.accelerator = Accelerator(log_with=config.train.trackers)
self.accelerator = Accelerator(
log_with=config.train.tracker, logging_dir=config.train.logging_dir
)
if int(os.environ.get("WORLD_SIZE", 1)) > 1:
torch.distributed.barrier(device_ids=[int(os.environ.get("LOCAL_RANK", 0))])

Expand Down Expand Up @@ -78,19 +81,41 @@ def __init__(self, config, **kwargs):
dist_config = get_distributed_config(self.accelerator)
config_dict["distributed"] = dist_config
init_trackers_kwargs = {}
if "wandb" in config.train.trackers:

if config.train.tracker not in ("wandb", "tensorboard"):
raise ValueError(
f"Only supported trackers are wandb and tensorboard, got {config.train.tracker}"
)

if config.train.tracker == "wandb":
init_trackers_kwargs["wandb"] = {
"name": run_name,
"entity": self.config.train.entity_name,
"group": self.config.train.group_name,
"tags": ["/".join(get_git_tag())],
"mode": "disabled" if os.environ.get("debug", False) else "online",
}
self.accelerator.init_trackers(
project_name=self.config.train.project_name,
config=config_dict,
init_kwargs=init_trackers_kwargs,
)

self.accelerator.init_trackers(
project_name=self.config.train.project_name,
config=config_dict,
init_kwargs=init_trackers_kwargs,
)
else: # only other suppoert tracker is tensorboard
config_dict_flat = flatten_dict(
config_dict
) # flatten config for tensorboard, split list in hparams into flatten config
config_dict_flat["optimizer/kwargs/beta_1"] = config_dict_flat[
"optimizer/kwargs/betas"
][0]
config_dict_flat["optimizer/kwargs/beta_2"] = config_dict_flat[
"optimizer/kwargs/betas"
][1]
config_dict_flat.pop("optimizer/kwargs/betas", None)
self.accelerator.init_trackers(
project_name=self.config.train.project_name,
config=config_dict_flat,
)

def setup_model(self):
"""
Expand Down