Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 44 additions & 15 deletions src/Authoring/WinRT.SourceGenerator/AotOptimizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1239,8 +1239,7 @@ private static EquatableArray<VtableAttribute> GetVtableAttributesToAddOnLookupT
{
if (methodSymbol.Parameters[paramsIdx].RefKind != RefKind.Out)
{
var argumentType = context.SemanticModel.GetTypeInfo(invocation.ArgumentList.Arguments[idx].Expression);
AddVtableAttributesForType(argumentType, methodSymbol.Parameters[paramsIdx].Type);
AddVtableAttributesForExpression(invocation.ArgumentList.Arguments[idx].Expression, methodSymbol.Parameters[paramsIdx].Type);
}

// The method parameter can be declared as params which means
Expand Down Expand Up @@ -1269,8 +1268,7 @@ private static EquatableArray<VtableAttribute> GetVtableAttributesToAddOnLookupT
{
if (methodSymbol.Parameters[paramsIdx].RefKind != RefKind.Out)
{
var argumentType = context.SemanticModel.GetTypeInfo(objectCreation.ArgumentList.Arguments[idx].Expression);
AddVtableAttributesForType(argumentType, methodSymbol.Parameters[paramsIdx].Type);
AddVtableAttributesForExpression(objectCreation.ArgumentList.Arguments[idx].Expression, methodSymbol.Parameters[paramsIdx].Type);
}

if (!methodSymbol.Parameters[paramsIdx].IsParams)
Expand All @@ -1293,15 +1291,15 @@ private static EquatableArray<VtableAttribute> GetVtableAttributesToAddOnLookupT
(isGeneratedBindableCustomPropertyClass = GeneratorHelper.IsGeneratedBindableCustomPropertyClass(context.SemanticModel.Compilation, propertySymbol.ContainingSymbol)) ||
SymbolEqualityComparer.Default.Equals(propertySymbol.ContainingAssembly, context.SemanticModel.Compilation.Assembly)))
{
AddVtableAttributesForType(context.SemanticModel.GetTypeInfo(assignment.Right), propertySymbol.Type, isGeneratedBindableCustomPropertyClass);
AddVtableAttributesForExpression(assignment.Right, propertySymbol.Type, isGeneratedBindableCustomPropertyClass);
}
else if (leftSymbol is IFieldSymbol fieldSymbol &&
// WinRT interfaces don't have fields, so we don't need to check for them.
(isWinRTClassOrInterface(fieldSymbol.ContainingSymbol, false) ||
(isGeneratedBindableCustomPropertyClass = GeneratorHelper.IsGeneratedBindableCustomPropertyClass(context.SemanticModel.Compilation, fieldSymbol.ContainingSymbol)) ||
SymbolEqualityComparer.Default.Equals(fieldSymbol.ContainingAssembly, context.SemanticModel.Compilation.Assembly)))
{
AddVtableAttributesForType(context.SemanticModel.GetTypeInfo(assignment.Right), fieldSymbol.Type, isGeneratedBindableCustomPropertyClass);
AddVtableAttributesForExpression(assignment.Right, fieldSymbol.Type, isGeneratedBindableCustomPropertyClass);
}
}
else if (context.Node is VariableDeclarationSyntax variableDeclaration)
Expand All @@ -1314,8 +1312,7 @@ private static EquatableArray<VtableAttribute> GetVtableAttributesToAddOnLookupT
{
if (variable.Initializer != null)
{
var instantiatedType = context.SemanticModel.GetTypeInfo(variable.Initializer.Value);
AddVtableAttributesForType(instantiatedType, namedType);
AddVtableAttributesForExpression(variable.Initializer.Value, namedType);
}
}
}
Expand All @@ -1328,40 +1325,37 @@ private static EquatableArray<VtableAttribute> GetVtableAttributesToAddOnLookupT
var leftSymbol = context.SemanticModel.GetSymbolInfo(propertyDeclaration.Type).Symbol;
if (leftSymbol is INamedTypeSymbol namedType)
{
var instantiatedType = context.SemanticModel.GetTypeInfo(propertyDeclaration.Initializer.Value);
AddVtableAttributesForType(instantiatedType, namedType);
AddVtableAttributesForExpression(propertyDeclaration.Initializer.Value, namedType);
}
}
else if (propertyDeclaration.ExpressionBody != null)
{
var leftSymbol = context.SemanticModel.GetSymbolInfo(propertyDeclaration.Type).Symbol;
if (leftSymbol is INamedTypeSymbol namedType)
{
var instantiatedType = context.SemanticModel.GetTypeInfo(propertyDeclaration.ExpressionBody.Expression);
AddVtableAttributesForType(instantiatedType, namedType);
AddVtableAttributesForExpression(propertyDeclaration.ExpressionBody.Expression, namedType);
}
}
}
else if (context.Node is ReturnStatementSyntax { Expression: not null } returnDeclaration)
{
// Detect scenarios where the method or property being returned from is doing a box or cast of the type
// in the return statement.
var returnSymbol = context.SemanticModel.GetTypeInfo(returnDeclaration.Expression);
var parent = returnDeclaration.Ancestors().OfType<MemberDeclarationSyntax>().FirstOrDefault();
if (parent is MethodDeclarationSyntax methodDeclaration)
{
var methodReturnSymbol = context.SemanticModel.GetSymbolInfo(methodDeclaration.ReturnType).Symbol;
if (methodReturnSymbol is ITypeSymbol typeSymbol)
{
AddVtableAttributesForType(returnSymbol, typeSymbol);
AddVtableAttributesForExpression(returnDeclaration.Expression, typeSymbol);
}
}
else if (parent is BasePropertyDeclarationSyntax propertyDeclarationSyntax)
{
var propertyTypeSymbol = context.SemanticModel.GetSymbolInfo(propertyDeclarationSyntax.Type).Symbol;
if (propertyTypeSymbol is ITypeSymbol typeSymbol)
{
AddVtableAttributesForType(returnSymbol, typeSymbol);
AddVtableAttributesForExpression(returnDeclaration.Expression, typeSymbol);
}
}
}
Expand Down Expand Up @@ -1399,6 +1393,41 @@ SpecialType.System_Collections_Generic_IReadOnlyCollection_T or

