Skip to content

Commit 8a07c25

Browse files
committed
[vllm, lmi-dist] add support for top_n_tokens
1 parent bc328ec commit 8a07c25

File tree

6 files changed

+283
-53
lines changed

6 files changed

+283
-53
lines changed

engines/python/setup/djl_python/output_formatter.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import json
1414
import logging
1515
import time
16-
from typing import Union, Callable
16+
from typing import Union, Callable, Optional, Dict
1717

1818
from typing_extensions import deprecated
1919

@@ -43,6 +43,9 @@ def get_sequence_details(request_output: RequestOutput,
4343
if parameters.get("decoder_input_details"):
4444
sequence_details["prefill"] = request_output.get_prompt_tokens_as_dict(
4545
)
46+
if parameters.get("top_n_tokens", 0) > 0:
47+
sequence_details["top_tokens"] = request_output.get_top_tokens_as_dict(
48+
sequence_index)
4649
return sequence_details
4750

4851

@@ -106,31 +109,45 @@ def _json_output_formatter(request_output: RequestOutput):
106109
json_encoded_str = f"[{json_encoded_str}"
107110
json_encoded_str = f"{json_encoded_str}{json.dumps(next_token.text, ensure_ascii=False)[1:-1]}"
108111
if last_token:
109-
if parameters.get("details", tgi_compat):
110-
final_dict = {
111-
"finish_reason": best_sequence.finish_reason,
112-
"generated_tokens": len(best_sequence.tokens),
113-
"inputs": request_output.input.input_text,
114-
"tokens": request_output.get_tokens_as_dict(),
115-
}
116-
117-
if parameters.get("decoder_input_details"):
118-
final_dict[
119-
"prefill"] = request_output.get_prompt_tokens_as_dict()
120-
details_str = f"\"details\": {json.dumps(final_dict, ensure_ascii=False)}"
121-
json_encoded_str = f"{json_encoded_str}\", {details_str}}}"
122-
elif best_sequence.finish_reason == "error":
123-
final_dict = {"finish_reason": best_sequence.finish_reason}
124-
details_str = f"\"details\": {json.dumps(final_dict, ensure_ascii=False)}"
112+
details_dict = get_details_dict(request_output, include_tokens=True)
113+
if details_dict:
114+
details_str = f"\"details\": {json.dumps(details_dict, ensure_ascii=False)}"
125115
json_encoded_str = f"{json_encoded_str}\", {details_str}}}"
126116
else:
127117
json_encoded_str = f"{json_encoded_str}\"}}"
128118
if tgi_compat:
129119
json_encoded_str = f"{json_encoded_str}]"
130-
131120
return json_encoded_str
132121

133122

123+
def get_details_dict(request_output: RequestOutput,
124+
include_tokens: bool = True) -> Optional[Dict]:
125+
parameters = request_output.input.parameters
126+
best_sequence = request_output.sequences[
127+
request_output.best_sequence_index]
128+
if parameters.get("details", request_output.input.tgi_compat):
129+
final_dict = {
130+
"finish_reason": best_sequence.finish_reason,
131+
"generated_tokens": len(best_sequence.tokens),
132+
"inputs": request_output.input.input_text,
133+
}
134+
135+
if include_tokens:
136+
final_dict["tokens"] = request_output.get_tokens_as_dict()
137+
138+
if parameters.get("decoder_input_details"):
139+
final_dict["prefill"] = request_output.get_prompt_tokens_as_dict()
140+
if parameters.get("top_n_tokens", 0) > 0:
141+
final_dict["top_tokens"] = request_output.get_top_tokens_as_dict(
142+
request_output.best_sequence_index)
143+
144+
return final_dict
145+
elif best_sequence.finish_reason == "error":
146+
return {"finish_reason": best_sequence.finish_reason}
147+
else:
148+
return None
149+
150+
134151
def _jsonlines_output_formatter(request_output: RequestOutput):
135152
"""
136153
jsonlines output formatter
@@ -148,19 +165,9 @@ def _jsonlines_output_formatter(request_output: RequestOutput):
148165
if last_token:
149166
generated_text = get_generated_text(best_sequence, request_output)
150167
final_dict["generated_text"] = generated_text
151-
if parameters.get("details", tgi_compat):
152-
final_dict["details"] = {
153-
"finish_reason": best_sequence.finish_reason,
154-
"generated_tokens": len(best_sequence.tokens),
155-
"inputs": request_output.input.input_text,
156-
}
157-
if parameters.get("decoder_input_details"):
158-
final_dict["details"][
159-
"prefill"] = request_output.get_prompt_tokens_as_dict()
160-
elif best_sequence.finish_reason == "error":
161-
final_dict["details"] = {
162-
"finish_reason": best_sequence.finish_reason
163-
}
168+
details_dict = get_details_dict(request_output, include_tokens=False)
169+
if details_dict:
170+
final_dict["details"] = details_dict
164171
json_encoded_str = json.dumps(final_dict, ensure_ascii=False) + "\n"
165172
return json_encoded_str
166173

engines/python/setup/djl_python/request_io.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,3 +237,21 @@ def get_prompt_tokens_as_dict(self):
237237
else:
238238
tokens.append(token.as_dict())
239239
return tokens
240+
241+
def get_top_tokens_as_dict(self, sequence_index=0):
242+
"""Returns the top tokens of the given sequence index as a dictionary.
243+
If not given, returns the top tokens of the first sequence index as a dictionary.
244+
245+
:param sequence_index: index of the sequence to get the top tokens from.
246+
:return: top tokens of the given sequence index as a dictionary.
247+
"""
248+
top_tokens = []
249+
for top_token in self.sequences[sequence_index].top_tokens:
250+
top_token_list = []
251+
for token in top_token:
252+
if self.input.tgi_compat:
253+
top_token_list.append(token.as_tgi_dict())
254+
else:
255+
top_token_list.append(token.as_dict())
256+
top_tokens.append(top_token_list)
257+
return top_tokens

engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,14 @@ def translate_lmi_dist_params(self, parameters: dict):
128128
parameters["use_beam_search"] = True
129129
if parameters.pop("decoder_input_details", False):
130130
parameters["prompt_logprobs"] = 1
131-
parameters["logprobs"] = parameters.get("logprobs", 1)
131+
if "best_of" in parameters.keys():
132+
# if n is not explicitly set, we return `best_of` values sequences.
133+
if "n" not in "best_of":
134+
parameters["n"] = parameters["best_of"]
135+
if "top_n_tokens" in parameters.keys():
136+
parameters["logprobs"] = parameters.pop("top_n_tokens")
137+
else:
138+
parameters["logprobs"] = parameters.get("logprobs", 1)
132139
parameters = filter_unused_generation_params(
133140
parameters,
134141
LMI_DIST_GENERATION_PARAMS,

engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ def update_request_cache_with_output(request_cache: OrderedDict,
6363
request_output.best_sequence_index = vllm_request_output.outputs[
6464
0].index
6565
request_cache.pop(request_id)
66+
for i in range(1, len(vllm_request_output.outputs)):
67+
index = vllm_request_output.outputs[i].index
68+
request_output.other_sequences_indices.append(index)
6669

6770
return request_cache
6871

@@ -105,17 +108,21 @@ def update_multiple_sequences(cache, request_output, vllm_request_output):
105108
output_token_texts = [text] * len(
106109
new_token_ids) if not output_token_texts else output_token_texts
107110

111+
top_tokens = []
108112
# calculate log probs
109113
if completion_output.logprobs:
110114
new_logprobs_list = completion_output.logprobs[
111115
prev_len:
112116
cur_len] if prev_len < cur_len else completion_output.logprobs
113-
new_logprobs = [
114-
# NOTE: vLLM 0.4.1 changed logprob type
115-
logprobs[token_id] if isinstance(logprobs[token_id], float)
116-
else logprobs[token_id].logprob
117-
for token_id, logprobs in zip(new_token_ids, new_logprobs_list)
118-
]
117+
new_logprobs = []
118+
for token_id, logprobs in zip(new_token_ids, new_logprobs_list):
119+
for token_id_key, logprob in logprobs.items():
120+
new_logprobs.append(logprobs[token_id].logprob)
121+
top_tokens.append(
122+
Token(id=token_id_key,
123+
text=logprob.decoded_token,
124+
log_prob=logprob.logprob))
125+
119126
else:
120127
new_logprobs = [None] * len(new_token_ids)
121128

@@ -139,6 +146,10 @@ def update_multiple_sequences(cache, request_output, vllm_request_output):
139146
is_last_token = finish_reason is not None
140147
request_output.sequences[sequence_index].set_next_token(
141148
token, is_last_token)
149+
top_tokens.append(token)
150+
151+
request_output.sequences[sequence_index].set_next_top_tokens(
152+
top_tokens)
142153

143154
cache[f"sequence_index_{sequence_index}"]["curr_length"] = len(
144155
completion_output.text)

engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def translate_vllm_params(self, parameters: dict) -> dict:
8383
if not (is_beam_search and is_best_of):
8484
# if temperature is zero, vLLM does greedy sampling
8585
parameters['temperature'] = 0
86-
elif parameters.pop("do_sample"):
86+
elif not parameters.pop("do_sample", False):
8787
parameters["temperature"] = 0
8888
if "stop_sequences" in parameters.keys():
8989
parameters["stop"] = parameters.pop("stop_sequences")
@@ -94,7 +94,16 @@ def translate_vllm_params(self, parameters: dict) -> dict:
9494
parameters["use_beam_search"] = True
9595
if parameters.pop("decoder_input_details", False):
9696
parameters["prompt_logprobs"] = 1
97-
parameters["logprobs"] = parameters.get("logprobs", 1)
97+
98+
# if n is not explicitly set when best_of is set, we return `best_of` values sequences for tgi compatibility.
99+
if "best_of" in parameters.keys():
100+
if "n" not in "best_of":
101+
parameters["n"] = parameters["best_of"]
102+
103+
if "top_n_tokens" in parameters.keys():
104+
parameters["logprobs"] = parameters.pop("top_n_tokens")
105+
else:
106+
parameters["logprobs"] = parameters.get("logprobs", 1)
98107
parameters = filter_unused_generation_params(parameters,
99108
VLLM_GENERATION_PARAMS,
100109
"vllm",

0 commit comments

Comments
 (0)