Skip to content

Question about quantized model with zero3 #30663

@mxjmtxrm

Description

@mxjmtxrm

System Info

  • transformers version: 4.41.0.dev0
  • Platform: Linux-5.15.0-92-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.21.4
  • Safetensors version: 0.4.2
  • Accelerate version: 0.28.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.0a0+81ea7a4 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@pacman100

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype="bfloat16",
            bnb_4bit_quant_storage="bfloat16",
        )
model = AutoModelForCausalLM.from_pretrained(
            args.model_name_or_path,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            quantization_config=bnb_config,
            attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
        )

I found that when fine-tune a quantized model using trainer with Zero3, the quantized model will be loaded all to the GPU first, and then partitioning the parameters across data-parallel processes. What if there is not enough memory to load the whole quantized model?
The code that load all quantized model is in the deepspeed/runtime/engine.py about line262:

 self._configure_distributed_model(model)

It was entered from transformers trainer: inner_train_loop: about line1082:

 model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)

Expected behavior

How to partitioning the parameters during load from_pretrain instead of in trainer? like load float model.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions