1
- from dataclasses import dataclass
2
- from typing import Any , Dict , Optional , Set
1
+ from dataclasses import dataclass , field
2
+ from typing import Any , Dict , Optional , Set , Tuple
3
3
4
4
import yaml
5
5
@@ -56,7 +56,7 @@ class OptimizerConfig:
56
56
"""
57
57
58
58
name : str
59
- kwargs : Dict [str , Any ] = None
59
+ kwargs : Dict [str , Any ] = field ( default_factory = dict )
60
60
61
61
@classmethod
62
62
def from_dict (cls , config : Dict [str , Any ]):
@@ -76,7 +76,7 @@ class SchedulerConfig:
76
76
"""
77
77
78
78
name : str
79
- kwargs : Dict [str , Any ] = None
79
+ kwargs : Dict [str , Any ] = field ( default_factory = dict )
80
80
81
81
@classmethod
82
82
def from_dict (cls , config : Dict [str , Any ]):
@@ -100,6 +100,9 @@ class TrainConfig:
100
100
:param batch_size: Batch size for training
101
101
:type batch_size: int
102
102
103
+ :param trackers: Tuple of trackers to use for logging. Default: ("wandb",)
104
+ :type trackers: Tuple[str]
105
+
103
106
:param checkpoint_interval: Save model every checkpoint_interval steps
104
107
:type checkpoint_interval: int
105
108
@@ -123,7 +126,8 @@ class TrainConfig:
123
126
:param checkpoint_dir: Directory to save checkpoints
124
127
:type checkpoint_dir: str
125
128
126
- :param rollout_logging_dir: Directory to store generated rollouts for use in Algorithm Distillation. Only used by AcceleratePPOTrainer.
129
+ :param rollout_logging_dir: Directory to store generated rollouts for use in Algorithm Distillation.
130
+ Only used by AcceleratePPOTrainer.
127
131
:type rollout_logging_dir: Optional[str]
128
132
129
133
:param seed: Random seed
@@ -148,6 +152,7 @@ class TrainConfig:
148
152
checkpoint_dir : str = "ckpts"
149
153
rollout_logging_dir : Optional [str ] = None
150
154
155
+ trackers : Tuple [str ] = ("wandb" ,)
151
156
seed : int = 1000
152
157
153
158
@classmethod
0 commit comments