Skip to content

feat: MCP instrumentation for tool, resources, and prompt clean #1675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
from typing import Any, AsyncGenerator, Callable, Collection, Tuple, cast

from opentelemetry import context, propagate
from opentelemetry import trace as trace_api
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor # type: ignore
from opentelemetry.instrumentation.utils import unwrap
from opentelemetry.trace import Status, StatusCode
from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper

from openinference.instrumentation import safe_json_dumps
from openinference.instrumentation.mcp.package import _instruments
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes


class MCPInstrumentor(BaseInstrumentor): # type: ignore
Expand All @@ -19,6 +23,35 @@ def instrumentation_dependencies(self) -> Collection[str]:
return _instruments

def _instrument(self, **kwargs: Any) -> None:
# Instrument high-level MCP client operations
register_post_import_hook(
lambda _: wrap_function_wrapper(
"mcp.client.session",
"ClientSession.call_tool",
self._wrap_call_tool,
),
"mcp.client.session",
)

register_post_import_hook(
lambda _: wrap_function_wrapper(
"mcp.client.session",
"ClientSession.get_prompt",
self._wrap_get_prompt,
),
"mcp.client.session",
)

register_post_import_hook(
lambda _: wrap_function_wrapper(
"mcp.client.session",
"ClientSession.read_resource",
self._wrap_read_resource,
),
"mcp.client.session",
)

# Existing transport-level instrumentation for context propagation
register_post_import_hook(
lambda _: wrap_function_wrapper(
"mcp.client.streamable_http",
Expand Down Expand Up @@ -61,6 +94,14 @@ def _instrument(self, **kwargs: Any) -> None:
),
"mcp.server.stdio",
)
register_post_import_hook(
lambda _: wrap_function_wrapper(
"mcp.server.lowlevel.server",
"Server._handle_request",
self._wrap_handle_request,
),
"mcp.server.lowlevel.server",
)

# While we prefer to instrument the lowest level primitive, the transports above, it doesn't
# mean context will be propagated to handlers automatically. Notably, the MCP SDK passes
Expand All @@ -77,8 +118,12 @@ def _instrument(self, **kwargs: Any) -> None:
)

def _uninstrument(self, **kwargs: Any) -> None:
unwrap("mcp.client.session", "ClientSession.call_tool")
unwrap("mcp.client.session", "ClientSession.get_prompt")
unwrap("mcp.client.session", "ClientSession.read_resource")
unwrap("mcp.client.stdio", "stdio_client")
unwrap("mcp.server.stdio", "stdio_server")
unwrap("mcp.client.session", "ClientSession.call_tool")

