Skip to content

Commit 3c8c837

Browse files
authored
Add entity name config for wandb logging (#78)
1 parent 2f123cd commit 3c8c837

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

trlx/data/configs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Any, Dict, Tuple
2+
from typing import Any, Dict, Optional, Tuple
33

44
import yaml
55

@@ -80,6 +80,9 @@ class TrainConfig:
8080
8181
:param project_name: Project name for wandb
8282
:type project_name: str
83+
84+
:param entity_name: Entity name for wandb
85+
:type entity_name: str
8386
"""
8487

8588
total_steps: int
@@ -102,6 +105,7 @@ class TrainConfig:
102105

103106
checkpoint_dir: str = "ckpts"
104107
project_name: str = "trlx"
108+
entity_name: Optional[str] = None
105109
seed: int = 1000
106110

107111
@classmethod

trlx/model/accelerate_base_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def __init__(self, config, train_mode=True):
7070
init_kwargs={
7171
"wandb": {
7272
"name": f"{config.model.model_path}",
73+
"entity": self.config.train.entity_name,
7374
"mode": "disabled"
7475
if os.environ.get("debug", False)
7576
else "online",

0 commit comments

Comments
 (0)