Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ApiDiff: Allow deciding whether to include partial modifier in TypeDeclarationCSharpSyntaxRewriter output #45803

Merged
merged 2 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,16 @@ public sealed class CSharpFileBuilder : IAssemblySymbolWriter, IDisposable
private readonly AdhocWorkspace _adhocWorkspace;
private readonly SyntaxGenerator _syntaxGenerator;
private readonly IEnumerable<MetadataReference> _metadataReferences;
private readonly bool _addPartialModifier;

public CSharpFileBuilder(ILog logger,
ISymbolFilter symbolFilter,
ISymbolFilter attributeDataSymbolFilter,
TextWriter textWriter,
string? exceptionMessage,
bool includeAssemblyAttributes,
IEnumerable<MetadataReference> metadataReferences)
IEnumerable<MetadataReference> metadataReferences,
bool addPartialModifier)
carlossanlop marked this conversation as resolved.
Show resolved Hide resolved
{
_logger = logger;
_textWriter = textWriter;
Expand All @@ -52,6 +54,7 @@ public CSharpFileBuilder(ILog logger,
_adhocWorkspace = new AdhocWorkspace();
_syntaxGenerator = SyntaxGenerator.GetGenerator(_adhocWorkspace, LanguageNames.CSharp);
_metadataReferences = metadataReferences;
_addPartialModifier = addPartialModifier;
}

/// <inheritdoc />
Expand Down Expand Up @@ -79,7 +82,7 @@ public void WriteAssembly(IAssemblySymbol assemblySymbol)

SyntaxNode compilationUnit = _syntaxGenerator.CompilationUnit(namespaceSyntaxNodes)
.WithAdditionalAnnotations(Formatter.Annotation, Simplifier.Annotation)
.Rewrite(new TypeDeclarationCSharpSyntaxRewriter())
.Rewrite(new TypeDeclarationCSharpSyntaxRewriter(addPartialModifier: true))
.Rewrite(new BodyBlockCSharpSyntaxRewriter(_exceptionMessage));

if (_includeAssemblyAttributes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ public static void Run(ILog logger,
textWriter,
exceptionMessage,
includeAssemblyAttributes,
loader.MetadataReferences);
loader.MetadataReferences,
addPartialModifier: true);

fileBuilder.WriteAssembly(assemblySymbol);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ namespace Microsoft.DotNet.GenAPI.SyntaxRewriter
/// </summary>
public class TypeDeclarationCSharpSyntaxRewriter : CSharpSyntaxRewriter
{
private readonly bool _addPartialModifier;

/// <summary>
/// Initializes a new instance of the <see cref="TypeDeclarationCSharpSyntaxRewriter"/> class, and allows deciding whether to insert the partial modifier for types or not.
/// </summary>
/// <param name="addPartialModifier">Determines whether to insert the partial modifier for types or not.</param>
public TypeDeclarationCSharpSyntaxRewriter(bool addPartialModifier) => _addPartialModifier = addPartialModifier;

/// <inheritdoc />
public override SyntaxNode? VisitInterfaceDeclaration(InterfaceDeclarationSyntax node)
{
Expand Down Expand Up @@ -83,15 +91,15 @@ public class TypeDeclarationCSharpSyntaxRewriter : CSharpSyntaxRewriter
}
}

private static T? VisitCommonTypeDeclaration<T>(T? node) where T : TypeDeclarationSyntax
private T? VisitCommonTypeDeclaration<T>(T? node) where T : TypeDeclarationSyntax
{
if (node == null)
{
return null;
}

node = RemoveBaseType(node, "global::System.Object");
return AddPartialModifier(node);
return _addPartialModifier ? AddPartialModifier(node) : node;
}

private static T? AddPartialModifier<T>(T? node) where T : TypeDeclarationSyntax =>
Expand Down
45 changes: 23 additions & 22 deletions test/Microsoft.DotNet.GenAPI.Tests/CSharpFileBuilderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ private void RunTest(string original,
stringWriter,
null,
false,
MetadataReferences);
MetadataReferences,
addPartialModifier: true);

