Skip to content

Commit c0db198

Browse files
committed
- build a default logger in BaseModelWorker.__init__ to avoid duplicated
loggers - add 3 abstract methods - run format.sh
1 parent 7f0141c commit c0db198

File tree

14 files changed

+302
-22
lines changed

14 files changed

+302
-22
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ This requires 8-bit compression to be enabled and the bitsandbytes package to be
185185

186186
#### More Platforms and Quantization
187187
- For AMD GPU users, please install ROCm and [the ROCm version of PyTorch](https://pytorch.org/get-started/locally/) before you install FastChat. See also this [post](https://github.com/lm-sys/FastChat/issues/104#issuecomment-1613791563).
188+
- FastChat supports ExLlama V2. See [docs/exllama_v2.md](/docs/exllama_v2.md).
188189
- FastChat supports GPTQ 4bit inference with [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa). See [docs/gptq.md](/docs/gptq.md).
189190
- FastChat supports AWQ 4bit inference with [mit-han-lab/llm-awq](https://github.com/mit-han-lab/llm-awq). See [docs/awq.md](/docs/awq.md).
190191
- [MLC LLM](https://mlc.ai/mlc-llm/), backed by [TVM Unity](https://github.com/apache/tvm/tree/unity) compiler, deploys Vicuna natively on phones, consumer-class GPUs and web browsers via Vulkan, Metal, CUDA and WebGPU.

docker/Dockerfile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ FROM nvidia/cuda:11.7.1-runtime-ubuntu20.04
33
RUN apt-get update -y && apt-get install -y python3.9 python3.9-distutils curl
44
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
55
RUN python3.9 get-pip.py
6-
RUN pip3 install fschat
6+
RUN pip3 install fschat
7+
RUN pip3 install fschat[model_worker,webui] pydantic==1.10.1

docs/exllama_v2.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# ExllamaV2 GPTQ Inference Franework
2+
3+
Integrated [ExllamaV2](https://github.com/turboderp/exllamav2) customized kernel into Fastchat to provide **Faster** GPTQ inference speed.
4+
5+
**Note: Exllama not yet support embedding REST API.**
6+
7+
## Install ExllamaV2
8+
9+
Setup environment (please refer to [this link](https://github.com/turboderp/exllamav2#how-to) for more details):
10+
11+
```bash
12+
git clone https://github.com/turboderp/exllamav2
13+
cd exllamav2
14+
pip install -e .
15+
```
16+
17+
Chat with the CLI:
18+
```bash
19+
python3 -m fastchat.serve.cli \
20+
--model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \
21+
--enable-exllama
22+
```
23+
24+
Start model worker:
25+
```bash
26+
# Download quantized model from huggingface
27+
# Make sure you have git-lfs installed (https://git-lfs.com)
28+
git lfs install
29+
git clone https://huggingface.co/TheBloke/vicuna-7B-1.1-GPTQ-4bit-128g models/vicuna-7B-1.1-GPTQ-4bit-128g
30+
31+
# Load model with default configuration (max sequence length 4096, no GPU split setting).
32+
python3 -m fastchat.serve.model_worker \
33+
--model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \
34+
--enable-exllama
35+
36+
#Load model with max sequence length 2048, allocate 18 GB to CUDA:0 and 24 GB to CUDA:1.
37+
python3 -m fastchat.serve.model_worker \
38+
--model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \
39+
--enable-exllama \
40+
--exllama-max-seq-len 2048 \
41+
--exllama-gpu-split 18,24
42+
```
43+
44+
## Performance
45+
46+
Reference: https://github.com/turboderp/exllamav2#performance
47+
48+
49+
| Model | Mode | Size | grpsz | act | V1: 3090Ti | V1: 4090 | V2: 3090Ti | V2: 4090 |
50+
|------------|--------------|-------|-------|-----|------------|----------|------------|-------------|
51+
| Llama | GPTQ | 7B | 128 | no | 143 t/s | 173 t/s | 175 t/s | **195** t/s |
52+
| Llama | GPTQ | 13B | 128 | no | 84 t/s | 102 t/s | 105 t/s | **110** t/s |
53+
| Llama | GPTQ | 33B | 128 | yes | 37 t/s | 45 t/s | 45 t/s | **48** t/s |
54+
| OpenLlama | GPTQ | 3B | 128 | yes | 194 t/s | 226 t/s | 295 t/s | **321** t/s |
55+
| CodeLlama | EXL2 4.0 bpw | 34B | - | - | - | - | 42 t/s | **48** t/s |
56+
| Llama2 | EXL2 3.0 bpw | 7B | - | - | - | - | 195 t/s | **224** t/s |
57+
| Llama2 | EXL2 4.0 bpw | 7B | - | - | - | - | 164 t/s | **197** t/s |
58+
| Llama2 | EXL2 5.0 bpw | 7B | - | - | - | - | 144 t/s | **160** t/s |
59+
| Llama2 | EXL2 2.5 bpw | 70B | - | - | - | - | 30 t/s | **35** t/s |
60+
| TinyLlama | EXL2 3.0 bpw | 1.1B | - | - | - | - | 536 t/s | **635** t/s |
61+
| TinyLlama | EXL2 4.0 bpw | 1.1B | - | - | - | - | 509 t/s | **590** t/s |

fastchat/model/compression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def load_compress_model(model_path, device, torch_dtype, use_fast, revision="mai
147147
# We don't necessarily need to download the model' repo again if there is a cache.
148148
# So check the default huggingface cache first.
149149
model_path_temp = os.path.join(
150-
os.getenv("HOME"),
150+
os.path.expanduser("~"),
151151
".cache/huggingface/hub",
152152
"models--" + model_path.replace("/", "--"),
153153
"snapshots/",

fastchat/model/model_adapter.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,19 @@
2727
)
2828

2929
from fastchat.constants import CPU_ISA
30-
from fastchat.modules.gptq import GptqConfig, load_gptq_quantized
31-
from fastchat.modules.awq import AWQConfig, load_awq_quantized
3230
from fastchat.conversation import Conversation, get_conv_template
3331
from fastchat.model.compression import load_compress_model
3432
from fastchat.model.llama_condense_monkey_patch import replace_llama_with_condense
3533
from fastchat.model.model_chatglm import generate_stream_chatglm
3634
from fastchat.model.model_codet5p import generate_stream_codet5p
3735
from fastchat.model.model_falcon import generate_stream_falcon
36+
from fastchat.model.model_exllama import generate_stream_exllama
3837
from fastchat.model.monkey_patch_non_inplace import (
3938
replace_llama_attn_with_non_inplace_operations,
4039
)
40+
from fastchat.modules.awq import AWQConfig, load_awq_quantized
41+
from fastchat.modules.exllama import ExllamaConfig, load_exllama_model
42+
from fastchat.modules.gptq import GptqConfig, load_gptq_quantized
4143
from fastchat.utils import get_gpu_memory
4244

4345
# Check an environment variable to check if we should be sharing Peft model
@@ -155,6 +157,7 @@ def load_model(
155157
cpu_offloading: bool = False,
156158
gptq_config: Optional[GptqConfig] = None,
157159
awq_config: Optional[AWQConfig] = None,
160+
exllama_config: Optional[ExllamaConfig] = None,
158161
revision: str = "main",
159162
debug: bool = False,
160163
):
@@ -279,6 +282,9 @@ def load_model(
279282
else:
280283
model.to(device)
281284
return model, tokenizer
285+
elif exllama_config:
286+
model, tokenizer = load_exllama_model(model_path, exllama_config)
287+
return model, tokenizer
282288
kwargs["revision"] = revision
283289

284290
if dtype is not None: # Overwrite dtype if it is provided in the arguments.
@@ -325,13 +331,17 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str):
325331
is_falcon = "rwforcausallm" in model_type
326332
is_codet5p = "codet5p" in model_type
327333
is_peft = "peft" in model_type
334+
is_exllama = "exllama" in model_type
328335

329336
if is_chatglm:
330337
return generate_stream_chatglm
331338
elif is_falcon:
332339
return generate_stream_falcon
333340
elif is_codet5p:
334341
return generate_stream_codet5p
342+
elif is_exllama:
343+
return generate_stream_exllama
344+
335345
elif peft_share_base_weights and is_peft:
336346
# Return a curried stream function that loads the right adapter
337347
# according to the model_name available in this context. This ensures
@@ -453,6 +463,23 @@ def add_model_args(parser):
453463
default=-1,
454464
help="Used for AWQ. Groupsize to use for AWQ quantization; default uses full row.",
455465
)
466+
parser.add_argument(
467+
"--enable-exllama",
468+
action="store_true",
469+
help="Used for exllamabv2. Enable exllamaV2 inference framework.",
470+
)
471+
parser.add_argument(
472+
"--exllama-max-seq-len",
473+
type=int,
474+
default=4096,
475+
help="Used for exllamabv2. Max sequence length to use for exllamav2 framework; default 4096 sequence length.",
476+
)
477+
parser.add_argument(
478+
"--exllama-gpu-split",
479+
type=str,
480+
default=None,
481+
help="Used for exllamabv2. Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7",
482+
)
456483

457484

458485
def remove_parent_directory_name(model_path):
@@ -1625,6 +1652,16 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
16251652
return get_conv_template("phind")
16261653

16271654

1655+
class Llama2ChangAdapter(Llama2Adapter):
1656+
"""The model adapter for Llama2-ko-chang (e.g., lcw99/llama2-ko-chang-instruct-chat)"""
1657+
1658+
def match(self, model_path: str):
1659+
return "llama2-ko-chang" in model_path.lower()
1660+
1661+
def get_default_conv_template(self, model_path: str) -> Conversation:
1662+
return get_conv_template("polyglot_changgpt")
1663+
1664+
16281665
# Note: the registration order matters.
16291666
# The one registered earlier has a higher matching priority.
16301667
register_model_adapter(PeftModelAdapter)
@@ -1684,6 +1721,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
16841721
register_model_adapter(ReaLMAdapter)
16851722
register_model_adapter(PhindCodeLlamaAdapter)
16861723
register_model_adapter(CodeLlamaAdapter)
1724+
register_model_adapter(Llama2ChangAdapter)
16871725

16881726
# After all adapters, try the default base adapter.
16891727
register_model_adapter(BaseModelAdapter)

fastchat/model/model_exllama.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import gc
2+
import sys
3+
from typing import Dict
4+
5+
import torch
6+
7+
8+
def generate_stream_exllama(
9+
model,
10+
tokenizer,
11+
params: Dict,
12+
device: str,
13+
context_len: int,
14+
stream_interval: int = 2,
15+
judge_sent_end: bool = False,
16+
):
17+
try:
18+
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
19+
except ImportError as e:
20+
print(f"Error: Failed to load Exllamav2. {e}")
21+
sys.exit(-1)
22+
23+
prompt = params["prompt"]
24+
25+
generator = ExLlamaV2StreamingGenerator(model.model, model.cache, tokenizer)
26+
settings = ExLlamaV2Sampler.Settings()
27+
28+
settings.temperature = float(params.get("temperature", 0.85))
29+
settings.top_k = int(params.get("top_k", 50))
30+
settings.top_p = float(params.get("top_p", 0.8))
31+
settings.token_repetition_penalty = float(params.get("repetition_penalty", 1.15))
32+
settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id])
33+
34+
max_new_tokens = int(params.get("max_new_tokens", 256))
35+
36+
generator.set_stop_conditions(params.get("stop_token_ids", None) or [])
37+
echo = bool(params.get("echo", True))
38+
39+
input_ids = generator.tokenizer.encode(prompt)
40+
prompt_tokens = input_ids.shape[-1]
41+
generator.begin_stream(input_ids, settings)
42+
43+
generated_tokens = 0
44+
if echo:
45+
output = prompt
46+
else:
47+
output = ""
48+
while True:
49+
chunk, eos, _ = generator.stream()
50+
output += chunk
51+
generated_tokens += 1
52+
if generated_tokens == max_new_tokens:
53+
finish_reason = "length"
54+
break
55+
elif eos:
56+
finish_reason = "length"
57+
break
58+
yield {
59+
"text": output,
60+
"usage": {
61+
"prompt_tokens": prompt_tokens,
62+
"completion_tokens": generated_tokens,
63+
"total_tokens": prompt_tokens + generated_tokens,
64+
},
65+
"finish_reason": None,
66+
}
67+
68+
yield {
69+
"text": output,
70+
"usage": {
71+
"prompt_tokens": prompt_tokens,
72+
"completion_tokens": generated_tokens,
73+
"total_tokens": prompt_tokens + generated_tokens,
74+
},
75+
"finish_reason": finish_reason,
76+
}
77+
gc.collect()

fastchat/modules/exllama.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from dataclasses import dataclass, field
2+
import sys
3+
4+
5+
@dataclass
6+
class ExllamaConfig:
7+
max_seq_len: int
8+
gpu_split: str = None
9+
10+
11+
class ExllamaModel:
12+
def __init__(self, exllama_model, exllama_cache):
13+
self.model = exllama_model
14+
self.cache = exllama_cache
15+
self.config = self.model.config
16+
17+
18+
def load_exllama_model(model_path, exllama_config: ExllamaConfig):
19+
try:
20+
from exllamav2 import (
21+
ExLlamaV2Config,
22+
ExLlamaV2Tokenizer,
23+
ExLlamaV2,
24+
ExLlamaV2Cache,
25+
)
26+
except ImportError as e:
27+
print(f"Error: Failed to load Exllamav2. {e}")
28+
sys.exit(-1)
29+
30+
exllamav2_config = ExLlamaV2Config()
31+
exllamav2_config.model_dir = model_path
32+
exllamav2_config.prepare()
33+
exllamav2_config.max_seq_len = exllama_config.max_seq_len
34+
35+
exllama_model = ExLlamaV2(exllamav2_config)
36+
tokenizer = ExLlamaV2Tokenizer(exllamav2_config)
37+
38+
split = None
39+
if exllama_config.gpu_split:
40+
split = [float(alloc) for alloc in exllama_config.gpu_split.split(",")]
41+
exllama_model.load(split)
42+
43+
exllama_cache = ExLlamaV2Cache(exllama_model)
44+
model = ExllamaModel(exllama_model=exllama_model, exllama_cache=exllama_cache)
45+
46+
return model, tokenizer

fastchat/serve/base_model_worker.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import asyncio
33
import threading
44
import requests
5+
import uuid
56
from fastchat.constants import WORKER_HEART_BEAT_INTERVAL
67
from fastchat.conversation import Conversation
78
from fastchat.utils import pretty_print_semaphore, build_logger
@@ -10,7 +11,7 @@
1011
from typing import List
1112

1213

13-
worker_id = None
14+
worker_id = str(uuid.uuid4())[:8]
1415
worker = None
1516
logger = None
1617

@@ -34,6 +35,8 @@ def __init__(
3435
limit_worker_concurrency: int,
3536
conv_template: str = None,
3637
):
38+
global logger
39+
3740
self.controller_addr = controller_addr
3841
self.worker_addr = worker_addr
3942
self.worker_id = worker_id
@@ -50,14 +53,17 @@ def __init__(
5053

5154
self.heart_beat_thread = None
5255

56+
if logger is None:
57+
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
58+
5359
def make_conv_template(
5460
self,
5561
conv_template: str = None,
5662
model_path: str = None,
57-
)-> Conversation:
58-
'''
63+
) -> Conversation:
64+
"""
5965
can be overrided to costomize the conversation template for different model workers.
60-
'''
66+
"""
6167
from fastchat.conversation import get_conv_template
6268
from fastchat.model.model_adapter import get_conversation_template
6369

@@ -140,8 +146,12 @@ def get_status(self):
140146

141147
def count_token(self, params):
142148
prompt = params["prompt"]
143-
input_ids = self.tokenizer(prompt).input_ids
144-
input_echo_len = len(input_ids)
149+
150+
try:
151+
input_ids = self.tokenizer(prompt).input_ids
152+
input_echo_len = len(input_ids)
153+
except TypeError:
154+
input_echo_len = self.tokenizer.num_tokens(prompt)
145155

146156
ret = {
147157
"count": input_echo_len,
@@ -152,6 +162,15 @@ def count_token(self, params):
152162
def get_conv_template(self):
153163
return {"conv": self.conv}
154164

165+
def generate_stream_gate(self, params):
166+
raise NotImplementedError
167+
168+
def generate_gate(self, params):
169+
raise NotImplementedError
170+
171+
def get_embeddings(self, params):
172+
raise NotImplementedError
173+
155174

156175
def release_worker_semaphore():
157176
worker.semaphore.release()

0 commit comments

Comments
 (0)