Skip to content

Commit d7db6b2

Browse files
committed
Merge branch 'release/1.3' of github.com:NVIDIA/NeMo-Agent-Toolkit into david-fix-async-chat
Signed-off-by: David Gardner <[email protected]>
2 parents 584ab70 + 5382798 commit d7db6b2

File tree

2 files changed

+430
-22
lines changed

2 files changed

+430
-22
lines changed

src/nat/front_ends/mcp/tool_converter.py

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@
1818
from inspect import Parameter
1919
from inspect import Signature
2020
from typing import TYPE_CHECKING
21+
from typing import Any
2122

2223
from mcp.server.fastmcp import FastMCP
2324
from pydantic import BaseModel
25+
from pydantic.fields import FieldInfo
26+
from pydantic_core import PydanticUndefined
2427

2528
from nat.builder.context import ContextState
2629
from nat.builder.function import Function
@@ -31,6 +34,41 @@
3134

3235
logger = logging.getLogger(__name__)
3336

37+
# Sentinel: marks "optional; let Pydantic supply default/factory"
38+
_USE_PYDANTIC_DEFAULT = object()
39+
40+
41+
def is_field_optional(field: FieldInfo) -> tuple[bool, Any]:
42+
"""Determine if a Pydantic field is optional and extract its default value for MCP signatures.
43+
44+
For MCP tool signatures, we need to distinguish:
45+
- Required fields: marked with Parameter.empty
46+
- Optional with concrete default: use that default
47+
- Optional with factory: use sentinel so Pydantic can apply the factory later
48+
49+
Args:
50+
field: The Pydantic FieldInfo to check
51+
52+
Returns:
53+
A tuple of (is_optional, default_value):
54+
- (False, Parameter.empty) for required fields
55+
- (True, actual_default) for optional fields with explicit defaults
56+
- (True, _USE_PYDANTIC_DEFAULT) for optional fields with default_factory
57+
"""
58+
if field.is_required():
59+
return False, Parameter.empty
60+
61+
# Field is optional - has either default or factory
62+
if field.default is not PydanticUndefined:
63+
return True, field.default
64+
65+
# Factory case: mark optional in signature but don't fabricate a value
66+
if field.default_factory is not None:
67+
return True, _USE_PYDANTIC_DEFAULT
68+
69+
# Rare corner case: non-required yet no default surfaced
70+
return True, _USE_PYDANTIC_DEFAULT
71+
3472

3573
def create_function_wrapper(
3674
function_name: str,
@@ -76,12 +114,15 @@ def create_function_wrapper(
76114
# Get the field type and convert to appropriate Python type
77115
field_type = field.annotation
78116

117+
# Check if field is optional and get its default value
118+
_is_optional, param_default = is_field_optional(field)
119+
79120
# Add the parameter to our list
80121
parameters.append(
81122
Parameter(
82123
name=name,
83124
kind=Parameter.KEYWORD_ONLY,
84-
default=Parameter.empty if field.is_required else None,
125+
default=param_default,
85126
annotation=field_type,
86127
))
87128

@@ -140,33 +181,23 @@ async def call_with_observability(func_call):
140181
result = await call_with_observability(lambda: function.ainvoke(chat_request, to_type=str))
141182
else:
142183
# Regular handling
143-
# Handle complex input schema - if we extracted fields from a nested schema,
144-
# we need to reconstruct the input
145-
if len(schema.model_fields) == 1 and len(parameters) > 1:
146-
# Get the field name from the original schema
147-
field_name = next(iter(schema.model_fields.keys()))
148-
field_type = schema.model_fields[field_name].annotation
149-
150-
# If it's a pydantic model, we need to create an instance
151-
if field_type and hasattr(field_type, "model_validate"):
152-
# Create the nested object
153-
nested_obj = field_type.model_validate(kwargs)
154-
# Call with the nested object
155-
kwargs = {field_name: nested_obj}
184+
# Strip sentinel values so Pydantic can apply defaults/factories
185+
cleaned_kwargs = {k: v for k, v in kwargs.items() if v is not _USE_PYDANTIC_DEFAULT}
186+
187+
# Always validate with the declared schema
188+
# This handles defaults, factories, nested models, validators, etc.
189+
model_input = schema.model_validate(cleaned_kwargs)
156190

157191
# Call the NAT function with the parameters - special handling for Workflow
158192
if is_workflow:
159-
# For workflow with regular input, we'll assume the first parameter is the input
160-
input_value = list(kwargs.values())[0] if kwargs else ""
161-
162-
# Workflows have a run method that is an async context manager
163-
# that returns a Runner
164-
async with function.run(input_value) as runner:
193+
# Workflows expect the model instance directly
194+
async with function.run(model_input) as runner:
165195
# Get the result from the runner
166196
result = await runner.result(to_type=str)
167197
else:
168-
# Regular function call
169-
result = await call_with_observability(lambda: function.acall_invoke(**kwargs))
198+
# Regular function call - unpack the validated model
199+
result = await call_with_observability(lambda: function.acall_invoke(**model_input.model_dump())
200+
)
170201

171202
# Report completion
172203
if ctx:

0 commit comments

Comments
 (0)