Skip to content

Commit 5382798

Browse files
authored
Fixes to detect optional parameters in tool conversion used by "nat mcp serve" (#1133)
Parameters with default values were incorrectly marked as required in MCP tool schemas. Now checking for PydanticUndefined and properly extracting default values from Pydantic fields to correctly identify optional parameters. This PR also adds comprehensive unit tests for the tool schema conversion - I am familiar with the [Contributing Guidelines](https://github.com/NVIDIA/NeMo-Agent-Toolkit/blob/develop/docs/source/resources/contributing.md). - We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license. - Any contribution which contains commits that are not Signed-Off will not be accepted. - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. * **New Features** * Improved detection and handling of optional parameters and default/default-factory values in tool wrapper generation, with unified validation and consistent invocation behavior for workflows and regular functions. * Ensures parameter descriptions, order, and type annotations are preserved in generated wrappers. * **Tests** * Added comprehensive tests covering optional/default combinations, wrapper signatures, execution paths, and observability propagation. ## Summary by CodeRabbit ## Release Notes * **New Features** * Added improved support for optional parameters in MCP tools with enhanced default value handling. * **Bug Fixes** * Fixed inconsistent behavior with nested schema validation and default factory application in tool parameters. * **Tests** * Added comprehensive test coverage for optional field handling and parameter validation scenarios. Authors: - Will Killian (https://github.com/willkill07) - Anuradha Karuppiah (https://github.com/AnuradhaKaruppiah) Approvers: - Anuradha Karuppiah (https://github.com/AnuradhaKaruppiah) URL: #1133
1 parent 3a179be commit 5382798

File tree

2 files changed

+430
-22
lines changed

2 files changed

+430
-22
lines changed

src/nat/front_ends/mcp/tool_converter.py

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@
1818
from inspect import Parameter
1919
from inspect import Signature
2020
from typing import TYPE_CHECKING
21+
from typing import Any
2122

2223
from mcp.server.fastmcp import FastMCP
2324
from pydantic import BaseModel
25+
from pydantic.fields import FieldInfo
26+
from pydantic_core import PydanticUndefined
2427

2528
from nat.builder.context import ContextState
2629
from nat.builder.function import Function
@@ -31,6 +34,41 @@
3134

3235
logger = logging.getLogger(__name__)
3336

37+
# Sentinel: marks "optional; let Pydantic supply default/factory"
38+
_USE_PYDANTIC_DEFAULT = object()
39+
40+
41+
def is_field_optional(field: FieldInfo) -> tuple[bool, Any]:
42+
"""Determine if a Pydantic field is optional and extract its default value for MCP signatures.
43+
44+
For MCP tool signatures, we need to distinguish:
45+
- Required fields: marked with Parameter.empty
46+
- Optional with concrete default: use that default
47+
- Optional with factory: use sentinel so Pydantic can apply the factory later
48+
49+
Args:
50+
field: The Pydantic FieldInfo to check
51+
52+
Returns:
53+
A tuple of (is_optional, default_value):
54+
- (False, Parameter.empty) for required fields
55+
- (True, actual_default) for optional fields with explicit defaults
56+
- (True, _USE_PYDANTIC_DEFAULT) for optional fields with default_factory
57+
"""
58+
if field.is_required():
59+
return False, Parameter.empty
60+
61+
# Field is optional - has either default or factory
62+
if field.default is not PydanticUndefined:
63+
return True, field.default
64+
65+
# Factory case: mark optional in signature but don't fabricate a value
66+
if field.default_factory is not None:
67+
return True, _USE_PYDANTIC_DEFAULT
68+
69+
# Rare corner case: non-required yet no default surfaced
70+
return True, _USE_PYDANTIC_DEFAULT
71+
3472

3573
def create_function_wrapper(
3674
function_name: str,
@@ -76,12 +114,15 @@ def create_function_wrapper(
76114
# Get the field type and convert to appropriate Python type
77115
field_type = field.annotation
78116

117+
# Check if field is optional and get its default value
118+
_is_optional, param_default = is_field_optional(field)
119+
79120
# Add the parameter to our list
80121
parameters.append(
81122
Parameter(
82123
name=name,
83124
kind=Parameter.KEYWORD_ONLY,
84-
default=Parameter.empty if field.is_required else None,
125+
default=param_default,
85126
annotation=field_type,
86127
))
87128

@@ -140,33 +181,23 @@ async def call_with_observability(func_call):
140181
result = await call_with_observability(lambda: function.ainvoke(chat_request, to_type=str))
141182
else:
142183
# Regular handling
143-
# Handle complex input schema - if we extracted fields from a nested schema,
144-
# we need to reconstruct the input
145-
if len(schema.model_fields) == 1 and len(parameters) > 1:
146-
# Get the field name from the original schema
147-
field_name = next(iter(schema.model_fields.keys()))
148-
field_type = schema.model_fields[field_name].annotation
149-
150-
# If it's a pydantic model, we need to create an instance
151-
if field_type and hasattr(field_type, "model_validate"):
152-
# Create the nested object
153-
nested_obj = field_type.model_validate(kwargs)
154-
# Call with the nested object
155-
kwargs = {field_name: nested_obj}
184+
# Strip sentinel values so Pydantic can apply defaults/factories
185+
cleaned_kwargs = {k: v for k, v in kwargs.items() if v is not _USE_PYDANTIC_DEFAULT}
186+
187+
# Always validate with the declared schema
188+
# This handles defaults, factories, nested models, validators, etc.
189+
model_input = schema.model_validate(cleaned_kwargs)
156190

157191
# Call the NAT function with the parameters - special handling for Workflow
158192
if is_workflow:
159-
# For workflow with regular input, we'll assume the first parameter is the input
160-
input_value = list(kwargs.values())[0] if kwargs else ""
161-
162-
# Workflows have a run method that is an async context manager
163-
# that returns a Runner
164-
async with function.run(input_value) as runner:
193+
# Workflows expect the model instance directly
194+
async with function.run(model_input) as runner:
165195
# Get the result from the runner
166196
result = await runner.result(to_type=str)
167197
else:
168-
# Regular function call
169-
result = await call_with_observability(lambda: function.acall_invoke(**kwargs))
198+
# Regular function call - unpack the validated model
199+
result = await call_with_observability(lambda: function.acall_invoke(**model_input.model_dump())
200+
)
170201

171202
# Report completion
172203
if ctx:

0 commit comments

Comments
 (0)