Skip to content

DeepSeek: Fix Hessian Estimation #157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 11, 2024
Merged

DeepSeek: Fix Hessian Estimation #157

merged 4 commits into from
Sep 11, 2024

Conversation

Satrat
Copy link
Contributor

@Satrat Satrat commented Sep 10, 2024

SUMMARY:
We were running out of memory using GPTQ on deepseek because our hessian memory estimation code assumes all the decoder layers are the same shape. This is not the case for deepseek. Updating the function to find the largest layer across all the decoders fixed the issue.

TEST PLAN:

import torch
from datasets import load_dataset
from transformers import AutoTokenizer

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot
from llmcompressor.transformers.compression.helpers import calculate_offload_device_map

# select a Mixture of Experts model for quantization
MODEL_ID = "deepseek-ai/DeepSeek-Coder-V2-Instruct"

# adjust based off number of desired GPUs
# if not enough memory is available, some layers will automatically be offlaoded to cpu
device_map = calculate_offload_device_map(
    MODEL_ID,
    reserve_for_hessians=True,
    num_gpus=2,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)

model = SparseAutoModelForCausalLM.from_pretrained(
    MODEL_ID, device_map=device_map, torch_dtype=torch.bfloat16, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048


# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))


def preprocess(example):
    return {
        "text": tokenizer.apply_chat_template(
            example["messages"],
            tokenize=False,
        )
    }


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
    return tokenizer(
        sample["text"],
        padding=False,
        max_length=MAX_SEQUENCE_LENGTH,
        truncation=True,
        add_special_tokens=False,
    )


ds = ds.map(tokenize, remove_columns=ds.column_names)

# define a llmcompressor recipe for W416 quantization
# since the MoE gate layers are sensitive to quantization, we add them to the ignore
# list so they remain at full precision
recipe = [
    GPTQModifier(
        targets="Linear",
        scheme="W8A8",
        ignore=["lm_head", "re:.*mlp.gate$"],
        sequential_update=True,
    ),
]

oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    save_compressed=True,
)

@Satrat Satrat requested a review from dsikka September 10, 2024 14:36
Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - this was tested on the full deepseekv2 model?

@Satrat Satrat merged commit 3ee99bd into main Sep 11, 2024
6 of 7 checks passed
@Satrat Satrat deleted the sa/deepseek_fix branch September 11, 2024 14:04
@Satrat
Copy link
Contributor Author

Satrat commented Sep 11, 2024

LGTM - this was tested on the full deepseekv2 model?

I let it run to completion on the lite model, the full model is really slow so I killed it halfway through

dbarbuzzi pushed a commit to dbarbuzzi/llm-compressor that referenced this pull request Sep 11, 2024
* loop through all decoder layers to figure out hessians

* remove print
markmc pushed a commit to markmc/llm-compressor that referenced this pull request Nov 13, 2024
* remove weight and input details

* add default test

* PR comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants