Skip to content

Commit ce48579

Browse files
authored
Provide support for exposing .NET classes to COM through source generation (#83755)
1 parent 3a28f6e commit ce48579

25 files changed

+655
-257
lines changed
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System.Collections.Generic;
5+
using System.Collections.Immutable;
6+
using System.IO;
7+
using System.Linq;
8+
using Microsoft.CodeAnalysis;
9+
using Microsoft.CodeAnalysis.CSharp.Syntax;
10+
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
11+
using Microsoft.CodeAnalysis.CSharp;
12+
13+
namespace Microsoft.Interop
14+
{
15+
[Generator]
16+
public class ComClassGenerator : IIncrementalGenerator
17+
{
18+
private sealed record ComClassInfo(string ClassName, ContainingSyntaxContext ContainingSyntaxContext, ContainingSyntax ClassSyntax, SequenceEqualImmutableArray<string> ImplementedInterfacesNames);
19+
public void Initialize(IncrementalGeneratorInitializationContext context)
20+
{
21+
// Get all types with the [GeneratedComClassAttribute] attribute.
22+
var attributedClasses = context.SyntaxProvider
23+
.ForAttributeWithMetadataName(
24+
TypeNames.GeneratedComClassAttribute,
25+
static (node, ct) => node is ClassDeclarationSyntax,
26+
static (context, ct) =>
27+
{
28+
var type = (INamedTypeSymbol)context.TargetSymbol;
29+
var syntax = (ClassDeclarationSyntax)context.TargetNode;
30+
ImmutableArray<string>.Builder names = ImmutableArray.CreateBuilder<string>();
31+
foreach (INamedTypeSymbol iface in type.AllInterfaces)
32+
{
33+
if (iface.GetAttributes().Any(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute))
34+
{
35+
names.Add(iface.ToDisplayString());
36+
}
37+
}
38+
return new ComClassInfo(
39+
type.ToDisplayString(),
40+
new ContainingSyntaxContext(syntax),
41+
new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList),
42+
new(names.ToImmutable()));
43+
});
44+
45+
var className = attributedClasses.Select(static (info, ct) => info.ClassName);
46+
47+
var classInfoType = attributedClasses
48+
.Select(static (info, ct) => new { info.ClassName, info.ImplementedInterfacesNames })
49+
.Select(static (info, ct) => GenerateClassInfoType(info.ImplementedInterfacesNames.Array).NormalizeWhitespace());
50+
51+
var attribute = attributedClasses
52+
.Select(static (info, ct) => new { info.ContainingSyntaxContext, info.ClassSyntax })
53+
.Select(static (info, ct) => GenerateClassInfoAttributeOnUserType(info.ContainingSyntaxContext, info.ClassSyntax).NormalizeWhitespace());
54+
55+
context.RegisterSourceOutput(className.Zip(classInfoType).Zip(attribute), static (context, classInfo) =>
56+
{
57+
var ((className, classInfoType), attribute) = classInfo;
58+
StringWriter writer = new();
59+
writer.WriteLine(classInfoType.ToFullString());
60+
writer.WriteLine();
61+
writer.WriteLine(attribute);
62+
context.AddSource(className, writer.ToString());
63+
});
64+
}
65+
66+
private const string ClassInfoTypeName = "ComClassInformation";
67+
68+
private static readonly AttributeSyntax s_comExposedClassAttributeTemplate =
69+
Attribute(
70+
GenericName(TypeNames.ComExposedClassAttribute)
71+
.AddTypeArgumentListArguments(
72+
IdentifierName(ClassInfoTypeName)));
73+
private static MemberDeclarationSyntax GenerateClassInfoAttributeOnUserType(ContainingSyntaxContext containingSyntaxContext, ContainingSyntax classSyntax) =>
74+
containingSyntaxContext.WrapMemberInContainingSyntaxWithUnsafeModifier(
75+
TypeDeclaration(classSyntax.TypeKind, classSyntax.Identifier)
76+
.WithModifiers(classSyntax.Modifiers)
77+
.WithTypeParameterList(classSyntax.TypeParameters)
78+
.AddAttributeLists(AttributeList(SingletonSeparatedList(s_comExposedClassAttributeTemplate))));
79+
private static ClassDeclarationSyntax GenerateClassInfoType(ImmutableArray<string> implementedInterfaces)
80+
{
81+
const string vtablesField = "s_vtables";
82+
const string vtablesLocal = "vtables";
83+
const string detailsTempLocal = "details";
84+
const string countIdentifier = "count";
85+
var typeDeclaration = ClassDeclaration(ClassInfoTypeName)
86+
.AddModifiers(
87+
Token(SyntaxKind.FileKeyword),
88+
Token(SyntaxKind.SealedKeyword),
89+
Token(SyntaxKind.UnsafeKeyword))
90+
.AddBaseListTypes(SimpleBaseType(ParseTypeName(TypeNames.IComExposedClass)))
91+
.AddMembers(
92+
FieldDeclaration(
93+
VariableDeclaration(
94+
PointerType(
95+
ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
96+
SingletonSeparatedList(VariableDeclarator(vtablesField))))
97+
.AddModifiers(
98+
Token(SyntaxKind.PrivateKeyword),
99+
Token(SyntaxKind.StaticKeyword),
100+
Token(SyntaxKind.VolatileKeyword)));
101+
List<StatementSyntax> vtableInitializationBlock = new()
102+
{
103+
// ComInterfaceEntry* vtables = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(<ClassInfoTypeName>), sizeof(ComInterfaceEntry) * <numInterfaces>);
104+
LocalDeclarationStatement(
105+
VariableDeclaration(
106+
PointerType(
107+
ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
108+
SingletonSeparatedList(
109+
VariableDeclarator(vtablesLocal)
110+
.WithInitializer(EqualsValueClause(
111+
CastExpression(
112+
PointerType(
113+
ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
114+
InvocationExpression(
115+
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
116+
ParseTypeName(TypeNames.System_Runtime_CompilerServices_RuntimeHelpers),
117+
IdentifierName("AllocateTypeAssociatedMemory")))
118+
.AddArgumentListArguments(
119+
Argument(TypeOfExpression(IdentifierName(ClassInfoTypeName))),
120+
Argument(
121+
BinaryExpression(
122+
SyntaxKind.MultiplyExpression,
123+
SizeOfExpression(ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
124+
LiteralExpression(
125+
SyntaxKind.NumericLiteralExpression,
126+
Literal(implementedInterfaces.Length))))))))))),
127+
// IIUnknownDerivedDetails details;
128+
LocalDeclarationStatement(
129+
VariableDeclaration(
130+
ParseTypeName(TypeNames.IIUnknownDerivedDetails),
131+
SingletonSeparatedList(
132+
VariableDeclarator(detailsTempLocal))))
133+
};
134+
for (int i = 0; i < implementedInterfaces.Length; i++)
135+
{
136+
string ifaceName = implementedInterfaces[i];
137+
138+
// details = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(<ifaceName>).TypeHandle);
139+
vtableInitializationBlock.Add(
140+
ExpressionStatement(
141+
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
142+
IdentifierName(detailsTempLocal),
143+
InvocationExpression(
144+
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
145+
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
146+
ParseTypeName(TypeNames.StrategyBasedComWrappers),
147+
IdentifierName("DefaultIUnknownInterfaceDetailsStrategy")),
148+
IdentifierName("GetIUnknownDerivedDetails")),
149+
ArgumentList(
150+
SingletonSeparatedList(
151+
Argument(
152+
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
153+
TypeOfExpression(ParseName(ifaceName)),
154+
IdentifierName("TypeHandle")))))))));
155+
// vtable[i] = new() { IID = details.Iid, Vtable = details.ManagedVirtualMethodTable };
156+
vtableInitializationBlock.Add(
157+
ExpressionStatement(
158+
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
159+
ElementAccessExpression(
160+
IdentifierName(vtablesLocal),
161+
BracketedArgumentList(
162+
SingletonSeparatedList(
163+
Argument(
164+
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(i)))))),
165+
ImplicitObjectCreationExpression(
166+
ArgumentList(),
167+
InitializerExpression(SyntaxKind.ObjectInitializerExpression,
168+
SeparatedList(
169+
new ExpressionSyntax[]
170+
{
171+
AssignmentExpression(
172+
SyntaxKind.SimpleAssignmentExpression,
173+
IdentifierName("IID"),
174+
MemberAccessExpression(
175+
SyntaxKind.SimpleMemberAccessExpression,
176+
IdentifierName(detailsTempLocal),
177+
IdentifierName("Iid"))),
178+
AssignmentExpression(
179+
SyntaxKind.SimpleAssignmentExpression,
180+
IdentifierName("Vtable"),
181+
CastExpression(
182+
IdentifierName("nint"),
183+
MemberAccessExpression(
184+
SyntaxKind.SimpleMemberAccessExpression,
185+
IdentifierName(detailsTempLocal),
186+
IdentifierName("ManagedVirtualMethodTable"))))
187+
}))))));
188+
}
189+
190+
// s_vtable = vtable;
191+
vtableInitializationBlock.Add(
192+
ExpressionStatement(
193+
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
194+
IdentifierName(vtablesField),
195+
IdentifierName(vtablesLocal))));
196+
197+
BlockSyntax getComInterfaceEntriesMethodBody = Block(
198+
// count = <count>;
199+
ExpressionStatement(
200+
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
201+
IdentifierName(countIdentifier),
202+
LiteralExpression(SyntaxKind.NumericLiteralExpression,
203+
Literal(implementedInterfaces.Length)))),
204+
// if (s_vtable == null)
205+
// { initializer block }
206+
IfStatement(
207+
BinaryExpression(SyntaxKind.EqualsExpression,
208+
IdentifierName(vtablesField),
209+
LiteralExpression(SyntaxKind.NullLiteralExpression)),
210+
Block(vtableInitializationBlock)),
211+
// return s_vtable;
212+
ReturnStatement(IdentifierName(vtablesField)));
213+
214+
typeDeclaration = typeDeclaration.AddMembers(
215+
// public static unsafe ComWrappers.ComInterfaceDispatch* GetComInterfaceEntries(out int count)
216+
// { body }
217+
MethodDeclaration(
218+
PointerType(
219+
ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
220+
"GetComInterfaceEntries")
221+
.AddParameterListParameters(
222+
Parameter(Identifier(countIdentifier))
223+
.WithType(PredefinedType(Token(SyntaxKind.IntKeyword)))
224+
.AddModifiers(Token(SyntaxKind.OutKeyword)))
225+
.WithBody(getComInterfaceEntriesMethodBody)
226+
.AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)));
227+
228+
return typeDeclaration;
229+
}
230+
}
231+
}

