Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Microsoft.CodeAnalysis.CSharp.UseCollectionExpression;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.UseCollectionExpression;
using Microsoft.CodeAnalysis.UseCollectionInitializer;

Expand Down Expand Up @@ -110,4 +111,59 @@ protected override bool IsValidContainingStatement(StatementSyntax node)
return node is not LocalDeclarationStatementSyntax localDecl ||
localDecl.UsingKeyword == default;
}

protected override bool ShouldSuppressDiagnostic(
SemanticModel semanticModel,
BaseObjectCreationExpressionSyntax objectCreationExpression,
ITypeSymbol objectType,
CancellationToken cancellationToken)
{
// Check if the type being created has a CollectionBuilder attribute that points to the method we're currently in.
// If so, suppress the diagnostic to avoid suggesting a change that would cause infinite recursion.
// For example, if we're inside the Create method of a CollectionBuilder, and we have:
// MyCustomCollection<T> collection = new();
// foreach (T item in items) { collection.Add(item); }
// We should NOT suggest changing it to:
// MyCustomCollection<T> collection = [.. items];
// Because that would recursively call the same Create method.

if (objectType is not INamedTypeSymbol namedType)
return false;

// For generic types, get the type definition to check for the attribute
var typeToCheck = namedType.OriginalDefinition;

// Look for CollectionBuilder attribute on the type
var collectionBuilderAttribute = typeToCheck.GetAttributes().FirstOrDefault(attr =>
attr.AttributeClass?.IsCollectionBuilderAttribute() == true);

if (collectionBuilderAttribute == null)
return false;

// Get the builder type and method name from the attribute.
// CollectionBuilderAttribute has exactly 2 constructor parameters: builderType and methodName
if (collectionBuilderAttribute.ConstructorArguments is not
[
{ Kind: TypedConstantKind.Type, Value: INamedTypeSymbol builderType },
{ Kind: TypedConstantKind.Primitive, Value: string methodName }
])
{
return false;
}

// Get the containing method we're currently analyzing
var containingMethod = semanticModel.GetEnclosingSymbol<IMethodSymbol>(objectCreationExpression.SpanStart, cancellationToken);
if (containingMethod == null)
return false;

// Check if the containing method matches the CollectionBuilder method
// We need to compare the original definitions in case the method is generic
if (containingMethod.Name == methodName &&
SymbolEqualityComparer.Default.Equals(containingMethod.ContainingType.OriginalDefinition, builderType.OriginalDefinition))
{
return true;
}

return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5752,4 +5752,138 @@ private void CodeFixErrorRepro_OldEnumerables()
LanguageVersion = LanguageVersion.CSharp12,
ReferenceAssemblies = ReferenceAssemblies.Net.Net80,
}.RunAsync();

[Fact, WorkItem("https://github.com/dotnet/roslyn/issues/70099")]
public Task TestNotInCollectionBuilderMethod()
=> TestMissingInRegularAndScriptAsync(
"""
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Runtime.CompilerServices;

namespace System.Runtime.CompilerServices
{
[AttributeUsage(AttributeTargets.All, Inherited = false, AllowMultiple = false)]
public sealed class CollectionBuilderAttribute : Attribute
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the definition necessary? e.g. it isn't included in the standard testworkspace references?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically not in the refeences. but we have a string constant for it. updating.

{
public CollectionBuilderAttribute(Type builderType, string methodName) { }
}
}

[CollectionBuilder(typeof(MyCustomCollection), nameof(MyCustomCollection.Create))]
internal class MyCustomCollection<T> : Collection<T>
{
}

internal static class MyCustomCollection
{
public static MyCustomCollection<T> Create<T>(ReadOnlySpan<T> items)
{
MyCustomCollection<T> collection = new();
foreach (T item in items)
{
collection.Add(item);
}

return collection;
}
}
""");

[Fact, WorkItem("https://github.com/dotnet/roslyn/issues/70099")]
public Task TestCollectionBuilderOutsideMethod()
=> TestInRegularAndScriptAsync(
"""
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Runtime.CompilerServices;

namespace System.Runtime.CompilerServices
{
[AttributeUsage(AttributeTargets.All, Inherited = false, AllowMultiple = false)]
public sealed class CollectionBuilderAttribute : Attribute
{
public CollectionBuilderAttribute(Type builderType, string methodName) { }
}
}

[CollectionBuilder(typeof(MyCustomCollection), nameof(MyCustomCollection.Create))]
internal class MyCustomCollection<T> : Collection<T>
{
}

internal static class MyCustomCollection
{
public static MyCustomCollection<T> Create<T>(ReadOnlySpan<T> items)
{
MyCustomCollection<T> collection = new();
foreach (T item in items)
{
collection.Add(item);
}

return collection;
}
}

class C
{
void M()
{
MyCustomCollection<int> c = [|new|]();
[|c.Add(|]1);
}
}
""",
"""
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Runtime.CompilerServices;

namespace System.Runtime.CompilerServices
{
[AttributeUsage(AttributeTargets.All, Inherited = false, AllowMultiple = false)]
public sealed class CollectionBuilderAttribute : Attribute
{
public CollectionBuilderAttribute(Type builderType, string methodName) { }
}
}

[CollectionBuilder(typeof(MyCustomCollection), nameof(MyCustomCollection.Create))]
internal class MyCustomCollection<T> : Collection<T>
{
}

internal static class MyCustomCollection
{
public static MyCustomCollection<T> Create<T>(ReadOnlySpan<T> items)
{
MyCustomCollection<T> collection = new();
foreach (T item in items)
{
collection.Add(item);
}

return collection;
}
}

class C
{
void M()
{
MyCustomCollection<int> c =
[
1
];
}
}
""");
}

Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,20 @@ protected abstract bool CanUseCollectionExpression(

protected abstract bool IsValidContainingStatement(TStatementSyntax node);

/// <summary>
/// Returns true if the diagnostic should be suppressed for the given object creation expression and type.
/// This is used to prevent suggesting collection expressions in scenarios where they would cause recursion,
/// such as inside CollectionBuilder methods.
/// </summary>
protected virtual bool ShouldSuppressDiagnostic(
SemanticModel semanticModel,
TObjectCreationExpressionSyntax objectCreationExpression,
ITypeSymbol objectType,
CancellationToken cancellationToken)
{
return false;
}

protected sealed override void InitializeWorker(AnalysisContext context)
=> context.RegisterCompilationStartAction(OnCompilationStart);

Expand Down Expand Up @@ -149,6 +163,10 @@ private void AnalyzeNode(
if (objectType.Type == null || !objectType.Type.AllInterfaces.Contains(ienumerableType))
return;

// Check if the diagnostic should be suppressed (e.g., inside CollectionBuilder methods)
if (ShouldSuppressDiagnostic(semanticModel, objectCreationExpression, objectType.Type, cancellationToken))
return;

// Analyze the surrounding statements. First, try a broader set of statements if the language supports
// collection expressions.
var syntaxFacts = this.SyntaxFacts;
Expand Down
Loading