Skip to content

Commit d5e4b27

Browse files
SebastianBodzabodzabodza
authored
Added settings vllm (#2599)
Co-authored-by: bodza <[email protected]> Co-authored-by: bodza <[email protected]>
1 parent 7a31d3b commit d5e4b27

File tree

4 files changed

+36
-0
lines changed

4 files changed

+36
-0
lines changed

fastchat/protocol/api_protocol.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,15 @@ class APIChatCompletionRequest(BaseModel):
5353
messages: Union[str, List[Dict[str, str]]]
5454
temperature: Optional[float] = 0.7
5555
top_p: Optional[float] = 1.0
56+
top_k: Optional[int] = -1
5657
n: Optional[int] = 1
5758
max_tokens: Optional[int] = None
5859
stop: Optional[Union[str, List[str]]] = None
5960
stream: Optional[bool] = False
6061
user: Optional[str] = None
6162
repetition_penalty: Optional[float] = 1.0
63+
frequency_penalty: Optional[float] = 0.0
64+
presence_penalty: Optional[float] = 0.0
6265

6366

6467
class ChatMessage(BaseModel):
@@ -130,6 +133,7 @@ class CompletionRequest(BaseModel):
130133
stop: Optional[Union[str, List[str]]] = None
131134
stream: Optional[bool] = False
132135
top_p: Optional[float] = 1.0
136+
top_k: Optional[int] = -1
133137
logprobs: Optional[int] = None
134138
echo: Optional[bool] = False
135139
presence_penalty: Optional[float] = 0.0

fastchat/protocol/openai_api_protocol.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class ChatCompletionRequest(BaseModel):
5353
messages: Union[str, List[Dict[str, str]]]
5454
temperature: Optional[float] = 0.7
5555
top_p: Optional[float] = 1.0
56+
top_k: Optional[int] = -1
5657
n: Optional[int] = 1
5758
max_tokens: Optional[int] = None
5859
stop: Optional[Union[str, List[str]]] = None
@@ -146,6 +147,7 @@ class CompletionRequest(BaseModel):
146147
stop: Optional[Union[str, List[str]]] = None
147148
stream: Optional[bool] = False
148149
top_p: Optional[float] = 1.0
150+
top_k: Optional[int] = -1
149151
logprobs: Optional[int] = None
150152
echo: Optional[bool] = False
151153
presence_penalty: Optional[float] = 0.0

fastchat/serve/openai_api_server.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,11 @@ def check_requests(request) -> Optional[JSONResponse]:
199199
ErrorCode.PARAM_OUT_OF_RANGE,
200200
f"{request.top_p} is greater than the maximum of 1 - 'temperature'",
201201
)
202+
if request.top_k is not None and (request.top_k > -1 and request.top_k < 1):
203+
return create_error_response(
204+
ErrorCode.PARAM_OUT_OF_RANGE,
205+
f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.",
206+
)
202207
if request.stop is not None and (
203208
not isinstance(request.stop, str) and not isinstance(request.stop, list)
204209
):
@@ -240,6 +245,9 @@ async def get_gen_params(
240245
*,
241246
temperature: float,
242247
top_p: float,
248+
top_k: Optional[int],
249+
presence_penalty: Optional[float],
250+
frequency_penalty: Optional[float],
243251
max_tokens: Optional[int],
244252
echo: Optional[bool],
245253
stop: Optional[Union[str, List[str]]],
@@ -284,6 +292,9 @@ async def get_gen_params(
284292
"prompt": prompt,
285293
"temperature": temperature,
286294
"top_p": top_p,
295+
"top_k": top_k,
296+
"presence_penalty": presence_penalty,
297+
"frequency_penalty": frequency_penalty,
287298
"max_new_tokens": max_tokens,
288299
"echo": echo,
289300
"stop_token_ids": conv.stop_token_ids,
@@ -366,6 +377,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
366377
request.messages,
367378
temperature=request.temperature,
368379
top_p=request.top_p,
380+
top_k=request.top_k,
381+
presence_penalty=request.presence_penalty,
382+
frequency_penalty=request.frequency_penalty,
369383
max_tokens=request.max_tokens,
370384
echo=False,
371385
stop=request.stop,
@@ -498,6 +512,9 @@ async def create_completion(request: CompletionRequest):
498512
text,
499513
temperature=request.temperature,
500514
top_p=request.top_p,
515+
top_k=request.top_k,
516+
frequency_penalty=request.frequency_penalty,
517+
presence_penalty=request.presence_penalty,
501518
max_tokens=request.max_tokens,
502519
echo=request.echo,
503520
stop=request.stop,
@@ -552,6 +569,9 @@ async def generate_completion_stream_generator(
552569
text,
553570
temperature=request.temperature,
554571
top_p=request.top_p,
572+
top_k=request.top_k,
573+
presence_penalty=request.presence_penalty,
574+
frequency_penalty=request.frequency_penalty,
555575
max_tokens=request.max_tokens,
556576
echo=request.echo,
557577
stop=request.stop,
@@ -731,6 +751,9 @@ async def create_chat_completion(request: APIChatCompletionRequest):
731751
request.messages,
732752
temperature=request.temperature,
733753
top_p=request.top_p,
754+
top_k=request.top_k,
755+
presence_penalty=request.presence_penalty,
756+
frequency_penalty=request.frequency_penalty,
734757
max_tokens=request.max_tokens,
735758
echo=False,
736759
stop=request.stop,

fastchat/serve/vllm_worker.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ async def generate_stream(self, params):
6868
request_id = params.pop("request_id")
6969
temperature = float(params.get("temperature", 1.0))
7070
top_p = float(params.get("top_p", 1.0))
71+
top_k = params.get("top_k", -1.0)
72+
presence_penalty = float(params.get("presence_penalty", 0.0))
73+
frequency_penalty = float(params.get("frequency_penalty", 0.0))
7174
max_new_tokens = params.get("max_new_tokens", 256)
7275
stop_str = params.get("stop", None)
7376
stop_token_ids = params.get("stop_token_ids", None) or []
@@ -92,13 +95,17 @@ async def generate_stream(self, params):
9295
top_p = max(top_p, 1e-5)
9396
if temperature <= 1e-5:
9497
top_p = 1.0
98+
9599
sampling_params = SamplingParams(
96100
n=1,
97101
temperature=temperature,
98102
top_p=top_p,
99103
use_beam_search=use_beam_search,
100104
stop=list(stop),
101105
max_tokens=max_new_tokens,
106+
top_k=top_k,
107+
presence_penalty=presence_penalty,
108+
frequency_penalty=frequency_penalty,
102109
best_of=best_of,
103110
)
104111
results_generator = engine.generate(context, sampling_params, request_id)

0 commit comments

Comments
 (0)