Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion opencompass/models/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ class OpenAI(BaseAPIModel):
the request
think_tag (str, optional): The tag to use for reasoning content.
Defaults to '</think>'.
max_workers (int, optional): Maximum number of worker threads for
concurrent API requests. For I/O-intensive API calls, recommended
value is 10-20. Defaults to None (uses CPU count * 2).
"""

is_api: bool = True
Expand All @@ -95,6 +98,7 @@ def __init__(
extra_body: Optional[Dict] = None,
verbose: bool = False,
think_tag: str = '</think>',
max_workers: Optional[int] = None,
):

super().__init__(
Expand All @@ -119,6 +123,12 @@ def __init__(
self.extra_body = extra_body
self.think_tag = think_tag

if max_workers is None:
cpu_count = os.cpu_count() or 1
self.max_workers = min(32, (cpu_count + 5) * 2)
else:
self.max_workers = max_workers

if isinstance(key, str):
if key == 'ENV':
if 'OPENAI_API_KEY' not in os.environ:
Expand Down Expand Up @@ -175,7 +185,7 @@ def generate(
if self.temperature is not None:
temperature = self.temperature

with ThreadPoolExecutor() as executor:
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
results = list(
tqdm(
executor.map(
Expand Down
4 changes: 4 additions & 0 deletions opencompass/runners/rjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def _run_task(self, task_name, log_path, poll_interval=60):

found_dict = False
for line in output.splitlines():
if 'Starting' in line:
status = 'Starting'
found_dict = True
break
if '{' in line and '}' in line:
try:
d = ast.literal_eval(
Expand Down