diff --git a/src/CodeFixes/SyntaxGeneratorExtensions.cs b/src/CodeFixes/SyntaxGeneratorExtensions.cs index cbb2d6ae9..04e7bd91a 100644 --- a/src/CodeFixes/SyntaxGeneratorExtensions.cs +++ b/src/CodeFixes/SyntaxGeneratorExtensions.cs @@ -29,14 +29,14 @@ public static SyntaxNode InsertArguments(this SyntaxGenerator generator, SyntaxN return invocation.WithArgumentList(SyntaxFactory.ArgumentList(arguments)); } - if (syntax is ObjectCreationExpressionSyntax creation) + if (syntax is BaseObjectCreationExpressionSyntax creation) { SeparatedSyntaxList arguments = creation.ArgumentList?.Arguments ?? []; arguments = arguments.InsertRange(index, items.OfType()); return creation.WithArgumentList(SyntaxFactory.ArgumentList(arguments)); } - throw new ArgumentException($"Must be of type {nameof(InvocationExpressionSyntax)} or {nameof(ObjectCreationExpressionSyntax)} but is of type {syntax.GetType().Name}", nameof(syntax)); + throw new ArgumentException($"Must be of type {nameof(InvocationExpressionSyntax)} or {nameof(BaseObjectCreationExpressionSyntax)} but is of type {syntax.GetType().Name}", nameof(syntax)); } public static SyntaxNode ReplaceArgument(this SyntaxGenerator generator, IOperation operation, int index, SyntaxNode item) @@ -59,13 +59,13 @@ public static SyntaxNode ReplaceArgument(this SyntaxGenerator generator, SyntaxN return invocation.WithArgumentList(SyntaxFactory.ArgumentList(arguments)); } - if (syntax is ObjectCreationExpressionSyntax creation) + if (syntax is BaseObjectCreationExpressionSyntax creation) { SeparatedSyntaxList arguments = creation.ArgumentList?.Arguments ?? []; arguments = arguments.RemoveAt(index).Insert(index, argument); return creation.WithArgumentList(SyntaxFactory.ArgumentList(arguments)); } - throw new ArgumentException($"Must be of type {nameof(InvocationExpressionSyntax)} or {nameof(ObjectCreationExpressionSyntax)} but is of type {syntax.GetType().Name}", nameof(syntax)); + throw new ArgumentException($"Must be of type {nameof(InvocationExpressionSyntax)} or {nameof(BaseObjectCreationExpressionSyntax)} but is of type {syntax.GetType().Name}", nameof(syntax)); } } diff --git a/tests/Moq.Analyzers.Test/SetExplicitMockBehaviorCodeFixTests.cs b/tests/Moq.Analyzers.Test/SetExplicitMockBehaviorCodeFixTests.cs index 97ab8e121..9184f6e00 100644 --- a/tests/Moq.Analyzers.Test/SetExplicitMockBehaviorCodeFixTests.cs +++ b/tests/Moq.Analyzers.Test/SetExplicitMockBehaviorCodeFixTests.cs @@ -33,6 +33,26 @@ public static IEnumerable TestData() ], }.WithNamespaces().WithMoqReferenceAssemblyGroups(); + IEnumerable mockConstructorsWithTargetTypedNew = new object[][] + { + [ + """Mock mock = {|Moq1400:new()|};""", + """Mock mock = new(MockBehavior.Loose);""", + ], + [ + """Mock mock = {|Moq1400:new(MockBehavior.Default)|};""", + """Mock mock = new(MockBehavior.Loose);""", + ], + [ + """Mock mock = new(MockBehavior.Loose);""", + """Mock mock = new(MockBehavior.Loose);""", + ], + [ + """Mock mock = new(MockBehavior.Strict);""", + """Mock mock = new(MockBehavior.Strict);""", + ], + }.WithNamespaces().WithMoqReferenceAssemblyGroups(); + IEnumerable mockConstructorsWithExpressions = new object[][] { [ @@ -89,7 +109,7 @@ public static IEnumerable TestData() ], }.WithNamespaces().WithNewMoqReferenceAssemblyGroups(); - return mockConstructors.Union(mockConstructorsWithExpressions).Union(fluentBuilders).Union(mockRepositories); + return mockConstructors.Union(mockConstructorsWithTargetTypedNew).Union(mockConstructorsWithExpressions).Union(fluentBuilders).Union(mockRepositories); } [Theory]