Skip to content

Commit 1029499

Browse files
committed
feat: impl nebulagraph mcp server
Signed-off-by: Chojan Shang <[email protected]>
1 parent 54581cc commit 1029499

File tree

4 files changed

+526
-0
lines changed

4 files changed

+526
-0
lines changed

.env.example

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,6 @@ NEBULA_HOST=your-nebulagraph-server-host
88
NEBULA_PORT=your-nebulagraph-server-port
99
NEBULA_USER=your-nebulagraph-server-user
1010
NEBULA_PASSWORD=your-nebulagraph-server-password
11+
12+
# For integration tests
13+
RUN_INTEGRATION_TESTS=false # set to true to run integration tests
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1+
from .server import main
12

3+
__all__ = ["main"]
Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
import argparse
2+
import os
3+
from contextlib import asynccontextmanager
4+
from dataclasses import dataclass
5+
from typing import AsyncIterator
6+
7+
from dotenv import load_dotenv
8+
from mcp.server.fastmcp import FastMCP
9+
from nebula3.Config import Config
10+
from nebula3.gclient.net import ConnectionPool
11+
12+
load_dotenv()
13+
14+
15+
@dataclass
16+
class NebulaContext:
17+
pool: ConnectionPool
18+
19+
20+
# Create a global connection pool
21+
config = Config()
22+
config.max_connection_pool_size = 10
23+
global_pool = ConnectionPool()
24+
25+
26+
def get_connection_pool() -> ConnectionPool:
27+
"""Get the global connection pool"""
28+
return global_pool
29+
30+
31+
@asynccontextmanager
32+
async def nebula_lifespan(server: FastMCP) -> AsyncIterator[NebulaContext]:
33+
"""This is a context manager for NebulaGraph connection."""
34+
try:
35+
if os.environ["NEBULA_VERSION"] != "v3":
36+
raise ValueError("NebulaGraph version must be v3")
37+
38+
# Initialize the connection
39+
global_pool.init(
40+
[
41+
(
42+
os.getenv("NEBULA_HOST", "127.0.0.1"),
43+
int(os.getenv("NEBULA_PORT", "9669")),
44+
)
45+
],
46+
config,
47+
)
48+
49+
yield NebulaContext(pool=global_pool)
50+
finally:
51+
# Clean up the connection
52+
global_pool.close()
53+
54+
55+
# Create MCP server
56+
mcp = FastMCP("NebulaGraph MCP Server", lifespan=nebula_lifespan)
57+
58+
59+
@mcp.resource("schema://space/{space}")
60+
def get_space_schema_resource(space: str) -> str:
61+
"""Get the schema information of the specified space
62+
Args:
63+
space: The space to get the schema for
64+
Returns:
65+
The schema information of the specified space
66+
"""
67+
pool = get_connection_pool()
68+
session = pool.get_session(
69+
os.getenv("NEBULA_USER", "root"), os.getenv("NEBULA_PASSWORD", "nebula")
70+
)
71+
72+
try:
73+
session.execute(f"USE {space}")
74+
# Get tags
75+
tags = session.execute("SHOW TAGS").column_values("Name")
76+
# Get edges
77+
edges = session.execute("SHOW EDGES").column_values("Name")
78+
79+
schema = f"Space: {space}\n\nTags:\n"
80+
for tag in tags:
81+
tag_result = session.execute(f"DESCRIBE TAG {tag}")
82+
schema += f"\n{tag}:\n"
83+
# Iterate through all rows
84+
for i in range(tag_result.row_size()):
85+
field = tag_result.row_values(i)
86+
schema += f" - {field[0]}: {field[1]}\n"
87+
88+
schema += "\nEdges:\n"
89+
for edge in edges:
90+
edge_result = session.execute(f"DESCRIBE EDGE {edge}")
91+
schema += f"\n{edge}:\n"
92+
# Iterate through all rows
93+
for i in range(edge_result.row_size()):
94+
field = edge_result.row_values(i)
95+
schema += f" - {field[0]}: {field[1]}\n"
96+
97+
return schema
98+
finally:
99+
session.release()
100+
101+
102+
@mcp.resource("path://space/{space}/from/{src}/to/{dst}/depth/{depth}/limit/{limit}")
103+
def get_path_resource(space: str, src: str, dst: str, depth: int, limit: int) -> str:
104+
"""Get the path between two vertices
105+
Args:
106+
space: The space to use
107+
src: The source vertex ID
108+
dst: The destination vertex ID
109+
depth: The maximum path depth
110+
limit: The maximum number of paths to return
111+
Returns:
112+
The path between the source and destination vertices
113+
"""
114+
pool = get_connection_pool()
115+
session = pool.get_session(
116+
os.getenv("NEBULA_USER", "root"), os.getenv("NEBULA_PASSWORD", "nebula")
117+
)
118+
119+
try:
120+
session.execute(f"USE {space}")
121+
122+
query = f"""FIND ALL PATH WITH PROP FROM "{src}" TO "{dst}" OVER * BIDIRECT UPTO {depth} STEPS
123+
YIELD PATH AS paths | LIMIT {limit}"""
124+
125+
result = session.execute(query)
126+
if result.is_succeeded():
127+
# Format the path results
128+
if result.row_size() > 0:
129+
output = f"Find paths from {src} to {dst}: \n\n"
130+
131+
# Iterate through all paths
132+
for i in range(result.row_size()):
133+
path = result.row_values(i)[0] # The path should be in the first column
134+
output += f"Path {i + 1}:\n{path}\n\n"
135+
136+
return output
137+
return f"No paths found from {src} to {dst}"
138+
else:
139+
return f"Query failed: {result.error_msg()}"
140+
finally:
141+
session.release()
142+
143+
144+
@mcp.tool()
145+
def list_spaces() -> str:
146+
"""List all available spaces
147+
Returns:
148+
The available spaces
149+
"""
150+
pool = get_connection_pool()
151+
session = pool.get_session(
152+
os.getenv("NEBULA_USER", "root"), os.getenv("NEBULA_PASSWORD", "nebula")
153+
)
154+
155+
try:
156+
result = session.execute("SHOW SPACES")
157+
if result.is_succeeded():
158+
spaces = result.column_values("Name")
159+
return "Available spaces:\n" + "\n".join(f"- {space}" for space in spaces)
160+
return f"Failed to list spaces: {result.error_msg()}"
161+
finally:
162+
session.release()
163+
164+
165+
@mcp.tool()
166+
def get_space_schema(space: str) -> str:
167+
"""Get the schema information of the specified space
168+
Args:
169+
space: The space to get the schema for
170+
Returns:
171+
The schema information of the specified space
172+
"""
173+
return get_space_schema_resource(space)
174+
175+
176+
@mcp.tool()
177+
def execute_query(query: str, space: str) -> str:
178+
"""Execute a query
179+
Args:
180+
query: The query to execute
181+
space: The space to use
182+
Returns:
183+
The results of the query
184+
"""
185+
pool = get_connection_pool()
186+
session = pool.get_session(
187+
os.getenv("NEBULA_USER", "root"), os.getenv("NEBULA_PASSWORD", "nebula")
188+
)
189+
190+
try:
191+
session.execute(f"USE {space}")
192+
result = session.execute(query)
193+
if result.is_succeeded():
194+
# Format the query results
195+
if result.row_size() > 0:
196+
columns = result.keys()
197+
output = "Results:\n"
198+
output += " | ".join(columns) + "\n"
199+
output += "-" * (len(" | ".join(columns))) + "\n"
200+
201+
# Iterate through all rows
202+
for i in range(result.row_size()):
203+
row = result.row_values(i)
204+
output += " | ".join(str(val) for val in row) + "\n"
205+
return output
206+
return "Query executed successfully (no results)"
207+
else:
208+
return f"Query failed: {result.error_msg()}"
209+
finally:
210+
session.release()
211+
212+
213+
@mcp.tool()
214+
def find_path(src: str, dst: str, space: str, depth: int = 3, limit: int = 10) -> str:
215+
"""Find paths between two vertices
216+
Args:
217+
src: The source vertex ID
218+
dst: The destination vertex ID
219+
space: The space to use
220+
depth: The maximum path depth
221+
limit: The maximum number of paths to return
222+
Returns:
223+
The path results
224+
"""
225+
return get_path_resource(space, src, dst, depth, limit)
226+
227+
228+
@mcp.resource("neighbors://space/{space}/vertex/{vertex}/depth/{depth}")
229+
def get_neighbors_resource(space: str, vertex: str, depth: int) -> str:
230+
"""Get the neighbors of the specified vertex
231+
Args:
232+
space: The space to use
233+
vertex: The vertex ID to query
234+
depth: The depth of the query
235+
Returns:
236+
The neighbors of the specified vertex
237+
"""
238+
pool = get_connection_pool()
239+
session = pool.get_session(
240+
os.getenv("NEBULA_USER", "root"), os.getenv("NEBULA_PASSWORD", "nebula")
241+
)
242+
243+
try:
244+
session.execute(f"USE {space}")
245+
246+
query = f"""
247+
MATCH (u)-[e*1..{depth}]-(v)
248+
WHERE id(u) == "{vertex}"
249+
RETURN DISTINCT v, e
250+
"""
251+
252+
result = session.execute(query)
253+
if result.is_succeeded():
254+
if result.row_size() > 0:
255+
output = f"Vertex {vertex} neighbors (depth {depth}):\n\n"
256+
for i in range(result.row_size()):
257+
row = result.row_values(i)
258+
neighbor_vertex = row[0]
259+
edges = row[1]
260+
output += f"Neighbor Vertex:\n{neighbor_vertex}\nEdges:\n{edges}\n\n"
261+
return output
262+
return f"No neighbors found for vertex {vertex}"
263+
else:
264+
return f"Query failed: {result.error_msg()}"
265+
finally:
266+
session.release()
267+
268+
269+
@mcp.tool()
270+
def find_neighbors(vertex: str, space: str, depth: int = 1) -> str:
271+
"""Find the neighbors of the specified vertex
272+
Args:
273+
vertex: The vertex ID to query
274+
space: The space to use
275+
depth: The depth of the query, default is 1
276+
Returns:
277+
The neighbors of the specified vertex
278+
"""
279+
return get_neighbors_resource(space, vertex, depth)
280+
281+
282+
def main():
283+
parser = argparse.ArgumentParser(description="NebulaGraph MCP server")
284+
parser.add_argument(
285+
"--transport",
286+
type=str,
287+
choices=["stdio", "sse"],
288+
default="stdio",
289+
help="Transport method (stdio or sse)",
290+
)
291+
292+
args = parser.parse_args()
293+
294+
if args.transport == "sse":
295+
mcp.run("sse")
296+
else:
297+
mcp.run("stdio")
298+
299+
300+
if __name__ == "__main__":
301+
main()

0 commit comments

Comments
 (0)