Skip to content

Commit 706ea7d

Browse files
Refactoring the function to accept list of metric names instead of a dictionary of metrics. (#938)
* .. * undoing prev commit * Refactoring the function to accept list of metric names instead of dictionary * .. * .. * .. * ..
1 parent 15ee0ac commit 706ea7d

File tree

4 files changed

+6
-12
lines changed

4 files changed

+6
-12
lines changed

llmfoundry/utils/builders.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from omegaconf import DictConfig, ListConfig
2929
from omegaconf import OmegaConf as om
3030
from torch.optim.optimizer import Optimizer
31-
from torchmetrics import Metric
3231
from transformers import AutoTokenizer, PreTrainedTokenizerBase
3332

3433
from llmfoundry.callbacks import (AsyncEval, EvalGauntlet, FDiffMetrics,
@@ -108,9 +107,8 @@ def build_eval_loaders(
108107

109108
def add_metrics_to_eval_loaders(
110109
evaluators: List[Evaluator],
111-
metrics: Dict[str, Metric],
110+
metric_names: List[str],
112111
) -> List[Evaluator]:
113-
metric_names = list(metrics.keys())
114112
eval_loaders, other_evaluators = [], []
115113
for evaluator in evaluators:
116114
if evaluator.metric_names == []:

scripts/eval/eval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ def evaluate_model(
184184
# Now add the eval metrics
185185
if eval_loader_config is not None:
186186
train_metrics = composer_model.get_metrics(is_train=True)
187-
evaluators = add_metrics_to_eval_loaders(evaluators, train_metrics)
187+
evaluators = add_metrics_to_eval_loaders(evaluators,
188+
list(train_metrics.keys()))
188189

189190
if eval_gauntlet_df is None and eval_gauntlet_callback is not None:
190191
eval_gauntlet_df = pd.DataFrame(

scripts/train/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,8 @@ def main(cfg: DictConfig) -> Trainer:
544544
# Now add the eval metrics
545545
if eval_loader_config is not None and not use_async_eval:
546546
train_metrics = model.get_metrics(is_train=True)
547-
evaluators = add_metrics_to_eval_loaders(evaluators, train_metrics)
547+
evaluators = add_metrics_to_eval_loaders(evaluators,
548+
list(train_metrics.keys()))
548549

549550
# Build the Trainer
550551
log.info('Building trainer...')

tests/utils/test_builders.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -335,13 +335,7 @@ def test_add_metrics_to_eval_loaders():
335335
)
336336
]
337337

338-
new_evaluators = add_metrics_to_eval_loaders(
339-
evaluators,
340-
{
341-
'new1': 'foo',
342-
'new2': 'bar'
343-
}, # type: ignore
344-
)
338+
new_evaluators = add_metrics_to_eval_loaders(evaluators, ['new1', 'new2'])
345339
assert len(new_evaluators) == 3
346340

347341
assert new_evaluators[0].label == 'second'

0 commit comments

Comments
 (0)