28
28
from dotenv import load_dotenv
29
29
from tqdm import tqdm
30
30
31
- # A dictionary to store the evaluation scores.
32
- # Key is model name, value is a dictionary with keys as test category and values as a dictionary with accuracy and total count
33
- LEADERBOARD_TABLE = {}
31
+
32
+ def get_handler (model_name ):
33
+ return HANDLER_MAP [model_name ](
34
+ model_name , temperature = 0
35
+ ) # Temperature doesn't matter for evaluation
34
36
35
37
36
38
def multi_turn_runner (
@@ -452,17 +454,26 @@ def ast_file_runner(
452
454
#### Main runner function ####
453
455
def runner (model_names , test_categories , api_sanity_check , result_dir , score_dir ):
454
456
455
- # A flag to indicate if the API has been tested.
456
- # We should always test the API with ground truth first before running the executable tests.
457
- # Sometimes the API may not be working as expected and we want to catch that before running the evaluation to ensure the results are accurate.
458
- API_TESTED = False
459
- API_STATUS_ERROR_REST = None
460
- API_STATUS_ERROR_EXECUTABLE = None
461
-
462
- # Before running the executable evaluation, we need to get the expected output from the ground truth.
463
- # So we need a list of all the test categories that we have ran the ground truth evaluation on.
464
- # We only get the expected output once for each test category.
465
- EXECUTABLE_TEST_CATEGORIES_HAVE_RUN = []
457
+ # State udpated by each eval subtask.
458
+ state = dict (
459
+ # Flags to indicate if the API has been tested.
460
+ # We should always test the API with ground truth first before running
461
+ # the executable tests. Sometimes the API may not be working as expected
462
+ # and we want to catch that before running the evaluation to ensure the
463
+ # results are accurate.
464
+ api_tested = False ,
465
+ api_status_error_rest = None ,
466
+ api_status_error_executable = None ,
467
+ # A dictionary to store the evaluation scores.
468
+ # Key is model name, value is a dictionary with keys as test category
469
+ # and values as a dictionary with accuracy and total count.
470
+ leaderboard_table = {},
471
+ # Before running the executable evaluation, we need to get the expected
472
+ # output from the ground truth. So we need a list of all the test
473
+ # categories that we have ran the ground truth evaluation on. We only
474
+ # get the expected output once for each test category.
475
+ executable_test_categories_have_run = [],
476
+ )
466
477
467
478
# Get a list of all entries in the folder
468
479
entries = result_dir .iterdir ()
@@ -489,135 +500,156 @@ def runner(model_names, test_categories, api_sanity_check, result_dir, score_dir
489
500
490
501
handler = get_handler (model_name_escaped )
491
502
492
- # We don't evaluate chatable and SQL models in our current leaderboard
503
+ # We don't evaluate chatable and SQL models in our current
504
+ # leaderboard.
493
505
if is_chatable (test_category ) or is_sql (test_category ):
494
506
continue
495
507
496
- language = "Python"
497
- if is_java (test_category ):
498
- language = "Java"
499
- if is_js (test_category ):
500
- language = "JavaScript"
501
-
502
- print (f"🔍 Running test: { test_category } " )
503
-
504
508
model_result = load_file (model_result_json , sort_by_id = True )
505
- record_cost_latency (LEADERBOARD_TABLE , model_name , model_result )
506
-
507
- # Find the corresponding test file
508
- prompt_file = find_file_with_suffix (PROMPT_PATH , test_category )
509
- prompt = load_file (prompt_file , sort_by_id = True )
510
-
511
- if is_relevance_or_irrelevance (test_category ):
512
- accuracy , total_count = relevance_file_runner (
513
- handler , model_result , prompt , model_name , test_category , score_dir
514
- )
515
- record_result (
516
- LEADERBOARD_TABLE , model_name , test_category , accuracy , total_count
517
- )
518
- print (f"✅ Test completed: { test_category } . 🎯 Accuracy: { accuracy } " )
519
- continue
520
-
521
- if is_executable (test_category ):
522
- # We only test the API with ground truth once
523
- if not API_TESTED and api_sanity_check :
524
- print ("---- Sanity checking API status ----" )
525
- try :
526
- api_status_sanity_check_rest ()
527
- except BadAPIStatusError as e :
528
- API_STATUS_ERROR_REST = e
529
-
530
- try :
531
- api_status_sanity_check_executable ()
532
- except BadAPIStatusError as e :
533
- API_STATUS_ERROR_EXECUTABLE = e
534
-
535
- display_api_status_error (
536
- API_STATUS_ERROR_REST ,
537
- API_STATUS_ERROR_EXECUTABLE ,
538
- display_success = True ,
539
- )
540
- print ("Continuing evaluation..." )
541
-
542
- API_TESTED = True
543
-
544
- if (
545
- test_category not in EXECUTABLE_TEST_CATEGORIES_HAVE_RUN
546
- and not is_rest (test_category )
547
- ):
548
- print (
549
- f"---- Getting real-time execution result from ground truth for { test_category } ----"
550
- )
551
- get_executable_expected_output (prompt_file )
552
- print (
553
- f"---- Ground truth real-time execution result obtained for { test_category } 🌟 ----"
554
- )
555
- EXECUTABLE_TEST_CATEGORIES_HAVE_RUN .append (test_category )
556
- # Need to re-load the prompt file after getting the expected output, as the prompt file has been updated
557
- prompt = load_file (prompt_file , sort_by_id = True )
558
-
559
- accuracy , total_count = executable_file_runner (
560
- handler , model_result , prompt , model_name , test_category , score_dir
561
- )
562
- record_result (
563
- LEADERBOARD_TABLE , model_name , test_category , accuracy , total_count
564
- )
565
- print (f"✅ Test completed: { test_category } . 🎯 Accuracy: { accuracy } " )
566
-
567
- continue
568
509
569
- # Find the corresponding possible answer file
570
- possible_answer_file = find_file_with_suffix (
571
- POSSIBLE_ANSWER_PATH , test_category
510
+ state = evaluate_task (
511
+ test_category ,
512
+ api_sanity_check ,
513
+ result_dir ,
514
+ score_dir ,
515
+ model_result ,
516
+ model_name ,
517
+ handler ,
518
+ state ,
572
519
)
573
- possible_answer = load_file (possible_answer_file , sort_by_id = True )
574
-
575
- if is_multi_turn (test_category ):
576
- accuracy , total_count = multi_turn_runner (
577
- handler ,
578
- model_result ,
579
- prompt ,
580
- possible_answer ,
581
- model_name ,
582
- test_category ,
583
- score_dir ,
584
- )
585
- record_result (
586
- LEADERBOARD_TABLE , model_name , test_category , accuracy , total_count
587
- )
588
- print (f"✅ Test completed: { test_category } . 🎯 Accuracy: { accuracy } " )
589
- # Single turn test
590
- else :
591
- accuracy , total_count = ast_file_runner (
592
- handler ,
593
- model_result ,
594
- prompt ,
595
- possible_answer ,
596
- language ,
597
- test_category ,
598
- model_name ,
599
- score_dir ,
600
- )
601
- record_result (
602
- LEADERBOARD_TABLE , model_name , test_category , accuracy , total_count
603
- )
604
- print (f"✅ Test completed: { test_category } . 🎯 Accuracy: { accuracy } " )
605
520
606
- # This function reads all the score files from local folder and updates the leaderboard table.
607
- # This is helpful when you only want to run the evaluation for a subset of models and test categories.
608
- update_leaderboard_table_with_local_score_file (LEADERBOARD_TABLE , score_dir )
521
+ # This function reads all the score files from local folder and updates the
522
+ # leaderboard table. This is helpful when you only want to run the
523
+ # evaluation for a subset of models and test categories.
524
+ update_leaderboard_table_with_local_score_file (state ["leaderboard_table" ], score_dir )
609
525
# Write the leaderboard table to a file
610
- generate_leaderboard_csv (LEADERBOARD_TABLE , score_dir , model_names , test_categories )
526
+ generate_leaderboard_csv (
527
+ state ["leaderboard_table" ], score_dir , model_names , test_categories
528
+ )
611
529
612
530
# Clean up the executable expected output files
613
531
# They should be re-generated the next time the evaluation is run
614
- clean_up_executable_expected_output (PROMPT_PATH , EXECUTABLE_TEST_CATEGORIES_HAVE_RUN )
532
+ clean_up_executable_expected_output (
533
+ PROMPT_PATH , state ["executable_test_categories_have_run" ]
534
+ )
615
535
616
536
display_api_status_error (
617
- API_STATUS_ERROR_REST , API_STATUS_ERROR_EXECUTABLE , display_success = False
537
+ state ["api_status_error_rest" ],
538
+ state ["api_status_error_executable" ],
539
+ display_success = False ,
618
540
)
619
541
620
542
543
+ def evaluate_task (
544
+ test_category ,
545
+ api_sanity_check ,
546
+ result_dir ,
547
+ score_dir ,
548
+ model_result ,
549
+ model_name ,
550
+ handler ,
551
+ state ,
552
+ ):
553
+
554
+ language = "Python"
555
+ if is_java (test_category ):
556
+ language = "Java"
557
+ if is_js (test_category ):
558
+ language = "JavaScript"
559
+
560
+ print (f"🔍 Running test: { test_category } " )
561
+
562
+ record_cost_latency (state ["leaderboard_table" ], model_name , model_result )
563
+
564
+ # Find the corresponding test file.
565
+ prompt_file = find_file_with_suffix (PROMPT_PATH , test_category )
566
+ prompt = load_file (prompt_file , sort_by_id = True )
567
+
568
+ if is_relevance_or_irrelevance (test_category ):
569
+ accuracy , total_count = relevance_file_runner (
570
+ handler , model_result , prompt , model_name , test_category , score_dir
571
+ )
572
+
573
+ elif is_executable (test_category ):
574
+ # We only test the API with ground truth once.
575
+ if not state ["api_tested" ] and api_sanity_check :
576
+ print ("---- Sanity checking API status ----" )
577
+ try :
578
+ api_status_sanity_check_rest ()
579
+ except BadAPIStatusError as e :
580
+ state ["api_status_error_rest" ] = e
581
+
582
+ try :
583
+ api_status_sanity_check_executable ()
584
+ except BadAPIStatusError as e :
585
+ state ["api_status_error_executable" ] = e
586
+
587
+ display_api_status_error (
588
+ state ["api_status_error_rest" ],
589
+ state ["api_status_error_executable" ],
590
+ display_success = True ,
591
+ )
592
+ print ("Continuing evaluation..." )
593
+
594
+ state ["api_tested" ] = True
595
+
596
+ if (
597
+ test_category not in state ["executable_test_categories_have_run" ]
598
+ and not is_rest (test_category )
599
+ ):
600
+ print (
601
+ f"---- Getting real-time execution result from ground truth"
602
+ f" for { test_category } ----"
603
+ )
604
+ get_executable_expected_output (prompt_file )
605
+ print (
606
+ f"---- Ground truth real-time execution result obtained for"
607
+ f" { test_category } 🌟 ----"
608
+ )
609
+ state ["executable_test_categories_have_run" ].append (test_category )
610
+ # Need to re-load the prompt file after getting the expected
611
+ # output, as the prompt file has been updated.
612
+ prompt = load_file (prompt_file , sort_by_id = True )
613
+
614
+ accuracy , total_count = executable_file_runner (
615
+ handler , model_result , prompt , model_name , test_category , score_dir
616
+ )
617
+
618
+ else :
619
+ # Find the corresponding possible answer file
620
+ possible_answer_file = find_file_with_suffix (POSSIBLE_ANSWER_PATH , test_category )
621
+ possible_answer = load_file (possible_answer_file , sort_by_id = True )
622
+
623
+ if is_multi_turn (test_category ):
624
+ accuracy , total_count = multi_turn_runner (
625
+ handler ,
626
+ model_result ,
627
+ prompt ,
628
+ possible_answer ,
629
+ model_name ,
630
+ test_category ,
631
+ score_dir ,
632
+ )
633
+
634
+ # Single turn test
635
+ else :
636
+ accuracy , total_count = ast_file_runner (
637
+ handler ,
638
+ model_result ,
639
+ prompt ,
640
+ possible_answer ,
641
+ language ,
642
+ test_category ,
643
+ model_name ,
644
+ score_dir ,
645
+ )
646
+
647
+ record_result (state , model_name , test_category , accuracy , total_count )
648
+ print (f"✅ Test completed: { test_category } . 🎯 Accuracy: { accuracy } " )
649
+
650
+ return state
651
+
652
+
621
653
def main (model , test_categories , api_sanity_check , result_dir , score_dir ):
622
654
if result_dir is None :
623
655
result_dir = RESULT_PATH
@@ -674,12 +706,6 @@ def main(model, test_categories, api_sanity_check, result_dir, score_dir):
674
706
)
675
707
676
708
677
- def get_handler (model_name ):
678
- return HANDLER_MAP [model_name ](
679
- model_name , temperature = 0
680
- ) # Temperature doesn't matter for evaluation
681
-
682
-
683
709
if __name__ == "__main__" :
684
710
parser = argparse .ArgumentParser (description = "Process two lists of strings." )
685
711
0 commit comments