From a6b054f0c7eca60dc14f171e4f164cdacbf1a3c2 Mon Sep 17 00:00:00 2001 From: asrichesson <51305664+asrichesson@users.noreply.github.com> Date: Wed, 18 Sep 2024 13:50:14 -0400 Subject: [PATCH] Handle cross database view definition (#2388) --- .../LanguageServices/LanguageService.cs | 63 +++--- .../Scripting/ScripterCore.cs | 198 ++++++++++-------- .../Scripting/SqlObjectIdentifier.cs | 19 ++ .../Scripting/Contracts/ScriptingObject.cs | 5 +- .../Scripting/ScriptAsScriptingOperation.cs | 13 +- .../LanguageServer/PeekDefinitionTests.cs | 175 ++++++++-------- .../LanguageServer/PeekDefinitionTests.cs | 5 +- 7 files changed, 268 insertions(+), 210 deletions(-) create mode 100644 src/Microsoft.SqlTools.ServiceLayer/Scripting/SqlObjectIdentifier.cs diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs index 29a3e19ed0..03dec1f0fe 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs @@ -1270,15 +1270,15 @@ private DefinitionResult QueueTask(TextDocumentPosition textDocumentPosition, Sc bindingTimeout: LanguageService.PeekDefinitionTimeout, bindOperation: (bindingContext, cancelToken) => { - string schemaName = this.GetSchemaName(scriptParseInfo, textDocumentPosition.Position, scriptFile); + Sql4PartIdentifier identifier = this.GetFullIdentifier(scriptParseInfo, textDocumentPosition.Position); + // Script object using SMO Scripter scripter = new Scripter(bindingContext.ServerConnection, connInfo); return scripter.GetScript( scriptParseInfo.ParseResult, textDocumentPosition.Position, bindingContext.MetadataDisplayInfoProvider, - tokenText, - schemaName); + identifier); }, timeoutOperation: (bindingContext) => { @@ -1422,14 +1422,13 @@ internal DefinitionResult GetDefinition(TextDocumentPosition textDocumentPositio /// Wrapper around find token method /// /// - /// - /// + /// /// token index - private int FindTokenWithCorrectOffset(ScriptParseInfo scriptParseInfo, int startLine, int startColumn) + private int FindTokenWithCorrectOffset(ScriptParseInfo scriptParseInfo, Position position) { - var tokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.FindToken(startLine, startColumn); + var tokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.FindToken(position.Line, position.Character); var end = scriptParseInfo.ParseResult.Script.TokenManager.GetToken(tokenIndex).EndLocation; - if (end.LineNumber == startLine && end.ColumnNumber == startColumn) + if (end.LineNumber == position.Line && end.ColumnNumber == position.Character) { return tokenIndex + 1; } @@ -1437,33 +1436,41 @@ private int FindTokenWithCorrectOffset(ScriptParseInfo scriptParseInfo, int star } /// - /// Extract schema name for a token, if present + /// Returns a 4 part identifier at the position in a script, if present /// /// /// - /// - /// schema name - private string GetSchemaName(ScriptParseInfo scriptParseInfo, Position position, ScriptFile scriptFile) + /// + private Sql4PartIdentifier GetFullIdentifier(ScriptParseInfo scriptParseInfo, Position position) { - // Offset index by 1 for sql parser - int startLine = position.Line; - int startColumn = position.Character; - - // Get schema name - if (scriptParseInfo != null && scriptParseInfo.ParseResult != null && scriptParseInfo.ParseResult.Script != null && scriptParseInfo.ParseResult.Script.Tokens != null) - { - var tokenIndex = FindTokenWithCorrectOffset(scriptParseInfo, startLine, startColumn); - var prevTokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.GetPreviousSignificantTokenIndex(tokenIndex); - var prevTokenText = scriptParseInfo.ParseResult.Script.TokenManager.GetText(prevTokenIndex); - if (prevTokenText != null && prevTokenText.Equals(".")) + if (scriptParseInfo?.ParseResult?.Script?.Tokens == null) return null; + var tokenManager = scriptParseInfo.ParseResult.Script.TokenManager; + int tokenIndex = this.FindTokenWithCorrectOffset(scriptParseInfo, position); + var identifiers = new string[4]; + //work backwards from the initial token to read identifier parts + for (int i = 0; i < identifiers.Length; i++) + { + if (i > 0) //consume separator dot { - var schemaTokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.GetPreviousSignificantTokenIndex(prevTokenIndex); - Token schemaToken = scriptParseInfo.ParseResult.Script.TokenManager.GetToken(schemaTokenIndex); - return TextUtilities.RemoveSquareBracketSyntax(schemaToken.Text); + tokenIndex = tokenManager.GetPreviousSignificantTokenIndex(tokenIndex); + if (tokenIndex < 0) break; + var period = tokenManager.GetText(tokenIndex); + if (period is null or not ".") break; + tokenIndex = tokenManager.GetPreviousSignificantTokenIndex(tokenIndex); } + + if (tokenIndex < 0) break; + string identifierText = tokenManager.GetText(tokenIndex); + if (string.IsNullOrEmpty(identifierText)) break; + identifiers[i] = TextUtilities.RemoveSquareBracketSyntax(identifierText); } - // if no schema name, returns null - return null; + return new Sql4PartIdentifier + { + ObjectName = identifiers[0], + SchemaName = identifiers[1], + DatabaseName = identifiers[2], + ServerName = identifiers[3] + }; } /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScripterCore.cs b/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScripterCore.cs index 86d0db788f..d77a8a9380 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScripterCore.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScripterCore.cs @@ -72,48 +72,56 @@ internal Database Database { if (this.serverConnection != null && !string.IsNullOrEmpty(this.serverConnection.DatabaseName)) { - try + // The default database name is the database name of the server connection + string dbName = this.serverConnection.DatabaseName; + + // If there is a query DbConnection, use that connection to get the database name + // This is preferred since it has the most current database name (in case of database switching) + if (this.connectionInfo?.TryGetConnection(ConnectionType.Query, out DbConnection connection) == true) { - // Reuse existing connection - Server server = new Server(this.serverConnection); - // The default database name is the database name of the server connection - string dbName = this.serverConnection.DatabaseName; - if (this.connectionInfo != null) + if (!string.IsNullOrEmpty(connection.Database)) { - // If there is a query DbConnection, use that connection to get the database name - // This is preferred since it has the most current database name (in case of database switching) - DbConnection connection; - if (connectionInfo.TryGetConnection(ConnectionType.Query, out connection)) - { - if (!string.IsNullOrEmpty(connection.Database)) - { - dbName = connection.Database; - } - } + dbName = connection.Database; } - this.database = new Database(server, dbName); - this.database.Refresh(); - } - catch (ConnectionFailureException cfe) - { - Logger.Error("Exception at PeekDefinition Database.get() : " + cfe.Message); - this.error = true; - this.errorMessage = (connectionInfo != null && connectionInfo.IsCloud) ? SR.PeekDefinitionAzureError(cfe.Message) : SR.PeekDefinitionError(cfe.Message); - return null; - } - catch (Exception ex) - { - Logger.Error("Exception at PeekDefinition Database.get() : " + ex.Message); - this.error = true; - this.errorMessage = SR.PeekDefinitionError(ex.Message); - return null; } + + this.database = this.GetDatabase(dbName); } } + return this.database; } } + private Database? GetDatabase(string dbName) + { + try + { + // Reuse existing connection + var server = new Server(this.serverConnection); + + var db = new Database(server, dbName); + db.Refresh(); + return db; + } + catch (ConnectionFailureException cfe) + { + Logger.Error("Exception at PeekDefinition Database.get() : " + cfe.Message); + this.error = true; + this.errorMessage = (connectionInfo != null && connectionInfo.IsCloud) + ? SR.PeekDefinitionAzureError(cfe.Message) + : SR.PeekDefinitionError(cfe.Message); + return null; + } + catch (Exception ex) + { + Logger.Error("Exception at PeekDefinition Database.get() : " + ex.Message); + this.error = true; + this.errorMessage = SR.PeekDefinitionError(ex.Message); + return null; + } + } + /// /// Add the given type, scriptgetter and the typeName string to the respective dictionaries /// @@ -129,34 +137,36 @@ private void AddSupportedType(DeclarationType type, string typeName, string quic /// /// Get the script of the selected token based on the type of the token /// - /// - /// - /// + /// + /// + /// + /// The object to be scripted /// Location object of the script file - internal DefinitionResult GetScript(ParseResult parseResult, Position position, IMetadataDisplayInfoProvider metadataDisplayInfoProvider, string tokenText, string schemaName) + internal DefinitionResult GetScript(ParseResult parseResult, Position position, IMetadataDisplayInfoProvider metadataDisplayInfoProvider, Sql3PartIdentifier identifier) { int parserLine = position.Line; int parserColumn = position.Character; // Get DeclarationItems from The Intellisense Resolver for the selected token. The type of the selected token is extracted from the declarationItem. IEnumerable declarationItems = GetCompletionsForToken(parseResult, parserLine, parserColumn, metadataDisplayInfoProvider); - if (declarationItems != null && declarationItems.Count() > 0) + if (declarationItems != null && declarationItems.Any()) { + Database? targetDb = identifier.DatabaseName == null ? this.Database : this.GetDatabase(identifier.DatabaseName); foreach (Declaration declarationItem in declarationItems) { if (declarationItem.Title == null) { continue; } - if (this.Database == null) + if (targetDb == null) { return GetDefinitionErrorResult(SR.PeekDefinitionDatabaseError); } - StringComparison caseSensitivity = this.Database.CaseSensitive ? StringComparison.Ordinal : StringComparison.OrdinalIgnoreCase; + StringComparison caseSensitivity = targetDb.CaseSensitive ? StringComparison.Ordinal : StringComparison.OrdinalIgnoreCase; // if declarationItem matches the selected token, script SMO using that type - if (declarationItem.Title.Equals(tokenText, caseSensitivity)) + if (declarationItem.Title.Equals(identifier.ObjectName, caseSensitivity)) { - return GetDefinitionUsingDeclarationType(declarationItem.Type, declarationItem.DatabaseQualifiedName, tokenText, schemaName); + return GetDefinitionUsingDeclarationType(declarationItem.Type, declarationItem.DatabaseQualifiedName, identifier); } } } @@ -164,7 +174,7 @@ internal DefinitionResult GetScript(ParseResult parseResult, Position position, { // if no declarationItem matched the selected token, we try to find the type of the token using QuickInfo.Text string quickInfoText = GetQuickInfoForToken(parseResult, parserLine, parserColumn, metadataDisplayInfoProvider); - return GetDefinitionUsingQuickInfoText(quickInfoText, tokenText, schemaName); + return GetDefinitionUsingQuickInfoText(quickInfoText, identifier); } // no definition found return GetDefinitionErrorResult(SR.PeekDefinitionNoResultsError); @@ -174,17 +184,17 @@ internal DefinitionResult GetScript(ParseResult parseResult, Position position, /// Script an object using the type extracted from quickInfo Text /// /// the text from the quickInfo for the selected token - /// The text of the selected token - /// Schema name + /// The object for the definition /// - internal DefinitionResult GetDefinitionUsingQuickInfoText(string quickInfoText, string tokenText, string schemaName) + internal DefinitionResult GetDefinitionUsingQuickInfoText(string quickInfoText, Sql3PartIdentifier identifier) { - if (this.Database == null) + Database? targetDb = identifier.DatabaseName == null ? this.Database : this.GetDatabase(identifier.DatabaseName); + if (targetDb == null) { return GetDefinitionErrorResult(SR.PeekDefinitionDatabaseError); } - StringComparison caseSensitivity = this.Database.CaseSensitive ? StringComparison.Ordinal : StringComparison.OrdinalIgnoreCase; - string tokenType = GetTokenTypeFromQuickInfo(quickInfoText, tokenText, caseSensitivity); + StringComparison caseSensitivity = targetDb.CaseSensitive ? StringComparison.Ordinal : StringComparison.OrdinalIgnoreCase; + string tokenType = GetTokenTypeFromQuickInfo(quickInfoText, identifier.ObjectName, caseSensitivity); if (tokenType != null) { if (sqlObjectTypesFromQuickInfo.TryGetValue(tokenType.ToLowerInvariant(), out string sqlObjectType)) @@ -193,15 +203,12 @@ internal DefinitionResult GetDefinitionUsingQuickInfoText(string quickInfoText, // This workaround ensures that a schema name is present by attempting // to get the schema name from the declaration item. // If all fails, the default schema name is assumed to be "dbo" - if ((connectionInfo != null && connectionInfo.ConnectionDetails.AuthenticationType.Equals(Constants.SqlLoginAuthenticationType)) && string.IsNullOrEmpty(schemaName)) + if (connectionInfo != null && connectionInfo.ConnectionDetails.AuthenticationType.Equals(Constants.SqlLoginAuthenticationType) && string.IsNullOrEmpty(identifier.SchemaName)) { - string fullObjectName = this.GetFullObjectNameFromQuickInfo(quickInfoText, tokenText, caseSensitivity); - schemaName = this.GetSchemaFromDatabaseQualifiedName(fullObjectName, tokenText); + string fullObjectName = this.GetFullObjectNameFromQuickInfo(quickInfoText, identifier.ObjectName, caseSensitivity); + identifier.SchemaName = this.GetSchemaFromDatabaseQualifiedName(fullObjectName, identifier.ObjectName); } - Location[] locations = GetSqlObjectDefinition( - tokenText, - schemaName, - sqlObjectType); + Location[] locations = GetSqlObjectDefinition(identifier, sqlObjectType); DefinitionResult result = new DefinitionResult { IsErrorResult = this.error, @@ -221,13 +228,13 @@ internal DefinitionResult GetDefinitionUsingQuickInfoText(string quickInfoText, } /// - /// Script a object using the type extracted from declarationItem + /// Script an object using the type extracted from declarationItem /// - /// The Declaration object that matched with the selected token - /// The text of the selected token - /// Schema name + /// + /// + /// The object for the definition /// - internal DefinitionResult GetDefinitionUsingDeclarationType(DeclarationType type, string databaseQualifiedName, string tokenText, string schemaName) + internal DefinitionResult GetDefinitionUsingDeclarationType(DeclarationType type, string databaseQualifiedName, Sql3PartIdentifier identifier) { if (sqlObjectTypes.TryGetValue(type, out string sqlObjectType)) { @@ -235,15 +242,13 @@ internal DefinitionResult GetDefinitionUsingDeclarationType(DeclarationType type // This workaround ensures that a schema name is present by attempting // to get the schema name from the declaration item. // If all fails, the default schema name is assumed to be "dbo" - if ((connectionInfo != null && connectionInfo.ConnectionDetails.AuthenticationType.Equals(Constants.SqlLoginAuthenticationType)) && string.IsNullOrEmpty(schemaName)) + if ((connectionInfo != null && connectionInfo.ConnectionDetails.AuthenticationType.Equals(Constants.SqlLoginAuthenticationType)) && string.IsNullOrEmpty(identifier.SchemaName)) { string fullObjectName = databaseQualifiedName; - schemaName = this.GetSchemaFromDatabaseQualifiedName(fullObjectName, tokenText); + identifier.SchemaName = this.GetSchemaFromDatabaseQualifiedName(fullObjectName, identifier.ObjectName); } - Location[] locations = GetSqlObjectDefinition( - tokenText, - schemaName, - sqlObjectType); + + Location[] locations = this.GetSqlObjectDefinition(identifier, sqlObjectType); DefinitionResult result = new DefinitionResult { IsErrorResult = this.error, @@ -256,8 +261,20 @@ internal DefinitionResult GetDefinitionUsingDeclarationType(DeclarationType type return GetDefinitionErrorResult(SR.PeekDefinitionTypeNotSupportedError); } + internal Location[] GetSqlObjectDefinition( + string objectName, + string schemaName, + string objectType) + { + return GetSqlObjectDefinition(new Sql3PartIdentifier + { + ObjectName = objectName, + SchemaName = schemaName, + }, objectType); + } + /// - /// Script a object using SMO and write to a file. + /// Script an object using SMO and write to a file. /// /// Function that returns the SMO scripts for an object /// SQL object name @@ -265,15 +282,15 @@ internal DefinitionResult GetDefinitionUsingDeclarationType(DeclarationType type /// Type of SQL object /// Location object representing URI and range of the script file internal Location[] GetSqlObjectDefinition( - string objectName, - string schemaName, - string objectType) + Sql3PartIdentifier identifier, + string objectType) { // script file destination - string tempFileName = (schemaName != null) ? Path.Combine(this.tempPath, string.Format("{0}.{1}.sql", schemaName, objectName)) - : Path.Combine(this.tempPath, string.Format("{0}.sql", objectName)); + string fileName = CreateFileName(identifier); - SmoScriptingOperation operation = InitScriptOperation(objectName, schemaName, objectType); + string tempFileName = Path.Combine(this.tempPath, fileName); + + SmoScriptingOperation operation = InitScriptOperation(identifier, objectType); operation.Execute(); string script = operation.ScriptText; @@ -289,7 +306,7 @@ internal Location[] GetSqlObjectDefinition( createSyntax = string.Format("CREATE"); foreach (string line in lines) { - if (LineContainsObject(line, objectName, createSyntax)) + if (LineContainsObject(line, identifier.ObjectName, createSyntax)) { createStatementLineNumber = lineCount; objectFound = true; @@ -311,6 +328,13 @@ internal Location[] GetSqlObjectDefinition( } } + private static string CreateFileName(Sql3PartIdentifier identifier) + { + if (identifier.DatabaseName != null) return $"{identifier.DatabaseName}.{identifier.SchemaName}.{identifier.ObjectName}.sql"; + if (identifier.SchemaName != null) return $"{identifier.SchemaName}.{identifier.ObjectName}.sql"; + return $"{identifier.ObjectName}.sql"; + } + #region Helper Methods /// /// Return schema name from the full name of the database. If schema is missing return dbo as schema name. @@ -351,15 +375,15 @@ internal Location[] GetLocationFromFile(string tempFileName, int lineNumber) // Create a location array containing the tempFile Uri, as expected by VSCode. Location[] locations = new[] { - new Location + new Location + { + Uri = tempFileName, + Range = new Range { - Uri = tempFileName, - Range = new Range - { - Start = new Position { Line = lineNumber, Character = 0}, - End = new Position { Line = lineNumber + 1, Character = 0} - } + Start = new Position { Line = lineNumber, Character = 0 }, + End = new Position { Line = lineNumber + 1, Character = 0 } } + } }; return locations; } @@ -453,18 +477,17 @@ internal IEnumerable GetCompletionsForToken(ParseResult parseResult /// /// Wrapper method that calls Resolver.FindCompletions /// - /// - /// + /// /// - /// /// - internal SmoScriptingOperation InitScriptOperation(string objectName, string schemaName, string objectType) + internal SmoScriptingOperation InitScriptOperation(Sql3PartIdentifier identifier, string objectType) { // object that has to be scripted ScriptingObject scriptingObject = new ScriptingObject { - Name = objectName, - Schema = schemaName, + Name = identifier.ObjectName, + Schema = identifier.SchemaName, + DatabaseName = identifier.DatabaseName ?? this.Database.Name, Type = objectType }; @@ -491,7 +514,6 @@ internal SmoScriptingOperation InitScriptOperation(string objectName, string sch ScriptPrimaryKeys = false, ScriptTriggers = false, UniqueKeys = false - }; List objectList = new List(); diff --git a/src/Microsoft.SqlTools.ServiceLayer/Scripting/SqlObjectIdentifier.cs b/src/Microsoft.SqlTools.ServiceLayer/Scripting/SqlObjectIdentifier.cs new file mode 100644 index 0000000000..5c00b52f9b --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Scripting/SqlObjectIdentifier.cs @@ -0,0 +1,19 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Scripting +{ + internal class Sql4PartIdentifier : Sql3PartIdentifier + { + public string? ServerName { get; set; } + } + + internal class Sql3PartIdentifier + { + public required string ObjectName { get; set; } + public string? SchemaName { get; set; } + public string? DatabaseName { get; set; } + } +} \ No newline at end of file diff --git a/src/Microsoft.SqlTools.SqlCore/Scripting/Contracts/ScriptingObject.cs b/src/Microsoft.SqlTools.SqlCore/Scripting/Contracts/ScriptingObject.cs index f8005f6eb4..dd7ee729f9 100644 --- a/src/Microsoft.SqlTools.SqlCore/Scripting/Contracts/ScriptingObject.cs +++ b/src/Microsoft.SqlTools.SqlCore/Scripting/Contracts/ScriptingObject.cs @@ -61,6 +61,8 @@ public sealed class ScriptingObject : IEquatable /// public string ParentTypeName { get; set; } + public string DatabaseName { get; set; } + public override string ToString() { string objectName = string.Empty; @@ -82,6 +84,7 @@ public override int GetHashCode() StringComparer.OrdinalIgnoreCase.GetHashCode(this.Schema ?? string.Empty) ^ StringComparer.OrdinalIgnoreCase.GetHashCode(this.ParentName ?? string.Empty) ^ StringComparer.OrdinalIgnoreCase.GetHashCode(this.ParentTypeName ?? string.Empty) ^ + StringComparer.OrdinalIgnoreCase.GetHashCode(this.DatabaseName ?? string.Empty) ^ StringComparer.OrdinalIgnoreCase.GetHashCode(this.Name ?? string.Empty); } @@ -105,7 +108,7 @@ public bool Equals(ScriptingObject other) string.Equals(this.Schema, other.Schema, StringComparison.OrdinalIgnoreCase) && string.Equals(this.ParentName, other.ParentName, StringComparison.OrdinalIgnoreCase) && string.Equals(this.ParentTypeName, other.ParentTypeName, StringComparison.OrdinalIgnoreCase) && - string.Equals(this.ParentTypeName, other.ParentTypeName, StringComparison.OrdinalIgnoreCase) && + string.Equals(this.DatabaseName, other.DatabaseName, StringComparison.OrdinalIgnoreCase) && string.Equals(this.Name, other.Name, StringComparison.OrdinalIgnoreCase); } } diff --git a/src/Microsoft.SqlTools.SqlCore/Scripting/ScriptAsScriptingOperation.cs b/src/Microsoft.SqlTools.SqlCore/Scripting/ScriptAsScriptingOperation.cs index 3c635dd626..aa3f09baea 100644 --- a/src/Microsoft.SqlTools.SqlCore/Scripting/ScriptAsScriptingOperation.cs +++ b/src/Microsoft.SqlTools.SqlCore/Scripting/ScriptAsScriptingOperation.cs @@ -74,7 +74,7 @@ public ScriptAsScriptingOperation(ScriptingParams parameters, string azureAccoun internal ServerConnection ServerConnection { get; set; } private string serverName; - private string databaseName; + private bool disconnectAtDispose = false; public override void Execute() @@ -181,7 +181,7 @@ private string GenerateScriptSelect(Server server, UrnCollection urns) ScriptingObject scriptingObject = this.Parameters.ScriptingObjects[0]; Urn objectUrn = urns[0]; string typeName = objectUrn.GetNameForType(scriptingObject.Type); - + var dbName = objectUrn.GetNameForType("Database"); // select from service broker if (string.Compare(typeName, "ServiceBroker", StringComparison.CurrentCultureIgnoreCase) == 0) { @@ -198,7 +198,7 @@ private string GenerateScriptSelect(Server server, UrnCollection urns) // select from table or view else { - Database db = server.Databases[databaseName]; + Database db = server.Databases[dbName]; bool isDw = db.IsSqlDw; script = ScriptingHelper.SelectFromTableOrView(server, objectUrn, isDw); } @@ -525,7 +525,6 @@ private UrnCollection CreateUrns(ServerConnection serverConnection) IEnumerable selectedObjects = new List(this.Parameters.ScriptingObjects); serverName = serverConnection.TrueName; - databaseName = new SqlConnectionStringBuilder(this.Parameters.ConnectionString).InitialCatalog; UrnCollection urnCollection = new UrnCollection(); foreach (var scriptingObject in selectedObjects) { @@ -534,7 +533,11 @@ private UrnCollection CreateUrns(ServerConnection serverConnection) // TODO: get the default schema scriptingObject.Schema = "dbo"; } - urnCollection.Add(scriptingObject.ToUrn(serverName, databaseName)); + if (string.IsNullOrEmpty(scriptingObject.DatabaseName)) + { + scriptingObject.DatabaseName = serverConnection.DatabaseName; + } + urnCollection.Add(scriptingObject.ToUrn(serverName, scriptingObject.DatabaseName)); } return urnCollection; } diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/LanguageServer/PeekDefinitionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/LanguageServer/PeekDefinitionTests.cs index 71737a1b40..6de4327576 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/LanguageServer/PeekDefinitionTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/LanguageServer/PeekDefinitionTests.cs @@ -23,6 +23,7 @@ using System.Collections.Generic; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using System.Linq; namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.LanguageServices { @@ -121,7 +122,6 @@ public void LoggerGetValidTableDefinitionTest() test.LogMessage = "OnScriptingProgress ScriptingCompleted"; //Log message to verify. This message comes from SMO code. test.Verify(); // The log message should be absent since the tracing level is set to Off. test.Cleanup(); - } /// @@ -151,7 +151,7 @@ public void GetTableDefinitionInvalidObjectTest() public void GetTableDefinitionWithSchemaTest() { // Get live connectionInfo and serverConnection - ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); + ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition("master"); ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); Scripter scripter = new Scripter(serverConnection, connInfo); @@ -203,7 +203,7 @@ public void GetDefinitionWithNoResultsFoundError() }; ScriptParseInfo scriptParseInfo = new ScriptParseInfo() { IsConnected = true }; Mock bindingContextMock = new Mock(); - DefinitionResult result = scripter.GetScript(scriptParseInfo.ParseResult, position, bindingContextMock.Object.MetadataDisplayInfoProvider, objectName, null); + DefinitionResult result = scripter.GetScript(scriptParseInfo.ParseResult, position, bindingContextMock.Object.MetadataDisplayInfoProvider, new Sql3PartIdentifier {ObjectName = objectName}); Assert.NotNull(result); Assert.True(result.IsErrorResult); @@ -291,6 +291,7 @@ public void GetValidViewDefinitionTest() Cleanup(locations); } + /// /// Test get definition for an invalid view object with no schema name and with active connection /// @@ -415,7 +416,7 @@ private void ValidatePeekTest(string databaseName, string objectName, string obj var connectionService = LiveConnectionHelper.GetLiveTestConnectionService(); connectionService.Disconnect(new DisconnectParams { - OwnerUri = connInfo.OwnerUri + OwnerUri = connInfo.OwnerUri }); } @@ -530,7 +531,6 @@ public async Task GetUserDefinedTableTypeDefinitionWithNonExistentFailureTest() string schemaName = "dbo"; string objectType = UserDefinedTableTypeTypeName; await ExecuteAndValidatePeekTest(null, objectName, objectType, schemaName); - } /// @@ -583,7 +583,6 @@ public void GetDefinitionUsingDeclarationTypeWithValidObjectTest() Assert.NotNull(result.Locations); Assert.False(result.IsErrorResult); Cleanup(result.Locations); - } /// @@ -620,12 +619,11 @@ public void GetDefinitionUsingQuickInfoTextWithValidObjectTest() string schemaName = "sys"; string quickInfoText = "view master.sys.objects"; - DefinitionResult result = scripter.GetDefinitionUsingQuickInfoText(quickInfoText, objectName, schemaName); + DefinitionResult result = scripter.GetDefinitionUsingQuickInfoText(quickInfoText, new Sql3PartIdentifier { ObjectName = objectName, SchemaName = schemaName }); Assert.NotNull(result); Assert.NotNull(result.Locations); Assert.False(result.IsErrorResult); Cleanup(result.Locations); - } /// @@ -643,7 +641,7 @@ public void GetDefinitionUsingQuickInfoTextWithNonexistentObjectTest() string schemaName = "sys"; string quickInfoText = "view master.sys.objects"; - DefinitionResult result = scripter.GetDefinitionUsingQuickInfoText(quickInfoText, objectName, schemaName); + DefinitionResult result = scripter.GetDefinitionUsingQuickInfoText(quickInfoText, new Sql3PartIdentifier { ObjectName = objectName, SchemaName = schemaName }); Assert.NotNull(result); Assert.True(result.IsErrorResult); } @@ -703,40 +701,13 @@ public void GetDatabaseWithQueryConnectionTest() public async Task GetDefinitionFromChildrenAndParents() { string queryString = "select * from master.sys.objects"; - // place the cursor on every token - - //cursor on objects - TextDocumentPosition objectDocument = CreateTextDocPositionWithCursor(26, OwnerUri); - //cursor on sys - TextDocumentPosition sysDocument = CreateTextDocPositionWithCursor(22, OwnerUri); - - //cursor on master - TextDocumentPosition masterDocument = CreateTextDocPositionWithCursor(17, OwnerUri); - - LiveConnectionHelper.TestConnectionResult connectionResult = LiveConnectionHelper.InitLiveConnectionInfo(null); - ScriptFile scriptFile = connectionResult.ScriptFile; - ConnectionInfo connInfo = connectionResult.ConnectionInfo; - connInfo.RemoveAllConnections(); - var bindingQueue = new ConnectedBindingQueue(); - bindingQueue.AddConnectionContext(connInfo); - scriptFile.Contents = queryString; - - var service = new LanguageService(); - service.RemoveScriptParseInfo(OwnerUri); - service.BindingQueue = bindingQueue; - await service.UpdateLanguageServiceOnConnection(connectionResult.ConnectionInfo); - Thread.Sleep(2000); - - ScriptParseInfo scriptInfo = new ScriptParseInfo { IsConnected = true }; - await service.ParseAndBind(scriptFile, connInfo); - scriptInfo.ConnectionKey = bindingQueue.AddConnectionContext(connInfo); - service.ScriptParseInfoMap.TryAdd(OwnerUri, scriptInfo); + // place the cursor on every token // When I call the language service - var objectResult = service.GetDefinition(objectDocument, scriptFile, connInfo); - var sysResult = service.GetDefinition(sysDocument, scriptFile, connInfo); - var masterResult = service.GetDefinition(masterDocument, scriptFile, connInfo); + DefinitionResult objectResult = await this.PeekDefinitionAt(queryString, 26); //cursor on objects + DefinitionResult sysResult = await this.PeekDefinitionAt(queryString, 22); //cursor on sys + DefinitionResult masterResult = await this.PeekDefinitionAt(queryString, 17); //cursor on master // Then I expect the results to be non-null Assert.NotNull(objectResult); @@ -750,40 +721,72 @@ public async Task GetDefinitionFromChildrenAndParents() Cleanup(objectResult.Locations); Cleanup(sysResult.Locations); Cleanup(masterResult.Locations); - service.ScriptParseInfoMap.TryRemove(OwnerUri, out _); - connInfo.RemoveAllConnections(); } [Test] public async Task GetDefinitionFromProcedures() { - string queryString = "EXEC master.dbo.sp_MSrepl_startup"; // place the cursor on every token + // When I call the language service + DefinitionResult fnResult = await this.PeekDefinitionAt(queryString, 30); //cursor on objects + DefinitionResult sysResult = await this.PeekDefinitionAt(queryString, 14); //cursor on sys + DefinitionResult masterResult = await this.PeekDefinitionAt(queryString, 10); //cursor on master + + // Then I expect the results to be non-null + Assert.NotNull(fnResult); + Assert.NotNull(sysResult); + Assert.NotNull(masterResult); + + // And I expect the all results to be the same + Assert.True(CompareLocations(fnResult.Locations, sysResult.Locations)); + Assert.True(CompareLocations(fnResult.Locations, masterResult.Locations)); + + Cleanup(fnResult.Locations); + Cleanup(sysResult.Locations); + Cleanup(masterResult.Locations); + } + + [Test] + public async Task GetCrossDatabaseDefinition() + { + string queryString = "SELECT * FROM msdb.dbo.sysalerts"; + //cursor on objects - TextDocumentPosition fnDocument = CreateTextDocPositionWithCursor(30, TestUri); + DefinitionResult definition = await PeekDefinitionAt(queryString, 30, "master"); - //cursor on sys - TextDocumentPosition dboDocument = CreateTextDocPositionWithCursor(14, TestUri); + Assert.IsFalse(definition.IsErrorResult); + Location location = definition.Locations.Single(); - //cursor on master - TextDocumentPosition masterDocument = CreateTextDocPositionWithCursor(10, TestUri); + Assert.NotNull(location); + Assert.IsTrue(location.Uri.EndsWith("msdb.dbo.sysalerts.sql")); - LiveConnectionHelper.TestConnectionResult connectionResult = LiveConnectionHelper.InitLiveConnectionInfo(null); + Uri filePath = GetFilePath(location); + Assert.NotNull(filePath); + var scriptedFile = await File.ReadAllTextAsync(filePath.AbsolutePath); + Assert.IsTrue(scriptedFile.Contains("CREATE TABLE [dbo].[sysalerts]")); + + Cleanup(definition.Locations); + } + + private async Task PeekDefinitionAt(string fileContents, int column, string? databasename = null) + { + TextDocumentPosition fnDocument = this.CreateTextDocPositionWithCursor(column, TestUri); + + LiveConnectionHelper.TestConnectionResult connectionResult = LiveConnectionHelper.InitLiveConnectionInfo(databasename); ScriptFile scriptFile = connectionResult.ScriptFile; ConnectionInfo connInfo = connectionResult.ConnectionInfo; connInfo.RemoveAllConnections(); var bindingQueue = new ConnectedBindingQueue(); bindingQueue.AddConnectionContext(connInfo); - scriptFile.Contents = queryString; + scriptFile.Contents = fileContents; var service = new LanguageService(); service.RemoveScriptParseInfo(OwnerUri); service.BindingQueue = bindingQueue; await service.UpdateLanguageServiceOnConnection(connectionResult.ConnectionInfo); - Thread.Sleep(2000); ScriptParseInfo scriptInfo = new ScriptParseInfo { IsConnected = true }; await service.ParseAndBind(scriptFile, connInfo); @@ -792,50 +795,44 @@ public async Task GetDefinitionFromProcedures() // When I call the language service var fnResult = service.GetDefinition(fnDocument, scriptFile, connInfo); - var sysResult = service.GetDefinition(dboDocument, scriptFile, connInfo); - var masterResult = service.GetDefinition(masterDocument, scriptFile, connInfo); - - // Then I expect the results to be non-null - Assert.NotNull(fnResult); - Assert.NotNull(sysResult); - Assert.NotNull(masterResult); - - // And I expect the all results to be the same - Assert.True(CompareLocations(fnResult.Locations, sysResult.Locations)); - Assert.True(CompareLocations(fnResult.Locations, masterResult.Locations)); - - Cleanup(fnResult.Locations); - Cleanup(sysResult.Locations); - Cleanup(masterResult.Locations); - service.ScriptParseInfoMap.TryRemove(TestUri, out _); - connInfo.RemoveAllConnections(); + return fnResult; } - /// - /// Helper method to clean up script files + /// Gets the path to the file of a Location on disk /// - private void Cleanup(Location[] locations) + /// + /// + private static Uri? GetFilePath(Location location) { - try + string filePath = location.Uri; + Uri fileUri = null; + if (Uri.IsWellFormedUriString(filePath, UriKind.Absolute)) + { + fileUri = new Uri(filePath); + } + else { - string filePath = locations[0].Uri; - Uri fileUri = null; + filePath = filePath.Replace("file:/", "file://"); if (Uri.IsWellFormedUriString(filePath, UriKind.Absolute)) { fileUri = new Uri(filePath); } - else - { - filePath = filePath.Replace("file:/", "file://"); - if (Uri.IsWellFormedUriString(filePath, UriKind.Absolute)) - { - fileUri = new Uri(filePath); - } - } - if (fileUri != null && File.Exists(fileUri.LocalPath)) + } + return fileUri; + } + + /// + /// Helper method to clean up script files + /// + private void Cleanup(Location[] locations) + { + try + { + foreach (var location in locations) { - File.Delete(fileUri.LocalPath); + var path = GetFilePath(location); + if (path != null) File.Delete(path.AbsolutePath); //does not throw if file doesn't exist } } catch (Exception) @@ -880,4 +877,12 @@ private TextDocumentPosition CreateTextDocPositionWithCursor(int column, string return textDocPos; } } -} \ No newline at end of file + + internal static class ScripterExtensions + { + internal static DefinitionResult GetDefinitionUsingDeclarationType(this Scripter scripter, DeclarationType type, string databaseQualifiedName, string objectName, string schemaName) + { + return scripter.GetDefinitionUsingDeclarationType(type, databaseQualifiedName, new Sql3PartIdentifier { ObjectName = objectName, SchemaName = schemaName }); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/LanguageServer/PeekDefinitionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/LanguageServer/PeekDefinitionTests.cs index b379e37037..16f04a2628 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/LanguageServer/PeekDefinitionTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/LanguageServer/PeekDefinitionTests.cs @@ -72,7 +72,6 @@ public void GetSchemaFromDatabaseQualifiedNameWithValidNameTest() /// /// Test PeekDefinition.GetSchemaFromDatabaseQualifiedName with a valid object name and no schema /// - [Test] public void GetSchemaFromDatabaseQualifiedNameWithNoSchemaTest() { @@ -286,7 +285,7 @@ public void GetDefinitionUsingQuickInfoWithoutConnectionTest() Scripter peekDefinition = new Scripter(null, null); string objectName = "tableName"; string quickInfoText = "table master.dbo.tableName"; - DefinitionResult result = peekDefinition.GetDefinitionUsingQuickInfoText(quickInfoText, objectName, null); + DefinitionResult result = peekDefinition.GetDefinitionUsingQuickInfoText(quickInfoText, new Sql3PartIdentifier { ObjectName = objectName }); Assert.NotNull(result); Assert.True(result.IsErrorResult); } @@ -301,7 +300,7 @@ public void GetDefinitionUsingDeclarationItemWithoutConnectionTest() Scripter peekDefinition = new Scripter(null, null); string objectName = "tableName"; string fullObjectName = "master.dbo.tableName"; - Assert.Throws(() => peekDefinition.GetDefinitionUsingDeclarationType(DeclarationType.Table, fullObjectName, objectName, null)); + Assert.Throws(() => peekDefinition.GetDefinitionUsingDeclarationType(DeclarationType.Table, fullObjectName, new Sql3PartIdentifier { ObjectName = objectName })); } } }