Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions examples/scaffolding/run_majority_vote_aime24.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import argparse
import asyncio
import json

from tensorrt_llm.scaffolding import (MajorityVoteController,
from tensorrt_llm.scaffolding import (GenerationTokenCounter,
MajorityVoteController,
NativeGenerationController,
ScaffoldingLlm, TRTLLMWorker,
ScaffoldingBenchRequest, ScaffoldingLlm,
TRTLLMWorker, async_scaffolding_benchmark,
extract_answer_from_boxed)


Expand All @@ -19,6 +22,8 @@ def parse_arguments():
parser.add_argument('--jsonl_file', type=str, default='./test.jsonl')
parser.add_argument('--threshold', type=float, default=None)
parser.add_argument('--sample_num', type=int, default=10)
parser.add_argument('--concurrency', type=int, default=None)
parser.add_argument('--static_with_benchmark', action='store_true')
args = parser.parse_args()
return args

Expand Down Expand Up @@ -70,7 +75,27 @@ def main():
test_case = test_dataset[i]
prompts.append(test_case["problem"])

results = llm.generate(prompts)
if args.static_with_benchmark or args.concurrency:
if args.concurrency == None:
args.concurrency = 1

if args.static_with_benchmark:
task_collection_types = {"token_counter": GenerationTokenCounter}

requests = [
ScaffoldingBenchRequest(prompt=prompt) for prompt in prompts
]

results, requests_execution_time, total_time = asyncio.run(
async_scaffolding_benchmark(llm, task_collection_types, requests,
args.concurrency))
else:
results = llm.generate(prompts)

print(f'main shutting down...')
llm.shutdown()
llm_worker.shutdown()
print(f'main shut down done')

for i in range(len(results)):
result = results[i]
Expand All @@ -95,10 +120,17 @@ def main():
assert correct_count >= args.threshold * total_count, \
f'Accuracy check failed with {correct_count}/{total_count} < {args.threshold}'
print(f'Accuracy check passed with threshold={args.threshold}')
print(f'main shutting down...')
llm.shutdown()
llm_worker.shutdown()
print(f'main shut down done')

if args.static_with_benchmark:
print(f'Total time: {total_time}')
print(
f'Average requests execution time: {sum(requests_execution_time) / len(requests_execution_time)}'
)
total_token_count = 0
for result in results:
total_token_count += result.task_collections[
"token_counter"].generation_token_count
print(f'Average output token count: {total_token_count / len(results)}')


if __name__ == '__main__':
Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/scaffolding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .benchmark import ScaffoldingBenchRequest, async_scaffolding_benchmark
from .controller import (BestOfNController, Controller, MajorityVoteController,
NativeGenerationController, NativeRewardController,
ParallelProcess, ScaffoldingOutput)
Expand Down Expand Up @@ -32,4 +33,6 @@
"TaskCollection",
"with_task_collection",
"GenerationTokenCounter",
"async_scaffolding_benchmark",
"ScaffoldingBenchRequest",
]
101 changes: 101 additions & 0 deletions tensorrt_llm/scaffolding/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import asyncio
import time
from typing import List, Mapping, Tuple, Type

from pydantic import BaseModel

from tensorrt_llm.scaffolding.scaffolding_llm import (ScaffoldingLlm,
ScaffoldingResult)
from tensorrt_llm.scaffolding.task_collection import (TaskCollection,
with_task_collection)


class ScaffoldingBenchRequest(BaseModel):
prompt: str


async def enqueue_requests(input_queue, requests):
for request in requests:
await input_queue.put(request)

await input_queue.put(None)


async def process_request(scaffolding_llm, request, output_queue, semaphore):
async with semaphore:
request_start_time = time.time()
result = scaffolding_llm.generate_async(request.prompt)
await result.aresult()
request_execution_time = time.time() - request_start_time
await output_queue.put((result, request_execution_time))


async def run_scaffolding_llm(scaffolding_llm, input_queue, output_queue,
concurrency):
semaphore = asyncio.Semaphore(concurrency)
tasks = set()

while True:
request = await input_queue.get()
if request is None:
break
task = asyncio.create_task(
process_request(scaffolding_llm, request, output_queue, semaphore))
tasks.add(task)
task.add_done_callback(tasks.discard)

await asyncio.gather(*tasks)
await output_queue.put(None)


def wrapper_prototype_controller_with_task_collection(scaffolding_llm,
task_collection_types):
prototype_controller_type = type(scaffolding_llm.prototype_controller)
controller_type_with_task_collection = prototype_controller_type

