diff --git a/lib/Doctrine/ODM/MongoDB/Aggregation/Stage/VectorSearch.php b/lib/Doctrine/ODM/MongoDB/Aggregation/Stage/VectorSearch.php index 72e8a8bea..97d1ba723 100644 --- a/lib/Doctrine/ODM/MongoDB/Aggregation/Stage/VectorSearch.php +++ b/lib/Doctrine/ODM/MongoDB/Aggregation/Stage/VectorSearch.php @@ -8,10 +8,15 @@ use Doctrine\ODM\MongoDB\Aggregation\Stage; use Doctrine\ODM\MongoDB\Persisters\DocumentPersister; use Doctrine\ODM\MongoDB\Query\Expr; +use InvalidArgumentException; use MongoDB\BSON\Binary; use MongoDB\BSON\Decimal128; use MongoDB\BSON\Int64; +use function array_is_list; +use function is_array; +use function sprintf; + /** * @phpstan-type Vector list|list|list|Binary * @phpstan-type VectorSearchStageExpression array{ @@ -28,12 +33,15 @@ */ class VectorSearch extends Stage { - private ?bool $exact = null; - private ?Expr $filter = null; - private ?string $index = null; - private ?int $limit = null; - private ?int $numCandidates = null; - private ?string $path = null; + /** @see Binary::TYPE_VECTOR introduced in ext-mongodb 2.2 */ + private const BINARY_TYPE_VECTOR = 9; + + private ?bool $exact = null; + private array|Expr|null $filter = null; + private ?string $index = null; + private ?int $limit = null; + private ?int $numCandidates = null; + private ?string $path = null; /** @phpstan-var Vector|null */ private array|Binary|null $queryVector = null; @@ -50,8 +58,10 @@ public function getExpression(): array $params['exact'] = $this->exact; } - if ($this->filter !== null) { + if ($this->filter instanceof Expr) { $params['filter'] = $this->filter->getQuery(); + } elseif (is_array($this->filter)) { + $params['filter'] = $this->filter; } if ($this->index !== null) { @@ -84,7 +94,8 @@ public function exact(bool $exact): static return $this; } - public function filter(Expr $filter): static + /** @phpstan-param array|Expr $filter */ + public function filter(array|Expr $filter): static { $this->filter = $filter; @@ -122,6 +133,18 @@ public function path(string $path): static /** @phpstan-param Vector $queryVector */ public function queryVector(array|Binary $queryVector): static { + if ($queryVector === []) { + throw new InvalidArgumentException('Query vector cannot be an empty array.'); + } + + if (is_array($queryVector) && ! array_is_list($queryVector)) { + throw new InvalidArgumentException('Query vector must be a list of numbers, got an associative array.'); + } + + if ($queryVector instanceof Binary && $queryVector->getType() !== self::BINARY_TYPE_VECTOR) { + throw new InvalidArgumentException(sprintf('Binary query vector must be of type 9 (Vector), got %d.', $queryVector->getType())); + } + $this->queryVector = $queryVector; return $this; diff --git a/phpstan-baseline.neon b/phpstan-baseline.neon index e39430ff4..b123bec67 100644 --- a/phpstan-baseline.neon +++ b/phpstan-baseline.neon @@ -330,6 +330,12 @@ parameters: count: 1 path: lib/Doctrine/ODM/MongoDB/Aggregation/Stage/UnionWith.php + - + message: '#^Property Doctrine\\ODM\\MongoDB\\Aggregation\\Stage\\VectorSearch\:\:\$filter type has no value type specified in iterable type array\.$#' + identifier: missingType.iterableValue + count: 1 + path: lib/Doctrine/ODM/MongoDB/Aggregation/Stage/VectorSearch.php + - message: '#^Return type \(Doctrine\\ODM\\MongoDB\\Mapping\\ClassMetadataFactoryInterface\) of method Doctrine\\ODM\\MongoDB\\DocumentManager\:\:getMetadataFactory\(\) should be compatible with return type \(Doctrine\\Persistence\\Mapping\\ClassMetadataFactory\\>\) of method Doctrine\\Persistence\\ObjectManager\:\:getMetadataFactory\(\)$#' identifier: method.childReturnType diff --git a/tests/Doctrine/ODM/MongoDB/Tests/Aggregation/Stage/VectorSearchTest.php b/tests/Doctrine/ODM/MongoDB/Tests/Aggregation/Stage/VectorSearchTest.php index e17d75428..314296d57 100644 --- a/tests/Doctrine/ODM/MongoDB/Tests/Aggregation/Stage/VectorSearchTest.php +++ b/tests/Doctrine/ODM/MongoDB/Tests/Aggregation/Stage/VectorSearchTest.php @@ -10,8 +10,10 @@ use Doctrine\ODM\MongoDB\Tests\BaseTestCase; use Documents\User; use Documents\VectorEmbedding; +use InvalidArgumentException; use MongoDB\BSON\Binary; use MongoDB\BSON\VectorType; +use PHPUnit\Framework\Attributes\TestWith; use function enum_exists; @@ -27,12 +29,19 @@ public function testEmptyStage(): void public function testExact(): void { - [$stage, $builder] = $this->createVectorSearchStage(); + [$stage] = $this->createVectorSearchStage(); $stage->exact(true); self::assertSame(['$vectorSearch' => ['exact' => true]], $stage->getExpression()); } - public function testFilter(): void + public function testFilterArray(): void + { + [$stage] = $this->createVectorSearchStage(); + $stage->filter(['status' => ['$ne' => 'inactive']]); + self::assertSame(['$vectorSearch' => ['filter' => ['status' => ['$ne' => 'inactive']]]], $stage->getExpression()); + } + + public function testFilterExpr(): void { [$stage, $builder] = $this->createVectorSearchStage(); $stage->filter($builder->matchExpr()->field('status')->notEqual('inactive')); @@ -97,6 +106,17 @@ public function testQueryVectorAcceptsBinary(): void self::assertSame(['$vectorSearch' => ['queryVector' => $binaryVector]], $stage->getExpression()); } + #[TestWith([new Binary("\x03\x00\x01\x02\x03", Binary::TYPE_GENERIC), 'Binary query vector must be of type 9 (Vector), got 0.'])] + #[TestWith([[1 => 1, 2 => 3], 'Query vector must be a list of numbers, got an associative array.'])] + #[TestWith([[], 'Query vector cannot be an empty array.'])] + public function testQueryVectorInvalidType(mixed $queryVector, string $message): void + { + [$stage] = $this->createVectorSearchStage(); + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage($message); + $stage->queryVector($queryVector); + } + public function testChainingAllOptions(): void { [$stage, $builder] = $this->createVectorSearchStage();