diff --git a/src/Compatibility/GenAPI/Microsoft.DotNet.GenAPI/CSharpFileBuilder.cs b/src/Compatibility/GenAPI/Microsoft.DotNet.GenAPI/CSharpFileBuilder.cs index c4a287843beb..10cf790e827a 100644 --- a/src/Compatibility/GenAPI/Microsoft.DotNet.GenAPI/CSharpFileBuilder.cs +++ b/src/Compatibility/GenAPI/Microsoft.DotNet.GenAPI/CSharpFileBuilder.cs @@ -34,6 +34,7 @@ public sealed class CSharpFileBuilder : IAssemblySymbolWriter, IDisposable private readonly AdhocWorkspace _adhocWorkspace; private readonly SyntaxGenerator _syntaxGenerator; private readonly IEnumerable _metadataReferences; + private readonly bool _addPartialModifier; public CSharpFileBuilder(ILog logger, ISymbolFilter symbolFilter, @@ -41,7 +42,8 @@ public CSharpFileBuilder(ILog logger, TextWriter textWriter, string? exceptionMessage, bool includeAssemblyAttributes, - IEnumerable metadataReferences) + IEnumerable metadataReferences, + bool addPartialModifier) { _logger = logger; _textWriter = textWriter; @@ -52,6 +54,7 @@ public CSharpFileBuilder(ILog logger, _adhocWorkspace = new AdhocWorkspace(); _syntaxGenerator = SyntaxGenerator.GetGenerator(_adhocWorkspace, LanguageNames.CSharp); _metadataReferences = metadataReferences; + _addPartialModifier = addPartialModifier; } /// @@ -79,7 +82,7 @@ public void WriteAssembly(IAssemblySymbol assemblySymbol) SyntaxNode compilationUnit = _syntaxGenerator.CompilationUnit(namespaceSyntaxNodes) .WithAdditionalAnnotations(Formatter.Annotation, Simplifier.Annotation) - .Rewrite(new TypeDeclarationCSharpSyntaxRewriter()) + .Rewrite(new TypeDeclarationCSharpSyntaxRewriter(addPartialModifier: true)) .Rewrite(new BodyBlockCSharpSyntaxRewriter(_exceptionMessage)); if (_includeAssemblyAttributes) diff --git a/src/Compatibility/GenAPI/Microsoft.DotNet.GenAPI/GenAPIApp.cs b/src/Compatibility/GenAPI/Microsoft.DotNet.GenAPI/GenAPIApp.cs index 51bf68c42017..5fbdd6c574c3 100644 --- a/src/Compatibility/GenAPI/Microsoft.DotNet.GenAPI/GenAPIApp.cs +++ b/src/Compatibility/GenAPI/Microsoft.DotNet.GenAPI/GenAPIApp.cs @@ -81,7 +81,8 @@ public static void Run(ILog logger, textWriter, exceptionMessage, includeAssemblyAttributes, - loader.MetadataReferences); + loader.MetadataReferences, + addPartialModifier: true); fileBuilder.WriteAssembly(assemblySymbol); } diff --git a/src/Compatibility/GenAPI/Microsoft.DotNet.GenAPI/SyntaxRewriter/TypeDeclarationCSharpSyntaxRewriter.cs b/src/Compatibility/GenAPI/Microsoft.DotNet.GenAPI/SyntaxRewriter/TypeDeclarationCSharpSyntaxRewriter.cs index 80a061eb461f..99b92239dfa8 100644 --- a/src/Compatibility/GenAPI/Microsoft.DotNet.GenAPI/SyntaxRewriter/TypeDeclarationCSharpSyntaxRewriter.cs +++ b/src/Compatibility/GenAPI/Microsoft.DotNet.GenAPI/SyntaxRewriter/TypeDeclarationCSharpSyntaxRewriter.cs @@ -17,6 +17,14 @@ namespace Microsoft.DotNet.GenAPI.SyntaxRewriter /// public class TypeDeclarationCSharpSyntaxRewriter : CSharpSyntaxRewriter { + private readonly bool _addPartialModifier; + + /// + /// Initializes a new instance of the class, and allows deciding whether to insert the partial modifier for types or not. + /// + /// Determines whether to insert the partial modifier for types or not. + public TypeDeclarationCSharpSyntaxRewriter(bool addPartialModifier) => _addPartialModifier = addPartialModifier; + /// public override SyntaxNode? VisitInterfaceDeclaration(InterfaceDeclarationSyntax node) { @@ -83,7 +91,7 @@ public class TypeDeclarationCSharpSyntaxRewriter : CSharpSyntaxRewriter } } - private static T? VisitCommonTypeDeclaration(T? node) where T : TypeDeclarationSyntax + private T? VisitCommonTypeDeclaration(T? node) where T : TypeDeclarationSyntax { if (node == null) { @@ -91,7 +99,7 @@ public class TypeDeclarationCSharpSyntaxRewriter : CSharpSyntaxRewriter } node = RemoveBaseType(node, "global::System.Object"); - return AddPartialModifier(node); + return _addPartialModifier ? AddPartialModifier(node) : node; } private static T? AddPartialModifier(T? node) where T : TypeDeclarationSyntax => diff --git a/test/Microsoft.DotNet.GenAPI.Tests/CSharpFileBuilderTests.cs b/test/Microsoft.DotNet.GenAPI.Tests/CSharpFileBuilderTests.cs index 68009dbd4f0e..55ad81c0c048 100644 --- a/test/Microsoft.DotNet.GenAPI.Tests/CSharpFileBuilderTests.cs +++ b/test/Microsoft.DotNet.GenAPI.Tests/CSharpFileBuilderTests.cs @@ -61,7 +61,8 @@ private void RunTest(string original, stringWriter, null, false, - MetadataReferences); + MetadataReferences, + addPartialModifier: true); using Stream assemblyStream = SymbolFactory.EmitAssemblyStreamFromSyntax(original, enableNullable: true, allowUnsafe: allowUnsafe, assemblyName: assemblyName); AssemblySymbolLoader assemblySymbolLoader = new(resolveAssemblyReferences: true, includeInternalSymbols: includeInternalSymbols); @@ -231,7 +232,7 @@ public void TestRecordDeclaration() { RunTest(original: """ namespace Foo - { + { public record RecordClass; public record RecordClass1(int i); public record RecordClass2(string s, int i); @@ -240,7 +241,7 @@ public record DerivedRecord2(string x, int i, double d) : RecordClass2(default(s public record DerivedRecord3(string x, int i, double d) : RecordClass2(default(string)!, i); public record DerivedRecord4(double d) : RecordClass2(default(string)!, default); public record DerivedRecord5() : RecordClass2(default(string)!, default); - + public record RecordClassWithMethods(int i) { public void DoSomething() { } @@ -345,11 +346,11 @@ public void TestRecordStructDeclaration() RunTest(original: """ namespace Foo { - - public record struct RecordStruct; + + public record struct RecordStruct; public record struct RecordStruct1(int i); public record struct RecordStruct2(string s, int i); - + public record struct RecordStructWithMethods(int i) { public void DoSomething() { } @@ -367,10 +368,10 @@ public record struct RecordStructWithConstructors(int i) public RecordStructWithConstructors() : this(1) { } public RecordStructWithConstructors(string s) : this(int.Parse(s)) { } } - + } """, - expected: """ + expected: """ namespace Foo { public partial struct RecordStruct : System.IEquatable @@ -1644,12 +1645,12 @@ public class B { public B(int i) {} } - + public class C : B { internal C() : base(0) {} } - + public class D : B { internal D(int i) : base(i) {} @@ -1672,7 +1673,7 @@ public partial class B { public B(int i) {} } - + public partial class C : B { internal C() : base(default) {} @@ -1702,12 +1703,12 @@ public class B { public B(int i) {} } - + public class C : B { internal C() : base(0) {} } - + public class D : B { internal D(int i) : base(i) {} @@ -1781,8 +1782,8 @@ namespace A public partial class B { protected B() {} - } - + } + public partial class C : B { internal C() {} @@ -1935,7 +1936,7 @@ public class B : A public class D { } public class Id { } - + public class V { } } """, @@ -2828,7 +2829,7 @@ public class Foo : System.Collections.ICollection, System.Collections.Generic } } - + """, // https://github.com/dotnet/sdk/issues/32195 tracks interface expansion expected: """ @@ -2909,7 +2910,7 @@ namespace N { public ref struct C where T : unmanaged { - public required (string? k, dynamic v, nint n) X { get; init; } + public required (string? k, dynamic v, nint n) X { get; init; } } public static class E @@ -2918,7 +2919,7 @@ public static void M(this object c, scoped System.ReadOnlySpan values) { } } } """, - expected: """ + expected: """ namespace N { public ref partial struct C @@ -2982,7 +2983,7 @@ public void TestExplicitInterfaceNonGenericCollections() namespace a { #pragma warning disable CS8597 - + public partial class MyStringCollection : ICollection, IEnumerable, IList { public int Count { get { throw null; } } @@ -3006,7 +3007,7 @@ public void RemoveAt(int index) { } void ICollection.CopyTo(Array array, int index) { } IEnumerator IEnumerable.GetEnumerator() { throw null; } int IList.Add(object? value) { throw null; } - bool IList.Contains(object? value) { throw null; } + bool IList.Contains(object? value) { throw null; } int IList.IndexOf(object? value) { throw null; } void IList.Insert(int index, object? value) { } void IList.Remove(object? value) { } @@ -3015,7 +3016,7 @@ void IList.Remove(object? value) { } #pragma warning restore CS8597 } """, - expected: """ + expected: """ namespace a { public partial class MyStringCollection : System.Collections.ICollection, System.Collections.IEnumerable, System.Collections.IList diff --git a/test/Microsoft.DotNet.GenAPI.Tests/SyntaxRewriter/TypeDeclarationCSharpSyntaxRewriterTests.cs b/test/Microsoft.DotNet.GenAPI.Tests/SyntaxRewriter/TypeDeclarationCSharpSyntaxRewriterTests.cs index c2f498c1651c..f33ea9ec90e4 100644 --- a/test/Microsoft.DotNet.GenAPI.Tests/SyntaxRewriter/TypeDeclarationCSharpSyntaxRewriterTests.cs +++ b/test/Microsoft.DotNet.GenAPI.Tests/SyntaxRewriter/TypeDeclarationCSharpSyntaxRewriterTests.cs @@ -10,7 +10,7 @@ public class TypeDeclarationCSharpSyntaxRewriterTests : CSharpSyntaxRewriterTest [Fact] public void TestRemoveSystemObjectAsBaseClass() { - CompareSyntaxTree(new TypeDeclarationCSharpSyntaxRewriter(), + CompareSyntaxTree(new TypeDeclarationCSharpSyntaxRewriter(addPartialModifier: true), original: """ namespace A { @@ -32,7 +32,7 @@ partial class B [Fact] public void TestAddPartialKeyword() { - CompareSyntaxTree(new TypeDeclarationCSharpSyntaxRewriter(), + CompareSyntaxTree(new TypeDeclarationCSharpSyntaxRewriter(addPartialModifier: true), original: """ namespace A { @@ -54,7 +54,7 @@ partial interface D { } [Fact] public void TestPartialTypeDeclaration() { - CompareSyntaxTree(new TypeDeclarationCSharpSyntaxRewriter(), + CompareSyntaxTree(new TypeDeclarationCSharpSyntaxRewriter(addPartialModifier: true), original: """ namespace A {