Skip to content

Commit 38b706c

Browse files
authored
🔑 feat: AWS / Bedrock Session Token Support (#167)
1 parent 87396d6 commit 38b706c

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ The following environment variables are required to run the application:
8484
- `AWS_DEFAULT_REGION`: (Optional) defaults to `us-east-1`
8585
- `AWS_ACCESS_KEY_ID`: (Optional) needed for bedrock embeddings
8686
- `AWS_SECRET_ACCESS_KEY`: (Optional) needed for bedrock embeddings
87+
- `AWS_SESSION_TOKEN`: (Optional) may be needed for bedrock embeddings
8788
- `GOOGLE_APPLICATION_CREDENTIALS`: (Optional) needed for Google VertexAI embeddings
8889
- `RAG_CHECK_EMBEDDING_CTX_LENGTH` (Optional) Default is true, disabling this will send raw input to the embedder, use this for custom embedding models.
8990

app/config.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ async def dispatch(self, request, call_next):
178178
OLLAMA_BASE_URL = get_env_variable("OLLAMA_BASE_URL", "http://ollama:11434")
179179
AWS_ACCESS_KEY_ID = get_env_variable("AWS_ACCESS_KEY_ID", "")
180180
AWS_SECRET_ACCESS_KEY = get_env_variable("AWS_SECRET_ACCESS_KEY", "")
181+
AWS_SESSION_TOKEN = get_env_variable("AWS_SESSION_TOKEN", "")
181182
GOOGLE_APPLICATION_CREDENTIALS = get_env_variable("GOOGLE_APPLICATION_CREDENTIALS", "")
182183
env_value = get_env_variable("RAG_CHECK_EMBEDDING_CTX_LENGTH", "True").lower()
183184
RAG_CHECK_EMBEDDING_CTX_LENGTH = True if env_value == "true" else False
@@ -229,11 +230,16 @@ def init_embeddings(provider, model):
229230
elif provider == EmbeddingsProvider.BEDROCK:
230231
from langchain_aws import BedrockEmbeddings
231232

232-
session = boto3.Session(
233-
aws_access_key_id=AWS_ACCESS_KEY_ID,
234-
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
235-
region_name=AWS_DEFAULT_REGION,
236-
)
233+
session_kwargs = {
234+
"aws_access_key_id": AWS_ACCESS_KEY_ID,
235+
"aws_secret_access_key": AWS_SECRET_ACCESS_KEY,
236+
"region_name": AWS_DEFAULT_REGION,
237+
}
238+
239+
if AWS_SESSION_TOKEN:
240+
session_kwargs["aws_session_token"] = AWS_SESSION_TOKEN
241+
242+
session = boto3.Session(**session_kwargs)
237243
return BedrockEmbeddings(
238244
client=session.client("bedrock-runtime"),
239245
model_id=model,

0 commit comments

Comments
 (0)