Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 89 additions & 5 deletions src/Builders/PromptBuilder.php
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ class PromptBuilder
*/
protected ?ModelInterface $model = null;

/**
* @var list<ModelInterface|string> Preferred models to evaluate in order.
*/
protected array $modelPreferences = [];

/**
* @var string|null The provider ID or class name.
*/
Expand Down Expand Up @@ -215,6 +220,49 @@ public function usingModel(ModelInterface $model): self
return $this;
}

/**
* Sets preferred models to evaluate in order.
*
* @since n.e.x.t
*
* @param mixed ...$preferredModels The preferred models as model IDs or instances.
* @return self
*
* @throws InvalidArgumentException When a preferred model has an invalid type or identifier.
*/
public function usingModelPreference(...$preferredModels): self
{
if ($preferredModels === []) {
throw new InvalidArgumentException('At least one model preference must be provided.');
}

$normalizedPreferences = [];

foreach ($preferredModels as $preferredModel) {
if ($preferredModel instanceof ModelInterface) {
$normalizedPreferences[] = $preferredModel;
continue;
}

if (is_string($preferredModel)) {
$trimmed = trim($preferredModel);
if ($trimmed === '') {
throw new InvalidArgumentException('Model preference identifiers cannot be empty.');
}
$normalizedPreferences[] = $trimmed;
continue;
}

throw new InvalidArgumentException(
'Model preferences must be model identifiers or instances of ModelInterface.'
);
}

$this->modelPreferences = $normalizedPreferences;

return $this;
}

/**
* Sets the model configuration.
*
Expand Down Expand Up @@ -1024,6 +1072,40 @@ protected function appendPartToMessages(MessagePart $part): void
$this->messages[] = new UserMessage([$part]);
}

/**
* Gets the first preferred model that can be instantiated.
*
* @since n.e.x.t
*
* @return ModelInterface|null The instantiated preferred model, or null if none found.
*/
private function getAvailablePreferredModel(): ?ModelInterface
{
if ($this->modelPreferences === []) {
return null;
}

foreach ($this->modelPreferences as $preferredModel) {
if ($preferredModel instanceof ModelInterface) {
return $preferredModel;
}

$providerIds = $this->providerIdOrClassName !== null
? [$this->providerIdOrClassName]
: $this->registry->getRegisteredProviderIds();

foreach ($providerIds as $providerId) {
try {
return $this->registry->getProviderModel($providerId, $preferredModel);
} catch (InvalidArgumentException $exception) {
continue;
}
}
}

return null;
}

/**
* Gets the model to use for generation.
*
Expand All @@ -1040,11 +1122,13 @@ private function getConfiguredModel(CapabilityEnum $capability): ModelInterface
{
$requirements = ModelRequirements::fromPromptData($capability, $this->messages, $this->modelConfig);

// If a model has been explicitly set, return it
if ($this->model !== null) {
$this->model->setConfig($this->modelConfig);
$this->registry->bindModelDependencies($this->model);
return $this->model;
$explicitModel = $this->model ?? $this->getAvailablePreferredModel();

if ($explicitModel !== null) {
$explicitModel->setConfig($this->modelConfig);
$this->registry->bindModelDependencies($explicitModel);
$this->model = $explicitModel;
return $explicitModel;
}

// Find a suitable model based on requirements
Expand Down
219 changes: 219 additions & 0 deletions tests/unit/Builders/PromptBuilderTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
use PHPUnit\Framework\TestCase;
use RuntimeException;
use WordPress\AiClient\Builders\PromptBuilder;
use WordPress\AiClient\Common\Exception\InvalidArgumentException as AiInvalidArgumentException;
use WordPress\AiClient\Files\DTO\File;
use WordPress\AiClient\Files\Enums\FileTypeEnum;
use WordPress\AiClient\Messages\DTO\Message;
Expand All @@ -18,6 +19,7 @@
use WordPress\AiClient\Messages\Enums\MessageRoleEnum;
use WordPress\AiClient\Messages\Enums\ModalityEnum;
use WordPress\AiClient\Providers\DTO\ProviderMetadata;
use WordPress\AiClient\Providers\DTO\ProviderModelsMetadata;
use WordPress\AiClient\Providers\Enums\ProviderTypeEnum;
use WordPress\AiClient\Providers\Models\Contracts\ModelInterface;
use WordPress\AiClient\Providers\Models\DTO\ModelConfig;
Expand Down Expand Up @@ -61,6 +63,25 @@ private function createTestProviderMetadata(): ProviderMetadata
return new ProviderMetadata('test-provider', 'Test Provider', ProviderTypeEnum::cloud());
}

/**
* Creates text model metadata supporting any input modalities.
*
* @param string $id The model identifier.
* @return ModelMetadata
*/
private function createTextModelMetadataWithInputSupport(string $id): ModelMetadata
{
return new ModelMetadata(
$id,
'Test Text Model',
[CapabilityEnum::textGeneration()],
[
new SupportedOption(OptionEnum::inputModalities()),
new SupportedOption(OptionEnum::outputModalities()),
]
);
}

