Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
OPENAI_ORG_ID: ${{ secrets.OPENAI_ORG_ID }}
SKIP_TESTS_NAAI: "tests/llm/chat_completion tests/llm/embedding tests/data"
SKIP_TESTS_NAAI: "tests/llm/chat_completion tests/llm/embedding tests/llm/image_gen tests/data"
run: uv run nox -s test-${{ matrix.python-version }}
quality:
runs-on: ubuntu-24.04
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ all machines that use the project, both during development and in production.
To install all dependencies into an isolated virtual environment:

```shell
uv sync --all-extras
uv sync --all-extras --all-groups
```

To upgrade all dependencies to their latest versions:
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test(s: Session) -> None:

# Skip tests in directories specified by the SKIP_TESTS_NAII environment variable.
skip_tests = os.getenv("SKIP_TESTS_NAAI", "")
skip_tests += " tests/llm/chat_completion/ tests/llm/embedding/"
skip_tests += " tests/llm/chat_completion/ tests/llm/embedding/ tests/llm/image_gen/"
skip_args = [f"--ignore={dir}" for dir in skip_tests.split()] if skip_tests else []

s.run(
Expand Down
11 changes: 6 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "not-again-ai"
version = "0.18.0"
version = "0.19.0"
description = "Designed to once and for all collect all the little things that come up over and over again in AI projects and put them in one place."
authors = [
{ name = "DaveCoDev", email = "[email protected]" }
Expand All @@ -23,8 +23,8 @@ classifiers = [
]
requires-python = ">=3.11"
dependencies = [
"loguru>=0.7",
"pydantic>=2.10",
"loguru>=0.7,<1.0",
"pydantic>=2.11,<3.0",
]

[project.urls]
Expand All @@ -38,10 +38,11 @@ data = [
"pytest-playwright>=0.7,<1.0",
]
llm = [
"anthropic>=0.49,<1.0",
"anthropic>=0.50,<1.0",
"azure-identity>=1.21,<2.0",
"google-genai>1.12,<2.0",
"ollama>=0.4,<1.0",
"openai>=1.68,<2.0",
"openai>=1.76,<2.0",
"python-liquid>=2.0,<3.0",
"tiktoken>=0.9,<1.0"
]
Expand Down
7 changes: 5 additions & 2 deletions src/not_again_ai/llm/chat_completion/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any

from not_again_ai.llm.chat_completion.providers.anthropic_api import anthropic_chat_completion
from not_again_ai.llm.chat_completion.providers.gemini_api import gemini_chat_completion
from not_again_ai.llm.chat_completion.providers.ollama_api import ollama_chat_completion, ollama_chat_completion_stream
from not_again_ai.llm.chat_completion.providers.openai_api import openai_chat_completion, openai_chat_completion_stream
from not_again_ai.llm.chat_completion.types import ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse
Expand All @@ -16,6 +17,8 @@ def chat_completion(
- `openai` - OpenAI
- `azure_openai` - Azure OpenAI
- `ollama` - Ollama
- `anthropic` - Anthropic
- `gemini` - Gemini

Args:
request: Request parameter object
Expand All @@ -31,6 +34,8 @@ def chat_completion(
return ollama_chat_completion(request, client)
elif provider == "anthropic":
return anthropic_chat_completion(request, client)
elif provider == "gemini":
return gemini_chat_completion(request, client)
else:
raise ValueError(f"Provider {provider} not supported")

Expand All @@ -43,8 +48,6 @@ async def chat_completion_stream(
"""Stream a chat completion response from the given provider. Currently supported providers:
- `openai` - OpenAI
- `azure_openai` - Azure OpenAI
- `ollama` - Ollama
- `anthropic` - Anthropic

Args:
request: Request parameter object
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,12 @@ def anthropic_chat_completion(request: ChatCompletionRequest, client: Callable[.
elif tool_choice_value in ["auto", "any"]:
tool_choice["type"] = "auto"
if kwargs.get("parallel_tool_calls") is not None:
tool_choice["disable_parallel_tool_use"] = str(not kwargs["parallel_tool_calls"])
tool_choice["disable_parallel_tool_use"] = not kwargs["parallel_tool_calls"] # type: ignore
else:
tool_choice["name"] = tool_choice_value
tool_choice["type"] = "tool"
if kwargs.get("parallel_tool_calls") is not None:
tool_choice["disable_parallel_tool_use"] = str(not kwargs["parallel_tool_calls"])
tool_choice["disable_parallel_tool_use"] = not kwargs["parallel_tool_calls"] # type: ignore
kwargs["tool_choice"] = tool_choice
kwargs.pop("parallel_tool_calls", None)

Expand Down
237 changes: 237 additions & 0 deletions src/not_again_ai/llm/chat_completion/providers/gemini_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import base64
from collections.abc import Callable
import os
import time
from typing import Any

from google import genai
from google.genai import types
from google.genai.types import FunctionCall, FunctionCallingConfigMode, GenerateContentResponse

from not_again_ai.llm.chat_completion.types import (
AssistantMessage,
ChatCompletionChoice,
ChatCompletionRequest,
ChatCompletionResponse,
Function,
ImageContent,
Role,
TextContent,
ToolCall,
)

# This should be all of the options we want to support in types.GenerateContentConfig, that are not handled otherwise
GEMINI_PARAMETER_MAP = {
"max_completion_tokens": "max_output_tokens",
"temperature": "temperature",
"top_p": "top_p",
"top_k": "top_k",
}

GEMINI_FINISH_REASON_MAP = {
"STOP": "stop",
"MAX_TOKENS": "max_tokens",
"SAFETY": "safety",
"RECITATION": "recitation",
"LANGUAGE": "language",
"OTHER": "other",
"BLOCKLIST": "blocklist",
"PROHIBITED_CONTENT": "prohibited_content",
"SPII": "spii",
"MALFORMED_FUNCTION_CALL": "malformed_function_call",
"IMAGE_SAFETY": "image_safety",
}


def gemini_chat_completion(request: ChatCompletionRequest, client: Callable[..., Any]) -> ChatCompletionResponse:
"""Experimental Gemini chat completion function."""
# Handle messages
# Any system messages need to be removed from messages and concatenated into a single string (in order)
system = ""
contents = []
for message in request.messages:
if message.role == "system":
# Handle both string content and structured content
if isinstance(message.content, str):
system += message.content + "\n"
else:
# If it's a list of content parts, extract text content
for part in message.content:
if hasattr(part, "text"):
system += part.text + "\n"
elif message.role == "tool":
tool_name = message.name if message.name is not None else ""
function_response_part = types.Part.from_function_response(
name=tool_name,
response={"result": message.content},
)
contents.append(
types.Content(
role="user",
parts=[function_response_part],
)
)
elif message.role == "assistant":
if message.content and isinstance(message.content, str):
contents.append(types.Content(role="model", parts=[types.Part(text=message.content)]))
function_parts = []
if isinstance(message, AssistantMessage) and message.tool_calls:
for tool_call in message.tool_calls:
function_call_part = types.Part(
function_call=FunctionCall(
id=tool_call.id,
name=tool_call.function.name,
args=tool_call.function.arguments,
)
)
function_parts.append(function_call_part)
if function_parts:
contents.append(types.Content(role="model", parts=function_parts))
elif message.role == "user":
if isinstance(message.content, str):
contents.append(types.Content(role="user", parts=[types.Part(text=message.content)]))
elif isinstance(message.content, list):
parts = []
for part in message.content:
if isinstance(part, TextContent):
parts.append(types.Part(text=part.text))
elif isinstance(part, ImageContent):
# Extract MIME type and data from data URI
uri_parts = part.image_url.url.split(",", 1)
if len(uri_parts) == 2:
mime_type = uri_parts[0].split(":")[1].split(";")[0]
base64_data = uri_parts[1]
image_data = base64.b64decode(base64_data)
parts.append(types.Part.from_bytes(mime_type=mime_type, data=image_data))
contents.append(types.Content(role="user", parts=parts))

kwargs: dict[str, Any] = {}
kwargs["contents"] = contents
kwargs["model"] = request.model
config: dict[str, Any] = {}
config["system_instruction"] = system.rstrip()
config["automatic_function_calling"] = {"disable": True}

# Handle the possible tool choice options
if request.tool_choice:
tool_choice = request.tool_choice
if tool_choice == "auto":
config["tool_config"] = types.FunctionCallingConfig(mode=FunctionCallingConfigMode.AUTO)
elif tool_choice == "any":
config["tool_config"] = types.FunctionCallingConfig(mode=FunctionCallingConfigMode.ANY)
elif tool_choice == "none":
config["tool_config"] = types.FunctionCallingConfig(mode=FunctionCallingConfigMode.NONE)
elif isinstance(tool_choice, list):
config["tool_config"] = types.FunctionCallingConfig(
mode=FunctionCallingConfigMode.ANY, allowed_function_names=tool_choice
)
elif tool_choice not in (None, "auto", "any", "none"):
config["tool_config"] = types.FunctionCallingConfig(
mode=FunctionCallingConfigMode.ANY, allowed_function_names=[tool_choice]
)

# Handle tools
tools = []
for tool in request.tools or []:
tools.append(types.Tool(function_declarations=[tool])) # type: ignore
if tools:
config["tools"] = tools

# Everything else defined in GEMINI_PARAMETER_MAP goes into kwargs["config"]
request_kwargs = request.model_dump(mode="json", exclude_none=True)
for key, value in GEMINI_PARAMETER_MAP.items():
if value is not None and key in request_kwargs:
config[value] = request_kwargs.pop(key)

kwargs["config"] = types.GenerateContentConfig(**config)

start_time = time.time()
response: GenerateContentResponse = client(**kwargs)
end_time = time.time()
response_duration = round(end_time - start_time, 4)

finish_reason = "other"
if response.candidates and response.candidates[0].finish_reason:
finish_reason_str = str(response.candidates[0].finish_reason)
finish_reason = GEMINI_FINISH_REASON_MAP.get(finish_reason_str, "other")

tool_calls: list[ToolCall] = []
tool_call_objs = response.function_calls
if tool_call_objs:
for tool_call_obj in tool_call_objs:
tool_call_id = tool_call_obj.id if tool_call_obj.id else ""
tool_calls.append(
ToolCall(
id=tool_call_id,
function=Function(
name=tool_call_obj.name if tool_call_obj.name is not None else "",
arguments=tool_call_obj.args if tool_call_obj.args is not None else {},
),
)
)

assistant_message = ""
if (
response.candidates
and response.candidates[0].content
and response.candidates[0].content.parts
and response.candidates[0].content.parts[0].text
):
assistant_message = response.candidates[0].content.parts[0].text

choice = ChatCompletionChoice(
message=AssistantMessage(
role=Role.ASSISTANT,
content=assistant_message,
tool_calls=tool_calls,
),
finish_reason=finish_reason,
)

completion_tokens = 0
# Add null check for usage_metadata
if response.usage_metadata is not None:
if response.usage_metadata.thoughts_token_count:
completion_tokens = response.usage_metadata.thoughts_token_count
if response.usage_metadata.candidates_token_count:
completion_tokens += response.usage_metadata.candidates_token_count

# Set safe default for prompt_tokens
prompt_tokens = 0
if response.usage_metadata is not None and response.usage_metadata.prompt_token_count:
prompt_tokens = response.usage_metadata.prompt_token_count

chat_completion_response = ChatCompletionResponse(
choices=[choice],
completion_tokens=completion_tokens,
prompt_tokens=prompt_tokens,
response_duration=response_duration,
)
return chat_completion_response


def create_client_callable(client_class: type[genai.Client], **client_args: Any) -> Callable[..., Any]:
"""Creates a callable that instantiates and uses a Google genai client.

Args:
client_class: The Google genai client class to instantiate
**client_args: Arguments to pass to the client constructor

Returns:
A callable that creates a client and returns completion results
"""
filtered_args = {k: v for k, v in client_args.items() if v is not None}

def client_callable(**kwargs: Any) -> Any:
client = client_class(**filtered_args)
completion = client.models.generate_content(**kwargs)
return completion

return client_callable


def gemini_client(api_key: str | None = None) -> Callable[..., Any]:
if not api_key:
api_key = os.environ.get("GEMINI_API_KEY")
client_callable = create_client_callable(genai.Client, api_key=api_key)
return client_callable
5 changes: 1 addition & 4 deletions src/not_again_ai/llm/chat_completion/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,7 @@ class ChatCompletionRequest(BaseModel):

class ChatCompletionChoice(BaseModel):
message: AssistantMessage
finish_reason: Literal[
"stop", "length", "tool_calls", "content_filter", "end_turn", "max_tokens", "stop_sequence", "tool_use"
]

finish_reason: str
json_message: dict[str, Any] | None = Field(default=None)
logprobs: list[dict[str, Any] | list[dict[str, Any]]] | None = Field(default=None)

Expand Down
4 changes: 4 additions & 0 deletions src/not_again_ai/llm/image_gen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from not_again_ai.llm.image_gen.interface import create_image
from not_again_ai.llm.image_gen.types import ImageGenRequest

__all__ = ["ImageGenRequest", "create_image"]
24 changes: 24 additions & 0 deletions src/not_again_ai/llm/image_gen/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from collections.abc import Callable
from typing import Any

from not_again_ai.llm.image_gen.providers.openai_api import openai_create_image
from not_again_ai.llm.image_gen.types import ImageGenRequest, ImageGenResponse


def create_image(request: ImageGenRequest, provider: str, client: Callable[..., Any]) -> ImageGenResponse:
"""Get a image response from the given provider. Currently supported providers:
- `openai` - OpenAI
- `azure_openai` - Azure OpenAI

Args:
request: Request parameter object
provider: The supported provider name
client: Client information, see the provider's implementation for what can be provided

Returns:
ImageGenResponse: The image generation response.
"""
if provider == "openai" or provider == "azure_openai":
return openai_create_image(request, client)
else:
raise ValueError(f"Provider {provider} not supported")
Empty file.
Loading
Loading