@@ -1551,7 +1551,7 @@ def test_fp16_full_eval(self):
1551
1551
a = torch .ones (1000 , bs ) + 0.001
1552
1552
b = torch .ones (1000 , bs ) - 0.001
1553
1553
1554
- # 1. with mem metrics enabled
1554
+ # 1. with fp16_full_eval disabled
1555
1555
trainer = get_regression_trainer (a = a , b = b , eval_len = eval_len , skip_memory_metrics = False )
1556
1556
metrics = trainer .evaluate ()
1557
1557
del trainer
@@ -1572,7 +1572,7 @@ def test_fp16_full_eval(self):
1572
1572
# perfect world: fp32_eval == close to zero
1573
1573
self .assertLess (fp32_eval , 5_000 )
1574
1574
1575
- # 2. with mem metrics disabled
1575
+ # 2. with fp16_full_eval enabled
1576
1576
trainer = get_regression_trainer (a = a , b = b , eval_len = eval_len , fp16_full_eval = True , skip_memory_metrics = False )
1577
1577
metrics = trainer .evaluate ()
1578
1578
fp16_init = metrics ["init_mem_gpu_alloc_delta" ]
@@ -1611,7 +1611,7 @@ def test_bf16_full_eval(self):
1611
1611
a = torch .ones (1000 , bs ) + 0.001
1612
1612
b = torch .ones (1000 , bs ) - 0.001
1613
1613
1614
- # 1. with mem metrics enabled
1614
+ # 1. with bf16_full_eval disabled
1615
1615
trainer = get_regression_trainer (a = a , b = b , eval_len = eval_len , skip_memory_metrics = False )
1616
1616
metrics = trainer .evaluate ()
1617
1617
del trainer
@@ -1632,7 +1632,7 @@ def test_bf16_full_eval(self):
1632
1632
# perfect world: fp32_eval == close to zero
1633
1633
self .assertLess (fp32_eval , 5_000 )
1634
1634
1635
- # 2. with mem metrics disabled
1635
+ # 2. with bf16_full_eval enabled
1636
1636
trainer = get_regression_trainer (a = a , b = b , eval_len = eval_len , bf16_full_eval = True , skip_memory_metrics = False )
1637
1637
metrics = trainer .evaluate ()
1638
1638
bf16_init = metrics ["init_mem_gpu_alloc_delta" ]
0 commit comments