src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ public static class StepNames
4545

4646
public void Initialize(IncrementalGeneratorInitializationContext context)
4747
{
48-
// Get all methods with the [GeneratedComInterface] attribute.
48+
// Get all types with the [GeneratedComInterface] attribute.
4949
var attributedInterfaces = context.SyntaxProvider
5050
.ForAttributeWithMetadataName(
5151
TypeNames.GeneratedComInterfaceAttribute,
@@ -62,7 +62,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
6262
return new { data.Syntax, data.Symbol, Diagnostic = diagnostic };
6363
});
6464

65-
// Split the methods we want to generate and the ones we don't into two separate groups.
65+
// Split the types we want to generate and the ones we don't into two separate groups.
6666
var interfacesToGenerate = interfacesWithDiagnostics.Where(static data => data.Diagnostic is null);
6767
var invalidTypeDiagnostics = interfacesWithDiagnostics.Where(static data => data.Diagnostic is not null);
6868

@@ -726,7 +726,7 @@ private static ClassDeclarationSyntax GenerateInterfaceInformation(ComInterfaceC
726726
.WithExpressionBody(
727727
ArrowExpressionClause(
728728
ConditionalExpression(
729-
BinaryExpression(SyntaxKind.EqualsExpression,
729+
BinaryExpression(SyntaxKind.NotEqualsExpression,
730730
IdentifierName(vtableFieldName),
731731
LiteralExpression(SyntaxKind.NullLiteralExpression)),
732732
IdentifierName(vtableFieldName),

src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@ internal static class IncrementalValuesProviderExtensions
3131
});
3232
}
3333

34+
/// <summary>
35+
/// Format the syntax nodes in the given provider such that we will not re-normalize if the input nodes have not changed.
36+
/// </summary>
37+
/// <typeparam name="TNode">A syntax node kind.</typeparam>
38+
/// <param name="provider">The input nodes</param>
39+
/// <returns>A provider of the formatted syntax nodes.</returns>
40+
/// <remarks>
41+
/// Normalizing whitespace is very expensive, so if a generator will have cases where the input information into the step
42+
/// that creates <paramref name="provider"/> may change but the results of <paramref name="provider"/> will say the same,
43+
/// using this method to format the code in a separate step will reduce the amount of work the generator repeats when the
44+
/// output code will not change.
45+
/// </remarks>
3446
public static IncrementalValuesProvider<TNode> SelectNormalized<TNode>(this IncrementalValuesProvider<TNode> provider)
3547
where TNode : SyntaxNode
3648
{

src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,22 @@ public static string MarshalEx(InteropGenerationOptions options)
117117

118118
public const string System_Runtime_InteropServices_ComWrappers_ComInterfaceDispatch = "System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch";
119119

120+
public const string System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry = "System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry";
121+
122+
public const string StrategyBasedComWrappers = "System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers";
123+
120124
public const string IIUnknownInterfaceType = "System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceType";
121125
public const string IUnknownDerivedAttribute = "System.Runtime.InteropServices.Marshalling.IUnknownDerivedAttribute";
126+
public const string IIUnknownDerivedDetails = "System.Runtime.InteropServices.Marshalling.IIUnknownDerivedDetails";
122127

123128
public const string ComWrappersUnwrapper = "System.Runtime.InteropServices.Marshalling.ComWrappersUnwrapper";
124129
public const string UnmanagedObjectUnwrapperAttribute = "System.Runtime.InteropServices.Marshalling.UnmanagedObjectUnwrapperAttribute`1";
125130

126131
public const string IUnmanagedObjectUnwrapper = "System.Runtime.InteropServices.Marshalling.IUnmanagedObjectUnwrapper";
127132
public const string UnmanagedObjectUnwrapper = "System.Runtime.InteropServices.Marshalling.UnmanagedObjectUnwrapper";
133+
134+
public const string GeneratedComClassAttribute = "System.Runtime.InteropServices.Marshalling.GeneratedComClassAttribute";
135+
public const string ComExposedClassAttribute = "System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute";
136+
public const string IComExposedClass = "System.Runtime.InteropServices.Marshalling.IComExposedClass";
128137
}
129138
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Linq;
7+
using System.Text;
8+
using System.Threading.Tasks;
9+
10+
namespace System.Runtime.InteropServices.Marshalling
11+
{
12+
/// <summary>
13+
/// An attribute to mark this class as a type whose instances should be exposed to COM.
14+
/// </summary>
15+
/// <typeparam name="T">The type that provides information about how to expose the attributed type to COM.</typeparam>
16+
[AttributeUsage(AttributeTargets.Class, Inherited = false)]
17+
public sealed class ComExposedClassAttribute<T> : Attribute, IComExposedDetails
18+
where T : IComExposedClass
19+
{
20+
/// <inheritdoc />
21+
public unsafe ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count) => T.GetComInterfaceEntries(out count);
22+
}
23+
}

0 commit comments

Comments
 (0)