Skip to content

Commit 691f361

Browse files
Add AIFunction.ReturnJsonSchema (#6447)
* Add AIFunction.ReturnJsonSchema * Use `MethodInfo.ReturnParameter`. * Use null return schema for void returning functions. * Remove experimental attribute and suppress warning. * Remove suppression and list new API as stable. --------- Co-authored-by: Jeff Handley <[email protected]>
1 parent af446a2 commit 691f361

File tree

5 files changed

+59
-9
lines changed

5 files changed

+59
-9
lines changed

src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ public abstract class AIFunction : AITool
3838
/// </remarks>
3939
public virtual JsonElement JsonSchema => AIJsonUtilities.DefaultJsonSchema;
4040

41+
/// <summary>Gets a JSON Schema describing the function's return value.</summary>
42+
/// <remarks>
43+
/// A <see langword="null"/> typically reflects a function that doesn't specify a return schema
44+
/// or a function that returns <see cref="void"/>, <see cref="Task"/>, or <see cref="ValueTask"/>.
45+
/// </remarks>
46+
public virtual JsonElement? ReturnJsonSchema => null;
47+
4148
/// <summary>
4249
/// Gets the underlying <see cref="MethodInfo"/> that this <see cref="AIFunction"/> might be wrapping.
4350
/// </summary>

src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactory.cs

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,7 @@ private ReflectionAIFunction(
540540
public override string Description => FunctionDescriptor.Description;
541541
public override MethodInfo UnderlyingMethod => FunctionDescriptor.Method;
542542
public override JsonElement JsonSchema => FunctionDescriptor.JsonSchema;
543+
public override JsonElement? ReturnJsonSchema => FunctionDescriptor.ReturnJsonSchema;
543544
public override JsonSerializerOptions JsonSerializerOptions => FunctionDescriptor.JsonSerializerOptions;
544545

545546
protected override async ValueTask<object?> InvokeCoreAsync(
@@ -683,13 +684,17 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
683684
ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, options, parameters[i]);
684685
}
685686

686-
// Get a marshaling delegate for the return value.
687-
ReturnParameterMarshaller = GetReturnParameterMarshaller(key, serializerOptions);
688-
687+
ReturnParameterMarshaller = GetReturnParameterMarshaller(key, serializerOptions, out Type? returnType);
689688
Method = key.Method;
690689
Name = key.Name ?? GetFunctionName(key.Method);
691690
Description = key.Description ?? key.Method.GetCustomAttribute<DescriptionAttribute>(inherit: true)?.Description ?? string.Empty;
692691
JsonSerializerOptions = serializerOptions;
692+
ReturnJsonSchema = returnType is null ? null : AIJsonUtilities.CreateJsonSchema(
693+
returnType,
694+
description: key.Method.ReturnParameter.GetCustomAttribute<DescriptionAttribute>(inherit: true)?.Description,
695+
serializerOptions: serializerOptions,
696+
inferenceOptions: schemaOptions);
697+
693698
JsonSchema = AIJsonUtilities.CreateFunctionJsonSchema(
694699
key.Method,
695700
title: string.Empty, // Forces skipping of the title keyword
@@ -703,6 +708,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
703708
public MethodInfo Method { get; }
704709
public JsonSerializerOptions JsonSerializerOptions { get; }
705710
public JsonElement JsonSchema { get; }
711+
public JsonElement? ReturnJsonSchema { get; }
706712
public Func<AIFunctionArguments, CancellationToken, object?>[] ParameterMarshallers { get; }
707713
public Func<object?, CancellationToken, ValueTask<object?>> ReturnParameterMarshaller { get; }
708714
public ReflectionAIFunction? CachedDefaultInstance { get; set; }
@@ -849,15 +855,16 @@ static void ThrowNullServices(string parameterName) =>
849855
/// Gets a delegate for handling the result value of a method, converting it into the <see cref="Task{FunctionResult}"/> to return from the invocation.
850856
/// </summary>
851857
private static Func<object?, CancellationToken, ValueTask<object?>> GetReturnParameterMarshaller(
852-
DescriptorKey key, JsonSerializerOptions serializerOptions)
858+
DescriptorKey key, JsonSerializerOptions serializerOptions, out Type? returnType)
853859
{
854-
Type returnType = key.Method.ReturnType;
860+
returnType = key.Method.ReturnType;
855861
JsonTypeInfo returnTypeInfo;
856862
Func<object?, Type?, CancellationToken, ValueTask<object?>>? marshalResult = key.MarshalResult;
857863

858864
// Void
859865
if (returnType == typeof(void))
860866
{
867+
returnType = null;
861868
if (marshalResult is not null)
862869
{
863870
return (result, cancellationToken) => marshalResult(null, null, cancellationToken);
@@ -869,6 +876,7 @@ static void ThrowNullServices(string parameterName) =>
869876
// Task
870877
if (returnType == typeof(Task))
871878
{
879+
returnType = null;
872880
if (marshalResult is not null)
873881
{
874882
return async (result, cancellationToken) =>
@@ -888,6 +896,7 @@ static void ThrowNullServices(string parameterName) =>
888896
// ValueTask
889897
if (returnType == typeof(ValueTask))
890898
{
899+
returnType = null;
891900
if (marshalResult is not null)
892901
{
893902
return async (result, cancellationToken) =>
@@ -910,6 +919,8 @@ static void ThrowNullServices(string parameterName) =>
910919
if (returnType.GetGenericTypeDefinition() == typeof(Task<>))
911920
{
912921
MethodInfo taskResultGetter = GetMethodFromGenericMethodDefinition(returnType, _taskGetResult);
922+
returnType = taskResultGetter.ReturnType;
923+
913924
if (marshalResult is not null)
914925
{
915926
return async (taskObj, cancellationToken) =>
@@ -920,7 +931,7 @@ static void ThrowNullServices(string parameterName) =>
920931
};
921932
}
922933

923-
returnTypeInfo = serializerOptions.GetTypeInfo(taskResultGetter.ReturnType);
934+
returnTypeInfo = serializerOptions.GetTypeInfo(returnType);
924935
return async (taskObj, cancellationToken) =>
925936
{
926937
await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(true);
@@ -934,6 +945,7 @@ static void ThrowNullServices(string parameterName) =>
934945
{
935946
MethodInfo valueTaskAsTask = GetMethodFromGenericMethodDefinition(returnType, _valueTaskAsTask);
936947
MethodInfo asTaskResultGetter = GetMethodFromGenericMethodDefinition(valueTaskAsTask.ReturnType, _taskGetResult);
948+
returnType = asTaskResultGetter.ReturnType;
937949

938950
if (marshalResult is not null)
939951
{
@@ -946,7 +958,7 @@ static void ThrowNullServices(string parameterName) =>
946958
};
947959
}
948960

949-
returnTypeInfo = serializerOptions.GetTypeInfo(asTaskResultGetter.ReturnType);
961+
returnTypeInfo = serializerOptions.GetTypeInfo(returnType);
950962
return async (taskObj, cancellationToken) =>
951963
{
952964
var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!;
@@ -960,7 +972,8 @@ static void ThrowNullServices(string parameterName) =>
960972
// For everything else, just serialize the result as-is.
961973
if (marshalResult is not null)
962974
{
963-
return (result, cancellationToken) => marshalResult(result, returnType, cancellationToken);
975+
Type returnTypeCopy = returnType;
976+
return (result, cancellationToken) => marshalResult(result, returnTypeCopy, cancellationToken);
964977
}
965978

966979
returnTypeInfo = serializerOptions.GetTypeInfo(returnType);

src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@
169169
"Member": "virtual System.Text.Json.JsonSerializerOptions Microsoft.Extensions.AI.AIFunction.JsonSerializerOptions { get; }",
170170
"Stage": "Stable"
171171
},
172+
{
173+
"Member": "virtual System.Text.Json.JsonElement? Microsoft.Extensions.AI.AIFunction.ReturnJsonSchema { get; }",
174+
"Stage": "Stable"
175+
},
172176
{
173177
"Member": "virtual System.Reflection.MethodInfo? Microsoft.Extensions.AI.AIFunction.UnderlyingMethod { get; }",
174178
"Stage": "Stable"

src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ private static JsonNode CreateJsonSchemaCore(
227227
(schemaObj ??= [])[DescriptionPropertyName] = description;
228228
}
229229

230-
return schemaObj ?? (JsonNode)true;
230+
return schemaObj ?? new JsonObject();
231231
}
232232

233233
if (type == typeof(void))

test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,22 +100,27 @@ public async Task Returns_AsyncReturnTypesSupported_Async()
100100
AIFunction func;
101101

102102
func = AIFunctionFactory.Create(Task<string> (string a) => Task.FromResult(a + " " + a));
103+
Assert.Equal("""{"type":"string"}""", func.ReturnJsonSchema.ToString());
103104
AssertExtensions.EqualFunctionCallResults("test test", await func.InvokeAsync(new() { ["a"] = "test" }));
104105

105106
func = AIFunctionFactory.Create(ValueTask<string> (string a, string b) => new ValueTask<string>(b + " " + a));
107+
Assert.Equal("""{"type":"string"}""", func.ReturnJsonSchema.ToString());
106108
AssertExtensions.EqualFunctionCallResults("hello world", await func.InvokeAsync(new() { ["b"] = "hello", ["a"] = "world" }));
107109

108110
long result = 0;
109111
func = AIFunctionFactory.Create(async Task (int a, long b) => { result = a + b; await Task.Yield(); });
112+
Assert.Null(func.ReturnJsonSchema);
110113
AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync(new() { ["a"] = 1, ["b"] = 2L }));
111114
Assert.Equal(3, result);
112115

113116
result = 0;
114117
func = AIFunctionFactory.Create(async ValueTask (int a, long b) => { result = a + b; await Task.Yield(); });
118+
Assert.Null(func.ReturnJsonSchema);
115119
AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync(new() { ["a"] = 1, ["b"] = 2L }));
116120
Assert.Equal(3, result);
117121

118122
func = AIFunctionFactory.Create((int count) => SimpleIAsyncEnumerable(count), serializerOptions: JsonContext.Default.Options);
123+
Assert.Equal("""{"type":"array","items":{"type":"integer"}}""", func.ReturnJsonSchema.ToString());
119124
AssertExtensions.EqualFunctionCallResults(new int[] { 0, 1, 2, 3, 4 }, await func.InvokeAsync(new() { ["count"] = 5 }), JsonContext.Default.Options);
120125

121126
static async IAsyncEnumerable<int> SimpleIAsyncEnumerable(int count)
@@ -220,6 +225,8 @@ public async Task AIFunctionFactoryOptions_SupportsSkippingParameters()
220225
Assert.DoesNotContain("firstParameter", func.JsonSchema.ToString());
221226
Assert.Contains("secondParameter", func.JsonSchema.ToString());
222227

228+
Assert.Equal("""{"type":"string"}""", func.ReturnJsonSchema.ToString());
229+
223230
var result = (JsonElement?)await func.InvokeAsync(new()
224231
{
225232
["firstParameter"] = "test",
@@ -265,6 +272,8 @@ public async Task AIFunctionArguments_SatisfiesParameters()
265272
Assert.DoesNotContain("services", func.JsonSchema.ToString());
266273
Assert.DoesNotContain("arguments", func.JsonSchema.ToString());
267274

275+
Assert.Equal("""{"type":"integer"}""", func.ReturnJsonSchema.ToString());
276+
268277
await Assert.ThrowsAsync<ArgumentNullException>("arguments.Services", () => func.InvokeAsync(arguments).AsTask());
269278

270279
arguments.Services = sp;
@@ -430,6 +439,8 @@ public async Task FromKeyedServices_ResolvesFromServiceProvider()
430439
Assert.Contains("myInteger", f.JsonSchema.ToString());
431440
Assert.DoesNotContain("service", f.JsonSchema.ToString());
432441

442+
Assert.Equal("""{"type":"integer"}""", f.ReturnJsonSchema.ToString());
443+
433444
Exception e = await Assert.ThrowsAsync<ArgumentException>("arguments.Services", () => f.InvokeAsync(new() { ["myInteger"] = 1 }).AsTask());
434445

435446
var result = await f.InvokeAsync(new() { ["myInteger"] = 1, Services = sp });
@@ -451,6 +462,8 @@ public async Task FromKeyedServices_NullKeysBindToNonKeyedServices()
451462
Assert.Contains("myInteger", f.JsonSchema.ToString());
452463
Assert.DoesNotContain("service", f.JsonSchema.ToString());
453464

465+
Assert.Equal("""{"type":"integer"}""", f.ReturnJsonSchema.ToString());
466+
454467
Exception e = await Assert.ThrowsAsync<ArgumentException>("arguments.Services", () => f.InvokeAsync(new() { ["myInteger"] = 1 }).AsTask());
455468

456469
var result = await f.InvokeAsync(new() { ["myInteger"] = 1, Services = sp });
@@ -743,6 +756,7 @@ public async Task MarshalResult_TypeIsDeclaredTypeEvenWhenDerivedTypeReturned()
743756
Assert.Equal(cts.Token, cancellationToken);
744757
return "marshalResultInvoked";
745758
},
759+
SerializerOptions = JsonContext.Default.Options,
746760
});
747761

748762
object? result = await f.InvokeAsync(new() { ["i"] = 42 }, cts.Token);
@@ -760,6 +774,17 @@ public async Task AIFunctionFactory_DefaultDefaultParameter()
760774
Assert.Contains("00000000-0000-0000-0000-000000000000,0", result?.ToString());
761775
}
762776

777+
[Fact]
778+
public void AIFunctionFactory_ReturnTypeWithDescriptionAttribute()
779+
{
780+
AIFunction f = AIFunctionFactory.Create(Add, serializerOptions: JsonContext.Default.Options);
781+
782+
Assert.Equal("""{"description":"The summed result","type":"integer"}""", f.ReturnJsonSchema.ToString());
783+
784+
[return: Description("The summed result")]
785+
static int Add(int a, int b) => a + b;
786+
}
787+
763788
private sealed class MyService(int value)
764789
{
765790
public int Value => value;
@@ -853,5 +878,6 @@ private static AIFunctionFactoryOptions CreateKeyedServicesSupportOptions() =>
853878
[JsonSerializable(typeof(string))]
854879
[JsonSerializable(typeof(Guid))]
855880
[JsonSerializable(typeof(StructWithDefaultCtor))]
881+
[JsonSerializable(typeof(B))]
856882
private partial class JsonContext : JsonSerializerContext;
857883
}

0 commit comments

Comments
 (0)