Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: Pydantic settings config #6392

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/samples/concepts/chat_completion/chat_gpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
kernel = Kernel()

service_id = "chat-gpt"
kernel.add_service(OpenAIChatCompletion(service_id=service_id))
kernel.add_service(OpenAIChatCompletion(service_id=service_id, ai_model_id="gpt-3.5-turbo"))

settings = kernel.get_prompt_execution_settings_from_service_id(service_id)
settings.max_tokens = 2000
Expand Down
39 changes: 8 additions & 31 deletions python/samples/concepts/memory/azure_cognitive_search_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from semantic_kernel import Kernel
from semantic_kernel.connectors.ai.open_ai import AzureTextCompletion, AzureTextEmbedding
from semantic_kernel.connectors.memory.azure_cognitive_search import AzureCognitiveSearchMemoryStore
from semantic_kernel.connectors.memory.azure_cognitive_search.azure_ai_search_settings import AzureAISearchSettings
from semantic_kernel.core_plugins import TextMemoryPlugin
from semantic_kernel.memory import SemanticTextMemory

Expand Down Expand Up @@ -43,41 +42,19 @@ async def search_acs_memory_questions(memory: SemanticTextMemory) -> None:
async def main() -> None:
kernel = Kernel()

azure_ai_search_settings = AzureAISearchSettings()

vector_size = 1536

# Setting up OpenAI services for text completion and text embedding
text_complete_service_id = "dv"
kernel.add_service(
AzureTextCompletion(
service_id=text_complete_service_id,
),
)
embedding_service_id = "ada"
embedding_gen = AzureTextEmbedding(
service_id=embedding_service_id,
)
kernel.add_service(
embedding_gen,
)

acs_connector = AzureCognitiveSearchMemoryStore(
vector_size=vector_size,
search_endpoint=azure_ai_search_settings.endpoint,
admin_key=azure_ai_search_settings.api_key,
)

memory = SemanticTextMemory(storage=acs_connector, embeddings_generator=embedding_gen)
kernel.add_plugin(TextMemoryPlugin(memory), "TextMemoryPlugin")

print("Populating memory...")
await populate_memory(memory)
kernel.add_service(AzureTextCompletion(service_id="dv"))
async with AzureCognitiveSearchMemoryStore(vector_size=vector_size) as acs_connector:
memory = SemanticTextMemory(storage=acs_connector, embeddings_generator=AzureTextEmbedding(service_id="ada"))
kernel.add_plugin(TextMemoryPlugin(memory), "TextMemoryPlugin")

print("Asking questions... (manually)")
await search_acs_memory_questions(memory)
print("Populating memory...")
await populate_memory(memory)

await acs_connector.close()
print("Asking questions... (manually)")
await search_acs_memory_questions(memory)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,14 @@
# }

# Create the data source settings
azure_ai_search_settings = AzureAISearchSettings.create()
azure_ai_search_settings = AzureAISearchSettings.create(env_file_path=".env")

az_source = AzureAISearchDataSource(parameters=azure_ai_search_settings.model_dump())
az_source = AzureAISearchDataSource.from_azure_ai_search_settings(azure_ai_search_settings=azure_ai_search_settings)
extra = ExtraBody(data_sources=[az_source])
req_settings = AzureChatPromptExecutionSettings(service_id="default", extra_body=extra)

# When using data, use the 2024-02-15-preview API version.
chat_service = AzureChatCompletion(
service_id="chat-gpt",
)
chat_service = AzureChatCompletion(service_id="chat-gpt")
kernel.add_service(chat_service)

prompt_template_config = PromptTemplateConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import google.generativeai as palm
from google.generativeai.types import ChatResponse, MessageDict
from pydantic import PrivateAttr, StringConstraints, ValidationError
from pydantic import PrivateAttr, StringConstraints

