Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 206 additions & 32 deletions src/Builders/PromptBuilder.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use WordPress\AiClient\Messages\Enums\ModalityEnum;
use WordPress\AiClient\Providers\Models\Contracts\ModelInterface;
use WordPress\AiClient\Providers\Models\DTO\ModelConfig;
use WordPress\AiClient\Providers\Models\DTO\ModelMetadata;
use WordPress\AiClient\Providers\Models\DTO\ModelRequirements;
use WordPress\AiClient\Providers\Models\DTO\RequiredOption;
use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum;
Expand Down Expand Up @@ -59,6 +60,11 @@ class PromptBuilder
*/
protected ?ModelInterface $model = null;

/**
* @var list<ModelInterface|string|array{0:string,1:string}> Preferred models to evaluate in order.
*/
protected array $modelPreference = [];

/**
* @var string|null The provider ID or class name.
*/
Expand Down Expand Up @@ -215,6 +221,81 @@ public function usingModel(ModelInterface $model): self
return $this;
}

/**
* Sets preferred models to evaluate in order.
*
* @since n.e.x.t
*
* @param string|ModelInterface|array{0:string,1:string} ...$preferredModels The preferred models as model IDs,
* model instances, or [provider ID, model ID] tuples.
* @return self
*
* @throws InvalidArgumentException When a preferred model has an invalid type or identifier.
*/
public function usingModelPreference(...$preferredModels): self
{
if ($preferredModels === []) {
throw new InvalidArgumentException('At least one model preference must be provided.');
}

$normalizedModels = [];

foreach ($preferredModels as $preferredModel) {
if (is_array($preferredModel)) {
if (!array_is_list($preferredModel) || count($preferredModel) !== 2) {
throw new InvalidArgumentException(
'Model preference tuple must contain provider ID and model identifier.'
);
}

[$providerId, $modelIdentifier] = $preferredModel;

if (!is_string($providerId)) {
throw new InvalidArgumentException('Model preference provider identifiers cannot be empty.');
}
$providerId = trim($providerId);
if ($providerId === '') {
throw new InvalidArgumentException('Model preference provider identifiers cannot be empty.');
}

if (!is_string($modelIdentifier)) {
throw new InvalidArgumentException(
'Model preference tuple must contain provider ID and model identifier.'
);
}
$modelIdentifier = trim($modelIdentifier);
if ($modelIdentifier === '') {
throw new InvalidArgumentException('Model preference identifiers cannot be empty.');
}

$normalizedModels[] = [$providerId, $modelIdentifier];
continue;
}

if ($preferredModel instanceof ModelInterface) {
$normalizedModels[] = $preferredModel;
continue;
}

if (is_string($preferredModel)) {
$trimmed = trim($preferredModel);
if ($trimmed === '') {
throw new InvalidArgumentException('Model preference identifiers cannot be empty.');
}
$normalizedModels[] = $trimmed;
continue;
}

throw new InvalidArgumentException(
'Model preferences must be model identifiers, instances of ModelInterface, or provider/model tuples.'
);
}

$this->modelPreference = $normalizedModels;

return $this;
}

