diff --git a/src/Authoring/WinRT.SourceGenerator/AotOptimizer.cs b/src/Authoring/WinRT.SourceGenerator/AotOptimizer.cs index 739132c28c..07b4aa6620 100644 --- a/src/Authoring/WinRT.SourceGenerator/AotOptimizer.cs +++ b/src/Authoring/WinRT.SourceGenerator/AotOptimizer.cs @@ -1239,8 +1239,7 @@ private static EquatableArray GetVtableAttributesToAddOnLookupT { if (methodSymbol.Parameters[paramsIdx].RefKind != RefKind.Out) { - var argumentType = context.SemanticModel.GetTypeInfo(invocation.ArgumentList.Arguments[idx].Expression); - AddVtableAttributesForType(argumentType, methodSymbol.Parameters[paramsIdx].Type); + AddVtableAttributesForExpression(invocation.ArgumentList.Arguments[idx].Expression, methodSymbol.Parameters[paramsIdx].Type); } // The method parameter can be declared as params which means @@ -1269,8 +1268,7 @@ private static EquatableArray GetVtableAttributesToAddOnLookupT { if (methodSymbol.Parameters[paramsIdx].RefKind != RefKind.Out) { - var argumentType = context.SemanticModel.GetTypeInfo(objectCreation.ArgumentList.Arguments[idx].Expression); - AddVtableAttributesForType(argumentType, methodSymbol.Parameters[paramsIdx].Type); + AddVtableAttributesForExpression(objectCreation.ArgumentList.Arguments[idx].Expression, methodSymbol.Parameters[paramsIdx].Type); } if (!methodSymbol.Parameters[paramsIdx].IsParams) @@ -1293,7 +1291,7 @@ private static EquatableArray GetVtableAttributesToAddOnLookupT (isGeneratedBindableCustomPropertyClass = GeneratorHelper.IsGeneratedBindableCustomPropertyClass(context.SemanticModel.Compilation, propertySymbol.ContainingSymbol)) || SymbolEqualityComparer.Default.Equals(propertySymbol.ContainingAssembly, context.SemanticModel.Compilation.Assembly))) { - AddVtableAttributesForType(context.SemanticModel.GetTypeInfo(assignment.Right), propertySymbol.Type, isGeneratedBindableCustomPropertyClass); + AddVtableAttributesForExpression(assignment.Right, propertySymbol.Type, isGeneratedBindableCustomPropertyClass); } else if (leftSymbol is IFieldSymbol fieldSymbol && // WinRT interfaces don't have fields, so we don't need to check for them. @@ -1301,7 +1299,7 @@ private static EquatableArray GetVtableAttributesToAddOnLookupT (isGeneratedBindableCustomPropertyClass = GeneratorHelper.IsGeneratedBindableCustomPropertyClass(context.SemanticModel.Compilation, fieldSymbol.ContainingSymbol)) || SymbolEqualityComparer.Default.Equals(fieldSymbol.ContainingAssembly, context.SemanticModel.Compilation.Assembly))) { - AddVtableAttributesForType(context.SemanticModel.GetTypeInfo(assignment.Right), fieldSymbol.Type, isGeneratedBindableCustomPropertyClass); + AddVtableAttributesForExpression(assignment.Right, fieldSymbol.Type, isGeneratedBindableCustomPropertyClass); } } else if (context.Node is VariableDeclarationSyntax variableDeclaration) @@ -1314,8 +1312,7 @@ private static EquatableArray GetVtableAttributesToAddOnLookupT { if (variable.Initializer != null) { - var instantiatedType = context.SemanticModel.GetTypeInfo(variable.Initializer.Value); - AddVtableAttributesForType(instantiatedType, namedType); + AddVtableAttributesForExpression(variable.Initializer.Value, namedType); } } } @@ -1328,8 +1325,7 @@ private static EquatableArray GetVtableAttributesToAddOnLookupT var leftSymbol = context.SemanticModel.GetSymbolInfo(propertyDeclaration.Type).Symbol; if (leftSymbol is INamedTypeSymbol namedType) { - var instantiatedType = context.SemanticModel.GetTypeInfo(propertyDeclaration.Initializer.Value); - AddVtableAttributesForType(instantiatedType, namedType); + AddVtableAttributesForExpression(propertyDeclaration.Initializer.Value, namedType); } } else if (propertyDeclaration.ExpressionBody != null) @@ -1337,8 +1333,7 @@ private static EquatableArray GetVtableAttributesToAddOnLookupT var leftSymbol = context.SemanticModel.GetSymbolInfo(propertyDeclaration.Type).Symbol; if (leftSymbol is INamedTypeSymbol namedType) { - var instantiatedType = context.SemanticModel.GetTypeInfo(propertyDeclaration.ExpressionBody.Expression); - AddVtableAttributesForType(instantiatedType, namedType); + AddVtableAttributesForExpression(propertyDeclaration.ExpressionBody.Expression, namedType); } } } @@ -1346,14 +1341,13 @@ private static EquatableArray GetVtableAttributesToAddOnLookupT { // Detect scenarios where the method or property being returned from is doing a box or cast of the type // in the return statement. - var returnSymbol = context.SemanticModel.GetTypeInfo(returnDeclaration.Expression); var parent = returnDeclaration.Ancestors().OfType().FirstOrDefault(); if (parent is MethodDeclarationSyntax methodDeclaration) { var methodReturnSymbol = context.SemanticModel.GetSymbolInfo(methodDeclaration.ReturnType).Symbol; if (methodReturnSymbol is ITypeSymbol typeSymbol) { - AddVtableAttributesForType(returnSymbol, typeSymbol); + AddVtableAttributesForExpression(returnDeclaration.Expression, typeSymbol); } } else if (parent is BasePropertyDeclarationSyntax propertyDeclarationSyntax) @@ -1361,7 +1355,7 @@ private static EquatableArray GetVtableAttributesToAddOnLookupT var propertyTypeSymbol = context.SemanticModel.GetSymbolInfo(propertyDeclarationSyntax.Type).Symbol; if (propertyTypeSymbol is ITypeSymbol typeSymbol) { - AddVtableAttributesForType(returnSymbol, typeSymbol); + AddVtableAttributesForExpression(returnDeclaration.Expression, typeSymbol); } } } @@ -1399,6 +1393,41 @@ SpecialType.System_Collections_Generic_IReadOnlyCollection_T or return vtableAttributes.ToImmutableArray(); + // Looks through parenthesized, cast, conditional (ternary) and switch expressions to reach the concrete + // leaf expressions that actually flow into the given target, gathering vtable information for each of them. + // This is required because eg. the type of a conditional or switch expression is just the common type of + // all of its branches (usually 'object' or an interface), which hides the concrete types that are the ones + // being boxed or cast and thus needing CCW vtable entries. + void AddVtableAttributesForExpression(ExpressionSyntax expression, ITypeSymbol convertedToTypeSymbol, bool isGeneratedBindableCustomPropertyClass = false) + { + switch (expression) + { + case ParenthesizedExpressionSyntax parenthesizedExpression: + AddVtableAttributesForExpression(parenthesizedExpression.Expression, convertedToTypeSymbol, isGeneratedBindableCustomPropertyClass); + break; + // The cast target type can itself be a concrete type being boxed or cast (eg. '(List)value'), so + // process it directly, but also look through to the operand to catch the concrete type being cast to + // something more general (eg. '(object)new List()'), which the cast type alone would hide. + case CastExpressionSyntax castExpression: + AddVtableAttributesForType(context.SemanticModel.GetTypeInfo(castExpression), convertedToTypeSymbol, isGeneratedBindableCustomPropertyClass); + AddVtableAttributesForExpression(castExpression.Expression, convertedToTypeSymbol, isGeneratedBindableCustomPropertyClass); + break; + case ConditionalExpressionSyntax conditionalExpression: + AddVtableAttributesForExpression(conditionalExpression.WhenTrue, convertedToTypeSymbol, isGeneratedBindableCustomPropertyClass); + AddVtableAttributesForExpression(conditionalExpression.WhenFalse, convertedToTypeSymbol, isGeneratedBindableCustomPropertyClass); + break; + case SwitchExpressionSyntax switchExpression: + foreach (var switchExpressionArm in switchExpression.Arms) + { + AddVtableAttributesForExpression(switchExpressionArm.Expression, convertedToTypeSymbol, isGeneratedBindableCustomPropertyClass); + } + break; + default: + AddVtableAttributesForType(context.SemanticModel.GetTypeInfo(expression), convertedToTypeSymbol, isGeneratedBindableCustomPropertyClass); + break; + } + } + // Helper to directly use 'AddVtableAttributesForTypeDirect' with 'TypeInfo' values void AddVtableAttributesForType(Microsoft.CodeAnalysis.TypeInfo instantiatedType, ITypeSymbol convertedToTypeSymbol, bool isGeneratedBindableCustomPropertyClass = false) { diff --git a/src/Tests/FunctionalTests/CCW/Program.cs b/src/Tests/FunctionalTests/CCW/Program.cs index 6159b431e1..5d9e4d6703 100644 --- a/src/Tests/FunctionalTests/CCW/Program.cs +++ b/src/Tests/FunctionalTests/CCW/Program.cs @@ -338,6 +338,34 @@ } #endif +// Regression test for https://github.com/microsoft/CsWinRT/issues/1947: the AOT source generator has to look +// through conditional (ternary) and switch expressions (as well as casts) to discover the concrete generic types +// being boxed. The 'List' types below only ever appear inside those expressions, so without that discovery their +// CCWs would be missing the 'IVector' vtable under Native AOT and the runtime class name check would fail. +ccw = MarshalInspectable.CreateMarshaler(BoxListViaTernary(true)); +if (!CheckRuntimeClassName(ccw, "Windows.Foundation.Collections.IVector`1")) +{ + return 139; +} + +ccw = MarshalInspectable.CreateMarshaler(BoxListViaTernary(false)); +if (!CheckRuntimeClassName(ccw, "Windows.Foundation.Collections.IVector`1")) +{ + return 140; +} + +ccw = MarshalInspectable.CreateMarshaler(BoxListViaSwitch(0)); +if (!CheckRuntimeClassName(ccw, "Windows.Foundation.Collections.IVector`1")) +{ + return 141; +} + +ccw = MarshalInspectable.CreateMarshaler(BoxListViaSwitch(1)); +if (!CheckRuntimeClassName(ccw, "Windows.Foundation.Collections.IVector`1")) +{ + return 142; +} + return 100; @@ -390,6 +418,25 @@ unsafe bool CheckRuntimeClassName(IObjectReference objRef, string expected) } } +// Boxes a generic list through a conditional (ternary) expression, so the concrete element types are only +// ever reachable through the ternary branches (see issue #1947). +static object BoxListViaTernary(bool flag) +{ + object boxed = flag ? new List() : (object)new List(); + return boxed; +} + +// Same as 'BoxListViaTernary', but exercising a switch expression instead of a conditional expression. +static object BoxListViaSwitch(int selector) +{ + object boxed = selector switch + { + 0 => new List(), + _ => (object)new List() + }; + return boxed; +} + sealed partial class ManagedProperties : IProperties1, IUriHandler { private readonly int _value; diff --git a/src/Tests/SourceGeneratorTest/AotOptimizerTests.cs b/src/Tests/SourceGeneratorTest/AotOptimizerTests.cs new file mode 100644 index 0000000000..58c162c7c1 --- /dev/null +++ b/src/Tests/SourceGeneratorTest/AotOptimizerTests.cs @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using Basic.Reference.Assemblies; +using Generator; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using WinRT; + +namespace SourceGeneratorTest; + +[TestClass] +public class AotOptimizerTests +{ + // Regression tests for https://github.com/microsoft/CsWinRT/issues/1947. The AOT source generator has to look + // through conditional (ternary) and switch expressions, as well as casts, to discover the concrete generic types + // being boxed. Discovered types are registered in the generated CCW vtable lookup table, keyed by their runtime + // 'Type.ToString()' name (eg. 'System.Collections.Generic.List`1[System.Int32]'), so their presence in the + // generated sources proves the generator saw the boxing. + + [TestMethod] + public void ConditionalExpression_DiscoversConcreteTypesInBothBranches() + { + const string source = """ + using System.Collections.Generic; + + internal class Test + { + private static bool GetFlag() => true; + + public object M() + { + return GetFlag() ? new List() : (object)new List(); + } + } + """; + + string generated = RunAotOptimizer(source); + + Assert.IsTrue(generated.Contains("System.Collections.Generic.List`1[System.Int32]")); + Assert.IsTrue(generated.Contains("System.Collections.Generic.List`1[System.String]")); + } + + [TestMethod] + public void SwitchExpression_DiscoversConcreteTypesInAllArms() + { + const string source = """ + using System.Collections.Generic; + + internal class Test + { + public object M(int selector) + { + object boxed = selector switch + { + 0 => new List(), + _ => (object)new List() + }; + + return boxed; + } + } + """; + + string generated = RunAotOptimizer(source); + + Assert.IsTrue(generated.Contains("System.Collections.Generic.List`1[System.Byte]")); + Assert.IsTrue(generated.Contains("System.Collections.Generic.List`1[System.Single]")); + } + + [TestMethod] + public void CastExpression_DiscoversConcreteOperandType() + { + const string source = """ + using System.Collections.Generic; + + internal class Test + { + public object M() + { + object boxed = (object)new List(); + + return boxed; + } + } + """; + + string generated = RunAotOptimizer(source); + + Assert.IsTrue(generated.Contains("System.Collections.Generic.List`1[System.Double]")); + } + + [TestMethod] + public void ConditionalExpression_NoBoxing_DoesNotDiscoverConcreteTypes() + { + // The lists are assigned to their own concrete type, so nothing is boxed or cast and there is + // no work for the CCW lookup table generator to do. This guards against over-eager discovery. + const string source = """ + using System.Collections.Generic; + + internal class Test + { + private static bool GetFlag() => true; + + public List M() + { + return GetFlag() ? new List() : new List(); + } + } + """; + + string generated = RunAotOptimizer(source); + + Assert.IsFalse(generated.Contains("System.Collections.Generic.List`1[System.Int32]")); + } + + private static string RunAotOptimizer(string source) + { + SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(source, new CSharpParseOptions(LanguageVersion.Latest)); + + List references = new(Net80.References.All) + { + MetadataReference.CreateFromFile(typeof(ComWrappersSupport).Assembly.Location) + }; + + CSharpCompilation compilation = CSharpCompilation.Create( + "AotOptimizerTest", + new[] { syntaxTree }, + references, + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary, allowUnsafe: true)); + + GeneratorDriver driver = CSharpGeneratorDriver.Create( + generators: new[] { new WinRTAotSourceGenerator().AsSourceGenerator() }, + additionalTexts: ImmutableArray.Empty, + parseOptions: (CSharpParseOptions)syntaxTree.Options, + optionsProvider: new ConfigOptionsProvider()); + + driver = driver.RunGenerators(compilation); + + return string.Join( + Environment.NewLine, + driver.GetRunResult().GeneratedTrees.Select(static tree => tree.ToString())); + } + + private sealed class ConfigOptions : AnalyzerConfigOptions + { + public Dictionary Values { get; } = new() + { + ["build_property.AssemblyName"] = "AotOptimizerTest", + ["build_property.AssemblyVersion"] = "1.0.0.0", + ["build_property.CsWinRTComponent"] = "false", + ["build_property.CsWinRTAotOptimizerEnabled"] = "auto", + ["build_property.CsWinRTCcwLookupTableGeneratorEnabled"] = "true", + }; + + public override bool TryGetValue(string key, [NotNullWhen(true)] out string value) + { + return Values.TryGetValue(key, out value); + } + } + + private sealed class ConfigOptionsProvider : AnalyzerConfigOptionsProvider + { + public override AnalyzerConfigOptions GlobalOptions { get; } = new ConfigOptions(); + + public override AnalyzerConfigOptions GetOptions(SyntaxTree tree) => GlobalOptions; + + public override AnalyzerConfigOptions GetOptions(AdditionalText textFile) => GlobalOptions; + } +}