Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 60 additions & 10 deletions berkeley-function-call-leaderboard/bfcl/_llm_response_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,20 @@

from bfcl._apply_function_credential_config import apply_function_credential_config
from bfcl.constant import (
DOTENV_PATH,
MULTI_TURN_FUNC_DOC_FILE_MAPPING,
MULTI_TURN_FUNC_DOC_PATH,
PROMPT_PATH,
RESULT_PATH,
TEST_COLLECTION_MAPPING,
TEST_FILE_MAPPING,
)
from bfcl.eval_checker.eval_runner_helper import is_executable
from bfcl.eval_checker.eval_runner_helper import (
is_executable,
is_multi_turn,
load_file,
)
from bfcl.model_handler.handler_map import HANDLER_MAP
from bfcl.model_handler.model_style import ModelStyle
from dotenv import load_dotenv
from tqdm import tqdm

RETRY_LIMIT = 3
Expand Down Expand Up @@ -89,12 +93,12 @@ def parse_test_category_argument(test_category_args):
return sorted(list(test_name_total)), sorted(list(test_filename_total))


def collect_test_cases(test_filename_total, model_name):
def collect_test_cases(test_name_total, test_filename_total, model_name):
model_name_dir = model_name.replace("/", "_")
model_result_dir = RESULT_PATH / model_name_dir

test_cases_total = []
for file_to_open in test_filename_total:
for test_category, file_to_open in zip(test_name_total, test_filename_total):
test_cases = []
with open(PROMPT_PATH / file_to_open) as f:
for line in f:
Expand All @@ -108,13 +112,50 @@ def collect_test_cases(test_filename_total, model_name):
existing_result.append(json.loads(line))

existing_ids = [entry["id"] for entry in existing_result]
test_cases_total.extend(
[test_case for test_case in test_cases if test_case["id"] not in existing_ids]
test_cases_to_generate = [
test_case for test_case in test_cases if test_case["id"] not in existing_ids
]
test_cases_to_generate = process_multi_turn_test_case(
test_cases_to_generate, test_category
)

test_cases_total.extend(test_cases_to_generate)

return sorted(test_cases_total, key=sort_key)


def process_multi_turn_test_case(test_cases, test_category):
"""
Multi-turn test cases don't have the function doc in the prompt. We need to add them here.
"""
if not is_multi_turn(test_category):
return test_cases
for entry in test_cases:
involved_classes = entry["involved_classes"]
entry["function"] = []
for func_collection in involved_classes:
# func_doc is a list of dict
func_doc = load_file(
MULTI_TURN_FUNC_DOC_PATH / MULTI_TURN_FUNC_DOC_FILE_MAPPING[func_collection]
)
entry["function"].extend(func_doc)

# Handle Miss Func category; we need to remove the holdout function doc
if "missed_function" in entry:
for turn_index, missed_func_names in entry["missed_function"].items():
entry["missed_function"][turn_index] = []
for missed_func_name in missed_func_names:
for i, func_doc in enumerate(entry["function"]):
if func_doc["name"] == missed_func_name:
# Add the missed function doc to the missed_function list
entry["missed_function"][turn_index].append(func_doc)
# Remove it from the function list
entry["function"].pop(i)
break

return test_cases


def multi_threaded_inference(handler, test_case, include_debugging_log):

assert type(test_case["function"]) is list
Expand All @@ -123,7 +164,9 @@ def multi_threaded_inference(handler, test_case, include_debugging_log):

while True:
try:
result, metadata = handler.inference(copy.deepcopy(test_case), include_debugging_log)
result, metadata = handler.inference(
copy.deepcopy(test_case), include_debugging_log
)
break # Success, exit the loop
except Exception as e:
# TODO: It might be better to handle the exception in the handler itself rather than a universal catch block here, as each handler use different ways to call the endpoint.
Expand Down Expand Up @@ -186,7 +229,12 @@ def generate_results(args, model_name, test_cases_total):
) as pbar:

for test_case in test_cases_total:
future = executor.submit(multi_threaded_inference, handler, test_case, args.include_debugging_log)
future = executor.submit(
multi_threaded_inference,
handler,
test_case,
args.include_debugging_log,
)
futures.append(future)

for future in futures:
Expand Down Expand Up @@ -218,7 +266,9 @@ def main(args):
):
model_name = model_name + "-optimized"

test_cases_total = collect_test_cases(test_filename_total, model_name)
test_cases_total = collect_test_cases(
test_name_total, test_filename_total, model_name
)

if len(test_cases_total) == 0:
print(
Expand Down
13 changes: 13 additions & 0 deletions berkeley-function-call-leaderboard/bfcl/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# NOTE: These paths are relative to the `bfcl` directory where this script is located.
RESULT_PATH = "../result/"
PROMPT_PATH = "../data/"
MULTI_TURN_FUNC_DOC_PATH = "../data/multi_turn_func_doc/"
POSSIBLE_ANSWER_PATH = "../data/possible_answer/"
SCORE_PATH = "../score/"
DOTENV_PATH = "../.env"
Expand Down Expand Up @@ -175,11 +176,23 @@
],
}

MULTI_TURN_FUNC_DOC_FILE_MAPPING = {
"GorillaFileSystem": "gorilla_file_system.json",
"MathAPI": "math_api.json",
"MessageAPI": "message_api.json",
"TwitterAPI": "posting_api.json",
"TicketAPI": "ticket_api.json",
"TradingBot": "trading_bot.json",
"TravelAPI": "travel_booking.json",
"VehicleControlAPI": "vehicle_control.json",
}


# Construct the full path to use by other scripts
script_dir = Path(__file__).parent
RESULT_PATH = (script_dir / RESULT_PATH).resolve()
PROMPT_PATH = (script_dir / PROMPT_PATH).resolve()
MULTI_TURN_FUNC_DOC_PATH = (script_dir / MULTI_TURN_FUNC_DOC_PATH).resolve()
POSSIBLE_ANSWER_PATH = (script_dir / POSSIBLE_ANSWER_PATH).resolve()
SCORE_PATH = (script_dir / SCORE_PATH).resolve()
DOTENV_PATH = (script_dir / DOTENV_PATH).resolve()
Expand Down
Loading