Skip to content

Commit 36d0212

Browse files
[python] parse input only when new requests are received (#2155)
1 parent 823563f commit 36d0212

File tree

11 files changed

+49
-34
lines changed

11 files changed

+49
-34
lines changed

engines/python/setup/djl_python/huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def inference(self, inputs: Input) -> Output:
226226
**self.input_format_args)
227227
requests = parsed_input.requests
228228
errors = parsed_input.errors
229-
if len(requests) == 0:
229+
if errors and len(parsed_input.batch) == len(errors):
230230
for i in range(len(parsed_input.batch)):
231231
err = errors.get(i)
232232
if is_rolling_batch_enabled(self.hf_configs.rolling_batch):

engines/python/setup/djl_python/input_parser.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,23 @@ class ParsedInput:
3030
batch: List = field(default_factory=lambda: [])
3131

3232

33+
def get_batch_start_id(batch, **kwargs):
34+
if kwargs.get("is_rolling_batch"):
35+
# for rolling batch, we only need to parse the new requests, as the active requests kept in cache.
36+
rolling_batch = kwargs.get("rolling_batch")
37+
active_requests_len = len(rolling_batch.active_requests)
38+
batch_size = len(batch)
39+
if batch_size > active_requests_len:
40+
# if batch_size > active_requests_len, then new requests are received
41+
return active_requests_len
42+
else:
43+
# no new requests are received, so sending batch_size, nothing will be parsed.
44+
return batch_size
45+
else:
46+
# for non-rolling batch, python process only receives new requests.
47+
return 0
48+
49+
3350
def parse_input_with_formatter(inputs: Input, **kwargs) -> ParsedInput:
3451
"""
3552
Preprocessing function that extracts information from Input objects.
@@ -44,7 +61,9 @@ def parse_input_with_formatter(inputs: Input, **kwargs) -> ParsedInput:
4461
kwargs["is_rolling_batch"] = is_rolling_batch_enabled(
4562
kwargs.get("configs").rolling_batch)
4663
request_id_counter = get_req_id_counter(kwargs)
47-
for i, input_item in enumerate(batch):
64+
start_batch_id = get_batch_start_id(batch, **kwargs)
65+
for i in range(start_batch_id, len(batch)):
66+
input_item = batch[i]
4867
try:
4968
request_id = request_id_counter.next_id(
5069
) if request_id_counter else i

engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,22 +143,23 @@ def translate_lmi_dist_params(self, parameters: dict):
143143
return parameters
144144

145145
@stop_on_any_exception
146-
def inference(self, requests: List[Request]) -> List:
146+
def inference(self, new_requests: List[Request]) -> List:
147147
"""
148148
Adds new requests and gets output tokens from the backend.
149149
150-
:param requests: List of requests
150+
:param new_requests: List of requests
151151
152152
:return results: List of dictionaries, one for each request, that contain output tokens and other data.
153153
"""
154-
new_requests = self.get_new_requests(requests)
154+
self.add_new_requests(new_requests)
155155
# step 0: register new requests to engine
156156
for request in new_requests:
157157
request_id = str(request.id)
158158
params = self.translate_lmi_dist_params(request.parameters)
159159
request_params = RequestParams(**params)
160160
lora_request_params = get_lora_request_params(
161161
request, self.lora_ids)
162+
# Constructing Request in lmi-dist library
162163
lmi_dist_request = Request(
163164
id=request_id,
164165
prompt=request.input_text,

engines/python/setup/djl_python/rolling_batch/neuron_rolling_batch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,17 @@ def append_speculated_generations(self, generation, request, req_ids):
9898
speculated_generation = generation.speculated_generations.dequeue()
9999

100100
@stop_on_any_exception
101-
def inference(self, requests: List[Request]) -> list:
101+
def inference(self, new_requests: List[Request]) -> list:
102102
"""
103103
Loads new requests and gets output tokens from all currently active requests from
104104
the Neuron backend.
105105
106-
:param requests: List[Request] List of requests
106+
:param new_requests: List[Request] List of requests
107107
108108
:return: generated batch decoded tokens - list of dictionaries, one for
109109
each request, that contain output tokens and other data.
110110
"""
111-
new_requests = self.get_new_requests(requests)
111+
self.add_new_requests(new_requests)
112112
if len(new_requests) > 0:
113113
generations = self.scheduler.prefill(new_requests)
114114
else:

engines/python/setup/djl_python/rolling_batch/rolling_batch.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,30 +93,23 @@ def get_tokenizer(self):
9393
raise RuntimeError("get_tokenizer function not supported")
9494

9595
@abstractmethod
96-
def inference(self, requests: List[Request]) -> List:
96+
def inference(self, new_requests: List[Request]) -> List:
9797
"""
9898
Performs prefill and decode operations for the batch.
9999
100-
:param requests: List[Request] List of requests
100+
:param new_requests: List[Request] List of requests
101101
102102
:return: generated batch decoded tokens
103103
"""
104104
pass
105105

106-
def get_new_requests(self, requests: List[Request]) -> List[Request]:
106+
def add_new_requests(self, requests: List[Request]):
107107
"""
108108
Adds requests to the batch when there is availability
109109
110110
:param requests: List[Request] List of requests
111-
112-
:return: list of current active requests (including those that have just been added)
113111
"""
114-
total_req_len = len(self.active_requests)
115-
batch_size = len(requests)
116-
if batch_size > total_req_len:
117-
for i in range(total_req_len, batch_size):
118-
self.active_requests.append(requests[i])
119-
return self.active_requests[total_req_len:]
112+
self.active_requests.extend(requests)
120113

121114
@abstractmethod
122115
def preprocess_requests(self, requests: List[Request]):

engines/python/setup/djl_python/rolling_batch/scheduler_rolling_batch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,14 @@ def __init__(self, model_id_or_path: str, properties: dict,
6969
self._init_scheduler()
7070

7171
@stop_on_any_exception
72-
def inference(self, requests: List) -> List:
72+
def inference(self, new_requests: List) -> List:
7373
"""
7474
Performs prefill and decode operations for the batch.
7575
76-
:param requests: List[Request] List of requests
76+
:param new_requests: List[Request] List of requests
7777
:return: generated batch decoded tokens
7878
"""
79-
new_requests = self.get_new_requests(requests)
79+
self.add_new_requests(new_requests)
8080

8181
preprocessed_new_requests = self.preprocess_requests(new_requests)
8282
self._prefill_and_decode(preprocessed_new_requests)

engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,20 @@ def translate_triton_params(self, parameters: dict) -> dict:
8787
return parameters
8888

8989
@stop_on_any_exception
90-
def inference(self, requests: List[Request]) -> List:
90+
def inference(self, new_requests: List[Request]) -> List:
9191
"""
9292
Loads new requests into the batch when there is availability, and gets output tokens from the backend
9393
asynchronously.
9494
95-
:param requests: List[Request] List of requests
95+
:param new_requests: List[Request] List of requests
9696
:param input_data: List of input prompts.
9797
:param parameters: List of settings pertaining to each request.
9898
:param adapters: List of adapters inputs for each request in a batch
9999
100100
:return results: List of dictionaries, one for each request, that contain output tokens and other data.
101101
"""
102102
# add pending requests to active requests list
103-
new_requests = self.get_new_requests(requests)
103+
self.add_new_requests(new_requests)
104104
# step 0: register new active requests
105105
for request in new_requests:
106106
param = self.translate_triton_params(request.parameters)

engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,15 @@ def translate_vllm_params(self, parameters: dict) -> dict:
107107
return parameters
108108

109109
@stop_on_any_exception
110-
def inference(self, requests: List[Request]) -> List:
110+
def inference(self, new_requests: List[Request]) -> List:
111111
"""
112112
Adds new requests and gets output tokens from the backend.
113113
114-
:param requests: List[Request] List of requests
114+
:param new_requests: List[Request] List of requests
115115
116116
:return results: List of dictionaries, one for each request, that contain output tokens and other data.
117117
"""
118-
new_requests = self.get_new_requests(requests)
118+
self.add_new_requests(new_requests)
119119
# step 0: register new requests to engine
120120
for request in new_requests:
121121
request_id = random_uuid()

engines/python/setup/djl_python/tensorrt_llm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def inference(self, inputs: Input) -> Output:
6464

6565
parsed_input = parse_input_with_formatter(inputs,
6666
**self.input_format_args)
67-
if len(parsed_input.requests) == 0:
67+
if parsed_input.errors and len(parsed_input.requests) == len(
68+
parsed_input.errors):
6869
for i in range(len(parsed_input.batch)):
6970
err = parsed_input.errors.get(i)
7071
err = {"data": "", "last": True, "code": 424, "error": err}

engines/python/setup/djl_python/tensorrt_llm_python.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def inference(self, inputs: Input) -> Output:
115115

116116
parsed_input = parse_input_with_formatter(inputs,
117117
**self.input_format_args)
118-
if len(parsed_input.requests) == 0:
118+
if parsed_input.errors and len(parsed_input.requests) == len(
119+
parsed_input.errors):
119120
for i in range(len(parsed_input.batch)):
120121
err = parsed_input.errors.get(i)
121122
outputs.add(err, key="data", batch_index=i)

0 commit comments

Comments
 (0)