from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
from semantic_kernel.connectors.ai.google_palm.gp_prompt_execution_settings import (
Expand Down Expand Up @@ -37,6 +37,7 @@ def __init__(
api_key: str | None = None,
message_history: ChatHistory | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
):
"""Initializes a new instance of the GooglePalmChatCompletion class.

Expand All @@ -48,25 +49,17 @@ def __init__(
message_history (ChatHistory | None): The message history to use for context. (Optional)
env_file_path (str | None): Use the environment settings file as a fallback to
environment variables. (Optional)
env_file_encoding (str | None): The encoding of the environment settings file. (Optional)
"""
google_palm_settings = None
try:
google_palm_settings = GooglePalmSettings.create(env_file_path=env_file_path)
except ValidationError as e:
logger.warning(f"Error loading Google Palm pydantic settings: {e}")

api_key = api_key or (
google_palm_settings.api_key.get_secret_value()
if google_palm_settings and google_palm_settings.api_key
else None
)
ai_model_id = ai_model_id or (
google_palm_settings.chat_model_id if google_palm_settings and google_palm_settings.chat_model_id else None
google_palm_settings = GooglePalmSettings.create(
api_key=api_key,
chat_model_id=ai_model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)

super().__init__(
ai_model_id=ai_model_id,
api_key=api_key,
ai_model_id=google_palm_settings.chat_model_id,
api_key=google_palm_settings.api_key.get_secret_value() if google_palm_settings.api_key else None,
)
self._message_history = message_history

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import google.generativeai as palm
from google.generativeai.types import Completion
from google.generativeai.types.text_types import TextCompletion
from pydantic import StringConstraints, ValidationError
from pydantic import StringConstraints

from semantic_kernel.connectors.ai.google_palm.gp_prompt_execution_settings import GooglePalmTextPromptExecutionSettings
from semantic_kernel.connectors.ai.google_palm.settings.google_palm_settings import GooglePalmSettings
Expand All @@ -21,7 +21,13 @@
class GooglePalmTextCompletion(TextCompletionClientBase):
api_key: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1)]

def __init__(self, ai_model_id: str, api_key: str | None = None, env_file_path: str | None = None):
def __init__(
self,
ai_model_id: str,
api_key: str | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
):
"""Initializes a new instance of the GooglePalmTextCompletion class.

Args:
Expand All @@ -31,23 +37,19 @@ def __init__(self, ai_model_id: str, api_key: str | None = None, env_file_path:
read from either the env vars or the .env settings file.
env_file_path (str | None): Use the environment settings file as a
fallback to environment variables. (Optional)
env_file_encoding (str | None): The encoding of the environment settings file. (Optional)
"""
try:
google_palm_settings = GooglePalmSettings.create(env_file_path=env_file_path)
except ValidationError as e:
logger.warning(f"Error loading Google Palm pydantic settings: {e}")

api_key = api_key or (
google_palm_settings.api_key.get_secret_value()
if google_palm_settings and google_palm_settings.api_key
else None
google_palm_settings = GooglePalmSettings.create(
api_key=api_key,
text_model_id=ai_model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
ai_model_id = ai_model_id or (
google_palm_settings.text_model_id if google_palm_settings and google_palm_settings.text_model_id else None
super().__init__(
ai_model_id=google_palm_settings.text_model_id,
api_key=google_palm_settings.api_key.get_secret_value() if google_palm_settings.api_key else None,
)

super().__init__(ai_model_id=ai_model_id, api_key=api_key)

async def get_text_contents(
self, prompt: str, settings: GooglePalmTextPromptExecutionSettings
) -> list[TextContent]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import google.generativeai as palm
from numpy import array, ndarray
from pydantic import StringConstraints, ValidationError
from pydantic import StringConstraints

from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import EmbeddingGeneratorBase
from semantic_kernel.connectors.ai.google_palm.settings.google_palm_settings import GooglePalmSettings
Expand All @@ -19,7 +19,13 @@
class GooglePalmTextEmbedding(EmbeddingGeneratorBase):
api_key: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1)]

def __init__(self, ai_model_id: str, api_key: str | None = None, env_file_path: str | None = None) -> None:
def __init__(
self,
ai_model_id: str,
api_key: str | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
"""Initializes a new instance of the GooglePalmTextEmbedding class.

Args:
Expand All @@ -29,23 +35,18 @@ def __init__(self, ai_model_id: str, api_key: str | None = None, env_file_path:
read from either the env vars or the .env settings file.
env_file_path (str | None): Use the environment settings file
as a fallback to environment variables. (Optional)
env_file_encoding (str | None): The encoding of the environment settings file. (Optional)
"""
try:
google_palm_settings = GooglePalmSettings.create(env_file_path=env_file_path)
except ValidationError as e:
logger.error(f"Error loading Google Palm pydantic settings: {e}")

api_key = api_key or (
google_palm_settings.api_key.get_secret_value()
if google_palm_settings and google_palm_settings.api_key
else None
google_palm_settings = GooglePalmSettings.create(
api_key=api_key,
embedding_model_id=ai_model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
ai_model_id = ai_model_id or (
google_palm_settings.embedding_model_id
if google_palm_settings and google_palm_settings.embedding_model_id
else None
super().__init__(
ai_model_id=google_palm_settings.embedding_model_id,
api_key=google_palm_settings.api_key.get_secret_value() if google_palm_settings.api_key else None,
)
super().__init__(ai_model_id=ai_model_id, api_key=api_key)

async def generate_embeddings(self, texts: list[str], **kwargs: Any) -> ndarray:
"""Generates embeddings for the given list of texts."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Copyright (c) Microsoft. All rights reserved.

from typing import ClassVar

from pydantic import SecretStr
from pydantic_settings import BaseSettings

from semantic_kernel.kernel_pydantic import KernelBaseSettings


class GooglePalmSettings(BaseSettings):
class GooglePalmSettings(KernelBaseSettings):
"""Google Palm model settings.

The settings are first loaded from environment variables with the prefix 'GOOGLE_PALM_'. If the
Expand All @@ -24,26 +27,9 @@ class GooglePalmSettings(BaseSettings):
(Env var GOOGLE_PALM_EMBEDDING_MODEL_ID)
"""

env_file_path: str | None = None
env_prefix: ClassVar[str] = "GOOGLE_PALM_"

api_key: SecretStr | None = None
chat_model_id: str | None = None
text_model_id: str | None = None
embedding_model_id: str | None = None

class Config:
"""Pydantic configuration settings."""

env_prefix = "GOOGLE_PALM_"
env_file = None
env_file_encoding = "utf-8"
extra = "ignore"
case_sensitive = False

@classmethod
def create(cls, **kwargs):
"""Create the settings object."""
if "env_file_path" in kwargs and kwargs["env_file_path"]:
cls.Config.env_file = kwargs["env_file_path"]
else:
cls.Config.env_file = None
return cls(**kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from semantic_kernel.connectors.ai.open_ai.prompt_execution_settings.open_ai_prompt_execution_settings import (
OpenAIChatPromptExecutionSettings,
)
from semantic_kernel.connectors.memory.azure_cognitive_search.azure_ai_search_settings import AzureAISearchSettings
from semantic_kernel.kernel_pydantic import KernelBaseModel

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -82,6 +83,18 @@ class AzureAISearchDataSource(AzureChatRequestBase):
type: Literal["azure_search"] = "azure_search"
parameters: Annotated[dict, AzureAISearchDataSourceParameters]

@classmethod
def from_azure_ai_search_settings(cls, azure_ai_search_settings: AzureAISearchSettings, **kwargs: Any):
"""Create an instance from Azure AI Search settings."""
kwargs["parameters"] = {
"endpoint": str(azure_ai_search_settings.endpoint),
"index_name": azure_ai_search_settings.index_name,
"authentication": {
"key": azure_ai_search_settings.api_key.get_secret_value() if azure_ai_search_settings.api_key else None
},
}
return cls(**kwargs)


DataSource = Annotated[Union[AzureAISearchDataSource, AzureCosmosDBDataSource], Field(discriminator="type")]

Expand Down
Loading
Loading