1
1
import json
2
- import time
3
2
from copy import deepcopy
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING
4
5
5
6
from bfcl .constants .category_mapping import VERSION_PREFIX
6
7
from bfcl .constants .default_prompts import (
9
10
MAXIMUM_STEP_LIMIT ,
10
11
)
11
12
from bfcl .constants .eval_config import RESULT_PATH
12
- from bfcl .eval_checker .multi_turn_eval .multi_turn_utils import (
13
+ from bfcl .constants .executable_backend_config import (
14
+ OMIT_STATE_INFO_CLASSES ,
13
15
STATELESS_CLASSES ,
16
+ )
17
+ from bfcl .eval_checker .multi_turn_eval .multi_turn_utils import (
14
18
execute_multi_turn_func_call ,
15
19
is_empty_execute_response ,
16
20
)
17
21
from bfcl .model_handler .model_style import ModelStyle
18
- from bfcl .utils import load_file , make_json_serializable , sort_key
22
+ from bfcl .model_handler .utils import add_memory_instruction_system_prompt
23
+ from bfcl .utils import *
19
24
from overrides import final
20
25
26
+ if TYPE_CHECKING :
27
+ from bfcl .eval_checker .multi_turn_eval .func_source_code .memory_api_metaclass import (
28
+ MemoryAPI ,
29
+ )
30
+
21
31
22
32
class BaseHandler :
23
33
model_name : str
@@ -30,35 +40,47 @@ def __init__(self, model_name, temperature) -> None:
30
40
self .model_name_underline_replaced = (
31
41
model_name .replace ("/" , "_" ).replace ("-" , "_" ).replace ("." , "_" )
32
42
)
43
+ # The directory name for the model
44
+ self .model_name_dir = model_name .replace ("/" , "_" )
33
45
self .temperature = temperature
34
46
self .is_fc_model = False # Whether the model is a function calling model
35
47
36
- def inference (self , test_entry : dict , include_input_log : bool , exclude_state_log : bool ):
48
+ def inference (
49
+ self ,
50
+ test_entry : dict ,
51
+ include_input_log : bool ,
52
+ exclude_state_log : bool ,
53
+ result_dir = RESULT_PATH ,
54
+ ):
37
55
# This method is used to retrive model response for each model.
38
56
39
57
# FC model
40
58
# TODO: Let all models have the is_fc_model attribute and remove the "FC" check
41
59
if "FC" in self .model_name or self .is_fc_model :
42
- if "multi_turn" in test_entry ["id" ]:
60
+ if contain_multi_turn_interaction ( test_entry ["id" ]) :
43
61
return self .inference_multi_turn_FC (
44
- test_entry , include_input_log , exclude_state_log
62
+ test_entry , include_input_log , exclude_state_log , result_dir
45
63
)
46
64
else :
47
65
return self .inference_single_turn_FC (test_entry , include_input_log )
48
66
# Prompting model
49
67
else :
50
- if "multi_turn" in test_entry ["id" ]:
68
+ if contain_multi_turn_interaction ( test_entry ["id" ]) :
51
69
return self .inference_multi_turn_prompting (
52
- test_entry , include_input_log , exclude_state_log
70
+ test_entry , include_input_log , exclude_state_log , result_dir
53
71
)
54
72
else :
55
73
return self .inference_single_turn_prompting (test_entry , include_input_log )
56
74
57
75
@final
58
76
def inference_multi_turn_FC (
59
- self , test_entry : dict , include_input_log : bool , exclude_state_log : bool
77
+ self ,
78
+ test_entry : dict ,
79
+ include_input_log : bool ,
80
+ exclude_state_log : bool ,
81
+ result_dir : Path ,
60
82
) -> tuple [list [list ], dict ]:
61
- initial_config : dict = test_entry [ "initial_config" ]
83
+ initial_config : dict = test_entry . get ( "initial_config" , {})
62
84
involved_classes : list = test_entry ["involved_classes" ]
63
85
test_entry_id : str = test_entry ["id" ]
64
86
test_category : str = test_entry_id .rsplit ("_" , 1 )[0 ]
@@ -79,22 +101,32 @@ def inference_multi_turn_FC(
79
101
force_quit = False # Whether the model has been forced to quit. If True, this whole entry will be failed.
80
102
81
103
all_reasoning_content : list [list ] = []
104
+
82
105
# Execute no function call, but just to get a reference to all the instances to get the initial state for logging purpose
83
- if not exclude_state_log :
84
- _ , involved_instances = execute_multi_turn_func_call (
85
- [],
86
- initial_config ,
87
- involved_classes ,
88
- self .model_name_underline_replaced ,
89
- test_entry_id ,
90
- long_context = (
91
- "long_context" in test_category or "composite" in test_category
92
- ),
93
- is_evaL_run = False ,
106
+ _ , involved_instances = execute_multi_turn_func_call (
107
+ [],
108
+ initial_config ,
109
+ involved_classes ,
110
+ self .model_name_underline_replaced ,
111
+ test_entry_id ,
112
+ long_context = ("long_context" in test_category or "composite" in test_category ),
113
+ is_evaL_run = False ,
114
+ )
115
+
116
+ if is_memory (test_category ):
117
+ assert (
118
+ len (involved_instances ) == 1
119
+ ), "Memory category should only involve one class."
120
+
121
+ memory_instance : "MemoryAPI" = list (involved_instances .values ())[0 ]
122
+ test_entry ["question" ] = add_memory_instruction_system_prompt (
123
+ test_entry ["question" ], test_category , test_entry ["scenario" ], memory_instance
94
124
)
125
+
126
+ if not exclude_state_log :
95
127
state_log = []
96
128
for class_name , class_instance in involved_instances .items ():
97
- if class_name in STATELESS_CLASSES :
129
+ if class_name in STATELESS_CLASSES or class_name in OMIT_STATE_INFO_CLASSES :
98
130
continue
99
131
# Avoid modification in future turns
100
132
class_instance = deepcopy (class_instance )
@@ -109,7 +141,8 @@ def inference_multi_turn_FC(
109
141
},
110
142
}
111
143
)
112
- all_inference_log .append (state_log )
144
+ if len (state_log ) > 0 :
145
+ all_inference_log .append (state_log )
113
146
114
147
inference_data : dict = {}
115
148
inference_data = self ._pre_query_processing_FC (inference_data , test_entry )
@@ -126,6 +159,8 @@ def inference_multi_turn_FC(
126
159
assert (
127
160
len (current_turn_message ) == 0
128
161
), "Holdout turn should not have user message."
162
+ # TODO: Move this to before pre_query_processing_FC.
163
+ # Shouldn't be happening in the inference loop.
129
164
current_turn_message = [
130
165
{
131
166
"role" : "user" ,
@@ -284,7 +319,10 @@ def inference_multi_turn_FC(
284
319
if not exclude_state_log :
285
320
state_log = []
286
321
for class_name , class_instance in involved_instances .items ():
287
- if class_name in STATELESS_CLASSES :
322
+ if (
323
+ class_name in STATELESS_CLASSES
324
+ or class_name in OMIT_STATE_INFO_CLASSES
325
+ ):
288
326
continue
289
327
# Avoid modification in future turns
290
328
class_instance = deepcopy (class_instance )
@@ -299,11 +337,21 @@ def inference_multi_turn_FC(
299
337
},
300
338
}
301
339
)
302
- all_inference_log .append (state_log )
340
+ if len (state_log ) > 0 :
341
+ all_inference_log .append (state_log )
303
342
304
343
if force_quit :
305
344
break
306
345
346
+ # Special handling for the memory category
347
+ # Need to flush the memory to local file at the end of the conversation
348
+ if is_memory_prereq (test_entry_id ):
349
+ assert (
350
+ len (involved_instances ) == 1
351
+ ), "Memory category should only involve one class."
352
+ memory_instance : "MemoryAPI" = list (involved_instances .values ())[0 ]
353
+ memory_instance ._flush_memory_to_local_file ()
354
+
307
355
metadata = {
308
356
"input_token_count" : total_input_token_count ,
309
357
"output_token_count" : total_output_token_count ,
@@ -321,9 +369,13 @@ def inference_multi_turn_FC(
321
369
322
370
@final
323
371
def inference_multi_turn_prompting (
324
- self , test_entry : dict , include_input_log : bool , exclude_state_log : bool
372
+ self ,
373
+ test_entry : dict ,
374
+ include_input_log : bool ,
375
+ exclude_state_log : bool ,
376
+ result_dir : Path ,
325
377
) -> tuple [list [list ], dict ]:
326
- initial_config : dict = test_entry [ "initial_config" ]
378
+ initial_config : dict = test_entry . get ( "initial_config" , {})
327
379
involved_classes : list = test_entry ["involved_classes" ]
328
380
test_entry_id : str = test_entry ["id" ]
329
381
test_category : str = test_entry_id .rsplit ("_" , 1 )[0 ]
@@ -344,21 +396,30 @@ def inference_multi_turn_prompting(
344
396
force_quit = False # Whether the model has been forced to quit. If True, this whole entry will be failed.
345
397
346
398
# Execute no function call, but just to get a reference to all the instances to get the initial state for logging purpose
347
- if not exclude_state_log :
348
- _ , involved_instances = execute_multi_turn_func_call (
349
- [],
350
- initial_config ,
351
- involved_classes ,
352
- self .model_name_underline_replaced ,
353
- test_entry_id ,
354
- long_context = (
355
- "long_context" in test_category or "composite" in test_category
356
- ),
357
- is_evaL_run = False ,
399
+ _ , involved_instances = execute_multi_turn_func_call (
400
+ [],
401
+ initial_config ,
402
+ involved_classes ,
403
+ self .model_name_underline_replaced ,
404
+ test_entry_id ,
405
+ long_context = ("long_context" in test_category or "composite" in test_category ),
406
+ is_evaL_run = False ,
407
+ )
408
+
409
+ if is_memory (test_category ):
410
+ assert (
411
+ len (involved_instances ) == 1
412
+ ), "Memory category should only involve one class."
413
+
414
+ memory_instance : "MemoryAPI" = list (involved_instances .values ())[0 ]
415
+ test_entry ["question" ] = add_memory_instruction_system_prompt (
416
+ test_entry ["question" ], test_category , test_entry ["scenario" ], memory_instance
358
417
)
418
+
419
+ if not exclude_state_log :
359
420
state_log = []
360
421
for class_name , class_instance in involved_instances .items ():
361
- if class_name in STATELESS_CLASSES :
422
+ if class_name in STATELESS_CLASSES or class_name in OMIT_STATE_INFO_CLASSES :
362
423
continue
363
424
# Avoid modification in future turns
364
425
class_instance = deepcopy (class_instance )
@@ -373,7 +434,8 @@ def inference_multi_turn_prompting(
373
434
},
374
435
}
375
436
)
376
- all_inference_log .append (state_log )
437
+ if len (state_log ) > 0 :
438
+ all_inference_log .append (state_log )
377
439
378
440
inference_data : dict = self ._pre_query_processing_prompting (test_entry )
379
441
@@ -544,7 +606,10 @@ def inference_multi_turn_prompting(
544
606
if not exclude_state_log :
545
607
state_log = []
546
608
for class_name , class_instance in involved_instances .items ():
547
- if class_name in STATELESS_CLASSES :
609
+ if (
610
+ class_name in STATELESS_CLASSES
611
+ or class_name in OMIT_STATE_INFO_CLASSES
612
+ ):
548
613
continue
549
614
# Avoid modification in future turns
550
615
class_instance = deepcopy (class_instance )
@@ -559,11 +624,21 @@ def inference_multi_turn_prompting(
559
624
},
560
625
}
561
626
)
562
- all_inference_log .append (state_log )
627
+ if len (state_log ) > 0 :
628
+ all_inference_log .append (state_log )
563
629
564
630
if force_quit :
565
631
break
566
632
633
+ # Special handling for the memory category
634
+ # Need to flush the memory to local file at the end of the conversation
635
+ if is_memory_prereq (test_entry_id ):
636
+ assert (
637
+ len (involved_instances ) == 1
638
+ ), "Memory category should only involve one class."
639
+ memory_instance : "MemoryAPI" = list (involved_instances .values ())[0 ]
640
+ memory_instance ._flush_memory_to_local_file ()
641
+
567
642
metadata = {
568
643
"input_token_count" : total_input_token_count ,
569
644
"output_token_count" : total_output_token_count ,
0 commit comments