Skip to content

Commit fa6a9af

Browse files
fvisinHuanzhiMao
andauthored
[BFCL] Move logic to eval a task in a separate function. (#933)
This change exposes an API for 3P code to run the evaluation of a single task, while also improving code readability. --------- Co-authored-by: Huanzhi (Hans) Mao <[email protected]>
1 parent 1c603e9 commit fa6a9af

File tree

1 file changed

+159
-133
lines changed

1 file changed

+159
-133
lines changed

berkeley-function-call-leaderboard/bfcl/eval_checker/eval_runner.py

Lines changed: 159 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@
2828
from dotenv import load_dotenv
2929
from tqdm import tqdm
3030

31-
# A dictionary to store the evaluation scores.
32-
# Key is model name, value is a dictionary with keys as test category and values as a dictionary with accuracy and total count
33-
LEADERBOARD_TABLE = {}
31+
32+
def get_handler(model_name):
33+
return HANDLER_MAP[model_name](
34+
model_name, temperature=0
35+
) # Temperature doesn't matter for evaluation
3436

3537

3638
def multi_turn_runner(
@@ -452,17 +454,26 @@ def ast_file_runner(
452454
#### Main runner function ####
453455
def runner(model_names, test_categories, api_sanity_check, result_dir, score_dir):
454456

455-
# A flag to indicate if the API has been tested.
456-
# We should always test the API with ground truth first before running the executable tests.
457-
# Sometimes the API may not be working as expected and we want to catch that before running the evaluation to ensure the results are accurate.
458-
API_TESTED = False
459-
API_STATUS_ERROR_REST = None
460-
API_STATUS_ERROR_EXECUTABLE = None
461-
462-
# Before running the executable evaluation, we need to get the expected output from the ground truth.
463-
# So we need a list of all the test categories that we have ran the ground truth evaluation on.
464-
# We only get the expected output once for each test category.
465-
EXECUTABLE_TEST_CATEGORIES_HAVE_RUN = []
457+
# State udpated by each eval subtask.
458+
state = dict(
459+
# Flags to indicate if the API has been tested.
460+
# We should always test the API with ground truth first before running
461+
# the executable tests. Sometimes the API may not be working as expected
462+
# and we want to catch that before running the evaluation to ensure the
463+
# results are accurate.
464+
api_tested=False,
465+
api_status_error_rest=None,
466+
api_status_error_executable=None,
467+
# A dictionary to store the evaluation scores.
468+
# Key is model name, value is a dictionary with keys as test category
469+
# and values as a dictionary with accuracy and total count.
470+
leaderboard_table={},
471+
# Before running the executable evaluation, we need to get the expected
472+
# output from the ground truth. So we need a list of all the test
473+
# categories that we have ran the ground truth evaluation on. We only
474+
# get the expected output once for each test category.
475+
executable_test_categories_have_run=[],
476+
)
466477

467478
# Get a list of all entries in the folder
468479
entries = result_dir.iterdir()
@@ -489,135 +500,156 @@ def runner(model_names, test_categories, api_sanity_check, result_dir, score_dir
489500

490501
handler = get_handler(model_name_escaped)
491502

492-
# We don't evaluate chatable and SQL models in our current leaderboard
503+
# We don't evaluate chatable and SQL models in our current
504+
# leaderboard.
493505
if is_chatable(test_category) or is_sql(test_category):
494506
continue
495507

496-
language = "Python"
497-
if is_java(test_category):
498-
language = "Java"
499-
if is_js(test_category):
500-
language = "JavaScript"
501-
502-
print(f"🔍 Running test: {test_category}")
503-
504508
model_result = load_file(model_result_json, sort_by_id=True)
505-
record_cost_latency(LEADERBOARD_TABLE, model_name, model_result)
506-
507-
# Find the corresponding test file
508-
prompt_file = find_file_with_suffix(PROMPT_PATH, test_category)
509-
prompt = load_file(prompt_file, sort_by_id=True)
510-
511-
if is_relevance_or_irrelevance(test_category):
512-
accuracy, total_count = relevance_file_runner(
513-
handler, model_result, prompt, model_name, test_category, score_dir
514-
)
515-
record_result(
516-
LEADERBOARD_TABLE, model_name, test_category, accuracy, total_count
517-
)
518-
print(f"✅ Test completed: {test_category}. 🎯 Accuracy: {accuracy}")
519-
continue
520-
521-
if is_executable(test_category):
522-
# We only test the API with ground truth once
523-
if not API_TESTED and api_sanity_check:
524-
print("---- Sanity checking API status ----")
525-
try:
526-
api_status_sanity_check_rest()
527-
except BadAPIStatusError as e:
528-
API_STATUS_ERROR_REST = e
529-
530-
try:
531-
api_status_sanity_check_executable()
532-
except BadAPIStatusError as e:
533-
API_STATUS_ERROR_EXECUTABLE = e
534-
535-
display_api_status_error(
536-
API_STATUS_ERROR_REST,
537-
API_STATUS_ERROR_EXECUTABLE,
538-
display_success=True,
539-
)
540-
print("Continuing evaluation...")
541-
542-
API_TESTED = True
543-
544-
if (
545-
test_category not in EXECUTABLE_TEST_CATEGORIES_HAVE_RUN
546-
and not is_rest(test_category)
547-
):
548-
print(
549-
f"---- Getting real-time execution result from ground truth for {test_category} ----"
550-
)
551-
get_executable_expected_output(prompt_file)
552-
print(
553-
f"---- Ground truth real-time execution result obtained for {test_category} 🌟 ----"
554-
)
555-
EXECUTABLE_TEST_CATEGORIES_HAVE_RUN.append(test_category)
556-
# Need to re-load the prompt file after getting the expected output, as the prompt file has been updated
557-
prompt = load_file(prompt_file, sort_by_id=True)
558-
559-
accuracy, total_count = executable_file_runner(
560-
handler, model_result, prompt, model_name, test_category, score_dir
561-
)
562-
record_result(
563-
LEADERBOARD_TABLE, model_name, test_category, accuracy, total_count
564-
)
565-
print(f"✅ Test completed: {test_category}. 🎯 Accuracy: {accuracy}")
566-
567-
continue
568509

569-
# Find the corresponding possible answer file
570-
possible_answer_file = find_file_with_suffix(
571-
POSSIBLE_ANSWER_PATH, test_category
510+
state = evaluate_task(
511+
test_category,
512+
api_sanity_check,
513+
result_dir,
514+
score_dir,
515+
model_result,
516+
model_name,
517+
handler,
518+
state,
572519
)
573-
possible_answer = load_file(possible_answer_file, sort_by_id=True)
574-
575-
if is_multi_turn(test_category):
576-
accuracy, total_count = multi_turn_runner(
577-
handler,
578-
model_result,
579-
prompt,
580-
possible_answer,
581-
model_name,
582-
test_category,
583-
score_dir,
584-
)
585-
record_result(
586-
LEADERBOARD_TABLE, model_name, test_category, accuracy, total_count
587-
)
588-
print(f"✅ Test completed: {test_category}. 🎯 Accuracy: {accuracy}")
589-
# Single turn test
590-
else:
591-
accuracy, total_count = ast_file_runner(
592-
handler,
593-
model_result,
594-
prompt,
595-
possible_answer,
596-
language,
597-
test_category,
598-
model_name,
599-
score_dir,
600-
)
601-
record_result(
602-
LEADERBOARD_TABLE, model_name, test_category, accuracy, total_count
603-
)
604-
print(f"✅ Test completed: {test_category}. 🎯 Accuracy: {accuracy}")
605520

606-
# This function reads all the score files from local folder and updates the leaderboard table.
607-
# This is helpful when you only want to run the evaluation for a subset of models and test categories.
608-
update_leaderboard_table_with_local_score_file(LEADERBOARD_TABLE, score_dir)
521+
# This function reads all the score files from local folder and updates the
522+
# leaderboard table. This is helpful when you only want to run the
523+
# evaluation for a subset of models and test categories.
524+
update_leaderboard_table_with_local_score_file(state["leaderboard_table"], score_dir)
609525
# Write the leaderboard table to a file
610-
generate_leaderboard_csv(LEADERBOARD_TABLE, score_dir, model_names, test_categories)
526+
generate_leaderboard_csv(
527+
state["leaderboard_table"], score_dir, model_names, test_categories
528+
)
611529

612530
# Clean up the executable expected output files
613531
# They should be re-generated the next time the evaluation is run
614-
clean_up_executable_expected_output(PROMPT_PATH, EXECUTABLE_TEST_CATEGORIES_HAVE_RUN)
532+
clean_up_executable_expected_output(
533+
PROMPT_PATH, state["executable_test_categories_have_run"]
534+
)
615535

616536
display_api_status_error(
617-
API_STATUS_ERROR_REST, API_STATUS_ERROR_EXECUTABLE, display_success=False
537+
state["api_status_error_rest"],
538+
state["api_status_error_executable"],
539+
display_success=False,
618540
)
619541

620542

543+
def evaluate_task(
544+
test_category,
545+
api_sanity_check,
546+
result_dir,
547+
score_dir,
548+
model_result,
549+
model_name,
550+
handler,
551+
state,
552+
):
553+
554+
language = "Python"
555+
if is_java(test_category):
556+
language = "Java"
557+
if is_js(test_category):
558+
language = "JavaScript"
559+
560+
print(f"🔍 Running test: {test_category}")
561+
562+
record_cost_latency(state["leaderboard_table"], model_name, model_result)
563+
564+
# Find the corresponding test file.
565+
prompt_file = find_file_with_suffix(PROMPT_PATH, test_category)
566+
prompt = load_file(prompt_file, sort_by_id=True)
567+
568+
if is_relevance_or_irrelevance(test_category):
569+
accuracy, total_count = relevance_file_runner(
570+
handler, model_result, prompt, model_name, test_category, score_dir
571+
)
572+
573+
elif is_executable(test_category):
574+
# We only test the API with ground truth once.
575+
if not state["api_tested"] and api_sanity_check:
576+
print("---- Sanity checking API status ----")
577+
try:
578+
api_status_sanity_check_rest()
579+
except BadAPIStatusError as e:
580+
state["api_status_error_rest"] = e
581+
582+
try:
583+
api_status_sanity_check_executable()
584+
except BadAPIStatusError as e:
585+
state["api_status_error_executable"] = e
586+
587+
display_api_status_error(
588+
state["api_status_error_rest"],
589+
state["api_status_error_executable"],
590+
display_success=True,
591+
)
592+
print("Continuing evaluation...")
593+
594+
state["api_tested"] = True
595+
596+
if (
597+
test_category not in state["executable_test_categories_have_run"]
598+
and not is_rest(test_category)
599+
):
600+
print(
601+
f"---- Getting real-time execution result from ground truth"
602+
f" for {test_category} ----"
603+
)
604+
get_executable_expected_output(prompt_file)
605+
print(
606+
f"---- Ground truth real-time execution result obtained for"
607+
f" {test_category} 🌟 ----"
608+
)
609+
state["executable_test_categories_have_run"].append(test_category)
610+
# Need to re-load the prompt file after getting the expected
611+
# output, as the prompt file has been updated.
612+
prompt = load_file(prompt_file, sort_by_id=True)
613+
614+
accuracy, total_count = executable_file_runner(
615+
handler, model_result, prompt, model_name, test_category, score_dir
616+
)
617+
618+
else:
619+
# Find the corresponding possible answer file
620+
possible_answer_file = find_file_with_suffix(POSSIBLE_ANSWER_PATH, test_category)
621+
possible_answer = load_file(possible_answer_file, sort_by_id=True)
622+
623+
if is_multi_turn(test_category):
624+
accuracy, total_count = multi_turn_runner(
625+
handler,
626+
model_result,
627+
prompt,
628+
possible_answer,
629+
model_name,
630+
test_category,
631+
score_dir,
632+
)
633+
634+
# Single turn test
635+
else:
636+
accuracy, total_count = ast_file_runner(
637+
handler,
638+
model_result,
639+
prompt,
640+
possible_answer,
641+
language,
642+
test_category,
643+
model_name,
644+
score_dir,
645+
)
646+
647+
record_result(state, model_name, test_category, accuracy, total_count)
648+
print(f"✅ Test completed: {test_category}. 🎯 Accuracy: {accuracy}")
649+
650+
return state
651+
652+
621653
def main(model, test_categories, api_sanity_check, result_dir, score_dir):
622654
if result_dir is None:
623655
result_dir = RESULT_PATH
@@ -674,12 +706,6 @@ def main(model, test_categories, api_sanity_check, result_dir, score_dir):
674706
)
675707

676708

677-
def get_handler(model_name):
678-
return HANDLER_MAP[model_name](
679-
model_name, temperature=0
680-
) # Temperature doesn't matter for evaluation
681-
682-
683709
if __name__ == "__main__":
684710
parser = argparse.ArgumentParser(description="Process two lists of strings.")
685711

0 commit comments

Comments
 (0)