From 7ef81d0e5409dab6ba999b21f5214fc43bd0f08c Mon Sep 17 00:00:00 2001 From: Anthony Dresser Date: Fri, 28 Jul 2017 13:35:46 -0700 Subject: [PATCH] Multiple Connection Simple Execute (#421) * change simple execute to open a new connection and close it every query * updated tests for simple execute * removed an unnecessary connect * refactored code to be more readable * global try catch on simple execute * added multiple execution test * update execution to be asynchrous; update tests to account for asynchrounous nature --- .../QueryExecution/QueryExecutionService.cs | 169 +++++++++++------- .../Execution/ServiceIntegrationTests.cs | 50 +++++- 2 files changed, 151 insertions(+), 68 deletions(-) diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs index 088d6f1666..c8f5211919 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs @@ -4,9 +4,12 @@ // using System; using System.Collections.Concurrent; +using System.Data.Common; +using System.Data.SqlClient; using System.IO; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts.ExecuteRequests; @@ -118,6 +121,17 @@ private IFileStreamFactory BufferFileFactory /// internal SqlToolsSettings Settings { get; set; } + /// + /// Holds a map from the simple execute unique GUID and the underlying task that is being ran + /// + private readonly Lazy> simpleExecuteRequests = + new Lazy>(() => new ConcurrentDictionary()); + + /// + /// Holds a map from the simple execute unique GUID and the underlying task that is being ran + /// + internal ConcurrentDictionary ActiveSimpleExecuteRequests => simpleExecuteRequests.Value; + #endregion /// @@ -173,82 +187,111 @@ internal Task HandleExecuteRequest(ExecuteRequestParamsBase executeParams, /// /// Handles a request to execute a string and return the result /// - internal Task HandleSimpleExecuteRequest(SimpleExecuteParams executeParams, + internal async Task HandleSimpleExecuteRequest(SimpleExecuteParams executeParams, RequestContext requestContext) { - ExecuteStringParams executeStringParams = new ExecuteStringParams - { - Query = executeParams.QueryString, - // generate guid as the owner uri to make sure every query is unique - OwnerUri = Guid.NewGuid().ToString() - }; - - // get connection - ConnectionInfo connInfo; - if (!ConnectionService.TryFindConnection(executeParams.OwnerUri, out connInfo)) - { - return requestContext.SendError(SR.QueryServiceQueryInvalidOwnerUri); - } - - if (connInfo.ConnectionDetails.MultipleActiveResultSets == null || connInfo.ConnectionDetails.MultipleActiveResultSets == false) { - // if multipleActive result sets is not allowed, don't specific a connection and make the ownerURI the true owneruri - connInfo = null; - executeStringParams.OwnerUri = executeParams.OwnerUri; - } - - Func queryCreateFailureAction = message => requestContext.SendError(message); - - ResultOnlyContext newContext = new ResultOnlyContext(requestContext); - - // handle sending event back when the query completes - Query.QueryAsyncEventHandler queryComplete = async q => + try { - Query removedQuery; - // check to make sure any results were recieved - if (q.Batches.Length == 0 || q.Batches[0].ResultSets.Count == 0) + string randomUri = Guid.NewGuid().ToString(); + ExecuteStringParams executeStringParams = new ExecuteStringParams { - await requestContext.SendError(SR.QueryServiceResultSetHasNoResults); - ActiveQueries.TryRemove(executeStringParams.OwnerUri, out removedQuery); - return; - } + Query = executeParams.QueryString, + // generate guid as the owner uri to make sure every query is unique + OwnerUri = randomUri + }; - var rowCount = q.Batches[0].ResultSets[0].RowCount; - // check to make sure there is a safe amount of rows to load into memory - if (rowCount > Int32.MaxValue) + // get connection + ConnectionInfo connInfo; + if (!ConnectionService.TryFindConnection(executeParams.OwnerUri, out connInfo)) { - await requestContext.SendError(SR.QueryServiceResultSetTooLarge); - ActiveQueries.TryRemove(executeStringParams.OwnerUri, out removedQuery); + await requestContext.SendError(SR.QueryServiceQueryInvalidOwnerUri); return; } - SubsetParams subsetRequestParams = new SubsetParams - { - OwnerUri = executeStringParams.OwnerUri, - BatchIndex = 0, - ResultSetIndex = 0, - RowsStartIndex = 0, - RowsCount = Convert.ToInt32(rowCount) - }; - // get the data to send back - ResultSetSubset subset = await InterServiceResultSubset(subsetRequestParams); - SimpleExecuteResult result = new SimpleExecuteResult + ConnectParams connectParams = new ConnectParams { - RowCount = q.Batches[0].ResultSets[0].RowCount, - ColumnInfo = q.Batches[0].ResultSets[0].Columns, - Rows = subset.Rows + OwnerUri = randomUri, + Connection = connInfo.ConnectionDetails, + Type = ConnectionType.Default }; - await requestContext.SendResult(result); - // remove the active query since we are done with it - ActiveQueries.TryRemove(executeStringParams.OwnerUri, out removedQuery); - }; + + Task workTask = Task.Run(async () => { + await ConnectionService.Connect(connectParams); - // handle sending error back when query fails - Query.QueryAsyncErrorEventHandler queryFail = async (q, e) => - { - await requestContext.SendError(e); - }; + ConnectionInfo newConn; + ConnectionService.TryFindConnection(randomUri, out newConn); + + Func queryCreateFailureAction = message => requestContext.SendError(message); + + ResultOnlyContext newContext = new ResultOnlyContext(requestContext); + + // handle sending event back when the query completes + Query.QueryAsyncEventHandler queryComplete = async query => + { + try + { + // check to make sure any results were recieved + if (query.Batches.Length == 0 || query.Batches[0].ResultSets.Count == 0) + { + await requestContext.SendError(SR.QueryServiceResultSetHasNoResults); + return; + } + + var rowCount = query.Batches[0].ResultSets[0].RowCount; + // check to make sure there is a safe amount of rows to load into memory + if (rowCount > Int32.MaxValue) + { + await requestContext.SendError(SR.QueryServiceResultSetTooLarge); + return; + } + + SubsetParams subsetRequestParams = new SubsetParams + { + OwnerUri = randomUri, + BatchIndex = 0, + ResultSetIndex = 0, + RowsStartIndex = 0, + RowsCount = Convert.ToInt32(rowCount) + }; + // get the data to send back + ResultSetSubset subset = await InterServiceResultSubset(subsetRequestParams); + SimpleExecuteResult result = new SimpleExecuteResult + { + RowCount = query.Batches[0].ResultSets[0].RowCount, + ColumnInfo = query.Batches[0].ResultSets[0].Columns, + Rows = subset.Rows + }; + await requestContext.SendResult(result); + } + finally + { + Query removedQuery; + Task removedTask; + // remove the active query since we are done with it + ActiveQueries.TryRemove(randomUri, out removedQuery); + ActiveSimpleExecuteRequests.TryRemove(randomUri, out removedTask); + ConnectionService.Disconnect(new DisconnectParams(){ + OwnerUri = randomUri, + Type = null + }); + } + }; + + // handle sending error back when query fails + Query.QueryAsyncErrorEventHandler queryFail = async (q, e) => + { + await requestContext.SendError(e); + }; - return InterServiceExecuteQuery(executeStringParams, connInfo, newContext, null, queryCreateFailureAction, queryComplete, queryFail); + await InterServiceExecuteQuery(executeStringParams, newConn, newContext, null, queryCreateFailureAction, queryComplete, queryFail); + }); + + ActiveSimpleExecuteRequests.TryAdd(randomUri, workTask); + } + catch(Exception ex) + { + await requestContext.SendError(ex.ToString()); + } } /// diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Execution/ServiceIntegrationTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Execution/ServiceIntegrationTests.cs index 2a37a1f1b9..802373dab3 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Execution/ServiceIntegrationTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Execution/ServiceIntegrationTests.cs @@ -4,6 +4,7 @@ // using System; +using System.Linq; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.QueryExecution; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts.ExecuteRequests; @@ -431,15 +432,16 @@ public async Task SimpleExecuteErrorWithNoResultsTest() .Complete(); await queryService.HandleSimpleExecuteRequest(queryParams, efv.Object); - Query q; - queryService.ActiveQueries.TryGetValue(Constants.OwnerUri, out q); + await Task.WhenAll(queryService.ActiveSimpleExecuteRequests.Values); - // wait on the task to finish + Query q = queryService.ActiveQueries.Values.First(); + Assert.NotNull(q); q.ExecutionTask.Wait(); efv.Validate(); Assert.Equal(0, queryService.ActiveQueries.Count); + } [Fact] @@ -452,8 +454,11 @@ public async Task SimpleExecuteVerifyResultsTest() .Complete(); await queryService.HandleSimpleExecuteRequest(queryParams, efv.Object); - Query q; - queryService.ActiveQueries.TryGetValue(Constants.OwnerUri, out q); + await Task.WhenAll(queryService.ActiveSimpleExecuteRequests.Values); + + Query q = queryService.ActiveQueries.Values.First(); + + Assert.NotNull(q); // wait on the task to finish q.ExecutionTask.Wait(); @@ -463,6 +468,41 @@ public async Task SimpleExecuteVerifyResultsTest() Assert.Equal(0, queryService.ActiveQueries.Count); } + [Fact] + public async Task SimpleExecuteMultipleQueriesTest() + { + var queryService = Common.GetPrimedExecutionService(Common.StandardTestDataSet, true, false, null); + var queryParams = new SimpleExecuteParams { OwnerUri = Constants.OwnerUri, QueryString = Constants.StandardQuery }; + var efv1 = new EventFlowValidator() + .AddSimpleExecuteQueryResultValidator(Common.StandardTestDataSet) + .Complete(); + var efv2 = new EventFlowValidator() + .AddSimpleExecuteQueryResultValidator(Common.StandardTestDataSet) + .Complete(); + Task qT1 = queryService.HandleSimpleExecuteRequest(queryParams, efv1.Object); + Task qT2 = queryService.HandleSimpleExecuteRequest(queryParams, efv2.Object); + + await Task.WhenAll(qT1, qT2); + + await Task.WhenAll(queryService.ActiveSimpleExecuteRequests.Values); + + var queries = queryService.ActiveQueries.Values.Take(2).ToArray(); + Query q1 = queries[0]; + Query q2 = queries[1]; + + Assert.NotNull(q1); + Assert.NotNull(q2); + + // wait on the task to finish + q1.ExecutionTask.Wait(); + q2.ExecutionTask.Wait(); + + efv1.Validate(); + efv2.Validate(); + + Assert.Equal(0, queryService.ActiveQueries.Count); + } + #endregion private static WorkspaceService GetDefaultWorkspaceService(string query)