for name, task_collection_type in task_collection_types.items():
scaffolding_llm.prototype_controller.task_collections[
name] = task_collection_type()
controller_type_with_task_collection = with_task_collection(
name, task_collection_type)(controller_type_with_task_collection)

scaffolding_llm.enable_output_task_collection()


async def async_scaffolding_benchmark(
scaffolding_llm: ScaffoldingLlm,
task_collection_types: Mapping[str, Type[TaskCollection]],
requests: List[ScaffoldingBenchRequest],
concurrency: int) -> Tuple[List[ScaffoldingResult], List[float], float]:
wrapper_prototype_controller_with_task_collection(scaffolding_llm,
task_collection_types)

input_queue = asyncio.Queue()
output_queue = asyncio.Queue()

start_time = time.time()
results = []
requests_execution_time = []
enqueue_task = asyncio.create_task(enqueue_requests(input_queue, requests))

run_scaffolding_llm_task = asyncio.create_task(
run_scaffolding_llm(scaffolding_llm, input_queue, output_queue,
concurrency))

while True:
try:
item = await asyncio.wait_for(output_queue.get(), timeout=1.0)
if item is None:
break
result, request_execution_time = item
results.append(result)
requests_execution_time.append(request_execution_time)
except asyncio.TimeoutError:
continue

total_time = time.time() - start_time

enqueue_task.result()
run_scaffolding_llm_task.result()

return results, requests_execution_time, total_time
13 changes: 13 additions & 0 deletions tensorrt_llm/scaffolding/scaffolding_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,15 @@ def __init__(self):
self._done = False
self.aqueue = asyncio.Queue()
self.output = None
self.task_collections = None

def set_output(self, output: ScaffoldingOutput):
self.aqueue.put_nowait(output)

def set_task_collections(self, task_collections: Mapping[str,
"TaskCollection"]):
self.task_collections = task_collections

async def aresult_step(self):
# TODO: error handling or raise exception?
self.output = await self.aqueue.get()
Expand Down Expand Up @@ -76,6 +81,8 @@ def __init__(
self.max_parallel_requests = 64
self.pending_queue = deque()

self.output_task_collection = False

def __enter__(self):
return self

Expand Down Expand Up @@ -121,6 +128,9 @@ async def handle_single_request(request: ScaffoldingRequest):
def controller_generator_wrapper(request: ScaffoldingRequest):
scaffolding_output = yield from request.controller.generate(
request.prompt, **request.kwargs)
if self.output_task_collection:
request.result.set_task_collections(
request.controller.task_collections)
request.result.set_output(scaffolding_output)

try:
Expand Down Expand Up @@ -212,6 +222,9 @@ def generate(

return scaffolding_results[0] if unbatched else scaffolding_results

def enable_output_task_collection(self):
self.output_task_collection = True

def shutdown(self, shutdown_workers=False):

def shutdown_workers():
Expand Down
59 changes: 59 additions & 0 deletions tests/unittest/scaffolding/test_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import asyncio
from typing import List

from tensorrt_llm.scaffolding import (GenerationTask,
NativeGenerationController,
ScaffoldingBenchRequest, ScaffoldingLlm,
Task, TaskCollection, TaskStatus, Worker,
async_scaffolding_benchmark)

OUTPUT_STR = "Yes."


class DummyWorker(Worker):

async def dummy_generation_handler(self, task: GenerationTask):
task.output_str = OUTPUT_STR
return TaskStatus.SUCCESS

task_handlers = {GenerationTask: dummy_generation_handler}


class DummyTaskCollection(TaskCollection):

def __init__(self):
super().__init__()
self.output_len = 0

def before_yield(self, tasks: List[Task]):
pass

def after_yield(self, tasks: List[Task]):
self.output_len = len(tasks[0].output_str)


def test_scaffolding_benchmark():
task_collection_types = {"bench_dummy_collection": DummyTaskCollection}

prototype_controller = NativeGenerationController()
dummy_worker = DummyWorker()
workers = {NativeGenerationController.WorkerTag.GENERATION: dummy_worker}
scaffolding_llm = ScaffoldingLlm(prototype_controller, workers)

requests_num = 100
requests = [
ScaffoldingBenchRequest(prompt="Is today a nice day?")
for _ in range(requests_num)
]

concurrency = 10

results, requests_execution_time, total_time = asyncio.run(
async_scaffolding_benchmark(scaffolding_llm, task_collection_types,
requests, concurrency))

assert len(results) == requests_num
assert len(requests_execution_time) == requests_num
assert results[0].output.output_str == OUTPUT_STR
assert results[0].task_collections[
"bench_dummy_collection"].output_len == len(OUTPUT_STR)