Skip to content

[BFCL] Fix Hanging Inference for OSS Models on GPU Platforms #663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 5, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import subprocess
import threading
import time

import requests
Expand Down Expand Up @@ -34,24 +35,23 @@ def inference(self, test_entry: dict, include_debugging_log: bool):
"OSS Models should call the batch_inference method instead."
)

def _format_prompt(self, messages, function):
raise NotImplementedError(
"OSS Models should implement their own prompt formatting."
)

def decode_ast(self, result, language="Python"):
return default_decode_ast_prompting(result, language)

def decode_execute(self, result):
return default_decode_execute_prompting(result)

def batch_inference(
self, test_entries: list[dict], num_gpus: int, gpu_memory_utilization: float, include_debugging_log: bool
self,
test_entries: list[dict],
num_gpus: int,
gpu_memory_utilization: float,
include_debugging_log: bool,
):
"""
Batch inference for OSS models.
"""

process = subprocess.Popen(
[
"vllm",
Expand All @@ -72,6 +72,30 @@ def batch_inference(
text=True, # To get the output as text instead of bytes
)

stop_event = (
threading.Event()
) # Event to signal threads to stop; no need to see vllm logs after server is ready

def log_subprocess_output(pipe, stop_event):
# Read lines until stop event is set
for line in iter(pipe.readline, ""):
if stop_event.is_set():
break
else:
print(line, end="")
pipe.close()
print("vllm server log tracking thread stopped successfully.")

# Start threads to read and print stdout and stderr
stdout_thread = threading.Thread(
target=log_subprocess_output, args=(process.stdout, stop_event)
)
stderr_thread = threading.Thread(
target=log_subprocess_output, args=(process.stderr, stop_event)
)
stdout_thread.start()
stderr_thread.start()

try:
# Wait for the server to be ready
server_ready = False
Expand All @@ -95,24 +119,27 @@ def batch_inference(
# If the connection is not ready, wait and try again
time.sleep(1)

# After the server is ready, stop capturing the output, otherwise the terminal looks messy
process.stdout.close()
process.stderr.close()
process.stdout = subprocess.DEVNULL
process.stderr = subprocess.DEVNULL
# Signal threads to stop reading output
stop_event.set()

# Once the server is ready, make the completion requests
for test_entry in tqdm(test_entries, desc="Generating results"):
try:
if "multi_turn" in test_entry["id"]:
model_responses, metadata = self.inference_multi_turn_prompting(test_entry, include_debugging_log)
model_responses, metadata = self.inference_multi_turn_prompting(
test_entry, include_debugging_log
)
else:
model_responses, metadata = self.inference_single_turn_prompting(test_entry, include_debugging_log)
model_responses, metadata = self.inference_single_turn_prompting(
test_entry, include_debugging_log
)
except Exception as e:
print(f"Error during inference for test entry {test_entry['id']}: {str(e)}")
print(
f"Error during inference for test entry {test_entry['id']}: {str(e)}"
)
model_responses = f"Error during inference: {str(e)}"
metadata = {}

result_to_write = {
"id": test_entry["id"],
"result": model_responses,
Expand All @@ -136,8 +163,18 @@ def batch_inference(
process.wait() # Wait again to ensure it's fully terminated
print("Process killed.")

# Wait for the output threads to finish
stop_event.set()
stdout_thread.join()
stderr_thread.join()

#### Prompting methods ####

def _format_prompt(self, messages, function):
raise NotImplementedError(
"OSS Models should implement their own prompt formatting."
)

def _query_prompting(self, inference_data: dict):
# We use the OpenAI Completions API with vLLM
function: list[dict] = inference_data["function"]
Expand Down