Skip to content

Commit a134ef6

Browse files
authored
Support eos_token_id from generation_config.json (#4182)
1 parent 8a7a3e4 commit a134ef6

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

vllm/engine/llm_engine.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import time
22
from typing import Iterable, List, Optional, Type, Union
33

4-
from transformers import PreTrainedTokenizer
4+
from transformers import GenerationConfig, PreTrainedTokenizer
55

66
import vllm
77
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
@@ -34,6 +34,17 @@
3434
_LOCAL_LOGGING_INTERVAL_SEC = 5
3535

3636

37+
def _load_generation_config_dict(model_config: ModelConfig):
38+
try:
39+
return GenerationConfig.from_pretrained(
40+
model_config.model,
41+
revision=model_config.revision,
42+
).to_diff_dict()
43+
except OSError:
44+
# Not found.
45+
return {}
46+
47+
3748
class LLMEngine:
3849
"""An LLM engine that receives requests and generates texts.
3950
@@ -124,6 +135,8 @@ def __init__(
124135
self._init_tokenizer()
125136
self.detokenizer = Detokenizer(self.tokenizer)
126137
self.seq_counter = Counter()
138+
self.generation_config_fields = _load_generation_config_dict(
139+
model_config)
127140

128141
self.model_executor = executor_class(
129142
model_config=model_config,
@@ -391,6 +404,8 @@ def add_request(
391404
# inject the eos token id into the sampling_params to support min_tokens
392405
# processing
393406
sampling_params.eos_token_id = seq.eos_token_id
407+
sampling_params.update_from_generation_config(
408+
self.generation_config_fields)
394409

395410
# Create the sequence group.
396411
seq_group = SequenceGroup(request_id, [seq], sampling_params,
@@ -435,7 +450,7 @@ def _process_model_outputs(
435450
scheduled_seq_groups: List[SequenceGroup],
436451
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
437452
"""Apply the model output to the sequences in the scheduled seq groups.
438-
453+
439454
Returns RequestOutputs that can be returned to the client.
440455
"""
441456

vllm/sampling_params.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import copy
33
from enum import IntEnum
44
from functools import cached_property
5-
from typing import Callable, List, Optional, Union
5+
from typing import Any, Callable, Dict, List, Optional, Union
66

77
import torch
88
from pydantic import Field
@@ -271,6 +271,18 @@ def _verify_greedy_sampling(self) -> None:
271271
raise ValueError("best_of must be 1 when using greedy sampling."
272272
f"Got {self.best_of}.")
273273

274+
def update_from_generation_config(
275+
self, generation_config: Dict[str, Any]) -> None:
276+
"""Update if there are non-default values from generation_config"""
277+
# Update eos_token_id for generation
278+
if eos_ids := generation_config.get("eos_token_id"):
279+
# it can be either int or list of int
280+
if isinstance(eos_ids, int):
281+
eos_ids = [eos_ids]
282+
original_stop_token_ids = set(self.stop_token_ids)
283+
original_stop_token_ids.update(eos_ids)
284+
self.stop_token_ids = list(original_stop_token_ids)
285+
274286
@cached_property
275287
def sampling_type(self) -> SamplingType:
276288
if self.use_beam_search:

0 commit comments

Comments
 (0)