Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 47 additions & 12 deletions google/genai/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,49 @@ def _t_tool_response(
) from e


def _t_speech_config(
origin: Union[types.SpeechConfigUnionDict, Any],
) -> Optional[types.SpeechConfig]:
if not origin:
return None
if isinstance(origin, types.SpeechConfig):
return origin
if isinstance(origin, str):
# There is no way to know if the string is a voice name or a language code.
raise ValueError(
f'Unsupported speechConfig type: {type(origin)}. There is no way to'
' know if the string is a voice name or a language code.'
)
if isinstance(origin, dict):
speech_config = types.SpeechConfig()
if (
'voice_config' in origin
and origin['voice_config'] is not None
and 'prebuilt_voice_config' in origin['voice_config']
and origin['voice_config']['prebuilt_voice_config'] is not None
and 'voice_name' in origin['voice_config']['prebuilt_voice_config']
):
speech_config.voice_config = types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(
voice_name=origin['voice_config']['prebuilt_voice_config'].get(
'voice_name'
)
)
)
if 'language_code' in origin and origin['language_code'] is not None:
speech_config.language_code = origin['language_code']
if (
speech_config.voice_config is None
and speech_config.language_code is None
):
raise ValueError(
'Unsupported speechConfig type: {type(origin)}. At least one of'
' voice_config or language_code must be set.'
)
return speech_config
raise ValueError(f'Unsupported speechConfig type: {type(origin)}')


class AsyncLive(_api_module.BaseModule):
"""AsyncLive. The live module is experimental."""

Expand Down Expand Up @@ -1056,18 +1099,14 @@ def _LiveSetup_to_mldev(
if getv(to_object, ['generationConfig']) is not None:
to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_mldev(
self._api_client,
t.t_speech_config(
self._api_client, getv(config, ['speech_config'])
),
_t_speech_config(getv(config, ['speech_config'])),
to_object,
)
else:
to_object['generationConfig'] = {
'speechConfig': _SpeechConfig_to_mldev(
self._api_client,
t.t_speech_config(
self._api_client, getv(config, ['speech_config'])
),
_t_speech_config(getv(config, ['speech_config'])),
to_object,
)
}
Expand Down Expand Up @@ -1169,18 +1208,14 @@ def _LiveSetup_to_vertex(
if getv(to_object, ['generationConfig']) is not None:
to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_vertex(
self._api_client,
t.t_speech_config(
self._api_client, getv(config, ['speech_config'])
),
_t_speech_config(getv(config, ['speech_config'])),
to_object,
)
else:
to_object['generationConfig'] = {
'speechConfig': _SpeechConfig_to_vertex(
self._api_client,
t.t_speech_config(
self._api_client, getv(config, ['speech_config'])
),
_t_speech_config(getv(config, ['speech_config'])),
to_object,
)
}
Expand Down
6 changes: 6 additions & 0 deletions google/genai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,9 @@ def _SpeechConfig_to_mldev(
),
)

if getv(from_object, ['language_code']) is not None:
setv(to_object, ['languageCode'], getv(from_object, ['language_code']))

return to_object


Expand Down Expand Up @@ -1490,6 +1493,9 @@ def _SpeechConfig_to_vertex(
),
)

if getv(from_object, ['language_code']) is not None:
setv(to_object, ['languageCode'], getv(from_object, ['language_code']))

return to_object


Expand Down
12 changes: 8 additions & 4 deletions google/genai/tests/live/test_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ... import live
from ... import types


def exception_if_mldev(vertexai, exception_type: type[Exception]):
if vertexai:
return contextlib.nullcontext()
Expand Down Expand Up @@ -490,7 +491,8 @@ async def test_bidi_setup_to_api_speech_config(vertexai):
'speechConfig': {
'voiceConfig': {
'prebuiltVoiceConfig': {'voiceName': 'en-default'}
}
},
'languageCode': 'en-US',
},
'temperature': 0.7,
'topP': 0.8,
Expand All @@ -512,14 +514,15 @@ async def test_bidi_setup_to_api_speech_config(vertexai):
expected_result['setup']['model'] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
expected_result['setup']['generationConfig']['responseModalities'] = ['AUDIO']
else:
expected_result['setup']['model'] = 'models/test_model'
expected_result['setup']['model'] = 'models/test_model'

# Test for mldev, config is a dict
config_dict = {
'speech_config': {
'voice_config': {
'prebuilt_voice_config': {'voice_name': 'en-default'}
}
},
'language_code': 'en-US',
},
'temperature': 0.7,
'top_p': 0.8,
Expand All @@ -542,7 +545,8 @@ async def test_bidi_setup_to_api_speech_config(vertexai):
prebuilt_voice_config=types.PrebuiltVoiceConfig(
voice_name='en-default'
)
)
),
language_code='en-US',
),
temperature=0.7,
top_p=0.8,
Expand Down
11 changes: 11 additions & 0 deletions google/genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1541,6 +1541,12 @@ class SpeechConfig(_common.BaseModel):
description="""The configuration for the speaker to use.
""",
)
language_code: Optional[str] = Field(
default=None,
description="""Language code (ISO 639. e.g. en-US) for the speech synthesization.
Only available for Live API.
""",
)


class SpeechConfigDict(TypedDict, total=False):
Expand All @@ -1550,6 +1556,11 @@ class SpeechConfigDict(TypedDict, total=False):
"""The configuration for the speaker to use.
"""

language_code: Optional[str]
"""Language code (ISO 639. e.g. en-US) for the speech synthesization.
Only available for Live API.
"""


SpeechConfigOrDict = Union[SpeechConfig, SpeechConfigDict]

Expand Down