Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions python/ray/llm/_internal/batch/stages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,8 @@ async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]]
for idx, row in enumerate(inputs):
row[self.IDX_IN_BATCH_COLUMN] = idx

# Always stream the outputs one by one to better overlapping
# batches. For example, when the output batch size is 64, Ray Data
# will collect 64 outputs, and 1) send the batch of 64 to the next stage,
# 2) get the next batch of this stage. Assuming the input batch size
# is 63 and we yield all 63 results at once, then Ray Data will wait
# for 2 batches (63 + 63 > 64) to continue proceeding. On the other hand,
# if we stream outputs one-by-one, Ray Data can form a batch of 64 before
# the second batch is done.
# Collect all outputs first, then return them in the original order
# This is a requirement set by https://github.com/ray-project/ray/pull/54190/
not_outputed_rows = set(range(len(inputs)))
async for output in self.udf(inputs):
if self.IDX_IN_BATCH_COLUMN not in output:
Expand All @@ -186,11 +180,13 @@ async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]]
# Add stage outputs to the data column of the row.
inputs[idx_in_batch].pop(self.IDX_IN_BATCH_COLUMN)
inputs[idx_in_batch].update(output)
yield {self.data_column: [inputs[idx_in_batch]]}

if not_outputed_rows:
raise ValueError(f"The rows {not_outputed_rows} are not outputed.")

# Return all updated inputs in the original order
yield {self.data_column: inputs}

def validate_inputs(self, inputs: List[Dict[str, Any]]):
"""Validate the inputs to make sure the required keys are present.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ async def test_chat_template_udf_basic(mock_tokenizer_setup):

results = []
async for result in udf(batch):
results.append(result)
results.extend(result["__data"])

assert len(results) == 1
assert results[0]["__data"][0]["prompt"] == "<chat>Hello AI</chat>"
assert results[0]["prompt"] == "<chat>Hello AI</chat>"
mock_tokenizer.apply_chat_template.assert_called_once()


Expand Down Expand Up @@ -83,9 +83,9 @@ async def test_chat_template_udf_multiple_messages(mock_tokenizer_setup):
async for result in udf(batch):
results.append(result)

assert len(results) == 2
assert len(results) == 1
assert results[0]["__data"][0]["prompt"] == "<chat>Hello AI</chat>"
assert results[1]["__data"][0]["prompt"] == "<chat>How are you?</chat>"
assert results[0]["__data"][1]["prompt"] == "<chat>How are you?</chat>"
assert mock_tokenizer.apply_chat_template.call_count == 2


Expand Down Expand Up @@ -123,14 +123,12 @@ async def test_chat_template_udf_assistant_prefill(mock_tokenizer_setup):

results = []
async for result in udf(batch):
results.append(result)
results.extend(result["__data"])

assert len(results) == 2
assert mock_tokenizer.apply_chat_template.call_count == 2
assert (
results[0]["__data"][0]["prompt"] == "<chat>Hello AI<assistant><think>\n</chat>"
)
assert results[1]["__data"][0]["prompt"] == "<chat>Hello AI</chat>"
assert results[0]["prompt"] == "<chat>Hello AI<assistant><think>\n</chat>"
assert results[1]["prompt"] == "<chat>Hello AI</chat>"
# check if kwargs were set properly
call_args_list = mock_tokenizer.apply_chat_template.call_args_list
args1, kwargs1 = call_args_list[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def test_http_request_udf_with_qps(mock_session):

results = []
async for result in udf(batch):
results.append(result)
results.extend(result["__data"])

assert len(results) == 2
assert mock_sleep.called # Should have called sleep for QPS limiting
Expand Down Expand Up @@ -113,7 +113,7 @@ async def test_http_request_udf_with_retry(mock_response):
with patch("asyncio.sleep") as mock_sleep:
results = []
async for result in udf(batch):
results.append(result)
results.extend(result["__data"])

assert len(results) == 2
mock_sleep.assert_called()
Expand Down
5 changes: 2 additions & 3 deletions python/ray/llm/tests/batch/cpu/stages/test_stage_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,10 @@ async def test_basic_processing(self):

results = []
async for result in udf(batch):
results.append(result)
results.extend(result["__data"])

assert len(results) == 2
for result in results:
data = result["__data"][0]
for data in results:
val = data["value"]
assert data["processed"] == val * 2
assert data["extra"] == 10 * val
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async def test_tokenize_udf_basic(mock_tokenizer_setup):

results = []
async for result in udf(batch):
results.append(result["__data"][0])
results.extend(result["__data"])

assert len(results) == 2
assert all(result["tokenized_prompt"] == [1, 2, 3] for result in results)
Expand Down Expand Up @@ -64,7 +64,7 @@ async def test_detokenize_udf_basic(mock_tokenizer_setup):

results = []
async for result in udf(batch):
results.append(result["__data"][0])
results.extend(result["__data"])

assert len(results) == 2
assert results[0]["generated_text"] == "Hello"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ async def test_sglang_engine_udf_basic(mock_sglang_wrapper, model_llama_3_2_216M

responses = []
async for response in udf(batch):
responses.append(response["__data"][0])
responses.extend(response["__data"])

assert len(responses) == 2
assert all("batch_uuid" in r for r in responses)
Expand Down