/**
* Sets the model configuration.
*
Expand Down Expand Up @@ -1040,15 +1121,46 @@ private function getConfiguredModel(CapabilityEnum $capability): ModelInterface
{
$requirements = ModelRequirements::fromPromptData($capability, $this->messages, $this->modelConfig);

// If a model has been explicitly set, return it
if ($this->model !== null) {
// Explicit model was provided via usingModel(); just update config and bind dependencies.
$this->model->setConfig($this->modelConfig);
$this->registry->bindModelDependencies($this->model);
return $this->model;
}

// Find a suitable model based on requirements
// Retrieve the models which satisfy the requirements.
$candidateModels = $this->getCandidateModels($requirements);

// Find the model which matches the preference, if any.
$selectedModel = $this->findModelByPreference($candidateModels);

if ($selectedModel === null) {
// No preference matched; fall back to the first candidate discovered.
[$providerId, $modelMetadata] = $candidateModels[0];
$selectedModel = $this->registry->getProviderModel(
$providerId,
$modelMetadata->getId(),
$this->modelConfig
);
}

return $selectedModel;
}

/**
* Builds the list of candidate provider/model combinations that satisfy the requirements.
*
* @since n.e.x.t
*
* @param ModelRequirements $requirements The requirements derived from the prompt.
* @return list<array{0:string,1:ModelMetadata}> The candidate provider/model tuples.
*
* @throws InvalidArgumentException If no suitable models are found.
*/
private function getCandidateModels(ModelRequirements $requirements): array
{
if ($this->providerIdOrClassName === null) {
// No provider locked in, gather all models across providers that meet requirements.
$providerModelsMetadata = $this->registry->findModelsMetadataForSupport($requirements);

if (empty($providerModelsMetadata)) {
Expand All @@ -1066,41 +1178,103 @@ private function getConfiguredModel(CapabilityEnum $capability): ModelInterface
);
}

$firstProviderModels = $providerModelsMetadata[0];
$provider = $firstProviderModels->getProvider()->getId();
$modelMetadata = $firstProviderModels->getModels()[0];
} else {
$modelsMetadata = $this->registry->findProviderModelsMetadataForSupport(
$this->providerIdOrClassName,
$requirements
);

if (empty($modelsMetadata)) {
throw new InvalidArgumentException(
sprintf(
'No models found for %s that support the required capabilities and options for this prompt. ' .
'Required capabilities: %s. Required options: %s',
$this->providerIdOrClassName,
implode(', ', array_map(function ($cap) {
return $cap->value;
}, $requirements->getRequiredCapabilities())),
implode(', ', array_map(function ($opt) {
return $opt->getName()->value . '=' . json_encode($opt->getValue());
}, $requirements->getRequiredOptions()))
)
);
$candidateModels = [];
foreach ($providerModelsMetadata as $providerModels) {
$providerId = $providerModels->getProvider()->getId();
foreach ($providerModels->getModels() as $modelMetadata) {
$candidateModels[] = [$providerId, $modelMetadata];
}
}

$provider = $this->providerIdOrClassName;
$modelMetadata = $modelsMetadata[0];
return $candidateModels;
}

// Get the model instance from the provider
return $this->registry->getProviderModel(
$provider,
$modelMetadata->getId(),
$this->modelConfig
// Provider set, only consider models from that provider.
$modelsMetadata = $this->registry->findProviderModelsMetadataForSupport(
$this->providerIdOrClassName,
$requirements
);

if (empty($modelsMetadata)) {
throw new InvalidArgumentException(
sprintf(
'No models found for %s that support the required capabilities and options for this prompt. ' .
'Required capabilities: %s. Required options: %s',
$this->providerIdOrClassName,
implode(', ', array_map(function ($cap) {
return $cap->value;
}, $requirements->getRequiredCapabilities())),
implode(', ', array_map(function ($opt) {
return $opt->getName()->value . '=' . json_encode($opt->getValue());
}, $requirements->getRequiredOptions()))
)
);
}

$candidateModels = [];
foreach ($modelsMetadata as $modelMetadata) {
$candidateModels[] = [$this->providerIdOrClassName, $modelMetadata];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this can be an ID or a class name, we need to normalize this to always be the ID, in order for the lookup to work.

We already ProviderRegistry::getProviderClassName($id). Maybe we should also have ProviderRegistry::getProviderId($className), and that we could then use here if the value is a class name. That should be fairly straightforward to add.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored in 79a7a56

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you actually addressed the concern I mentioned above?

}

return $candidateModels;
}

/**
* Finds a preferred model among the given candidates.
*
* @since n.e.x.t
*
* @param list<array{0:string,1:ModelMetadata}> $candidateModels Candidate provider/model tuples.
* @return ModelInterface|null The model matching the highest-priority preference, or null if none.
*/
private function findModelByPreference(array $candidateModels): ?ModelInterface
{
foreach ($candidateModels as [$providerId, $modelMetadata]) {
foreach ($this->modelPreference as $preferenceIndex => $preferredModel) {
if (is_array($preferredModel)) {
[$preferredProviderId, $preferredModelId] = $preferredModel;
if (
$providerId !== $preferredProviderId ||
$modelMetadata->getId() !== $preferredModelId
) {
continue;
}

return $this->registry->getProviderModel(
$providerId,
$modelMetadata->getId(),
$this->modelConfig
);
} elseif ($preferredModel instanceof ModelInterface) {
$preferredProviderId = $preferredModel->providerMetadata()->getId();
$preferredModelId = $preferredModel->metadata()->getId();

if (
$providerId !== $preferredProviderId ||
$modelMetadata->getId() !== $preferredModelId
) {
continue;
}

$preferredModel->setConfig($this->modelConfig);
$this->registry->bindModelDependencies($preferredModel);

return $preferredModel;
} else {
if ($modelMetadata->getId() !== $preferredModel) {
continue;
}

return $this->registry->getProviderModel(
$providerId,
$modelMetadata->getId(),
$this->modelConfig
);
}
}
}

return null;
}

/**
Expand Down
Loading
Loading