Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 225 additions & 33 deletions src/Builders/PromptBuilder.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use WordPress\AiClient\Messages\Enums\ModalityEnum;
use WordPress\AiClient\Providers\Models\Contracts\ModelInterface;
use WordPress\AiClient\Providers\Models\DTO\ModelConfig;
use WordPress\AiClient\Providers\Models\DTO\ModelMetadata;
use WordPress\AiClient\Providers\Models\DTO\ModelRequirements;
use WordPress\AiClient\Providers\Models\DTO\RequiredOption;
use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum;
Expand Down Expand Up @@ -59,6 +60,12 @@ class PromptBuilder
*/
protected ?ModelInterface $model = null;

/**
* @var array<string, string|array{0:string,1:string}> Preferred model map keyed by preference key with
* provider-model mappings.
*/
protected array $modelPreferenceMap = [];

/**
* @var string|null The provider ID or class name.
*/
Expand Down Expand Up @@ -216,9 +223,74 @@ public function usingModel(ModelInterface $model): self
}

/**
* Sets the model configuration.
* Sets preferred models to evaluate in order.
*
* @since n.e.x.t
*
* Merges the provided configuration with the builder's configuration,
* @param string|ModelInterface|array{0:string,1:string} ...$preferredModels The preferred models as model IDs,
* model instances, or [model ID, provider ID] tuples.
* @return self
*
* @throws InvalidArgumentException When a preferred model has an invalid type or identifier.
*/
public function usingModelPreference(...$preferredModels): self
{
if ($preferredModels === []) {
throw new InvalidArgumentException('At least one model preference must be provided.');
}

$preferenceMap = [];

foreach ($preferredModels as $preferredModel) {
if (is_array($preferredModel)) {
// [model identifier, provider ID] tuple
if (!array_is_list($preferredModel) || count($preferredModel) !== 2) {
throw new InvalidArgumentException(
'Model preference tuple must contain model identifier and provider ID.'
);
}

[$modelIdentifier, $providerId] = $preferredModel;

$modelId = $this->normalizePreferenceIdentifier($modelIdentifier);
$providerId = $this->normalizePreferenceIdentifier(
$providerId,
'Model preference provider identifiers cannot be empty.'
);

$key = $this->createProviderModelPreferenceKey($providerId, $modelId);
$preferenceMap[$key] = [$modelId, $providerId];

} elseif ($preferredModel instanceof ModelInterface) {
// Model instance
$modelId = $preferredModel->metadata()->getId();
$providerId = $preferredModel->providerMetadata()->getId();
$providerModelKey = $this->createProviderModelPreferenceKey($providerId, $modelId);
$preferenceMap[$providerModelKey] = [$modelId, $providerId];

} elseif (is_string($preferredModel)) {
// Model ID
$modelId = $this->normalizePreferenceIdentifier($preferredModel);
$modelKey = $this->createModelPreferenceKey($modelId);
$preferenceMap[$modelKey] = $modelId;

} else {
// Invalid type
throw new InvalidArgumentException(
'Model preferences must be model identifiers, instances of ModelInterface, or provider/model tuples.'
);
}
}

$this->modelPreferenceMap = $preferenceMap;

return $this;
}

