Skip to content

Commit e62aefa

Browse files
RolandMinruipeteryang1peteryangmsXu
authored
feat: idea pool integrated to exp_gen & add timer to RD-Agent & pause-resume to RD-loops (#795)
* update all code * update all code * dump knowledge base * rename the tag * add timer to RD-Agent * fix CI * fix CI * use batch embedding * fix a small bug * fix prompt bug * feat: add pause resume to handle K8S cluster pause (#804) * add resume to cluster running * fix non-pickle problem * fix a small bug * fix a small bug * avoid shutil move error * refine the logic * move knowledge base out of session * avoid mistake information to pipeline coding * avoid load and dump in steps * archive the right folder * small improvement * avoid restart when timer is already started * fix CI --------- Co-authored-by: Xu Yang <[email protected]> --------- Co-authored-by: Xu Yang <[email protected]> Co-authored-by: Xu Yang <[email protected]> Co-authored-by: Xu <[email protected]>
1 parent 6f8cdd9 commit e62aefa

File tree

15 files changed

+576
-83
lines changed

15 files changed

+576
-83
lines changed

rdagent/app/data_science/conf.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,18 @@ class DataScienceBasePropSetting(KaggleBasePropSetting):
4141
enable_doc_dev: bool = False
4242
model_dump_check_level: Literal["medium", "high"] = "medium"
4343

44+
### knowledge base
45+
enable_knowledge_base: bool = False
46+
knowledge_base_version: str = "v1"
47+
knowledge_base_path: str | None = None
48+
idea_pool_json_path: str | None = None
49+
50+
### archive log folder after each loop
51+
enable_log_archive: bool = True
52+
log_archive_path: str | None = None
53+
log_archive_temp_path: str | None = (
54+
None # This is to store the mid tar file since writing the tar file is preferred in local storage then copy to target storage
55+
)
56+
4457

4558
DS_RD_SETTING = DataScienceBasePropSetting()

rdagent/app/data_science/loop.py

Lines changed: 75 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import shutil
2+
import subprocess
3+
from datetime import datetime
14
from pathlib import Path
2-
from typing import Any
5+
from typing import Any, Optional, Union
36

47
import fire
58

@@ -28,10 +31,8 @@
2831
from rdagent.scenarios.data_science.dev.runner import DSCoSTEERRunner
2932
from rdagent.scenarios.data_science.experiment.experiment import DSExperiment
3033
from rdagent.scenarios.data_science.proposal.exp_gen import DSExpGen, DSTrace
31-
from rdagent.scenarios.data_science.proposal.exp_gen.select import (
32-
LatestCKPSelector,
33-
SOTAJumpCKPSelector,
34-
)
34+
from rdagent.scenarios.data_science.proposal.exp_gen.idea_pool import DSKnowledgeBase
35+
from rdagent.scenarios.data_science.proposal.exp_gen.select import LatestCKPSelector
3536
from rdagent.scenarios.kaggle.kaggle_crawler import download_data
3637

3738

@@ -42,13 +43,6 @@ def __init__(self, PROP_SETTING: BasePropSetting):
4243
logger.log_object(PROP_SETTING.competition, tag="competition")
4344
scen: Scenario = import_class(PROP_SETTING.scen)(PROP_SETTING.competition)
4445

45-
### shared components in the workflow # TODO: check if
46-
knowledge_base = (
47-
import_class(PROP_SETTING.knowledge_base)(PROP_SETTING.knowledge_base_path, scen)
48-
if PROP_SETTING.knowledge_base != ""
49-
else None
50-
)
51-
5246
# 1) task generation from scratch
5347
# self.scratch_gen: tuple[HypothesisGen, Hypothesis2Experiment] = DummyHypothesisGen(scen),
5448

@@ -70,8 +64,13 @@ def __init__(self, PROP_SETTING: BasePropSetting):
7064
# self.summarizer: Experiment2Feedback = import_class(PROP_SETTING.summarizer)(scen)
7165
# logger.log_object(self.summarizer, tag="summarizer")
7266

73-
# self.trace = KGTrace(scen=scen, knowledge_base=knowledge_base)
74-
self.trace = DSTrace(scen=scen)
67+
if DS_RD_SETTING.enable_knowledge_base and DS_RD_SETTING.knowledge_base_version == "v1":
68+
knowledge_base = DSKnowledgeBase(
69+
path=DS_RD_SETTING.knowledge_base_path, idea_pool_json_path=DS_RD_SETTING.idea_pool_json_path
70+
)
71+
self.trace = DSTrace(scen=scen, knowledge_base=knowledge_base)
72+
else:
73+
self.trace = DSTrace(scen=scen)
7574
self.summarizer = DSExperiment2Feedback(scen)
7675
super(RDLoop, self).__init__()
7776

@@ -166,10 +165,70 @@ def record(self, prev_out: dict[str, Any]):
166165
self.trace = DSTrace(scen=self.trace.scen, knowledge_base=self.trace.knowledge_base)
167166
logger.log_object(self.trace, tag="trace")
168167
logger.log_object(self.trace.sota_experiment(), tag="SOTA experiment")
168+
if DS_RD_SETTING.enable_knowledge_base and DS_RD_SETTING.knowledge_base_version == "v1":
169+
logger.log_object(self.trace.knowledge_base, tag="knowledge_base")
170+
self.trace.knowledge_base.dump()
171+
172+
if (
173+
DS_RD_SETTING.enable_log_archive
174+
and DS_RD_SETTING.log_archive_path is not None
175+
and Path(DS_RD_SETTING.log_archive_path).is_dir()
176+
):
177+
start_archive_datetime = datetime.now()
178+
logger.info(f"Archiving log folder after loop {self.loop_idx}")
179+
tar_path = (
180+
Path(
181+
DS_RD_SETTING.log_archive_temp_path
182+
if DS_RD_SETTING.log_archive_temp_path
183+
else DS_RD_SETTING.log_archive_path
184+
)
185+
/ "mid_log.tar"
186+
)
187+
subprocess.run(["tar", "-cf", str(tar_path), "-C", (Path().cwd() / "log"), "."], check=True)
188+
if DS_RD_SETTING.log_archive_temp_path is not None:
189+
shutil.move(tar_path, Path(DS_RD_SETTING.log_archive_path) / "mid_log.tar")
190+
tar_path = Path(DS_RD_SETTING.log_archive_path) / "mid_log.tar"
191+
shutil.copy(
192+
tar_path, Path(DS_RD_SETTING.log_archive_path) / "mid_log_bak.tar"
193+
) # backup when upper code line is killed when running
194+
self.timer.add_duration(datetime.now() - start_archive_datetime)
195+
196+
@classmethod
197+
def load(
198+
cls, path: Union[str, Path], output_path: Optional[Union[str, Path]] = None, do_truncate: bool = False
199+
) -> "LoopBase":
200+
session = super().load(path, output_path, do_truncate)
201+
if (
202+
DS_RD_SETTING.enable_knowledge_base
203+
and DS_RD_SETTING.knowledge_base_version == "v1"
204+
and Path(DS_RD_SETTING.knowledge_base_path).exists()
205+
):
206+
knowledge_base = DSKnowledgeBase(path=DS_RD_SETTING.knowledge_base_path)
207+
session.trace.knowledge_base = knowledge_base
208+
return session
209+
210+
def dump(self, path: str | Path) -> None:
211+
"""
212+
Since knowledge_base is big and we don't want to dump it every time
213+
So we remove it from the trace before dumping and restore it after.
214+
"""
215+
backup_knowledge_base = None
216+
if self.trace.knowledge_base is not None:
217+
backup_knowledge_base = self.trace.knowledge_base
218+
self.trace.knowledge_base = None
219+
super().dump(path)
220+
if backup_knowledge_base is not None:
221+
self.trace.knowledge_base = backup_knowledge_base
169222

170223

171224
def main(
172-
path=None, output_path=None, step_n=None, loop_n=None, competition="bms-molecular-translation", do_truncate=True
225+
path=None,
226+
output_path=None,
227+
step_n=None,
228+
loop_n=None,
229+
competition="bms-molecular-translation",
230+
do_truncate=True,
231+
timeout=None,
173232
):
174233
"""
175234
@@ -213,7 +272,7 @@ def main(
213272
kaggle_loop = DataScienceRDLoop(DS_RD_SETTING)
214273
else:
215274
kaggle_loop = DataScienceRDLoop.load(path, output_path, do_truncate)
216-
kaggle_loop.run(step_n=step_n, loop_n=loop_n)
275+
kaggle_loop.run(step_n=step_n, loop_n=loop_n, all_duration=timeout)
217276

218277

219278
if __name__ == "__main__":

rdagent/components/knowledge_management/graph.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,17 @@
1313
cosine,
1414
)
1515
from rdagent.core.knowledge_base import KnowledgeBase
16+
from rdagent.log import rdagent_logger as logger
1617
from rdagent.oai.llm_utils import APIBackend
1718

1819
Node = KnowledgeMetaData
1920

2021

2122
class UndirectedNode(Node):
22-
def __init__(self, content: str = "", label: str = "", embedding: Any = None) -> None:
23+
def __init__(self, content: str = "", label: str = "", embedding: Any = None, appendix: Any = None) -> None:
2324
super().__init__(content, label, embedding)
2425
self.neighbors: set[UndirectedNode] = set()
26+
self.appendix = appendix # appendix stores any additional information
2527
assert isinstance(content, str), "content must be a string"
2628

2729
def add_neighbor(self, node: UndirectedNode) -> None:
@@ -86,6 +88,10 @@ def batch_embedding(nodes: list[Node]) -> list[Node]:
8688
size = 16
8789
embeddings = []
8890
for i in range(0, len(contents), size):
91+
logger.info(
92+
f"Creating embedding for index {i} to {i + size} with {len(contents)} contents",
93+
tag="batch embedding",
94+
)
8995
embeddings.extend(
9096
APIBackend().create_embedding(input_content=contents[i : i + size]),
9197
)
@@ -270,7 +276,7 @@ def semantic_search(
270276
self,
271277
node: UndirectedNode | str,
272278
similarity_threshold: float = 0.0,
273-
topk_k: int = 5,
279+
topk_k: int = None,
274280
constraint_labels: list[str] | None = None,
275281
) -> list[UndirectedNode]:
276282
"""

rdagent/components/knowledge_management/vector_base.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def add(self, document: Union[Document, List[Document]]):
8787
"""
8888
pass
8989

90-
def search(self, content: str, topk_k: int = 5, similarity_threshold: float = 0) -> List[Document]:
90+
def search(self, content: str, topk_k: int | None = None, similarity_threshold: float = 0) -> List[Document]:
9191
"""
9292
search vector_df by node
9393
Parameters
@@ -156,7 +156,11 @@ def add(self, document: Union[Document, List[Document]]):
156156
self.add(document=doc)
157157

158158
def search(
159-
self, content: str, topk_k: int = 5, similarity_threshold: float = 0, constraint_labels: list[str] | None = None
159+
self,
160+
content: str,
161+
topk_k: int | None = None,
162+
similarity_threshold: float = 0,
163+
constraint_labels: list[str] | None = None,
160164
) -> Tuple[List[Document], List]:
161165
"""
162166
Search vector by node's embedding.
@@ -192,7 +196,9 @@ def search(
192196
lambda x: 1 - cosine(x, document.embedding)
193197
) # cosine is cosine distance, 1-similarity
194198

195-
searched_similarities = similarities[similarities > similarity_threshold].nlargest(topk_k)
199+
searched_similarities = similarities[similarities > similarity_threshold]
200+
if topk_k is not None:
201+
searched_similarities = searched_similarities.nlargest(topk_k)
196202
most_similar_docs = filtered_df.loc[searched_similarities.index]
197203

198204
docs = []

rdagent/log/timer.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import re
2+
from datetime import datetime, timedelta
3+
4+
from rdagent.core.utils import SingletonBaseClass
5+
from rdagent.log import rdagent_logger as logger
6+
7+
8+
class RDAgentTimer:
9+
def __init__(self) -> None:
10+
self.started: bool = False
11+
self.target_time: datetime | None = None
12+
self.all_duration: timedelta | None = None
13+
self.remain_time_duration: timedelta | None = None
14+
15+
def reset(self, all_duration: str | timedelta) -> None:
16+
if isinstance(all_duration, str):
17+
pattern = re.compile(r"^\s*(\d*\.?\d+)\s*([smhd]?)\s*$")
18+
19+
match = pattern.match(all_duration)
20+
if not match:
21+
return None
22+
value = float(match.group(1))
23+
unit = match.group(2)
24+
if unit == "s":
25+
self.all_duration = timedelta(seconds=value)
26+
elif unit == "m":
27+
self.all_duration = timedelta(minutes=value)
28+
elif unit == "h":
29+
self.all_duration = timedelta(hours=value)
30+
elif unit == "d":
31+
self.all_duration = timedelta(days=value)
32+
else:
33+
self.all_duration = timedelta(seconds=value)
34+
elif isinstance(all_duration, timedelta):
35+
self.all_duration = all_duration
36+
self.target_time = datetime.now() + self.all_duration
37+
logger.info(f"Timer set to {self.all_duration} seconds and counting down.")
38+
self.started = True
39+
return None
40+
41+
def restart_by_remain_time(self) -> None:
42+
if self.remain_time_duration is not None:
43+
self.target_time = datetime.now() + self.remain_time_duration
44+
self.started = True
45+
logger.info(f"Timer restarted with remaining time: {self.remain_time_duration}")
46+
else:
47+
logger.warning("No remaining time to restart the timer.")
48+
return None
49+
50+
def add_duration(self, duration: timedelta) -> None:
51+
if self.started and self.target_time is not None:
52+
logger.info(f"Adding {duration} to the timer. Currently {self.remain_time()} remains.")
53+
self.target_time = self.target_time + duration
54+
self.update_remain_time()
55+
56+
def is_timeout(self) -> bool:
57+
if self.started and self.target_time is not None:
58+
self.update_remain_time()
59+
if datetime.now() > self.target_time:
60+
return True
61+
return False
62+
63+
def update_remain_time(self) -> None:
64+
if self.started and self.target_time is not None:
65+
self.remain_time_duration = self.target_time - datetime.now()
66+
return None
67+
68+
def remain_time(self) -> timedelta | None:
69+
if self.started:
70+
self.update_remain_time()
71+
return self.remain_time_duration
72+
return None
73+
74+
75+
class RDAgentTimerWrapper(SingletonBaseClass):
76+
def __init__(self) -> None:
77+
self.timer: RDAgentTimer = RDAgentTimer()
78+
79+
def replace_timer(self, timer: RDAgentTimer) -> None:
80+
self.timer = timer
81+
logger.info("Timer replaced successfully.")
82+
83+
84+
RD_Agent_TIMER_wrapper = RDAgentTimerWrapper()

rdagent/oai/backend/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import uuid
88
from abc import ABC, abstractmethod
99
from copy import deepcopy
10+
from datetime import datetime
1011
from pathlib import Path
1112
from typing import Any, Optional, cast
1213

@@ -15,6 +16,7 @@
1516
from rdagent.core.utils import LLM_CACHE_SEED_GEN, SingletonBaseClass
1617
from rdagent.log import LogColors
1718
from rdagent.log import rdagent_logger as logger
19+
from rdagent.log.timer import RD_Agent_TIMER_wrapper
1820
from rdagent.oai.llm_conf import LLM_SETTINGS
1921
from rdagent.utils import md5_hash
2022

@@ -330,6 +332,7 @@ def _try_create_chat_completion_or_embedding( # type: ignore[no-untyped-def]
330332
max_retry = LLM_SETTINGS.max_retry if LLM_SETTINGS.max_retry is not None else max_retry
331333
timeout_count = 0
332334
for i in range(max_retry):
335+
API_start_time = datetime.now()
333336
try:
334337
if embedding:
335338
return self._create_embedding_with_cache(*args, **kwargs)
@@ -361,6 +364,8 @@ def _try_create_chat_completion_or_embedding( # type: ignore[no-untyped-def]
361364
raise e
362365
else:
363366
time.sleep(self.retry_wait_seconds)
367+
if RD_Agent_TIMER_wrapper.timer.started and not isinstance(e, json.decoder.JSONDecodeError):
368+
RD_Agent_TIMER_wrapper.timer.add_duration(datetime.now() - API_start_time)
364369
logger.warning(str(e))
365370
logger.warning(f"Retrying {i+1}th time...")
366371
error_message = f"Failed to create chat completion after {max_retry} retries."

rdagent/oai/backend/litellm.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,16 @@ def _create_embedding_inner_function(
5151
"""
5252
Call the embedding function
5353
"""
54-
response_list = []
55-
for input_content_iter in input_content_list:
56-
model_name = LITELLM_SETTINGS.embedding_model
57-
logger.info(f"{LogColors.GREEN}Using emb model{LogColors.END} {model_name}", tag="debug_litellm_emb")
58-
logger.info(f"Creating embedding for: {input_content_iter}", tag="debug_litellm_emb")
59-
if not isinstance(input_content_iter, str):
60-
raise ValueError("Input content must be a string")
61-
response = embedding(
62-
model=model_name,
63-
input=input_content_iter,
64-
*args,
65-
**kwargs,
66-
)
67-
response_list.append(response.data[0]["embedding"])
54+
model_name = LITELLM_SETTINGS.embedding_model
55+
logger.info(f"{LogColors.GREEN}Using emb model{LogColors.END} {model_name}", tag="debug_litellm_emb")
56+
logger.info(f"Creating embedding for: {input_content_list}", tag="debug_litellm_emb")
57+
response = embedding(
58+
model=model_name,
59+
input=input_content_list,
60+
*args,
61+
**kwargs,
62+
)
63+
response_list = [data["embedding"] for data in response.data]
6864
return response_list
6965

7066
def _create_chat_completion_inner_function( # type: ignore[no-untyped-def] # noqa: C901, PLR0912, PLR0915

0 commit comments

Comments
 (0)