Skip to content

Commit 6d2e29d

Browse files
manuelciosicistas00
authored andcommitted
Add support for bitsandbytes (huggingface#15622)
* Add initial BNB integration * fixup! Add initial BNB integration * Add bnb test decorator * Update Adamw8bit option name * Use the full bnb package name * Overide bnb for all embedding layers * Fix package name * Formatting * Remove unnecessary import * Update src/transformers/trainer.py Co-authored-by: Stas Bekman <[email protected]> * Rename AdamwBNB optimizer option * Add training test checking that bnb memory utilization is lower * fix merge * fix merge; fix + extend new test * cleanup * expand bnb * move all require_* candidates to testing_utils.py Co-authored-by: Stas Bekman <[email protected]> Co-authored-by: Stas Bekman <[email protected]>
1 parent ff85469 commit 6d2e29d

File tree

7 files changed

+194
-29
lines changed

7 files changed

+194
-29
lines changed

src/transformers/testing_utils.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,16 @@
3131
from transformers import logging as transformers_logging
3232

3333
from .deepspeed import is_deepspeed_available
34-
from .integrations import is_optuna_available, is_ray_available, is_sigopt_available, is_wandb_available
34+
from .integrations import (
35+
is_fairscale_available,
36+
is_optuna_available,
37+
is_ray_available,
38+
is_sigopt_available,
39+
is_wandb_available,
40+
)
3541
from .utils import (
42+
is_apex_available,
43+
is_bitsandbytes_available,
3644
is_detectron2_available,
3745
is_faiss_available,
3846
is_flax_available,
@@ -638,6 +646,36 @@ def require_deepspeed(test_case):
638646
return test_case
639647

640648

649+
def require_fairscale(test_case):
650+
"""
651+
Decorator marking a test that requires fairscale
652+
"""
653+
if not is_fairscale_available():
654+
return unittest.skip("test requires fairscale")(test_case)
655+
else:
656+
return test_case
657+
658+
659+
def require_apex(test_case):
660+
"""
661+
Decorator marking a test that requires apex
662+
"""
663+
if not is_apex_available():
664+
return unittest.skip("test requires apex")(test_case)
665+
else:
666+
return test_case
667+
668+
669+
def require_bitsandbytes(test_case):
670+
"""
671+
Decorator for bits and bytes (bnb) dependency
672+
"""
673+
if not is_bitsandbytes_available():
674+
return unittest.skip("test requires bnb")(test_case)
675+
else:
676+
return test_case
677+
678+
641679
def require_phonemizer(test_case):
642680
"""
643681
Decorator marking a test that requires phonemizer

src/transformers/trainer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,15 @@ def create_optimizer(self):
867867
)
868868
else:
869869
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
870+
if optimizer_cls.__name__ == "Adam8bit":
871+
import bitsandbytes
872+
873+
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
874+
875+
for module in self.model.modules():
876+
if isinstance(module, nn.Embedding):
877+
manager.register_module_override(module, "weight", {"optim_bits": 32})
878+
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
870879

871880
if is_sagemaker_mp_enabled():
872881
self.optimizer = smp.DistributedOptimizer(self.optimizer)
@@ -917,6 +926,14 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
917926
optimizer_kwargs.update(adam_kwargs)
918927
except ImportError:
919928
raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
929+
elif args.optim == OptimizerNames.ADAMW_BNB:
930+
try:
931+
from bitsandbytes.optim import Adam8bit
932+
933+
optimizer_cls = Adam8bit
934+
optimizer_kwargs.update(adam_kwargs)
935+
except ImportError:
936+
raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
920937
else:
921938
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
922939
return optimizer_cls, optimizer_kwargs

src/transformers/training_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class OptimizerNames(ExplicitEnum):
7979
ADAMW_TORCH_XLA = "adamw_torch_xla"
8080
ADAMW_APEX_FUSED = "adamw_apex_fused"
8181
ADAFACTOR = "adafactor"
82+
ADAMW_BNB = "adamw_bnb_8bit"
8283

8384

8485
@dataclass

src/transformers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
DummyObject,
8686
_LazyModule,
8787
is_apex_available,
88+
is_bitsandbytes_available,
8889
is_coloredlogs_available,
8990
is_datasets_available,
9091
is_detectron2_available,

src/transformers/utils/import_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,10 @@ def is_apex_available():
400400
return importlib.util.find_spec("apex") is not None
401401

402402

403+
def is_bitsandbytes_available():
404+
return importlib.util.find_spec("bitsandbytes") is not None
405+
406+
403407
def is_faiss_available():
404408
return _faiss_available
405409

tests/extended/test_trainer_ext.py

Lines changed: 90 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,21 @@
1717
import re
1818
import sys
1919
import unittest
20+
from typing import Tuple
2021
from unittest.mock import patch
2122

2223
from parameterized import parameterized
23-
from transformers.integrations import is_fairscale_available
24+
from transformers import AutoModel
2425
from transformers.testing_utils import (
2526
CaptureStderr,
2627
ExtendSysPath,
2728
TestCasePlus,
2829
execute_subprocess_async,
2930
get_gpu_count,
3031
get_torch_dist_unique_port,
32+
require_apex,
33+
require_bitsandbytes,
34+
require_fairscale,
3135
require_torch,
3236
require_torch_gpu,
3337
require_torch_multi_gpu,
@@ -36,7 +40,6 @@
3640
)
3741
from transformers.trainer_callback import TrainerState
3842
from transformers.trainer_utils import set_seed
39-
from transformers.utils import is_apex_available
4043

4144

4245
bindir = os.path.abspath(os.path.dirname(__file__))
@@ -49,28 +52,6 @@
4952
MBART_TINY = "sshleifer/tiny-mbart"
5053

5154

52-
# a candidate for testing_utils
53-
def require_fairscale(test_case):
54-
"""
55-
Decorator marking a test that requires fairscale
56-
"""
57-
if not is_fairscale_available():
58-
return unittest.skip("test requires fairscale")(test_case)
59-
else:
60-
return test_case
61-
62-
63-
# a candidate for testing_utils
64-
def require_apex(test_case):
65-
"""
66-
Decorator marking a test that requires apex
67-
"""
68-
if not is_apex_available():
69-
return unittest.skip("test requires apex")(test_case)
70-
else:
71-
return test_case
72-
73-
7455
@require_torch
7556
class TestTrainerExt(TestCasePlus):
7657
def run_seq2seq_quick(
@@ -193,7 +174,7 @@ def test_trainer_log_level_replica(self, experiment_id):
193174
self.assertEqual(n_matches, data["n_matches"])
194175

195176
@slow
196-
def test_run_seq2seq_slow(self):
177+
def test_run_seq2seq(self):
197178
output_dir = self.run_trainer(
198179
eval_steps=2,
199180
max_len=128,
@@ -218,6 +199,88 @@ def test_run_seq2seq_slow(self):
218199
assert "generated_predictions.txt" in contents
219200
assert "predict_results.json" in contents
220201

202+
@slow
203+
@require_bitsandbytes
204+
def test_run_seq2seq_bnb(self):
205+
from transformers.training_args import OptimizerNames
206+
207+
def train_and_return_metrics(optim: str) -> Tuple[int, float]:
208+
from pathlib import Path
209+
210+
extra_args = (
211+
f"--skip_memory_metrics 0 --optim {optim} --do_eval False --do_predict "
212+
"False --adafactor False --log_level debug"
213+
)
214+
215+
output_dir = self.run_trainer(
216+
eval_steps=2,
217+
max_len=128,
218+
model_name=MARIAN_MODEL,
219+
learning_rate=3e-4,
220+
num_train_epochs=1,
221+
distributed=True, # force run in a new process
222+
extra_args_str=extra_args,
223+
do_eval=False,
224+
do_predict=False,
225+
)
226+
227+
# Check metrics
228+
logs = TrainerState.load_from_json(Path(output_dir, "trainer_state.json")).log_history
229+
gpu_peak_mem = logs[0]["train_mem_gpu_peaked_delta"]
230+
gpu_alloc_mem = logs[0]["train_mem_gpu_alloc_delta"]
231+
232+
loss = logs[0]["train_loss"]
233+
return gpu_peak_mem, gpu_alloc_mem, loss
234+
235+
gpu_peak_mem_orig, gpu_alloc_mem_orig, loss_orig = train_and_return_metrics(OptimizerNames.ADAMW_TORCH.value)
236+
gpu_peak_mem_bnb, gpu_alloc_mem_bnb, loss_bnb = train_and_return_metrics(OptimizerNames.ADAMW_BNB.value)
237+
238+
gpu_peak_mem_diff_bytes = gpu_peak_mem_orig - gpu_peak_mem_bnb
239+
gpu_peak_mem_diff_percent = gpu_peak_mem_diff_bytes / gpu_peak_mem_bnb
240+
241+
gpu_total_mem_orig = gpu_peak_mem_orig + gpu_alloc_mem_orig
242+
gpu_total_mem_bnb = gpu_peak_mem_bnb + gpu_alloc_mem_bnb
243+
244+
gpu_total_mem_diff_bytes = gpu_total_mem_orig - gpu_total_mem_bnb
245+
gpu_total_mem_diff_percent = gpu_total_mem_diff_bytes / gpu_total_mem_bnb
246+
247+
# leave this for now if CI gets very different results
248+
# print(f"{gpu_alloc_mem_orig=:010d} {gpu_peak_mem_orig=:010d} {gpu_alloc_mem_orig+gpu_peak_mem_orig=:010d}" )
249+
# print(f" {gpu_alloc_mem_bnb=:010d} {gpu_peak_mem_bnb=:010d} {gpu_alloc_mem_bnb+gpu_peak_mem_bnb=:010d}")
250+
# print(f"{gpu_peak_mem_diff_bytes=}, {gpu_peak_mem_diff_percent=}")
251+
# print(f"{gpu_total_mem_orig=}, {gpu_total_mem_bnb=}")
252+
# print(f"{gpu_total_mem_diff_bytes=}, {gpu_total_mem_diff_percent=}")
253+
254+
self.assertGreater(
255+
gpu_peak_mem_diff_percent,
256+
10, # basically a huge difference - got ~30x on my desktop
257+
"should use very little peak gpu memory with BNB, compared to without it"
258+
f"but got gpu_peak_mem_orig={gpu_peak_mem_orig} and gpu_peak_mem_bnb={gpu_peak_mem_bnb}",
259+
)
260+
261+
self.assertGreater(
262+
gpu_total_mem_diff_percent,
263+
0.20, # could easily be 0.50, but let's stay on the safe side
264+
"Using BNB should use less total GPU memory than without it"
265+
f"but got gpu_total_mem_orig={gpu_total_mem_orig} and gpu_total_mem_bnb={gpu_total_mem_bnb}",
266+
)
267+
268+
self.assertEqual(
269+
loss_orig, loss_bnb, "loss should be the same, but got loss_orig={loss_orig}, loss_bnb={loss_bnb}"
270+
)
271+
272+
# Additionally let's test that the absolute gpu memory difference is larger or about the
273+
# same as the expected saving coming from BNB (6 bytes per param)
274+
model = AutoModel.from_pretrained(MARIAN_MODEL)
275+
total_numel = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
276+
bnb_saved_bytes = total_numel * 6 # 324MB
277+
278+
self.assertGreater(
279+
gpu_total_mem_diff_bytes,
280+
bnb_saved_bytes * 0.8, # add a safety margin, if it saved slightly less
281+
f"BNB should have saved about {bnb_saved_bytes} bytes, but the saved bytes were {gpu_total_mem_diff_bytes}",
282+
)
283+
221284
def run_trainer(
222285
self,
223286
eval_steps: int,
@@ -300,6 +363,8 @@ def run_trainer(
300363
{self.examples_dir_str}/pytorch/translation/run_translation.py
301364
""".split()
302365
cmd = [sys.executable] + distributed_args + args
366+
# keep for quick debug
367+
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
303368
execute_subprocess_async(cmd, env=self.get_env())
304369
else:
305370
testargs = ["run_translation.py"] + args

tests/trainer/test_trainer.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
)
6666
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
6767
from transformers.training_args import OptimizerNames
68-
from transformers.utils import WEIGHTS_NAME, is_apex_available
68+
from transformers.utils import WEIGHTS_NAME, is_apex_available, is_bitsandbytes_available
6969
from transformers.utils.hp_naming import TrialShortNamer
7070

7171

@@ -1870,6 +1870,7 @@ def hp_name(trial):
18701870
},
18711871
),
18721872
]
1873+
18731874
if is_apex_available():
18741875
import apex
18751876