/**
* Sets the model configuration.
*
* Merges the provided configuration with the builder's configuration,
* with builder configuration taking precedence.
*
* @since 0.1.0
Expand Down Expand Up @@ -1040,15 +1112,46 @@ private function getConfiguredModel(CapabilityEnum $capability): ModelInterface
{
$requirements = ModelRequirements::fromPromptData($capability, $this->messages, $this->modelConfig);

// If a model has been explicitly set, return it
if ($this->model !== null) {
// Explicit model was provided via usingModel(); just update config and bind dependencies.
$this->model->setConfig($this->modelConfig);
$this->registry->bindModelDependencies($this->model);
return $this->model;
}

// Find a suitable model based on requirements
// Retrieve the models which satisfy the requirements.
$candidateModels = $this->getCandidateModels($requirements);

// Find the model which matches the preference, if any.
$selectedModel = $this->findModelByPreference($candidateModels);

if ($selectedModel === null) {
// No preference matched; fall back to the first candidate discovered.
[$providerId, $modelMetadata] = $candidateModels[0];
$selectedModel = $this->registry->getProviderModel(
$providerId,
$modelMetadata->getId(),
$this->modelConfig
);
}

return $selectedModel;
}

/**
* Builds the list of candidate provider/model combinations that satisfy the requirements.
*
* @since n.e.x.t
*
* @param ModelRequirements $requirements The requirements derived from the prompt.
* @return list<array{0:string,1:ModelMetadata}> The candidate provider/model tuples.
*
* @throws InvalidArgumentException If no suitable models are found.
*/
private function getCandidateModels(ModelRequirements $requirements): array
{
if ($this->providerIdOrClassName === null) {
// No provider locked in, gather all models across providers that meet requirements.
$providerModelsMetadata = $this->registry->findModelsMetadataForSupport($requirements);

if (empty($providerModelsMetadata)) {
Expand All @@ -1066,41 +1169,130 @@ private function getConfiguredModel(CapabilityEnum $capability): ModelInterface
);
}

$firstProviderModels = $providerModelsMetadata[0];
$provider = $firstProviderModels->getProvider()->getId();
$modelMetadata = $firstProviderModels->getModels()[0];
} else {
$modelsMetadata = $this->registry->findProviderModelsMetadataForSupport(
$this->providerIdOrClassName,
$requirements
$candidateModels = [];
foreach ($providerModelsMetadata as $providerModels) {
$providerId = $providerModels->getProvider()->getId();
foreach ($providerModels->getModels() as $modelMetadata) {
$candidateModels[] = [$providerId, $modelMetadata];
}
}

return $candidateModels;
}

// Provider set, only consider models from that provider.
$modelsMetadata = $this->registry->findProviderModelsMetadataForSupport(
$this->providerIdOrClassName,
$requirements
);

if (empty($modelsMetadata)) {
throw new InvalidArgumentException(
sprintf(
'No models found for %s that support the required capabilities and options for this prompt. ' .
'Required capabilities: %s. Required options: %s',
$this->providerIdOrClassName,
implode(', ', array_map(function ($cap) {
return $cap->value;
}, $requirements->getRequiredCapabilities())),
implode(', ', array_map(function ($opt) {
return $opt->getName()->value . '=' . json_encode($opt->getValue());
}, $requirements->getRequiredOptions()))
)
);
}

if (empty($modelsMetadata)) {
throw new InvalidArgumentException(
sprintf(
'No models found for %s that support the required capabilities and options for this prompt. ' .
'Required capabilities: %s. Required options: %s',
$this->providerIdOrClassName,
implode(', ', array_map(function ($cap) {
return $cap->value;
}, $requirements->getRequiredCapabilities())),
implode(', ', array_map(function ($opt) {
return $opt->getName()->value . '=' . json_encode($opt->getValue());
}, $requirements->getRequiredOptions()))
)
);
$candidateModels = [];
foreach ($modelsMetadata as $modelMetadata) {
$candidateModels[] = [$this->providerIdOrClassName, $modelMetadata];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this can be an ID or a class name, we need to normalize this to always be the ID, in order for the lookup to work.

We already ProviderRegistry::getProviderClassName($id). Maybe we should also have ProviderRegistry::getProviderId($className), and that we could then use here if the value is a class name. That should be fairly straightforward to add.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored in 79a7a56

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you actually addressed the concern I mentioned above?

}

return $candidateModels;
}

/**
* Finds a preferred model among the given candidates.
*
* @since n.e.x.t
*
* @param list<array{0:string,1:ModelMetadata}> $candidateModels Candidate provider/model tuples.
* @return ModelInterface|null The model matching the highest-priority preference, or null if none.
*/
private function findModelByPreference(array $candidateModels): ?ModelInterface
{
if ($this->modelPreferenceMap === []) {
return null;
}

foreach ($candidateModels as [$providerId, $modelMetadata]) {
$modelId = $modelMetadata->getId();

$providerModelKey = $this->createProviderModelPreferenceKey($providerId, $modelId);
if (isset($this->modelPreferenceMap[$providerModelKey])) {
return $this->registry->getProviderModel($providerId, $modelId, $this->modelConfig);
}

$provider = $this->providerIdOrClassName;
$modelMetadata = $modelsMetadata[0];
$modelKey = $this->createModelPreferenceKey($modelId);
if (isset($this->modelPreferenceMap[$modelKey])) {
return $this->registry->getProviderModel($providerId, $modelId, $this->modelConfig);
}
}

// Get the model instance from the provider
return $this->registry->getProviderModel(
$provider,
$modelMetadata->getId(),
$this->modelConfig
);
return null;
}

/**
* Normalizes and validates a preference identifier string.
*
* @since n.e.x.t
*
* @param mixed $value The value to normalize.
* @param string $emptyMessage The message for empty or invalid values.
* @return string The normalized identifier.
*
* @throws InvalidArgumentException If the value is not a non-empty string.
*/
private function normalizePreferenceIdentifier(
$value,
string $emptyMessage = 'Model preference identifiers cannot be empty.'
): string {
if (!is_string($value)) {
throw new InvalidArgumentException($emptyMessage);
}

$trimmed = trim($value);
if ($trimmed === '') {
throw new InvalidArgumentException($emptyMessage);
}

return $trimmed;
}

/**
* Creates a preference key for a provider/model combination.
*
* @since n.e.x.t
*
* @param string $providerId The provider identifier.
* @param string $modelId The model identifier.
* @return string The generated preference key.
*/
private function createProviderModelPreferenceKey(string $providerId, string $modelId): string
{
return 'providerModel::' . $providerId . '::' . $modelId;
}

/**
* Creates a preference key for a model identifier.
*
* @since n.e.x.t
*
* @param string $modelId The model identifier.
* @return string The generated preference key.
*/
private function createModelPreferenceKey(string $modelId): string
{
return 'model::' . $modelId;
}

/**
Expand Down
Loading
Loading