Skip to content

Commit 30eb23c

Browse files
kylesayrsaireilly
authored andcommitted
[Examples] Standardize AWQ example (vllm-project#1412)
## Purpose ## * Standardize awq example to follow the same format as the other examples ## Changes ## * Rearrange code to match the format of other examples * Use the chat template to match format of other examples * Do not load 100x more samples than are needed: instead, only load the number of examples that is required * Do not manually truncate input ids, instead utilize the truncation provided by the tokenizer ## Testing ## * Ran example to completion and confirmed good generation --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent 6f8e4c3 commit 30eb23c

File tree

2 files changed

+121
-41
lines changed

2 files changed

+121
-41
lines changed

examples/awq/README.md

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Quantizing Models with Activation-Aware Quantization (AWQ) #
2+
3+
Activation Aware Quantization (AWQ) is a state-of-the-art technique to quantize the weights of large language models which involves using a small calibration dataset to calibrate the model. The AWQ algorithm utilizes calibration data to derive scaling factors which reduce the dynamic range of weights while minimizing accuracy loss to the most salient weight values.
4+
5+
The AWQ implementation found in LLM Compressor is derived from the pioneering work of [AutoAWQ](https://github.com/casper-hansen/AutoAWQ) and with assistance from its original maintainer, [@casper-hansen](https://github.com/casper-hansen).
6+
7+
## AWQ Recipe ##
8+
9+
The AWQ recipe has been inferfaced as follows, where the `AWQModifier` adjusts model scales ahead of efficient weight quantization by the `QuantizationModifier`
10+
11+
```python
12+
recipe = [
13+
AWQModifier(bits=4, symmetric=False),
14+
QuantizationModifier(
15+
ignore=["lm_head"],
16+
config_groups={
17+
"group_0": QuantizationScheme(
18+
targets=["Linear"],
19+
weights=QuantizationArgs(
20+
num_bits=4,
21+
type=QuantizationType.INT,
22+
dynamic=False,
23+
symmetric=False,
24+
strategy=QuantizationStrategy.GROUP,
25+
group_size=128,
26+
),
27+
)
28+
},
29+
),
30+
]
31+
```
32+
33+
## Compressing Your Own Model ##
34+
To use your own model, start with an existing example change the `model_id` to match your own model stub.
35+
```python
36+
model_id = "path/to/your/model"
37+
model = AutoModelForCausalLM.from_pretrained(
38+
model_id,
39+
device_map="auto",
40+
torch_dtype="auto",
41+
)
42+
```
43+
44+
## Adding Mappings ##
45+
In order to target weight and activation scaling locations within the model, the `AWQModifier` must be provided an AWQ mapping. For example, the AWQ mapping for the Llama family of models looks like this:
46+
47+
```python
48+
[
49+
AWQMapping(
50+
"re:.*input_layernorm",
51+
["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"],
52+
),
53+
AWQMapping("re:.*v_proj", ["re:.*o_proj"]),
54+
AWQMapping(
55+
"re:.*post_attention_layernorm",
56+
["re:.*gate_proj", "re:.*up_proj"],
57+
),
58+
AWQMapping(
59+
"re:.*up_proj",
60+
["re:.*down_proj"],
61+
),
62+
]
63+
```
64+
65+
To support other model families, you can add supply your own mappings via the `mappings` argument with instantiating the `AWQModifier`, or you can add them to the registry [here](/src/llmcompressor/modifiers/awq/mappings.py) (contributions are welcome!)

examples/awq/awq_one_shot.py renamed to examples/awq/llama_example.py

Lines changed: 56 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,61 @@
55
QuantizationStrategy,
66
QuantizationType,
77
)
8+
from datasets import load_dataset
89
from lm_eval.utils import make_table
910
from transformers import AutoModelForCausalLM, AutoTokenizer
1011

1112
from llmcompressor import oneshot
1213
from llmcompressor.modifiers.awq import AWQModifier
1314
from llmcompressor.modifiers.quantization import QuantizationModifier
1415

15-
# This example demonstrates how to:
16-
# 1) Run the `llm-compressor` implementation of AWQ
17-
# 2) Evaluate the compressed model with the lm_eval framework
16+
# Select model and load it.
17+
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
18+
19+
model = AutoModelForCausalLM.from_pretrained(
20+
MODEL_ID, device_map="auto", torch_dtype="auto"
21+
)
22+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
1823

24+
# Select calibration dataset.
1925
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
2026
DATASET_ID = "mit-han-lab/pile-val-backup"
2127
DATASET_SPLIT = "validation"
28+
29+
# Select number of samples. 256 samples is a good place to start.
30+
# Increasing the number of samples can improve accuracy.
2231
NUM_CALIBRATION_SAMPLES = 256
2332
MAX_SEQUENCE_LENGTH = 512
24-
OUTPUT_DIR = MODEL_ID.split("/")[-1] + "-awq-asym"
2533

26-
#
27-
# 1) Run LLM Compressor AWQ implementation
28-
#
34+
# Load dataset and preprocess.
35+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
36+
ds = ds.shuffle(seed=42)
37+
38+
39+
def preprocess(example):
40+
return {
41+
"text": tokenizer.apply_chat_template(
42+
[{"role": "user", "content": example["text"]}],
43+
tokenize=False,
44+
)
45+
}
46+
47+
48+
ds = ds.map(preprocess)
49+
50+
51+
# Tokenize inputs.
52+
def tokenize(sample):
53+
return tokenizer(
54+
sample["text"],
55+
padding=False,
56+
max_length=MAX_SEQUENCE_LENGTH,
57+
truncation=True,
58+
add_special_tokens=False,
59+
)
2960

61+
62+
# Configure the quantization algorithm to run.
3063
recipe = [
3164
AWQModifier(bits=4, symmetric=False),
3265
QuantizationModifier(
@@ -47,54 +80,36 @@
4780
),
4881
]
4982

50-
model = AutoModelForCausalLM.from_pretrained(
51-
MODEL_ID, device_map="auto", torch_dtype="auto"
52-
)
53-
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
54-
55-
56-
def get_calib_dataset(tokenizer):
57-
from datasets import load_dataset
58-
59-
ds = load_dataset(
60-
DATASET_ID,
61-
split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES*100}]",
62-
)
63-
64-
def preprocess(example):
65-
return {
66-
"input_ids": tokenizer.encode(example["text"].strip())[:MAX_SEQUENCE_LENGTH]
67-
}
68-
69-
ds = (
70-
ds.shuffle(seed=42)
71-
.map(preprocess, remove_columns=ds.column_names)
72-
.filter(lambda example: len(example["input_ids"]) >= MAX_SEQUENCE_LENGTH)
73-
.select(range(NUM_CALIBRATION_SAMPLES))
74-
)
75-
76-
return ds
77-
78-
83+
# Apply algorithms.
7984
oneshot(
8085
model=model,
81-
dataset=get_calib_dataset(tokenizer=tokenizer),
86+
dataset=ds,
8287
recipe=recipe,
83-
output_dir=OUTPUT_DIR,
8488
max_seq_length=MAX_SEQUENCE_LENGTH,
8589
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
8690
)
8791

88-
print("Done! model saved to", OUTPUT_DIR)
92+
# Confirm generations of the quantized model look sane.
93+
print("\n\n")
94+
print("========== SAMPLE GENERATION ==============")
95+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
96+
output = model.generate(input_ids, max_new_tokens=100)
97+
print(tokenizer.decode(output[0]))
98+
print("==========================================\n\n")
99+
100+
# Save to disk compressed.
101+
SAVE_DIR = MODEL_ID.split("/")[-1] + "-awq-asym"
102+
model.save_pretrained(SAVE_DIR, save_compressed=True)
103+
tokenizer.save_pretrained(SAVE_DIR)
89104

90105
#
91106
# 2) Evaluate model on wikitext perplexity
92107
#
93108

94109
results = lm_eval.simple_evaluate(
95-
model="vllm",
110+
model="hf",
96111
model_args={
97-
"pretrained": OUTPUT_DIR,
112+
"pretrained": SAVE_DIR,
98113
"add_bos_token": True,
99114
"dtype": "bfloat16",
100115
"gpu_memory_utilization": 0.5,

0 commit comments

Comments
 (0)