Skip to content

Commit

Permalink
[extensions] 2/n update asr model parser with input schema
Browse files Browse the repository at this point in the history
With a new input attachment type `AttachmentDataWithStringValue` defined, This diff updates the ASR Model parser to validate input is in the expected form/type.

Also made some updates to to make input attachment validation clearer.

## Testplan


### Dependencies
Sapling removed the dependency pr. Depends on
#929
  • Loading branch information
Ankush Pala [email protected] committed Jan 16, 2024
1 parent ef75543 commit 17b7a09
Showing 1 changed file with 12 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aiconfig_extension_hugging_face.local_inference.util import get_hf_model
from aiconfig import ModelParser, InferenceOptions
from aiconfig.callback import CallbackEvent
from aiconfig.schema import Prompt, Output, ExecuteResult, Attachment
from aiconfig.schema import AttachmentDataWithStringValue, Prompt, Output, ExecuteResult, Attachment, PromptInput

if TYPE_CHECKING:
from aiconfig import AIConfigRuntime
Expand Down Expand Up @@ -154,7 +154,7 @@ def validate_attachment_type_is_audio(attachment: Attachment):

def validate_and_retrieve_audio_from_attachments(prompt: Prompt) -> list[str]:
"""
Retrieves the audio uri's from each attachment in the prompt input.
Retrieves the audio uri's or base64 from each attachment in the prompt input.
Throws an exception if
- attachment is not audio
Expand All @@ -163,21 +163,24 @@ def validate_and_retrieve_audio_from_attachments(prompt: Prompt) -> list[str]:
- operation fails for any reason
"""

if not hasattr(prompt.input, "attachments") or len(prompt.input.attachments) == 0:
if not isinstance(prompt.input, PromptInput):
raise ValueError(f"Prompt input is of type {type(prompt.input) }. Please specify a PromptInput with attachments for prompt {prompt.name}.")

if prompt.input.attachments is None or len(prompt.input.attachments) == 0:
raise ValueError(f"No attachments found in input for prompt {prompt.name}. Please add an audio attachment to the prompt input.")

audio_uris: list[str] = []
audio_inputs: list[str] = []

for i, attachment in enumerate(prompt.input.attachments):
validate_attachment_type_is_audio(attachment)

if not isinstance(attachment.data, str):
# See todo above, but for now only support uri's
raise ValueError(f"Attachment #{i} data is not a uri. Please specify a uri for the audio attachment in prompt {prompt.name}.")
if not isinstance(attachment.data, AttachmentDataWithStringValue):
raise ValueError(f"""Attachment data must be of type `AttachmentDataWithStringValue` with a `kind` and `value` field.
Please specify a uri for the audio attachment in prompt {prompt.name}.""")

audio_uris.append(attachment.data)
audio_inputs.append(attachment.data.value)

return audio_uris
return audio_inputs


def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict[str, Any]]:
Expand Down

0 comments on commit 17b7a09

Please sign in to comment.