Skip to content

Commit a15f4fc

Browse files
committed
update inference_multi_turn logic
1 parent 8ec2113 commit a15f4fc

File tree

1 file changed

+117
-42
lines changed

1 file changed

+117
-42
lines changed

berkeley-function-call-leaderboard/bfcl/model_handler/base_handler.py

Lines changed: 117 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
2-
import time
32
from copy import deepcopy
3+
from pathlib import Path
4+
from typing import TYPE_CHECKING
45

56
from bfcl.constants.category_mapping import VERSION_PREFIX
67
from bfcl.constants.default_prompts import (
@@ -9,15 +10,24 @@
910
MAXIMUM_STEP_LIMIT,
1011
)
1112
from bfcl.constants.eval_config import RESULT_PATH
12-
from bfcl.eval_checker.multi_turn_eval.multi_turn_utils import (
13+
from bfcl.constants.executable_backend_config import (
14+
OMIT_STATE_INFO_CLASSES,
1315
STATELESS_CLASSES,
16+
)
17+
from bfcl.eval_checker.multi_turn_eval.multi_turn_utils import (
1418
execute_multi_turn_func_call,
1519
is_empty_execute_response,
1620
)
1721
from bfcl.model_handler.model_style import ModelStyle
18-
from bfcl.utils import load_file, make_json_serializable, sort_key
22+
from bfcl.model_handler.utils import add_memory_instruction_system_prompt
23+
from bfcl.utils import *
1924
from overrides import final
2025

26+
if TYPE_CHECKING:
27+
from bfcl.eval_checker.multi_turn_eval.func_source_code.memory_api_metaclass import (
28+
MemoryAPI,
29+
)
30+
2131

2232
class BaseHandler:
2333
model_name: str
@@ -30,35 +40,47 @@ def __init__(self, model_name, temperature) -> None:
3040
self.model_name_underline_replaced = (
3141
model_name.replace("/", "_").replace("-", "_").replace(".", "_")
3242
)
43+
# The directory name for the model
44+
self.model_name_dir = model_name.replace("/", "_")
3345
self.temperature = temperature
3446
self.is_fc_model = False # Whether the model is a function calling model
3547

36-
def inference(self, test_entry: dict, include_input_log: bool, exclude_state_log: bool):
48+
def inference(
49+
self,
50+
test_entry: dict,
51+
include_input_log: bool,
52+
exclude_state_log: bool,
53+
result_dir=RESULT_PATH,
54+
):
3755
# This method is used to retrive model response for each model.
3856

3957
# FC model
4058
# TODO: Let all models have the is_fc_model attribute and remove the "FC" check
4159
if "FC" in self.model_name or self.is_fc_model:
42-
if "multi_turn" in test_entry["id"]:
60+
if contain_multi_turn_interaction(test_entry["id"]):
4361
return self.inference_multi_turn_FC(
44-
test_entry, include_input_log, exclude_state_log
62+
test_entry, include_input_log, exclude_state_log, result_dir
4563
)
4664
else:
4765
return self.inference_single_turn_FC(test_entry, include_input_log)
4866
# Prompting model
4967
else:
50-
if "multi_turn" in test_entry["id"]:
68+
if contain_multi_turn_interaction(test_entry["id"]):
5169
return self.inference_multi_turn_prompting(
52-
test_entry, include_input_log, exclude_state_log
70+
test_entry, include_input_log, exclude_state_log, result_dir
5371
)
5472
else:
5573
return self.inference_single_turn_prompting(test_entry, include_input_log)
5674

5775
@final
5876
def inference_multi_turn_FC(
59-
self, test_entry: dict, include_input_log: bool, exclude_state_log: bool
77+
self,
78+
test_entry: dict,
79+
include_input_log: bool,
80+
exclude_state_log: bool,
81+
result_dir: Path,
6082
) -> tuple[list[list], dict]:
61-
initial_config: dict = test_entry["initial_config"]
83+
initial_config: dict = test_entry.get("initial_config", {})
6284
involved_classes: list = test_entry["involved_classes"]
6385
test_entry_id: str = test_entry["id"]
6486
test_category: str = test_entry_id.rsplit("_", 1)[0]
@@ -79,22 +101,32 @@ def inference_multi_turn_FC(
79101
force_quit = False # Whether the model has been forced to quit. If True, this whole entry will be failed.
80102

81103
all_reasoning_content: list[list] = []
104+
82105
# Execute no function call, but just to get a reference to all the instances to get the initial state for logging purpose
83-
if not exclude_state_log:
84-
_, involved_instances = execute_multi_turn_func_call(
85-
[],
86-
initial_config,
87-
involved_classes,
88-
self.model_name_underline_replaced,
89-
test_entry_id,
90-
long_context=(
91-
"long_context" in test_category or "composite" in test_category
92-
),
93-
is_evaL_run=False,
106+
_, involved_instances = execute_multi_turn_func_call(
107+
[],
108+
initial_config,
109+
involved_classes,
110+
self.model_name_underline_replaced,
111+
test_entry_id,
112+
long_context=("long_context" in test_category or "composite" in test_category),
113+
is_evaL_run=False,
114+
)
115+
116+
if is_memory(test_category):
117+
assert (
118+
len(involved_instances) == 1
119+
), "Memory category should only involve one class."
120+
121+
memory_instance: "MemoryAPI" = list(involved_instances.values())[0]
122+
test_entry["question"] = add_memory_instruction_system_prompt(
123+
test_entry["question"], test_category, test_entry["scenario"], memory_instance
94124
)
125+
126+
if not exclude_state_log:
95127
state_log = []
96128
for class_name, class_instance in involved_instances.items():
97-
if class_name in STATELESS_CLASSES:
129+
if class_name in STATELESS_CLASSES or class_name in OMIT_STATE_INFO_CLASSES:
98130
continue
99131
# Avoid modification in future turns
100132
class_instance = deepcopy(class_instance)
@@ -109,7 +141,8 @@ def inference_multi_turn_FC(
109141
},
110142
}
111143
)
112-
all_inference_log.append(state_log)
144+
if len(state_log) > 0:
145+
all_inference_log.append(state_log)
113146

114147
inference_data: dict = {}
115148
inference_data = self._pre_query_processing_FC(inference_data, test_entry)
@@ -126,6 +159,8 @@ def inference_multi_turn_FC(
126159
assert (
127160
len(current_turn_message) == 0
128161
), "Holdout turn should not have user message."
162+
# TODO: Move this to before pre_query_processing_FC.
163+
# Shouldn't be happening in the inference loop.
129164
current_turn_message = [
130165
{
131166
"role": "user",
@@ -284,7 +319,10 @@ def inference_multi_turn_FC(
284319
if not exclude_state_log:
285320
state_log = []
286321
for class_name, class_instance in involved_instances.items():
287-
if class_name in STATELESS_CLASSES:
322+
if (
323+
class_name in STATELESS_CLASSES
324+
or class_name in OMIT_STATE_INFO_CLASSES
325+
):
288326
continue
289327
# Avoid modification in future turns
290328
class_instance = deepcopy(class_instance)
@@ -299,11 +337,21 @@ def inference_multi_turn_FC(
299337
},
300338
}
301339
)
302-
all_inference_log.append(state_log)
340+
if len(state_log) > 0:
341+
all_inference_log.append(state_log)
303342

304343
if force_quit:
305344
break
306345

346+
# Special handling for the memory category
347+
# Need to flush the memory to local file at the end of the conversation
348+
if is_memory_prereq(test_entry_id):
349+
assert (
350+
len(involved_instances) == 1
351+
), "Memory category should only involve one class."
352+
memory_instance: "MemoryAPI" = list(involved_instances.values())[0]
353+
memory_instance._flush_memory_to_local_file()
354+
307355
metadata = {
308356
"input_token_count": total_input_token_count,
309357
"output_token_count": total_output_token_count,
@@ -321,9 +369,13 @@ def inference_multi_turn_FC(
321369

322370
@final
323371
def inference_multi_turn_prompting(
324-
self, test_entry: dict, include_input_log: bool, exclude_state_log: bool
372+
self,
373+
test_entry: dict,
374+
include_input_log: bool,
375+
exclude_state_log: bool,
376+
result_dir: Path,
325377
) -> tuple[list[list], dict]:
326-
initial_config: dict = test_entry["initial_config"]
378+
initial_config: dict = test_entry.get("initial_config", {})
327379
involved_classes: list = test_entry["involved_classes"]
328380
test_entry_id: str = test_entry["id"]
329381
test_category: str = test_entry_id.rsplit("_", 1)[0]
@@ -344,21 +396,30 @@ def inference_multi_turn_prompting(
344396
force_quit = False # Whether the model has been forced to quit. If True, this whole entry will be failed.
345397

346398
# Execute no function call, but just to get a reference to all the instances to get the initial state for logging purpose
347-
if not exclude_state_log:
348-
_, involved_instances = execute_multi_turn_func_call(
349-
[],
350-
initial_config,
351-
involved_classes,
352-
self.model_name_underline_replaced,
353-
test_entry_id,
354-
long_context=(
355-
"long_context" in test_category or "composite" in test_category
356-
),
357-
is_evaL_run=False,
399+
_, involved_instances = execute_multi_turn_func_call(
400+
[],
401+
initial_config,
402+
involved_classes,
403+
self.model_name_underline_replaced,
404+
test_entry_id,
405+
long_context=("long_context" in test_category or "composite" in test_category),
406+
is_evaL_run=False,
407+
)
408+
409+
if is_memory(test_category):
410+
assert (
411+
len(involved_instances) == 1
412+
), "Memory category should only involve one class."
413+
414+
memory_instance: "MemoryAPI" = list(involved_instances.values())[0]
415+
test_entry["question"] = add_memory_instruction_system_prompt(
416+
test_entry["question"], test_category, test_entry["scenario"], memory_instance
358417
)
418+
419+
if not exclude_state_log:
359420
state_log = []
360421
for class_name, class_instance in involved_instances.items():
361-
if class_name in STATELESS_CLASSES:
422+
if class_name in STATELESS_CLASSES or class_name in OMIT_STATE_INFO_CLASSES:
362423
continue
363424
# Avoid modification in future turns
364425
class_instance = deepcopy(class_instance)
@@ -373,7 +434,8 @@ def inference_multi_turn_prompting(
373434
},
374435
}
375436
)
376-
all_inference_log.append(state_log)
437+
if len(state_log) > 0:
438+
all_inference_log.append(state_log)
377439

378440
inference_data: dict = self._pre_query_processing_prompting(test_entry)
379441

@@ -544,7 +606,10 @@ def inference_multi_turn_prompting(
544606
if not exclude_state_log:
545607
state_log = []
546608
for class_name, class_instance in involved_instances.items():
547-
if class_name in STATELESS_CLASSES:
609+
if (
610+
class_name in STATELESS_CLASSES
611+
or class_name in OMIT_STATE_INFO_CLASSES
612+
):
548613
continue
549614
# Avoid modification in future turns
550615
class_instance = deepcopy(class_instance)
@@ -559,11 +624,21 @@ def inference_multi_turn_prompting(
559624
},
560625
}
561626
)
562-
all_inference_log.append(state_log)
627+
if len(state_log) > 0:
628+
all_inference_log.append(state_log)
563629

564630
if force_quit:
565631
break
566632

633+
# Special handling for the memory category
634+
# Need to flush the memory to local file at the end of the conversation
635+
if is_memory_prereq(test_entry_id):
636+
assert (
637+
len(involved_instances) == 1
638+
), "Memory category should only involve one class."
639+
memory_instance: "MemoryAPI" = list(involved_instances.values())[0]
640+
memory_instance._flush_memory_to_local_file()
641+
567642
metadata = {
568643
"input_token_count": total_input_token_count,
569644
"output_token_count": total_output_token_count,

0 commit comments

Comments
 (0)