Skip to content
8 changes: 5 additions & 3 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ class TrainConfig:
:param batch_size: Batch size for training
:type batch_size: int

:param trackers: Tuple of trackers to use for logging. Default: ("wandb",)
:type trackers: Tuple[str]
:param tracker: Tracker to use for logging. Default: "wandb"
:type tracker: str

:param checkpoint_interval: Save model every checkpoint_interval steps
:type checkpoint_interval: int
Expand Down Expand Up @@ -198,7 +198,9 @@ class TrainConfig:
rollout_logging_dir: Optional[str] = None
save_best: bool = True

trackers: Tuple[str] = ("wandb",)
tracker: Optional[str] = "wandb"
logging_dir: Optional[str] = None

seed: int = 1000

@classmethod
Expand Down
39 changes: 26 additions & 13 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class AccelerateRLTrainer(BaseRLTrainer):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.max_length = config.train.seq_length
self.accelerator = Accelerator(log_with=config.train.trackers)
self.accelerator = Accelerator(log_with=config.train.tracker, logging_dir=config.train.logging_dir)
if int(os.environ.get("WORLD_SIZE", 1)) > 1:
torch.distributed.barrier(device_ids=[int(os.environ.get("LOCAL_RANK", 0))])

Expand Down Expand Up @@ -78,19 +78,32 @@ def __init__(self, config, **kwargs):
dist_config = get_distributed_config(self.accelerator)
config_dict["distributed"] = dist_config
init_trackers_kwargs = {}
if "wandb" in config.train.trackers:
# HACK: Tensorboard doesn't like nested dict as hyperparams
config_dict_flat = {a:b for (k,v) in config_dict.items() for (a,b) in v.items() if not isinstance(b, dict)}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use trlx.utils.modeling.flatten_dict here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the suggestion! I replaced the dict comprehension with a call to flatten_dict(), since tensorboard also doesn't like lists I added a couple lines to split the optimizer betas


if config.train.tracker not in ("wandb", "tensorboard"):
raise ValueError(f"Only supported trackers are wandb and tensorboard, got {config.train.tracker}")

if config.train.tracker == "wandb":
init_trackers_kwargs["wandb"] = {
"name": run_name,
"entity": self.config.train.entity_name,
"group": self.config.train.group_name,
"tags": ["/".join(get_git_tag())],
"mode": "disabled" if os.environ.get("debug", False) else "online",
}
self.accelerator.init_trackers(
project_name=self.config.train.project_name,
config=config_dict,
init_kwargs=init_trackers_kwargs,
)
"name": run_name,
"entity": self.config.train.entity_name,
"group": self.config.train.group_name,
"tags": ["/".join(get_git_tag())],
"mode": "disabled" if os.environ.get("debug", False) else "online",
}

self.accelerator.init_trackers(
project_name=self.config.train.project_name,
config=config_dict,
init_kwargs=init_trackers_kwargs,
)
else:
Copy link
Collaborator

@cat-state cat-state Jan 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!
Could you add back a comment explaining what this branch is for and the flattening? Aside from that it LGTM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

  • I had run into problems running trlx without wandb (I don't have an account as of now), and found an already opened issue precisely on this. The branch has minor modifications (which don't change the previous interface for wandb users) to allow tensorboard tracking
  • The only tricky part is that wandb is pretty fancy and take nested dicts as logging params, this is not the case for tensorboard, hence the experiment config is fully flattened, and the only list is simply split apart (for the same reason)
    Do let me know if anything is not clear or if I should add comments in the tensorboard specific logging

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, by branch I meant the else branch - i.e a short comment like
else: # tracker == 'tensorboard' and # flatten config for tensorboard, flatten lists in hparams into flattened config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, I was reading the comment in between my regular work and got fully confused, thanks for the feedback

self.accelerator.init_trackers(
project_name=self.config.train.project_name,
config=config_dict_flat,
)


def setup_model(self):
"""
Expand Down