Skip to content

Commit 018bd8d

Browse files
authored
use model.generation_config (#1850)
1 parent 38b114a commit 018bd8d

File tree

12 files changed

+86
-80
lines changed

12 files changed

+86
-80
lines changed

docs/source/LLM/命令行参数.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,11 @@
125125
- `--save_safetensors`: 默认为`True`.
126126
- `--include_num_input_tokens_seen`: 默认为`False`. 跟踪整个训练过程中观察到的输入tokens的数量.
127127
- `--max_new_tokens`: 默认为`2048`. 该参数只有在`predict_with_generate`设置为True的时候才生效.
128-
- `--do_sample`: 默认为`True`. 该参数只有在`predict_with_generate`设置为True的时候才生效.
129-
- `--temperature`: 默认为`0.3`. 该参数只有在`predict_with_generate`设置为True的时候才生效.
130-
- `--top_k`: 默认为`20`. 该参数只有在`predict_with_generate`设置为True的时候才生效.
131-
- `--top_p`: 默认为`0.7`. 该参数只有在`predict_with_generate`设置为True的时候才生效.
132-
- `--repetition_penalty`: 默认为`1.`. 该参数只有在`predict_with_generate`设置为True的时候才生效.
128+
- `--do_sample`: 参考文档: [https://huggingface.co/docs/transformers/main_classes/text_generation](https://huggingface.co/docs/transformers/main_classes/text_generation). 默认为`None`, 继承模型的generation_config. 该参数只有在`predict_with_generate`设置为True的时候才生效.
129+
- `--temperature`: 默认为`None`, 继承模型的generation_config. 该参数只有在`predict_with_generate`设置为True的时候才生效.
130+
- `--top_k`: 默认为`None`, 继承模型的generation_config. 该参数只有在`predict_with_generate`设置为True的时候才生效.
131+
- `--top_p`: 默认为`None`, 继承模型的generation_config. 该参数只有在`predict_with_generate`设置为True的时候才生效.
132+
- `--repetition_penalty`: 默认为`None`, 继承模型的generation_config. 该参数只有在`predict_with_generate`设置为True的时候才生效.
133133
- `--num_beams`: 默认为`1`. 该参数只有在`predict_with_generate`设置为True的时候才生效.
134134
- `--gpu_memory_fraction`: 默认为`None`. 该参数旨在指定显卡最大可用显存比例的情况下运行训练,用于极限测试.
135135
- `--train_dataset_mix_ratio`: 默认为`0.`. 该参数定义了如何进行数据集打混训练. 指定该参数时, 会混合训练集的`train_dataset_mix_ratio`倍数的`train_dataset_mix_ds`指定的通用知识数据集. 该参数已废弃, 请使用`--dataset`进行数据集混合.
@@ -327,11 +327,11 @@ RLHF参数继承了sft参数, 除此之外增加了以下参数:
327327
- `--bnb_4bit_use_double_quant`: 默认值为`True`. 具体的参数介绍可以在`sft命令行参数`中查看. 若`quantization_bit`设置为0, 则该参数失效.
328328
- `--bnb_4bit_quant_storage`: 默认值为`True`. 具体的参数介绍可以在`sft命令行参数`中查看. 若`quantization_bit`设置为0, 则该参数失效.
329329
- `--max_new_tokens`: 生成新token的最大数量, 默认值为`2048`.
330-
- `--do_sample`: 是使用贪婪生成的方式还是采样生成的方式, 默认值为`True`.
331-
- `--temperature`: 默认值为`0.3`. 该参数只有在`do_sample`设置为True时才生效. 该参数会在部署参数中作为默认值使用.
332-
- `--top_k`: 默认值为`20`. 该参数只有在`do_sample`设置为True时才生效. 该参数会在部署参数中作为默认值使用.
333-
- `--top_p`: 默认值为`0.7`. 该参数只有在`do_sample`设置为True时才生效. 该参数会在部署参数中作为默认值使用.
334-
- `--repetition_penalty`: 默认值为`1.`. 该参数会在部署参数中作为默认值使用.
330+
- `--do_sample`: 参考文档: [https://huggingface.co/docs/transformers/main_classes/text_generation](https://huggingface.co/docs/transformers/main_classes/text_generation). 默认值为`None`, 继承模型的generation_config.
331+
- `--temperature`: 默认值为`None`, 继承模型的generation_config. 该参数只有在`do_sample`设置为True时才生效. 该参数会在部署参数中作为默认值使用.
332+
- `--top_k`: 默认值为`None`, 继承模型的generation_config. 该参数只有在`do_sample`设置为True时才生效. 该参数会在部署参数中作为默认值使用.
333+
- `--top_p`: 默认值为`None`, 继承模型的generation_config. 该参数只有在`do_sample`设置为True时才生效. 该参数会在部署参数中作为默认值使用.
334+
- `--repetition_penalty`: 默认值为`None`, 继承模型的generation_config. 该参数会在部署参数中作为默认值使用.
335335
- `--num_beams`: 默认为`1`.
336336
- `--use_flash_attn`: 默认值为`None`, 即为'auto'. 具体的参数介绍可以在`sft命令行参数`中查看.
337337
- `--ignore_args_error`: 默认值为`False`, 具体的参数介绍可以在`sft命令行参数`中查看.

docs/source_en/LLM/Command-line-parameters.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,11 @@
126126
- `--save_safetensors`: Default is `True`.
127127
- `--include_num_input_tokens_seen`: Default is `False`. Tracks the number of input tokens seen throughout training.
128128
- `--max_new_tokens`: Default is `2048`. This parameter only takes effect when `predict_with_generate` is set to True.
129-
- `--do_sample`: Default is `True`. This parameter only takes effect when `predict_with_generate` is set to True.
130-
- `--temperature`: Default is `0.3`. This parameter only takes effect when `do_sample` is set to True. This parameter will be used as default value in deployment parameters.
131-
- `--top_k`: Default is `20`. This parameter only takes effect when `do_sample` is set to True. This parameter will be used as default value in deployment parameters.
132-
- `--top_p`: Default is `0.7`. This parameter only takes effect when `do_sample` is set to True. This parameter will be used as default value in deployment parameters.
133-
- `--repetition_penalty`: Default is `1.`. This parameter will be used as default value in deployment parameters.
129+
- `--do_sample`: Reference document: [https://huggingface.co/docs/transformers/main_classes/text_generation](https://huggingface.co/docs/transformers/main_classes/text_generation). Default is `None`, inheriting the model's generation_config. This parameter only takes effect when `predict_with_generate` is set to True.
130+
- `--temperature`: Default is `None`, inheriting the model's generation_config. This parameter only takes effect when `do_sample` is set to True. This parameter will be used as default value in deployment parameters.
131+
- `--top_k`: Default is `None`, inheriting the model's generation_config. This parameter only takes effect when `do_sample` is set to True. This parameter will be used as default value in deployment parameters.
132+
- `--top_p`: Default is `None`, inheriting the model's generation_config. This parameter only takes effect when `do_sample` is set to True. This parameter will be used as default value in deployment parameters.
133+
- `--repetition_penalty`: Default is `None`, inheriting the model's generation_config. This parameter will be used as default value in deployment parameters.
134134
- `--num_beams`: Default is `1`. This parameter only takes effect when `predict_with_generate` is set to True.
135135
- `--gpu_memory_fraction`: Default is `None`. This parameter aims to run training under a specified maximum available GPU memory percentage, used for extreme testing.
136136
- `--train_dataset_mix_ratio`: Default is `0.`. This parameter defines how to mix datasets for training. When this parameter is specified, it will mix the training dataset with a multiple of `train_dataset_mix_ratio` of the general knowledge dataset specified by `train_dataset_mix_ds`. This parameter has been deprecated, please use `--dataset {dataset_name}#{dataset_sample}` to mix datasets.
@@ -329,11 +329,11 @@ RLHF parameters are an extension of the sft parameters, with the addition of the
329329
- `--bnb_4bit_use_double_quant`: Default is `True`. See `sft command line arguments` for parameter details. If `quantization_bit` is set to 0, this parameter has no effect.
330330
- `--bnb_4bit_quant_storage`: Default value `None`.See `sft command line arguments` for parameter details. If `quantization_bit` is set to 0, this parameter has no effect.
331331
- `--max_new_tokens`: Maximum number of new tokens to generate, default is `2048`.
332-
- `--do_sample`: Whether to use greedy generation or sampling generation, default is `True`.
333-
- `--temperature`: Default is `0.3`. This parameter only takes effect when `do_sample` is set to True. This parameter will be used as default value in deployment parameters.
334-
- `--top_k`: Default is `20`. This parameter only takes effect when `do_sample` is set to True. This parameter will be used as default value in deployment parameters.
335-
- `--top_p`: Default is `0.7`. This parameter only takes effect when `do_sample` is set to True. This parameter will be used as default value in deployment parameters.
336-
- `--repetition_penalty`: Default is `1.`. This parameter will be used as default value in deployment parameters.
332+
- `--do_sample`: Reference document: [https://huggingface.co/docs/transformers/main_classes/text_generation](https://huggingface.co/docs/transformers/main_classes/text_generation). Default is `None`, inheriting the model's generation_config.
333+
- `--temperature`: Default is `None`, inheriting the model's generation_config. This parameter only takes effect when `do_sample` is set to True. This parameter will be used as default value in deployment parameters.
334+
- `--top_k`: Default is `None`, inheriting the model's generation_config. This parameter only takes effect when `do_sample` is set to True. This parameter will be used as default value in deployment parameters.
335+
- `--top_p`: Default is `None`, inheriting the model's generation_config. This parameter only takes effect when `do_sample` is set to True. This parameter will be used as default value in deployment parameters.
336+
- `--repetition_penalty`: Default is `None`, inheriting the model's generation_config. This parameter will be used as default value in deployment parameters.
337337
- `--num_beams`: Default is `1`.
338338
- `--use_flash_attn`: Default is `None`, i.e. 'auto'. See `sft command line arguments` for parameter details.
339339
- `--ignore_args_error`: Default is `False`, see `sft command line arguments` for parameter details.

swift/llm/infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,8 @@ def prepare_model_template(args: InferArguments,
203203
num_beams=args.num_beams,
204204
pad_token_id=tokenizer.pad_token_id,
205205
eos_token_id=tokenizer.eos_token_id)
206-
logger.info(f'generation_config: {generation_config}')
207206
set_generation_config(model, generation_config)
207+
logger.info(f'model.generation_config: {model.generation_config}')
208208

209209
if model.max_model_len is None:
210210
model.max_model_len = args.max_model_len

swift/llm/rlhf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
126126
num_beams=args.num_beams,
127127
pad_token_id=tokenizer.pad_token_id,
128128
eos_token_id=tokenizer.eos_token_id)
129-
logger.info(f'generation_config: {generation_config}')
130129
set_generation_config(model, generation_config)
130+
logger.info(f'model.generation_config: {model.generation_config}')
131131

132132
# Preparing LoRA
133133
model, _ = prepare_model(model, args)

swift/llm/rome.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def rome_infer(args: RomeArguments) -> None:
3535
num_beams=args.num_beams,
3636
pad_token_id=tokenizer.pad_token_id,
3737
eos_token_id=tokenizer.eos_token_id)
38-
logger.info(f'generation_config: {generation_config}')
3938
set_generation_config(model, generation_config)
39+
logger.info(f'model.generation_config: {model.generation_config}')
4040
if args.overwrite_generation_config:
4141
generation_config.save_pretrained(args.ckpt_dir)
4242

swift/llm/sft.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,9 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
235235
num_beams=args.num_beams,
236236
pad_token_id=tokenizer.pad_token_id,
237237
eos_token_id=tokenizer.eos_token_id)
238-
logger.info(f'generation_config: {generation_config}')
239238
set_generation_config(model, generation_config)
240-
training_args.generation_config = generation_config
239+
logger.info(f'model.generation_config: {model.generation_config}')
240+
training_args.generation_config = model.generation_config
241241

242242
if use_torchacc():
243243
import torchacc as ta

swift/llm/utils/argument.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ def __post_init__(self) -> None:
6464
self.device_map_config = json.load(f)
6565
else: # json str
6666
self.device_map_config = json.loads(self.device_map_config)
67+
_, local_rank, _, local_world_size = get_dist_setting()
68+
if local_world_size > 1 and isinstance(self.device_map_config, dict) and local_rank > 0:
69+
for k, v in self.device_map_config.items():
70+
if isinstance(v, int):
71+
self.device_map_config[k] += local_rank
6772

6873
@classmethod
6974
def _check_path(cls,
@@ -130,13 +135,6 @@ def check_flash_attn(self: Union['SftArguments', 'InferArguments']) -> None:
130135
def handle_generation_config(self: Union['SftArguments', 'InferArguments']) -> None:
131136
if self.temperature == 0:
132137
self.do_sample = False
133-
if self.do_sample is False:
134-
# fix warning
135-
self.temperature = 1.
136-
self.top_p = 1.
137-
self.top_k = 50
138-
logger.info('Due to do_sample=False, the following settings are applied: args.temperature: '
139-
f'{self.temperature}, args.top_p: {self.top_p}, args.top_k: {self.top_k}.')
140138

141139
def select_dtype(self: Union['SftArguments', 'InferArguments']) -> Tuple[Optional[Dtype], bool, bool]:
142140
if not is_torch_cuda_available() and not is_torch_npu_available():
@@ -825,11 +823,11 @@ class SftArguments(ArgumentsBase):
825823

826824
# generation config
827825
max_new_tokens: int = 2048
828-
do_sample: bool = True
829-
temperature: float = 0.3
830-
top_k: int = 20
831-
top_p: float = 0.7
832-
repetition_penalty: float = 1.
826+
do_sample: Optional[bool] = None
827+
temperature: Optional[float] = None
828+
top_k: Optional[int] = None
829+
top_p: Optional[float] = None
830+
repetition_penalty: Optional[float] = None
833831
num_beams: int = 1
834832

835833
# fsdp option
@@ -1336,11 +1334,11 @@ class InferArguments(ArgumentsBase):
13361334
bnb_4bit_quant_storage: Optional[str] = None
13371335

13381336
max_new_tokens: int = 2048
1339-
do_sample: bool = True
1340-
temperature: float = 0.3
1341-
top_k: int = 20
1342-
top_p: float = 0.7
1343-
repetition_penalty: float = 1.
1337+
do_sample: Optional[bool] = None
1338+
temperature: Optional[float] = None
1339+
top_k: Optional[int] = None
1340+
top_p: Optional[float] = None
1341+
repetition_penalty: Optional[float] = None
13441342
num_beams: int = 1
13451343
stop_words: List[str] = field(default_factory=list)
13461344

swift/llm/utils/client_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@
1919

2020

2121
def _get_request_kwargs(api_key: Optional[str] = None) -> Dict[str, Any]:
22+
timeout = float(os.getenv('TIMEOUT', '60'))
23+
request_kwargs = {'timeout': timeout}
2224
if api_key is None:
23-
return {}
24-
return {'headers': {'Authorization': f'Bearer {api_key}'}}
25+
return request_kwargs
26+
request_kwargs['headers'] = {'Authorization': f'Bearer {api_key}'}
27+
return request_kwargs
2528

2629

2730
def get_model_list_client(host: str = '127.0.0.1', port: str = '8000', api_key: str = 'EMPTY', **kwargs) -> ModelList:

swift/llm/utils/lmdeploy_utils.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def __init__(
137137
stop_words = []
138138
if max_new_tokens is None:
139139
max_new_tokens = 64
140+
self._temperature = temperature
140141
super().__init__(
141142
max_new_tokens=max_new_tokens,
142143
temperature=temperature,
@@ -149,6 +150,17 @@ def __init__(
149150
skip_special_tokens=skip_special_tokens,
150151
**kwargs)
151152

153+
def __setattr__(self, key: str, value: str) -> None:
154+
if key == 'do_sample':
155+
assert value in {True, False}
156+
super().__setattr__('temperature', self._temperature if value else 0)
157+
elif key == 'max_length':
158+
raise ValueError('`max_length` is not supported, please use `max_new_tokens` for setting.')
159+
else:
160+
if key == 'temperature':
161+
self._temperature = value
162+
super().__setattr__(key, value)
163+
152164

153165
def _add_stop_word(stop_words: List[int], token: Union[List[int], int, str, None], tokenizer=None) -> None:
154166
if token is None:
@@ -443,21 +455,16 @@ def prepare_lmdeploy_engine_template(args: InferArguments) -> Tuple[Union[AsyncE
443455
model_id_or_path=model_id_or_path)
444456
tokenizer = lmdeploy_engine.hf_tokenizer
445457

446-
if not args.do_sample:
447-
args.temperature = 0
448-
449458
stop_words = []
450459
for stop_word in args.stop_words:
451460
_add_stop_word(stop_words, stop_word, tokenizer=tokenizer)
452-
generation_config = LmdeployGenerationConfig(
453-
max_new_tokens=args.max_new_tokens,
454-
temperature=args.temperature,
455-
top_k=args.top_k,
456-
top_p=args.top_p,
457-
stop_words=stop_words,
458-
repetition_penalty=args.repetition_penalty)
459-
logger.info(f'generation_config: {generation_config}')
460-
lmdeploy_engine.generation_config = generation_config
461+
setattr(lmdeploy_engine.generation_config, 'max_new_tokens', args.max_new_tokens)
462+
for k in ['temperature', 'do_sample', 'top_k', 'top_p', 'repetition_penalty']:
463+
val = getattr(args, k, None)
464+
if val is not None:
465+
setattr(lmdeploy_engine.generation_config, k, val)
466+
logger.info(f'lmdeploy_engine.generation_config: {lmdeploy_engine.generation_config}')
467+
461468
template: Template = get_template(
462469
args.template_type,
463470
tokenizer,

swift/llm/utils/utils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,10 @@ def _prepare_inputs(model: PreTrainedModel,
599599
if 'token_type_ids' in inputs:
600600
inputs['token_type_ids'] = torch.tensor(inputs['token_type_ids'])[None]
601601
model.eval()
602-
602+
if not generation_config.do_sample:
603+
generation_config.temperature = 1.
604+
generation_config.top_p = 1.
605+
generation_config.top_k = 50
603606
if tokenizer.eos_token_id is not None:
604607
generation_config.eos_token_id = tokenizer.eos_token_id
605608
if tokenizer.pad_token_id is not None:
@@ -918,11 +921,12 @@ def set_generation_config(model: Module, generation_config: GenerationConfig) ->
918921
old_generation_config = getattr(model, 'generation_config', None)
919922
old_generation_priority_config = ['no_repeat_ngram_size']
920923
if old_generation_config is not None:
921-
for k, v in old_generation_config.__dict__.items():
922-
if k in old_generation_priority_config:
923-
setattr(generation_config, k, v)
924-
if k not in generation_config.__dict__:
925-
setattr(generation_config, k, v)
924+
for k, old_v in old_generation_config.__dict__.items():
925+
if k.startswith('_'):
926+
continue
927+
v = getattr(generation_config, k, None)
928+
if k in old_generation_priority_config or old_v is not None and v is None:
929+
setattr(generation_config, k, old_v)
926930
model.generation_config = generation_config
927931

928932

0 commit comments

Comments
 (0)