Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ private static TryCatch<IQueryPipelineStage> TryCreateFullQueryPipeline(
allRanges: allFeedRanges,
isContinuationExpected: cosmosQueryContext.IsContinuationExpected,
maxConcurrency: inputParameters.MaxConcurrency,
fullTextScoreScope: inputParameters.FullTextScoreScope,
requestContinuationToken: inputParameters.InitialUserContinuationToken);
}

Expand Down Expand Up @@ -838,6 +839,7 @@ private InputParameters(
bool enableOptimisticDirectExecution,
bool isHybridSearchQueryPlanOptimizationDisabled,
bool enableDistributedQueryGatewayMode,
FullTextScoreScope fullTextScoreScope,
TestInjections testInjections)
{
this.SqlQuerySpec = sqlQuerySpec ?? throw new ArgumentNullException(nameof(sqlQuerySpec));
Expand All @@ -853,6 +855,7 @@ private InputParameters(
this.EnableOptimisticDirectExecution = enableOptimisticDirectExecution;
this.IsHybridSearchQueryPlanOptimizationDisabled = isHybridSearchQueryPlanOptimizationDisabled;
this.EnableDistributedQueryGatewayMode = enableDistributedQueryGatewayMode;
this.FullTextScoreScope = fullTextScoreScope;
this.TestInjections = testInjections;
}

Expand All @@ -870,6 +873,7 @@ public static InputParameters Create(
bool enableOptimisticDirectExecution,
bool isHybridSearchQueryPlanOptimizationDisabled,
bool enableDistributedQueryGatewayMode,
FullTextScoreScope fullTextScoreScope,
TestInjections testInjections)
{
if (sqlQuerySpec == null)
Expand Down Expand Up @@ -909,6 +913,7 @@ public static InputParameters Create(
enableOptimisticDirectExecution: enableOptimisticDirectExecution,
isHybridSearchQueryPlanOptimizationDisabled: isHybridSearchQueryPlanOptimizationDisabled,
enableDistributedQueryGatewayMode: enableDistributedQueryGatewayMode,
fullTextScoreScope: fullTextScoreScope,
testInjections: testInjections);
}

Expand All @@ -927,6 +932,7 @@ public static InputParameters Create(
public bool IsHybridSearchQueryPlanOptimizationDisabled { get; }
public bool EnableDistributedQueryGatewayMode { get; }
public bool UseLengthAwareRangeComparer { get; }
public FullTextScoreScope FullTextScoreScope { get; }

public InputParameters WithContinuationToken(CosmosElement token)
{
Expand All @@ -944,6 +950,7 @@ public InputParameters WithContinuationToken(CosmosElement token)
this.EnableOptimisticDirectExecution,
this.IsHybridSearchQueryPlanOptimizationDisabled,
this.EnableDistributedQueryGatewayMode,
this.FullTextScoreScope,
this.TestInjections);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ public static TryCatch<IQueryPipelineStage> MonadicCreate(
IReadOnlyList<FeedRangeEpk> allRanges,
int maxItemCount,
bool isContinuationExpected,
int maxConcurrency)
int maxConcurrency,
Cosmos.FullTextScoreScope fullTextScoreScope)
{
TryCatch<IQueryPipelineStage> ComponentPipelineFactory(QueryInfo rewrittenQueryInfo)
{
Expand Down Expand Up @@ -124,10 +125,16 @@ TryCatch<IQueryPipelineStage> ComponentPipelineFactory(QueryInfo rewrittenQueryI
queryInfo.GlobalStatisticsQuery,
sqlQuerySpec.Parameters);

// When FullTextScoreScope is Global, use allRanges (all partitions) for statistics.
// When FullTextScoreScope is Local, use targetRanges (only the filtered partitions) for statistics.
IReadOnlyList<FeedRangeEpk> statisticsTargetRanges = fullTextScoreScope == Cosmos.FullTextScoreScope.Global
? allRanges
: targetRanges;

TryCatch<IQueryPipelineStage> tryCatchGlobalStatisticsPipeline = ParallelCrossPartitionQueryPipelineStage.MonadicCreate(
documentContainer: documentContainer,
sqlQuerySpec: globalStatisticsQuerySpec,
targetRanges: allRanges,
targetRanges: statisticsTargetRanges,
queryPaginationOptions: queryExecutionOptions,
partitionKey: null,
containerQueryProperties: containerQueryProperties,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public static TryCatch<IQueryPipelineStage> MonadicCreate(
IReadOnlyList<FeedRangeEpk> allRanges,
bool isContinuationExpected,
int maxConcurrency,
FullTextScoreScope fullTextScoreScope,
CosmosElement requestContinuationToken)
{
if (documentContainer == null)
Expand Down Expand Up @@ -94,7 +95,7 @@ public static TryCatch<IQueryPipelineStage> MonadicCreate(
requestContinuationToken: requestContinuationToken);
}
else
{
{
MonadicCreatePipelineStage monadicCreatePipelineStage = (_) => HybridSearchCrossPartitionQueryPipelineStage.MonadicCreate(
documentContainer: documentContainer,
containerQueryProperties: containerQueryProperties,
Expand All @@ -105,7 +106,8 @@ public static TryCatch<IQueryPipelineStage> MonadicCreate(
allRanges: allRanges,
maxItemCount: maxItemCount,
isContinuationExpected: isContinuationExpected,
maxConcurrency: maxConcurrency);
maxConcurrency: maxConcurrency,
fullTextScoreScope: fullTextScoreScope);

if (hybridSearchQueryInfo.Skip != null)
{
Expand Down
1 change: 1 addition & 0 deletions Microsoft.Azure.Cosmos/src/Query/v3Query/QueryIterator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ public static QueryIterator Create(
enableOptimisticDirectExecution: queryRequestOptions.EnableOptimisticDirectExecution,
isHybridSearchQueryPlanOptimizationDisabled: queryRequestOptions.IsHybridSearchQueryPlanOptimizationDisabled,
enableDistributedQueryGatewayMode: queryRequestOptions.EnableDistributedQueryGatewayMode && (clientContext.ClientOptions.ConnectionMode == ConnectionMode.Gateway),
fullTextScoreScope: queryRequestOptions.FullTextScoreScope,
testInjections: queryRequestOptions.TestSettings);

return new QueryIterator(
Expand Down
19 changes: 19 additions & 0 deletions Microsoft.Azure.Cosmos/src/RequestOptions/QueryRequestOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,25 @@ public ConsistencyLevel? ConsistencyLevel
/// </summary>
public QueryTextMode QueryTextMode { get; set; } = QueryTextMode.None;

/// <summary>
/// Gets or sets the scope for computing BM25 statistics used by FullTextScore in hybrid search queries.
/// </summary>
/// <value>
/// The scope for computing BM25 statistics. Defaults to <see cref="FullTextScoreScope.Global"/>.
/// </value>
/// <remarks>
/// <para>
/// When set to <see cref="FullTextScoreScope.Global"/>, BM25 statistics (term frequency, inverse document frequency,
/// and document length) are computed across all documents in the container, including all physical and logical partitions.
/// </para>
/// <para>
/// When set to <see cref="FullTextScoreScope.Local"/>, statistics are computed only over the subset of documents
/// within the partition key values specified in the query request. This is useful for multi-tenant scenarios where scoring
/// should reflect statistics that are accurate for a specific tenant's dataset.
/// </para>
/// </remarks>
public FullTextScoreScope FullTextScoreScope { get; set; } = FullTextScoreScope.Global;

internal CosmosElement CosmosElementContinuationToken { get; set; }

internal string StartId { get; set; }
Expand Down
26 changes: 26 additions & 0 deletions Microsoft.Azure.Cosmos/src/Resource/Settings/FullTextScoreScope.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
//------------------------------------------------------------

namespace Microsoft.Azure.Cosmos
{
/// <summary>
/// Specifies the scope for computing BM25 statistics used by FullTextScore in hybrid search queries.
/// </summary>
public enum FullTextScoreScope
{
/// <summary>
/// Compute BM25 statistics (term frequency, inverse document frequency, and document length)
/// across all documents in the container, including all physical and logical partitions.
/// This is the default behavior.
/// </summary>
Global,

/// <summary>
/// Compute BM25 statistics only over the subset of documents within the partition key values
/// specified in the query. This is useful for multi-tenant scenarios where scoring should
/// reflect statistics that are accurate for a specific tenant's dataset.
/// </summary>
Local
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@ FROM c
WHERE FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John')
ORDER BY RANK FullTextScore(c.title, 'John')",
new List<List<int>>{ new List<int>{ 2, 57, 85 }, new List<int>{ 2, 85, 57 } }),
MakeSanityTest(@"
SELECT c.index AS Index, c.title AS Title, c.text AS Text
FROM c
WHERE (FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John')) AND (c.index = 2)
ORDER BY RANK FullTextScore(c.title, 'John')",
new List<List<int>>{ new List<int>{ 2 } }),
MakeSanityTest(@"
SELECT c.index AS Index, c.title AS Title, c.text AS Text
FROM c
WHERE FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John')
ORDER BY RANK FullTextScore(c.title, 'John')",
new List<List<int>>{ new List<int>{ 2 } },
new PartitionKey(2)),
MakeSanityTest(@"
SELECT TOP 10 c.index AS Index, c.title AS Title, c.text AS Text
FROM c
Expand Down Expand Up @@ -137,7 +150,6 @@ ORDER BY RANK RRF(VectorDistance(c.vector, {SampleVector}), FullTextScore(c.titl
}

[TestMethod]
[Ignore("This test is disabled because it needs an emulator refresh.")]
public async Task WeightedRankFusionTests()
{
List<SanityTestCase> testCases = new List<SanityTestCase>
Expand All @@ -148,33 +160,30 @@ FROM c
WHERE FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John') OR FullTextContains(c.text, 'United States')
ORDER BY RANK RRF(FullTextScore(c.title, 'John'), FullTextScore(c.text, 'United States'), [1, 1])",
new List<List<int>>{
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25, 22, 2, 66, 57, 85 },
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25, 22, 2, 66, 85, 57 },
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 2, 22, 85, 57 },
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 2, 22, 57, 85 },
}),
MakeSanityTest(@"
SELECT c.index AS Index, c.title AS Title, c.text AS Text
FROM c
WHERE FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John') OR FullTextContains(c.text, 'United States')
ORDER BY RANK RRF(FullTextScore(c.title, 'John'), FullTextScore(c.text, 'United States'), [10, 10])",
new List<List<int>>{
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25, 22, 2, 66, 57, 85 },
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25, 22, 2, 66, 85, 57 },
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 2, 22, 57, 85 },
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 2, 22, 85, 57 },
}),
MakeSanityTest(@"
SELECT TOP 10 c.index AS Index, c.title AS Title, c.text AS Text
FROM c
WHERE FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John') OR FullTextContains(c.text, 'United States')
ORDER BY RANK RRF(FullTextScore(c.title, 'John'), FullTextScore(c.text, 'United States'), [0.1, 0.1])",
new List<List<int>>{ new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25 } }),
new List<List<int>>{ new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 2 } }),
MakeSanityTest(@"
SELECT c.index AS Index, c.title AS Title, c.text AS Text
FROM c
WHERE FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John') OR FullTextContains(c.text, 'United States')
ORDER BY RANK RRF(FullTextScore(c.title, 'John'), FullTextScore(c.text, 'United States'), [-1, -1])",
new List<List<int>>{
new List<int>{ 85, 57, 66, 2, 22, 25, 77, 76, 80, 75, 24, 49, 54, 51, 81 },
new List<int>{ 57, 85, 2, 66, 22, 25, 80, 76, 77, 24, 75, 54, 49, 51, 61 },
}),
new List<List<int>>{ new List<int>{ 57, 85, 2, 22, 80, 76, 77, 24, 75, 54, 49, 51, 61 } }),
};

