Skip to content

Commit ab205e7

Browse files
authored
partners/openai + community: Async Azure AD token provider support for Azure OpenAI (#27488)
This PR introduces a new `azure_ad_async_token_provider` attribute to the `AzureOpenAI` and `AzureChatOpenAI` classes in `partners/openai` and `community` packages, given it's currently supported on `openai` package as [AsyncAzureADTokenProvider](https://github.com/openai/openai-python/blob/main/src/openai/lib/azure.py#L33) type. The reason for creating a new attribute is to avoid breaking changes. Let's say you have an existing code that uses a `AzureOpenAI` or `AzureChatOpenAI` instance to perform both sync and async operations. The `azure_ad_token_provider` will work exactly as it is today, while `azure_ad_async_token_provider` will override it for async requests. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.
1 parent 3468442 commit ab205e7

File tree

6 files changed

+96
-11
lines changed

6 files changed

+96
-11
lines changed

libs/community/langchain_community/chat_models/azure_openai.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import os
77
import warnings
8-
from typing import Any, Callable, Dict, List, Union
8+
from typing import Any, Awaitable, Callable, Dict, List, Union
99

1010
from langchain_core._api.deprecation import deprecated
1111
from langchain_core.outputs import ChatResult
@@ -90,7 +90,13 @@ class AzureChatOpenAI(ChatOpenAI):
9090
azure_ad_token_provider: Union[Callable[[], str], None] = None
9191
"""A function that returns an Azure Active Directory token.
9292
93-
Will be invoked on every request.
93+
Will be invoked on every sync request. For async requests,
94+
will be invoked if `azure_ad_async_token_provider` is not provided.
95+
"""
96+
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
97+
"""A function that returns an Azure Active Directory token.
98+
99+
Will be invoked on every async request.
94100
"""
95101
model_version: str = ""
96102
"""Legacy, for openai<1.0.0 support."""
@@ -208,6 +214,12 @@ def validate_environment(cls, values: Dict) -> Dict:
208214
"http_client": values["http_client"],
209215
}
210216
values["client"] = openai.AzureOpenAI(**client_params).chat.completions
217+
218+
azure_ad_async_token_provider = values["azure_ad_async_token_provider"]
219+
220+
if azure_ad_async_token_provider:
221+
client_params["azure_ad_token_provider"] = azure_ad_async_token_provider
222+
211223
values["async_client"] = openai.AsyncAzureOpenAI(
212224
**client_params
213225
).chat.completions

libs/community/langchain_community/embeddings/azure_openai.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import os
66
import warnings
7-
from typing import Any, Callable, Dict, Optional, Union
7+
from typing import Any, Awaitable, Callable, Dict, Optional, Union
88

99
from langchain_core._api.deprecation import deprecated
1010
from langchain_core.utils import get_from_dict_or_env
@@ -49,7 +49,13 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
4949
azure_ad_token_provider: Union[Callable[[], str], None] = None
5050
"""A function that returns an Azure Active Directory token.
5151
52-
Will be invoked on every request.
52+
Will be invoked on every sync request. For async requests,
53+
will be invoked if `azure_ad_async_token_provider` is not provided.
54+
"""
55+
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
56+
"""A function that returns an Azure Active Directory token.
57+
58+
Will be invoked on every async request.
5359
"""
5460
openai_api_version: Optional[str] = Field(default=None, alias="api_version")
5561
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
@@ -162,6 +168,12 @@ def post_init_validator(self) -> Self:
162168
"http_client": self.http_client,
163169
}
164170
self.client = openai.AzureOpenAI(**client_params).embeddings
171+
172+
if self.azure_ad_async_token_provider:
173+
client_params["azure_ad_token_provider"] = (
174+
self.azure_ad_async_token_provider
175+
)
176+
165177
self.async_client = openai.AsyncAzureOpenAI(**client_params).embeddings
166178
else:
167179
self.client = openai.Embedding

libs/community/langchain_community/llms/openai.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
AbstractSet,
99
Any,
1010
AsyncIterator,
11+
Awaitable,
1112
Callable,
1213
Collection,
1314
Dict,
@@ -804,7 +805,13 @@ class AzureOpenAI(BaseOpenAI):
804805
azure_ad_token_provider: Union[Callable[[], str], None] = None
805806
"""A function that returns an Azure Active Directory token.
806807
807-
Will be invoked on every request.
808+
Will be invoked on every sync request. For async requests,
809+
will be invoked if `azure_ad_async_token_provider` is not provided.
810+
"""
811+
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
812+
"""A function that returns an Azure Active Directory token.
813+
814+
Will be invoked on every async request.
808815
"""
809816
openai_api_type: str = ""
810817
"""Legacy, for openai<1.0.0 support."""
@@ -922,6 +929,12 @@ def validate_environment(cls, values: Dict) -> Dict:
922929
"http_client": values["http_client"],
923930
}
924931
values["client"] = openai.AzureOpenAI(**client_params).completions
932+
933+
azure_ad_async_token_provider = values["azure_ad_async_token_provider"]
934+
935+
if azure_ad_async_token_provider:
936+
client_params["azure_ad_token_provider"] = azure_ad_async_token_provider
937+
925938
values["async_client"] = openai.AsyncAzureOpenAI(
926939
**client_params
927940
).completions

