diff --git a/src/Orleans.Analyzers/AnalyzerReleases.Unshipped.md b/src/Orleans.Analyzers/AnalyzerReleases.Unshipped.md
index cdf4f1397e0..f42839be1b4 100644
--- a/src/Orleans.Analyzers/AnalyzerReleases.Unshipped.md
+++ b/src/Orleans.Analyzers/AnalyzerReleases.Unshipped.md
@@ -1 +1,7 @@
; Please do not edit this file manually, it should only be updated through code fix application.
+
+### New Rules
+
+Rule ID | Category | Severity | Notes
+--------|----------|----------|-------
+ORLEANS0014 | Usage | Warning | ConfigureAwaitAnalyzer, Grain code should not use ConfigureAwait(false) or ConfigureAwait without ContinueOnCapturedContext
diff --git a/src/Orleans.Analyzers/ConfigureAwaitAnalyzer.cs b/src/Orleans.Analyzers/ConfigureAwaitAnalyzer.cs
new file mode 100644
index 00000000000..d62232155f6
--- /dev/null
+++ b/src/Orleans.Analyzers/ConfigureAwaitAnalyzer.cs
@@ -0,0 +1,264 @@
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Microsoft.CodeAnalysis.Diagnostics;
+using System;
+using System.Collections.Immutable;
+
+namespace Orleans.Analyzers;
+
+///
+/// An analyzer that warns when grain code uses ConfigureAwait(false) or ConfigureAwait(ConfigureAwaitOptions)
+/// without the ContinueOnCapturedContext flag.
+///
+[DiagnosticAnalyzer(LanguageNames.CSharp)]
+public class ConfigureAwaitAnalyzer : DiagnosticAnalyzer
+{
+ public const string RuleId = "ORLEANS0014";
+
+ private static readonly LocalizableString Title = new LocalizableResourceString(
+ nameof(Resources.AvoidConfigureAwaitFalseInGrainTitle),
+ Resources.ResourceManager,
+ typeof(Resources));
+
+ private static readonly LocalizableString MessageFormat = new LocalizableResourceString(
+ nameof(Resources.AvoidConfigureAwaitFalseInGrainMessageFormat),
+ Resources.ResourceManager,
+ typeof(Resources));
+
+ private static readonly LocalizableString Description = new LocalizableResourceString(
+ nameof(Resources.AvoidConfigureAwaitFalseInGrainDescription),
+ Resources.ResourceManager,
+ typeof(Resources));
+
+ private static readonly DiagnosticDescriptor Rule = new(
+ id: RuleId,
+ title: Title,
+ messageFormat: MessageFormat,
+ category: "Usage",
+ defaultSeverity: DiagnosticSeverity.Warning,
+ isEnabledByDefault: true,
+ description: Description);
+
+ public override ImmutableArray SupportedDiagnostics => ImmutableArray.Create(Rule);
+
+ public override void Initialize(AnalysisContext context)
+ {
+ context.EnableConcurrentExecution();
+ context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics);
+ context.RegisterSyntaxNodeAction(AnalyzeInvocation, SyntaxKind.InvocationExpression);
+ }
+
+ private static void AnalyzeInvocation(SyntaxNodeAnalysisContext context)
+ {
+ var invocation = (InvocationExpressionSyntax)context.Node;
+
+ // Check if this is a ConfigureAwait call
+ if (!IsConfigureAwaitCall(invocation, out var methodName))
+ {
+ return;
+ }
+
+ // Check if this code is inside a grain class
+ if (!IsInsideGrainClass(invocation, context.SemanticModel))
+ {
+ return;
+ }
+
+ // Get the symbol for the invocation to analyze the argument
+ var symbolInfo = context.SemanticModel.GetSymbolInfo(invocation, context.CancellationToken);
+ if (symbolInfo.Symbol is not IMethodSymbol methodSymbol)
+ {
+ return;
+ }
+
+ // Only check ConfigureAwait method
+ if (!string.Equals(methodSymbol.Name, "ConfigureAwait", StringComparison.Ordinal))
+ {
+ return;
+ }
+
+ // Check if it's a ConfigureAwait method on a Task-like type
+ var containingType = methodSymbol.ContainingType;
+ if (!IsTaskLikeType(containingType))
+ {
+ return;
+ }
+
+ // Get the arguments
+ var arguments = invocation.ArgumentList?.Arguments;
+ if (arguments is null || arguments.Value.Count == 0)
+ {
+ return;
+ }
+
+ var firstArgument = arguments.Value[0];
+ var argumentType = context.SemanticModel.GetTypeInfo(firstArgument.Expression, context.CancellationToken).Type;
+
+ if (argumentType is null)
+ {
+ return;
+ }
+
+ // Check for ConfigureAwait(bool) overload
+ if (argumentType.SpecialType == SpecialType.System_Boolean)
+ {
+ var constantValue = context.SemanticModel.GetConstantValue(firstArgument.Expression, context.CancellationToken);
+ if (constantValue.HasValue && constantValue.Value is false)
+ {
+ // ConfigureAwait(false) is not allowed
+ context.ReportDiagnostic(Diagnostic.Create(Rule, invocation.GetLocation()));
+ }
+ return;
+ }
+
+ // Check for ConfigureAwait(ConfigureAwaitOptions) overload
+ if (IsConfigureAwaitOptionsType(argumentType))
+ {
+ if (!HasContinueOnCapturedContextFlag(firstArgument.Expression, context.SemanticModel, context.CancellationToken))
+ {
+ context.ReportDiagnostic(Diagnostic.Create(Rule, invocation.GetLocation()));
+ }
+ }
+ }
+
+ private static bool IsConfigureAwaitCall(InvocationExpressionSyntax invocation, out string methodName)
+ {
+ methodName = null;
+
+ if (invocation.Expression is MemberAccessExpressionSyntax memberAccess)
+ {
+ methodName = memberAccess.Name.Identifier.Text;
+ return string.Equals(methodName, "ConfigureAwait", StringComparison.Ordinal);
+ }
+
+ return false;
+ }
+
+ private static bool IsInsideGrainClass(SyntaxNode node, SemanticModel semanticModel)
+ {
+ // Walk up to find the containing type declaration
+ var current = node.Parent;
+ while (current is not null)
+ {
+ if (current is ClassDeclarationSyntax classDeclaration)
+ {
+ var typeSymbol = semanticModel.GetDeclaredSymbol(classDeclaration);
+ if (typeSymbol is INamedTypeSymbol namedTypeSymbol && namedTypeSymbol.IsGrainClass())
+ {
+ return true;
+ }
+ }
+ else if (current is StructDeclarationSyntax or RecordDeclarationSyntax)
+ {
+ // If we hit a struct or record before finding a grain class, we're not in a grain
+ // (structs and records can't be grains)
+ return false;
+ }
+
+ current = current.Parent;
+ }
+
+ return false;
+ }
+
+ private static bool IsTaskLikeType(INamedTypeSymbol type)
+ {
+ if (type is null)
+ {
+ return false;
+ }
+
+ var fullName = type.ToDisplayString(NullableFlowState.None);
+
+ // Check for common task-like types that have ConfigureAwait
+ return fullName.StartsWith("System.Threading.Tasks.Task", StringComparison.Ordinal)
+ || fullName.StartsWith("System.Threading.Tasks.ValueTask", StringComparison.Ordinal)
+ || fullName.StartsWith("System.Runtime.CompilerServices.ConfiguredTaskAwaitable", StringComparison.Ordinal)
+ || fullName.StartsWith("System.Runtime.CompilerServices.ConfiguredValueTaskAwaitable", StringComparison.Ordinal)
+ || fullName.StartsWith("System.Collections.Generic.IAsyncEnumerable", StringComparison.Ordinal)
+ || fullName.StartsWith("System.Runtime.CompilerServices.ConfiguredCancelableAsyncEnumerable", StringComparison.Ordinal);
+ }
+
+ private static bool IsConfigureAwaitOptionsType(ITypeSymbol type)
+ {
+ if (type is null)
+ {
+ return false;
+ }
+
+ return string.Equals(
+ type.ToDisplayString(NullableFlowState.None),
+ "System.Threading.Tasks.ConfigureAwaitOptions",
+ StringComparison.Ordinal);
+ }
+
+ private static bool HasContinueOnCapturedContextFlag(ExpressionSyntax expression, SemanticModel semanticModel, System.Threading.CancellationToken cancellationToken)
+ {
+ // ConfigureAwaitOptions.ContinueOnCapturedContext has value 1
+ const int ContinueOnCapturedContextValue = 1;
+
+ // Try to get the constant value
+ var constantValue = semanticModel.GetConstantValue(expression, cancellationToken);
+ if (constantValue.HasValue && constantValue.Value is int intValue)
+ {
+ // Check if ContinueOnCapturedContext flag (value 1) is set
+ return (intValue & ContinueOnCapturedContextValue) != 0;
+ }
+
+ // If we can't determine the value at compile time, we need to analyze the expression
+ // to check if it includes ContinueOnCapturedContext
+ return ExpressionIncludesContinueOnCapturedContext(expression, semanticModel, cancellationToken);
+ }
+
+ private static bool ExpressionIncludesContinueOnCapturedContext(ExpressionSyntax expression, SemanticModel semanticModel, System.Threading.CancellationToken cancellationToken)
+ {
+ // Handle member access like ConfigureAwaitOptions.ContinueOnCapturedContext
+ if (expression is MemberAccessExpressionSyntax memberAccess)
+ {
+ var memberName = memberAccess.Name.Identifier.Text;
+ if (string.Equals(memberName, "ContinueOnCapturedContext", StringComparison.Ordinal))
+ {
+ return true;
+ }
+ }
+
+ // Handle binary OR expressions like ConfigureAwaitOptions.ContinueOnCapturedContext | ConfigureAwaitOptions.ForceYielding
+ if (expression is BinaryExpressionSyntax binaryExpression &&
+ binaryExpression.IsKind(SyntaxKind.BitwiseOrExpression))
+ {
+ return ExpressionIncludesContinueOnCapturedContext(binaryExpression.Left, semanticModel, cancellationToken)
+ || ExpressionIncludesContinueOnCapturedContext(binaryExpression.Right, semanticModel, cancellationToken);
+ }
+
+ // Handle parenthesized expressions
+ if (expression is ParenthesizedExpressionSyntax parenthesized)
+ {
+ return ExpressionIncludesContinueOnCapturedContext(parenthesized.Expression, semanticModel, cancellationToken);
+ }
+
+ // Handle cast expressions
+ if (expression is CastExpressionSyntax castExpression)
+ {
+ return ExpressionIncludesContinueOnCapturedContext(castExpression.Expression, semanticModel, cancellationToken);
+ }
+
+ // If we encounter a variable or method call, we can't statically determine the flags
+ // In this case, we give the benefit of the doubt and don't report
+ if (expression is IdentifierNameSyntax or InvocationExpressionSyntax)
+ {
+ // Try to get the constant value as a fallback
+ var constantValue = semanticModel.GetConstantValue(expression, cancellationToken);
+ if (constantValue.HasValue && constantValue.Value is int intValue)
+ {
+ const int ContinueOnCapturedContextValue = 1;
+ return (intValue & ContinueOnCapturedContextValue) != 0;
+ }
+
+ // Can't determine - don't report false positives
+ return true;
+ }
+
+ return false;
+ }
+}
diff --git a/src/Orleans.Analyzers/ConfigureAwaitCodeFix.cs b/src/Orleans.Analyzers/ConfigureAwaitCodeFix.cs
new file mode 100644
index 00000000000..cfed7bcc2e9
--- /dev/null
+++ b/src/Orleans.Analyzers/ConfigureAwaitCodeFix.cs
@@ -0,0 +1,146 @@
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CodeActions;
+using Microsoft.CodeAnalysis.CodeFixes;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using System;
+using System.Collections.Immutable;
+using System.Composition;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
+
+namespace Orleans.Analyzers;
+
+///
+/// A code fix provider that converts ConfigureAwait(false) to ConfigureAwait(true) and
+/// adds ContinueOnCapturedContext to ConfigureAwait(ConfigureAwaitOptions) calls.
+///
+[ExportCodeFixProvider(LanguageNames.CSharp, Name = nameof(ConfigureAwaitCodeFix)), Shared]
+public class ConfigureAwaitCodeFix : CodeFixProvider
+{
+ public sealed override ImmutableArray FixableDiagnosticIds => ImmutableArray.Create(ConfigureAwaitAnalyzer.RuleId);
+ public sealed override FixAllProvider GetFixAllProvider() => WellKnownFixAllProviders.BatchFixer;
+
+ public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context)
+ {
+ var root = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false);
+ var diagnostic = context.Diagnostics.First();
+ var diagnosticSpan = diagnostic.Location.SourceSpan;
+
+ // Find the invocation expression identified by the diagnostic
+ var node = root.FindNode(diagnosticSpan);
+ var invocation = node.FirstAncestorOrSelf();
+
+ if (invocation is null)
+ {
+ return;
+ }
+
+ // Get semantic model to determine which fix to apply
+ var semanticModel = await context.Document.GetSemanticModelAsync(context.CancellationToken).ConfigureAwait(false);
+ var symbolInfo = semanticModel.GetSymbolInfo(invocation, context.CancellationToken);
+
+ if (symbolInfo.Symbol is not IMethodSymbol methodSymbol)
+ {
+ return;
+ }
+
+ // Check the parameter type to determine which fix to apply
+ if (methodSymbol.Parameters.Length == 1)
+ {
+ var parameterType = methodSymbol.Parameters[0].Type;
+
+ if (parameterType.SpecialType == SpecialType.System_Boolean)
+ {
+ // Fix for ConfigureAwait(bool) - change false to true
+ context.RegisterCodeFix(
+ CodeAction.Create(
+ title: Resources.ConfigureAwaitCodeFixTitle,
+ createChangedDocument: ct => FixConfigureAwaitBoolAsync(context.Document, invocation, ct),
+ equivalenceKey: ConfigureAwaitAnalyzer.RuleId + "_Bool"),
+ diagnostic);
+ }
+ else if (string.Equals(parameterType.ToDisplayString(), "System.Threading.Tasks.ConfigureAwaitOptions", StringComparison.Ordinal))
+ {
+ // Fix for ConfigureAwait(ConfigureAwaitOptions) - add ContinueOnCapturedContext flag
+ context.RegisterCodeFix(
+ CodeAction.Create(
+ title: Resources.ConfigureAwaitCodeFixTitle,
+ createChangedDocument: ct => FixConfigureAwaitOptionsAsync(context.Document, invocation, semanticModel, ct),
+ equivalenceKey: ConfigureAwaitAnalyzer.RuleId + "_Options"),
+ diagnostic);
+ }
+ }
+ }
+
+ private static async Task FixConfigureAwaitBoolAsync(
+ Document document,
+ InvocationExpressionSyntax invocation,
+ CancellationToken cancellationToken)
+ {
+ var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
+
+ // Create new argument with 'true' instead of 'false'
+ var newArgument = Argument(LiteralExpression(SyntaxKind.TrueLiteralExpression));
+ var newArgumentList = ArgumentList(SingletonSeparatedList(newArgument));
+
+ // Replace the argument list
+ var newInvocation = invocation.WithArgumentList(newArgumentList);
+ var newRoot = root.ReplaceNode(invocation, newInvocation);
+
+ return document.WithSyntaxRoot(newRoot);
+ }
+
+ private static async Task FixConfigureAwaitOptionsAsync(
+ Document document,
+ InvocationExpressionSyntax invocation,
+ SemanticModel semanticModel,
+ CancellationToken cancellationToken)
+ {
+ var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
+
+ var arguments = invocation.ArgumentList?.Arguments;
+ if (arguments is null || arguments.Value.Count == 0)
+ {
+ return document;
+ }
+
+ var existingArgument = arguments.Value[0].Expression;
+
+ // Check if the existing argument is ConfigureAwaitOptions.None
+ var constantValue = semanticModel.GetConstantValue(existingArgument, cancellationToken);
+ ExpressionSyntax newExpression;
+
+ if (constantValue.HasValue && constantValue.Value is int intValue && intValue == 0)
+ {
+ // If it's None (0), just replace with ContinueOnCapturedContext
+ newExpression = MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName("ConfigureAwaitOptions"),
+ IdentifierName("ContinueOnCapturedContext"));
+ }
+ else
+ {
+ // Otherwise, add ContinueOnCapturedContext using bitwise OR
+ var continueOnCapturedContext = MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName("ConfigureAwaitOptions"),
+ IdentifierName("ContinueOnCapturedContext"));
+
+ newExpression = BinaryExpression(
+ SyntaxKind.BitwiseOrExpression,
+ existingArgument.WithoutTrivia(),
+ continueOnCapturedContext);
+ }
+
+ var newArgument = Argument(newExpression);
+ var newArgumentList = ArgumentList(SingletonSeparatedList(newArgument));
+
+ var newInvocation = invocation.WithArgumentList(newArgumentList);
+ var newRoot = root.ReplaceNode(invocation, newInvocation);
+
+ return document.WithSyntaxRoot(newRoot);
+ }
+}
diff --git a/src/Orleans.Analyzers/Constants.cs b/src/Orleans.Analyzers/Constants.cs
index c61794f2a26..21dbbc7df42 100644
--- a/src/Orleans.Analyzers/Constants.cs
+++ b/src/Orleans.Analyzers/Constants.cs
@@ -6,6 +6,9 @@ internal static class Constants
public const string IAddressibleFullyQualifiedName = "Orleans.Runtime.IAddressable";
public const string GrainBaseFullyQualifiedName = "Orleans.Grain";
+ public const string IGrainBaseFullyQualifiedName = "Orleans.IGrainBase";
+ public const string IGrainFullyQualifiedName = "Orleans.IGrain";
+ public const string ISystemTargetFullyQualifiedName = "Orleans.ISystemTarget";
public const string IdAttributeName = "Id";
public const string IdAttributeFullyQualifiedName = "global::Orleans.IdAttribute";
diff --git a/src/Orleans.Analyzers/Resources.Designer.cs b/src/Orleans.Analyzers/Resources.Designer.cs
index 08fdb8cb687..3edbbe21402 100644
--- a/src/Orleans.Analyzers/Resources.Designer.cs
+++ b/src/Orleans.Analyzers/Resources.Designer.cs
@@ -1,4 +1,4 @@
-//------------------------------------------------------------------------------
+//------------------------------------------------------------------------------
//
// This code was generated by a tool.
// Runtime Version:4.0.30319.42000
@@ -257,5 +257,41 @@ internal static string IncorrectAttributeUseTitleDescription {
return ResourceManager.GetString("IncorrectAttributeUseTitleDescription", resourceCulture);
}
}
+
+ ///
+ /// Looks up a localized string similar to Grain code must maintain the grain's synchronization context...
+ ///
+ internal static string AvoidConfigureAwaitFalseInGrainDescription {
+ get {
+ return ResourceManager.GetString("AvoidConfigureAwaitFalseInGrainDescription", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to ConfigureAwait in grain code must not use 'false' and must include ConfigureAwaitOptions.ContinueOnCapturedContext.
+ ///
+ internal static string AvoidConfigureAwaitFalseInGrainMessageFormat {
+ get {
+ return ResourceManager.GetString("AvoidConfigureAwaitFalseInGrainMessageFormat", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Avoid ConfigureAwait(false) or ConfigureAwait not specifying ContinueOnCapturedContext in grain code.
+ ///
+ internal static string AvoidConfigureAwaitFalseInGrainTitle {
+ get {
+ return ResourceManager.GetString("AvoidConfigureAwaitFalseInGrainTitle", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Use ConfigureAwait with ContinueOnCapturedContext.
+ ///
+ internal static string ConfigureAwaitCodeFixTitle {
+ get {
+ return ResourceManager.GetString("ConfigureAwaitCodeFixTitle", resourceCulture);
+ }
+ }
}
}
diff --git a/src/Orleans.Analyzers/Resources.resx b/src/Orleans.Analyzers/Resources.resx
index 09e12ca19ea..b6f1c96f653 100644
--- a/src/Orleans.Analyzers/Resources.resx
+++ b/src/Orleans.Analyzers/Resources.resx
@@ -183,4 +183,16 @@
This attribute should not be used on grain implementations.
+
+ Avoid ConfigureAwait(false) or ConfigureAwait not specifying ContinueOnCapturedContext in grain code
+
+
+ ConfigureAwait in grain code must not use 'false' and must include ConfigureAwaitOptions.ContinueOnCapturedContext
+
+
+ Grain code must maintain the grain's execution context. Using ConfigureAwait(false) or ConfigureAwait without ContinueOnCapturedContext can cause the continuation to run outside the grain's context, leading to concurrency issues and loss of grain identity.
+
+
+ Use ConfigureAwait with ContinueOnCapturedContext
+
\ No newline at end of file
diff --git a/src/Orleans.Analyzers/SyntaxHelpers.cs b/src/Orleans.Analyzers/SyntaxHelpers.cs
index 6d78c1a77b0..892b75dece7 100644
--- a/src/Orleans.Analyzers/SyntaxHelpers.cs
+++ b/src/Orleans.Analyzers/SyntaxHelpers.cs
@@ -148,5 +148,97 @@ public static bool ExtendsGrainInterface(this INamedTypeSymbol symbol)
return false;
}
+ public static bool InheritsGrainClass(this ClassDeclarationSyntax declaration, SemanticModel semanticModel)
+ {
+ var baseTypes = declaration.BaseList?.Types;
+ if (baseTypes is null)
+ {
+ return false;
+ }
+
+ foreach (var baseTypeSyntax in baseTypes)
+ {
+ var baseTypeSymbol = semanticModel.GetTypeInfo(baseTypeSyntax.Type).Type;
+ if (baseTypeSymbol is INamedTypeSymbol currentTypeSymbol)
+ {
+ if (currentTypeSymbol.IsGenericType &&
+ currentTypeSymbol.TypeParameters.Length == 1 &&
+ currentTypeSymbol.BaseType is { } baseBaseTypeSymbol)
+ {
+ currentTypeSymbol = baseBaseTypeSymbol;
+ }
+
+ if (Constants.GrainBaseFullyQualifiedName.Equals(currentTypeSymbol.ToDisplayString(NullableFlowState.None), StringComparison.Ordinal))
+ {
+ return true;
+ }
+ }
+ }
+
+ return false;
+ }
+
+ public static bool IsGrainClass(this INamedTypeSymbol typeSymbol)
+ {
+ if (typeSymbol is null || typeSymbol.TypeKind != TypeKind.Class)
+ {
+ return false;
+ }
+
+ // Check if the type implements IGrain or ISystemTarget interface
+ foreach (var interfaceSymbol in typeSymbol.AllInterfaces)
+ {
+ var interfaceName = interfaceSymbol.ToDisplayString(NullableFlowState.None);
+ if (Constants.IGrainFullyQualifiedName.Equals(interfaceName, StringComparison.Ordinal) ||
+ Constants.ISystemTargetFullyQualifiedName.Equals(interfaceName, StringComparison.Ordinal))
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ public static AttributeArgumentBag GetArgumentBag(this AttributeSyntax attribute, SemanticModel semanticModel)
+ {
+ if (attribute is null)
+ {
+ return default;
+ }
+
+ var argument = attribute.ArgumentList?.Arguments.FirstOrDefault();
+ if (argument is null || argument.Expression is not { } expression)
+ {
+ return default;
+ }
+
+ var constantValue = semanticModel.GetConstantValue(expression);
+ return constantValue.HasValue && constantValue.Value is T value ?
+ new(value, attribute.GetLocation()) : default;
+ }
+
+ public static IEnumerable GetAttributeSyntaxes(this SyntaxList attributeLists, string attributeName) =>
+ attributeLists
+ .SelectMany(attributeList => attributeList.Attributes)
+ .Where(attribute => attribute.IsAttribute(attributeName));
+
+ public static string GetArgumentValue(this AttributeSyntax attribute, SemanticModel semanticModel)
+ {
+ if (attribute?.ArgumentList == null || attribute.ArgumentList.Arguments.Count == 0)
+ {
+ return null;
+ }
+
+ var symbolInfo = semanticModel.GetSymbolInfo(attribute);
+ if (symbolInfo.Symbol == null && symbolInfo.CandidateSymbols.Length == 0)
+ {
+ return null;
+ }
+
+ var argumentExpression = attribute.ArgumentList.Arguments[0].Expression;
+ var constant = semanticModel.GetConstantValue(argumentExpression);
+
+ return constant.HasValue ? constant.Value?.ToString() : null;
+ }
}
}
diff --git a/src/Orleans.Runtime/Placement/Repartitioning/ActivationRepartitioner.MessageSink.cs b/src/Orleans.Runtime/Placement/Repartitioning/ActivationRepartitioner.MessageSink.cs
index 58311a000b0..0369a3493db 100644
--- a/src/Orleans.Runtime/Placement/Repartitioning/ActivationRepartitioner.MessageSink.cs
+++ b/src/Orleans.Runtime/Placement/Repartitioning/ActivationRepartitioner.MessageSink.cs
@@ -35,7 +35,7 @@ public async Task StopProcessingEdgesAsync(CancellationToken cancellationToken)
}
_pendingMessageEvent.Signal();
- await _processPendingEdgesTask.WaitAsync(cancellationToken).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
+ await _processPendingEdgesTask.WaitAsync(cancellationToken).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing | ConfigureAwaitOptions.ContinueOnCapturedContext);
LogTraceServiceStopped(_logger, nameof(ActivationRepartitioner));
}
diff --git a/src/Orleans.Streaming/PersistentStreams/PersistentStreamPullingAgent.cs b/src/Orleans.Streaming/PersistentStreams/PersistentStreamPullingAgent.cs
index db93df6a100..146a1f0e430 100644
--- a/src/Orleans.Streaming/PersistentStreams/PersistentStreamPullingAgent.cs
+++ b/src/Orleans.Streaming/PersistentStreams/PersistentStreamPullingAgent.cs
@@ -158,7 +158,7 @@ public async Task Shutdown()
Task localReceiverInitTask = receiverInitTask;
if (localReceiverInitTask != null)
{
- await localReceiverInitTask.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
+ await localReceiverInitTask.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing | ConfigureAwaitOptions.ContinueOnCapturedContext);
receiverInitTask = null;
}
diff --git a/test/Analyzers.Tests/ConfigureAwaitAnalyzerTest.cs b/test/Analyzers.Tests/ConfigureAwaitAnalyzerTest.cs
new file mode 100644
index 00000000000..6d39b6ff644
--- /dev/null
+++ b/test/Analyzers.Tests/ConfigureAwaitAnalyzerTest.cs
@@ -0,0 +1,1430 @@
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CodeActions;
+using Microsoft.CodeAnalysis.CodeFixes;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.Diagnostics;
+using Microsoft.CodeAnalysis.Text;
+using Orleans.Analyzers;
+using System.Collections.Immutable;
+using System.Reflection;
+using System.Text;
+using Xunit;
+
+namespace Analyzers.Tests;
+
+///
+/// Tests for the analyzer that warns against ConfigureAwait(false) or ConfigureAwait without
+/// ContinueOnCapturedContext in grain code. Grains must maintain their synchronization context
+/// to ensure proper execution within the grain's activation context.
+///
+[TestCategory("BVT"), TestCategory("Analyzer")]
+public class ConfigureAwaitAnalyzerTest : DiagnosticAnalyzerTestBase
+{
+ private static readonly string[] Usings = new[] {
+ "System",
+ "System.Threading.Tasks",
+ "Orleans"
+ };
+
+ private async Task VerifyHasDiagnostic(string code)
+ {
+ var (diagnostics, _) = await GetDiagnosticsAsync(code, Array.Empty());
+
+ Assert.NotEmpty(diagnostics);
+ var diagnostic = diagnostics.First();
+
+ Assert.Equal(ConfigureAwaitAnalyzer.RuleId, diagnostic.Id);
+ Assert.Equal(DiagnosticSeverity.Warning, diagnostic.Severity);
+ }
+
+ private async Task VerifyHasNoDiagnostic(string code)
+ {
+ var (diagnostics, _) = await GetDiagnosticsAsync(code, Array.Empty());
+ Assert.Empty(diagnostics);
+ }
+
+ private async Task VerifyCodeFix(string originalCode, string expectedFixedCode, string[] extraUsings = null)
+ {
+ extraUsings ??= Array.Empty();
+
+ // Prepend usings
+ var sb = new StringBuilder();
+ foreach (var @using in Usings.Concat(extraUsings))
+ {
+ sb.AppendLine($"using {@using};");
+ }
+ sb.AppendLine(originalCode);
+ var fullOriginalCode = sb.ToString();
+
+ sb.Clear();
+ foreach (var @using in Usings.Concat(extraUsings))
+ {
+ sb.AppendLine($"using {@using};");
+ }
+ sb.AppendLine(expectedFixedCode);
+ var fullExpectedCode = sb.ToString();
+
+ // Create project and get diagnostics
+ var project = CreateProject(fullOriginalCode);
+ var document = project.Documents.First();
+ var compilation = await project.GetCompilationAsync();
+
+ var analyzer = new ConfigureAwaitAnalyzer();
+ var compilationWithAnalyzers = compilation
+ .WithOptions(compilation.Options.WithSpecificDiagnosticOptions(
+ analyzer.SupportedDiagnostics.ToDictionary(d => d.Id, d => ReportDiagnostic.Default)))
+ .WithAnalyzers(ImmutableArray.Create(analyzer));
+
+ var diagnostics = await compilationWithAnalyzers.GetAnalyzerDiagnosticsAsync();
+ Assert.NotEmpty(diagnostics);
+
+ // Apply code fix
+ var codeFixer = new ConfigureAwaitCodeFix();
+ var actions = new List();
+ var context = new CodeFixContext(
+ document,
+ diagnostics.First(),
+ (action, _) => actions.Add(action),
+ CancellationToken.None);
+
+ await codeFixer.RegisterCodeFixesAsync(context);
+ Assert.NotEmpty(actions);
+
+ var operations = await actions.First().GetOperationsAsync(CancellationToken.None);
+ var changedSolution = operations.OfType().Single().ChangedSolution;
+ var changedDocument = changedSolution.GetDocument(document.Id);
+ var changedText = await changedDocument.GetTextAsync();
+
+ Assert.Equal(fullExpectedCode, changedText.ToString());
+ }
+
+ private static Project CreateProject(string source)
+ {
+ const string fileName = "Test.cs";
+
+ var projectId = ProjectId.CreateNewId(debugName: "TestProject");
+ var documentId = DocumentId.CreateNewId(projectId, fileName);
+
+ var assemblies = new[]
+ {
+ typeof(Task).Assembly,
+ typeof(Orleans.IGrain).Assembly,
+ typeof(Orleans.Grain).Assembly,
+ typeof(Attribute).Assembly,
+ typeof(int).Assembly,
+ typeof(object).Assembly,
+ };
+
+ var metadataReferences = assemblies
+ .SelectMany(x => x.GetReferencedAssemblies().Select(Assembly.Load))
+ .Concat(assemblies)
+ .Distinct()
+ .Select(x => MetadataReference.CreateFromFile(x.Location))
+ .Cast()
+ .ToList();
+
+ var assemblyPath = Path.GetDirectoryName(typeof(object).Assembly.Location);
+ metadataReferences.Add(MetadataReference.CreateFromFile(Path.Combine(assemblyPath, "mscorlib.dll")));
+ metadataReferences.Add(MetadataReference.CreateFromFile(Path.Combine(assemblyPath, "System.dll")));
+ metadataReferences.Add(MetadataReference.CreateFromFile(Path.Combine(assemblyPath, "System.Core.dll")));
+ metadataReferences.Add(MetadataReference.CreateFromFile(Path.Combine(assemblyPath, "System.Runtime.dll")));
+
+ var solution = new AdhocWorkspace()
+ .CurrentSolution
+ .AddProject(projectId, "TestProject", "TestProject", LanguageNames.CSharp)
+ .AddMetadataReferences(projectId, metadataReferences)
+ .AddDocument(documentId, fileName, SourceText.From(source));
+
+ return solution.GetProject(projectId)
+ .WithCompilationOptions(new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
+ }
+
+ #region ConfigureAwait(false) in Grain
+
+ ///
+ /// Verifies that ConfigureAwait(false) in a grain class triggers a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_InGrainClass_ShouldTriggerDiagnostic()
+ {
+ var code = """
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(false);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(false) in a generic grain class triggers a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_InGenericGrainClass_ShouldTriggerDiagnostic()
+ {
+ var code = """
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(false);
+ }
+ }
+
+ public class MyState { }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(true) in a grain class does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitTrue_InGrainClass_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(true);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ #endregion
+
+ #region ConfigureAwait(false) in non-grain class
+
+ ///
+ /// Verifies that ConfigureAwait(false) in a plain class (no inheritance) does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_InPlainClass_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ public class MyService
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(false);
+ }
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(false) in a class implementing a non-grain interface does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_InClassImplementingNonGrainInterface_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ public interface IMyService
+ {
+ Task DoSomething();
+ }
+
+ public class MyService : IMyService
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(false);
+ }
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(false) in a class inheriting from a non-grain base class does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_InClassInheritingNonGrainBase_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ public class BaseService
+ {
+ }
+
+ public class MyService : BaseService
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(false);
+ }
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(false) in a class with deep non-grain inheritance does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_InClassWithDeepNonGrainInheritance_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ public class GrandparentService
+ {
+ }
+
+ public class ParentService : GrandparentService
+ {
+ }
+
+ public class MyService : ParentService
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(false);
+ }
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(false) in a struct does not trigger a diagnostic.
+ /// Structs cannot be grains.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_InStruct_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ public struct MyStruct
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(false);
+ }
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(false) in a record does not trigger a diagnostic.
+ /// Records cannot be grains (they don't inherit from Grain or implement IGrainBase).
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_InRecord_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ public record MyRecord
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(false);
+ }
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(false) in a record struct does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_InRecordStruct_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ public record struct MyRecordStruct
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(false);
+ }
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(false) in a class implementing IDisposable (not a grain interface) does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_InClassImplementingIDisposable_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ public class MyService : IDisposable
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(false);
+ }
+
+ public void Dispose() { }
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(false) in a static class does not trigger a diagnostic.
+ /// Static classes cannot be grains.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_InStaticClass_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ public static class MyStaticHelper
+ {
+ public static async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(false);
+ }
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(false) in an abstract non-grain class does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_InAbstractNonGrainClass_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ public abstract class MyAbstractService
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(false);
+ }
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(false) in a generic non-grain class does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_InGenericNonGrainClass_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ public class MyGenericService
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(false);
+ }
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(false) in a nested class inside a non-grain class does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_InNestedClassInsideNonGrain_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ public class OuterService
+ {
+ public class InnerService
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(false);
+ }
+ }
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(ConfigureAwaitOptions) without ContinueOnCapturedContext
+ /// in a non-grain class does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitOptions_ForceYielding_InNonGrainClass_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ using System.Threading.Tasks;
+
+ public class MyService
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
+ }
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ #endregion
+
+ #region No ConfigureAwait
+
+ ///
+ /// Verifies that awaiting without ConfigureAwait in a grain class does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task NoConfigureAwait_InGrainClass_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ #endregion
+
+ #region IGrainBase implementation
+
+ ///
+ /// Verifies that ConfigureAwait(false) in a class implementing IGrainBase triggers a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_InIGrainBaseImplementation_ShouldTriggerDiagnostic()
+ {
+ var code = """
+ using Orleans.Runtime;
+
+ public class MyGrain : IGrainBase, IMyGrain
+ {
+ public IGrainContext GrainContext { get; }
+
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(false);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasDiagnostic(code);
+ }
+
+ #endregion
+
+ #region ISystemTarget implementation
+
+ ///
+ /// Verifies that ConfigureAwait(false) in a class implementing ISystemTarget triggers a diagnostic.
+ /// ISystemTarget is in the Orleans namespace (not Orleans.Runtime), defined in Orleans.Core.Abstractions.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_InISystemTargetImplementation_ShouldTriggerDiagnostic()
+ {
+ // Note: ISystemTarget is defined in namespace Orleans (in Orleans.Core.Abstractions assembly),
+ // so no additional using is needed since we already have "using Orleans;"
+ var code = """
+ public class MySystemTarget : ISystemTarget
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(false);
+ }
+ }
+ """;
+
+ return VerifyHasDiagnostic(code);
+ }
+
+ #endregion
+
+ #region Nested classes and lambdas
+
+ ///
+ /// Verifies that ConfigureAwait(false) in a lambda within a grain class triggers a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_InLambdaInsideGrain_ShouldTriggerDiagnostic()
+ {
+ var code = """
+ public class MyGrain : Grain, IMyGrain
+ {
+ public Task DoSomething()
+ {
+ Func action = async () =>
+ {
+ await Task.Delay(100).ConfigureAwait(false);
+ };
+ return action();
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(false) in a nested class within a grain class triggers a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_InNestedClassInsideGrain_ShouldTriggerDiagnostic()
+ {
+ var code = """
+ public class MyGrain : Grain, IMyGrain
+ {
+ public Task DoSomething() => Task.CompletedTask;
+
+ private class NestedClass
+ {
+ public async Task DoWork()
+ {
+ await Task.Delay(100).ConfigureAwait(false);
+ }
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ // The nested class is inside a grain class, so it should still trigger
+ return VerifyHasDiagnostic(code);
+ }
+
+ #endregion
+
+ #region Inherited grain classes
+
+ ///
+ /// Verifies that ConfigureAwait(false) in a class that inherits from another grain class triggers a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_InInheritedGrainClass_ShouldTriggerDiagnostic()
+ {
+ var code = """
+ public class BaseGrain : Grain
+ {
+ }
+
+ public class MyGrain : BaseGrain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(false);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasDiagnostic(code);
+ }
+
+ #endregion
+
+ #region ValueTask
+
+ ///
+ /// Verifies that ConfigureAwait(false) on ValueTask in a grain class triggers a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_OnValueTask_InGrainClass_ShouldTriggerDiagnostic()
+ {
+ var code = """
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await GetValueAsync().ConfigureAwait(false);
+ }
+
+ private ValueTask GetValueAsync() => ValueTask.CompletedTask;
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(false) on ValueTask<T> in a grain class triggers a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_OnGenericValueTask_InGrainClass_ShouldTriggerDiagnostic()
+ {
+ var code = """
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ var result = await GetValueAsync().ConfigureAwait(false);
+ }
+
+ private ValueTask GetValueAsync() => ValueTask.FromResult(42);
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(true) on ValueTask in a grain class does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitTrue_OnValueTask_InGrainClass_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await GetValueAsync().ConfigureAwait(true);
+ }
+
+ private ValueTask GetValueAsync() => ValueTask.CompletedTask;
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(false) on ValueTask in a non-grain class does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_OnValueTask_InNonGrainClass_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ public class MyService
+ {
+ public async Task DoSomething()
+ {
+ await GetValueAsync().ConfigureAwait(false);
+ }
+
+ private ValueTask GetValueAsync() => ValueTask.CompletedTask;
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ #endregion
+
+ #region IAsyncEnumerable
+
+ ///
+ /// Verifies that ConfigureAwait(false) on IAsyncEnumerable in await foreach in a grain class triggers a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_OnIAsyncEnumerable_InGrainClass_ShouldTriggerDiagnostic()
+ {
+ var code = """
+ using System.Collections.Generic;
+
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await foreach (var item in GetItemsAsync().ConfigureAwait(false))
+ {
+ // Process item
+ }
+ }
+
+ private async IAsyncEnumerable GetItemsAsync()
+ {
+ yield return 1;
+ await Task.Delay(1);
+ yield return 2;
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(true) on IAsyncEnumerable in await foreach in a grain class does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitTrue_OnIAsyncEnumerable_InGrainClass_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ using System.Collections.Generic;
+
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await foreach (var item in GetItemsAsync().ConfigureAwait(true))
+ {
+ // Process item
+ }
+ }
+
+ private async IAsyncEnumerable GetItemsAsync()
+ {
+ yield return 1;
+ await Task.Delay(1);
+ yield return 2;
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(false) on IAsyncEnumerable in a non-grain class does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_OnIAsyncEnumerable_InNonGrainClass_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ using System.Collections.Generic;
+
+ public class MyService
+ {
+ public async Task DoSomething()
+ {
+ await foreach (var item in GetItemsAsync().ConfigureAwait(false))
+ {
+ // Process item
+ }
+ }
+
+ private async IAsyncEnumerable GetItemsAsync()
+ {
+ yield return 1;
+ await Task.Delay(1);
+ yield return 2;
+ }
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that await foreach without ConfigureAwait in a grain class does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task NoConfigureAwait_OnIAsyncEnumerable_InGrainClass_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ using System.Collections.Generic;
+
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await foreach (var item in GetItemsAsync())
+ {
+ // Process item
+ }
+ }
+
+ private async IAsyncEnumerable GetItemsAsync()
+ {
+ yield return 1;
+ await Task.Delay(1);
+ yield return 2;
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ #endregion
+
+ #region Task
+
+ ///
+ /// Verifies that ConfigureAwait(false) on Task<T> in a grain class triggers a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitFalse_OnGenericTask_InGrainClass_ShouldTriggerDiagnostic()
+ {
+ var code = """
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ var result = await Task.FromResult(42).ConfigureAwait(false);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(true) on Task<T> in a grain class does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitTrue_OnGenericTask_InGrainClass_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ var result = await Task.FromResult(42).ConfigureAwait(true);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ #endregion
+
+ #region ConfigureAwait(ConfigureAwaitOptions)
+
+ ///
+ /// Verifies that ConfigureAwait(ConfigureAwaitOptions.None) in a grain class triggers a diagnostic
+ /// because it doesn't include ContinueOnCapturedContext.
+ ///
+ [Fact]
+ public Task ConfigureAwaitOptions_None_InGrainClass_ShouldTriggerDiagnostic()
+ {
+ var code = """
+ using System.Threading.Tasks;
+
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(ConfigureAwaitOptions.None);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(ConfigureAwaitOptions.ForceYielding) in a grain class triggers a diagnostic
+ /// because it doesn't include ContinueOnCapturedContext.
+ ///
+ [Fact]
+ public Task ConfigureAwaitOptions_ForceYielding_InGrainClass_ShouldTriggerDiagnostic()
+ {
+ var code = """
+ using System.Threading.Tasks;
+
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing) in a grain class triggers a diagnostic
+ /// because it doesn't include ContinueOnCapturedContext.
+ ///
+ [Fact]
+ public Task ConfigureAwaitOptions_SuppressThrowing_InGrainClass_ShouldTriggerDiagnostic()
+ {
+ var code = """
+ using System.Threading.Tasks;
+
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(ConfigureAwaitOptions.ContinueOnCapturedContext) in a grain class
+ /// does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitOptions_ContinueOnCapturedContext_InGrainClass_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ using System.Threading.Tasks;
+
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(ConfigureAwaitOptions.ContinueOnCapturedContext);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait with combined flags including ContinueOnCapturedContext
+ /// does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitOptions_CombinedWithContinueOnCapturedContext_InGrainClass_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ using System.Threading.Tasks;
+
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(ConfigureAwaitOptions.ContinueOnCapturedContext | ConfigureAwaitOptions.ForceYielding);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait with combined flags NOT including ContinueOnCapturedContext
+ /// triggers a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitOptions_CombinedWithoutContinueOnCapturedContext_InGrainClass_ShouldTriggerDiagnostic()
+ {
+ var code = """
+ using System.Threading.Tasks;
+
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(ConfigureAwaitOptions.ForceYielding | ConfigureAwaitOptions.SuppressThrowing);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(ConfigureAwaitOptions) in a non-grain class does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitOptions_None_InNonGrainClass_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ using System.Threading.Tasks;
+
+ public class MyService
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(ConfigureAwaitOptions.None);
+ }
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(ConfigureAwaitOptions) on Task<T> works correctly.
+ ///
+ [Fact]
+ public Task ConfigureAwaitOptions_None_OnGenericTask_InGrainClass_ShouldTriggerDiagnostic()
+ {
+ var code = """
+ using System.Threading.Tasks;
+
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ var result = await Task.FromResult(42).ConfigureAwait(ConfigureAwaitOptions.None);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasDiagnostic(code);
+ }
+
+ ///
+ /// Verifies that ConfigureAwait(ConfigureAwaitOptions.ContinueOnCapturedContext) on Task<T> does not trigger a diagnostic.
+ ///
+ [Fact]
+ public Task ConfigureAwaitOptions_ContinueOnCapturedContext_OnGenericTask_InGrainClass_ShouldNotTriggerDiagnostic()
+ {
+ var code = """
+ using System.Threading.Tasks;
+
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ var result = await Task.FromResult(42).ConfigureAwait(ConfigureAwaitOptions.ContinueOnCapturedContext);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyHasNoDiagnostic(code);
+ }
+
+ #endregion
+
+ #region Code Fix Tests
+
+ ///
+ /// Verifies that the code fix converts ConfigureAwait(false) to ConfigureAwait(true).
+ ///
+ [Fact]
+ public Task CodeFix_ConfigureAwaitFalse_ChangesToTrue()
+ {
+ var originalCode = """
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(false);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ var expectedFixedCode = """
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(true);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyCodeFix(originalCode, expectedFixedCode);
+ }
+
+ ///
+ /// Verifies that the code fix converts ConfigureAwait(false) to ConfigureAwait(true) on ValueTask.
+ ///
+ [Fact]
+ public Task CodeFix_ConfigureAwaitFalse_OnValueTask_ChangesToTrue()
+ {
+ var originalCode = """
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await GetValueAsync().ConfigureAwait(false);
+ }
+
+ private ValueTask GetValueAsync() => ValueTask.CompletedTask;
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ var expectedFixedCode = """
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await GetValueAsync().ConfigureAwait(true);
+ }
+
+ private ValueTask GetValueAsync() => ValueTask.CompletedTask;
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyCodeFix(originalCode, expectedFixedCode);
+ }
+
+ ///
+ /// Verifies that the code fix converts ConfigureAwait(ConfigureAwaitOptions.None) to ConfigureAwait(ConfigureAwaitOptions.ContinueOnCapturedContext).
+ ///
+ [Fact]
+ public Task CodeFix_ConfigureAwaitOptionsNone_ChangesToContinueOnCapturedContext()
+ {
+ var originalCode = """
+ using System.Threading.Tasks;
+
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(ConfigureAwaitOptions.None);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ var expectedFixedCode = """
+ using System.Threading.Tasks;
+
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(ConfigureAwaitOptions.ContinueOnCapturedContext);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyCodeFix(originalCode, expectedFixedCode);
+ }
+
+ ///
+ /// Verifies that the code fix adds ContinueOnCapturedContext to ConfigureAwait(ConfigureAwaitOptions.ForceYielding).
+ ///
+ [Fact]
+ public Task CodeFix_ConfigureAwaitOptionsForceYielding_AddsContinueOnCapturedContext()
+ {
+ var originalCode = """
+ using System.Threading.Tasks;
+
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ var expectedFixedCode = """
+ using System.Threading.Tasks;
+
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(ConfigureAwaitOptions.ForceYielding | ConfigureAwaitOptions.ContinueOnCapturedContext);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyCodeFix(originalCode, expectedFixedCode);
+ }
+
+ ///
+ /// Verifies that the code fix adds ContinueOnCapturedContext to combined ConfigureAwaitOptions.
+ ///
+ [Fact]
+ public Task CodeFix_ConfigureAwaitOptionsCombined_AddsContinueOnCapturedContext()
+ {
+ var originalCode = """
+ using System.Threading.Tasks;
+
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(ConfigureAwaitOptions.ForceYielding | ConfigureAwaitOptions.SuppressThrowing);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ var expectedFixedCode = """
+ using System.Threading.Tasks;
+
+ public class MyGrain : Grain, IMyGrain
+ {
+ public async Task DoSomething()
+ {
+ await Task.Delay(100).ConfigureAwait(ConfigureAwaitOptions.ForceYielding | ConfigureAwaitOptions.SuppressThrowing | ConfigureAwaitOptions.ContinueOnCapturedContext);
+ }
+ }
+
+ public interface IMyGrain : IGrainWithGuidKey
+ {
+ Task DoSomething();
+ }
+ """;
+
+ return VerifyCodeFix(originalCode, expectedFixedCode);
+ }
+
+ #endregion
+}