Skip to content

Commit 8fbde58

Browse files
xuangu-fangWinstonLiyt
authored andcommitted
feat: enable to set different version of idea-proposal for multi traces (#895)
* fix the logic of kb-inject, allow different verion * set more flexiable proposal-version change for multi-tarce * auto-lint * fix the divede-zero-bug in a trival way * keep the dump imp. first, update in next version * use get_sub_trace_count() to get trace_num_count * fix the conern case bug of divide-zero * update corner case * fix the bug * auto-lint * fis the bug * fix the logic bug in max_sota_filter * fix bug of old version of self.exp_gen.gen * update the reset_exp_gen_version * use get_parent_exps to replace all collect_all_ancestors * auto lint * fix the bug of reset_exp_gen_version * fix bug: update V3's old hypothesis_rank * trival patch on gap of V3 & V2 * make dump patch to unify proposal_V3's dentify_problems * auto-lint * fix the bug of sub_trace_count
1 parent 9fc6f0e commit 8fbde58

File tree

5 files changed

+135
-29
lines changed

5 files changed

+135
-29
lines changed

rdagent/app/data_science/conf.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,13 @@ class DataScienceBasePropSetting(KaggleBasePropSetting):
9191
# inject diverse when start a new sub-trace
9292
enable_inject_diverse: bool = False
9393

94-
# inject diverse at the root of the trace
94+
# inject knowledge at the root of the trace
9595
enable_inject_knowledge_at_root: bool = False
9696

97+
# enable different version of DSExpGen for multi-trace
98+
enable_multi_version_exp_gen: bool = False
99+
exp_gen_version_list: str = "v3,v2"
100+
97101
#### multi-trace: time for final multi-trace merge
98102
merge_hours: int = 2
99103
"""The time for merge"""

rdagent/scenarios/data_science/proposal/exp_gen/merge.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -304,21 +304,40 @@ def __init__(self, *args, **kwargs):
304304
)
305305
self.flag_start_merge = False
306306

307-
def gen(self, trace: DSTrace) -> DSExperiment:
307+
def reset_exp_gen_version(self, version: str = "v2"):
308+
DS_RD_SETTING.proposal_version = version
309+
logger.info(f"ExpGen2TraceAndMergeV2: Resetting proposal version to {version}")
310+
self.exp_gen = DataScienceRDLoop._get_exp_gen(
311+
f"rdagent.scenarios.data_science.proposal.exp_gen.DSExpGen", self.scen
312+
)
313+
314+
def gen(self, trace: DSTrace, selection: tuple[int, ...] = (-1,)) -> DSExperiment:
308315
timer: RDAgentTimer = RD_Agent_TIMER_wrapper.timer
309316
logger.info(f"Remain time: {timer.remain_time_duration}")
310317

311318
if timer.remain_time_duration >= timedelta(hours=DS_RD_SETTING.merge_hours):
312319

313320
if DS_RD_SETTING.enable_inject_knowledge_at_root:
321+
if DS_RD_SETTING.knowledge_base_path is not None and DS_RD_SETTING.idea_pool_json_path is not None:
322+
if len(trace.hist) == 0:
323+
# set the knowledge base option to True for the first trace
324+
DS_RD_SETTING.enable_knowledge_base = True
325+
326+
if DS_RD_SETTING.enable_multi_version_exp_gen:
327+
exp_gen_version_list = DS_RD_SETTING.exp_gen_version_list.split(",")
328+
for version in exp_gen_version_list:
329+
assert version in ["v3", "v2", "v1"]
314330

315331
if len(trace.hist) == 0:
316-
# set the knowledge base option to True for the first trace
317-
DS_RD_SETTING.enable_knowledge_base = True
332+
# set the proposal version for the first sub-trace
333+
self.reset_exp_gen_version(version=exp_gen_version_list[0])
334+
elif len(trace.get_current_selection()) == 0 and trace.sub_trace_count > 0:
335+
# reset the proposal version at the start of other sub-trace
336+
if trace.sub_trace_count - 1 < len(exp_gen_version_list):
337+
self.reset_exp_gen_version(version=exp_gen_version_list[trace.sub_trace_count - 1])
338+
else:
339+
self.reset_exp_gen_version(version=exp_gen_version_list[-1])
318340

319-
else:
320-
# set the knowledge base option back to False for the other traces
321-
DS_RD_SETTING.enable_knowledge_base = False
322341
return self.exp_gen.gen(trace)
323342

324343
else:

rdagent/scenarios/data_science/proposal/exp_gen/prompts_selector.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ auto_sota_selector:
1616
"explanation": "A brief explanation text for your selection."
1717
}
1818
19+
If you cannot make a selection, like no SOTA experiments and feedbacks, or the gap is too small, return
20+
{
21+
"selected_SOTA_idx": None,
22+
"explanation": "No SOTA experiments and feedbacks"
23+
}
24+
1925
user: |-
2026
# SOTA Experiments and Feedback
2127
{{ historical_sota_exp_with_desc_and_scores }}

rdagent/scenarios/data_science/proposal/exp_gen/proposal.py