using Stream assemblyStream = SymbolFactory.EmitAssemblyStreamFromSyntax(original, enableNullable: true, allowUnsafe: allowUnsafe, assemblyName: assemblyName);
AssemblySymbolLoader assemblySymbolLoader = new(resolveAssemblyReferences: true, includeInternalSymbols: includeInternalSymbols);
Expand Down Expand Up @@ -231,7 +232,7 @@ public void TestRecordDeclaration()
{
RunTest(original: """
namespace Foo
{
{
public record RecordClass;
public record RecordClass1(int i);
public record RecordClass2(string s, int i);
Expand All @@ -240,7 +241,7 @@ public record DerivedRecord2(string x, int i, double d) : RecordClass2(default(s
public record DerivedRecord3(string x, int i, double d) : RecordClass2(default(string)!, i);
public record DerivedRecord4(double d) : RecordClass2(default(string)!, default);
public record DerivedRecord5() : RecordClass2(default(string)!, default);

public record RecordClassWithMethods(int i)
{
public void DoSomething() { }
Expand Down Expand Up @@ -345,11 +346,11 @@ public void TestRecordStructDeclaration()
RunTest(original: """
namespace Foo
{
public record struct RecordStruct;

public record struct RecordStruct;
public record struct RecordStruct1(int i);
public record struct RecordStruct2(string s, int i);

public record struct RecordStructWithMethods(int i)
{
public void DoSomething() { }
Expand All @@ -367,10 +368,10 @@ public record struct RecordStructWithConstructors(int i)
public RecordStructWithConstructors() : this(1) { }
public RecordStructWithConstructors(string s) : this(int.Parse(s)) { }
}

}
""",
expected: """
expected: """
namespace Foo
{
public partial struct RecordStruct : System.IEquatable<RecordStruct>
Expand Down Expand Up @@ -1644,12 +1645,12 @@ public class B
{
public B(int i) {}
}

public class C : B
{
internal C() : base(0) {}
}

public class D : B
{
internal D(int i) : base(i) {}
Expand All @@ -1672,7 +1673,7 @@ public partial class B
{
public B(int i) {}
}

public partial class C : B
{
internal C() : base(default) {}
Expand Down Expand Up @@ -1702,12 +1703,12 @@ public class B
{
public B(int i) {}
}

public class C : B
{
internal C() : base(0) {}
}

public class D : B
{
internal D(int i) : base(i) {}
Expand Down Expand Up @@ -1781,8 +1782,8 @@ namespace A
public partial class B
{
protected B() {}
}
}

public partial class C : B
{
internal C() {}
Expand Down Expand Up @@ -1935,7 +1936,7 @@ public class B : A
public class D { }

public class Id { }

public class V { }
}
""",
Expand Down Expand Up @@ -2828,7 +2829,7 @@ public class Foo<T> : System.Collections.ICollection, System.Collections.Generic

}
}

""",
// https://github.com/dotnet/sdk/issues/32195 tracks interface expansion
expected: """
Expand Down Expand Up @@ -2909,7 +2910,7 @@ namespace N {
public ref struct C<T>
where T : unmanaged
{
public required (string? k, dynamic v, nint n) X { get; init; }
public required (string? k, dynamic v, nint n) X { get; init; }
}

public static class E
Expand All @@ -2918,7 +2919,7 @@ public static void M<T>(this object c, scoped System.ReadOnlySpan<T> values) { }
}
}
""",
expected: """
expected: """
namespace N
{
public ref partial struct C<T>
Expand Down Expand Up @@ -2982,7 +2983,7 @@ public void TestExplicitInterfaceNonGenericCollections()
namespace a
{
#pragma warning disable CS8597

public partial class MyStringCollection : ICollection, IEnumerable, IList
{
public int Count { get { throw null; } }
Expand All @@ -3006,7 +3007,7 @@ public void RemoveAt(int index) { }
void ICollection.CopyTo(Array array, int index) { }
IEnumerator IEnumerable.GetEnumerator() { throw null; }
int IList.Add(object? value) { throw null; }
bool IList.Contains(object? value) { throw null; }
bool IList.Contains(object? value) { throw null; }
int IList.IndexOf(object? value) { throw null; }
void IList.Insert(int index, object? value) { }
void IList.Remove(object? value) { }
Expand All @@ -3015,7 +3016,7 @@ void IList.Remove(object? value) { }
#pragma warning restore CS8597
}
""",
expected: """
expected: """
namespace a
{
public partial class MyStringCollection : System.Collections.ICollection, System.Collections.IEnumerable, System.Collections.IList
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public class TypeDeclarationCSharpSyntaxRewriterTests : CSharpSyntaxRewriterTest
[Fact]
public void TestRemoveSystemObjectAsBaseClass()
{
CompareSyntaxTree(new TypeDeclarationCSharpSyntaxRewriter(),
CompareSyntaxTree(new TypeDeclarationCSharpSyntaxRewriter(addPartialModifier: true),
original: """
namespace A
{
Expand All @@ -32,7 +32,7 @@ partial class B
[Fact]
public void TestAddPartialKeyword()
{
CompareSyntaxTree(new TypeDeclarationCSharpSyntaxRewriter(),
CompareSyntaxTree(new TypeDeclarationCSharpSyntaxRewriter(addPartialModifier: true),
original: """
namespace A
{
Expand All @@ -54,7 +54,7 @@ partial interface D { }
[Fact]
public void TestPartialTypeDeclaration()
{
CompareSyntaxTree(new TypeDeclarationCSharpSyntaxRewriter(),
CompareSyntaxTree(new TypeDeclarationCSharpSyntaxRewriter(addPartialModifier: true),
original: """
namespace A
{
Expand Down
Loading