1717import re
1818import sys
1919import unittest
20+ from typing import Tuple
2021from unittest .mock import patch
2122
2223from parameterized import parameterized
23- from transformers . integrations import is_fairscale_available
24+ from transformers import AutoModel
2425from transformers .testing_utils import (
2526 CaptureStderr ,
2627 ExtendSysPath ,
2728 TestCasePlus ,
2829 execute_subprocess_async ,
2930 get_gpu_count ,
3031 get_torch_dist_unique_port ,
32+ require_apex ,
33+ require_bitsandbytes ,
34+ require_fairscale ,
3135 require_torch ,
3236 require_torch_gpu ,
3337 require_torch_multi_gpu ,
3640)
3741from transformers .trainer_callback import TrainerState
3842from transformers .trainer_utils import set_seed
39- from transformers .utils import is_apex_available
4043
4144
4245bindir = os .path .abspath (os .path .dirname (__file__ ))
4952MBART_TINY = "sshleifer/tiny-mbart"
5053
5154
52- # a candidate for testing_utils
53- def require_fairscale (test_case ):
54- """
55- Decorator marking a test that requires fairscale
56- """
57- if not is_fairscale_available ():
58- return unittest .skip ("test requires fairscale" )(test_case )
59- else :
60- return test_case
61-
62-
63- # a candidate for testing_utils
64- def require_apex (test_case ):
65- """
66- Decorator marking a test that requires apex
67- """
68- if not is_apex_available ():
69- return unittest .skip ("test requires apex" )(test_case )
70- else :
71- return test_case
72-
73-
7455@require_torch
7556class TestTrainerExt (TestCasePlus ):
7657 def run_seq2seq_quick (
@@ -193,7 +174,7 @@ def test_trainer_log_level_replica(self, experiment_id):
193174 self .assertEqual (n_matches , data ["n_matches" ])
194175
195176 @slow
196- def test_run_seq2seq_slow (self ):
177+ def test_run_seq2seq (self ):
197178 output_dir = self .run_trainer (
198179 eval_steps = 2 ,
199180 max_len = 128 ,
@@ -218,6 +199,88 @@ def test_run_seq2seq_slow(self):
218199 assert "generated_predictions.txt" in contents
219200 assert "predict_results.json" in contents
220201
202+ @slow
203+ @require_bitsandbytes
204+ def test_run_seq2seq_bnb (self ):
205+ from transformers .training_args import OptimizerNames
206+
207+ def train_and_return_metrics (optim : str ) -> Tuple [int , float ]:
208+ from pathlib import Path
209+
210+ extra_args = (
211+ f"--skip_memory_metrics 0 --optim { optim } --do_eval False --do_predict "
212+ "False --adafactor False --log_level debug"
213+ )
214+
215+ output_dir = self .run_trainer (
216+ eval_steps = 2 ,
217+ max_len = 128 ,
218+ model_name = MARIAN_MODEL ,
219+ learning_rate = 3e-4 ,
220+ num_train_epochs = 1 ,
221+ distributed = True , # force run in a new process
222+ extra_args_str = extra_args ,
223+ do_eval = False ,
224+ do_predict = False ,
225+ )
226+
227+ # Check metrics
228+ logs = TrainerState .load_from_json (Path (output_dir , "trainer_state.json" )).log_history
229+ gpu_peak_mem = logs [0 ]["train_mem_gpu_peaked_delta" ]
230+ gpu_alloc_mem = logs [0 ]["train_mem_gpu_alloc_delta" ]
231+
232+ loss = logs [0 ]["train_loss" ]
233+ return gpu_peak_mem , gpu_alloc_mem , loss
234+
235+ gpu_peak_mem_orig , gpu_alloc_mem_orig , loss_orig = train_and_return_metrics (OptimizerNames .ADAMW_TORCH .value )
236+ gpu_peak_mem_bnb , gpu_alloc_mem_bnb , loss_bnb = train_and_return_metrics (OptimizerNames .ADAMW_BNB .value )
237+
238+ gpu_peak_mem_diff_bytes = gpu_peak_mem_orig - gpu_peak_mem_bnb
239+ gpu_peak_mem_diff_percent = gpu_peak_mem_diff_bytes / gpu_peak_mem_bnb
240+
241+ gpu_total_mem_orig = gpu_peak_mem_orig + gpu_alloc_mem_orig
242+ gpu_total_mem_bnb = gpu_peak_mem_bnb + gpu_alloc_mem_bnb
243+
244+ gpu_total_mem_diff_bytes = gpu_total_mem_orig - gpu_total_mem_bnb
245+ gpu_total_mem_diff_percent = gpu_total_mem_diff_bytes / gpu_total_mem_bnb
246+
247+ # leave this for now if CI gets very different results
248+ # print(f"{gpu_alloc_mem_orig=:010d} {gpu_peak_mem_orig=:010d} {gpu_alloc_mem_orig+gpu_peak_mem_orig=:010d}" )
249+ # print(f" {gpu_alloc_mem_bnb=:010d} {gpu_peak_mem_bnb=:010d} {gpu_alloc_mem_bnb+gpu_peak_mem_bnb=:010d}")
250+ # print(f"{gpu_peak_mem_diff_bytes=}, {gpu_peak_mem_diff_percent=}")
251+ # print(f"{gpu_total_mem_orig=}, {gpu_total_mem_bnb=}")
252+ # print(f"{gpu_total_mem_diff_bytes=}, {gpu_total_mem_diff_percent=}")
253+
254+ self .assertGreater (
255+ gpu_peak_mem_diff_percent ,
256+ 10 , # basically a huge difference - got ~30x on my desktop
257+ "should use very little peak gpu memory with BNB, compared to without it"
258+ f"but got gpu_peak_mem_orig={ gpu_peak_mem_orig } and gpu_peak_mem_bnb={ gpu_peak_mem_bnb } " ,
259+ )
260+
261+ self .assertGreater (
262+ gpu_total_mem_diff_percent ,
263+ 0.20 , # could easily be 0.50, but let's stay on the safe side
264+ "Using BNB should use less total GPU memory than without it"
265+ f"but got gpu_total_mem_orig={ gpu_total_mem_orig } and gpu_total_mem_bnb={ gpu_total_mem_bnb } " ,
266+ )
267+
268+ self .assertEqual (
269+ loss_orig , loss_bnb , "loss should be the same, but got loss_orig={loss_orig}, loss_bnb={loss_bnb}"
270+ )
271+
272+ # Additionally let's test that the absolute gpu memory difference is larger or about the
273+ # same as the expected saving coming from BNB (6 bytes per param)
274+ model = AutoModel .from_pretrained (MARIAN_MODEL )
275+ total_numel = sum (dict ((p .data_ptr (), p .numel ()) for p in model .parameters ()).values ())
276+ bnb_saved_bytes = total_numel * 6 # 324MB
277+
278+ self .assertGreater (
279+ gpu_total_mem_diff_bytes ,
280+ bnb_saved_bytes * 0.8 , # add a safety margin, if it saved slightly less
281+ f"BNB should have saved about { bnb_saved_bytes } bytes, but the saved bytes were { gpu_total_mem_diff_bytes } " ,
282+ )
283+
221284 def run_trainer (
222285 self ,
223286 eval_steps : int ,
@@ -300,6 +363,8 @@ def run_trainer(
300363 { self .examples_dir_str } /pytorch/translation/run_translation.py
301364 """ .split ()
302365 cmd = [sys .executable ] + distributed_args + args
366+ # keep for quick debug
367+ # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
303368 execute_subprocess_async (cmd , env = self .get_env ())
304369 else :
305370 testargs = ["run_translation.py" ] + args
0 commit comments