Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 223 additions & 46 deletions src/Builders/PromptBuilder.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use WordPress\AiClient\Messages\Enums\ModalityEnum;
use WordPress\AiClient\Providers\Models\Contracts\ModelInterface;
use WordPress\AiClient\Providers\Models\DTO\ModelConfig;
use WordPress\AiClient\Providers\Models\DTO\ModelMetadata;
use WordPress\AiClient\Providers\Models\DTO\ModelRequirements;
use WordPress\AiClient\Providers\Models\DTO\RequiredOption;
use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum;
Expand Down Expand Up @@ -59,6 +60,11 @@ class PromptBuilder
*/
protected ?ModelInterface $model = null;

/**
* @var list<string> Ordered list of preference keys to check when selecting a model.
*/
protected array $modelPreferenceKeys = [];

/**
* @var string|null The provider ID or class name.
*/
Expand Down Expand Up @@ -216,9 +222,73 @@ public function usingModel(ModelInterface $model): self
}

/**
* Sets the model configuration.
* Sets preferred models to evaluate in order.
*
* @since n.e.x.t
*
* Merges the provided configuration with the builder's configuration,
* @param string|ModelInterface|array{0:string,1:string} ...$preferredModels The preferred models as model IDs,
* model instances, or [model ID, provider ID] tuples.
* @return self
*
* @throws InvalidArgumentException When a preferred model has an invalid type or identifier.
*/
public function usingModelPreference(...$preferredModels): self
{
if ($preferredModels === []) {
throw new InvalidArgumentException('At least one model preference must be provided.');
}

$preferenceKeys = [];

foreach ($preferredModels as $preferredModel) {
if (is_array($preferredModel)) {
// [model identifier, provider ID] tuple
if (!array_is_list($preferredModel) || count($preferredModel) !== 2) {
throw new InvalidArgumentException(
'Model preference tuple must contain model identifier and provider ID.'
);
}

[$modelIdentifier, $providerId] = $preferredModel;

$modelId = $this->normalizePreferenceIdentifier($modelIdentifier);
$providerId = $this->normalizePreferenceIdentifier(
$providerId,
'Model preference provider identifiers cannot be empty.'
);

$preferenceKey = $this->createProviderModelPreferenceKey($providerId, $modelId);
} elseif ($preferredModel instanceof ModelInterface) {
// Model instance
$modelId = $preferredModel->metadata()->getId();
$providerId = $preferredModel->providerMetadata()->getId();

$preferenceKey = $this->createProviderModelPreferenceKey($providerId, $modelId);
} elseif (is_string($preferredModel)) {
// Model ID
$modelId = $this->normalizePreferenceIdentifier($preferredModel);

$preferenceKey = $this->createModelPreferenceKey($modelId);
} else {
// Invalid type
throw new InvalidArgumentException(
'Model preferences must be model identifiers, instances of ModelInterface, ' .
'or provider/model tuples.'
);
}

$preferenceKeys[] = $preferenceKey;
}

$this->modelPreferenceKeys = $preferenceKeys;

return $this;
}

