|
18 | 18 | from inspect import Parameter |
19 | 19 | from inspect import Signature |
20 | 20 | from typing import TYPE_CHECKING |
| 21 | +from typing import Any |
21 | 22 |
|
22 | 23 | from mcp.server.fastmcp import FastMCP |
23 | 24 | from pydantic import BaseModel |
| 25 | +from pydantic.fields import FieldInfo |
| 26 | +from pydantic_core import PydanticUndefined |
24 | 27 |
|
25 | 28 | from nat.builder.context import ContextState |
26 | 29 | from nat.builder.function import Function |
|
31 | 34 |
|
32 | 35 | logger = logging.getLogger(__name__) |
33 | 36 |
|
| 37 | +# Sentinel: marks "optional; let Pydantic supply default/factory" |
| 38 | +_USE_PYDANTIC_DEFAULT = object() |
| 39 | + |
| 40 | + |
| 41 | +def is_field_optional(field: FieldInfo) -> tuple[bool, Any]: |
| 42 | + """Determine if a Pydantic field is optional and extract its default value for MCP signatures. |
| 43 | +
|
| 44 | + For MCP tool signatures, we need to distinguish: |
| 45 | + - Required fields: marked with Parameter.empty |
| 46 | + - Optional with concrete default: use that default |
| 47 | + - Optional with factory: use sentinel so Pydantic can apply the factory later |
| 48 | +
|
| 49 | + Args: |
| 50 | + field: The Pydantic FieldInfo to check |
| 51 | +
|
| 52 | + Returns: |
| 53 | + A tuple of (is_optional, default_value): |
| 54 | + - (False, Parameter.empty) for required fields |
| 55 | + - (True, actual_default) for optional fields with explicit defaults |
| 56 | + - (True, _USE_PYDANTIC_DEFAULT) for optional fields with default_factory |
| 57 | + """ |
| 58 | + if field.is_required(): |
| 59 | + return False, Parameter.empty |
| 60 | + |
| 61 | + # Field is optional - has either default or factory |
| 62 | + if field.default is not PydanticUndefined: |
| 63 | + return True, field.default |
| 64 | + |
| 65 | + # Factory case: mark optional in signature but don't fabricate a value |
| 66 | + if field.default_factory is not None: |
| 67 | + return True, _USE_PYDANTIC_DEFAULT |
| 68 | + |
| 69 | + # Rare corner case: non-required yet no default surfaced |
| 70 | + return True, _USE_PYDANTIC_DEFAULT |
| 71 | + |
34 | 72 |
|
35 | 73 | def create_function_wrapper( |
36 | 74 | function_name: str, |
@@ -76,12 +114,15 @@ def create_function_wrapper( |
76 | 114 | # Get the field type and convert to appropriate Python type |
77 | 115 | field_type = field.annotation |
78 | 116 |
|
| 117 | + # Check if field is optional and get its default value |
| 118 | + _is_optional, param_default = is_field_optional(field) |
| 119 | + |
79 | 120 | # Add the parameter to our list |
80 | 121 | parameters.append( |
81 | 122 | Parameter( |
82 | 123 | name=name, |
83 | 124 | kind=Parameter.KEYWORD_ONLY, |
84 | | - default=Parameter.empty if field.is_required else None, |
| 125 | + default=param_default, |
85 | 126 | annotation=field_type, |
86 | 127 | )) |
87 | 128 |
|
@@ -140,33 +181,23 @@ async def call_with_observability(func_call): |
140 | 181 | result = await call_with_observability(lambda: function.ainvoke(chat_request, to_type=str)) |
141 | 182 | else: |
142 | 183 | # Regular handling |
143 | | - # Handle complex input schema - if we extracted fields from a nested schema, |
144 | | - # we need to reconstruct the input |
145 | | - if len(schema.model_fields) == 1 and len(parameters) > 1: |
146 | | - # Get the field name from the original schema |
147 | | - field_name = next(iter(schema.model_fields.keys())) |
148 | | - field_type = schema.model_fields[field_name].annotation |
149 | | - |
150 | | - # If it's a pydantic model, we need to create an instance |
151 | | - if field_type and hasattr(field_type, "model_validate"): |
152 | | - # Create the nested object |
153 | | - nested_obj = field_type.model_validate(kwargs) |
154 | | - # Call with the nested object |
155 | | - kwargs = {field_name: nested_obj} |
| 184 | + # Strip sentinel values so Pydantic can apply defaults/factories |
| 185 | + cleaned_kwargs = {k: v for k, v in kwargs.items() if v is not _USE_PYDANTIC_DEFAULT} |
| 186 | + |
| 187 | + # Always validate with the declared schema |
| 188 | + # This handles defaults, factories, nested models, validators, etc. |
| 189 | + model_input = schema.model_validate(cleaned_kwargs) |
156 | 190 |
|
157 | 191 | # Call the NAT function with the parameters - special handling for Workflow |
158 | 192 | if is_workflow: |
159 | | - # For workflow with regular input, we'll assume the first parameter is the input |
160 | | - input_value = list(kwargs.values())[0] if kwargs else "" |
161 | | - |
162 | | - # Workflows have a run method that is an async context manager |
163 | | - # that returns a Runner |
164 | | - async with function.run(input_value) as runner: |
| 193 | + # Workflows expect the model instance directly |
| 194 | + async with function.run(model_input) as runner: |
165 | 195 | # Get the result from the runner |
166 | 196 | result = await runner.result(to_type=str) |
167 | 197 | else: |
168 | | - # Regular function call |
169 | | - result = await call_with_observability(lambda: function.acall_invoke(**kwargs)) |
| 198 | + # Regular function call - unpack the validated model |
| 199 | + result = await call_with_observability(lambda: function.acall_invoke(**model_input.model_dump()) |
| 200 | + ) |
170 | 201 |
|
171 | 202 | # Report completion |
172 | 203 | if ctx: |
|
0 commit comments