@@ -54,7 +54,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
5454 var externalInterfaceSymbols = attributedInterfaces . SelectMany ( static ( data , ct ) =>
5555 {
5656 return ComInterfaceInfo . CreateInterfaceInfoForBaseInterfacesInOtherCompilations ( data . Symbol ) ;
57- } ) ;
57+ } ) . Collect ( ) . SelectMany ( static ( data , ct ) => data . Distinct ( ComInterfaceInfo . EqualityComparerForExternalIfaces . Instance ) ) ;
5858
5959 var interfaceSymbolsWithoutDiagnostics = interfaceSymbolsToGenerateWithoutDiagnostics . Concat ( externalInterfaceSymbols ) ;
6060
@@ -84,11 +84,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
8484 . SelectMany ( static ( data , ct ) =>
8585 {
8686 return ComMethodContext . CalculateAllMethods ( data , ct ) ;
87- } )
88- // Now that we've determined method offsets, we can remove all externally defined methods.
89- // We'll also filter out methods originally declared on externally defined base interfaces
90- // as we may not be able to emit them into our assembly.
91- . Where ( context => ! context . Method . OriginalDeclaringInterface . IsExternallyDefined ) ;
87+ } ) ;
9288
9389 // Now that we've determined method offsets, we can remove all externally defined interfaces.
9490 var interfaceContextsToGenerate = interfaceContexts . Where ( context => ! context . IsExternallyDefined ) ;
@@ -107,13 +103,20 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
107103 return new ComMethodContext (
108104 data . Method ,
109105 data . OwningInterface ,
110- CalculateStubInformation ( data . Method . MethodInfo . Syntax , symbolMap [ data . Method . MethodInfo ] , data . Method . Index , env , data . OwningInterface . Info , ct ) ) ;
106+ CalculateStubInformation (
107+ data . Method . MethodInfo . Syntax ,
108+ symbolMap [ data . Method . MethodInfo ] ,
109+ data . Method . Index ,
110+ env ,
111+ data . OwningInterface . Info ,
112+ ct ) ) ;
111113 } ) . WithTrackingName ( StepNames . CalculateStubInformation ) ;
112114
113115 var interfaceAndMethodsContexts = comMethodContexts
114116 . Collect ( )
115117 . Combine ( interfaceContextsToGenerate . Collect ( ) )
116- . SelectMany ( ( data , ct ) => GroupComContextsForInterfaceGeneration ( data . Left , data . Right , ct ) ) ;
118+ . SelectMany ( ( data , ct ) =>
119+ GroupComContextsForInterfaceGeneration ( data . Left , data . Right , ct ) ) ;
117120
118121 // Generate the code for the managed-to-unmanaged stubs.
119122 var managedToNativeInterfaceImplementations = interfaceAndMethodsContexts
@@ -256,12 +259,22 @@ private static bool IsHResultLikeType(ManagedTypeInfo type)
256259 || typeName . Equals ( "hresult" , StringComparison . OrdinalIgnoreCase ) ;
257260 }
258261
259- private static IncrementalMethodStubGenerationContext CalculateStubInformation ( MethodDeclarationSyntax syntax , IMethodSymbol symbol , int index , StubEnvironment environment , ComInterfaceInfo owningInterfaceInfo , CancellationToken ct )
262+ /// <summary>
263+ /// Calculates the shared information needed for both source-available and sourceless stub generation.
264+ /// </summary>
265+ private static IncrementalMethodStubGenerationContext CalculateSharedStubInformation (
266+ IMethodSymbol symbol ,
267+ int index ,
268+ StubEnvironment environment ,
269+ ISignatureDiagnosticLocations diagnosticLocations ,
270+ ComInterfaceInfo owningInterfaceInfo ,
271+ CancellationToken ct )
260272 {
261273 ct . ThrowIfCancellationRequested ( ) ;
262274 INamedTypeSymbol ? lcidConversionAttrType = environment . LcidConversionAttrType ;
263275 INamedTypeSymbol ? suppressGCTransitionAttrType = environment . SuppressGCTransitionAttrType ;
264276 INamedTypeSymbol ? unmanagedCallConvAttrType = environment . UnmanagedCallConvAttrType ;
277+
265278 // Get any attributes of interest on the method
266279 AttributeData ? lcidConversionAttr = null ;
267280 AttributeData ? suppressGCTransitionAttribute = null ;
@@ -282,8 +295,7 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
282295 }
283296 }
284297
285- var locations = new MethodSignatureDiagnosticLocations ( syntax ) ;
286- var generatorDiagnostics = new GeneratorDiagnosticsBag ( new DiagnosticDescriptorProvider ( ) , locations , SR . ResourceManager , typeof ( FxResources . Microsoft . Interop . ComInterfaceGenerator . SR ) ) ;
298+ var generatorDiagnostics = new GeneratorDiagnosticsBag ( new DiagnosticDescriptorProvider ( ) , diagnosticLocations , SR . ResourceManager , typeof ( FxResources . Microsoft . Interop . ComInterfaceGenerator . SR ) ) ;
287299
288300 if ( lcidConversionAttr is not null )
289301 {
@@ -293,8 +305,8 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
293305
294306 GeneratedComInterfaceCompilationData . TryGetGeneratedComInterfaceAttributeFromInterface ( symbol . ContainingType , out var generatedComAttribute ) ;
295307 var generatedComInterfaceAttributeData = GeneratedComInterfaceCompilationData . GetDataFromAttribute ( generatedComAttribute ) ;
296- // Create the stub.
297308
309+ // Create the stub.
298310 var signatureContext = SignatureContext . Create (
299311 symbol ,
300312 DefaultMarshallingInfoParser . Create (
@@ -387,21 +399,14 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
387399 GeneratorDiagnostics . SizeOfInCollectionMustBeDefinedAtCallReturnValue ) ;
388400 }
389401
390- var containingSyntaxContext = new ContainingSyntaxContext ( syntax ) ;
391-
392- var methodSyntaxTemplate = new ContainingSyntax ( new SyntaxTokenList ( syntax . Modifiers . Where ( static m => ! m . IsKind ( SyntaxKind . NewKeyword ) ) ) . StripAccessibilityModifiers ( ) , SyntaxKind . MethodDeclaration , syntax . Identifier , syntax . TypeParameterList ) ;
393-
394402 ImmutableArray < FunctionPointerUnmanagedCallingConventionSyntax > callConv = VirtualMethodPointerStubGenerator . GenerateCallConvSyntaxFromAttributes (
395403 suppressGCTransitionAttribute ,
396404 unmanagedCallConvAttribute ,
397405 ImmutableArray . Create ( FunctionPointerUnmanagedCallingConvention ( Identifier ( "MemberFunction" ) ) ) ) ;
398406
399407 var declaringType = ManagedTypeInfo . CreateTypeInfoForTypeSymbol ( symbol . ContainingType ) ;
400408
401- var virtualMethodIndexData = new VirtualMethodIndexData ( index , ImplicitThisParameter : true , direction , true , ExceptionMarshalling . Com ) ;
402-
403409 MarshallingInfo exceptionMarshallingInfo ;
404-
405410 if ( generatedComInterfaceAttributeData . ExceptionToUnmanagedMarshaller is null )
406411 {
407412 exceptionMarshallingInfo = new ComExceptionMarshalling ( ) ;
@@ -418,11 +423,9 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
418423
419424 return new IncrementalMethodStubGenerationContext (
420425 signatureContext ,
421- containingSyntaxContext ,
422- methodSyntaxTemplate ,
423- locations ,
426+ diagnosticLocations ,
424427 callConv . ToSequenceEqualImmutableArray ( SyntaxEquivalentComparer . Instance ) ,
425- virtualMethodIndexData ,
428+ new VirtualMethodIndexData ( index , ImplicitThisParameter : true , direction , true , ExceptionMarshalling . Com ) ,
426429 exceptionMarshallingInfo ,
427430 environment . EnvironmentFlags ,
428431 owningInterfaceInfo . Type ,
@@ -431,6 +434,45 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
431434 ComInterfaceDispatchMarshallingInfo . Instance ) ;
432435 }
433436
437+ private static IncrementalMethodStubGenerationContext CalculateStubInformation ( MethodDeclarationSyntax ? syntax , IMethodSymbol symbol , int index , StubEnvironment environment , ComInterfaceInfo owningInterface , CancellationToken ct )
438+ {
439+ ISignatureDiagnosticLocations locations = syntax is null
440+ ? NoneSignatureDiagnosticLocations . Instance
441+ : new MethodSignatureDiagnosticLocations ( syntax ) ;
442+
443+ var sourcelessStubInformation = CalculateSharedStubInformation (
444+ symbol ,
445+ index ,
446+ environment ,
447+ locations ,
448+ owningInterface ,
449+ ct ) ;
450+
451+ if ( syntax is null )
452+ return sourcelessStubInformation ;
453+
454+ var containingSyntaxContext = new ContainingSyntaxContext ( syntax ) ;
455+ var methodSyntaxTemplate = new ContainingSyntax (
456+ new SyntaxTokenList ( syntax . Modifiers . Where ( static m => ! m . IsKind ( SyntaxKind . NewKeyword ) ) ) . StripAccessibilityModifiers ( ) ,
457+ SyntaxKind . MethodDeclaration ,
458+ syntax . Identifier ,
459+ syntax . TypeParameterList ) ;
460+
461+ return new SourceAvailableIncrementalMethodStubGenerationContext (
462+ sourcelessStubInformation . SignatureContext ,
463+ containingSyntaxContext ,
464+ methodSyntaxTemplate ,
465+ locations ,
466+ sourcelessStubInformation . CallingConvention ,
467+ sourcelessStubInformation . VtableIndexData ,
468+ sourcelessStubInformation . ExceptionMarshallingInfo ,
469+ sourcelessStubInformation . EnvironmentFlags ,
470+ sourcelessStubInformation . TypeKeyOwner ,
471+ sourcelessStubInformation . DeclaringType ,
472+ sourcelessStubInformation . Diagnostics ,
473+ ComInterfaceDispatchMarshallingInfo . Instance ) ;
474+ }
475+
434476 private static MarshalDirection GetDirectionFromOptions ( ComInterfaceOptions options )
435477 {
436478 if ( options . HasFlag ( ComInterfaceOptions . ManagedObjectWrapper | ComInterfaceOptions . ComObjectWrapper ) )
@@ -520,12 +562,12 @@ static bool MethodEquals(ComMethodContext a, ComMethodContext b)
520562 private static InterfaceDeclarationSyntax GenerateImplementationInterface ( ComInterfaceAndMethodsContext interfaceGroup , CancellationToken _ )
521563 {
522564 var definingType = interfaceGroup . Interface . Info . Type ;
523- var shadowImplementations = interfaceGroup . InheritedMethods . Select ( m => ( Method : m , ManagedToUnmanagedStub : m . ManagedToUnmanagedStub ) )
565+ var shadowImplementations = interfaceGroup . InheritedMethods . Where ( m => ! m . IsExternallyDefined ) . Select ( m => ( Method : m , ManagedToUnmanagedStub : m . ManagedToUnmanagedStub ) )
524566 . Where ( p => p . ManagedToUnmanagedStub is GeneratedStubCodeContext )
525567 . Select ( ctx => ( ( GeneratedStubCodeContext ) ctx . ManagedToUnmanagedStub ) . Stub . Node
526568 . WithExplicitInterfaceSpecifier (
527569 ExplicitInterfaceSpecifier ( ParseName ( definingType . FullTypeName ) ) ) ) ;
528- var inheritedStubs = interfaceGroup . InheritedMethods . Select ( m => m . UnreachableExceptionStub ) ;
570+ var inheritedStubs = interfaceGroup . InheritedMethods . Where ( m => ! m . IsExternallyDefined ) . Select ( m => m . UnreachableExceptionStub ) ;
529571 return ImplementationInterfaceTemplate
530572 . AddBaseListTypes ( SimpleBaseType ( definingType . Syntax ) )
531573 . WithMembers (
@@ -661,7 +703,6 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
661703
662704 BlockSyntax fillBaseInterfaceSlots ;
663705
664-
665706 if ( interfaceMethods . Interface . Base is null )
666707 {
667708 // If we don't have a base interface, we need to manually fill in the base iUnknown slots.
@@ -740,7 +781,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
740781 }
741782 else
742783 {
743- // NativeMemory.Copy(StrategyBasedComWrappers.DefaultIUnknownInteraceDetailsStrategy .GetIUnknownDerivedDetails(typeof(<baseInterfaceType>).TypeHandle).ManagedVirtualMethodTable, vtable, (nuint)(sizeof(void*) * <startingOffset >));
784+ // NativeMemory.Copy(StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy .GetIUnknownDerivedDetails(typeof(<baseInterfaceType>).TypeHandle).ManagedVirtualMethodTable, vtable, (nuint)(sizeof(void*) * <baseVTableSize >));
744785 fillBaseInterfaceSlots = Block (
745786 MethodInvocationStatement (
746787 TypeSyntaxes . System_Runtime_InteropServices_NativeMemory ,
@@ -750,7 +791,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
750791 TypeSyntaxes . StrategyBasedComWrappers
751792 . Dot ( IdentifierName ( "DefaultIUnknownInterfaceDetailsStrategy" ) ) ,
752793 IdentifierName ( "GetIUnknownDerivedDetails" ) ,
753- Argument ( //baseInterfaceTypeInfo.BaseInterface.FullTypeName)),
794+ Argument (
754795 TypeOfExpression ( ParseTypeName ( interfaceMethods . Interface . Base . Info . Type . FullTypeName ) )
755796 . Dot ( IdentifierName ( "TypeHandle" ) ) ) )
756797 . Dot ( IdentifierName ( "ManagedVirtualMethodTable" ) ) ) ,
@@ -767,7 +808,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
767808 ParenthesizedExpression (
768809 BinaryExpression ( SyntaxKind . MultiplyExpression ,
769810 SizeOfExpression ( PointerType ( PredefinedType ( Token ( SyntaxKind . VoidKeyword ) ) ) ) ,
770- LiteralExpression ( SyntaxKind . NumericLiteralExpression , Literal ( interfaceMethods . InheritedMethods . Count ( ) + 3 ) ) ) ) ) ) ) ) ;
811+ LiteralExpression ( SyntaxKind . NumericLiteralExpression , Literal ( interfaceMethods . BaseVTableSize ) ) ) ) ) ) ) ) ;
771812 }
772813
773814 var validDeclaredMethods = interfaceMethods . DeclaredMethods
@@ -787,7 +828,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
787828 IdentifierName ( $ "{ declaredMethodContext . MethodInfo . MethodName } _{ declaredMethodContext . GenerationContext . VtableIndexData . Index } ") ) ,
788829 PrefixUnaryExpression (
789830 SyntaxKind . AddressOfExpression ,
790- IdentifierName ( $ "ABI_{ declaredMethodContext . GenerationContext . StubMethodSyntaxTemplate . Identifier } ") ) ) ) ) ;
831+ IdentifierName ( $ "ABI_{ ( ( SourceAvailableIncrementalMethodStubGenerationContext ) declaredMethodContext . GenerationContext ) . StubMethodSyntaxTemplate . Identifier } ") ) ) ) ) ;
791832 }
792833
793834 return ImplementationInterfaceTemplate
0 commit comments