/**
* Sets the model configuration.
*
* Merges the provided configuration with the builder's configuration,
* with builder configuration taking precedence.
*
* @since 0.1.0
Expand Down Expand Up @@ -1040,67 +1110,174 @@ private function getConfiguredModel(CapabilityEnum $capability): ModelInterface
{
$requirements = ModelRequirements::fromPromptData($capability, $this->messages, $this->modelConfig);

// If a model has been explicitly set, return it
if ($this->model !== null) {
// Explicit model was provided via usingModel(); just update config and bind dependencies.
$this->model->setConfig($this->modelConfig);
$this->registry->bindModelDependencies($this->model);
return $this->model;
}

// Find a suitable model based on requirements
if ($this->providerIdOrClassName === null) {
$providerModelsMetadata = $this->registry->findModelsMetadataForSupport($requirements);
// Retrieve the candidate models map which satisfies the requirements.
$candidateMap = $this->getCandidateModelsMap($requirements);

if (empty($providerModelsMetadata)) {
throw new InvalidArgumentException(
sprintf(
'No models found that support the required capabilities and options for this prompt. ' .
'Required capabilities: %s. Required options: %s',
implode(', ', array_map(function ($cap) {
return $cap->value;
}, $requirements->getRequiredCapabilities())),
implode(', ', array_map(function ($opt) {
return $opt->getName()->value . '=' . json_encode($opt->getValue());
}, $requirements->getRequiredOptions()))
)
if (empty($candidateMap)) {
$message = sprintf(
'No models found that support %s for this prompt.',
$capability->value
);

if ($this->providerIdOrClassName !== null) {
$message = sprintf(
'No models found for provider "%s" that support %s for this prompt.',
$this->providerIdOrClassName,
$capability->value
);
}

$firstProviderModels = $providerModelsMetadata[0];
$provider = $firstProviderModels->getProvider()->getId();
$modelMetadata = $firstProviderModels->getModels()[0];
} else {
$modelsMetadata = $this->registry->findProviderModelsMetadataForSupport(
$this->providerIdOrClassName,
$requirements
throw new InvalidArgumentException($message);
}

// Check if any preferred models match the candidates, in priority order.
if (!empty($this->modelPreferenceKeys)) {
// Find preferences that match available candidates, preserving preference order.
$matchingPreferences = array_intersect_key(
array_flip($this->modelPreferenceKeys),
$candidateMap
);

if (empty($modelsMetadata)) {
throw new InvalidArgumentException(
sprintf(
'No models found for %s that support the required capabilities and options for this prompt. ' .
'Required capabilities: %s. Required options: %s',
$this->providerIdOrClassName,
implode(', ', array_map(function ($cap) {
return $cap->value;
}, $requirements->getRequiredCapabilities())),
implode(', ', array_map(function ($opt) {
return $opt->getName()->value . '=' . json_encode($opt->getValue());
}, $requirements->getRequiredOptions()))
)
);
if (!empty($matchingPreferences)) {
// Get the first matching preference key
$firstMatchKey = key($matchingPreferences);
[$providerId, $modelId] = $candidateMap[$firstMatchKey];

return $this->registry->getProviderModel($providerId, $modelId, $this->modelConfig);
}
}

// No preference matched; fall back to the first candidate discovered.
[$providerId, $modelId] = reset($candidateMap);

$provider = $this->providerIdOrClassName;
$modelMetadata = $modelsMetadata[0];
return $this->registry->getProviderModel($providerId, $modelId, $this->modelConfig);
}

/**
* Builds a map of candidate models that satisfy the requirements for efficient lookup.
*
* @since n.e.x.t
*
* @param ModelRequirements $requirements The requirements derived from the prompt.
* @return array<string, array{0:string,1:string}> Map of preference keys to [providerId, modelId] tuples.
*/
private function getCandidateModelsMap(ModelRequirements $requirements): array
{
if ($this->providerIdOrClassName === null) {
// No provider locked in, gather all models across providers that meet requirements.
$providerModelsMetadata = $this->registry->findModelsMetadataForSupport($requirements);

$candidateMap = [];
foreach ($providerModelsMetadata as $providerModels) {
$providerId = $providerModels->getProvider()->getId();
$providerMap = $this->generateMapFromCandidates($providerId, $providerModels->getModels());

// Use + operator to merge, preserving keys from $candidateMap (first provider wins for model-only keys)
$candidateMap = $candidateMap + $providerMap;
}

return $candidateMap;
}

// Get the model instance from the provider
return $this->registry->getProviderModel(
$provider,
$modelMetadata->getId(),
$this->modelConfig
// Provider set, only consider models from that provider.
$modelsMetadata = $this->registry->findProviderModelsMetadataForSupport(
$this->providerIdOrClassName,
$requirements
);

// Ensure we pass the provider ID, not the class name
$providerId = $this->registry->getProviderId($this->providerIdOrClassName);

return $this->generateMapFromCandidates($providerId, $modelsMetadata);
}

/**
* Generates a candidate map from model metadata with both provider-specific and model-only keys.
*
* @since n.e.x.t
*
* @param string $providerId The provider ID.
* @param list<ModelMetadata> $modelsMetadata The models metadata to map.
* @return array<string, array{0:string,1:string}> Map of preference keys to [providerId, modelId] tuples.
*/
private function generateMapFromCandidates(string $providerId, array $modelsMetadata): array
{
$map = [];

foreach ($modelsMetadata as $modelMetadata) {
$modelId = $modelMetadata->getId();

// Add provider-specific key
$providerModelKey = $this->createProviderModelPreferenceKey($providerId, $modelId);
$map[$providerModelKey] = [$providerId, $modelId];

// Add model-only key
$modelKey = $this->createModelPreferenceKey($modelId);
$map[$modelKey] = [$providerId, $modelId];
}

return $map;
}

/**
* Normalizes and validates a preference identifier string.
*
* @since n.e.x.t
*
* @param mixed $value The value to normalize.
* @param string $emptyMessage The message for empty or invalid values.
* @return string The normalized identifier.
*
* @throws InvalidArgumentException If the value is not a non-empty string.
*/
private function normalizePreferenceIdentifier(
$value,
string $emptyMessage = 'Model preference identifiers cannot be empty.'
): string {
if (!is_string($value)) {
throw new InvalidArgumentException($emptyMessage);
}

$trimmed = trim($value);
if ($trimmed === '') {
throw new InvalidArgumentException($emptyMessage);
}

return $trimmed;
}

/**
* Creates a preference key for a provider/model combination.
*
* @since n.e.x.t
*
* @param string $providerId The provider identifier.
* @param string $modelId The model identifier.
* @return string The generated preference key.
*/
private function createProviderModelPreferenceKey(string $providerId, string $modelId): string
{
return 'providerModel::' . $providerId . '::' . $modelId;
}

/**
* Creates a preference key for a model identifier.
*
* @since n.e.x.t
*
* @param string $modelId The model identifier.
* @return string The generated preference key.
*/
private function createModelPreferenceKey(string $modelId): string
{
return 'model::' . $modelId;
}

/**
Expand Down
31 changes: 31 additions & 0 deletions src/Providers/ProviderRegistry.php
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,37 @@ public function getProviderClassName(string $id): string
return $this->providerClassNames[$id];
}

/**
* Gets the provider ID for a registered provider.
*
* @since n.e.x.t
*
* @param string|class-string<ProviderInterface> $idOrClassName The provider ID or class name.
* @return string The provider ID.
* @throws InvalidArgumentException If the provider is not registered.
*/
public function getProviderId(string $idOrClassName): string
{
// If it's already an ID, return it
if (isset($this->providerClassNames[$idOrClassName])) {
return $idOrClassName;
}

// If it's a class name, find its ID
if (isset($this->registeredClassNames[$idOrClassName])) {
foreach ($this->providerClassNames as $id => $className) {
if ($className === $idOrClassName) {
return $id;
}
}
}

// Not found
throw new InvalidArgumentException(
sprintf('Provider not registered: %s', $idOrClassName)
);
}

/**
* Checks if a provider is properly configured.
*
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/AiClientTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ public function testGenerateResultWithNullModelDelegatesToPromptBuilder(): void

// This should delegate to PromptBuilder's intelligent discovery
$this->expectException(\InvalidArgumentException::class);
$this->expectExceptionMessage('No models found that support the required capabilities');
$this->expectExceptionMessage('No models found that support text_generation for this prompt.');

AiClient::generateResult($prompt, null, $this->createMockEmptyRegistry());
}
Expand Down Expand Up @@ -441,7 +441,7 @@ public function testGenerateResultWithModelConfigDelegatesToPromptBuilder(): voi
$config->setMaxTokens(100);

$this->expectException(\InvalidArgumentException::class);
$this->expectExceptionMessage('No models found that support the required capabilities');
$this->expectExceptionMessage('No models found that support text_generation for this prompt.');

AiClient::generateResult($prompt, $config, $this->createMockEmptyRegistry());
}
Expand Down Expand Up @@ -613,7 +613,7 @@ function () {
$this->fail("Expected InvalidArgumentException for configuration $index");
} catch (\InvalidArgumentException $e) {
$this->assertStringContainsString(
'No models found that support the required capabilities',
'No models found that support text_generation for this prompt.',
$e->getMessage(),
"Configuration $index should delegate to PromptBuilder properly"
);
Expand All @@ -630,7 +630,7 @@ public function testEmptyModelConfig(): void
$emptyConfig = new ModelConfig();

$this->expectException(\InvalidArgumentException::class);
$this->expectExceptionMessage('No models found that support the required capabilities');
$this->expectExceptionMessage('No models found that support text_generation for this prompt.');

AiClient::generateResult($prompt, $emptyConfig, $this->createMockEmptyRegistry());
}
Expand Down
Loading