Lines changed: 74 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,46 @@ def get_all_hypotheses(self, problem_dict: dict, hypothesis_dict: dict) -> list[
10301030
)
10311031
return result
10321032

1033+
# FIXME: remove this, dump solution, should be merged into identify_problem in V2
1034+
def identify_problems_v3(
1035+
self, trace: DSTrace, scenario_desc: str, sota_exp_desc: str, exp_feedback_list_desc: str
1036+
) -> Dict:
1037+
sub_trace = trace.get_parent_exps()
1038+
trace_length = len(trace.hist)
1039+
all_problems = {}
1040+
1041+
# 阶段一:探索期(主要场景问题)
1042+
if trace_length <= 3:
1043+
scen_problems = self.identify_scenario_problem(scenario_desc, sota_exp_desc)
1044+
for problem_name in scen_problems:
1045+
scen_problems[problem_name]["label"] = "SCENARIO_PROBLEM"
1046+
all_problems[problem_name] = scen_problems[problem_name]
1047+
self.scen_prob_multiplier = 3
1048+
1049+
# 阶段二:混合期(两种问题都考虑)
1050+
elif trace_length <= 6:
1051+
# 优先场景问题,但也考虑反馈
1052+
scen_problems = self.identify_scenario_problem(scenario_desc, sota_exp_desc)
1053+
for problem_name in scen_problems:
1054+
scen_problems[problem_name]["label"] = "SCENARIO_PROBLEM"
1055+
all_problems[problem_name] = scen_problems[problem_name]
1056+
1057+
fb_problems = self.identify_feedback_problem(scenario_desc, exp_feedback_list_desc, sota_exp_desc)
1058+
for problem_name in fb_problems:
1059+
fb_problems[problem_name]["label"] = "FEEDBACK_PROBLEM"
1060+
all_problems[problem_name] = fb_problems[problem_name]
1061+
self.scen_prob_multiplier = 2
1062+
1063+
# 阶段三:优化期(主要反馈问题)
1064+
else:
1065+
fb_problems = self.identify_feedback_problem(scenario_desc, exp_feedback_list_desc, sota_exp_desc)
1066+
for problem_name in fb_problems:
1067+
fb_problems[problem_name]["label"] = "FEEDBACK_PROBLEM"
1068+
all_problems[problem_name] = fb_problems[problem_name]
1069+
self.scen_prob_multiplier = 1
1070+
1071+
return all_problems
1072+
10331073
def gen(self, trace: DSTrace) -> DSExperiment:
10341074
pipeline = DS_RD_SETTING.coder_on_whole_pipeline
10351075
if not pipeline and (draft_exp := draft_exp_in_decomposition(self.scen, trace)):
@@ -1067,26 +1107,43 @@ def gen(self, trace: DSTrace) -> DSExperiment:
10671107
pipeline=pipeline,
10681108
)
10691109

1110+
if DS_RD_SETTING.enable_inject_diverse and len(trace.hist) > 0:
1111+
if len(trace.current_selection) == 0:
1112+
# start a new sub-trace, and inject diverse problems.
1113+
inject_diverse = True
1114+
logger.info("Start a new sub-trace, and inject diverse problems.")
1115+
else:
1116+
inject_diverse = False
1117+
else:
1118+
inject_diverse = False
10701119
# Step 1: Identify problems
10711120
all_problems = {}
1072-
if len(trace.hist) >= 3:
1073-
fb_problems = self.identify_feedback_problem(
1074-
scenario_desc=scenario_desc,
1075-
exp_feedback_list_desc=exp_feedback_list_desc,
1076-
sota_exp_desc=sota_exp_desc,
1077-
)
1078-
for problem_name in fb_problems:
1079-
fb_problems[problem_name]["label"] = "FEEDBACK_PROBLEM"
1080-
all_problems[problem_name] = fb_problems[problem_name]
10811121

1082-
if len(trace.hist) < 9:
1083-
scen_problems = self.identify_scenario_problem(
1084-
scenario_desc=scenario_desc,
1085-
sota_exp_desc=sota_exp_desc,
1086-
)
1087-
for problem_name in scen_problems:
1088-
scen_problems[problem_name]["label"] = "SCENARIO_PROBLEM"
1089-
all_problems[problem_name] = scen_problems[problem_name]
1122+
all_problems = self.identify_problems_v3(
1123+
trace=trace,
1124+
scenario_desc=scenario_desc,
1125+
sota_exp_desc=sota_exp_desc,
1126+
exp_feedback_list_desc=exp_feedback_list_desc,
1127+
)
1128+
1129+
# if len(trace.hist) > 3:
1130+
# fb_problems = self.identify_feedback_problem(
1131+
# scenario_desc=scenario_desc,
1132+
# exp_feedback_list_desc=exp_feedback_list_desc,
1133+
# sota_exp_desc=sota_exp_desc,
1134+
# )
1135+
# for problem_name in fb_problems:
1136+
# fb_problems[problem_name]["label"] = "FEEDBACK_PROBLEM"
1137+
# all_problems[problem_name] = fb_problems[problem_name]
1138+
1139+
# if len(trace.hist) < 9:
1140+
# scen_problems = self.identify_scenario_problem(
1141+
# scenario_desc=scenario_desc,
1142+
# sota_exp_desc=sota_exp_desc,
1143+
# )
1144+
# for problem_name in scen_problems:
1145+
# scen_problems[problem_name]["label"] = "SCENARIO_PROBLEM"
1146+
# all_problems[problem_name] = scen_problems[problem_name]
10901147

