Skip to content

Commit 119dc22

Browse files
Gemini API and gpt-image-1 (#18)
* init gemini support * openai gpt-image-1 * update gemini api * misc
1 parent 8f1b33c commit 119dc22

File tree

20 files changed

+1587
-582
lines changed

20 files changed

+1587
-582
lines changed

.github/workflows/python.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
env:
3636
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
3737
OPENAI_ORG_ID: ${{ secrets.OPENAI_ORG_ID }}
38-
SKIP_TESTS_NAAI: "tests/llm/chat_completion tests/llm/embedding tests/data"
38+
SKIP_TESTS_NAAI: "tests/llm/chat_completion tests/llm/embedding tests/llm/image_gen tests/data"
3939
run: uv run nox -s test-${{ matrix.python-version }}
4040
quality:
4141
runs-on: ubuntu-24.04

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ all machines that use the project, both during development and in production.
9595
To install all dependencies into an isolated virtual environment:
9696

9797
```shell
98-
uv sync --all-extras
98+
uv sync --all-extras --all-groups
9999
```
100100

101101
To upgrade all dependencies to their latest versions:

noxfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test(s: Session) -> None:
2626

2727
# Skip tests in directories specified by the SKIP_TESTS_NAII environment variable.
2828
skip_tests = os.getenv("SKIP_TESTS_NAAI", "")
29-
skip_tests += " tests/llm/chat_completion/ tests/llm/embedding/"
29+
skip_tests += " tests/llm/chat_completion/ tests/llm/embedding/ tests/llm/image_gen/"
3030
skip_args = [f"--ignore={dir}" for dir in skip_tests.split()] if skip_tests else []
3131

3232
s.run(

pyproject.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "not-again-ai"
3-
version = "0.18.0"
3+
version = "0.19.0"
44
description = "Designed to once and for all collect all the little things that come up over and over again in AI projects and put them in one place."
55
authors = [
66
{ name = "DaveCoDev", email = "[email protected]" }
@@ -23,8 +23,8 @@ classifiers = [
2323
]
2424
requires-python = ">=3.11"
2525
dependencies = [
26-
"loguru>=0.7",
27-
"pydantic>=2.10",
26+
"loguru>=0.7,<1.0",
27+
"pydantic>=2.11,<3.0",
2828
]
2929

3030
[project.urls]
@@ -38,10 +38,11 @@ data = [
3838
"pytest-playwright>=0.7,<1.0",
3939
]
4040
llm = [
41-
"anthropic>=0.49,<1.0",
41+
"anthropic>=0.50,<1.0",
4242
"azure-identity>=1.21,<2.0",
43+
"google-genai>1.12,<2.0",
4344
"ollama>=0.4,<1.0",
44-
"openai>=1.68,<2.0",
45+
"openai>=1.76,<2.0",
4546
"python-liquid>=2.0,<3.0",
4647
"tiktoken>=0.9,<1.0"
4748
]

src/not_again_ai/llm/chat_completion/interface.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any
33

44
from not_again_ai.llm.chat_completion.providers.anthropic_api import anthropic_chat_completion
5+
from not_again_ai.llm.chat_completion.providers.gemini_api import gemini_chat_completion
56
from not_again_ai.llm.chat_completion.providers.ollama_api import ollama_chat_completion, ollama_chat_completion_stream
67
from not_again_ai.llm.chat_completion.providers.openai_api import openai_chat_completion, openai_chat_completion_stream
78
from not_again_ai.llm.chat_completion.types import ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse
@@ -16,6 +17,8 @@ def chat_completion(
1617
- `openai` - OpenAI
1718
- `azure_openai` - Azure OpenAI
1819
- `ollama` - Ollama
20+
- `anthropic` - Anthropic
21+
- `gemini` - Gemini
1922
2023
Args:
2124
request: Request parameter object
@@ -31,6 +34,8 @@ def chat_completion(
3134
return ollama_chat_completion(request, client)
3235
elif provider == "anthropic":
3336
return anthropic_chat_completion(request, client)
37+
elif provider == "gemini":
38+
return gemini_chat_completion(request, client)
3439
else:
3540
raise ValueError(f"Provider {provider} not supported")
3641

@@ -43,8 +48,6 @@ async def chat_completion_stream(
4348
"""Stream a chat completion response from the given provider. Currently supported providers:
4449
- `openai` - OpenAI
4550
- `azure_openai` - Azure OpenAI
46-
- `ollama` - Ollama
47-
- `anthropic` - Anthropic
4851
4952
Args:
5053
request: Request parameter object

src/not_again_ai/llm/chat_completion/providers/anthropic_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,12 @@ def anthropic_chat_completion(request: ChatCompletionRequest, client: Callable[.
103103
elif tool_choice_value in ["auto", "any"]:
104104
tool_choice["type"] = "auto"
105105
if kwargs.get("parallel_tool_calls") is not None:
106-
tool_choice["disable_parallel_tool_use"] = str(not kwargs["parallel_tool_calls"])
106+
tool_choice["disable_parallel_tool_use"] = not kwargs["parallel_tool_calls"] # type: ignore
107107
else:
108108
tool_choice["name"] = tool_choice_value
109109
tool_choice["type"] = "tool"
110110
if kwargs.get("parallel_tool_calls") is not None:
111-
tool_choice["disable_parallel_tool_use"] = str(not kwargs["parallel_tool_calls"])
111+
tool_choice["disable_parallel_tool_use"] = not kwargs["parallel_tool_calls"] # type: ignore
112112
kwargs["tool_choice"] = tool_choice
113113
kwargs.pop("parallel_tool_calls", None)
114114

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
import base64
2+
from collections.abc import Callable
3+
import os
4+
import time
5+
from typing import Any
6+
7+
from google import genai
8+
from google.genai import types
9+
from google.genai.types import FunctionCall, FunctionCallingConfigMode, GenerateContentResponse
10+
11+
from not_again_ai.llm.chat_completion.types import (
12+
AssistantMessage,
13+
ChatCompletionChoice,
14+
ChatCompletionRequest,
15+
ChatCompletionResponse,
16+
Function,
17+
ImageContent,
18+
Role,
19+
TextContent,
20+
ToolCall,
21+
)
22+
23+
# This should be all of the options we want to support in types.GenerateContentConfig, that are not handled otherwise
24+
GEMINI_PARAMETER_MAP = {
25+
"max_completion_tokens": "max_output_tokens",
26+
"temperature": "temperature",
27+
"top_p": "top_p",
28+
"top_k": "top_k",
29+
}
30+
31+
GEMINI_FINISH_REASON_MAP = {
32+
"STOP": "stop",
33+
"MAX_TOKENS": "max_tokens",
34+
"SAFETY": "safety",
35+
"RECITATION": "recitation",
36+
"LANGUAGE": "language",
37+
"OTHER": "other",
38+
"BLOCKLIST": "blocklist",
39+
"PROHIBITED_CONTENT": "prohibited_content",
40+
"SPII": "spii",
41+
"MALFORMED_FUNCTION_CALL": "malformed_function_call",
42+
"IMAGE_SAFETY": "image_safety",
43+
}
44+
45+
46+
def gemini_chat_completion(request: ChatCompletionRequest, client: Callable[..., Any]) -> ChatCompletionResponse:
47+
"""Experimental Gemini chat completion function."""
48+
# Handle messages
49+
# Any system messages need to be removed from messages and concatenated into a single string (in order)
50+
system = ""
51+
contents = []
52+
for message in request.messages:
53+
if message.role == "system":
54+
# Handle both string content and structured content
55+
if isinstance(message.content, str):
56+
system += message.content + "\n"
57+
else:
58+
# If it's a list of content parts, extract text content
59+
for part in message.content:
60+
if hasattr(part, "text"):
61+
system += part.text + "\n"
62+
elif message.role == "tool":
63+
tool_name = message.name if message.name is not None else ""
64+
function_response_part = types.Part.from_function_response(
65+
name=tool_name,
66+
response={"result": message.content},
67+
)
68+
contents.append(
69+
types.Content(
70+
role="user",
71+
parts=[function_response_part],
72+
)
73+
)
74+
elif message.role == "assistant":
75+
if message.content and isinstance(message.content, str):
76+
contents.append(types.Content(role="model", parts=[types.Part(text=message.content)]))
77+
function_parts = []
78+
if isinstance(message, AssistantMessage) and message.tool_calls:
79+
for tool_call in message.tool_calls:
80+
function_call_part = types.Part(
81+
function_call=FunctionCall(
82+
id=tool_call.id,
83+
name=tool_call.function.name,
84+
args=tool_call.function.arguments,
85+
)
86+
)
87+
function_parts.append(function_call_part)
88+
if function_parts:
89+
contents.append(types.Content(role="model", parts=function_parts))
90+
elif message.role == "user":
91+
if isinstance(message.content, str):
92+
contents.append(types.Content(role="user", parts=[types.Part(text=message.content)]))
93+
elif isinstance(message.content, list):
94+
parts = []
95+
for part in message.content:
96+
if isinstance(part, TextContent):
97+
parts.append(types.Part(text=part.text))
98+
elif isinstance(part, ImageContent):
99+
# Extract MIME type and data from data URI
100+
uri_parts = part.image_url.url.split(",", 1)
101+
if len(uri_parts) == 2:
102+
mime_type = uri_parts[0].split(":")[1].split(";")[0]
103+
base64_data = uri_parts[1]
104+
image_data = base64.b64decode(base64_data)
105+
parts.append(types.Part.from_bytes(mime_type=mime_type, data=image_data))
106+
contents.append(types.Content(role="user", parts=parts))
107+
108+
kwargs: dict[str, Any] = {}
109+
kwargs["contents"] = contents
110+
kwargs["model"] = request.model
111+
config: dict[str, Any] = {}
112+
config["system_instruction"] = system.rstrip()
113+
config["automatic_function_calling"] = {"disable": True}
114+
115+
# Handle the possible tool choice options
116+
if request.tool_choice:
117+
tool_choice = request.tool_choice
118+
if tool_choice == "auto":
119+
config["tool_config"] = types.FunctionCallingConfig(mode=FunctionCallingConfigMode.AUTO)
120+
elif tool_choice == "any":
121+
config["tool_config"] = types.FunctionCallingConfig(mode=FunctionCallingConfigMode.ANY)
122+
elif tool_choice == "none":
123+
config["tool_config"] = types.FunctionCallingConfig(mode=FunctionCallingConfigMode.NONE)
124+
elif isinstance(tool_choice, list):
125+
config["tool_config"] = types.FunctionCallingConfig(
126+
mode=FunctionCallingConfigMode.ANY, allowed_function_names=tool_choice
127+
)
128+
elif tool_choice not in (None, "auto", "any", "none"):
129+
config["tool_config"] = types.FunctionCallingConfig(
130+
mode=FunctionCallingConfigMode.ANY, allowed_function_names=[tool_choice]
131+
)
132+
133+
# Handle tools
134+
tools = []
135+
for tool in request.tools or []:
136+
tools.append(types.Tool(function_declarations=[tool])) # type: ignore
137+
if tools:
138+
config["tools"] = tools
139+
140+
# Everything else defined in GEMINI_PARAMETER_MAP goes into kwargs["config"]
141+
request_kwargs = request.model_dump(mode="json", exclude_none=True)
142+
for key, value in GEMINI_PARAMETER_MAP.items():
143+
if value is not None and key in request_kwargs:
144+
config[value] = request_kwargs.pop(key)
145+
146+
kwargs["config"] = types.GenerateContentConfig(**config)
147+
148+
start_time = time.time()
149+
response: GenerateContentResponse = client(**kwargs)
150+
end_time = time.time()
151+
response_duration = round(end_time - start_time, 4)
152+
153+
finish_reason = "other"
154+
if response.candidates and response.candidates[0].finish_reason:
155+
finish_reason_str = str(response.candidates[0].finish_reason)
156+
finish_reason = GEMINI_FINISH_REASON_MAP.get(finish_reason_str, "other")
157+
158+
tool_calls: list[ToolCall] = []
159+
tool_call_objs = response.function_calls
160+
if tool_call_objs:
161+
for tool_call_obj in tool_call_objs:
162+
tool_call_id = tool_call_obj.id if tool_call_obj.id else ""
163+
tool_calls.append(
164+
ToolCall(
165+
id=tool_call_id,
166+
function=Function(
167+
name=tool_call_obj.name if tool_call_obj.name is not None else "",
168+
arguments=tool_call_obj.args if tool_call_obj.args is not None else {},
169+
),
170+
)
171+
)
172+
173+
assistant_message = ""
174+
if (
175+
response.candidates
176+
and response.candidates[0].content
177+
and response.candidates[0].content.parts
178+
and response.candidates[0].content.parts[0].text
179+
):
180+
assistant_message = response.candidates[0].content.parts[0].text
181+
182+
choice = ChatCompletionChoice(
183+
message=AssistantMessage(
184+
role=Role.ASSISTANT,
185+
content=assistant_message,
186+
tool_calls=tool_calls,
187+
),
188+
finish_reason=finish_reason,
189+
)
190+
191+
completion_tokens = 0
192+
# Add null check for usage_metadata
193+
if response.usage_metadata is not None:
194+
if response.usage_metadata.thoughts_token_count:
195+
completion_tokens = response.usage_metadata.thoughts_token_count
196+
if response.usage_metadata.candidates_token_count:
197+
completion_tokens += response.usage_metadata.candidates_token_count
198+
199+
# Set safe default for prompt_tokens
200+
prompt_tokens = 0
201+
if response.usage_metadata is not None and response.usage_metadata.prompt_token_count:
202+
prompt_tokens = response.usage_metadata.prompt_token_count
203+
204+
chat_completion_response = ChatCompletionResponse(
205+
choices=[choice],
206+
completion_tokens=completion_tokens,
207+
prompt_tokens=prompt_tokens,
208+
response_duration=response_duration,
209+
)
210+
return chat_completion_response
211+
212+
213+
def create_client_callable(client_class: type[genai.Client], **client_args: Any) -> Callable[..., Any]:
214+
"""Creates a callable that instantiates and uses a Google genai client.
215+
216+
Args:
217+
client_class: The Google genai client class to instantiate
218+
**client_args: Arguments to pass to the client constructor
219+
220+
Returns:
221+
A callable that creates a client and returns completion results
222+
"""
223+
filtered_args = {k: v for k, v in client_args.items() if v is not None}
224+
225+
def client_callable(**kwargs: Any) -> Any:
226+
client = client_class(**filtered_args)
227+
completion = client.models.generate_content(**kwargs)
228+
return completion
229+
230+
return client_callable
231+
232+
233+
def gemini_client(api_key: str | None = None) -> Callable[..., Any]:
234+
if not api_key:
235+
api_key = os.environ.get("GEMINI_API_KEY")
236+
client_callable = create_client_callable(genai.Client, api_key=api_key)
237+
return client_callable

src/not_again_ai/llm/chat_completion/types.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,7 @@ class ChatCompletionRequest(BaseModel):
138138

139139
class ChatCompletionChoice(BaseModel):
140140
message: AssistantMessage
141-
finish_reason: Literal[
142-
"stop", "length", "tool_calls", "content_filter", "end_turn", "max_tokens", "stop_sequence", "tool_use"
143-
]
144-
141+
finish_reason: str
145142
json_message: dict[str, Any] | None = Field(default=None)
146143
logprobs: list[dict[str, Any] | list[dict[str, Any]]] | None = Field(default=None)
147144

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from not_again_ai.llm.image_gen.interface import create_image
2+
from not_again_ai.llm.image_gen.types import ImageGenRequest
3+
4+
__all__ = ["ImageGenRequest", "create_image"]
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from collections.abc import Callable
2+
from typing import Any
3+
4+
from not_again_ai.llm.image_gen.providers.openai_api import openai_create_image
5+
from not_again_ai.llm.image_gen.types import ImageGenRequest, ImageGenResponse
6+
7+
8+
def create_image(request: ImageGenRequest, provider: str, client: Callable[..., Any]) -> ImageGenResponse:
9+
"""Get a image response from the given provider. Currently supported providers:
10+
- `openai` - OpenAI
11+
- `azure_openai` - Azure OpenAI
12+
13+
Args:
14+
request: Request parameter object
15+
provider: The supported provider name
16+
client: Client information, see the provider's implementation for what can be provided
17+
18+
Returns:
19+
ImageGenResponse: The image generation response.
20+
"""
21+
if provider == "openai" or provider == "azure_openai":
22+
return openai_create_image(request, client)
23+
else:
24+
raise ValueError(f"Provider {provider} not supported")

0 commit comments

Comments
 (0)