@asynccontextmanager
async def _wrap_transport_with_callback(
Expand All @@ -98,6 +143,21 @@ async def _wrap_plain_transport(
async with wrapped(*args, **kwargs) as (read_stream, write_stream):
yield InstrumentedStreamReader(read_stream), InstrumentedStreamWriter(write_stream)

async def _wrap_handle_request(
self, wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any
) -> Any:
token = None
try:
# Message has been deserialized, we need to extract the traceparent
_meta = {"traceparent": args[1].params.meta.traceparent}
ctx = propagate.extract(_meta)
token = context.attach(ctx)
finally:
res = await wrapped(*args, **kwargs)
if token:
context.detach(token)
return res

def _base_session_init_wrapper(
self, wrapped: Callable[..., None], instance: Any, args: Any, kwargs: Any
) -> None:
Expand All @@ -110,6 +170,143 @@ def _base_session_init_wrapper(
)
setattr(instance, "_incoming_message_stream_writer", ContextSavingStreamWriter(writer))

async def _wrap_call_tool(
self, wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any
) -> Any:
"""Wrap MCP call_tool operation with tracing."""
# Extract arguments
name = kwargs.get("name", "unknown")
arguments = kwargs.get("arguments")

tracer = trace_api.get_tracer(__name__)
with tracer.start_as_current_span(
f"mcp.call_tool.{name}",
attributes={
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
SpanAttributes.TOOL_NAME: name,
SpanAttributes.SESSION_ID: instance._request_id,
SpanAttributes.INPUT_VALUE: safe_json_dumps(kwargs),
SpanAttributes.INPUT_MIME_TYPE: "application/json",
},
) as span:
# Add input attributes
if arguments:
span.set_attribute(SpanAttributes.TOOL_PARAMETERS, safe_json_dumps(arguments))

try:
# Call the original method
result = await wrapped(*args, **kwargs)

# Add output attributes
# TODO: handle content types
if hasattr(result, "content") and result.content:
span.set_attribute(SpanAttributes.OUTPUT_VALUE, safe_json_dumps(result.content))
span.set_attribute(SpanAttributes.OUTPUT_MIME_TYPE, "application/json")

span.set_status(Status(StatusCode.OK))
return result
except Exception as e:
span.set_status(Status(StatusCode.ERROR, str(e)))
span.record_exception(e)
raise

# TODO: update once OpenInference support for prompts grows
async def _wrap_get_prompt(
self, wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any
) -> Any:
"""Wrap MCP get_prompt operation with tracing."""
# Extract arguments
name = kwargs.get("name", "unknown")

tracer = trace_api.get_tracer(__name__)
with tracer.start_as_current_span(
f"mcp.get_prompt.{name}",
attributes={
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.UNKNOWN.value,
SpanAttributes.PROMPT_ID: name,
SpanAttributes.SESSION_ID: instance._request_id,
SpanAttributes.INPUT_VALUE: safe_json_dumps(kwargs),
SpanAttributes.INPUT_MIME_TYPE: "application/json",
},
) as span:
try:
# Call the original method
result = await wrapped(*args, **kwargs)

# Add output attributes
if hasattr(result, "messages") and result.messages:
span.set_attribute(
SpanAttributes.OUTPUT_VALUE,
safe_json_dumps(
{"description": result.description, "messages": result.messages}
),
)
span.set_attribute(SpanAttributes.OUTPUT_MIME_TYPE, "application/json")

span.set_status(Status(StatusCode.OK))
return result
except Exception as e:
span.set_status(Status(StatusCode.ERROR, str(e)))
span.record_exception(e)
raise

# TODO: further e2e testing once Phoenix error #7687 is resolved
async def _wrap_read_resource(
self, wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any
) -> Any:
"""Wrap MCP read_resource operation with tracing."""
# Extract arguments
uri = args[0]
tracer = trace_api.get_tracer(__name__)
with tracer.start_as_current_span(
f"mcp.read_resource.{uri}",
attributes={
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.RETRIEVER.value,
SpanAttributes.SESSION_ID: instance._request_id,
SpanAttributes.INPUT_VALUE: safe_json_dumps({"uri": uri}),
SpanAttributes.INPUT_MIME_TYPE: "application/json",
},
) as span:
try:
# Call the original method
result = await wrapped(*args, **kwargs)
# Add output attributes
if hasattr(result, "contents") and result.contents:
for i, content in enumerate(result.contents):
if content.text:
span.set_attribute(
f"retrieval.documents.{i}.document.content", content.text
)
span.set_attribute(
f"retrieval.documents.{i}.document.metadata",
safe_json_dumps({"type": content.__class__.__name__}),
)
elif content.blob:
span.set_attribute(
f"retrieval.documents.{i}.document.content", content.blob
)
span.set_attribute(
f"retrieval.documents.{i}.document.metadata",
safe_json_dumps({"type": content.__class__.__name__}),
)
else:
# fail silently for now
print(f"Unknown document type: {type(content)}")
serialized_contents = [content.model_dump() for content in result.contents]
span.set_attribute(
SpanAttributes.OUTPUT_VALUE,
safe_json_dumps({"contents": serialized_contents}),
)
span.set_attribute(SpanAttributes.OUTPUT_MIME_TYPE, "application/json")

span.set_status(Status(StatusCode.OK))
return result

except Exception as e:
span.set_status(Status(StatusCode.ERROR, str(e)))
span.record_exception(e)
raise


class InstrumentedStreamReader(ObjectProxy): # type: ignore
# ObjectProxy missing context manager - https://github.com/GrahamDumpleton/wrapt/issues/73
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.3.0"
__version__ = "1.4.0"
Loading