return vtableAttributes.ToImmutableArray();

// Looks through parenthesized, cast, conditional (ternary) and switch expressions to reach the concrete
// leaf expressions that actually flow into the given target, gathering vtable information for each of them.
// This is required because eg. the type of a conditional or switch expression is just the common type of
// all of its branches (usually 'object' or an interface), which hides the concrete types that are the ones
// being boxed or cast and thus needing CCW vtable entries.
void AddVtableAttributesForExpression(ExpressionSyntax expression, ITypeSymbol convertedToTypeSymbol, bool isGeneratedBindableCustomPropertyClass = false)
{
switch (expression)
{
case ParenthesizedExpressionSyntax parenthesizedExpression:
AddVtableAttributesForExpression(parenthesizedExpression.Expression, convertedToTypeSymbol, isGeneratedBindableCustomPropertyClass);
break;
// The cast target type can itself be a concrete type being boxed or cast (eg. '(List<int>)value'), so
// process it directly, but also look through to the operand to catch the concrete type being cast to
// something more general (eg. '(object)new List<string>()'), which the cast type alone would hide.
case CastExpressionSyntax castExpression:
AddVtableAttributesForType(context.SemanticModel.GetTypeInfo(castExpression), convertedToTypeSymbol, isGeneratedBindableCustomPropertyClass);
AddVtableAttributesForExpression(castExpression.Expression, convertedToTypeSymbol, isGeneratedBindableCustomPropertyClass);
break;
case ConditionalExpressionSyntax conditionalExpression:
AddVtableAttributesForExpression(conditionalExpression.WhenTrue, convertedToTypeSymbol, isGeneratedBindableCustomPropertyClass);
AddVtableAttributesForExpression(conditionalExpression.WhenFalse, convertedToTypeSymbol, isGeneratedBindableCustomPropertyClass);
break;
case SwitchExpressionSyntax switchExpression:
foreach (var switchExpressionArm in switchExpression.Arms)
{
AddVtableAttributesForExpression(switchExpressionArm.Expression, convertedToTypeSymbol, isGeneratedBindableCustomPropertyClass);
}
break;
default:
AddVtableAttributesForType(context.SemanticModel.GetTypeInfo(expression), convertedToTypeSymbol, isGeneratedBindableCustomPropertyClass);
break;
}
}

