|
27 | 27 | ) |
28 | 28 |
|
29 | 29 | from fastchat.constants import CPU_ISA |
30 | | -from fastchat.modules.gptq import GptqConfig, load_gptq_quantized |
31 | | -from fastchat.modules.awq import AWQConfig, load_awq_quantized |
32 | 30 | from fastchat.conversation import Conversation, get_conv_template |
33 | 31 | from fastchat.model.compression import load_compress_model |
34 | 32 | from fastchat.model.llama_condense_monkey_patch import replace_llama_with_condense |
35 | 33 | from fastchat.model.model_chatglm import generate_stream_chatglm |
36 | 34 | from fastchat.model.model_codet5p import generate_stream_codet5p |
37 | 35 | from fastchat.model.model_falcon import generate_stream_falcon |
| 36 | +from fastchat.model.model_exllama import generate_stream_exllama |
38 | 37 | from fastchat.model.monkey_patch_non_inplace import ( |
39 | 38 | replace_llama_attn_with_non_inplace_operations, |
40 | 39 | ) |
| 40 | +from fastchat.modules.awq import AWQConfig, load_awq_quantized |
| 41 | +from fastchat.modules.exllama import ExllamaConfig, load_exllama_model |
| 42 | +from fastchat.modules.gptq import GptqConfig, load_gptq_quantized |
41 | 43 | from fastchat.utils import get_gpu_memory |
42 | 44 |
|
43 | 45 | # Check an environment variable to check if we should be sharing Peft model |
@@ -155,6 +157,7 @@ def load_model( |
155 | 157 | cpu_offloading: bool = False, |
156 | 158 | gptq_config: Optional[GptqConfig] = None, |
157 | 159 | awq_config: Optional[AWQConfig] = None, |
| 160 | + exllama_config: Optional[ExllamaConfig] = None, |
158 | 161 | revision: str = "main", |
159 | 162 | debug: bool = False, |
160 | 163 | ): |
@@ -279,6 +282,9 @@ def load_model( |
279 | 282 | else: |
280 | 283 | model.to(device) |
281 | 284 | return model, tokenizer |
| 285 | + elif exllama_config: |
| 286 | + model, tokenizer = load_exllama_model(model_path, exllama_config) |
| 287 | + return model, tokenizer |
282 | 288 | kwargs["revision"] = revision |
283 | 289 |
|
284 | 290 | if dtype is not None: # Overwrite dtype if it is provided in the arguments. |
@@ -325,13 +331,17 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str): |
325 | 331 | is_falcon = "rwforcausallm" in model_type |
326 | 332 | is_codet5p = "codet5p" in model_type |
327 | 333 | is_peft = "peft" in model_type |
| 334 | + is_exllama = "exllama" in model_type |
328 | 335 |
|
329 | 336 | if is_chatglm: |
330 | 337 | return generate_stream_chatglm |
331 | 338 | elif is_falcon: |
332 | 339 | return generate_stream_falcon |
333 | 340 | elif is_codet5p: |
334 | 341 | return generate_stream_codet5p |
| 342 | + elif is_exllama: |
| 343 | + return generate_stream_exllama |
| 344 | + |
335 | 345 | elif peft_share_base_weights and is_peft: |
336 | 346 | # Return a curried stream function that loads the right adapter |
337 | 347 | # according to the model_name available in this context. This ensures |
@@ -453,6 +463,23 @@ def add_model_args(parser): |
453 | 463 | default=-1, |
454 | 464 | help="Used for AWQ. Groupsize to use for AWQ quantization; default uses full row.", |
455 | 465 | ) |
| 466 | + parser.add_argument( |
| 467 | + "--enable-exllama", |
| 468 | + action="store_true", |
| 469 | + help="Used for exllamabv2. Enable exllamaV2 inference framework.", |
| 470 | + ) |
| 471 | + parser.add_argument( |
| 472 | + "--exllama-max-seq-len", |
| 473 | + type=int, |
| 474 | + default=4096, |
| 475 | + help="Used for exllamabv2. Max sequence length to use for exllamav2 framework; default 4096 sequence length.", |
| 476 | + ) |
| 477 | + parser.add_argument( |
| 478 | + "--exllama-gpu-split", |
| 479 | + type=str, |
| 480 | + default=None, |
| 481 | + help="Used for exllamabv2. Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7", |
| 482 | + ) |
456 | 483 |
|
457 | 484 |
|
458 | 485 | def remove_parent_directory_name(model_path): |
@@ -1625,6 +1652,16 @@ def get_default_conv_template(self, model_path: str) -> Conversation: |
1625 | 1652 | return get_conv_template("phind") |
1626 | 1653 |
|
1627 | 1654 |
|
| 1655 | +class Llama2ChangAdapter(Llama2Adapter): |
| 1656 | + """The model adapter for Llama2-ko-chang (e.g., lcw99/llama2-ko-chang-instruct-chat)""" |
| 1657 | + |
| 1658 | + def match(self, model_path: str): |
| 1659 | + return "llama2-ko-chang" in model_path.lower() |
| 1660 | + |
| 1661 | + def get_default_conv_template(self, model_path: str) -> Conversation: |
| 1662 | + return get_conv_template("polyglot_changgpt") |
| 1663 | + |
| 1664 | + |
1628 | 1665 | # Note: the registration order matters. |
1629 | 1666 | # The one registered earlier has a higher matching priority. |
1630 | 1667 | register_model_adapter(PeftModelAdapter) |
@@ -1684,6 +1721,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation: |
1684 | 1721 | register_model_adapter(ReaLMAdapter) |
1685 | 1722 | register_model_adapter(PhindCodeLlamaAdapter) |
1686 | 1723 | register_model_adapter(CodeLlamaAdapter) |
| 1724 | +register_model_adapter(Llama2ChangAdapter) |
1687 | 1725 |
|
1688 | 1726 | # After all adapters, try the default base adapter. |
1689 | 1727 | register_model_adapter(BaseModelAdapter) |
0 commit comments