Skip to content
10 changes: 6 additions & 4 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Set, Tuple
from typing import Any, Dict, Optional, Set

import yaml

Expand Down Expand Up @@ -137,8 +137,8 @@ class TrainConfig:
:param batch_size: Batch size for training
:type batch_size: int

:param trackers: Tuple of trackers to use for logging. Default: ("wandb",)
:type trackers: Tuple[str]
:param tracker: Tracker to use for logging. Default: "wandb"
:type tracker: str

:param checkpoint_interval: Save model every checkpoint_interval steps
:type checkpoint_interval: int
Expand Down Expand Up @@ -198,7 +198,9 @@ class TrainConfig:
rollout_logging_dir: Optional[str] = None
save_best: bool = True

trackers: Tuple[str] = ("wandb",)
tracker: Optional[str] = "wandb"
logging_dir: Optional[str] = None

seed: int = 1000

@classmethod
Expand Down
37 changes: 30 additions & 7 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
significant,
)
from trlx.utils.modeling import (
flatten_dict,
freeze_bottom_causal_layers,
freeze_bottom_seq2seq_layers,
get_delta_model_class,
Expand All @@ -44,7 +45,9 @@ class AccelerateRLTrainer(BaseRLTrainer):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.max_length = config.train.seq_length
self.accelerator = Accelerator(log_with=config.train.trackers)
self.accelerator = Accelerator(
log_with=config.train.tracker, logging_dir=config.train.logging_dir
)
if int(os.environ.get("WORLD_SIZE", 1)) > 1:
torch.distributed.barrier(device_ids=[int(os.environ.get("LOCAL_RANK", 0))])

Expand Down Expand Up @@ -78,19 +81,39 @@ def __init__(self, config, **kwargs):
dist_config = get_distributed_config(self.accelerator)
config_dict["distributed"] = dist_config
init_trackers_kwargs = {}
if "wandb" in config.train.trackers:

if config.train.tracker not in ("wandb", "tensorboard"):
raise ValueError(
f"Only supported trackers are wandb and tensorboard, got {config.train.tracker}"
)

if config.train.tracker == "wandb":
init_trackers_kwargs["wandb"] = {
"name": run_name,
"entity": self.config.train.entity_name,
"group": self.config.train.group_name,
"tags": ["/".join(get_git_tag())],
"mode": "disabled" if os.environ.get("debug", False) else "online",
}
self.accelerator.init_trackers(
project_name=self.config.train.project_name,
config=config_dict,
init_kwargs=init_trackers_kwargs,
)

self.accelerator.init_trackers(
project_name=self.config.train.project_name,
config=config_dict,
init_kwargs=init_trackers_kwargs,
)
else:
Copy link
Collaborator

@cat-state cat-state Jan 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!
Could you add back a comment explaining what this branch is for and the flattening? Aside from that it LGTM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

  • I had run into problems running trlx without wandb (I don't have an account as of now), and found an already opened issue precisely on this. The branch has minor modifications (which don't change the previous interface for wandb users) to allow tensorboard tracking
  • The only tricky part is that wandb is pretty fancy and take nested dicts as logging params, this is not the case for tensorboard, hence the experiment config is fully flattened, and the only list is simply split apart (for the same reason)
    Do let me know if anything is not clear or if I should add comments in the tensorboard specific logging

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, by branch I meant the else branch - i.e a short comment like
else: # tracker == 'tensorboard' and # flatten config for tensorboard, flatten lists in hparams into flattened config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, I was reading the comment in between my regular work and got fully confused, thanks for the feedback

config_dict_flat = flatten_dict(config_dict)
config_dict_flat["optimizer/kwargs/beta_1"] = config_dict_flat[
"optimizer/kwargs/betas"
][0]
config_dict_flat["optimizer/kwargs/beta_2"] = config_dict_flat[
"optimizer/kwargs/betas"
][1]
config_dict_flat.pop("optimizer/kwargs/betas", None)
self.accelerator.init_trackers(
project_name=self.config.train.project_name,
config=config_dict_flat,
)

def setup_model(self):
"""
Expand Down