// Helper to directly use 'AddVtableAttributesForTypeDirect' with 'TypeInfo' values
void AddVtableAttributesForType(Microsoft.CodeAnalysis.TypeInfo instantiatedType, ITypeSymbol convertedToTypeSymbol, bool isGeneratedBindableCustomPropertyClass = false)
{
Expand Down
47 changes: 47 additions & 0 deletions src/Tests/FunctionalTests/CCW/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,34 @@
}
#endif

// Regression test for https://github.com/microsoft/CsWinRT/issues/1947: the AOT source generator has to look
// through conditional (ternary) and switch expressions (as well as casts) to discover the concrete generic types
// being boxed. The 'List<T>' types below only ever appear inside those expressions, so without that discovery their
// CCWs would be missing the 'IVector<T>' vtable under Native AOT and the runtime class name check would fail.
ccw = MarshalInspectable<object>.CreateMarshaler(BoxListViaTernary(true));
if (!CheckRuntimeClassName(ccw, "Windows.Foundation.Collections.IVector`1<Int32>"))
{
return 139;
}

ccw = MarshalInspectable<object>.CreateMarshaler(BoxListViaTernary(false));
if (!CheckRuntimeClassName(ccw, "Windows.Foundation.Collections.IVector`1<String>"))
{
return 140;
}

ccw = MarshalInspectable<object>.CreateMarshaler(BoxListViaSwitch(0));
if (!CheckRuntimeClassName(ccw, "Windows.Foundation.Collections.IVector`1<UInt8>"))
{
return 141;
}

ccw = MarshalInspectable<object>.CreateMarshaler(BoxListViaSwitch(1));
if (!CheckRuntimeClassName(ccw, "Windows.Foundation.Collections.IVector`1<Single>"))
{
return 142;
}

return 100;


Expand Down Expand Up @@ -390,6 +418,25 @@ unsafe bool CheckRuntimeClassName(IObjectReference objRef, string expected)
}
}

// Boxes a generic list through a conditional (ternary) expression, so the concrete element types are only
// ever reachable through the ternary branches (see issue #1947).
static object BoxListViaTernary(bool flag)
{
object boxed = flag ? new List<int>() : (object)new List<string>();
return boxed;
}

// Same as 'BoxListViaTernary', but exercising a switch expression instead of a conditional expression.
static object BoxListViaSwitch(int selector)
{
object boxed = selector switch
{
0 => new List<byte>(),
_ => (object)new List<float>()
};
return boxed;
}

