|
2 | 2 | import copy
|
3 | 3 | import pickle
|
4 | 4 | from contextlib import contextmanager, suppress
|
5 |
| -from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, |
6 |
| - Union, overload) |
| 5 | +from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, |
| 6 | + Optional, Union, overload) |
7 | 7 |
|
8 | 8 | import cloudpickle
|
9 | 9 | import zmq
|
|
12 | 12 | from zmq.asyncio import Socket
|
13 | 13 |
|
14 | 14 | from vllm import PoolingParams
|
| 15 | +from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function |
15 | 16 | from vllm.config import DecodingConfig, EngineConfig, ModelConfig
|
16 | 17 | from vllm.engine.arg_utils import AsyncEngineArgs
|
17 | 18 | # yapf conflicts with isort for this block
|
|
27 | 28 | RPCUProfileRequest)
|
28 | 29 | # yapf: enable
|
29 | 30 | from vllm.envs import VLLM_RPC_TIMEOUT
|
30 |
| -from vllm.inputs import PromptType |
| 31 | +from vllm.inputs import PromptType, TokensPrompt |
31 | 32 | from vllm.logger import init_logger
|
32 | 33 | from vllm.lora.request import LoRARequest
|
33 |
| -from vllm.outputs import EmbeddingRequestOutput, RequestOutput |
| 34 | +from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput, |
| 35 | + RequestOutput) |
34 | 36 | from vllm.prompt_adapter.request import PromptAdapterRequest
|
35 |
| -from vllm.sampling_params import SamplingParams |
| 37 | +from vllm.sampling_params import BeamSearchParams, SamplingParams |
36 | 38 | from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
37 |
| -from vllm.utils import deprecate_kwargs |
| 39 | +from vllm.utils import (collect_from_async_generator, deprecate_kwargs, |
| 40 | + random_uuid) |
38 | 41 |
|
39 | 42 | logger = init_logger(__name__)
|
40 | 43 |
|
@@ -441,6 +444,104 @@ def generate(
|
441 | 444 | lora_request, trace_headers,
|
442 | 445 | prompt_adapter_request, priority)
|
443 | 446 |
|
| 447 | + async def beam_search( |
| 448 | + self, |
| 449 | + prompt: Union[PromptType, List[int]], |
| 450 | + request_id: str, |
| 451 | + params: BeamSearchParams, |
| 452 | + ) -> AsyncGenerator[RequestOutput, None]: |
| 453 | + |
| 454 | + beam_width = params.beam_width |
| 455 | + max_tokens = params.max_tokens |
| 456 | + ignore_eos = params.ignore_eos |
| 457 | + temperature = params.temperature |
| 458 | + length_penalty = params.length_penalty |
| 459 | + |
| 460 | + tokenizer = await self.get_tokenizer(lora_request=None) |
| 461 | + tokenizedPrompt = prompt if isinstance( |
| 462 | + prompt, list) else tokenizer.encode(prompt) |
| 463 | + tokenizedLength = len(tokenizedPrompt) |
| 464 | + |
| 465 | + sort_beams_key = create_sort_beams_key_function( |
| 466 | + tokenizer.eos_token_id, length_penalty) |
| 467 | + |
| 468 | + beam_search_params = SamplingParams(logprobs=2 * beam_width, |
| 469 | + max_tokens=1, |
| 470 | + temperature=temperature) |
| 471 | + all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)] |
| 472 | + completed = [] |
| 473 | + |
| 474 | + for _ in range(max_tokens): |
| 475 | + prompts_batch = [ |
| 476 | + TokensPrompt(prompt_token_ids=beam.tokens) |
| 477 | + for beam in all_beams |
| 478 | + ] |
| 479 | + |
| 480 | + tasks = [] |
| 481 | + |
| 482 | + request_id = f"beam_search-{random_uuid()}" |
| 483 | + for i, individual_prompt in enumerate(prompts_batch): |
| 484 | + request_id_item = f"{request_id}-{i}" |
| 485 | + task = asyncio.create_task( |
| 486 | + collect_from_async_generator( |
| 487 | + self.generate(individual_prompt, beam_search_params, |
| 488 | + request_id_item))) |
| 489 | + tasks.append(task) |
| 490 | + |
| 491 | + output = await asyncio.gather(*tasks) |
| 492 | + |
| 493 | + output = [x[0] for x in output] |
| 494 | + |
| 495 | + logger.info(output) |
| 496 | + |
| 497 | + new_beams = [] |
| 498 | + for i, current_beam in enumerate(all_beams): |
| 499 | + result = output[i] |
| 500 | + |
| 501 | + if result.outputs[0].logprobs is not None: |
| 502 | + logprobs = result.outputs[0].logprobs[0] |
| 503 | + for token_id, logprob_obj in logprobs.items(): |
| 504 | + new_beam = BeamSearchSequence( |
| 505 | + tokens=current_beam.tokens + [token_id], |
| 506 | + cum_logprob=current_beam.cum_logprob + |
| 507 | + logprob_obj.logprob) |
| 508 | + |
| 509 | + if token_id == tokenizer.eos_token_id and \ |
| 510 | + not ignore_eos: |
| 511 | + completed.append(new_beam) |
| 512 | + else: |
| 513 | + new_beams.append(new_beam) |
| 514 | + |
| 515 | + sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) |
| 516 | + all_beams = sorted_beams[:beam_width] |
| 517 | + |
| 518 | + completed.extend(all_beams) |
| 519 | + sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) |
| 520 | + best_beams = sorted_completed[:beam_width] |
| 521 | + |
| 522 | + for beam in best_beams: |
| 523 | + beam.text = tokenizer.decode(beam.tokens[tokenizedLength:]) |
| 524 | + |
| 525 | + beam_search_output = RequestOutput( |
| 526 | + request_id=request_id, |
| 527 | + prompt=prompt, |
| 528 | + outputs=[ |
| 529 | + CompletionOutput( |
| 530 | + text=beam.text, |
| 531 | + cumulative_logprob=beam.cum_logprob, |
| 532 | + token_ids=beam.tokens, |
| 533 | + index=i, |
| 534 | + logprobs=beam.cum_logprob, |
| 535 | + ) for (i, beam) in enumerate(best_beams) |
| 536 | + ], |
| 537 | + finished=True, |
| 538 | + prompt_token_ids=tokenizedPrompt, |
| 539 | + prompt_logprobs=None) |
| 540 | + |
| 541 | + logger.info(beam_search_output) |
| 542 | + |
| 543 | + yield beam_search_output |
| 544 | + |
444 | 545 | @overload # DEPRECATED
|
445 | 546 | def encode(
|
446 | 547 | self,
|
|
0 commit comments