Skip to content

Commit 393bc18

Browse files
Merge branch 'master' into SD2-1118-make-eden-ai-chat-api-open-ai-compatible
2 parents d623583 + d3dfcf5 commit 393bc18

15 files changed

+354
-296
lines changed

edenai_apis/features/text/generation/generation_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@ def generation_arguments(provider_name: str):
1212
"mistral": "large-latest",
1313
"ai21labs": "j2-ultra",
1414
"meta": "llama3-1-70b-instruct-v1:0",
15+
"xai": "grok-2-latest",
1516
},
1617
}

edenai_apis/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def get_async_job_result(
363363
provider_name: str,
364364
feature: str,
365365
subfeature: str,
366-
async_job_id: AsyncLaunchJobResponseType,
366+
async_job_id: str,
367367
phase: str = "",
368368
fake: bool = False,
369369
user_email=None,

edenai_apis/llmengine/llm_engine.py

Lines changed: 52 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,70 @@
1-
import os
2-
import uuid
3-
import json
41
import base64
2+
import json
53
import mimetypes
4+
import os
5+
import re
6+
import uuid
67
from io import BytesIO
7-
from typing import List, Literal, Optional, Union, Dict, Type
8+
from typing import Dict, List, Literal, Optional, Type, Union
89

910
import httpx
11+
from loaders.data_loader import ProviderDataEnum
12+
from loaders.loaders import load_provider
1013
from pydantic import BaseModel
11-
from edenai_apis.utils.upload_s3 import upload_file_bytes_to_s3
12-
from edenai_apis.llmengine.types.response_types import (
13-
ResponseModel,
14+
15+
from edenai_apis.features.image import (
16+
ExplicitContentDataClass,
17+
GeneratedImageDataClass,
18+
GenerationDataClass,
19+
LogoDetectionDataClass,
20+
QuestionAnswerDataClass,
1421
)
15-
from edenai_apis.llmengine.clients import LLM_COMPLETION_CLIENTS
16-
from edenai_apis.llmengine.clients.completion import CompletionClient
17-
from edenai_apis.llmengine.mapping import Mappings
18-
from edenai_apis.utils.types import ResponseType
19-
from edenai_apis.utils.exception import ProviderException
20-
from edenai_apis.features.translation import (
21-
AutomaticTranslationDataClass,
22-
LanguageDetectionDataClass,
22+
from edenai_apis.features.multimodal.chat import (
23+
ChatDataClass as ChatMultimodalDataClass,
24+
)
25+
from edenai_apis.features.multimodal.chat import (
26+
ChatStreamResponse as ChatMultimodalStreamResponse,
2327
)
28+
from edenai_apis.features.multimodal.chat import StreamChat as StreamMultimodalChat
2429
from edenai_apis.features.text import (
25-
SummarizeDataClass,
26-
TopicExtractionDataClass,
27-
SpellCheckDataClass,
28-
SentimentAnalysisDataClass,
29-
KeywordExtractionDataClass,
3030
AnonymizationDataClass,
31-
NamedEntityRecognitionDataClass,
3231
CodeGenerationDataClass,
32+
CustomClassificationDataClass,
33+
CustomNamedEntityRecognitionDataClass,
3334
EmbeddingDataClass,
3435
EmbeddingsDataClass,
36+
KeywordExtractionDataClass,
3537
ModerationDataClass,
38+
NamedEntityRecognitionDataClass,
39+
SentimentAnalysisDataClass,
40+
SpellCheckDataClass,
41+
SummarizeDataClass,
3642
TextModerationItem,
37-
CustomClassificationDataClass,
38-
CustomNamedEntityRecognitionDataClass,
39-
)
40-
from edenai_apis.features.text.moderation.category import (
41-
CategoryType as CategoryTypeModeration,
42-
)
43-
from edenai_apis.utils.conversion import standardized_confidence_score
44-
from edenai_apis.features.image import (
45-
LogoDetectionDataClass,
46-
QuestionAnswerDataClass,
47-
ExplicitContentDataClass,
48-
GeneratedImageDataClass,
49-
GenerationDataClass,
43+
TopicExtractionDataClass,
5044
)
5145
from edenai_apis.features.text.chat import ChatDataClass, ChatMessageDataClass
5246
from edenai_apis.features.text.chat.chat_dataclass import (
53-
StreamChat,
5447
ChatStreamResponse,
48+
StreamChat,
5549
ToolCall,
5650
)
57-
from edenai_apis.features.multimodal.chat import (
58-
ChatDataClass as ChatMultimodalDataClass,
59-
StreamChat as StreamMultimodalChat,
60-
ChatStreamResponse as ChatMultimodalStreamResponse,
51+
from edenai_apis.features.text.moderation.category import (
52+
CategoryType as CategoryTypeModeration,
53+
)
54+
from edenai_apis.features.translation import (
55+
AutomaticTranslationDataClass,
56+
LanguageDetectionDataClass,
6157
)
58+
from edenai_apis.llmengine.clients import LLM_COMPLETION_CLIENTS
59+
from edenai_apis.llmengine.clients.completion import CompletionClient
60+
from edenai_apis.llmengine.mapping import Mappings
6261
from edenai_apis.llmengine.prompts import BasePrompt
62+
from edenai_apis.llmengine.types.response_types import ResponseModel
6363
from edenai_apis.llmengine.utils.moderation import moderate
64-
from loaders.data_loader import ProviderDataEnum
65-
from loaders.loaders import load_provider
66-
import re
64+
from edenai_apis.utils.conversion import standardized_confidence_score
65+
from edenai_apis.utils.exception import ProviderException
66+
from edenai_apis.utils.types import ResponseType
67+
from edenai_apis.utils.upload_s3 import upload_file_bytes_to_s3
6768

6869

6970
class LLMEngine:
@@ -184,7 +185,7 @@ def chat(
184185
usage=response.usage,
185186
)
186187
else:
187-
stream = (
188+
stream_response = (
188189
ChatStreamResponse(
189190
text=chunk.choices[0].delta.content or "",
190191
blocked=False,
@@ -195,7 +196,8 @@ def chat(
195196
)
196197

197198
return ResponseType[StreamChat](
198-
original_response=None, standardized_response=StreamChat(stream=stream)
199+
original_response=None,
200+
standardized_response=StreamChat(stream=stream_response),
199201
)
200202

201203
@moderate
@@ -238,8 +240,8 @@ def multimodal_chat(
238240
args["response_format"] = response_format
239241
args["drop_invalid_params"] = True
240242
response = self.completion_client.completion(**args, **kwargs)
241-
response = ResponseModel.model_validate(response)
242243
if stream is False:
244+
response = ResponseModel.model_validate(response)
243245
generated_text = (
244246
response.choices[0].message.content or "" if response.choices else ""
245247
)
@@ -256,9 +258,9 @@ def multimodal_chat(
256258
)
257259

258260
else:
259-
stream = (
261+
stream_response = (
260262
ChatMultimodalStreamResponse(
261-
text=chunk["choices"][0]["delta"].get("content", ""),
263+
text=chunk["choices"][0]["delta"].get("content") or "",
262264
blocked=not chunk["choices"][0].get("finish_reason")
263265
in (None, "stop"),
264266
provider=self.provider_name,
@@ -269,7 +271,7 @@ def multimodal_chat(
269271

270272
return ResponseType[StreamMultimodalChat](
271273
original_response=None,
272-
standardized_response=StreamMultimodalChat(stream=stream),
274+
standardized_response=StreamMultimodalChat(stream=stream_response),
273275
)
274276

275277
def summarize(
@@ -820,12 +822,12 @@ def completion(
820822
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = location
821823
elif is_gemini:
822824
api_settings = load_provider(
823-
ProviderDataEnum.KEY, provider_name, api_keys=api_key
825+
ProviderDataEnum.KEY, provider_name=provider_name, api_keys=api_key
824826
)
825827
api_key = api_settings["genai_api_key"]
826828
else:
827829
api_settings = load_provider(
828-
ProviderDataEnum.KEY, provider_name, api_keys=api_key
830+
ProviderDataEnum.KEY, provider_name=provider_name, api_keys=api_key
829831
)
830832
api_key = api_settings["api_key"]
831833
try:
Lines changed: 91 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,108 @@
1-
import sys
21
from pprint import pprint
2+
from time import sleep
3+
from typing import Tuple
34

45
import requests
5-
from edenai_apis.interface import compute_output
6+
7+
from edenai_apis.interface import compute_output, get_async_job_result
68
from edenai_apis.loaders.data_loader import FeatureDataEnum
79
from edenai_apis.loaders.loaders import load_feature
810

9-
HOURLY = "hourly"
1011

11-
if __name__ == "__main__":
12-
interval = sys.argv[1]
12+
def process_async_get_result(
13+
provider: str,
14+
feature: str,
15+
subfeature: str,
16+
phase: str,
17+
async_job_id: str,
18+
max_time=300,
19+
sleep_time=5,
20+
):
21+
while max_time > 0:
22+
res = get_async_job_result(
23+
provider_name=provider,
24+
feature=feature,
25+
subfeature=subfeature,
26+
async_job_id=async_job_id,
27+
phase=phase,
28+
)
29+
if res["status"] == "failed":
30+
raise Exception(res["error"])
31+
elif res["status"] == "succeeded":
32+
return
33+
sleep(sleep_time)
34+
max_time -= sleep_time
35+
36+
raise TimeoutError(f"Async job timed out after {max_time} seconds")
37+
38+
39+
def process_provider(provider_info: Tuple[str, str, str, str]):
40+
provider, feature, subfeature, phase = provider_info
41+
if phase == "create_project":
42+
return None
43+
try:
44+
arguments = load_feature(
45+
FeatureDataEnum.SAMPLES_ARGS,
46+
feature=feature,
47+
subfeature=subfeature,
48+
phase=phase,
49+
)
50+
except NotImplementedError:
51+
return None
52+
try:
53+
res = compute_output(
54+
provider_name=provider,
55+
feature=feature,
56+
subfeature=subfeature,
57+
args=arguments,
58+
phase=phase,
59+
)
60+
if res["status"] == "fail":
61+
raise Exception(res["error"])
62+
63+
# poll for result if async job
64+
if "provider_job_id" in res:
65+
process_async_get_result(
66+
provider=provider,
67+
feature=feature,
68+
subfeature=subfeature,
69+
phase=phase,
70+
async_job_id=res["provider_job_id"],
71+
)
72+
73+
return (provider, feature, subfeature, None)
74+
except Exception as exc:
75+
return (provider, feature, subfeature, exc)
76+
77+
78+
def fetch_provider_subfeatures():
79+
url = "https://api.edenai.run/v2/info/provider_subfeatures"
80+
response = requests.get(url)
81+
return response.json()
82+
83+
84+
def main():
1385
not_working = []
14-
query_is_working = "?is_working=False" if interval == HOURLY else ""
15-
provider_subfeatures = requests.get(
16-
url=f"https://api.edenai.run/v2/info/provider_subfeatures{query_is_working}"
17-
).json()
86+
provider_subfeatures = fetch_provider_subfeatures()
1887
all_providers = [
1988
(
2089
provider["provider"]["name"],
2190
provider["feature"]["name"],
2291
provider["subfeature"]["name"],
23-
provider.get("phase", ""),
92+
provider.get("phase") or "",
2493
)
2594
for provider in provider_subfeatures
2695
]
27-
for provider, feature, subfeature, phase in all_providers:
28-
if phase == "create_project":
29-
continue
30-
try:
31-
arguments = load_feature(
32-
FeatureDataEnum.SAMPLES_ARGS,
33-
feature=feature,
34-
subfeature=subfeature,
35-
phase=phase,
36-
)
37-
except NotImplementedError:
38-
continue
39-
try:
40-
res = compute_output(
41-
provider_name=provider,
42-
feature=feature,
43-
subfeature=subfeature,
44-
args=arguments,
45-
phase=phase,
46-
)
47-
if res["status"] == "fail":
48-
raise Exception(res["error"])
4996

50-
except Exception as exc:
97+
for provider_info in all_providers:
98+
result = process_provider(provider_info)
99+
if result is None:
100+
continue
101+
provider, feature, subfeature, error = result
102+
if error is not None:
51103
print(provider, feature, subfeature)
52-
print(exc)
53-
not_working.append((provider, feature, subfeature, exc))
104+
print(error)
105+
not_working.append((provider, feature, subfeature, error))
54106

55107
print("=================================")
56108
print("NOT WORKING PROVIDERS WITH ERRORS")
@@ -59,3 +111,7 @@
59111

60112
if not_working:
61113
raise Exception
114+
115+
116+
if __name__ == "__main__":
117+
main()

0 commit comments

Comments
 (0)