/**
* Creates a mock model that implements both ModelInterface and SpeechGenerationModelInterface.
*
Expand Down Expand Up @@ -587,6 +608,204 @@ public function testUsingModel(): void
$this->assertSame($model, $actualModel);
}

/**
* Tests usingModelPreference selects provided model instance when requirements are met.
*
* @return void
*/
public function testUsingModelPreferenceWithModelInstance(): void
{
$result = $this->createTestResult('Preferred model result');
$metadata = $this->createTextModelMetadataWithInputSupport('preferred-model');
$model = $this->createMockTextGenerationModel($result, $metadata);

$this->registry->expects($this->once())
->method('bindModelDependencies')
->with($model);

$this->registry->expects($this->never())
->method('findModelsMetadataForSupport');

$this->registry->expects($this->never())
->method('findProviderModelsMetadataForSupport');

$builder = new PromptBuilder($this->registry, 'Test prompt');
$builder->usingModelPreference($model);

$actualResult = $builder->generateTextResult();

$this->assertSame($result, $actualResult);

$reflection = new \ReflectionClass($builder);
$modelProperty = $reflection->getProperty('model');
$modelProperty->setAccessible(true);

$this->assertSame($model, $modelProperty->getValue($builder));
}

/**
* Tests usingModelPreference selects the first available model ID for the configured provider.
*
* @return void
*/
public function testUsingModelPreferencePrefersFirstAvailableModelId(): void
{
$result = $this->createTestResult('Preferred by ID');
$metadata = $this->createTextModelMetadataWithInputSupport('preferred-id');
$model = $this->createMockTextGenerationModel($result, $metadata);

$this->registry->expects($this->once())
->method('getProviderModel')
->with('test-provider', 'preferred-id', $this->isNull())
->willReturn($model);

$this->registry->expects($this->never())
->method('findModelsMetadataForSupport');

$this->registry->expects($this->once())
->method('bindModelDependencies')
->with($model);

$builder = new PromptBuilder($this->registry, 'Test prompt');
$builder->usingProvider('test-provider');
$builder->usingModelPreference('preferred-id', 'secondary-id');

$actualResult = $builder->generateTextResult();

$this->assertSame($result, $actualResult);
}

/**
* Tests usingModelPreference skips unavailable model IDs and falls back to the next preference.
*
* @return void
*/
public function testUsingModelPreferenceSkipsUnavailableModelId(): void
{
$result = $this->createTestResult('Fallback model result');
$metadata = $this->createTextModelMetadataWithInputSupport('fallback-id');
$model = $this->createMockTextGenerationModel($result, $metadata);

$this->registry->expects($this->exactly(2))
->method('getProviderModel')
->withConsecutive(
['test-provider', 'missing-id', $this->isNull()],
['test-provider', 'fallback-id', $this->isNull()]
)
->willReturnOnConsecutiveCalls(
$this->throwException(new AiInvalidArgumentException('missing model')),
$model
);

$this->registry->expects($this->never())
->method('findModelsMetadataForSupport');

$this->registry->expects($this->once())
->method('bindModelDependencies')
->with($model);

$builder = new PromptBuilder($this->registry, 'Test prompt');
$builder->usingProvider('test-provider');
$builder->usingModelPreference('missing-id', 'fallback-id');

$actualResult = $builder->generateTextResult();

$this->assertSame($result, $actualResult);
}

/**
* Tests usingModelPreference falls back to discovery when no preferences are available.
*
* @return void
*/
public function testUsingModelPreferenceFallsBackToDiscovery(): void
{
$result = $this->createTestResult('Discovered model result');
$metadata = $this->createTextModelMetadataWithInputSupport('discovered-id');
$providerMetadata = $this->createTestProviderMetadata();
$providerModelsMetadata = new ProviderModelsMetadata($providerMetadata, [$metadata]);

$model = $this->createMockTextGenerationModel($result, $metadata);

$this->registry->expects($this->once())
->method('getRegisteredProviderIds')
->willReturn([$providerMetadata->getId()]);

$this->registry->expects($this->once())
->method('findModelsMetadataForSupport')
->with($this->isInstanceOf(ModelRequirements::class))
->willReturn([$providerModelsMetadata]);

$this->registry->expects($this->exactly(2))
->method('getProviderModel')
->withConsecutive(
[$providerMetadata->getId(), 'unavailable-model', $this->isNull()],
[$providerMetadata->getId(), 'discovered-id', $this->isInstanceOf(ModelConfig::class)]
)
->willReturnOnConsecutiveCalls(
$this->throwException(new AiInvalidArgumentException('missing model')),
$model
);

$this->registry->expects($this->never())
->method('findProviderModelsMetadataForSupport');

$this->registry->expects($this->never())
->method('bindModelDependencies');

$builder = new PromptBuilder($this->registry, 'Test prompt');
$builder->usingModelPreference('unavailable-model');

$actualResult = $builder->generateTextResult();

$this->assertSame($result, $actualResult);
}

/**
* Tests usingModelPreference rejects invalid preference types.
*
* @return void
*/
public function testUsingModelPreferenceWithInvalidTypeThrowsException(): void
{
$builder = new PromptBuilder($this->registry);

$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Model preferences must be model identifiers or instances of ModelInterface.');

$builder->usingModelPreference(123);
}

/**
* Tests usingModelPreference rejects empty preference identifier strings.
*
* @return void
*/
public function testUsingModelPreferenceWithEmptyIdentifierThrowsException(): void
{
$builder = new PromptBuilder($this->registry);

$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Model preference identifiers cannot be empty.');

$builder->usingModelPreference(' ');
}

/**
* Tests usingModelPreference rejects calls without preferences.
*
* @return void
*/
public function testUsingModelPreferenceWithoutArgumentsThrowsException(): void
{
$builder = new PromptBuilder($this->registry);

$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('At least one model preference must be provided.');

$builder->usingModelPreference();
}

/**
* Tests usingModelConfig method.
*
Expand Down