@@ -1881,6 +1882,17 @@ def hp_name(trial):
18811882
)
18821883
)
18831884

1885+
if is_bitsandbytes_available():
1886+
import bitsandbytes as bnb
1887+
1888+
optim_test_params.append(
1889+
(
1890+
OptimizerNames.ADAMW_BNB,
1891+
bnb.optim.Adam8bit,
1892+
default_adam_kwargs,
1893+
)
1894+
)
1895+
18841896

18851897
@require_torch
18861898
class TrainerOptimizerChoiceTest(unittest.TestCase):
@@ -1905,8 +1917,8 @@ def test_optim_supported(self, name: str, expected_cls, mandatory_kwargs):
19051917

19061918
def test_fused_adam(self):
19071919
# Pretend that apex is installed and mock apex.optimizers.FusedAdam exists.
1908-
# Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam, but only has to return a
1909-
# class called, so mocking apex.optimizers.FusedAdam should be fine for testing and allow
1920+
# Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam. It only has to return the
1921+
# class given, so mocking apex.optimizers.FusedAdam should be fine for testing and allow
19101922
# the test to run without requiring an apex installation.
19111923
mock = Mock()
19121924
modules = {
@@ -1930,6 +1942,33 @@ def test_fused_adam_no_apex(self):
19301942
with self.assertRaises(ValueError):
19311943
Trainer.get_optimizer_cls_and_kwargs(args)
19321944

1945+
def test_bnb_adam8bit(self):
1946+
# Pretend that Bits and Bytes is installed and mock bnb.optim.Adam8bit exists.
1947+
# Trainer.get_optimizer_cls_and_kwargs does not use Adam8bit. It only has to return the
1948+
# class given, so mocking bnb.optim.Adam8bit should be fine for testing and allow
1949+
# the test to run without requiring a bnb installation.
1950+
mock = Mock()
1951+
modules = {
1952+
"bitsandbytes": mock,
1953+
"bitsandbytes.optim": mock.optim,
1954+
"bitsandbytes.optim.Adam8bit": mock.optim.Adam8bit,
1955+
}
1956+
with patch.dict("sys.modules", modules):
1957+
self.check_optim_and_kwargs(
1958+
OptimizerNames.ADAMW_BNB,
1959+
default_adam_kwargs,
1960+
mock.optim.Adam8bit,
1961+
)
1962+
1963+
def test_bnb_adam8bit_no_bnb(self):
1964+
args = TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None")
1965+
1966+
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
1967+
# bnb will fail even if bnb is installed.
1968+
with patch.dict("sys.modules", {"bnb.optim": None}):
1969+
with self.assertRaises(ValueError):
1970+
Trainer.get_optimizer_cls_and_kwargs(args)
1971+
19331972

19341973
@require_torch
19351974
@require_wandb

0 commit comments

Comments
 (0)