@@ -540,6 +540,7 @@ private ReflectionAIFunction(
540
540
public override string Description => FunctionDescriptor . Description ;
541
541
public override MethodInfo UnderlyingMethod => FunctionDescriptor . Method ;
542
542
public override JsonElement JsonSchema => FunctionDescriptor . JsonSchema ;
543
+ public override JsonElement ? ReturnJsonSchema => FunctionDescriptor . ReturnJsonSchema ;
543
544
public override JsonSerializerOptions JsonSerializerOptions => FunctionDescriptor . JsonSerializerOptions ;
544
545
545
546
protected override async ValueTask < object ? > InvokeCoreAsync (
@@ -683,13 +684,17 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
683
684
ParameterMarshallers [ i ] = GetParameterMarshaller ( serializerOptions , options , parameters [ i ] ) ;
684
685
}
685
686
686
- // Get a marshaling delegate for the return value.
687
- ReturnParameterMarshaller = GetReturnParameterMarshaller ( key , serializerOptions ) ;
688
-
687
+ ReturnParameterMarshaller = GetReturnParameterMarshaller ( key , serializerOptions , out Type ? returnType ) ;
689
688
Method = key . Method ;
690
689
Name = key . Name ?? GetFunctionName ( key . Method ) ;
691
690
Description = key . Description ?? key . Method . GetCustomAttribute < DescriptionAttribute > ( inherit : true ) ? . Description ?? string . Empty ;
692
691
JsonSerializerOptions = serializerOptions ;
692
+ ReturnJsonSchema = returnType is null ? null : AIJsonUtilities . CreateJsonSchema (
693
+ returnType ,
694
+ description : key . Method . ReturnParameter . GetCustomAttribute < DescriptionAttribute > ( inherit : true ) ? . Description ,
695
+ serializerOptions : serializerOptions ,
696
+ inferenceOptions : schemaOptions ) ;
697
+
693
698
JsonSchema = AIJsonUtilities . CreateFunctionJsonSchema (
694
699
key . Method ,
695
700
title : string . Empty , // Forces skipping of the title keyword
@@ -703,6 +708,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
703
708
public MethodInfo Method { get ; }
704
709
public JsonSerializerOptions JsonSerializerOptions { get ; }
705
710
public JsonElement JsonSchema { get ; }
711
+ public JsonElement ? ReturnJsonSchema { get ; }
706
712
public Func < AIFunctionArguments , CancellationToken , object ? > [ ] ParameterMarshallers { get ; }
707
713
public Func < object ? , CancellationToken , ValueTask < object ? > > ReturnParameterMarshaller { get ; }
708
714
public ReflectionAIFunction ? CachedDefaultInstance { get ; set ; }
@@ -849,15 +855,16 @@ static void ThrowNullServices(string parameterName) =>
849
855
/// Gets a delegate for handling the result value of a method, converting it into the <see cref="Task{FunctionResult}"/> to return from the invocation.
850
856
/// </summary>
851
857
private static Func < object ? , CancellationToken , ValueTask < object ? > > GetReturnParameterMarshaller (
852
- DescriptorKey key , JsonSerializerOptions serializerOptions )
858
+ DescriptorKey key , JsonSerializerOptions serializerOptions , out Type ? returnType )
853
859
{
854
- Type returnType = key . Method . ReturnType ;
860
+ returnType = key . Method . ReturnType ;
855
861
JsonTypeInfo returnTypeInfo ;
856
862
Func < object ? , Type ? , CancellationToken , ValueTask < object ? > > ? marshalResult = key . MarshalResult ;
857
863
858
864
// Void
859
865
if ( returnType == typeof ( void ) )
860
866
{
867
+ returnType = null ;
861
868
if ( marshalResult is not null )
862
869
{
863
870
return ( result , cancellationToken ) => marshalResult ( null , null , cancellationToken ) ;
@@ -869,6 +876,7 @@ static void ThrowNullServices(string parameterName) =>
869
876
// Task
870
877
if ( returnType == typeof ( Task ) )
871
878
{
879
+ returnType = null ;
872
880
if ( marshalResult is not null )
873
881
{
874
882
return async ( result , cancellationToken ) =>
@@ -888,6 +896,7 @@ static void ThrowNullServices(string parameterName) =>
888
896
// ValueTask
889
897
if ( returnType == typeof ( ValueTask ) )
890
898
{
899
+ returnType = null ;
891
900
if ( marshalResult is not null )
892
901
{
893
902
return async ( result , cancellationToken ) =>
@@ -910,6 +919,8 @@ static void ThrowNullServices(string parameterName) =>
910
919
if ( returnType . GetGenericTypeDefinition ( ) == typeof ( Task < > ) )
911
920
{
912
921
MethodInfo taskResultGetter = GetMethodFromGenericMethodDefinition ( returnType , _taskGetResult ) ;
922
+ returnType = taskResultGetter . ReturnType ;
923
+
913
924
if ( marshalResult is not null )
914
925
{
915
926
return async ( taskObj , cancellationToken ) =>
@@ -920,7 +931,7 @@ static void ThrowNullServices(string parameterName) =>
920
931
} ;
921
932
}
922
933
923
- returnTypeInfo = serializerOptions . GetTypeInfo ( taskResultGetter . ReturnType ) ;
934
+ returnTypeInfo = serializerOptions . GetTypeInfo ( returnType ) ;
924
935
return async ( taskObj , cancellationToken ) =>
925
936
{
926
937
await ( ( Task ) ThrowIfNullResult ( taskObj ) ) . ConfigureAwait ( true ) ;
@@ -934,6 +945,7 @@ static void ThrowNullServices(string parameterName) =>
934
945
{
935
946
MethodInfo valueTaskAsTask = GetMethodFromGenericMethodDefinition ( returnType , _valueTaskAsTask ) ;
936
947
MethodInfo asTaskResultGetter = GetMethodFromGenericMethodDefinition ( valueTaskAsTask . ReturnType , _taskGetResult ) ;
948
+ returnType = asTaskResultGetter . ReturnType ;
937
949
938
950
if ( marshalResult is not null )
939
951
{
@@ -946,7 +958,7 @@ static void ThrowNullServices(string parameterName) =>
946
958
} ;
947
959
}
948
960
949
- returnTypeInfo = serializerOptions . GetTypeInfo ( asTaskResultGetter . ReturnType ) ;
961
+ returnTypeInfo = serializerOptions . GetTypeInfo ( returnType ) ;
950
962
return async ( taskObj , cancellationToken ) =>
951
963
{
952
964
var task = ( Task ) ReflectionInvoke ( valueTaskAsTask , ThrowIfNullResult ( taskObj ) , null ) ! ;
@@ -960,7 +972,8 @@ static void ThrowNullServices(string parameterName) =>
960
972
// For everything else, just serialize the result as-is.
961
973
if ( marshalResult is not null )
962
974
{
963
- return ( result , cancellationToken ) => marshalResult ( result , returnType , cancellationToken ) ;
975
+ Type returnTypeCopy = returnType ;
976
+ return ( result , cancellationToken ) => marshalResult ( result , returnTypeCopy , cancellationToken ) ;
964
977
}
965
978
966
979
returnTypeInfo = serializerOptions . GetTypeInfo ( returnType ) ;
0 commit comments