Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions llm-service/app/routers/index/sessions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from ....services.chat.suggested_questions import generate_suggested_questions
from ....services.chat_history.chat_history_manager import (
RagStudioChatMessage,
chat_history_manager,
get_chat_history_manager,
)
from ....services.chat_history.paginator import paginate
from ....services.metadata_apis import session_metadata_api
Expand Down Expand Up @@ -142,7 +142,7 @@ class RagStudioChatHistoryResponse(BaseModel):
def chat_history(
session_id: int, limit: Optional[int] = None, offset: Optional[int] = None
) -> RagStudioChatHistoryResponse:
results = chat_history_manager.retrieve_chat_history(session_id=session_id)
results = get_chat_history_manager().retrieve_chat_history(session_id=session_id)

paginated_results, previous_id, next_id = paginate(results, limit, offset)
return RagStudioChatHistoryResponse(
Expand All @@ -158,8 +158,8 @@ def chat_history(
)
@exceptions.propagates
def get_message_by_id(session_id: int, message_id: str) -> RagStudioChatMessage:
results: list[RagStudioChatMessage] = chat_history_manager.retrieve_chat_history(
session_id=session_id
results: list[RagStudioChatMessage] = (
get_chat_history_manager().retrieve_chat_history(session_id=session_id)
)
for message in results:
if message.id == message_id:
Expand All @@ -175,14 +175,14 @@ def get_message_by_id(session_id: int, message_id: str) -> RagStudioChatMessage:
)
@exceptions.propagates
def clear_chat_history(session_id: int) -> str:
chat_history_manager.clear_chat_history(session_id=session_id)
get_chat_history_manager().clear_chat_history(session_id=session_id)
return "Chat history cleared."


@router.delete("", summary="Deletes the requested session.")
@exceptions.propagates
def delete_session(session_id: int) -> str:
chat_history_manager.delete_chat_history(session_id=session_id)
get_chat_history_manager().delete_chat_history(session_id=session_id)
return "Chat history deleted."


Expand Down
6 changes: 3 additions & 3 deletions llm-service/app/services/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
Evaluation,
RagMessage,
RagStudioChatMessage,
chat_history_manager,
get_chat_history_manager,
)
from app.services.metadata_apis.session_metadata_api import Session
from app.services.mlflow import record_rag_mlflow_run, record_direct_llm_mlflow_run
Expand Down Expand Up @@ -172,7 +172,7 @@ def finalize_response(
record_rag_mlflow_run(
new_chat_message, query_configuration, response_id, session, user_name
)
chat_history_manager.append_to_history(session.id, [new_chat_message])
get_chat_history_manager().append_to_history(session.id, [new_chat_message])

return new_chat_message

Expand All @@ -198,5 +198,5 @@ def direct_llm_chat(
timestamp=time.time(),
condensed_question=None,
)
chat_history_manager.append_to_history(session.id, [new_chat_message])
get_chat_history_manager().append_to_history(session.id, [new_chat_message])
return new_chat_message
4 changes: 2 additions & 2 deletions llm-service/app/services/chat/streaming_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from app.services.chat_history.chat_history_manager import (
RagStudioChatMessage,
RagMessage,
chat_history_manager,
get_chat_history_manager,
)
from app.services.metadata_apis.session_metadata_api import Session
from app.services.mlflow import record_direct_llm_mlflow_run
Expand Down Expand Up @@ -217,4 +217,4 @@ def _stream_direct_llm_chat(
timestamp=time.time(),
condensed_question=None,
)
chat_history_manager.append_to_history(session.id, [new_chat_message])
get_chat_history_manager().append_to_history(session.id, [new_chat_message])
4 changes: 2 additions & 2 deletions llm-service/app/services/chat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from pydantic import BaseModel

from app.services.chat_history.chat_history_manager import (
chat_history_manager,
get_chat_history_manager,
RagPredictSourceNode,
)

Expand All @@ -54,7 +54,7 @@ class RagContext(BaseModel):


def retrieve_chat_history(session_id: int) -> List[RagContext]:
chat_history = chat_history_manager.retrieve_chat_history(session_id)[-10:]
chat_history = get_chat_history_manager().retrieve_chat_history(session_id)[-10:]
history: List[RagContext] = []
for message in chat_history:
history.append(
Expand Down
17 changes: 14 additions & 3 deletions llm-service/app/services/chat_history/chat_history_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
# DATA.
#

import functools
from abc import ABCMeta, abstractmethod
from typing import Optional, Literal

Expand Down Expand Up @@ -64,6 +64,9 @@ class RagStudioChatMessage(BaseModel):


class ChatHistoryManager(metaclass=ABCMeta):
def __init__(self) -> None:
pass

@abstractmethod
def retrieve_chat_history(self, session_id: int) -> list[RagStudioChatMessage]:
pass
Expand All @@ -85,7 +88,13 @@ def append_to_history(
pass


def _create_chat_history_manager() -> ChatHistoryManager:
@functools.cache
def _get_chat_history_manager() -> ChatHistoryManager:
"""Create a ChatHistoryManager the first time this function is called, and return it.

This helper function can be monkey-patched for testing purposes.

"""
from app.services.chat_history.simple_chat_history_manager import (
SimpleChatHistoryManager,
)
Expand All @@ -101,4 +110,6 @@ def _create_chat_history_manager() -> ChatHistoryManager:
return SimpleChatHistoryManager()


chat_history_manager = _create_chat_history_manager()
def get_chat_history_manager() -> ChatHistoryManager:
"""Return a ChatHistoryManager based on the app's chat store config."""
return _get_chat_history_manager()
18 changes: 8 additions & 10 deletions llm-service/app/services/chat_history/s3_chat_history_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
# DATA.
#

import functools
import json
import logging
from typing import List
from typing import List, cast

import boto3
from boto3 import Session
Expand All @@ -57,18 +57,16 @@
class S3ChatHistoryManager(ChatHistoryManager):
"""Chat history manager that uses S3 for storage."""

def __init__(self, bucket_name: str = settings.document_bucket):
self.bucket_name = bucket_name
def __init__(self) -> None:
super().__init__()
self.bucket_name = settings.document_bucket
self.bucket_prefix = settings.document_bucket_prefix
self._s3_client: S3Client | None = None

@property
@functools.cached_property
def s3_client(self) -> S3Client:
"""Lazy initialization of S3 client."""
Comment on lines -63 to 67
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out functools.cached_property is precisely the thing for "lazily initialize an instance attribute"!

This also helps avoid conflating self.s3_client and self._s3_client since conventionally, both could be used inside a class's methods.

if self._s3_client is None:
session: Session = boto3.session.Session()
self._s3_client = session.client("s3")
return self._s3_client
session: Session = boto3.session.Session()
return cast(S3Client, session.client("s3"))

def _get_s3_key(self, session_id: int) -> str:
"""Build the S3 key for a session's chat history."""
Expand Down
6 changes: 3 additions & 3 deletions llm-service/app/services/llm_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from . import models
from .chat_history.chat_history_manager import (
RagStudioChatMessage,
chat_history_manager,
get_chat_history_manager,
)
from .query.query_configuration import QueryConfiguration

Expand All @@ -60,7 +60,7 @@ def make_chat_messages(x: RagStudioChatMessage) -> list[ChatMessage]:

def completion(session_id: int, question: str, model_name: str) -> ChatResponse:
model = models.LLM.get(model_name)
chat_history = chat_history_manager.retrieve_chat_history(session_id)[:10]
chat_history = get_chat_history_manager().retrieve_chat_history(session_id)[:10]
messages = list(
itertools.chain.from_iterable(
map(lambda x: make_chat_messages(x), chat_history)
Expand All @@ -78,7 +78,7 @@ def stream_completion(
Returns a generator that yields ChatResponse objects as they become available.
"""
model = models.LLM.get(model_name)
chat_history = chat_history_manager.retrieve_chat_history(session_id)[:10]
chat_history = get_chat_history_manager().retrieve_chat_history(session_id)[:10]
messages = list(
itertools.chain.from_iterable(
map(lambda x: make_chat_messages(x), chat_history)
Expand Down
4 changes: 2 additions & 2 deletions llm-service/app/services/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

from . import models
from .chat_history.chat_history_manager import (
chat_history_manager,
get_chat_history_manager,
RagStudioChatMessage,
)
from .metadata_apis import session_metadata_api
Expand Down Expand Up @@ -87,7 +87,7 @@

def rename_session(session_id: int, user_name: Optional[str]) -> str:
chat_history: list[RagStudioChatMessage] = (
chat_history_manager.retrieve_chat_history(session_id=session_id)
get_chat_history_manager().retrieve_chat_history(session_id=session_id)
)
if not chat_history:
logger.info("No chat history found for session ID %s", session_id)
Expand Down
Loading