await this.RunTests(testCases);
Expand All @@ -190,41 +199,55 @@ await this.CreateIngestQueryDeleteAsync(
collectionTypes: CollectionTypes.MultiPartition, // | CollectionTypes.SinglePartition,
documents: documents,
query: (container, _) => RunTests(container, testCases),
partitionKey: "/index",
indexingPolicy: CompositeIndexPolicy);
}

private static async Task RunTests(Container container, IEnumerable<SanityTestCase> testCases)
{
foreach (SanityTestCase testCase in testCases)
foreach (FullTextScoreScope fullTextScoreScope in new[]{ FullTextScoreScope.Local, FullTextScoreScope.Global })
{
List<TextDocument> result = await RunQueryCombinationsAsync<TextDocument>(
container,
testCase.Query,
queryRequestOptions: null,
queryDrainingMode: QueryDrainingMode.HoldState);
foreach (SanityTestCase testCase in testCases)
{
QueryRequestOptions testRequestOptions = new QueryRequestOptions
{
FullTextScoreScope = fullTextScoreScope,
};

IEnumerable<int> actual = result.Select(document => document.Index);
if (testCase.PartitionKey.HasValue)
{
testRequestOptions.PartitionKey = testCase.PartitionKey;
}

bool match = false;
foreach (IReadOnlyList<int> expectedIndices in testCase.ExpectedIndices)
{
if (expectedIndices.SequenceEqual(actual))
List<TextDocument> result = await RunQueryCombinationsAsync<TextDocument>(
container,
testCase.Query,
queryRequestOptions: testRequestOptions,
queryDrainingMode: QueryDrainingMode.HoldState);

IEnumerable<int> actual = result.Select(document => document.Index);

bool match = false;
foreach (IReadOnlyList<int> expectedIndices in testCase.ExpectedIndices)
{
match = true;
break;
if (expectedIndices.SequenceEqual(actual))
{
match = true;
break;
}
}
}

if (!match)
{
Trace.WriteLine($"Query: {testCase.Query}");
Trace.WriteLine($"Actual: {string.Join(", ", actual)}");
if (!match)
{
Trace.WriteLine($"Query: {testCase.Query}");
Trace.WriteLine($"Actual: {string.Join(", ", actual)}");

string errorMessage = @"The query results did not match any of the expected results." +
"Please set HybridSearchCrossPartitionQueryPipelineStage.HybridSearchDebugTraceHelpers.Enabled = true to debug." +
"Usually, the failure may be due to some swaps in the results that have equal scores. You can see this in the debug output." +
"The solution is to add another expected result that matches the actual results (provided the scores are in decresing order).";
Assert.Fail(errorMessage);
string errorMessage = @"The query results did not match any of the expected results." +
"Please set HybridSearchCrossPartitionQueryPipelineStage.HybridSearchDebugTraceHelpers.Enabled = true to debug." +
"Usually, the failure may be due to some swaps in the results that have equal scores. You can see this in the debug output." +
"The solution is to add another expected result that matches the actual results (provided the scores are in decresing order).";
Assert.Fail(errorMessage);
}
}
}
}
Expand Down Expand Up @@ -254,12 +277,13 @@ private static IndexingPolicy CreateIndexingPolicy()
return policy;
}

private static SanityTestCase MakeSanityTest(string query, IReadOnlyList<IReadOnlyList<int>> expectedIndices)
private static SanityTestCase MakeSanityTest(string query, IReadOnlyList<IReadOnlyList<int>> expectedIndices, PartitionKey? partitionKey = null)
{
return new SanityTestCase
{
Query = query,
ExpectedIndices = expectedIndices,
PartitionKey = partitionKey,
};
}

Expand All @@ -268,6 +292,8 @@ private sealed class SanityTestCase
public string Query { get; init; }

public IReadOnlyList<IReadOnlyList<int>> ExpectedIndices { get; init; }

public PartitionKey? PartitionKey { get; init; }
}

private sealed class TextDocument
Expand Down
Loading
Loading