Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,17 @@ build-backend = "setuptools.build_meta"

[project.scripts]
mcp-proxy = "mcp_proxy.__main__:main"

[tool.uv]
dev-dependencies = [
"pytest>=8.3.3",
"pytest-asyncio>=0.25.0",
]

[tool.pytest.ini_options]
pythonpath = "src"
addopts = [
"--import-mode=importlib",
]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
131 changes: 70 additions & 61 deletions src/mcp_proxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,79 +7,89 @@
logger = logging.getLogger(__name__)


async def confugure_app(name: str, remote_app: ClientSession):
app = server.Server(name)
async def create_proxy_server(remote_app: ClientSession):
"""Create a server instance from a remote app."""

async def _list_prompts(_: t.Any) -> types.ServerResult:
result = await remote_app.list_prompts()
return types.ServerResult(result)
response = await remote_app.initialize()
capabilities = response.capabilities

app.request_handlers[types.ListPromptsRequest] = _list_prompts
app = server.Server(response.serverInfo.name)

async def _get_prompt(req: types.GetPromptRequest) -> types.ServerResult:
result = await remote_app.get_prompt(req.params.name, req.params.arguments)
return types.ServerResult(result)
if capabilities.prompts:
async def _list_prompts(_: t.Any) -> types.ServerResult:
result = await remote_app.list_prompts()
return types.ServerResult(result)

app.request_handlers[types.GetPromptRequest] = _get_prompt
app.request_handlers[types.ListPromptsRequest] = _list_prompts

async def _list_resources(_: t.Any) -> types.ServerResult:
result = await remote_app.list_resources()
return types.ServerResult(result)
async def _get_prompt(req: types.GetPromptRequest) -> types.ServerResult:
result = await remote_app.get_prompt(req.params.name, req.params.arguments)
return types.ServerResult(result)

app.request_handlers[types.ListResourcesRequest] = _list_resources
app.request_handlers[types.GetPromptRequest] = _get_prompt

# list_resource_templates() is not implemented in the client
# async def _list_resource_templates(_: t.Any) -> types.ServerResult:
# result = await remote_app.list_resource_templates()
# return types.ServerResult(result)
if capabilities.resources:
async def _list_resources(_: t.Any) -> types.ServerResult:
result = await remote_app.list_resources()
return types.ServerResult(result)

# app.request_handlers[types.ListResourceTemplatesRequest] = _list_resource_templates
app.request_handlers[types.ListResourcesRequest] = _list_resources

async def _read_resource(req: types.ReadResourceRequest):
result = await remote_app.read_resource(req.params.uri)
return types.ServerResult(result)
# list_resource_templates() is not implemented in the client
# async def _list_resource_templates(_: t.Any) -> types.ServerResult:
# result = await remote_app.list_resource_templates()
# return types.ServerResult(result)

app.request_handlers[types.ReadResourceRequest] = _read_resource
# app.request_handlers[types.ListResourceTemplatesRequest] = _list_resource_templates

async def _set_logging_level(req: types.SetLevelRequest):
await remote_app.set_logging_level(req.params.level)
return types.ServerResult(types.EmptyResult())
async def _read_resource(req: types.ReadResourceRequest):
result = await remote_app.read_resource(req.params.uri)
return types.ServerResult(result)

app.request_handlers[types.SetLevelRequest] = _set_logging_level
app.request_handlers[types.ReadResourceRequest] = _read_resource

async def _subscribe_resource(req: types.SubscribeRequest):
await remote_app.subscribe_resource(req.params.uri)
return types.ServerResult(types.EmptyResult())
if capabilities.logging:
async def _set_logging_level(req: types.SetLevelRequest):
await remote_app.set_logging_level(req.params.level)
return types.ServerResult(types.EmptyResult())

app.request_handlers[types.SubscribeRequest] = _subscribe_resource
app.request_handlers[types.SetLevelRequest] = _set_logging_level

async def _unsubscribe_resource(req: types.UnsubscribeRequest):
await remote_app.unsubscribe_resource(req.params.uri)
return types.ServerResult(types.EmptyResult())
if capabilities.resources:
async def _subscribe_resource(req: types.SubscribeRequest):
await remote_app.subscribe_resource(req.params.uri)
return types.ServerResult(types.EmptyResult())

app.request_handlers[types.UnsubscribeRequest] = _unsubscribe_resource
app.request_handlers[types.SubscribeRequest] = _subscribe_resource

async def _list_tools(_: t.Any):
tools = await remote_app.list_tools()
return types.ServerResult(tools)
async def _unsubscribe_resource(req: types.UnsubscribeRequest):
await remote_app.unsubscribe_resource(req.params.uri)
return types.ServerResult(types.EmptyResult())

app.request_handlers[types.ListToolsRequest] = _list_tools
app.request_handlers[types.UnsubscribeRequest] = _unsubscribe_resource

