Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 60 additions & 9 deletions src/Generators/AzureFunctions/DurableFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,24 @@ public class DurableFunction
public DurableFunctionKind Kind { get; }
public TypedParameter Parameter { get; }
public string ReturnType { get; }
public bool ReturnsVoid { get; }

public DurableFunction(
string fullTypeName,
string name,
DurableFunctionKind kind,
TypedParameter parameter,
ITypeSymbol returnType,
ITypeSymbol? returnType,
bool returnsVoid,
HashSet<string> requiredNamespaces)
{
this.FullTypeName = fullTypeName;
this.RequiredNamespaces = requiredNamespaces;
this.Name = name;
this.Kind = kind;
this.Parameter = parameter;
this.ReturnType = SyntaxNodeUtility.GetRenderedTypeExpression(returnType, false);
this.ReturnType = returnType != null ? SyntaxNodeUtility.GetRenderedTypeExpression(returnType, false) : string.Empty;
this.ReturnsVoid = returnsVoid;
}

public static bool TryParse(SemanticModel model, MethodDeclarationSyntax method, out DurableFunction? function)
Expand All @@ -59,12 +62,54 @@ public static bool TryParse(SemanticModel model, MethodDeclarationSyntax method,
return false;
}

INamedTypeSymbol taskSymbol = model.Compilation.GetTypeByMetadataName("System.Threading.Tasks.Task`1")!;
INamedTypeSymbol returnSymbol = (INamedTypeSymbol)model.GetTypeInfo(returnType).Type!;
if (SymbolEqualityComparer.Default.Equals(returnSymbol.OriginalDefinition, taskSymbol))
ITypeSymbol? returnTypeSymbol = model.GetTypeInfo(returnType).Type;
if (returnTypeSymbol == null || returnTypeSymbol.TypeKind == TypeKind.Error)
{
// this is a Task<T> return value, lets pull out the generic.
returnSymbol = (INamedTypeSymbol)returnSymbol.TypeArguments[0];
function = null;
return false;
}

bool returnsVoid = false;
INamedTypeSymbol? returnSymbol = null;

Comment on lines +73 to +74
Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states "we'll use object as a placeholder since it won't be used", but this is misleading. The returnSymbol is actually used: it's passed to GetRenderedTypeExpression in the constructor (line 41) to set the ReturnType property, and it's added to the usedTypes list (line 117) to determine required namespaces. Consider revising the comment to more accurately reflect that object is used as a safe default type for namespace resolution, even though the actual return type won't be rendered in the generated code due to the ReturnsVoid flag.

Copilot uses AI. Check for mistakes.
// Check if it's a void return type
if (returnTypeSymbol.SpecialType == SpecialType.System_Void)
{
returnsVoid = true;
// returnSymbol is left as null since void has no type to track
}
// Check if it's Task (non-generic)
else if (returnTypeSymbol is INamedTypeSymbol namedReturn)
{
INamedTypeSymbol? nonGenericTaskSymbol = model.Compilation.GetTypeByMetadataName("System.Threading.Tasks.Task");
Comment on lines +83 to +84
Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states "we'll use object as a placeholder since it won't be used", but this is misleading. The returnSymbol is actually used: it's passed to GetRenderedTypeExpression in the constructor (line 41) to set the ReturnType property, and it's added to the usedTypes list (line 117) to determine required namespaces. Consider revising the comment to more accurately reflect that object is used as a safe default type for namespace resolution, even though the actual return type won't be rendered in the generated code due to the ReturnsVoid flag.

Copilot uses AI. Check for mistakes.
if (nonGenericTaskSymbol != null && SymbolEqualityComparer.Default.Equals(namedReturn, nonGenericTaskSymbol))
{
returnsVoid = true;
// returnSymbol is left as null since Task (non-generic) has no return type to track
}
// Check if it's Task<T>
else
{
INamedTypeSymbol? taskSymbol = model.Compilation.GetTypeByMetadataName("System.Threading.Tasks.Task`1");
returnSymbol = namedReturn;
if (taskSymbol != null && SymbolEqualityComparer.Default.Equals(returnSymbol.OriginalDefinition, taskSymbol))
{
// this is a Task<T> return value, lets pull out the generic.
ITypeSymbol typeArg = returnSymbol.TypeArguments[0];
if (typeArg is not INamedTypeSymbol namedTypeArg)
{
function = null;
return false;
}
returnSymbol = namedTypeArg;
}
}
}
else
{
// returnTypeSymbol is not INamedTypeSymbol, which is unexpected
function = null;
return false;
}

if (!SyntaxNodeUtility.TryGetParameter(model, method, kind, out TypedParameter? parameter) || parameter == null)
Expand All @@ -79,12 +124,18 @@ public static bool TryParse(SemanticModel model, MethodDeclarationSyntax method,
return false;
}

// Build list of types used for namespace resolution
List<INamedTypeSymbol> usedTypes = new()
{
returnSymbol,
parameter.Type
};

// Only include return type if it's not void
if (returnSymbol != null)
{
usedTypes.Add(returnSymbol);
}