libs/partners/openai/langchain_openai/chat_models/azure.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,18 @@
44

55
import logging
66
import os
7-
from typing import Any, Callable, Dict, List, Optional, Type, TypedDict, TypeVar, Union
7+
from typing import (
8+
Any,
9+
Awaitable,
10+
Callable,
11+
Dict,
12+
List,
13+
Optional,
14+
Type,
15+
TypedDict,
16+
TypeVar,
17+
Union,
18+
)
819

920
import openai
1021
from langchain_core.language_models.chat_models import LangSmithParams
@@ -494,7 +505,14 @@ class Joke(BaseModel):
494505
azure_ad_token_provider: Union[Callable[[], str], None] = None
495506
"""A function that returns an Azure Active Directory token.
496507
497-
Will be invoked on every request.
508+
Will be invoked on every sync request. For async requests,
509+
will be invoked if `azure_ad_async_token_provider` is not provided.
510+
"""
511+
512+
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
513+
"""A function that returns an Azure Active Directory token.
514+
515+
Will be invoked on every async request.
498516
"""
499517

500518
model_version: str = ""
@@ -633,6 +651,12 @@ def validate_environment(self) -> Self:
633651
self.client = self.root_client.chat.completions
634652
if not self.async_client:
635653
async_specific = {"http_client": self.http_async_client}
654+
655+
if self.azure_ad_async_token_provider:
656+
client_params["azure_ad_token_provider"] = (
657+
self.azure_ad_async_token_provider
658+
)
659+
636660
self.root_async_client = openai.AsyncAzureOpenAI(
637661
**client_params,
638662
**async_specific, # type: ignore[arg-type]

libs/partners/openai/langchain_openai/embeddings/azure.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import Callable, Optional, Union
5+
from typing import Awaitable, Callable, Optional, Union
66

77
import openai
88
from langchain_core.utils import from_env, secret_from_env
@@ -146,7 +146,13 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
146146
azure_ad_token_provider: Union[Callable[[], str], None] = None
147147
"""A function that returns an Azure Active Directory token.
148148
149-
Will be invoked on every request.
149+
Will be invoked on every sync request. For async requests,
150+
will be invoked if `azure_ad_async_token_provider` is not provided.
151+
"""
152+
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
153+
"""A function that returns an Azure Active Directory token.
154+
155+
Will be invoked on every async request.
150156
"""
151157
openai_api_type: Optional[str] = Field(
152158
default_factory=from_env("OPENAI_API_TYPE", default="azure")
@@ -203,6 +209,12 @@ def validate_environment(self) -> Self:
203209
).embeddings
204210
if not self.async_client:
205211
async_specific: dict = {"http_client": self.http_async_client}
212+
213+
if self.azure_ad_async_token_provider:
214+
client_params["azure_ad_token_provider"] = (
215+
self.azure_ad_async_token_provider
216+
)
217+
206218
self.async_client = openai.AsyncAzureOpenAI(
207219
**client_params, # type: ignore[arg-type]
208220
**async_specific,

libs/partners/openai/langchain_openai/llms/azure.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4-
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
4+
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Union
55

66
import openai
77
from langchain_core.language_models import LangSmithParams
@@ -73,7 +73,13 @@ class AzureOpenAI(BaseOpenAI):
7373
azure_ad_token_provider: Union[Callable[[], str], None] = None
7474
"""A function that returns an Azure Active Directory token.
7575
76-
Will be invoked on every request.
76+
Will be invoked on every sync request. For async requests,
77+
will be invoked if `azure_ad_async_token_provider` is not provided.
78+
"""
79+
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
80+
"""A function that returns an Azure Active Directory token.
81+
82+
Will be invoked on every async request.
7783
"""
7884
openai_api_type: Optional[str] = Field(
7985
default_factory=from_env("OPENAI_API_TYPE", default="azure")
@@ -158,6 +164,12 @@ def validate_environment(self) -> Self:
158164
).completions
159165
if not self.async_client:
160166
async_specific = {"http_client": self.http_async_client}
167+
168+
if self.azure_ad_async_token_provider:
169+
client_params["azure_ad_token_provider"] = (
170+
self.azure_ad_async_token_provider
171+
)
172+
161173
self.async_client = openai.AsyncAzureOpenAI(
162174
**client_params,
163175
**async_specific, # type: ignore[arg-type]

0 commit comments

Comments
 (0)