async def _call_tool(req: types.CallToolRequest) -> types.ServerResult:
try:
result = await remote_app.call_tool(
req.params.name, (req.params.arguments or {})
)
return types.ServerResult(result)
except Exception as e:
return types.ServerResult(
types.CallToolResult(
content=[types.TextContent(type="text", text=str(e))],
isError=True,
if capabilities.tools:
async def _list_tools(_: t.Any):
tools = await remote_app.list_tools()
return types.ServerResult(tools)

app.request_handlers[types.ListToolsRequest] = _list_tools

async def _call_tool(req: types.CallToolRequest) -> types.ServerResult:
try:
result = await remote_app.call_tool(
req.params.name, (req.params.arguments or {})
)
return types.ServerResult(result)
except Exception as e:
return types.ServerResult(
types.CallToolResult(
content=[types.TextContent(type="text", text=str(e))],
isError=True,
)
)
)

app.request_handlers[types.CallToolRequest] = _call_tool
app.request_handlers[types.CallToolRequest] = _call_tool

async def _send_progress_notification(req: types.ProgressNotification):
await remote_app.send_progress_notification(
Expand All @@ -96,19 +106,18 @@ async def _complete(req: types.CompleteRequest):

app.request_handlers[types.CompleteRequest] = _complete

async with server.stdio_server() as (read_stream, write_stream):
await app.run(
read_stream,
write_stream,
app.create_initialization_options(),
)
return app


async def run_sse_client(url: str):
from mcp.client.sse import sse_client

async with sse_client(url=url) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
response = await session.initialize()

await confugure_app(response.serverInfo.name, session)
app = await create_proxy_server(session)
async with server.stdio_server() as (read_stream, write_stream):
await app.run(
read_stream,
write_stream,
app.create_initialization_options(),
)
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for mcp-proxy."""
109 changes: 109 additions & 0 deletions tests/test_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Tests for the mcp-proxy module.

Tests are running in two modes:
- One where the server is exercised directly though an in memory client, just to
set a baseline for the expected behavior.
- Another where the server is exercised through a proxy server, which forwards
the requests to the original server.

The same test code is run on both to ensure parity.
"""

from typing import Any
from collections.abc import AsyncGenerator, Callable
from contextlib import asynccontextmanager, AbstractAsyncContextManager

import pytest

from mcp import types
from mcp.client.session import ClientSession
from mcp.server import Server
from mcp.shared.exceptions import McpError
from mcp.shared.memory import create_connected_server_and_client_session

from mcp_proxy import create_proxy_server

TOOL_INPUT_SCHEMA = {
"type": "object",
"properties": {
"input1": {"type": "string"}
}
}

SessionContextManager = Callable[[Server], AbstractAsyncContextManager[ClientSession]]

# Direct server connection
in_memory: SessionContextManager = create_connected_server_and_client_session

@asynccontextmanager
async def proxy(server: Server) -> AsyncGenerator[ClientSession, None]:
"""Create a connection to the server through the proxy server."""
async with in_memory(server) as session:
wrapped_server = await create_proxy_server(session)
async with in_memory(wrapped_server) as wrapped_session:
yield wrapped_session


@pytest.fixture(params=["server", "proxy"], scope="function")
def session_generator(request: Any) -> SessionContextManager:
"""Fixture that returns a client creation strategy either direct or using the proxy."""
if request.param == "server":
return in_memory
return proxy


async def test_list_prompts(session_generator: SessionContextManager):
"""Test list_prompts."""

server = Server("prompt-server")

@server.list_prompts()
async def list_prompts() -> list[types.Prompt]:
return [types.Prompt(name="prompt1")]

async with session_generator(server) as session:
result = await session.initialize()
assert result.serverInfo.name == "prompt-server"
assert result.capabilities
assert result.capabilities.prompts
assert not result.capabilities.tools
assert not result.capabilities.resources
assert not result.capabilities.logging

result = await session.list_prompts()
assert result.prompts == [types.Prompt(name="prompt1")]

with pytest.raises(McpError, match="Method not found"):
await session.list_tools()


async def test_list_tools(session_generator: SessionContextManager):
"""Test list_tools."""

server = Server("tools-server")

@server.list_tools()
async def list_tools() -> list[types.Tool]:
return [types.Tool(
name="tool-name",
description="tool-description",
inputSchema=TOOL_INPUT_SCHEMA
)]

async with session_generator(server) as session:
result = await session.initialize()
assert result.serverInfo.name == "tools-server"
assert result.capabilities
assert result.capabilities.tools
assert not result.capabilities.prompts
assert not result.capabilities.resources
assert not result.capabilities.logging

result = await session.list_tools()
assert len(result.tools) == 1
assert result.tools[0].name == "tool-name"
assert result.tools[0].description == "tool-description"
assert result.tools[0].inputSchema == TOOL_INPUT_SCHEMA

with pytest.raises(McpError, match="Method not found"):
await session.list_prompts()
Loading