10911148
# Step 1.5: Sample ideas from idea pool
10921149
if DS_RD_SETTING.enable_knowledge_base:
@@ -1128,7 +1185,6 @@ def gen(self, trace: DSTrace) -> DSExperiment:
11281185
pickled_problem_name, new_hypothesis = self.hypothesis_rank(
11291186
hypothesis_dict=hypothesis_dict,
11301187
problem_dict=all_problems,
1131-
trace=trace,
11321188
)
11331189
# Step 3.5: Update knowledge base with the picked problem
11341190
if DS_RD_SETTING.enable_knowledge_base:

rdagent/scenarios/data_science/proposal/exp_gen/sota_exp_select.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def get_sota_exp_to_submit(self, trace: Trace) -> DSExperiment | None:
4747
sota_exp_fb_list = trace.experiment_and_feedback_list_after_init(
4848
return_type="sota", search_type="all", max_retrieve_num=DS_RD_SETTING.max_sota_retrieved_num
4949
)
50-
50+
logger.info(f"Auto SOTA selector: Found {len(sota_exp_fb_list)} SOTA experiments")
5151
if len(sota_exp_fb_list) == 0:
5252
logger.info("Auto SOTA selector: No SOTA in trace yet")
5353
return None
@@ -69,6 +69,8 @@ def get_sota_exp_to_submit(self, trace: Trace) -> DSExperiment | None:
6969
leaves: list[int] = trace.get_leaves()
7070

7171
if len(leaves) >= 2:
72+
73+
logger.info(f"Auto SOTA selector: Multiple traces found, collecting SOTA experiments from each trace")
7274
# multiple trace case, collect the latest SOTA experiments from each trace
7375
new_sota_exp_fb_list: list[tuple[DSExperiment, ExperimentFeedback]] = []
7476
# calculate the number of SOTA experiments to retrieve from each trace
@@ -81,11 +83,26 @@ def get_sota_exp_to_submit(self, trace: Trace) -> DSExperiment | None:
8183
selection=(leaf,),
8284
max_retrieve_num=max_sota_retrieved_num_per_trace,
8385
)
86+
logger.info(
87+
f"Auto SOTA selector: Collected {len(sota_exp_fb_list_per_trace)} SOTA experiments from trace with leaf #. {leaf}"
88+
)
8489

8590
new_sota_exp_fb_list.extend(sota_exp_fb_list_per_trace)
8691

8792
sota_exp_fb_list = new_sota_exp_fb_list
8893

94+
if len(sota_exp_fb_list) == 0:
95+
logger.info("Auto SOTA selector: No SOTA in trace yet")
96+
return None
97+
98+
elif len(sota_exp_fb_list) == 1:
99+
logger.info("Auto SOTA selector: Only one SOTA in trace, using it")
100+
return sota_exp_fb_list[0][0]
101+
else:
102+
logger.info(
103+
f"Auto SOTA selector: {len(sota_exp_fb_list)} SOTA experiments found in all traces, calling LLM to select the best one"
104+
)
105+
89106
for i, (exp, ef) in enumerate(sota_exp_fb_list):
90107
if exp:
91108
current_final_score = pd.DataFrame(exp.result).loc["ensemble"].iloc[0]
@@ -115,7 +132,7 @@ def get_sota_exp_to_submit(self, trace: Trace) -> DSExperiment | None:
115132

116133
sota_submit_idx = response_dict.get("selected_SOTA_idx", None)
117134

118-
if sota_submit_idx is not None:
135+
if sota_submit_idx and int(sota_submit_idx) - 1 < len(sota_exp_fb_list):
119136
sota_submit = sota_exp_fb_list[int(sota_submit_idx) - 1]
120137
sota_idx_in_trace = trace.hist.index(sota_submit)
121138
logger.info(
@@ -124,8 +141,12 @@ def get_sota_exp_to_submit(self, trace: Trace) -> DSExperiment | None:
124141
return sota_submit[0]
125142
else:
126143
# no SOTA experiment to submit, using the latest SOTA experiment
127-
logger.info("Auto SOTA selector: No SOTA experiment to submit, using the latest SOTA experiment")
128-
return sota_exp_fb_list[-1][0]
144+
if len(sota_exp_fb_list) > 0:
145+
logger.info("Auto SOTA selector: No SOTA experiment to submit, using the latest SOTA experiment")
146+
return sota_exp_fb_list[-1][0]
147+
else:
148+
logger.info("Auto SOTA selector: No SOTA experiment in trace yet")
149+
return None
129150

130151

131152
class BestValidSelector(SOTAexpSelector):

0 commit comments

Comments
 (0)