|
16 | 16 | from __future__ import annotations |
17 | 17 |
|
18 | 18 | import asyncio |
| 19 | +import logging |
19 | 20 |
|
| 21 | +from zep_cloud import ApiError |
| 22 | +from zep_cloud import NotFoundError |
20 | 23 | from zep_cloud.client import AsyncZep |
21 | 24 | from zep_cloud.types import Message |
22 | 25 |
|
| 26 | +from nat.builder.context import Context |
23 | 27 | from nat.memory.interfaces import MemoryEditor |
24 | 28 | from nat.memory.models import MemoryItem |
25 | 29 |
|
| 30 | +logger = logging.getLogger(__name__) |
| 31 | + |
26 | 32 |
|
27 | 33 | class ZepEditor(MemoryEditor): |
28 | 34 | """ |
29 | | - Wrapper class that implements NAT interfaces for Zep Integrations Async. |
| 35 | + Wrapper class that implements NAT interfaces for Zep v3 Integrations Async. |
| 36 | + Uses thread-based memory management with automatic user creation. |
30 | 37 | """ |
31 | 38 |
|
32 | | - def __init__(self, zep_client: AsyncZep): |
| 39 | + def __init__(self, zep_client: AsyncZep) -> None: |
33 | 40 | """ |
34 | | - Initialize class with Predefined Mem0 Client. |
| 41 | + Initialize class with Zep v3 AsyncZep Client. |
35 | 42 |
|
36 | 43 | Args: |
37 | | - zep_client (AsyncZep): Async client instance. |
| 44 | + zep_client (AsyncZep): Async client instance. |
38 | 45 | """ |
39 | 46 | self._client = zep_client |
40 | 47 |
|
41 | | - async def add_items(self, items: list[MemoryItem]) -> None: |
| 48 | + async def _ensure_user_exists(self, user_id: str) -> None: |
| 49 | + """ |
| 50 | + Ensure a user exists in Zep v3, creating if necessary. |
| 51 | +
|
| 52 | + Args: |
| 53 | + user_id (str): The user ID to check/create. |
42 | 54 | """ |
43 | | - Insert Multiple MemoryItems into the memory. Each MemoryItem is translated and uploaded. |
| 55 | + logger.debug("Checking if Zep user exists") |
| 56 | + try: |
| 57 | + await self._client.user.get(user_id=user_id) |
| 58 | + logger.debug("Zep user already exists") |
| 59 | + except NotFoundError: |
| 60 | + # User doesn't exist, create with basic info |
| 61 | + logger.info("Zep user not found, creating...") |
| 62 | + try: |
| 63 | + # Set defaults only for default_user, otherwise use just user_id |
| 64 | + if user_id == "default_user": |
| 65 | + |
| 66 | + first_name = "Jane" |
| 67 | + last_name = "Doe" |
| 68 | + await self._client.user.add(user_id=user_id, |
| 69 | + email=email, |
| 70 | + first_name=first_name, |
| 71 | + last_name=last_name) |
| 72 | + else: |
| 73 | + # For non-default users, just use user_id (email/names not required) |
| 74 | + await self._client.user.add(user_id=user_id) |
| 75 | + |
| 76 | + logger.info("Created Zep user") |
| 77 | + except ApiError as e: |
| 78 | + # Check if user was created by another request (409 Conflict) |
| 79 | + if e.response_data and e.response_data.get("status_code") == 409: |
| 80 | + logger.info("Zep user already exists (409), continuing") |
| 81 | + else: |
| 82 | + logger.error("Failed creating Zep user: %s", str(e)) # noqa: TRY400 |
| 83 | + raise |
| 84 | + except ApiError as e: |
| 85 | + logger.error("Failed fetching Zep user: %s", str(e)) # noqa: TRY400 |
| 86 | + raise |
| 87 | + |
| 88 | + async def add_items(self, items: list[MemoryItem], **kwargs) -> None: |
44 | 89 | """ |
| 90 | + Insert Multiple MemoryItems into the memory using Zep v3 thread API. |
| 91 | + Each MemoryItem is translated and uploaded to a thread. |
| 92 | + Uses conversation_id from NAT context as thread_id for multi-thread support. |
| 93 | +
|
| 94 | + Args: |
| 95 | + items (list[MemoryItem]): The items to be added. |
| 96 | + kwargs (dict): Provider-specific keyword arguments. |
| 97 | +
|
| 98 | + - ignore_roles (list[str], optional): List of role types to ignore when adding |
| 99 | + messages to graph memory. Available roles: system, assistant, user, |
| 100 | + function, tool. |
| 101 | + """ |
| 102 | + # Extract Zep-specific parameters |
| 103 | + ignore_roles = kwargs.get("ignore_roles", None) |
45 | 104 |
|
46 | 105 | coroutines = [] |
| 106 | + created_threads: set[str] = set() |
| 107 | + ensured_users: set[str] = set() |
47 | 108 |
|
48 | | - # Iteratively insert memories into Mem0 |
| 109 | + # Iteratively insert memories into Zep using threads |
49 | 110 | for memory_item in items: |
50 | 111 | conversation = memory_item.conversation |
51 | | - session_id = memory_item.user_id |
| 112 | + user_id = memory_item.user_id or "default_user" # Validate user_id |
| 113 | + |
| 114 | + # Get thread_id from NAT context (unique per UI conversation) |
| 115 | + thread_id = Context.get().conversation_id |
| 116 | + |
| 117 | + # Fallback to default thread ID if no conversation_id available |
| 118 | + if not thread_id: |
| 119 | + thread_id = "default_zep_thread" |
| 120 | + |
52 | 121 | messages = [] |
| 122 | + |
| 123 | + # Ensure user exists before creating thread (only once per user) |
| 124 | + if user_id not in ensured_users: |
| 125 | + await self._ensure_user_exists(user_id) |
| 126 | + ensured_users.add(user_id) |
| 127 | + |
| 128 | + # Skip if no conversation data |
| 129 | + if not conversation: |
| 130 | + continue |
| 131 | + |
53 | 132 | for msg in conversation: |
54 | | - messages.append(Message(content=msg["content"], role_type=msg["role"])) |
| 133 | + # Create Message - role field instead of role_type in V3 |
| 134 | + message = Message(content=msg["content"], role=msg["role"]) |
| 135 | + messages.append(message) |
| 136 | + |
| 137 | + # Ensure thread exists once per thread_id |
| 138 | + thread_ready = True |
| 139 | + if thread_id not in created_threads: |
| 140 | + logger.info("Ensuring Zep thread exists (thread_id=%s)", thread_id) |
| 141 | + try: |
| 142 | + await self._client.thread.create(thread_id=thread_id, user_id=user_id) |
| 143 | + logger.info("Created Zep thread (thread_id=%s)", thread_id) |
| 144 | + created_threads.add(thread_id) |
| 145 | + except ApiError as create_error: |
| 146 | + if create_error.response_data and create_error.response_data.get("status_code") == 409: |
| 147 | + logger.info("Zep thread already exists (thread_id=%s)", thread_id) |
| 148 | + created_threads.add(thread_id) |
| 149 | + else: |
| 150 | + logger.exception("Thread create failed (thread_id=%s)", thread_id) |
| 151 | + thread_ready = False |
| 152 | + |
| 153 | + # Skip this item if thread creation failed unexpectedly |
| 154 | + if not thread_ready: |
| 155 | + continue |
| 156 | + |
| 157 | + # Add messages to thread using Zep v3 API |
| 158 | + logger.info("Queueing add_messages (thread_id=%s, count=%d)", thread_id, len(messages)) |
55 | 159 |
|
56 | | - coroutines.append(self._client.memory.add(session_id=session_id, messages=messages)) |
| 160 | + # Build add_messages parameters |
| 161 | + add_messages_params = {"thread_id": thread_id, "messages": messages} |
| 162 | + if ignore_roles is not None: |
| 163 | + add_messages_params["ignore_roles"] = ignore_roles |
| 164 | + |
| 165 | + coroutines.append(self._client.thread.add_messages(**add_messages_params)) |
57 | 166 |
|
58 | 167 | await asyncio.gather(*coroutines) |
59 | 168 |
|
60 | | - async def search(self, query: str, top_k: int = 5, **kwargs) -> list[MemoryItem]: |
| 169 | + async def search(self, query: str, top_k: int = 5, **kwargs) -> list[MemoryItem]: # noqa: ARG002 |
61 | 170 | """ |
62 | | - Retrieve items relevant to the given query. |
| 171 | + Retrieve memory from Zep v3 using the high-level get_user_context API. |
| 172 | + Uses conversation_id from NAT context as thread_id for multi-thread support. |
| 173 | +
|
| 174 | + Zep returns pre-formatted memory optimized for LLM consumption, including |
| 175 | + relevant facts, timestamps, and structured information from its knowledge graph. |
63 | 176 |
|
64 | 177 | Args: |
65 | | - query (str): The query string to match. |
66 | | - top_k (int): Maximum number of items to return. |
67 | | - kwargs: Other keyword arguments for search. |
| 178 | + query (str): The query string (not used by Zep's high-level API, included for interface compatibility). |
| 179 | + top_k (int): Maximum number of items to return (not used by Zep's context API). |
| 180 | + kwargs: Zep-specific keyword arguments. |
| 181 | +
|
| 182 | + - user_id (str, required for response construction): Used only to construct the |
| 183 | + returned MemoryItem. Zep v3's thread.get_user_context() only requires thread_id. |
| 184 | + - mode (str, optional): Retrieval mode. Zep server default is "summary". This |
| 185 | + implementation uses mode="basic" (NAT's default) for performance (P95 < 200ms). |
| 186 | + "summary" provides more comprehensive memory at the cost of latency. |
68 | 187 |
|
69 | 188 | Returns: |
70 | | - list[MemoryItem]: The most relevant MemoryItems for the given query. |
| 189 | + list[MemoryItem]: A single MemoryItem containing the formatted context from Zep. |
71 | 190 | """ |
| 191 | + # Validate required kwargs |
| 192 | + if "user_id" not in kwargs or not kwargs["user_id"]: |
| 193 | + raise ValueError("user_id is required.") |
| 194 | + user_id = kwargs.pop("user_id") |
| 195 | + mode = kwargs.pop("mode", "basic") # Get mode, default to "basic" for fast retrieval |
| 196 | + |
| 197 | + # Get thread_id from NAT context |
| 198 | + thread_id = Context.get().conversation_id |
72 | 199 |
|
73 | | - session_id = kwargs.pop("user_id") # Ensure user ID is in keyword arguments |
74 | | - limit = top_k |
| 200 | + # Fallback to default thread ID if no conversation_id available |
| 201 | + if not thread_id: |
| 202 | + thread_id = "default_zep_thread" |
75 | 203 |
|
76 | | - search_result = await self._client.memory.search_sessions(session_ids=[session_id], |
77 | | - text=query, |
78 | | - limit=limit, |
79 | | - search_scope="messages", |
80 | | - **kwargs) |
| 204 | + try: |
| 205 | + # Use Zep v3 thread.get_user_context - returns pre-formatted context |
| 206 | + memory_response = await self._client.thread.get_user_context(thread_id=thread_id, mode=mode) |
| 207 | + context_string = memory_response.context or "" |
81 | 208 |
|
82 | | - # Construct MemoryItem instances |
83 | | - memories = [] |
| 209 | + # Return as a single MemoryItem with the formatted context |
| 210 | + if context_string: |
| 211 | + return [ |
| 212 | + MemoryItem(conversation=[], |
| 213 | + user_id=user_id, |
| 214 | + memory=context_string, |
| 215 | + metadata={ |
| 216 | + "mode": mode, "thread_id": thread_id |
| 217 | + }) |
| 218 | + ] |
| 219 | + else: |
| 220 | + return [] |
84 | 221 |
|
85 | | - for res in search_result.results: |
86 | | - memories.append( |
87 | | - MemoryItem(conversation=[], |
88 | | - user_id=session_id, |
89 | | - memory=res.message.content, |
90 | | - metadata={ |
91 | | - "relevance_score": res.score, |
92 | | - "created_at": res.message.created_at, |
93 | | - "updated_at": res.message.updated_at |
94 | | - })) |
| 222 | + except NotFoundError: |
| 223 | + # Thread doesn't exist or no context available |
| 224 | + return [] |
| 225 | + except ApiError as e: |
| 226 | + logger.error("get_user_context failed (thread_id=%s): %s", thread_id, str(e)) # noqa: TRY400 |
| 227 | + raise |
95 | 228 |
|
96 | | - return memories |
| 229 | + async def remove_items(self, **kwargs) -> None: |
| 230 | + """ |
| 231 | + Remove memory items based on provided criteria. |
| 232 | +
|
| 233 | + Supports two deletion modes: |
| 234 | +
|
| 235 | + 1. Delete a specific thread by thread_id |
| 236 | + 2. Delete all threads for a user by user_id |
| 237 | +
|
| 238 | + Args: |
| 239 | + kwargs: Additional parameters. |
| 240 | +
|
| 241 | + - thread_id (str, optional): Thread ID to delete a specific thread. |
| 242 | + - user_id (str, optional): User ID to delete all threads for that user. |
| 243 | + """ |
| 244 | + if "thread_id" in kwargs: |
| 245 | + # Delete specific thread |
| 246 | + thread_id = kwargs.pop("thread_id") |
| 247 | + logger.info("Deleting thread (thread_id=%s)", thread_id) |
| 248 | + await self._client.thread.delete(thread_id=thread_id) |
| 249 | + elif "user_id" in kwargs: |
| 250 | + # Delete all threads for a user |
| 251 | + user_id = kwargs.pop("user_id") |
| 252 | + logger.debug("Deleting all threads for user (user_id=%s)", user_id) |
97 | 253 |
|
98 | | - async def remove_items(self, **kwargs): |
| 254 | + # Get all threads for this user |
| 255 | + threads = await self._client.user.get_threads(user_id=user_id) |
| 256 | + logger.debug("Found %d threads for user (user_id=%s)", len(threads), user_id) |
99 | 257 |
|
100 | | - if "session_id" in kwargs: |
101 | | - session_id = kwargs.pop("session_id") |
102 | | - await self._client.memory.delete(session_id) |
| 258 | + # Delete each thread |
| 259 | + delete_coroutines = [] |
| 260 | + for thread in threads: |
| 261 | + if thread.thread_id: |
| 262 | + logger.debug("Queueing deletion of thread (thread_id=%s)", thread.thread_id) |
| 263 | + delete_coroutines.append(self._client.thread.delete(thread_id=thread.thread_id)) |
103 | 264 |
|
| 265 | + if delete_coroutines: |
| 266 | + await asyncio.gather(*delete_coroutines) |
| 267 | + logger.info("Deleted %d threads for user", len(delete_coroutines)) |
104 | 268 | else: |
105 | | - raise ValueError("session_id not provided as part of the tool call. ") |
| 269 | + raise ValueError("Either thread_id or user_id is required.") |
0 commit comments