Skip to content

Commit

Permalink
Use SQL system catalog views to check for the presence of a Recoverab…
Browse files Browse the repository at this point in the history
…le column. This removes the need for SELECT permissions to send a message to a queue table. (#1452)

Co-authored-by: Marc Wils <[email protected]>
  • Loading branch information
tmasternak and MarcWils authored Oct 17, 2024
1 parent c6b5359 commit 7ec3284
Show file tree
Hide file tree
Showing 11 changed files with 45 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
namespace NServiceBus.Transport.SqlServer.AcceptanceTests.NativeTimeouts
namespace NServiceBus.Transport.SqlServer.AcceptanceTests.NativeTimeouts
{
using System;
using System.Collections.Generic;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
namespace NServiceBus.Transport.SqlServer.IntegrationTests
namespace NServiceBus.Transport.SqlServer.IntegrationTests
{
using System;
using System.Threading;
Expand All @@ -23,7 +23,7 @@ public async Task SetUp()

await ResetQueue(addressParser, sqlConnectionFactory);

queue = new TableBasedQueue(addressParser.Parse(QueueTableName).QualifiedTableName, QueueTableName, false);
queue = new TableBasedQueue(addressParser.Parse(QueueTableName), QueueTableName, false);
}

[Test]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
namespace NServiceBus.Transport.SqlServer.IntegrationTests
namespace NServiceBus.Transport.SqlServer.IntegrationTests
{
using System;
using System.Collections.Generic;
Expand Down Expand Up @@ -112,8 +112,7 @@ async Task PrepareAsync(CancellationToken cancellationToken = default)
Task PurgeOutputQueue(QueueAddressTranslator addressTranslator, CancellationToken cancellationToken = default)
{
purger = new QueuePurger(sqlConnectionFactory);
var queueAddress = addressTranslator.Parse(ValidAddress).QualifiedTableName;
queue = new TableBasedQueue(queueAddress, ValidAddress, true);
queue = new TableBasedQueue(addressTranslator.Parse(ValidAddress), ValidAddress, true);

return purger.Purge(queue, cancellationToken);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
namespace NServiceBus.Transport.SqlServer.IntegrationTests
namespace NServiceBus.Transport.SqlServer.IntegrationTests
{
using System;
using System.Collections.Generic;
Expand Down Expand Up @@ -28,7 +28,7 @@ public async Task SetUp()

await CreateQueueIfNotExists(addressParser, sqlConnectionFactory);

queue = new TableBasedQueue(addressParser.Parse(QueueTableName).QualifiedTableName, QueueTableName, true);
queue = new TableBasedQueue(addressParser.Parse(QueueTableName), QueueTableName, true);
}

[Test]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
namespace NServiceBus.Transport.SqlServer.IntegrationTests
namespace NServiceBus.Transport.SqlServer.IntegrationTests
{
using System;
using System.Linq;
Expand All @@ -23,7 +23,7 @@ public async Task Should_stop_receiving_messages_after_first_unsuccessful_receiv

var parser = new QueueAddressTranslator("nservicebus", "dbo", null, null);
var inputQueueName = "input";
var inputQueueAddress = parser.Parse(inputQueueName).Address;
var inputQueueAddress = parser.Parse(inputQueueName);
var inputQueue = new FakeTableBasedQueue(inputQueueAddress, queueSize, successfulReceives);

var connectionString = Environment.GetEnvironmentVariable("SqlServerTransportConnectionString") ?? @"Data Source=.\SQLEXPRESS;Initial Catalog=nservicebus;Integrated Security=True";
Expand All @@ -35,7 +35,7 @@ public async Task Should_stop_receiving_messages_after_first_unsuccessful_receiv
};

transport.Testing.QueueFactoryOverride = qa =>
qa == inputQueueAddress ? inputQueue : new TableBasedQueue(parser.Parse(qa).QualifiedTableName, qa, true);
qa == inputQueueAddress.Address ? inputQueue : new TableBasedQueue(parser.Parse(qa), qa, true);

var receiveSettings = new ReceiveSettings("receiver", new Transport.QueueAddress(inputQueueName), true, false, "error");
var hostSettings = new HostSettings("IntegrationTests", string.Empty, new StartupDiagnosticEntries(),
Expand Down Expand Up @@ -87,7 +87,7 @@ class FakeTableBasedQueue : TableBasedQueue
int queueSize;
int successfulReceives;

public FakeTableBasedQueue(string address, int queueSize, int successfulReceives) : base(address, "", true)
public FakeTableBasedQueue(CanonicalQueueAddress address, int queueSize, int successfulReceives) : base(address, "", true)
{
this.queueSize = queueSize;
this.successfulReceives = successfulReceives;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
namespace NServiceBus.Transport.SqlServer.IntegrationTests
namespace NServiceBus.Transport.SqlServer.IntegrationTests
{
using System;
using System.Collections.Generic;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
namespace NServiceBus.Transport.SqlServer.IntegrationTests
namespace NServiceBus.Transport.SqlServer.IntegrationTests
{
using System;
using System.Collections.Generic;
Expand Down Expand Up @@ -134,7 +134,7 @@ Task PurgeOutputQueue(QueueAddressTranslator addressParser, CancellationToken ca
{
purger = new QueuePurger(sqlConnectionFactory);
var queueAddress = addressParser.Parse(ValidAddress);
queue = new TableBasedQueue(queueAddress.QualifiedTableName, queueAddress.Address, true);
queue = new TableBasedQueue(queueAddress, queueAddress.Address, true);

return purger.Purge(queue, cancellationToken);
}
Expand Down
6 changes: 5 additions & 1 deletion src/NServiceBus.Transport.SqlServer/Queuing/SqlConstants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ THEN DATEADD(ms, @TimeToBeReceivedMs, GETUTCDATE()) END,
IF (@NOCOUNT = 'ON') SET NOCOUNT ON;
IF (@NOCOUNT = 'OFF') SET NOCOUNT OFF;";

public static readonly string CheckIfTableHasRecoverableText = "SELECT TOP (0) * FROM {0} WITH (NOLOCK);";
public static string CheckIfTableHasRecoverableText { get; set; } = @"
SELECT COUNT(*)
FROM {0}.sys.columns c
WHERE c.object_id = OBJECT_ID(N'{1}')
AND c.name = 'Recoverable'";

public static readonly string StoreDelayedMessageText =
@"
Expand Down
45 changes: 24 additions & 21 deletions src/NServiceBus.Transport.SqlServer/Queuing/TableBasedQueue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@ class TableBasedQueue
{
public string Name { get; }

public TableBasedQueue(string qualifiedTableName, string queueName, bool isStreamSupported)
public TableBasedQueue(CanonicalQueueAddress queueAddress, string queueName, bool isStreamSupported)
{
this.qualifiedTableName = qualifiedTableName;
qualifiedTableName = queueAddress.QualifiedTableName;
Name = queueName;
receiveCommand = Format(SqlConstants.ReceiveText, this.qualifiedTableName);
purgeCommand = Format(SqlConstants.PurgeText, this.qualifiedTableName);
purgeExpiredCommand = Format(SqlConstants.PurgeBatchOfExpiredMessagesText, this.qualifiedTableName);
checkExpiresIndexCommand = Format(SqlConstants.CheckIfExpiresIndexIsPresent, this.qualifiedTableName);
checkNonClusteredRowVersionIndexCommand = Format(SqlConstants.CheckIfNonClusteredRowVersionIndexIsPresent, this.qualifiedTableName);
checkHeadersColumnTypeCommand = Format(SqlConstants.CheckHeadersColumnType, this.qualifiedTableName);
receiveCommand = Format(SqlConstants.ReceiveText, qualifiedTableName);
purgeCommand = Format(SqlConstants.PurgeText, qualifiedTableName);
purgeExpiredCommand = Format(SqlConstants.PurgeBatchOfExpiredMessagesText, qualifiedTableName);
checkExpiresIndexCommand = Format(SqlConstants.CheckIfExpiresIndexIsPresent, qualifiedTableName);
checkNonClusteredRowVersionIndexCommand = Format(SqlConstants.CheckIfNonClusteredRowVersionIndexIsPresent, qualifiedTableName);
checkHeadersColumnTypeCommand = Format(SqlConstants.CheckHeadersColumnType, qualifiedTableName);
checkRecoverableColumnCommand = Format(SqlConstants.CheckIfTableHasRecoverableText, queueAddress.Catalog, qualifiedTableName);
this.isStreamSupported = isStreamSupported;
}

Expand Down Expand Up @@ -145,23 +146,24 @@ async Task<string> GetSendCommandText(SqlConnection connection, SqlTransaction t
return sendCommand;
}

var commandText = Format(SqlConstants.CheckIfTableHasRecoverableText, qualifiedTableName);
using (var command = new SqlCommand(commandText, connection, transaction))
using (var command = connection.CreateCommand())
{
using (var reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false))
command.CommandText = checkRecoverableColumnCommand;
command.CommandType = CommandType.Text;
command.Transaction = transaction;

var rowsCount = await command.ExecuteScalarAsync<int>(nameof(checkRecoverableColumnCommand), cancellationToken).ConfigureAwait(false);
if (rowsCount > 0)
{
for (int fieldIndex = 0; fieldIndex < reader.FieldCount; fieldIndex++)
{
if (string.Equals("Recoverable", reader.GetName(fieldIndex), StringComparison.OrdinalIgnoreCase))
{
cachedSendCommand = Format(SqlConstants.SendTextWithRecoverable, qualifiedTableName);
return cachedSendCommand;
}
}
cachedSendCommand = Format(SqlConstants.SendTextWithRecoverable, qualifiedTableName);
return cachedSendCommand;
}
else
{

cachedSendCommand = Format(SqlConstants.SendText, qualifiedTableName);
return cachedSendCommand;
cachedSendCommand = Format(SqlConstants.SendText, qualifiedTableName);
return cachedSendCommand;
}
}
}
finally
Expand Down Expand Up @@ -237,6 +239,7 @@ public override string ToString()
string checkExpiresIndexCommand;
string checkNonClusteredRowVersionIndexCommand;
string checkHeadersColumnTypeCommand;
string checkRecoverableColumnCommand;
bool isStreamSupported;

readonly SemaphoreSlim sendCommandLock = new SemaphoreSlim(1, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public TableBasedQueue Get(string destination)
{
var address = addressTranslator.Parse(destination);
var key = Tuple.Create(address.QualifiedTableName, address.Address);
var queue = cache.GetOrAdd(key, x => new TableBasedQueue(x.Item1, x.Item2, isStreamSupported));
var queue = cache.GetOrAdd(key, x => new TableBasedQueue(address, x.Item2, isStreamSupported));

return queue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ public async Task ConfigureReceiveInfrastructure(ReceiveSettings[] receiveSettin

var schemaVerification = new SchemaInspector((queue, token) => connectionFactory.OpenNewConnection(token), validateExpiredIndex);

var queueFactory = transport.Testing.QueueFactoryOverride ?? (queueName => new TableBasedQueue(addressTranslator.Parse(queueName).QualifiedTableName, queueName, !isEncrypted));
var queueFactory = transport.Testing.QueueFactoryOverride ?? (queueName => new TableBasedQueue(addressTranslator.Parse(queueName), queueName, !isEncrypted));

//Create delayed delivery infrastructure
CanonicalQueueAddress delayedQueueCanonicalAddress = null;
Expand Down

0 comments on commit 7ec3284

Please sign in to comment.