if (!SyntaxNodeUtility.TryGetRequiredNamespaces(usedTypes, out HashSet<string>? requiredNamespaces))
{
function = null;
Expand All @@ -93,7 +144,7 @@ public static bool TryParse(SemanticModel model, MethodDeclarationSyntax method,

requiredNamespaces!.UnionWith(GetRequiredGlobalNamespaces());

function = new DurableFunction(fullTypeName!, name, kind, parameter, returnSymbol, requiredNamespaces);
function = new DurableFunction(fullTypeName!, name, kind, parameter, returnSymbol, returnsVoid, requiredNamespaces);
return true;
}

Expand Down
12 changes: 11 additions & 1 deletion src/Generators/AzureFunctions/TypedParameter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,17 @@ public TypedParameter(INamedTypeSymbol type, string name)

public override string ToString()
{
return $"{SyntaxNodeUtility.GetRenderedTypeExpression(this.Type, false)} {this.Name}";
// Use the type as-is, preserving the nullability annotation from the source
string typeExpression = SyntaxNodeUtility.GetRenderedTypeExpression(this.Type, false);

// Special case: if the type is exactly System.Object (not a nullable object), make it nullable
// This is because object parameters are typically nullable in the context of Durable Functions
if (this.Type.SpecialType == SpecialType.System_Object && this.Type.NullableAnnotation != NullableAnnotation.Annotated)
{
typeExpression = "object?";
}

return $"{typeExpression} {this.Name}";
}
}
}
17 changes: 16 additions & 1 deletion src/Generators/DurableTaskSourceGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,21 @@ static void AddActivityCallMethod(StringBuilder sourceBuilder, DurableTaskTypeIn

static void AddActivityCallMethod(StringBuilder sourceBuilder, DurableFunction activity)
{
sourceBuilder.AppendLine($@"
if (activity.ReturnsVoid)
{
sourceBuilder.AppendLine($@"
/// <summary>
/// Calls the <see cref=""{activity.FullTypeName}""/> activity.
/// </summary>
/// <inheritdoc cref=""TaskOrchestrationContext.CallActivityAsync(TaskName, object?, TaskOptions?)""/>
public static Task Call{activity.Name}Async(this TaskOrchestrationContext ctx, {activity.Parameter}, TaskOptions? options = null)
{{
return ctx.CallActivityAsync(""{activity.Name}"", {activity.Parameter.Name}, options);
}}");
}
else
{
sourceBuilder.AppendLine($@"
/// <summary>
/// Calls the <see cref=""{activity.FullTypeName}""/> activity.
/// </summary>
Expand All @@ -444,6 +458,7 @@ static void AddActivityCallMethod(StringBuilder sourceBuilder, DurableFunction a
{{
return ctx.CallActivityAsync<{activity.ReturnType}>(""{activity.Name}"", {activity.Parameter.Name}, options);
}}");
}
}

static void AddEventWaitMethod(StringBuilder sourceBuilder, DurableEventTypeInfo eventInfo)
Expand Down
73 changes: 73 additions & 0 deletions test/Generators.Tests/AzureFunctionsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,79 @@ await TestHelpers.RunTestAsync<DurableTaskSourceGenerator>(
isDurableFunctions: true);
}

[Fact]
public async Task Activities_SimpleFunctionTrigger_VoidReturn()
{
string code = @"
using Microsoft.Azure.Functions.Worker;
using Microsoft.DurableTask;

public class Activities
{
[Function(nameof(FlakeyActivity))]
public static void FlakeyActivity([ActivityTrigger] object _)
{
throw new System.ApplicationException(""Kah-BOOOOM!!!"");
}
}";

string expectedOutput = TestHelpers.WrapAndFormat(
GeneratedClassName,
methodList: @"
/// <summary>
/// Calls the <see cref=""Activities.FlakeyActivity""/> activity.
/// </summary>
/// <inheritdoc cref=""TaskOrchestrationContext.CallActivityAsync(TaskName, object?, TaskOptions?)""/>
public static Task CallFlakeyActivityAsync(this TaskOrchestrationContext ctx, object? _, TaskOptions? options = null)
{
return ctx.CallActivityAsync(""FlakeyActivity"", _, options);
}",
isDurableFunctions: true);

await TestHelpers.RunTestAsync<DurableTaskSourceGenerator>(
GeneratedFileName,
code,
expectedOutput,
isDurableFunctions: true);
}

[Fact]
public async Task Activities_SimpleFunctionTrigger_TaskReturn()
{
string code = @"
using System.Threading.Tasks;
using Microsoft.Azure.Functions.Worker;
using Microsoft.DurableTask;

public class Activities
{
[Function(nameof(FlakeyActivity))]
public static Task FlakeyActivity([ActivityTrigger] object _)
{
throw new System.ApplicationException(""Kah-BOOOOM!!!"");
}
}";

string expectedOutput = TestHelpers.WrapAndFormat(
GeneratedClassName,
methodList: @"
/// <summary>
/// Calls the <see cref=""Activities.FlakeyActivity""/> activity.
/// </summary>
/// <inheritdoc cref=""TaskOrchestrationContext.CallActivityAsync(TaskName, object?, TaskOptions?)""/>
public static Task CallFlakeyActivityAsync(this TaskOrchestrationContext ctx, object? _, TaskOptions? options = null)
{
return ctx.CallActivityAsync(""FlakeyActivity"", _, options);
}",
isDurableFunctions: true);

await TestHelpers.RunTestAsync<DurableTaskSourceGenerator>(
GeneratedFileName,
code,
expectedOutput,
isDurableFunctions: true);
}

/// <summary>
/// Verifies that using the class-based activity syntax generates a <see cref="TaskOrchestrationContext"/>
/// extension method as well as an <see cref="ActivityTriggerAttribute"/> function definition.
Expand Down
Loading