Skip to content

Mixtral 8*22B Quantization Failed with 2 issues #35

@qingquansong

Description

@qingquansong

Describe the bug
A clear and concise description of what the bug is.

Hey Team, trying to quantize mistral 8*22b with W8A8 recipe and failed with two issues with different versions:

File "/home/jobuser/llm-compressor/src/llmcompressor/utils/pytorch/module.py", line 166, in get_layers
    return match_layers_params(targets, module)
  File "/home/jobuser/llm-compressor/src/llmcompressor/utils/pytorch/module.py", line 160, in match_layers_params
    raise ValueError(f"Could not find targets {missed} in module {module}")
ValueError: Could not find targets ['re:.*gate_proj'] in module MixtralForCausalLM(
  (model): MixtralModel(
    (embed_tokens): Embedding(32768, 6144)
    (layers): ModuleList(
      (0-55): 56 x MixtralDecoderLayer(
        (self_attn): MixtralSdpaAttention(
          (q_proj): Linear(in_features=6144, out_features=6144, bias=False)
          (k_proj): Linear(in_features=6144, out_features=1024, bias=False)
          (v_proj): Linear(in_features=6144, out_features=1024, bias=False)
          (o_proj): Linear(in_features=6144, out_features=6144, bias=False)
          (rotary_emb): MixtralRotaryEmbedding()
        )
        (block_sparse_moe): MixtralSparseMoeBlock(
          (gate): Linear(in_features=6144, out_features=8, bias=False)
          (experts): ModuleList(
            (0-7): 8 x MixtralBlockSparseTop2MLP(
              (w1): Linear(in_features=6144, out_features=16384, bias=False)
              (w2): Linear(in_features=16384, out_features=6144, bias=False)
              (w3): Linear(in_features=6144, out_features=16384, bias=False)
              (act_fn): SiLU()
            )
          )
        )
        (input_layernorm): MixtralRMSNorm()
        (post_attention_layernorm): MixtralRMSNorm()
      )
    )
    (norm): MixtralRMSNorm()
  )
  (lm_head): Linear(in_features=6144, out_features=32768, bias=False)
)

This issue happen when using the latest main branch and I think there're some regex issue and previously when using the main branch 1~2 week ago I didn't see the issue. Any things has changed?

