From 8b6c33a6b9e24156a9abedb737acb5ec0e3d7de8 Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Mon, 26 May 2025 15:57:39 +0100 Subject: [PATCH 1/2] Port #528 and #2091 to netfx --- .../Microsoft/Data/SqlClient/SqlCommand.cs | 8 +- .../Microsoft/Data/SqlClient/SqlCommand.cs | 312 +++++++++++++++--- .../Data/SqlClient/SqlInternalConnection.cs | 2 - 3 files changed, 262 insertions(+), 60 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs index dcafb119eb..3e0cced22a 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -35,7 +35,7 @@ public sealed partial class SqlCommand : DbCommand, ICloneable { private static int _objectTypeCount; // EventSource Counter private const int MaxRPCNameLength = 1046; - internal readonly int ObjectID = Interlocked.Increment(ref _objectTypeCount); private string _commandText; + internal readonly int ObjectID = Interlocked.Increment(ref _objectTypeCount); internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext { @@ -113,6 +113,7 @@ protected override void AfterCleared(SqlCommand owner) } } + private string _commandText; private CommandType _commandType; private int? _commandTimeout; private UpdateRowSource _updatedRowSource = UpdateRowSource.Both; @@ -2645,7 +2646,7 @@ private Task InternalExecuteNonQueryAsync(CancellationToken cancellationTok { s_diagnosticListener.WriteCommandError(operationId, this, _transaction, e); source.SetException(e); - context.Dispose(); + context?.Dispose(); } return returnedTask; @@ -2794,7 +2795,7 @@ private Task InternalExecuteReaderAsync(CommandBehavior behavior, } source.SetException(e); - context.Dispose(); + context?.Dispose(); } return returnedTask; @@ -3051,7 +3052,6 @@ private Task InternalExecuteXmlReaderAsync(CancellationToken cancella } context.Set(this, source, registration, operationId); - Task returnedTask = source.Task; try { diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs index 70d9f72ebc..3f5ee052cd 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -42,6 +42,82 @@ public sealed class SqlCommand : DbCommand, ICloneable private const int MaxRPCNameLength = 1046; internal readonly int ObjectID = System.Threading.Interlocked.Increment(ref _objectTypeCount); + internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext + { + public Guid OperationID; + public CommandBehavior CommandBehavior; + + public SqlCommand Command => _owner; + public TaskCompletionSource TaskCompletionSource => _source; + + public void Set(SqlCommand command, TaskCompletionSource source, CancellationTokenRegistration disposable, CommandBehavior behavior, Guid operationID) + { + base.Set(command, source, disposable); + CommandBehavior = behavior; + OperationID = operationID; + } + + protected override void Clear() + { + OperationID = default; + CommandBehavior = default; + } + + protected override void AfterCleared(SqlCommand owner) + { + owner?.SetCachedCommandExecuteReaderAsyncContext(this); + } + } + + internal sealed class ExecuteNonQueryAsyncCallContext : AAsyncCallContext + { + public Guid OperationID; + + public SqlCommand Command => _owner; + + public TaskCompletionSource TaskCompletionSource => _source; + + public void Set(SqlCommand command, TaskCompletionSource source, CancellationTokenRegistration disposable, Guid operationID) + { + base.Set(command, source, disposable); + OperationID = operationID; + } + + protected override void Clear() + { + OperationID = default; + } + + protected override void AfterCleared(SqlCommand owner) + { + owner?.SetCachedCommandExecuteNonQueryAsyncContext(this); + } + } + + internal sealed class ExecuteXmlReaderAsyncCallContext : AAsyncCallContext + { + public Guid OperationID; + + public SqlCommand Command => _owner; + public TaskCompletionSource TaskCompletionSource => _source; + + public void Set(SqlCommand command, TaskCompletionSource source, CancellationTokenRegistration disposable, Guid operationID) + { + base.Set(command, source, disposable); + OperationID = operationID; + } + + protected override void Clear() + { + OperationID = default; + } + + protected override void AfterCleared(SqlCommand owner) + { + owner?.SetCachedCommandExecuteXmlReaderContext(this); + } + } + private string _commandText; private CommandType _commandType; private int? _commandTimeout; @@ -2554,6 +2630,26 @@ private SqlDataReader EndExecuteReaderInternal(IAsyncResult asyncResult) } } + private void CleanupExecuteReaderAsync(Task task, TaskCompletionSource source, Guid operationId) + { + if (task.IsFaulted) + { + Exception e = task.Exception.InnerException; + source.SetException(e); + } + else + { + if (task.IsCanceled) + { + source.SetCanceled(); + } + else + { + source.SetResult(task.Result); + } + } + } + private IAsyncResult BeginExecuteReaderAsync(CommandBehavior behavior, AsyncCallback callback, object stateObject) { return BeginExecuteReaderInternal(behavior, callback, stateObject, CommandTimeout, inRetry: false, asyncWrite: true); @@ -2892,6 +2988,7 @@ private Task InternalExecuteNonQueryAsync(CancellationToken cancellationTok { SqlClientEventSource.Log.TryCorrelationTraceEvent(" ObjectID {0}, ActivityID {1}", ObjectID, ActivityCorrelator.Current); SqlConnection.ExecutePermission.Demand(); + Guid operationId = Guid.Empty; // connection can be used as state in RegisterForConnectionCloseNotification continuation // to avoid an allocation so use it as the state value if possible but it can be changed if @@ -2910,39 +3007,67 @@ private Task InternalExecuteNonQueryAsync(CancellationToken cancellationTok } Task returnedTask = source.Task; + + ExecuteNonQueryAsyncCallContext context = new ExecuteNonQueryAsyncCallContext(); + context.Set(this, source, registration, operationId); try { returnedTask = RegisterForConnectionCloseNotification(returnedTask); - Task.Factory.FromAsync(BeginExecuteNonQueryAsync, EndExecuteNonQueryAsync, null).ContinueWith((t) => - { - registration.Dispose(); - if (t.IsFaulted) + Task.Factory.FromAsync( + beginMethod: static (AsyncCallback callback, object stateObject) => { - Exception e = t.Exception.InnerException; - source.SetException(e); - } - else + return ((ExecuteNonQueryAsyncCallContext)stateObject).Command.BeginExecuteNonQueryAsync(callback, stateObject); + }, + endMethod: static (IAsyncResult asyncResult) => { - if (t.IsCanceled) - { - source.SetCanceled(); - } - else - { - source.SetResult(t.Result); - } - } - }, TaskScheduler.Default); + return ((ExecuteNonQueryAsyncCallContext)asyncResult.AsyncState).Command.EndExecuteNonQueryAsync(asyncResult); + }, + state: context + ) + .ContinueWith( + static (Task task) => + { + ExecuteNonQueryAsyncCallContext context = (ExecuteNonQueryAsyncCallContext)task.AsyncState; + SqlCommand command = context.Command; + Guid operationId = context.OperationID; + TaskCompletionSource source = context.TaskCompletionSource; + context.Dispose(); + + command.CleanupAfterExecuteNonQueryAsync(task, source, operationId); + }, + scheduler: TaskScheduler.Default + ); } catch (Exception e) { source.SetException(e); + context?.Dispose(); } return returnedTask; } + private void CleanupAfterExecuteNonQueryAsync(Task task, TaskCompletionSource source, Guid operationId) + { + if (task.IsFaulted) + { + Exception e = task.Exception.InnerException; + source.SetException(e); + } + else + { + if (task.IsCanceled) + { + source.SetCanceled(); + } + else + { + source.SetResult(task.Result); + } + } + } + /// protected override Task ExecuteDbDataReaderAsync(CommandBehavior behavior, CancellationToken cancellationToken) { @@ -2992,6 +3117,7 @@ private Task InternalExecuteReaderAsync(CommandBehavior behavior, { SqlClientEventSource.Log.TryCorrelationTraceEvent(" ObjectID {0}, behavior={1}, ActivityID {2}", ObjectID, (int)behavior, ActivityCorrelator.Current); SqlConnection.ExecutePermission.Demand(); + Guid operationId = default(Guid); // connection can be used as state in RegisterForConnectionCloseNotification continuation // to avoid an allocation so use it as the state value if possible but it can be changed if @@ -3010,39 +3136,80 @@ private Task InternalExecuteReaderAsync(CommandBehavior behavior, } Task returnedTask = source.Task; + ExecuteReaderAsyncCallContext context = null; try { returnedTask = RegisterForConnectionCloseNotification(returnedTask); - Task.Factory.FromAsync(BeginExecuteReaderAsync, EndExecuteReaderAsync, behavior, null).ContinueWith((t) => + if (_activeConnection?.InnerConnection is SqlInternalConnection sqlInternalConnection) + { + context = Interlocked.Exchange(ref sqlInternalConnection.CachedCommandExecuteReaderAsyncContext, null); + } + if (context is null) { - registration.Dispose(); - if (t.IsFaulted) + context = new ExecuteReaderAsyncCallContext(); + } + context.Set(this, source, registration, behavior, operationId); + + Task.Factory.FromAsync( + beginMethod: static (AsyncCallback callback, object stateObject) => { - Exception e = t.Exception.InnerException; - source.SetException(e); - } - else + ExecuteReaderAsyncCallContext args = (ExecuteReaderAsyncCallContext)stateObject; + return args.Command.BeginExecuteReaderInternal(args.CommandBehavior, callback, stateObject, args.Command.CommandTimeout, inRetry: false, asyncWrite: true); + }, + endMethod: static (IAsyncResult asyncResult) => { - if (t.IsCanceled) - { - source.SetCanceled(); - } - else - { - source.SetResult(t.Result); - } - } - }, TaskScheduler.Default); + ExecuteReaderAsyncCallContext args = (ExecuteReaderAsyncCallContext)asyncResult.AsyncState; + return args.Command.EndExecuteReaderAsync(asyncResult); + }, + state: context + ).ContinueWith( + continuationAction: static (Task task) => + { + ExecuteReaderAsyncCallContext context = (ExecuteReaderAsyncCallContext)task.AsyncState; + SqlCommand command = context.Command; + Guid operationId = context.OperationID; + TaskCompletionSource source = context.TaskCompletionSource; + context.Dispose(); + + command.CleanupExecuteReaderAsync(task, source, operationId); + }, + scheduler: TaskScheduler.Default + ); } catch (Exception e) { source.SetException(e); + context?.Dispose(); } return returnedTask; } + private void SetCachedCommandExecuteReaderAsyncContext(ExecuteReaderAsyncCallContext instance) + { + if (_activeConnection?.InnerConnection is SqlInternalConnection sqlInternalConnection) + { + Interlocked.CompareExchange(ref sqlInternalConnection.CachedCommandExecuteReaderAsyncContext, instance, null); + } + } + + private void SetCachedCommandExecuteNonQueryAsyncContext(ExecuteNonQueryAsyncCallContext instance) + { + if (_activeConnection?.InnerConnection is SqlInternalConnection sqlInternalConnection) + { + Interlocked.CompareExchange(ref sqlInternalConnection.CachedCommandExecuteNonQueryAsyncContext, instance, null); + } + } + + private void SetCachedCommandExecuteXmlReaderContext(ExecuteXmlReaderAsyncCallContext instance) + { + if (_activeConnection?.InnerConnection is SqlInternalConnection sqlInternalConnection) + { + Interlocked.CompareExchange(ref sqlInternalConnection.CachedCommandExecuteXmlReaderAsyncContext, instance, null); + } + } + /// public override Task ExecuteScalarAsync(CancellationToken cancellationToken) => // Do not use retry logic here as internal call to ExecuteReaderAsync handles retry logic. @@ -3116,7 +3283,9 @@ private Task InternalExecuteScalarAsync(CancellationToken cancellationTo // exception thrown by Dispose... source.SetException(e); } - }, TaskScheduler.Default); + }, + TaskScheduler.Default + ); } return source.Task; }, TaskScheduler.Default).Unwrap(); @@ -3141,6 +3310,7 @@ private Task InternalExecuteXmlReaderAsync(CancellationToken cancella { SqlClientEventSource.Log.TryCorrelationTraceEvent(" ObjectID {0}, ActivityID {1}", ObjectID, ActivityCorrelator.Current); SqlConnection.ExecutePermission.Demand(); + Guid operationId = Guid.Empty; // connection can be used as state in RegisterForConnectionCloseNotification continuation // to avoid an allocation so use it as the state value if possible but it can be changed if @@ -3158,31 +3328,45 @@ private Task InternalExecuteXmlReaderAsync(CancellationToken cancella registration = cancellationToken.Register(s_cancelIgnoreFailure, this); } + ExecuteXmlReaderAsyncCallContext context = null; + if (_activeConnection?.InnerConnection is SqlInternalConnection sqlInternalConnection) + { + context = Interlocked.Exchange(ref sqlInternalConnection.CachedCommandExecuteXmlReaderAsyncContext, null); + } + if (context is null) + { + context = new ExecuteXmlReaderAsyncCallContext(); + } + context.Set(this, source, registration, operationId); + Task returnedTask = source.Task; try { returnedTask = RegisterForConnectionCloseNotification(returnedTask); - Task.Factory.FromAsync(BeginExecuteXmlReaderAsync, EndExecuteXmlReaderAsync, null).ContinueWith((t) => - { - registration.Dispose(); - if (t.IsFaulted) + Task.Factory.FromAsync( + beginMethod: static (AsyncCallback callback, object stateObject) => { - Exception e = t.Exception.InnerException; - source.SetException(e); - } - else + return ((ExecuteXmlReaderAsyncCallContext)stateObject).Command.BeginExecuteXmlReaderAsync(callback, stateObject); + }, + endMethod: static (IAsyncResult asyncResult) => { - if (t.IsCanceled) - { - source.SetCanceled(); - } - else - { - source.SetResult(t.Result); - } - } - }, TaskScheduler.Default); + return ((ExecuteXmlReaderAsyncCallContext)asyncResult.AsyncState).Command.EndExecuteXmlReaderAsync(asyncResult); + }, + state: context + ).ContinueWith( + static (Task task) => + { + ExecuteXmlReaderAsyncCallContext context = (ExecuteXmlReaderAsyncCallContext)task.AsyncState; + SqlCommand command = context.Command; + Guid operationId = context.OperationID; + TaskCompletionSource source = context.TaskCompletionSource; + context.Dispose(); + + command.CleanupAfterExecuteXmlReaderAsync(task, source, operationId); + }, + TaskScheduler.Default + ); } catch (Exception e) { @@ -3192,6 +3376,26 @@ private Task InternalExecuteXmlReaderAsync(CancellationToken cancella return returnedTask; } + private void CleanupAfterExecuteXmlReaderAsync(Task task, TaskCompletionSource source, Guid operationId) + { + if (task.IsFaulted) + { + Exception e = task.Exception.InnerException; + source.SetException(e); + } + else + { + if (task.IsCanceled) + { + source.SetCanceled(); + } + else + { + source.SetResult(task.Result); + } + } + } + /// public void RegisterColumnEncryptionKeyStoreProvidersOnCommand(IDictionary customProviders) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlInternalConnection.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlInternalConnection.cs index a39c83d831..742c8b2865 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlInternalConnection.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlInternalConnection.cs @@ -27,11 +27,9 @@ internal abstract class SqlInternalConnection : DbConnectionInternal private bool _isGlobalTransactionEnabledForServer; // Whether Global Transactions are enabled for this Azure SQL DB Server private static readonly Guid s_globalTransactionTMID = new("1c742caf-6680-40ea-9c26-6b6846079764"); // ID of the Non-MSDTC, Azure SQL DB Transaction Manager -#if NET internal SqlCommand.ExecuteReaderAsyncCallContext CachedCommandExecuteReaderAsyncContext; internal SqlCommand.ExecuteNonQueryAsyncCallContext CachedCommandExecuteNonQueryAsyncContext; internal SqlCommand.ExecuteXmlReaderAsyncCallContext CachedCommandExecuteXmlReaderAsyncContext; -#endif internal SqlDataReader.Snapshot CachedDataReaderSnapshot; internal SqlDataReader.IsDBNullAsyncCallContext CachedDataReaderIsDBNullContext; internal SqlDataReader.ReadAsyncCallContext CachedDataReaderReadAsyncContext; From d60c800680e12066fcbcb40a317cb9ac76232646 Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Tue, 27 May 2025 17:54:30 +0100 Subject: [PATCH 2/2] Following review - redundant null check --- .../netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs | 2 +- .../netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs index 3e0cced22a..4fbbd619e8 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -2646,7 +2646,7 @@ private Task InternalExecuteNonQueryAsync(CancellationToken cancellationTok { s_diagnosticListener.WriteCommandError(operationId, this, _transaction, e); source.SetException(e); - context?.Dispose(); + context.Dispose(); } return returnedTask; diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs index 3f5ee052cd..5bc7eed53a 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -3042,7 +3042,7 @@ private Task InternalExecuteNonQueryAsync(CancellationToken cancellationTok catch (Exception e) { source.SetException(e); - context?.Dispose(); + context.Dispose(); } return returnedTask;