|
1 | 1 | import argparse
|
2 | 2 | import dataclasses
|
| 3 | +import json |
3 | 4 | from dataclasses import dataclass
|
4 | 5 | from typing import List, Optional, Tuple, Union
|
5 | 6 |
|
@@ -49,6 +50,7 @@ class EngineArgs:
|
49 | 50 | disable_log_stats: bool = False
|
50 | 51 | revision: Optional[str] = None
|
51 | 52 | code_revision: Optional[str] = None
|
| 53 | + rope_scaling: Optional[dict] = None |
52 | 54 | tokenizer_revision: Optional[str] = None
|
53 | 55 | quantization: Optional[str] = None
|
54 | 56 | enforce_eager: bool = False
|
@@ -330,6 +332,11 @@ def add_cli_args(
|
330 | 332 | 'None, we assume the model weights are not '
|
331 | 333 | 'quantized and use `dtype` to determine the data '
|
332 | 334 | 'type of the weights.')
|
| 335 | + parser.add_argument('--rope-scaling', |
| 336 | + default=None, |
| 337 | + type=json.loads, |
| 338 | + help='RoPE scaling configuration in JSON format. ' |
| 339 | + 'For example, {"type":"dynamic","factor":2.0}') |
333 | 340 | parser.add_argument('--enforce-eager',
|
334 | 341 | action='store_true',
|
335 | 342 | help='Always use eager-mode PyTorch. If False, '
|
@@ -548,11 +555,12 @@ def create_engine_config(self, ) -> EngineConfig:
|
548 | 555 | model_config = ModelConfig(
|
549 | 556 | self.model, self.tokenizer, self.tokenizer_mode,
|
550 | 557 | self.trust_remote_code, self.dtype, self.seed, self.revision,
|
551 |
| - self.code_revision, self.tokenizer_revision, self.max_model_len, |
552 |
| - self.quantization, self.quantization_param_path, |
553 |
| - self.enforce_eager, self.max_context_len_to_capture, |
554 |
| - self.max_seq_len_to_capture, self.max_logprobs, |
555 |
| - self.skip_tokenizer_init, self.served_model_name) |
| 558 | + self.code_revision, self.rope_scaling, self.tokenizer_revision, |
| 559 | + self.max_model_len, self.quantization, |
| 560 | + self.quantization_param_path, self.enforce_eager, |
| 561 | + self.max_context_len_to_capture, self.max_seq_len_to_capture, |
| 562 | + self.max_logprobs, self.skip_tokenizer_init, |
| 563 | + self.served_model_name) |
556 | 564 | cache_config = CacheConfig(self.block_size,
|
557 | 565 | self.gpu_memory_utilization,
|
558 | 566 | self.swap_space, self.kv_cache_dtype,
|
|
0 commit comments