-
Notifications
You must be signed in to change notification settings - Fork 35
Adds PromptBuilder preferred model support #110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
52d5889
ba25084
d21afd4
1e66a08
0aaed4c
722e4aa
0c76328
79a7a56
1b5fba3
2101ede
1590816
8771816
6dc4d61
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
| use WordPress\AiClient\Messages\Enums\ModalityEnum; | ||
| use WordPress\AiClient\Providers\Models\Contracts\ModelInterface; | ||
| use WordPress\AiClient\Providers\Models\DTO\ModelConfig; | ||
| use WordPress\AiClient\Providers\Models\DTO\ModelMetadata; | ||
| use WordPress\AiClient\Providers\Models\DTO\ModelRequirements; | ||
| use WordPress\AiClient\Providers\Models\DTO\RequiredOption; | ||
| use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; | ||
|
|
@@ -59,6 +60,11 @@ class PromptBuilder | |
| */ | ||
| protected ?ModelInterface $model = null; | ||
|
|
||
| /** | ||
| * @var list<ModelInterface|string|array{0:string,1:string}> Preferred models to evaluate in order. | ||
| */ | ||
| protected array $modelPreference = []; | ||
|
|
||
| /** | ||
| * @var string|null The provider ID or class name. | ||
| */ | ||
|
|
@@ -215,6 +221,81 @@ public function usingModel(ModelInterface $model): self | |
| return $this; | ||
| } | ||
|
|
||
| /** | ||
| * Sets preferred models to evaluate in order. | ||
| * | ||
| * @since n.e.x.t | ||
| * | ||
| * @param string|ModelInterface|array{0:string,1:string} ...$preferredModels The preferred models as model IDs, | ||
| * model instances, or [provider ID, model ID] tuples. | ||
| * @return self | ||
| * | ||
| * @throws InvalidArgumentException When a preferred model has an invalid type or identifier. | ||
| */ | ||
| public function usingModelPreference(...$preferredModels): self | ||
| { | ||
| if ($preferredModels === []) { | ||
| throw new InvalidArgumentException('At least one model preference must be provided.'); | ||
| } | ||
|
|
||
| $normalizedModels = []; | ||
|
|
||
| foreach ($preferredModels as $preferredModel) { | ||
| if (is_array($preferredModel)) { | ||
| if (!array_is_list($preferredModel) || count($preferredModel) !== 2) { | ||
| throw new InvalidArgumentException( | ||
| 'Model preference tuple must contain provider ID and model identifier.' | ||
| ); | ||
| } | ||
|
|
||
| [$providerId, $modelIdentifier] = $preferredModel; | ||
|
|
||
| if (!is_string($providerId)) { | ||
| throw new InvalidArgumentException('Model preference provider identifiers cannot be empty.'); | ||
| } | ||
| $providerId = trim($providerId); | ||
| if ($providerId === '') { | ||
| throw new InvalidArgumentException('Model preference provider identifiers cannot be empty.'); | ||
| } | ||
|
|
||
| if (!is_string($modelIdentifier)) { | ||
| throw new InvalidArgumentException( | ||
| 'Model preference tuple must contain provider ID and model identifier.' | ||
| ); | ||
| } | ||
| $modelIdentifier = trim($modelIdentifier); | ||
| if ($modelIdentifier === '') { | ||
| throw new InvalidArgumentException('Model preference identifiers cannot be empty.'); | ||
| } | ||
|
|
||
| $normalizedModels[] = [$providerId, $modelIdentifier]; | ||
| continue; | ||
| } | ||
|
|
||
| if ($preferredModel instanceof ModelInterface) { | ||
| $normalizedModels[] = $preferredModel; | ||
| continue; | ||
| } | ||
|
|
||
| if (is_string($preferredModel)) { | ||
| $trimmed = trim($preferredModel); | ||
| if ($trimmed === '') { | ||
| throw new InvalidArgumentException('Model preference identifiers cannot be empty.'); | ||
| } | ||
| $normalizedModels[] = $trimmed; | ||
| continue; | ||
| } | ||
|
|
||
| throw new InvalidArgumentException( | ||
| 'Model preferences must be model identifiers, instances of ModelInterface, or provider/model tuples.' | ||
| ); | ||
| } | ||
|
|
||
| $this->modelPreference = $normalizedModels; | ||
|
|
||
| return $this; | ||
| } | ||
|
|
||
| /** | ||
| * Sets the model configuration. | ||
| * | ||
|
|
@@ -1040,15 +1121,46 @@ private function getConfiguredModel(CapabilityEnum $capability): ModelInterface | |
| { | ||
| $requirements = ModelRequirements::fromPromptData($capability, $this->messages, $this->modelConfig); | ||
|
|
||
| // If a model has been explicitly set, return it | ||
| if ($this->model !== null) { | ||
| // Explicit model was provided via usingModel(); just update config and bind dependencies. | ||
| $this->model->setConfig($this->modelConfig); | ||
| $this->registry->bindModelDependencies($this->model); | ||
| return $this->model; | ||
| } | ||
|
|
||
| // Find a suitable model based on requirements | ||
| // Retrieve the models which satisfy the requirements. | ||
| $candidateModels = $this->getCandidateModels($requirements); | ||
|
|
||
| // Find the model which matches the preference, if any. | ||
| $selectedModel = $this->findModelByPreference($candidateModels); | ||
|
|
||
| if ($selectedModel === null) { | ||
| // No preference matched; fall back to the first candidate discovered. | ||
| [$providerId, $modelMetadata] = $candidateModels[0]; | ||
| $selectedModel = $this->registry->getProviderModel( | ||
| $providerId, | ||
| $modelMetadata->getId(), | ||
| $this->modelConfig | ||
| ); | ||
| } | ||
|
|
||
| return $selectedModel; | ||
| } | ||
|
|
||
| /** | ||
| * Builds the list of candidate provider/model combinations that satisfy the requirements. | ||
| * | ||
| * @since n.e.x.t | ||
| * | ||
| * @param ModelRequirements $requirements The requirements derived from the prompt. | ||
| * @return list<array{0:string,1:ModelMetadata}> The candidate provider/model tuples. | ||
| * | ||
| * @throws InvalidArgumentException If no suitable models are found. | ||
| */ | ||
| private function getCandidateModels(ModelRequirements $requirements): array | ||
| { | ||
| if ($this->providerIdOrClassName === null) { | ||
| // No provider locked in, gather all models across providers that meet requirements. | ||
| $providerModelsMetadata = $this->registry->findModelsMetadataForSupport($requirements); | ||
|
|
||
| if (empty($providerModelsMetadata)) { | ||
|
|
@@ -1066,41 +1178,103 @@ private function getConfiguredModel(CapabilityEnum $capability): ModelInterface | |
| ); | ||
| } | ||
|
|
||
| $firstProviderModels = $providerModelsMetadata[0]; | ||
| $provider = $firstProviderModels->getProvider()->getId(); | ||
| $modelMetadata = $firstProviderModels->getModels()[0]; | ||
| } else { | ||
| $modelsMetadata = $this->registry->findProviderModelsMetadataForSupport( | ||
| $this->providerIdOrClassName, | ||
| $requirements | ||
| ); | ||
|
|
||
| if (empty($modelsMetadata)) { | ||
| throw new InvalidArgumentException( | ||
| sprintf( | ||
| 'No models found for %s that support the required capabilities and options for this prompt. ' . | ||
| 'Required capabilities: %s. Required options: %s', | ||
| $this->providerIdOrClassName, | ||
| implode(', ', array_map(function ($cap) { | ||
| return $cap->value; | ||
| }, $requirements->getRequiredCapabilities())), | ||
| implode(', ', array_map(function ($opt) { | ||
| return $opt->getName()->value . '=' . json_encode($opt->getValue()); | ||
| }, $requirements->getRequiredOptions())) | ||
| ) | ||
| ); | ||
| $candidateModels = []; | ||
| foreach ($providerModelsMetadata as $providerModels) { | ||
| $providerId = $providerModels->getProvider()->getId(); | ||
| foreach ($providerModels->getModels() as $modelMetadata) { | ||
| $candidateModels[] = [$providerId, $modelMetadata]; | ||
| } | ||
| } | ||
|
|
||
| $provider = $this->providerIdOrClassName; | ||
| $modelMetadata = $modelsMetadata[0]; | ||
| return $candidateModels; | ||
| } | ||
|
|
||
| // Get the model instance from the provider | ||
| return $this->registry->getProviderModel( | ||
| $provider, | ||
| $modelMetadata->getId(), | ||
| $this->modelConfig | ||
| // Provider set, only consider models from that provider. | ||
| $modelsMetadata = $this->registry->findProviderModelsMetadataForSupport( | ||
| $this->providerIdOrClassName, | ||
| $requirements | ||
| ); | ||
|
|
||
| if (empty($modelsMetadata)) { | ||
| throw new InvalidArgumentException( | ||
| sprintf( | ||
| 'No models found for %s that support the required capabilities and options for this prompt. ' . | ||
| 'Required capabilities: %s. Required options: %s', | ||
| $this->providerIdOrClassName, | ||
| implode(', ', array_map(function ($cap) { | ||
| return $cap->value; | ||
| }, $requirements->getRequiredCapabilities())), | ||
| implode(', ', array_map(function ($opt) { | ||
| return $opt->getName()->value . '=' . json_encode($opt->getValue()); | ||
| }, $requirements->getRequiredOptions())) | ||
| ) | ||
| ); | ||
| } | ||
|
|
||
| $candidateModels = []; | ||
| foreach ($modelsMetadata as $modelMetadata) { | ||
| $candidateModels[] = [$this->providerIdOrClassName, $modelMetadata]; | ||
|
||
| } | ||
|
|
||
| return $candidateModels; | ||
| } | ||
|
|
||
| /** | ||
| * Finds a preferred model among the given candidates. | ||
| * | ||
| * @since n.e.x.t | ||
| * | ||
| * @param list<array{0:string,1:ModelMetadata}> $candidateModels Candidate provider/model tuples. | ||
| * @return ModelInterface|null The model matching the highest-priority preference, or null if none. | ||
| */ | ||
| private function findModelByPreference(array $candidateModels): ?ModelInterface | ||
| { | ||
| foreach ($candidateModels as [$providerId, $modelMetadata]) { | ||
| foreach ($this->modelPreference as $preferenceIndex => $preferredModel) { | ||
| if (is_array($preferredModel)) { | ||
| [$preferredProviderId, $preferredModelId] = $preferredModel; | ||
| if ( | ||
| $providerId !== $preferredProviderId || | ||
| $modelMetadata->getId() !== $preferredModelId | ||
| ) { | ||
| continue; | ||
| } | ||
|
|
||
| return $this->registry->getProviderModel( | ||
| $providerId, | ||
| $modelMetadata->getId(), | ||
| $this->modelConfig | ||
| ); | ||
| } elseif ($preferredModel instanceof ModelInterface) { | ||
| $preferredProviderId = $preferredModel->providerMetadata()->getId(); | ||
| $preferredModelId = $preferredModel->metadata()->getId(); | ||
|
|
||
| if ( | ||
| $providerId !== $preferredProviderId || | ||
| $modelMetadata->getId() !== $preferredModelId | ||
| ) { | ||
| continue; | ||
| } | ||
|
|
||
| $preferredModel->setConfig($this->modelConfig); | ||
| $this->registry->bindModelDependencies($preferredModel); | ||
|
|
||
| return $preferredModel; | ||
| } else { | ||
| if ($modelMetadata->getId() !== $preferredModel) { | ||
| continue; | ||
| } | ||
|
|
||
| return $this->registry->getProviderModel( | ||
| $providerId, | ||
| $modelMetadata->getId(), | ||
| $this->modelConfig | ||
| ); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return null; | ||
| } | ||
|
|
||
| /** | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.