1
+ import shutil
2
+ import subprocess
3
+ from datetime import datetime
1
4
from pathlib import Path
2
- from typing import Any
5
+ from typing import Any , Optional , Union
3
6
4
7
import fire
5
8
28
31
from rdagent .scenarios .data_science .dev .runner import DSCoSTEERRunner
29
32
from rdagent .scenarios .data_science .experiment .experiment import DSExperiment
30
33
from rdagent .scenarios .data_science .proposal .exp_gen import DSExpGen , DSTrace
31
- from rdagent .scenarios .data_science .proposal .exp_gen .select import (
32
- LatestCKPSelector ,
33
- SOTAJumpCKPSelector ,
34
- )
34
+ from rdagent .scenarios .data_science .proposal .exp_gen .idea_pool import DSKnowledgeBase
35
+ from rdagent .scenarios .data_science .proposal .exp_gen .select import LatestCKPSelector
35
36
from rdagent .scenarios .kaggle .kaggle_crawler import download_data
36
37
37
38
@@ -42,13 +43,6 @@ def __init__(self, PROP_SETTING: BasePropSetting):
42
43
logger .log_object (PROP_SETTING .competition , tag = "competition" )
43
44
scen : Scenario = import_class (PROP_SETTING .scen )(PROP_SETTING .competition )
44
45
45
- ### shared components in the workflow # TODO: check if
46
- knowledge_base = (
47
- import_class (PROP_SETTING .knowledge_base )(PROP_SETTING .knowledge_base_path , scen )
48
- if PROP_SETTING .knowledge_base != ""
49
- else None
50
- )
51
-
52
46
# 1) task generation from scratch
53
47
# self.scratch_gen: tuple[HypothesisGen, Hypothesis2Experiment] = DummyHypothesisGen(scen),
54
48
@@ -70,8 +64,13 @@ def __init__(self, PROP_SETTING: BasePropSetting):
70
64
# self.summarizer: Experiment2Feedback = import_class(PROP_SETTING.summarizer)(scen)
71
65
# logger.log_object(self.summarizer, tag="summarizer")
72
66
73
- # self.trace = KGTrace(scen=scen, knowledge_base=knowledge_base)
74
- self .trace = DSTrace (scen = scen )
67
+ if DS_RD_SETTING .enable_knowledge_base and DS_RD_SETTING .knowledge_base_version == "v1" :
68
+ knowledge_base = DSKnowledgeBase (
69
+ path = DS_RD_SETTING .knowledge_base_path , idea_pool_json_path = DS_RD_SETTING .idea_pool_json_path
70
+ )
71
+ self .trace = DSTrace (scen = scen , knowledge_base = knowledge_base )
72
+ else :
73
+ self .trace = DSTrace (scen = scen )
75
74
self .summarizer = DSExperiment2Feedback (scen )
76
75
super (RDLoop , self ).__init__ ()
77
76
@@ -166,10 +165,70 @@ def record(self, prev_out: dict[str, Any]):
166
165
self .trace = DSTrace (scen = self .trace .scen , knowledge_base = self .trace .knowledge_base )
167
166
logger .log_object (self .trace , tag = "trace" )
168
167
logger .log_object (self .trace .sota_experiment (), tag = "SOTA experiment" )
168
+ if DS_RD_SETTING .enable_knowledge_base and DS_RD_SETTING .knowledge_base_version == "v1" :
169
+ logger .log_object (self .trace .knowledge_base , tag = "knowledge_base" )
170
+ self .trace .knowledge_base .dump ()
171
+
172
+ if (
173
+ DS_RD_SETTING .enable_log_archive
174
+ and DS_RD_SETTING .log_archive_path is not None
175
+ and Path (DS_RD_SETTING .log_archive_path ).is_dir ()
176
+ ):
177
+ start_archive_datetime = datetime .now ()
178
+ logger .info (f"Archiving log folder after loop { self .loop_idx } " )
179
+ tar_path = (
180
+ Path (
181
+ DS_RD_SETTING .log_archive_temp_path
182
+ if DS_RD_SETTING .log_archive_temp_path
183
+ else DS_RD_SETTING .log_archive_path
184
+ )
185
+ / "mid_log.tar"
186
+ )
187
+ subprocess .run (["tar" , "-cf" , str (tar_path ), "-C" , (Path ().cwd () / "log" ), "." ], check = True )
188
+ if DS_RD_SETTING .log_archive_temp_path is not None :
189
+ shutil .move (tar_path , Path (DS_RD_SETTING .log_archive_path ) / "mid_log.tar" )
190
+ tar_path = Path (DS_RD_SETTING .log_archive_path ) / "mid_log.tar"
191
+ shutil .copy (
192
+ tar_path , Path (DS_RD_SETTING .log_archive_path ) / "mid_log_bak.tar"
193
+ ) # backup when upper code line is killed when running
194
+ self .timer .add_duration (datetime .now () - start_archive_datetime )
195
+
196
+ @classmethod
197
+ def load (
198
+ cls , path : Union [str , Path ], output_path : Optional [Union [str , Path ]] = None , do_truncate : bool = False
199
+ ) -> "LoopBase" :
200
+ session = super ().load (path , output_path , do_truncate )
201
+ if (
202
+ DS_RD_SETTING .enable_knowledge_base
203
+ and DS_RD_SETTING .knowledge_base_version == "v1"
204
+ and Path (DS_RD_SETTING .knowledge_base_path ).exists ()
205
+ ):
206
+ knowledge_base = DSKnowledgeBase (path = DS_RD_SETTING .knowledge_base_path )
207
+ session .trace .knowledge_base = knowledge_base
208
+ return session
209
+
210
+ def dump (self , path : str | Path ) -> None :
211
+ """
212
+ Since knowledge_base is big and we don't want to dump it every time
213
+ So we remove it from the trace before dumping and restore it after.
214
+ """
215
+ backup_knowledge_base = None
216
+ if self .trace .knowledge_base is not None :
217
+ backup_knowledge_base = self .trace .knowledge_base
218
+ self .trace .knowledge_base = None
219
+ super ().dump (path )
220
+ if backup_knowledge_base is not None :
221
+ self .trace .knowledge_base = backup_knowledge_base
169
222
170
223
171
224
def main (
172
- path = None , output_path = None , step_n = None , loop_n = None , competition = "bms-molecular-translation" , do_truncate = True
225
+ path = None ,
226
+ output_path = None ,
227
+ step_n = None ,
228
+ loop_n = None ,
229
+ competition = "bms-molecular-translation" ,
230
+ do_truncate = True ,
231
+ timeout = None ,
173
232
):
174
233
"""
175
234
@@ -213,7 +272,7 @@ def main(
213
272
kaggle_loop = DataScienceRDLoop (DS_RD_SETTING )
214
273
else :
215
274
kaggle_loop = DataScienceRDLoop .load (path , output_path , do_truncate )
216
- kaggle_loop .run (step_n = step_n , loop_n = loop_n )
275
+ kaggle_loop .run (step_n = step_n , loop_n = loop_n , all_duration = timeout )
217
276
218
277
219
278
if __name__ == "__main__" :
0 commit comments