sealed partial class ManagedProperties : IProperties1, IUriHandler
{
private readonly int _value;
Expand Down
177 changes: 177 additions & 0 deletions src/Tests/SourceGeneratorTest/AotOptimizerTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Basic.Reference.Assemblies;
using Generator;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using WinRT;

namespace SourceGeneratorTest;

[TestClass]
public class AotOptimizerTests
{
// Regression tests for https://github.com/microsoft/CsWinRT/issues/1947. The AOT source generator has to look
// through conditional (ternary) and switch expressions, as well as casts, to discover the concrete generic types
// being boxed. Discovered types are registered in the generated CCW vtable lookup table, keyed by their runtime
// 'Type.ToString()' name (eg. 'System.Collections.Generic.List`1[System.Int32]'), so their presence in the
// generated sources proves the generator saw the boxing.

[TestMethod]
public void ConditionalExpression_DiscoversConcreteTypesInBothBranches()
{
const string source = """
using System.Collections.Generic;

internal class Test
{
private static bool GetFlag() => true;

public object M()
{
return GetFlag() ? new List<int>() : (object)new List<string>();
}
}
""";

string generated = RunAotOptimizer(source);

Assert.IsTrue(generated.Contains("System.Collections.Generic.List`1[System.Int32]"));
Assert.IsTrue(generated.Contains("System.Collections.Generic.List`1[System.String]"));
}

[TestMethod]
public void SwitchExpression_DiscoversConcreteTypesInAllArms()
{
const string source = """
using System.Collections.Generic;

internal class Test
{
public object M(int selector)
{
object boxed = selector switch
{
0 => new List<byte>(),
_ => (object)new List<float>()
};

return boxed;
}
}
""";

string generated = RunAotOptimizer(source);

Assert.IsTrue(generated.Contains("System.Collections.Generic.List`1[System.Byte]"));
Assert.IsTrue(generated.Contains("System.Collections.Generic.List`1[System.Single]"));
}

[TestMethod]
public void CastExpression_DiscoversConcreteOperandType()
{
const string source = """
using System.Collections.Generic;

internal class Test
{
public object M()
{
object boxed = (object)new List<double>();

return boxed;
}
}
""";

string generated = RunAotOptimizer(source);

Assert.IsTrue(generated.Contains("System.Collections.Generic.List`1[System.Double]"));
}

[TestMethod]
public void ConditionalExpression_NoBoxing_DoesNotDiscoverConcreteTypes()
{
// The lists are assigned to their own concrete type, so nothing is boxed or cast and there is
// no work for the CCW lookup table generator to do. This guards against over-eager discovery.
const string source = """
using System.Collections.Generic;

internal class Test
{
private static bool GetFlag() => true;

public List<int> M()
{
return GetFlag() ? new List<int>() : new List<int>();
}
}
""";

string generated = RunAotOptimizer(source);

Assert.IsFalse(generated.Contains("System.Collections.Generic.List`1[System.Int32]"));
}

private static string RunAotOptimizer(string source)
{
SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(source, new CSharpParseOptions(LanguageVersion.Latest));

List<MetadataReference> references = new(Net80.References.All)
{
MetadataReference.CreateFromFile(typeof(ComWrappersSupport).Assembly.Location)
};

CSharpCompilation compilation = CSharpCompilation.Create(
"AotOptimizerTest",
new[] { syntaxTree },
references,
new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary, allowUnsafe: true));

GeneratorDriver driver = CSharpGeneratorDriver.Create(
generators: new[] { new WinRTAotSourceGenerator().AsSourceGenerator() },
additionalTexts: ImmutableArray<AdditionalText>.Empty,
parseOptions: (CSharpParseOptions)syntaxTree.Options,
optionsProvider: new ConfigOptionsProvider());

driver = driver.RunGenerators(compilation);

return string.Join(
Environment.NewLine,
driver.GetRunResult().GeneratedTrees.Select(static tree => tree.ToString()));
}

private sealed class ConfigOptions : AnalyzerConfigOptions
{
public Dictionary<string, string> Values { get; } = new()
{
["build_property.AssemblyName"] = "AotOptimizerTest",
["build_property.AssemblyVersion"] = "1.0.0.0",
["build_property.CsWinRTComponent"] = "false",
["build_property.CsWinRTAotOptimizerEnabled"] = "auto",
["build_property.CsWinRTCcwLookupTableGeneratorEnabled"] = "true",
};

public override bool TryGetValue(string key, [NotNullWhen(true)] out string value)
{
return Values.TryGetValue(key, out value);
}
}

private sealed class ConfigOptionsProvider : AnalyzerConfigOptionsProvider
{
public override AnalyzerConfigOptions GlobalOptions { get; } = new ConfigOptions();

public override AnalyzerConfigOptions GetOptions(SyntaxTree tree) => GlobalOptions;

public override AnalyzerConfigOptions GetOptions(AdditionalText textFile) => GlobalOptions;
}
}