Skip to content

Commit 14ace70

Browse files
kouroshHakhaccmao1130
authored andcommitted
[data.llm] Return a batch of rows in the udf instead of row by row (ray-project#54329)
Signed-off-by: Kourosh Hakhamaneshi <[email protected]> Signed-off-by: ChanChan Mao <[email protected]>
1 parent 6cd0a1c commit 14ace70

File tree

6 files changed

+19
-26
lines changed

6 files changed

+19
-26
lines changed

python/ray/llm/_internal/batch/stages/base.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,8 @@ async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]]
160160
for idx, row in enumerate(inputs):
161161
row[self.IDX_IN_BATCH_COLUMN] = idx
162162

163-
# Always stream the outputs one by one to better overlapping
164-
# batches. For example, when the output batch size is 64, Ray Data
165-
# will collect 64 outputs, and 1) send the batch of 64 to the next stage,
166-
# 2) get the next batch of this stage. Assuming the input batch size
167-
# is 63 and we yield all 63 results at once, then Ray Data will wait
168-
# for 2 batches (63 + 63 > 64) to continue proceeding. On the other hand,
169-
# if we stream outputs one-by-one, Ray Data can form a batch of 64 before
170-
# the second batch is done.
163+
# Collect all outputs first, then return them in the original order
164+
# This is a requirement set by https://github.com/ray-project/ray/pull/54190/
171165
not_outputed_rows = set(range(len(inputs)))
172166
async for output in self.udf(inputs):
173167
if self.IDX_IN_BATCH_COLUMN not in output:
@@ -186,11 +180,13 @@ async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]]
186180
# Add stage outputs to the data column of the row.
187181
inputs[idx_in_batch].pop(self.IDX_IN_BATCH_COLUMN)
188182
inputs[idx_in_batch].update(output)
189-
yield {self.data_column: [inputs[idx_in_batch]]}
190183

191184
if not_outputed_rows:
192185
raise ValueError(f"The rows {not_outputed_rows} are not outputed.")
193186

187+
# Return all updated inputs in the original order
188+
yield {self.data_column: inputs}
189+
194190
def validate_inputs(self, inputs: List[Dict[str, Any]]):
195191
"""Validate the inputs to make sure the required keys are present.
196192

python/ray/llm/tests/batch/cpu/stages/test_chat_template_stage.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ async def test_chat_template_udf_basic(mock_tokenizer_setup):
4343

4444
results = []
4545
async for result in udf(batch):
46-
results.append(result)
46+
results.extend(result["__data"])
4747

4848
assert len(results) == 1
49-
assert results[0]["__data"][0]["prompt"] == "<chat>Hello AI</chat>"
49+
assert results[0]["prompt"] == "<chat>Hello AI</chat>"
5050
mock_tokenizer.apply_chat_template.assert_called_once()
5151

5252

@@ -83,9 +83,9 @@ async def test_chat_template_udf_multiple_messages(mock_tokenizer_setup):
8383
async for result in udf(batch):
8484
results.append(result)
8585

86-
assert len(results) == 2
86+
assert len(results) == 1
8787
assert results[0]["__data"][0]["prompt"] == "<chat>Hello AI</chat>"
88-
assert results[1]["__data"][0]["prompt"] == "<chat>How are you?</chat>"
88+
assert results[0]["__data"][1]["prompt"] == "<chat>How are you?</chat>"
8989
assert mock_tokenizer.apply_chat_template.call_count == 2
9090

9191

@@ -123,14 +123,12 @@ async def test_chat_template_udf_assistant_prefill(mock_tokenizer_setup):
123123

124124
results = []
125125
async for result in udf(batch):
126-
results.append(result)
126+
results.extend(result["__data"])
127127

128128
assert len(results) == 2
129129
assert mock_tokenizer.apply_chat_template.call_count == 2
130-
assert (
131-
results[0]["__data"][0]["prompt"] == "<chat>Hello AI<assistant><think>\n</chat>"
132-
)
133-
assert results[1]["__data"][0]["prompt"] == "<chat>Hello AI</chat>"
130+
assert results[0]["prompt"] == "<chat>Hello AI<assistant><think>\n</chat>"
131+
assert results[1]["prompt"] == "<chat>Hello AI</chat>"
134132
# check if kwargs were set properly
135133
call_args_list = mock_tokenizer.apply_chat_template.call_args_list
136134
args1, kwargs1 = call_args_list[0]

python/ray/llm/tests/batch/cpu/stages/test_http_request_stage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ async def test_http_request_udf_with_qps(mock_session):
7373

7474
results = []
7575
async for result in udf(batch):
76-
results.append(result)
76+
results.extend(result["__data"])
7777

7878
assert len(results) == 2
7979
assert mock_sleep.called # Should have called sleep for QPS limiting
@@ -113,7 +113,7 @@ async def test_http_request_udf_with_retry(mock_response):
113113
with patch("asyncio.sleep") as mock_sleep:
114114
results = []
115115
async for result in udf(batch):
116-
results.append(result)
116+
results.extend(result["__data"])
117117

118118
assert len(results) == 2
119119
mock_sleep.assert_called()

python/ray/llm/tests/batch/cpu/stages/test_stage_base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,10 @@ async def test_basic_processing(self):
7373

7474
results = []
7575
async for result in udf(batch):
76-
results.append(result)
76+
results.extend(result["__data"])
7777

7878
assert len(results) == 2
79-
for result in results:
80-
data = result["__data"][0]
79+
for data in results:
8180
val = data["value"]
8281
assert data["processed"] == val * 2
8382
assert data["extra"] == 10 * val

python/ray/llm/tests/batch/cpu/stages/test_tokenize_stage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ async def test_tokenize_udf_basic(mock_tokenizer_setup):
3535

3636
results = []
3737
async for result in udf(batch):
38-
results.append(result["__data"][0])
38+
results.extend(result["__data"])
3939

4040
assert len(results) == 2
4141
assert all(result["tokenized_prompt"] == [1, 2, 3] for result in results)
@@ -64,7 +64,7 @@ async def test_detokenize_udf_basic(mock_tokenizer_setup):
6464

6565
results = []
6666
async for result in udf(batch):
67-
results.append(result["__data"][0])
67+
results.extend(result["__data"])
6868

6969
assert len(results) == 2
7070
assert results[0]["generated_text"] == "Hello"

python/ray/llm/tests/batch/gpu/stages/test_sglang_engine_stage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ async def test_sglang_engine_udf_basic(mock_sglang_wrapper, model_llama_3_2_216M
168168

169169
responses = []
170170
async for response in udf(batch):
171-
responses.append(response["__data"][0])
171+
responses.extend(response["__data"])
172172

173173
assert len(responses) == 2
174174
assert all("batch_uuid" in r for r in responses)

0 commit comments

Comments
 (0)