From a6b054f0c7eca60dc14f171e4f164cdacbf1a3c2 Mon Sep 17 00:00:00 2001
From: asrichesson <51305664+asrichesson@users.noreply.github.com>
Date: Wed, 18 Sep 2024 13:50:14 -0400
Subject: [PATCH] Handle cross database view definition (#2388)
---
.../LanguageServices/LanguageService.cs | 63 +++---
.../Scripting/ScripterCore.cs | 198 ++++++++++--------
.../Scripting/SqlObjectIdentifier.cs | 19 ++
.../Scripting/Contracts/ScriptingObject.cs | 5 +-
.../Scripting/ScriptAsScriptingOperation.cs | 13 +-
.../LanguageServer/PeekDefinitionTests.cs | 175 ++++++++--------
.../LanguageServer/PeekDefinitionTests.cs | 5 +-
7 files changed, 268 insertions(+), 210 deletions(-)
create mode 100644 src/Microsoft.SqlTools.ServiceLayer/Scripting/SqlObjectIdentifier.cs
diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs
index 29a3e19ed0..03dec1f0fe 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs
@@ -1270,15 +1270,15 @@ private DefinitionResult QueueTask(TextDocumentPosition textDocumentPosition, Sc
bindingTimeout: LanguageService.PeekDefinitionTimeout,
bindOperation: (bindingContext, cancelToken) =>
{
- string schemaName = this.GetSchemaName(scriptParseInfo, textDocumentPosition.Position, scriptFile);
+ Sql4PartIdentifier identifier = this.GetFullIdentifier(scriptParseInfo, textDocumentPosition.Position);
+
// Script object using SMO
Scripter scripter = new Scripter(bindingContext.ServerConnection, connInfo);
return scripter.GetScript(
scriptParseInfo.ParseResult,
textDocumentPosition.Position,
bindingContext.MetadataDisplayInfoProvider,
- tokenText,
- schemaName);
+ identifier);
},
timeoutOperation: (bindingContext) =>
{
@@ -1422,14 +1422,13 @@ internal DefinitionResult GetDefinition(TextDocumentPosition textDocumentPositio
/// Wrapper around find token method
///
///
- ///
- ///
+ ///
/// token index
- private int FindTokenWithCorrectOffset(ScriptParseInfo scriptParseInfo, int startLine, int startColumn)
+ private int FindTokenWithCorrectOffset(ScriptParseInfo scriptParseInfo, Position position)
{
- var tokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.FindToken(startLine, startColumn);
+ var tokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.FindToken(position.Line, position.Character);
var end = scriptParseInfo.ParseResult.Script.TokenManager.GetToken(tokenIndex).EndLocation;
- if (end.LineNumber == startLine && end.ColumnNumber == startColumn)
+ if (end.LineNumber == position.Line && end.ColumnNumber == position.Character)
{
return tokenIndex + 1;
}
@@ -1437,33 +1436,41 @@ private int FindTokenWithCorrectOffset(ScriptParseInfo scriptParseInfo, int star
}
///
- /// Extract schema name for a token, if present
+ /// Returns a 4 part identifier at the position in a script, if present
///
///
///
- ///
- /// schema name
- private string GetSchemaName(ScriptParseInfo scriptParseInfo, Position position, ScriptFile scriptFile)
+ ///
+ private Sql4PartIdentifier GetFullIdentifier(ScriptParseInfo scriptParseInfo, Position position)
{
- // Offset index by 1 for sql parser
- int startLine = position.Line;
- int startColumn = position.Character;
-
- // Get schema name
- if (scriptParseInfo != null && scriptParseInfo.ParseResult != null && scriptParseInfo.ParseResult.Script != null && scriptParseInfo.ParseResult.Script.Tokens != null)
- {
- var tokenIndex = FindTokenWithCorrectOffset(scriptParseInfo, startLine, startColumn);
- var prevTokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.GetPreviousSignificantTokenIndex(tokenIndex);
- var prevTokenText = scriptParseInfo.ParseResult.Script.TokenManager.GetText(prevTokenIndex);
- if (prevTokenText != null && prevTokenText.Equals("."))
+ if (scriptParseInfo?.ParseResult?.Script?.Tokens == null) return null;
+ var tokenManager = scriptParseInfo.ParseResult.Script.TokenManager;
+ int tokenIndex = this.FindTokenWithCorrectOffset(scriptParseInfo, position);
+ var identifiers = new string[4];
+ //work backwards from the initial token to read identifier parts
+ for (int i = 0; i < identifiers.Length; i++)
+ {
+ if (i > 0) //consume separator dot
{
- var schemaTokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.GetPreviousSignificantTokenIndex(prevTokenIndex);
- Token schemaToken = scriptParseInfo.ParseResult.Script.TokenManager.GetToken(schemaTokenIndex);
- return TextUtilities.RemoveSquareBracketSyntax(schemaToken.Text);
+ tokenIndex = tokenManager.GetPreviousSignificantTokenIndex(tokenIndex);
+ if (tokenIndex < 0) break;
+ var period = tokenManager.GetText(tokenIndex);
+ if (period is null or not ".") break;
+ tokenIndex = tokenManager.GetPreviousSignificantTokenIndex(tokenIndex);
}
+
+ if (tokenIndex < 0) break;
+ string identifierText = tokenManager.GetText(tokenIndex);
+ if (string.IsNullOrEmpty(identifierText)) break;
+ identifiers[i] = TextUtilities.RemoveSquareBracketSyntax(identifierText);
}
- // if no schema name, returns null
- return null;
+ return new Sql4PartIdentifier
+ {
+ ObjectName = identifiers[0],
+ SchemaName = identifiers[1],
+ DatabaseName = identifiers[2],
+ ServerName = identifiers[3]
+ };
}
///
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScripterCore.cs b/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScripterCore.cs
index 86d0db788f..d77a8a9380 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScripterCore.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScripterCore.cs
@@ -72,48 +72,56 @@ internal Database Database
{
if (this.serverConnection != null && !string.IsNullOrEmpty(this.serverConnection.DatabaseName))
{
- try
+ // The default database name is the database name of the server connection
+ string dbName = this.serverConnection.DatabaseName;
+
+ // If there is a query DbConnection, use that connection to get the database name
+ // This is preferred since it has the most current database name (in case of database switching)
+ if (this.connectionInfo?.TryGetConnection(ConnectionType.Query, out DbConnection connection) == true)
{
- // Reuse existing connection
- Server server = new Server(this.serverConnection);
- // The default database name is the database name of the server connection
- string dbName = this.serverConnection.DatabaseName;
- if (this.connectionInfo != null)
+ if (!string.IsNullOrEmpty(connection.Database))
{
- // If there is a query DbConnection, use that connection to get the database name
- // This is preferred since it has the most current database name (in case of database switching)
- DbConnection connection;
- if (connectionInfo.TryGetConnection(ConnectionType.Query, out connection))
- {
- if (!string.IsNullOrEmpty(connection.Database))
- {
- dbName = connection.Database;
- }
- }
+ dbName = connection.Database;
}
- this.database = new Database(server, dbName);
- this.database.Refresh();
- }
- catch (ConnectionFailureException cfe)
- {
- Logger.Error("Exception at PeekDefinition Database.get() : " + cfe.Message);
- this.error = true;
- this.errorMessage = (connectionInfo != null && connectionInfo.IsCloud) ? SR.PeekDefinitionAzureError(cfe.Message) : SR.PeekDefinitionError(cfe.Message);
- return null;
- }
- catch (Exception ex)
- {
- Logger.Error("Exception at PeekDefinition Database.get() : " + ex.Message);
- this.error = true;
- this.errorMessage = SR.PeekDefinitionError(ex.Message);
- return null;
}
+
+ this.database = this.GetDatabase(dbName);
}
}
+
return this.database;
}
}
+ private Database? GetDatabase(string dbName)
+ {
+ try
+ {
+ // Reuse existing connection
+ var server = new Server(this.serverConnection);
+
+ var db = new Database(server, dbName);
+ db.Refresh();
+ return db;
+ }
+ catch (ConnectionFailureException cfe)
+ {
+ Logger.Error("Exception at PeekDefinition Database.get() : " + cfe.Message);
+ this.error = true;
+ this.errorMessage = (connectionInfo != null && connectionInfo.IsCloud)
+ ? SR.PeekDefinitionAzureError(cfe.Message)
+ : SR.PeekDefinitionError(cfe.Message);
+ return null;
+ }
+ catch (Exception ex)
+ {
+ Logger.Error("Exception at PeekDefinition Database.get() : " + ex.Message);
+ this.error = true;
+ this.errorMessage = SR.PeekDefinitionError(ex.Message);
+ return null;
+ }
+ }
+
///
/// Add the given type, scriptgetter and the typeName string to the respective dictionaries
///
@@ -129,34 +137,36 @@ private void AddSupportedType(DeclarationType type, string typeName, string quic
///
/// Get the script of the selected token based on the type of the token
///
- ///
- ///
- ///
+ ///
+ ///
+ ///
+ /// The object to be scripted
/// Location object of the script file
- internal DefinitionResult GetScript(ParseResult parseResult, Position position, IMetadataDisplayInfoProvider metadataDisplayInfoProvider, string tokenText, string schemaName)
+ internal DefinitionResult GetScript(ParseResult parseResult, Position position, IMetadataDisplayInfoProvider metadataDisplayInfoProvider, Sql3PartIdentifier identifier)
{
int parserLine = position.Line;
int parserColumn = position.Character;
// Get DeclarationItems from The Intellisense Resolver for the selected token. The type of the selected token is extracted from the declarationItem.
IEnumerable declarationItems = GetCompletionsForToken(parseResult, parserLine, parserColumn, metadataDisplayInfoProvider);
- if (declarationItems != null && declarationItems.Count() > 0)
+ if (declarationItems != null && declarationItems.Any())
{
+ Database? targetDb = identifier.DatabaseName == null ? this.Database : this.GetDatabase(identifier.DatabaseName);
foreach (Declaration declarationItem in declarationItems)
{
if (declarationItem.Title == null)
{
continue;
}
- if (this.Database == null)
+ if (targetDb == null)
{
return GetDefinitionErrorResult(SR.PeekDefinitionDatabaseError);
}
- StringComparison caseSensitivity = this.Database.CaseSensitive ? StringComparison.Ordinal : StringComparison.OrdinalIgnoreCase;
+ StringComparison caseSensitivity = targetDb.CaseSensitive ? StringComparison.Ordinal : StringComparison.OrdinalIgnoreCase;
// if declarationItem matches the selected token, script SMO using that type
- if (declarationItem.Title.Equals(tokenText, caseSensitivity))
+ if (declarationItem.Title.Equals(identifier.ObjectName, caseSensitivity))
{
- return GetDefinitionUsingDeclarationType(declarationItem.Type, declarationItem.DatabaseQualifiedName, tokenText, schemaName);
+ return GetDefinitionUsingDeclarationType(declarationItem.Type, declarationItem.DatabaseQualifiedName, identifier);
}
}
}
@@ -164,7 +174,7 @@ internal DefinitionResult GetScript(ParseResult parseResult, Position position,
{
// if no declarationItem matched the selected token, we try to find the type of the token using QuickInfo.Text
string quickInfoText = GetQuickInfoForToken(parseResult, parserLine, parserColumn, metadataDisplayInfoProvider);
- return GetDefinitionUsingQuickInfoText(quickInfoText, tokenText, schemaName);
+ return GetDefinitionUsingQuickInfoText(quickInfoText, identifier);
}
// no definition found
return GetDefinitionErrorResult(SR.PeekDefinitionNoResultsError);
@@ -174,17 +184,17 @@ internal DefinitionResult GetScript(ParseResult parseResult, Position position,
/// Script an object using the type extracted from quickInfo Text
///
/// the text from the quickInfo for the selected token
- /// The text of the selected token
- /// Schema name
+ /// The object for the definition
///
- internal DefinitionResult GetDefinitionUsingQuickInfoText(string quickInfoText, string tokenText, string schemaName)
+ internal DefinitionResult GetDefinitionUsingQuickInfoText(string quickInfoText, Sql3PartIdentifier identifier)
{
- if (this.Database == null)
+ Database? targetDb = identifier.DatabaseName == null ? this.Database : this.GetDatabase(identifier.DatabaseName);
+ if (targetDb == null)
{
return GetDefinitionErrorResult(SR.PeekDefinitionDatabaseError);
}
- StringComparison caseSensitivity = this.Database.CaseSensitive ? StringComparison.Ordinal : StringComparison.OrdinalIgnoreCase;
- string tokenType = GetTokenTypeFromQuickInfo(quickInfoText, tokenText, caseSensitivity);
+ StringComparison caseSensitivity = targetDb.CaseSensitive ? StringComparison.Ordinal : StringComparison.OrdinalIgnoreCase;
+ string tokenType = GetTokenTypeFromQuickInfo(quickInfoText, identifier.ObjectName, caseSensitivity);
if (tokenType != null)
{
if (sqlObjectTypesFromQuickInfo.TryGetValue(tokenType.ToLowerInvariant(), out string sqlObjectType))
@@ -193,15 +203,12 @@ internal DefinitionResult GetDefinitionUsingQuickInfoText(string quickInfoText,
// This workaround ensures that a schema name is present by attempting
// to get the schema name from the declaration item.
// If all fails, the default schema name is assumed to be "dbo"
- if ((connectionInfo != null && connectionInfo.ConnectionDetails.AuthenticationType.Equals(Constants.SqlLoginAuthenticationType)) && string.IsNullOrEmpty(schemaName))
+ if (connectionInfo != null && connectionInfo.ConnectionDetails.AuthenticationType.Equals(Constants.SqlLoginAuthenticationType) && string.IsNullOrEmpty(identifier.SchemaName))
{
- string fullObjectName = this.GetFullObjectNameFromQuickInfo(quickInfoText, tokenText, caseSensitivity);
- schemaName = this.GetSchemaFromDatabaseQualifiedName(fullObjectName, tokenText);
+ string fullObjectName = this.GetFullObjectNameFromQuickInfo(quickInfoText, identifier.ObjectName, caseSensitivity);
+ identifier.SchemaName = this.GetSchemaFromDatabaseQualifiedName(fullObjectName, identifier.ObjectName);
}
- Location[] locations = GetSqlObjectDefinition(
- tokenText,
- schemaName,
- sqlObjectType);
+ Location[] locations = GetSqlObjectDefinition(identifier, sqlObjectType);
DefinitionResult result = new DefinitionResult
{
IsErrorResult = this.error,
@@ -221,13 +228,13 @@ internal DefinitionResult GetDefinitionUsingQuickInfoText(string quickInfoText,
}
///
- /// Script a object using the type extracted from declarationItem
+ /// Script an object using the type extracted from declarationItem
///
- /// The Declaration object that matched with the selected token
- /// The text of the selected token
- /// Schema name
+ ///
+ ///
+ /// The object for the definition
///
- internal DefinitionResult GetDefinitionUsingDeclarationType(DeclarationType type, string databaseQualifiedName, string tokenText, string schemaName)
+ internal DefinitionResult GetDefinitionUsingDeclarationType(DeclarationType type, string databaseQualifiedName, Sql3PartIdentifier identifier)
{
if (sqlObjectTypes.TryGetValue(type, out string sqlObjectType))
{
@@ -235,15 +242,13 @@ internal DefinitionResult GetDefinitionUsingDeclarationType(DeclarationType type
// This workaround ensures that a schema name is present by attempting
// to get the schema name from the declaration item.
// If all fails, the default schema name is assumed to be "dbo"
- if ((connectionInfo != null && connectionInfo.ConnectionDetails.AuthenticationType.Equals(Constants.SqlLoginAuthenticationType)) && string.IsNullOrEmpty(schemaName))
+ if ((connectionInfo != null && connectionInfo.ConnectionDetails.AuthenticationType.Equals(Constants.SqlLoginAuthenticationType)) && string.IsNullOrEmpty(identifier.SchemaName))
{
string fullObjectName = databaseQualifiedName;
- schemaName = this.GetSchemaFromDatabaseQualifiedName(fullObjectName, tokenText);
+ identifier.SchemaName = this.GetSchemaFromDatabaseQualifiedName(fullObjectName, identifier.ObjectName);
}
- Location[] locations = GetSqlObjectDefinition(
- tokenText,
- schemaName,
- sqlObjectType);
+
+ Location[] locations = this.GetSqlObjectDefinition(identifier, sqlObjectType);
DefinitionResult result = new DefinitionResult
{
IsErrorResult = this.error,
@@ -256,8 +261,20 @@ internal DefinitionResult GetDefinitionUsingDeclarationType(DeclarationType type
return GetDefinitionErrorResult(SR.PeekDefinitionTypeNotSupportedError);
}
+ internal Location[] GetSqlObjectDefinition(
+ string objectName,
+ string schemaName,
+ string objectType)
+ {
+ return GetSqlObjectDefinition(new Sql3PartIdentifier
+ {
+ ObjectName = objectName,
+ SchemaName = schemaName,
+ }, objectType);
+ }
+
///
- /// Script a object using SMO and write to a file.
+ /// Script an object using SMO and write to a file.
///
/// Function that returns the SMO scripts for an object
/// SQL object name
@@ -265,15 +282,15 @@ internal DefinitionResult GetDefinitionUsingDeclarationType(DeclarationType type
/// Type of SQL object
/// Location object representing URI and range of the script file
internal Location[] GetSqlObjectDefinition(
- string objectName,
- string schemaName,
- string objectType)
+ Sql3PartIdentifier identifier,
+ string objectType)
{
// script file destination
- string tempFileName = (schemaName != null) ? Path.Combine(this.tempPath, string.Format("{0}.{1}.sql", schemaName, objectName))
- : Path.Combine(this.tempPath, string.Format("{0}.sql", objectName));
+ string fileName = CreateFileName(identifier);
- SmoScriptingOperation operation = InitScriptOperation(objectName, schemaName, objectType);
+ string tempFileName = Path.Combine(this.tempPath, fileName);
+
+ SmoScriptingOperation operation = InitScriptOperation(identifier, objectType);
operation.Execute();
string script = operation.ScriptText;
@@ -289,7 +306,7 @@ internal Location[] GetSqlObjectDefinition(
createSyntax = string.Format("CREATE");
foreach (string line in lines)
{
- if (LineContainsObject(line, objectName, createSyntax))
+ if (LineContainsObject(line, identifier.ObjectName, createSyntax))
{
createStatementLineNumber = lineCount;
objectFound = true;
@@ -311,6 +328,13 @@ internal Location[] GetSqlObjectDefinition(
}
}
+ private static string CreateFileName(Sql3PartIdentifier identifier)
+ {
+ if (identifier.DatabaseName != null) return $"{identifier.DatabaseName}.{identifier.SchemaName}.{identifier.ObjectName}.sql";
+ if (identifier.SchemaName != null) return $"{identifier.SchemaName}.{identifier.ObjectName}.sql";
+ return $"{identifier.ObjectName}.sql";
+ }
+
#region Helper Methods
///
/// Return schema name from the full name of the database. If schema is missing return dbo as schema name.
@@ -351,15 +375,15 @@ internal Location[] GetLocationFromFile(string tempFileName, int lineNumber)
// Create a location array containing the tempFile Uri, as expected by VSCode.
Location[] locations = new[]
{
- new Location
+ new Location
+ {
+ Uri = tempFileName,
+ Range = new Range
{
- Uri = tempFileName,
- Range = new Range
- {
- Start = new Position { Line = lineNumber, Character = 0},
- End = new Position { Line = lineNumber + 1, Character = 0}
- }
+ Start = new Position { Line = lineNumber, Character = 0 },
+ End = new Position { Line = lineNumber + 1, Character = 0 }
}
+ }
};
return locations;
}
@@ -453,18 +477,17 @@ internal IEnumerable GetCompletionsForToken(ParseResult parseResult
///
/// Wrapper method that calls Resolver.FindCompletions
///
- ///
- ///
+ ///
///
- ///
///
- internal SmoScriptingOperation InitScriptOperation(string objectName, string schemaName, string objectType)
+ internal SmoScriptingOperation InitScriptOperation(Sql3PartIdentifier identifier, string objectType)
{
// object that has to be scripted
ScriptingObject scriptingObject = new ScriptingObject
{
- Name = objectName,
- Schema = schemaName,
+ Name = identifier.ObjectName,
+ Schema = identifier.SchemaName,
+ DatabaseName = identifier.DatabaseName ?? this.Database.Name,
Type = objectType
};
@@ -491,7 +514,6 @@ internal SmoScriptingOperation InitScriptOperation(string objectName, string sch
ScriptPrimaryKeys = false,
ScriptTriggers = false,
UniqueKeys = false
-
};
List objectList = new List();
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Scripting/SqlObjectIdentifier.cs b/src/Microsoft.SqlTools.ServiceLayer/Scripting/SqlObjectIdentifier.cs
new file mode 100644
index 0000000000..5c00b52f9b
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Scripting/SqlObjectIdentifier.cs
@@ -0,0 +1,19 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+namespace Microsoft.SqlTools.ServiceLayer.Scripting
+{
+ internal class Sql4PartIdentifier : Sql3PartIdentifier
+ {
+ public string? ServerName { get; set; }
+ }
+
+ internal class Sql3PartIdentifier
+ {
+ public required string ObjectName { get; set; }
+ public string? SchemaName { get; set; }
+ public string? DatabaseName { get; set; }
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.SqlTools.SqlCore/Scripting/Contracts/ScriptingObject.cs b/src/Microsoft.SqlTools.SqlCore/Scripting/Contracts/ScriptingObject.cs
index f8005f6eb4..dd7ee729f9 100644
--- a/src/Microsoft.SqlTools.SqlCore/Scripting/Contracts/ScriptingObject.cs
+++ b/src/Microsoft.SqlTools.SqlCore/Scripting/Contracts/ScriptingObject.cs
@@ -61,6 +61,8 @@ public sealed class ScriptingObject : IEquatable
///
public string ParentTypeName { get; set; }
+ public string DatabaseName { get; set; }
+
public override string ToString()
{
string objectName = string.Empty;
@@ -82,6 +84,7 @@ public override int GetHashCode()
StringComparer.OrdinalIgnoreCase.GetHashCode(this.Schema ?? string.Empty) ^
StringComparer.OrdinalIgnoreCase.GetHashCode(this.ParentName ?? string.Empty) ^
StringComparer.OrdinalIgnoreCase.GetHashCode(this.ParentTypeName ?? string.Empty) ^
+ StringComparer.OrdinalIgnoreCase.GetHashCode(this.DatabaseName ?? string.Empty) ^
StringComparer.OrdinalIgnoreCase.GetHashCode(this.Name ?? string.Empty);
}
@@ -105,7 +108,7 @@ public bool Equals(ScriptingObject other)
string.Equals(this.Schema, other.Schema, StringComparison.OrdinalIgnoreCase) &&
string.Equals(this.ParentName, other.ParentName, StringComparison.OrdinalIgnoreCase) &&
string.Equals(this.ParentTypeName, other.ParentTypeName, StringComparison.OrdinalIgnoreCase) &&
- string.Equals(this.ParentTypeName, other.ParentTypeName, StringComparison.OrdinalIgnoreCase) &&
+ string.Equals(this.DatabaseName, other.DatabaseName, StringComparison.OrdinalIgnoreCase) &&
string.Equals(this.Name, other.Name, StringComparison.OrdinalIgnoreCase);
}
}
diff --git a/src/Microsoft.SqlTools.SqlCore/Scripting/ScriptAsScriptingOperation.cs b/src/Microsoft.SqlTools.SqlCore/Scripting/ScriptAsScriptingOperation.cs
index 3c635dd626..aa3f09baea 100644
--- a/src/Microsoft.SqlTools.SqlCore/Scripting/ScriptAsScriptingOperation.cs
+++ b/src/Microsoft.SqlTools.SqlCore/Scripting/ScriptAsScriptingOperation.cs
@@ -74,7 +74,7 @@ public ScriptAsScriptingOperation(ScriptingParams parameters, string azureAccoun
internal ServerConnection ServerConnection { get; set; }
private string serverName;
- private string databaseName;
+
private bool disconnectAtDispose = false;
public override void Execute()
@@ -181,7 +181,7 @@ private string GenerateScriptSelect(Server server, UrnCollection urns)
ScriptingObject scriptingObject = this.Parameters.ScriptingObjects[0];
Urn objectUrn = urns[0];
string typeName = objectUrn.GetNameForType(scriptingObject.Type);
-
+ var dbName = objectUrn.GetNameForType("Database");
// select from service broker
if (string.Compare(typeName, "ServiceBroker", StringComparison.CurrentCultureIgnoreCase) == 0)
{
@@ -198,7 +198,7 @@ private string GenerateScriptSelect(Server server, UrnCollection urns)
// select from table or view
else
{
- Database db = server.Databases[databaseName];
+ Database db = server.Databases[dbName];
bool isDw = db.IsSqlDw;
script = ScriptingHelper.SelectFromTableOrView(server, objectUrn, isDw);
}
@@ -525,7 +525,6 @@ private UrnCollection CreateUrns(ServerConnection serverConnection)
IEnumerable selectedObjects = new List(this.Parameters.ScriptingObjects);
serverName = serverConnection.TrueName;
- databaseName = new SqlConnectionStringBuilder(this.Parameters.ConnectionString).InitialCatalog;
UrnCollection urnCollection = new UrnCollection();
foreach (var scriptingObject in selectedObjects)
{
@@ -534,7 +533,11 @@ private UrnCollection CreateUrns(ServerConnection serverConnection)
// TODO: get the default schema
scriptingObject.Schema = "dbo";
}
- urnCollection.Add(scriptingObject.ToUrn(serverName, databaseName));
+ if (string.IsNullOrEmpty(scriptingObject.DatabaseName))
+ {
+ scriptingObject.DatabaseName = serverConnection.DatabaseName;
+ }
+ urnCollection.Add(scriptingObject.ToUrn(serverName, scriptingObject.DatabaseName));
}
return urnCollection;
}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/LanguageServer/PeekDefinitionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/LanguageServer/PeekDefinitionTests.cs
index 71737a1b40..6de4327576 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/LanguageServer/PeekDefinitionTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/LanguageServer/PeekDefinitionTests.cs
@@ -23,6 +23,7 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
+using System.Linq;
namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.LanguageServices
{
@@ -121,7 +122,6 @@ public void LoggerGetValidTableDefinitionTest()
test.LogMessage = "OnScriptingProgress ScriptingCompleted"; //Log message to verify. This message comes from SMO code.
test.Verify(); // The log message should be absent since the tracing level is set to Off.
test.Cleanup();
-
}
///
@@ -151,7 +151,7 @@ public void GetTableDefinitionInvalidObjectTest()
public void GetTableDefinitionWithSchemaTest()
{
// Get live connectionInfo and serverConnection
- ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition();
+ ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition("master");
ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo);
Scripter scripter = new Scripter(serverConnection, connInfo);
@@ -203,7 +203,7 @@ public void GetDefinitionWithNoResultsFoundError()
};
ScriptParseInfo scriptParseInfo = new ScriptParseInfo() { IsConnected = true };
Mock bindingContextMock = new Mock();
- DefinitionResult result = scripter.GetScript(scriptParseInfo.ParseResult, position, bindingContextMock.Object.MetadataDisplayInfoProvider, objectName, null);
+ DefinitionResult result = scripter.GetScript(scriptParseInfo.ParseResult, position, bindingContextMock.Object.MetadataDisplayInfoProvider, new Sql3PartIdentifier {ObjectName = objectName});
Assert.NotNull(result);
Assert.True(result.IsErrorResult);
@@ -291,6 +291,7 @@ public void GetValidViewDefinitionTest()
Cleanup(locations);
}
+
///
/// Test get definition for an invalid view object with no schema name and with active connection
///
@@ -415,7 +416,7 @@ private void ValidatePeekTest(string databaseName, string objectName, string obj
var connectionService = LiveConnectionHelper.GetLiveTestConnectionService();
connectionService.Disconnect(new DisconnectParams
{
- OwnerUri = connInfo.OwnerUri
+ OwnerUri = connInfo.OwnerUri
});
}
@@ -530,7 +531,6 @@ public async Task GetUserDefinedTableTypeDefinitionWithNonExistentFailureTest()
string schemaName = "dbo";
string objectType = UserDefinedTableTypeTypeName;
await ExecuteAndValidatePeekTest(null, objectName, objectType, schemaName);
-
}
///
@@ -583,7 +583,6 @@ public void GetDefinitionUsingDeclarationTypeWithValidObjectTest()
Assert.NotNull(result.Locations);
Assert.False(result.IsErrorResult);
Cleanup(result.Locations);
-
}
///
@@ -620,12 +619,11 @@ public void GetDefinitionUsingQuickInfoTextWithValidObjectTest()
string schemaName = "sys";
string quickInfoText = "view master.sys.objects";
- DefinitionResult result = scripter.GetDefinitionUsingQuickInfoText(quickInfoText, objectName, schemaName);
+ DefinitionResult result = scripter.GetDefinitionUsingQuickInfoText(quickInfoText, new Sql3PartIdentifier { ObjectName = objectName, SchemaName = schemaName });
Assert.NotNull(result);
Assert.NotNull(result.Locations);
Assert.False(result.IsErrorResult);
Cleanup(result.Locations);
-
}
///
@@ -643,7 +641,7 @@ public void GetDefinitionUsingQuickInfoTextWithNonexistentObjectTest()
string schemaName = "sys";
string quickInfoText = "view master.sys.objects";
- DefinitionResult result = scripter.GetDefinitionUsingQuickInfoText(quickInfoText, objectName, schemaName);
+ DefinitionResult result = scripter.GetDefinitionUsingQuickInfoText(quickInfoText, new Sql3PartIdentifier { ObjectName = objectName, SchemaName = schemaName });
Assert.NotNull(result);
Assert.True(result.IsErrorResult);
}
@@ -703,40 +701,13 @@ public void GetDatabaseWithQueryConnectionTest()
public async Task GetDefinitionFromChildrenAndParents()
{
string queryString = "select * from master.sys.objects";
- // place the cursor on every token
-
- //cursor on objects
- TextDocumentPosition objectDocument = CreateTextDocPositionWithCursor(26, OwnerUri);
- //cursor on sys
- TextDocumentPosition sysDocument = CreateTextDocPositionWithCursor(22, OwnerUri);
-
- //cursor on master
- TextDocumentPosition masterDocument = CreateTextDocPositionWithCursor(17, OwnerUri);
-
- LiveConnectionHelper.TestConnectionResult connectionResult = LiveConnectionHelper.InitLiveConnectionInfo(null);
- ScriptFile scriptFile = connectionResult.ScriptFile;
- ConnectionInfo connInfo = connectionResult.ConnectionInfo;
- connInfo.RemoveAllConnections();
- var bindingQueue = new ConnectedBindingQueue();
- bindingQueue.AddConnectionContext(connInfo);
- scriptFile.Contents = queryString;
-
- var service = new LanguageService();
- service.RemoveScriptParseInfo(OwnerUri);
- service.BindingQueue = bindingQueue;
- await service.UpdateLanguageServiceOnConnection(connectionResult.ConnectionInfo);
- Thread.Sleep(2000);
-
- ScriptParseInfo scriptInfo = new ScriptParseInfo { IsConnected = true };
- await service.ParseAndBind(scriptFile, connInfo);
- scriptInfo.ConnectionKey = bindingQueue.AddConnectionContext(connInfo);
- service.ScriptParseInfoMap.TryAdd(OwnerUri, scriptInfo);
+ // place the cursor on every token
// When I call the language service
- var objectResult = service.GetDefinition(objectDocument, scriptFile, connInfo);
- var sysResult = service.GetDefinition(sysDocument, scriptFile, connInfo);
- var masterResult = service.GetDefinition(masterDocument, scriptFile, connInfo);
+ DefinitionResult objectResult = await this.PeekDefinitionAt(queryString, 26); //cursor on objects
+ DefinitionResult sysResult = await this.PeekDefinitionAt(queryString, 22); //cursor on sys
+ DefinitionResult masterResult = await this.PeekDefinitionAt(queryString, 17); //cursor on master
// Then I expect the results to be non-null
Assert.NotNull(objectResult);
@@ -750,40 +721,72 @@ public async Task GetDefinitionFromChildrenAndParents()
Cleanup(objectResult.Locations);
Cleanup(sysResult.Locations);
Cleanup(masterResult.Locations);
- service.ScriptParseInfoMap.TryRemove(OwnerUri, out _);
- connInfo.RemoveAllConnections();
}
[Test]
public async Task GetDefinitionFromProcedures()
{
-
string queryString = "EXEC master.dbo.sp_MSrepl_startup";
// place the cursor on every token
+ // When I call the language service
+ DefinitionResult fnResult = await this.PeekDefinitionAt(queryString, 30); //cursor on objects
+ DefinitionResult sysResult = await this.PeekDefinitionAt(queryString, 14); //cursor on sys
+ DefinitionResult masterResult = await this.PeekDefinitionAt(queryString, 10); //cursor on master
+
+ // Then I expect the results to be non-null
+ Assert.NotNull(fnResult);
+ Assert.NotNull(sysResult);
+ Assert.NotNull(masterResult);
+
+ // And I expect the all results to be the same
+ Assert.True(CompareLocations(fnResult.Locations, sysResult.Locations));
+ Assert.True(CompareLocations(fnResult.Locations, masterResult.Locations));
+
+ Cleanup(fnResult.Locations);
+ Cleanup(sysResult.Locations);
+ Cleanup(masterResult.Locations);
+ }
+
+ [Test]
+ public async Task GetCrossDatabaseDefinition()
+ {
+ string queryString = "SELECT * FROM msdb.dbo.sysalerts";
+
//cursor on objects
- TextDocumentPosition fnDocument = CreateTextDocPositionWithCursor(30, TestUri);
+ DefinitionResult definition = await PeekDefinitionAt(queryString, 30, "master");
- //cursor on sys
- TextDocumentPosition dboDocument = CreateTextDocPositionWithCursor(14, TestUri);
+ Assert.IsFalse(definition.IsErrorResult);
+ Location location = definition.Locations.Single();
- //cursor on master
- TextDocumentPosition masterDocument = CreateTextDocPositionWithCursor(10, TestUri);
+ Assert.NotNull(location);
+ Assert.IsTrue(location.Uri.EndsWith("msdb.dbo.sysalerts.sql"));
- LiveConnectionHelper.TestConnectionResult connectionResult = LiveConnectionHelper.InitLiveConnectionInfo(null);
+ Uri filePath = GetFilePath(location);
+ Assert.NotNull(filePath);
+ var scriptedFile = await File.ReadAllTextAsync(filePath.AbsolutePath);
+ Assert.IsTrue(scriptedFile.Contains("CREATE TABLE [dbo].[sysalerts]"));
+
+ Cleanup(definition.Locations);
+ }
+
+ private async Task PeekDefinitionAt(string fileContents, int column, string? databasename = null)
+ {
+ TextDocumentPosition fnDocument = this.CreateTextDocPositionWithCursor(column, TestUri);
+
+ LiveConnectionHelper.TestConnectionResult connectionResult = LiveConnectionHelper.InitLiveConnectionInfo(databasename);
ScriptFile scriptFile = connectionResult.ScriptFile;
ConnectionInfo connInfo = connectionResult.ConnectionInfo;
connInfo.RemoveAllConnections();
var bindingQueue = new ConnectedBindingQueue();
bindingQueue.AddConnectionContext(connInfo);
- scriptFile.Contents = queryString;
+ scriptFile.Contents = fileContents;
var service = new LanguageService();
service.RemoveScriptParseInfo(OwnerUri);
service.BindingQueue = bindingQueue;
await service.UpdateLanguageServiceOnConnection(connectionResult.ConnectionInfo);
- Thread.Sleep(2000);
ScriptParseInfo scriptInfo = new ScriptParseInfo { IsConnected = true };
await service.ParseAndBind(scriptFile, connInfo);
@@ -792,50 +795,44 @@ public async Task GetDefinitionFromProcedures()
// When I call the language service
var fnResult = service.GetDefinition(fnDocument, scriptFile, connInfo);
- var sysResult = service.GetDefinition(dboDocument, scriptFile, connInfo);
- var masterResult = service.GetDefinition(masterDocument, scriptFile, connInfo);
-
- // Then I expect the results to be non-null
- Assert.NotNull(fnResult);
- Assert.NotNull(sysResult);
- Assert.NotNull(masterResult);
-
- // And I expect the all results to be the same
- Assert.True(CompareLocations(fnResult.Locations, sysResult.Locations));
- Assert.True(CompareLocations(fnResult.Locations, masterResult.Locations));
-
- Cleanup(fnResult.Locations);
- Cleanup(sysResult.Locations);
- Cleanup(masterResult.Locations);
- service.ScriptParseInfoMap.TryRemove(TestUri, out _);
- connInfo.RemoveAllConnections();
+ return fnResult;
}
-
///
- /// Helper method to clean up script files
+ /// Gets the path to the file of a Location on disk
///
- private void Cleanup(Location[] locations)
+ ///
+ ///
+ private static Uri? GetFilePath(Location location)
{
- try
+ string filePath = location.Uri;
+ Uri fileUri = null;
+ if (Uri.IsWellFormedUriString(filePath, UriKind.Absolute))
+ {
+ fileUri = new Uri(filePath);
+ }
+ else
{
- string filePath = locations[0].Uri;
- Uri fileUri = null;
+ filePath = filePath.Replace("file:/", "file://");
if (Uri.IsWellFormedUriString(filePath, UriKind.Absolute))
{
fileUri = new Uri(filePath);
}
- else
- {
- filePath = filePath.Replace("file:/", "file://");
- if (Uri.IsWellFormedUriString(filePath, UriKind.Absolute))
- {
- fileUri = new Uri(filePath);
- }
- }
- if (fileUri != null && File.Exists(fileUri.LocalPath))
+ }
+ return fileUri;
+ }
+
+ ///
+ /// Helper method to clean up script files
+ ///
+ private void Cleanup(Location[] locations)
+ {
+ try
+ {
+ foreach (var location in locations)
{
- File.Delete(fileUri.LocalPath);
+ var path = GetFilePath(location);
+ if (path != null) File.Delete(path.AbsolutePath); //does not throw if file doesn't exist
}
}
catch (Exception)
@@ -880,4 +877,12 @@ private TextDocumentPosition CreateTextDocPositionWithCursor(int column, string
return textDocPos;
}
}
-}
\ No newline at end of file
+
+ internal static class ScripterExtensions
+ {
+ internal static DefinitionResult GetDefinitionUsingDeclarationType(this Scripter scripter, DeclarationType type, string databaseQualifiedName, string objectName, string schemaName)
+ {
+ return scripter.GetDefinitionUsingDeclarationType(type, databaseQualifiedName, new Sql3PartIdentifier { ObjectName = objectName, SchemaName = schemaName });
+ }
+ }
+}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/LanguageServer/PeekDefinitionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/LanguageServer/PeekDefinitionTests.cs
index b379e37037..16f04a2628 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/LanguageServer/PeekDefinitionTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/LanguageServer/PeekDefinitionTests.cs
@@ -72,7 +72,6 @@ public void GetSchemaFromDatabaseQualifiedNameWithValidNameTest()
///
/// Test PeekDefinition.GetSchemaFromDatabaseQualifiedName with a valid object name and no schema
///
-
[Test]
public void GetSchemaFromDatabaseQualifiedNameWithNoSchemaTest()
{
@@ -286,7 +285,7 @@ public void GetDefinitionUsingQuickInfoWithoutConnectionTest()
Scripter peekDefinition = new Scripter(null, null);
string objectName = "tableName";
string quickInfoText = "table master.dbo.tableName";
- DefinitionResult result = peekDefinition.GetDefinitionUsingQuickInfoText(quickInfoText, objectName, null);
+ DefinitionResult result = peekDefinition.GetDefinitionUsingQuickInfoText(quickInfoText, new Sql3PartIdentifier { ObjectName = objectName });
Assert.NotNull(result);
Assert.True(result.IsErrorResult);
}
@@ -301,7 +300,7 @@ public void GetDefinitionUsingDeclarationItemWithoutConnectionTest()
Scripter peekDefinition = new Scripter(null, null);
string objectName = "tableName";
string fullObjectName = "master.dbo.tableName";
- Assert.Throws(() => peekDefinition.GetDefinitionUsingDeclarationType(DeclarationType.Table, fullObjectName, objectName, null));
+ Assert.Throws(() => peekDefinition.GetDefinitionUsingDeclarationType(DeclarationType.Table, fullObjectName, new Sql3PartIdentifier { ObjectName = objectName }));
}
}
}