I think I can fix with changing the default mapping of smoothquant to the mixtral one manually,but wondering if there're some better solutions here and why previously it does not happen.

  1. Previously although didn't face this issue, there's another OOM issue happen after 3-4 layers. (I'm changing device map to auto and seems to work well for llama3 70b but not mistral 8*22b (larger though)) So probably have a cpu offloading or better block clean up schema is needed.

Expected behavior
A clear and concise description of what you expected to happen.
Expect to finish with 1node 8A100 setup

Environment
Include all relevant environment information:

  1. OS [e.g. Ubuntu 18.04]: Linux Mariner
  2. Python version [e.g. 3.7]: 3.10
  3. LLM Compressor version or commit hash [e.g. 0.1.0, f7245c8]: latest branch (end of 2024-0723)
  4. ML framework version(s) [e.g. torch 1.7.1]: torch 2.3.1 cu118
  5. Other Python package versions [e.g. SparseZoo, DeepSparse, numpy, ONNX]:
  6. Other relevant environment information [e.g. hardware, CUDA version]: cuda11.8

To Reproduce
Exact steps to reproduce the behavior:

recipe = [
    SmoothQuantModifier(smoothing_strength=0.8),
    GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
]
tokenized_ids_dataset = Dataset.from_dict(tokenized_ids)
oneshot(model=model_path,  # here use mitral 8*22b
    dataset=tokenized_ids_dataset,   # I'm using `garage-bAInd___open-platypus` with 8 sequences
    recipe=recipe,
    save_compressed=True,
    output_dir=output_model_path,
    oneshot_device="auto",
    overwrite_output_dir=True,
    max_seq_length=model_max_length,
    num_calibration_samples=num_calibration_samples,
)

Errors
If applicable, add a full print-out of any errors or exceptions that are raised or include screenshots to help explain your problem.

2024-07-24T06:48:58.529779+0000 | intialize_model_from_path | WARNING - Moving /shared/public/models/Mixtral-8x22B-Instruct-v0.1 to device auto for One-Shot
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████| 59/59 [01:50<00:00,  1.87s/it]
2024-07-24T06:51:22.896705+0000 | _check_create_state | INFO - State created for compression lifecycle
2024-07-24T06:51:22.898697+0000 | pre_initialize_structure | INFO - Compression lifecycle structure pre-initialized for 0 modifiers
2024-07-24T06:51:22.899583+0000 | pre_initialize_structure | INFO - Compression lifecycle structure pre-initialized for 0 modifiers
2024-07-24T06:51:22.913111+0000 | one_shot | INFO - *** One Shot ***
We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
2024-07-24T06:51:35.520914+0000 | from_modifiers | INFO - Creating recipe from modifiers
/home/jobuser/.local/lib/python3.10/site-packages/pydantic/main.py:364: UserWarning: Pydantic serializer warnings:
  Expected `tuple[any, ...]` but got `list` - serialized value may not be as expected
  Expected `tuple[any, ...]` but got `list` - serialized value may not be as expected
  return self.__pydantic_serializer__.to_python(
2024-07-24T06:51:35.524997+0000 | create_instance | WARNING - Could not process input as a file path or zoo stub, attempting to process it as a string.
2024-07-24T06:51:35.600929+0000 | _check_compile_recipe | INFO - Recipe compiled and 1 modifiers created
Traceback (most recent call last):
  File "/home/jobuser/test_large_quant.py", line 35, in <module>
    oneshot(model=model_path,
  File "/home/jobuser/llm-compressor/src/llmcompressor/transformers/finetune/text_generation.py", line 76, in oneshot
    main(model_args, data_args, training_args)
  File "/home/jobuser/llm-compressor/src/llmcompressor/transformers/finetune/text_generation.py", line 358, in main
    stage_runner.one_shot()
  File "/home/jobuser/llm-compressor/src/llmcompressor/transformers/finetune/runner.py", line 157, in one_shot
    self.trainer.one_shot(calib_data, stage=stage)
  File "/home/jobuser/llm-compressor/src/llmcompressor/transformers/finetune/session_mixin.py", line 399, in one_shot
    apply(
  File "/home/jobuser/llm-compressor/src/llmcompressor/core/session_functions.py", line 184, in apply
    return active_session().apply(
  File "/home/jobuser/llm-compressor/src/llmcompressor/core/session.py", line 210, in apply
    self.initialize(**kwargs)
  File "/home/jobuser/llm-compressor/src/llmcompressor/core/session.py", line 156, in initialize
    mod_data = self._lifecycle.initialize(
  File "/home/jobuser/llm-compressor/src/llmcompressor/core/lifecycle.py", line 126, in initialize
    data = mod.initialize(state=self.state, **extras)
  File "/home/jobuser/llm-compressor/src/llmcompressor/modifiers/stage.py", line 124, in initialize
    modifier.initialize(state, **kwargs)
  File "/home/jobuser/llm-compressor/src/llmcompressor/modifiers/modifier.py", line 118, in initialize
    initialized = self.on_initialize(state=state, **kwargs)
  File "/home/jobuser/llm-compressor/src/llmcompressor/modifiers/smoothquant/base.py", line 127, in on_initialize
    self.resolved_mappings_ = self._resolve_mappings(state.model)
  File "/home/jobuser/llm-compressor/src/llmcompressor/modifiers/smoothquant/base.py", line 184, in _resolve_mappings
    _, balance_layer = get_matching_layer(
  File "/home/jobuser/llm-compressor/src/llmcompressor/utils/pytorch/module.py", line 311, in get_matching_layer
    potential_matches = get_layers(target, module)
  File "/home/jobuser/llm-compressor/src/llmcompressor/utils/pytorch/module.py", line 166, in get_layers
    return match_layers_params(targets, module)
  File "/home/jobuser/llm-compressor/src/llmcompressor/utils/pytorch/module.py", line 160, in match_layers_params
    raise ValueError(f"Could not find targets {missed} in module {module}")
ValueError: Could not find targets ['re:.*gate_proj'] in module MixtralForCausalLM(
  (model): MixtralModel(
    (embed_tokens): Embedding(32768, 6144)
    (layers): ModuleList(
      (0-55): 56 x MixtralDecoderLayer(
        (self_attn): MixtralSdpaAttention(
          (q_proj): Linear(in_features=6144, out_features=6144, bias=False)
          (k_proj): Linear(in_features=6144, out_features=1024, bias=False)
          (v_proj): Linear(in_features=6144, out_features=1024, bias=False)
          (o_proj): Linear(in_features=6144, out_features=6144, bias=False)
          (rotary_emb): MixtralRotaryEmbedding()
        )
        (block_sparse_moe): MixtralSparseMoeBlock(
          (gate): Linear(in_features=6144, out_features=8, bias=False)
          (experts): ModuleList(
            (0-7): 8 x MixtralBlockSparseTop2MLP(
              (w1): Linear(in_features=6144, out_features=16384, bias=False)
              (w2): Linear(in_features=16384, out_features=6144, bias=False)
              (w3): Linear(in_features=6144, out_features=16384, bias=False)
              (act_fn): SiLU()
            )
          )
        )
        (input_layernorm): MixtralRMSNorm()
        (post_attention_layernorm): MixtralRMSNorm()
      )
    )
    (norm): MixtralRMSNorm()
  )
  (lm_head): Linear(in_features=6144, out_features=32768, bias=False)
)`

Another OOM issue cannot find the log but happens after 4 layers

Additional context
Add any other context about the problem here. Also include any relevant files.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions