Skip to content

Commit

Permalink
Update Google Cloud Generative Model Hooks/Operators to bring parity …
Browse files Browse the repository at this point in the history
…with Vertex AI API (#40484)
  • Loading branch information
CYarros10 committed Jul 4, 2024
1 parent 9918f2a commit 9c97067
Show file tree
Hide file tree
Showing 10 changed files with 751 additions and 143 deletions.
141 changes: 141 additions & 0 deletions airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
from typing import Sequence

import vertexai
from deprecated import deprecated
from vertexai.generative_models import GenerativeModel, Part
from vertexai.language_models import TextEmbeddingModel, TextGenerationModel

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook


Expand Down Expand Up @@ -59,11 +61,23 @@ def get_generative_model(self, pretrained_model: str) -> GenerativeModel:
model = GenerativeModel(pretrained_model)
return model

@deprecated(
reason=(
"The `get_generative_model_part` method is deprecated and will be removed after 01.01.2025, please include `Part` objects in `contents` parameter of `airflow.providers.google.cloud.hooks.generative_model.GenerativeModelHook.generative_model_generate_content`"
),
category=AirflowProviderDeprecationWarning,
)
def get_generative_model_part(self, content_gcs_path: str, content_mime_type: str | None = None) -> Part:
"""Return a Generative Model Part object."""
part = Part.from_uri(content_gcs_path, mime_type=content_mime_type)
return part

@deprecated(
reason=(
"The `prompt_language_model` method is deprecated and will be removed after 01.01.2025, please use `airflow.providers.google.cloud.hooks.generative_model.GenerativeModelHook.text_generation_model_predict` method."
),
category=AirflowProviderDeprecationWarning,
)
@GoogleBaseHook.fallback_to_default_project_id
def prompt_language_model(
self,
Expand Down Expand Up @@ -112,6 +126,12 @@ def prompt_language_model(
)
return response.text

@deprecated(
reason=(
"The `generate_text_embeddings` method is deprecated and will be removed after 01.01.2025, please use `airflow.providers.google.cloud.hooks.generative_model.GenerativeModelHook.text_embedding_model_get_embeddings` method."
),
category=AirflowProviderDeprecationWarning,
)
@GoogleBaseHook.fallback_to_default_project_id
def generate_text_embeddings(
self,
Expand All @@ -136,6 +156,12 @@ def generate_text_embeddings(

return response.values

@deprecated(
reason=(
"The `prompt_multimodal_model` method is deprecated and will be removed after 01.01.2025, please use `airflow.providers.google.cloud.hooks.generative_model.GenerativeModelHook.generative_model_generate_content` method."
),
category=AirflowProviderDeprecationWarning,
)
@GoogleBaseHook.fallback_to_default_project_id
def prompt_multimodal_model(
self,
Expand Down Expand Up @@ -169,6 +195,12 @@ def prompt_multimodal_model(

return response.text

@deprecated(
reason=(
"The `prompt_multimodal_model_with_media` method is deprecated and will be removed after 01.01.2025, please use `airflow.providers.google.cloud.hooks.generative_model.GenerativeModelHook.generative_model_generate_content` method."
),
category=AirflowProviderDeprecationWarning,
)
@GoogleBaseHook.fallback_to_default_project_id
def prompt_multimodal_model_with_media(
self,
Expand Down Expand Up @@ -207,3 +239,112 @@ def prompt_multimodal_model_with_media(
)

return response.text

@GoogleBaseHook.fallback_to_default_project_id
def text_generation_model_predict(
self,
prompt: str,
pretrained_model: str,
temperature: float,
max_output_tokens: int,
top_p: float,
top_k: int,
location: str,
project_id: str = PROVIDE_PROJECT_ID,
) -> str:
"""
Use the Vertex AI PaLM API to generate natural language text.
:param prompt: Required. Inputs or queries that a user or a program gives
to the Vertex AI PaLM API, in order to elicit a specific response.
:param pretrained_model: A pre-trained model optimized for performing natural
language tasks such as classification, summarization, extraction, content
creation, and ideation.
:param temperature: Temperature controls the degree of randomness in token
selection.
:param max_output_tokens: Token limit determines the maximum amount of text
output.
:param top_p: Tokens are selected from most probable to least until the sum
of their probabilities equals the top_p value. Defaults to 0.8.
:param top_k: A top_k of 1 means the selected token is the most probable
among all tokens.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
"""
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())

parameters = {
"temperature": temperature,
"max_output_tokens": max_output_tokens,
"top_p": top_p,
"top_k": top_k,
}

model = self.get_text_generation_model(pretrained_model)

response = model.predict(
prompt=prompt,
**parameters,
)
return response.text

@GoogleBaseHook.fallback_to_default_project_id
def text_embedding_model_get_embeddings(
self,
prompt: str,
pretrained_model: str,
location: str,
project_id: str = PROVIDE_PROJECT_ID,
) -> list:
"""
Use the Vertex AI PaLM API to generate text embeddings.
:param prompt: Required. Inputs or queries that a user or a program gives
to the Vertex AI PaLM API, in order to elicit a specific response.
:param pretrained_model: A pre-trained model optimized for generating text embeddings.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
"""
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
model = self.get_text_embedding_model(pretrained_model)

response = model.get_embeddings([prompt])[0] # single prompt

return response.values

@GoogleBaseHook.fallback_to_default_project_id
def generative_model_generate_content(
self,
contents: list,
location: str,
tools: list | None = None,
generation_config: dict | None = None,
safety_settings: dict | None = None,
pretrained_model: str = "gemini-pro",
project_id: str = PROVIDE_PROJECT_ID,
) -> str:
"""
Use the Vertex AI Gemini Pro foundation model to generate natural language text.
:param contents: Required. The multi-part content of a message that a user or a program
gives to the generative model, in order to elicit a specific response.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param generation_config: Optional. Generation configuration settings.
:param safety_settings: Optional. Per request settings for blocking unsafe content.
:param pretrained_model: By default uses the pre-trained model `gemini-pro`,
supporting prompts with text-only input, including natural language
tasks, multi-turn text and code chat, and code generation. It can
output text and code.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
"""
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())

model = self.get_generative_model(pretrained_model)
response = model.generate_content(
contents=contents,
tools=tools,
generation_config=generation_config,
safety_settings=safety_settings,
)

return response.text
Loading

0 comments on commit 9c97067

Please sign in to comment.