|
| 1 | +# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +import os |
| 6 | +from typing import Any, Awaitable, Callable, Optional, Union |
| 7 | + |
| 8 | +from openai.lib._pydantic import to_strict_json_schema |
| 9 | +from pydantic import BaseModel |
| 10 | + |
| 11 | +from haystack import component, default_from_dict, default_to_dict |
| 12 | +from haystack.components.generators.chat import OpenAIResponsesChatGenerator |
| 13 | +from haystack.dataclasses.streaming_chunk import StreamingCallbackT |
| 14 | +from haystack.tools import ToolsType, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset |
| 15 | +from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable |
| 16 | + |
| 17 | + |
| 18 | +@component |
| 19 | +class AzureOpenAIResponsesChatGenerator(OpenAIResponsesChatGenerator): |
| 20 | + """ |
| 21 | + Completes chats using OpenAI's Responses API on Azure. |
| 22 | +
|
| 23 | + It works with the gpt-5 and o-series models and supports streaming responses |
| 24 | + from OpenAI API. It uses [ChatMessage](https://docs.haystack.deepset.ai/docs/chatmessage) |
| 25 | + format in input and output. |
| 26 | +
|
| 27 | + You can customize how the text is generated by passing parameters to the |
| 28 | + OpenAI API. Use the `**generation_kwargs` argument when you initialize |
| 29 | + the component or when you run it. Any parameter that works with |
| 30 | + `openai.Responses.create` will work here too. |
| 31 | +
|
| 32 | + For details on OpenAI API parameters, see |
| 33 | + [OpenAI documentation](https://platform.openai.com/docs/api-reference/responses). |
| 34 | +
|
| 35 | + ### Usage example |
| 36 | +
|
| 37 | + ```python |
| 38 | + from haystack.components.generators.chat import AzureOpenAIResponsesChatGenerator |
| 39 | + from haystack.dataclasses import ChatMessage |
| 40 | +
|
| 41 | + messages = [ChatMessage.from_user("What's Natural Language Processing?")] |
| 42 | +
|
| 43 | + client = AzureOpenAIResponsesChatGenerator( |
| 44 | + azure_endpoint="https://example-resource.azure.openai.com/", |
| 45 | + generation_kwargs={"reasoning": {"effort": "low", "summary": "auto"}} |
| 46 | + ) |
| 47 | + response = client.run(messages) |
| 48 | + print(response) |
| 49 | + ``` |
| 50 | + """ |
| 51 | + |
| 52 | + # ruff: noqa: PLR0913 |
| 53 | + def __init__( |
| 54 | + self, |
| 55 | + *, |
| 56 | + api_key: Union[Secret, Callable[[], str], Callable[[], Awaitable[str]]] = Secret.from_env_var( |
| 57 | + "AZURE_OPENAI_API_KEY", strict=False |
| 58 | + ), |
| 59 | + azure_endpoint: Optional[str] = None, |
| 60 | + azure_deployment: str = "gpt-5-mini", |
| 61 | + streaming_callback: Optional[StreamingCallbackT] = None, |
| 62 | + organization: Optional[str] = None, |
| 63 | + generation_kwargs: Optional[dict[str, Any]] = None, |
| 64 | + timeout: Optional[float] = None, |
| 65 | + max_retries: Optional[int] = None, |
| 66 | + tools: Optional[ToolsType] = None, |
| 67 | + tools_strict: bool = False, |
| 68 | + http_client_kwargs: Optional[dict[str, Any]] = None, |
| 69 | + ): |
| 70 | + """ |
| 71 | + Initialize the AzureOpenAIResponsesChatGenerator component. |
| 72 | +
|
| 73 | + :param api_key: The API key to use for authentication. Can be: |
| 74 | + - A `Secret` object containing the API key. |
| 75 | + - A `Secret` object containing the [Azure Active Directory token](https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id). |
| 76 | + - A function that returns an Azure Active Directory token. |
| 77 | + :param azure_endpoint: The endpoint of the deployed model, for example `"https://example-resource.azure.openai.com/"`. |
| 78 | + :param azure_deployment: The deployment of the model, usually the model name. |
| 79 | + :param organization: Your organization ID, defaults to `None`. For help, see |
| 80 | + [Setting up your organization](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization). |
| 81 | + :param streaming_callback: A callback function called when a new token is received from the stream. |
| 82 | + It accepts [StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) |
| 83 | + as an argument. |
| 84 | + :param timeout: Timeout for OpenAI client calls. If not set, it defaults to either the |
| 85 | + `OPENAI_TIMEOUT` environment variable, or 30 seconds. |
| 86 | + :param max_retries: Maximum number of retries to contact OpenAI after an internal error. |
| 87 | + If not set, it defaults to either the `OPENAI_MAX_RETRIES` environment variable, or set to 5. |
| 88 | + :param generation_kwargs: Other parameters to use for the model. These parameters are sent |
| 89 | + directly to the OpenAI endpoint. |
| 90 | + See OpenAI [documentation](https://platform.openai.com/docs/api-reference/responses) for |
| 91 | + more details. |
| 92 | + Some of the supported parameters: |
| 93 | + - `temperature`: What sampling temperature to use. Higher values like 0.8 will make the output more random, |
| 94 | + while lower values like 0.2 will make it more focused and deterministic. |
| 95 | + - `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model |
| 96 | + considers the results of the tokens with top_p probability mass. For example, 0.1 means only the tokens |
| 97 | + comprising the top 10% probability mass are considered. |
| 98 | + - `previous_response_id`: The ID of the previous response. |
| 99 | + Use this to create multi-turn conversations. |
| 100 | + - `text_format`: A JSON schema or a Pydantic model that enforces the structure of the model's response. |
| 101 | + If provided, the output will always be validated against this |
| 102 | + format (unless the model returns a tool call). |
| 103 | + For details, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs). |
| 104 | + Notes: |
| 105 | + - This parameter accepts Pydantic models and JSON schemas for latest models starting from GPT-4o. |
| 106 | + Older models only support basic version of structured outputs through `{"type": "json_object"}`. |
| 107 | + For detailed information on JSON mode, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs#json-mode). |
| 108 | + - For structured outputs with streaming, |
| 109 | + the `text_format` must be a JSON schema and not a Pydantic model. |
| 110 | + - `reasoning`: A dictionary of parameters for reasoning. For example: |
| 111 | + - `summary`: The summary of the reasoning. |
| 112 | + - `effort`: The level of effort to put into the reasoning. Can be `low`, `medium` or `high`. |
| 113 | + - `generate_summary`: Whether to generate a summary of the reasoning. |
| 114 | + Note: OpenAI does not return the reasoning tokens, but we can view summary if its enabled. |
| 115 | + For details, see the [OpenAI Reasoning documentation](https://platform.openai.com/docs/guides/reasoning). |
| 116 | + :param tools: |
| 117 | + A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls. |
| 118 | + :param tools_strict: |
| 119 | + Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly |
| 120 | + the schema provided in the `parameters` field of the tool definition, but this may increase latency. |
| 121 | + :param http_client_kwargs: |
| 122 | + A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`. |
| 123 | + For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client). |
| 124 | + """ |
| 125 | + azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT") |
| 126 | + if azure_endpoint is None: |
| 127 | + raise ValueError( |
| 128 | + "You must provide `azure_endpoint` or set the `AZURE_OPENAI_ENDPOINT` environment variable." |
| 129 | + ) |
| 130 | + self._azure_endpoint = azure_endpoint |
| 131 | + self._azure_deployment = azure_deployment |
| 132 | + super(AzureOpenAIResponsesChatGenerator, self).__init__( |
| 133 | + api_key=api_key, # type: ignore[arg-type] |
| 134 | + model=self._azure_deployment, |
| 135 | + streaming_callback=streaming_callback, |
| 136 | + api_base_url=f"{self._azure_endpoint.rstrip('/')}/openai/v1", |
| 137 | + organization=organization, |
| 138 | + generation_kwargs=generation_kwargs, |
| 139 | + timeout=timeout, |
| 140 | + max_retries=max_retries, |
| 141 | + tools=tools, |
| 142 | + tools_strict=tools_strict, |
| 143 | + http_client_kwargs=http_client_kwargs, |
| 144 | + ) |
| 145 | + |
| 146 | + def to_dict(self) -> dict[str, Any]: |
| 147 | + """ |
| 148 | + Serialize this component to a dictionary. |
| 149 | +
|
| 150 | + :returns: |
| 151 | + The serialized component as a dictionary. |
| 152 | + """ |
| 153 | + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None |
| 154 | + |
| 155 | + # API key can be a secret or a callable |
| 156 | + serialized_api_key = ( |
| 157 | + serialize_callable(self.api_key) |
| 158 | + if callable(self.api_key) |
| 159 | + else self.api_key.to_dict() |
| 160 | + if isinstance(self.api_key, Secret) |
| 161 | + else None |
| 162 | + ) |
| 163 | + |
| 164 | + # If the response format is a Pydantic model, it's converted to openai's json schema format |
| 165 | + # If it's already a json schema, it's left as is |
| 166 | + generation_kwargs = self.generation_kwargs.copy() |
| 167 | + response_format = generation_kwargs.get("response_format") |
| 168 | + if response_format and issubclass(response_format, BaseModel): |
| 169 | + json_schema = { |
| 170 | + "type": "json_schema", |
| 171 | + "json_schema": { |
| 172 | + "name": response_format.__name__, |
| 173 | + "strict": True, |
| 174 | + "schema": to_strict_json_schema(response_format), |
| 175 | + }, |
| 176 | + } |
| 177 | + generation_kwargs["response_format"] = json_schema |
| 178 | + |
| 179 | + # OpenAI/MCP tools are passed as list of dictionaries |
| 180 | + serialized_tools: Union[dict[str, Any], list[dict[str, Any]], None] |
| 181 | + if self.tools and isinstance(self.tools, list) and isinstance(self.tools[0], dict): |
| 182 | + # mypy can't infer that self.tools is list[dict] here |
| 183 | + serialized_tools = self.tools # type: ignore[assignment] |
| 184 | + else: |
| 185 | + serialized_tools = serialize_tools_or_toolset(self.tools) # type: ignore[arg-type] |
| 186 | + |
| 187 | + return default_to_dict( |
| 188 | + self, |
| 189 | + azure_endpoint=self._azure_endpoint, |
| 190 | + api_key=serialized_api_key, |
| 191 | + azure_deployment=self._azure_deployment, |
| 192 | + streaming_callback=callback_name, |
| 193 | + organization=self.organization, |
| 194 | + generation_kwargs=generation_kwargs, |
| 195 | + timeout=self.timeout, |
| 196 | + max_retries=self.max_retries, |
| 197 | + tools=serialized_tools, |
| 198 | + tools_strict=self.tools_strict, |
| 199 | + http_client_kwargs=self.http_client_kwargs, |
| 200 | + ) |
| 201 | + |
| 202 | + @classmethod |
| 203 | + def from_dict(cls, data: dict[str, Any]) -> "AzureOpenAIResponsesChatGenerator": |
| 204 | + """ |
| 205 | + Deserialize this component from a dictionary. |
| 206 | +
|
| 207 | + :param data: The dictionary representation of this component. |
| 208 | + :returns: |
| 209 | + The deserialized component instance. |
| 210 | + """ |
| 211 | + serialized_api_key = data["init_parameters"].get("api_key") |
| 212 | + # If it's a dict most likely a Secret |
| 213 | + if isinstance(serialized_api_key, dict): |
| 214 | + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) |
| 215 | + # If it's a str, most likely a callable |
| 216 | + elif isinstance(serialized_api_key, str): |
| 217 | + data["init_parameters"]["api_key"] = deserialize_callable(serialized_api_key) |
| 218 | + |
| 219 | + # we only deserialize the tools if they are haystack tools |
| 220 | + # because openai tools are not serialized in the same way |
| 221 | + tools = data["init_parameters"].get("tools") |
| 222 | + if tools and ( |
| 223 | + isinstance(tools, dict) |
| 224 | + and tools.get("type") == "haystack.tools.toolset.Toolset" |
| 225 | + or isinstance(tools, list) |
| 226 | + and tools[0].get("type") == "haystack.tools.tool.Tool" |
| 227 | + ): |
| 228 | + deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools") |
| 229 | + |
| 230 | + init_params = data.get("init_parameters", {}) |
| 231 | + serialized_callback_handler = init_params.get("streaming_callback") |
| 232 | + if serialized_callback_handler: |
| 233 | + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) |
| 234 | + return default_from_dict(cls, data) |
0 commit comments