Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions lib/Doctrine/ODM/MongoDB/Aggregation/Stage/VectorSearch.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@
use Doctrine\ODM\MongoDB\Aggregation\Stage;
use Doctrine\ODM\MongoDB\Persisters\DocumentPersister;
use Doctrine\ODM\MongoDB\Query\Expr;
use InvalidArgumentException;
use MongoDB\BSON\Binary;
use MongoDB\BSON\Decimal128;
use MongoDB\BSON\Int64;

use function array_is_list;
use function is_array;
use function sprintf;

/**
* @phpstan-type Vector list<int|Int64>|list<float|Decimal128>|list<bool|0|1>|Binary
* @phpstan-type VectorSearchStageExpression array{
Expand All @@ -28,12 +33,15 @@
*/
class VectorSearch extends Stage
{
private ?bool $exact = null;
private ?Expr $filter = null;
private ?string $index = null;
private ?int $limit = null;
private ?int $numCandidates = null;
private ?string $path = null;
/** @see Binary::TYPE_VECTOR introduced in ext-mongodb 2.2 */
private const BINARY_TYPE_VECTOR = 9;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, will this stick around even after the ext-mongodb 2.2.0 release? I assume you don't want to raise the driver dependency just for this feature.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, we don't want to update the required version for this feature. But someone could receive a binary vector from the server and use in a query. That's why we have this constant duplicated here.


private ?bool $exact = null;
Comment thread
GromNaN marked this conversation as resolved.
private array|Expr|null $filter = null;
private ?string $index = null;
private ?int $limit = null;
private ?int $numCandidates = null;
private ?string $path = null;
/** @phpstan-var Vector|null */
private array|Binary|null $queryVector = null;

Expand All @@ -50,8 +58,10 @@ public function getExpression(): array
$params['exact'] = $this->exact;
}

if ($this->filter !== null) {
if ($this->filter instanceof Expr) {
$params['filter'] = $this->filter->getQuery();
} elseif (is_array($this->filter)) {
$params['filter'] = $this->filter;
}

if ($this->index !== null) {
Expand Down Expand Up @@ -84,7 +94,8 @@ public function exact(bool $exact): static
return $this;
}

public function filter(Expr $filter): static
/** @phpstan-param array<string, mixed>|Expr $filter */
public function filter(array|Expr $filter): static
{
$this->filter = $filter;

Expand Down Expand Up @@ -122,6 +133,18 @@ public function path(string $path): static
/** @phpstan-param Vector $queryVector */
public function queryVector(array|Binary $queryVector): static
{
if ($queryVector === []) {
throw new InvalidArgumentException('Query vector cannot be an empty array.');
}

if (is_array($queryVector) && ! array_is_list($queryVector)) {
throw new InvalidArgumentException('Query vector must be a list of numbers, got an associative array.');
Comment thread
GromNaN marked this conversation as resolved.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you actually intend to validate the contents of the list, or will you rely on the server to report that error?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is a partial validation. I don't know what vector type is expected so iterating over the values would just cost time.

}

if ($queryVector instanceof Binary && $queryVector->getType() !== self::BINARY_TYPE_VECTOR) {
throw new InvalidArgumentException(sprintf('Binary query vector must be of type 9 (Vector), got %d.', $queryVector->getType()));
}

$this->queryVector = $queryVector;

return $this;
Expand Down
6 changes: 6 additions & 0 deletions phpstan-baseline.neon
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,12 @@ parameters:
count: 1
path: lib/Doctrine/ODM/MongoDB/Aggregation/Stage/UnionWith.php

-
message: '#^Property Doctrine\\ODM\\MongoDB\\Aggregation\\Stage\\VectorSearch\:\:\$filter type has no value type specified in iterable type array\.$#'
identifier: missingType.iterableValue
count: 1
path: lib/Doctrine/ODM/MongoDB/Aggregation/Stage/VectorSearch.php

-
message: '#^Return type \(Doctrine\\ODM\\MongoDB\\Mapping\\ClassMetadataFactoryInterface\) of method Doctrine\\ODM\\MongoDB\\DocumentManager\:\:getMetadataFactory\(\) should be compatible with return type \(Doctrine\\Persistence\\Mapping\\ClassMetadataFactory\<Doctrine\\Persistence\\Mapping\\ClassMetadata\<object\>\>\) of method Doctrine\\Persistence\\ObjectManager\:\:getMetadataFactory\(\)$#'
identifier: method.childReturnType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
use Doctrine\ODM\MongoDB\Tests\BaseTestCase;
use Documents\User;
use Documents\VectorEmbedding;
use InvalidArgumentException;
use MongoDB\BSON\Binary;
use MongoDB\BSON\VectorType;
use PHPUnit\Framework\Attributes\TestWith;

use function enum_exists;

Expand All @@ -27,12 +29,19 @@ public function testEmptyStage(): void

public function testExact(): void
{
[$stage, $builder] = $this->createVectorSearchStage();
[$stage] = $this->createVectorSearchStage();
$stage->exact(true);
self::assertSame(['$vectorSearch' => ['exact' => true]], $stage->getExpression());
}

public function testFilter(): void
public function testFilterArray(): void
{
[$stage] = $this->createVectorSearchStage();
$stage->filter(['status' => ['$ne' => 'inactive']]);
self::assertSame(['$vectorSearch' => ['filter' => ['status' => ['$ne' => 'inactive']]]], $stage->getExpression());
}

public function testFilterExpr(): void
{
[$stage, $builder] = $this->createVectorSearchStage();
$stage->filter($builder->matchExpr()->field('status')->notEqual('inactive'));
Expand Down Expand Up @@ -97,6 +106,17 @@ public function testQueryVectorAcceptsBinary(): void
self::assertSame(['$vectorSearch' => ['queryVector' => $binaryVector]], $stage->getExpression());
}

#[TestWith([new Binary("\x03\x00\x01\x02\x03", Binary::TYPE_GENERIC), 'Binary query vector must be of type 9 (Vector), got 0.'])]
#[TestWith([[1 => 1, 2 => 3], 'Query vector must be a list of numbers, got an associative array.'])]
#[TestWith([[], 'Query vector cannot be an empty array.'])]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh cool, I had not seen these attributes before.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before attributes, it was the @testWith annotation but the syntax was awful.

public function testQueryVectorInvalidType(mixed $queryVector, string $message): void
{
[$stage] = $this->createVectorSearchStage();
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage($message);
$stage->queryVector($queryVector);
}

public function testChainingAllOptions(): void
{
[$stage, $builder] = $this->createVectorSearchStage();
Expand Down