From 7e2e9935a9fd3c5b7f70f06385449fe711e07116 Mon Sep 17 00:00:00 2001 From: Danylo Boiko <55975773+danylo-boiko@users.noreply.github.com> Date: Wed, 16 Oct 2024 19:12:19 +0300 Subject: [PATCH] toolkit: add text to speech synthesis (#801) * Add synthesizer * Format assistants web * Add message_id to StreamEnd * Add icons * Add synthesize message client endpoint * Add useSynthesize hook * Minor refactoring * Add integration tests * Add unit tests * Refactor unit tests * Update templates * Update back-end to use google cloud * Format tests * Add tts to experimental_features * Regenerate client web * Add experimental features hook * Update front-end to use google cloud * Add loading spinner * Refactor useSynthesizer hook * Fix typo and remove is ascii check * Add api key validation * Update exception text * Fix typecheck --- .env-template | 4 + poetry.lock | 44 +++++-- pyproject.toml | 1 + src/backend/config/secrets.template.yaml | 4 +- src/backend/config/settings.py | 8 ++ src/backend/crud/message.py | 21 ++++ src/backend/routers/conversation.py | 45 +++++++ src/backend/routers/experimental_features.py | 4 +- src/backend/schemas/chat.py | 1 + src/backend/services/chat.py | 1 + src/backend/services/synthesizer.py | 79 +++++++++++++ .../integration/routers/test_conversation.py | 56 +++++++++ src/backend/tests/unit/crud/test_message.py | 16 +++ .../assistants_web/src/assets/icons/Stop.tsx | 18 +++ .../src/assets/icons/Volume.tsx | 22 ++++ .../assistants_web/src/assets/icons/index.ts | 2 + .../src/cohere-client/client.ts | 14 +++ .../cohere-client/generated/schemas.gen.ts | 110 ++---------------- .../cohere-client/generated/services.gen.ts | 74 ++++++------ .../src/cohere-client/generated/types.gen.ts | 62 +++++----- .../src/components/MessageRow/MessageRow.tsx | 54 ++++++--- .../MessagingContainer/MessagingContainer.tsx | 4 + .../src/components/UI/Button.tsx | 6 +- .../assistants_web/src/components/UI/Icon.tsx | 14 +++ .../src/components/UI/IconButton.tsx | 4 + .../assistants_web/src/hooks/use-chat.ts | 1 + .../src/hooks/use-experimentalFeatures.ts | 12 ++ .../src/hooks/use-synthesizer.ts | 109 +++++++++++++++++ .../assistants_web/src/types/message.ts | 1 + .../assistants_web/src/utils/conversation.ts | 2 + 30 files changed, 603 insertions(+), 190 deletions(-) create mode 100644 src/backend/services/synthesizer.py create mode 100644 src/interfaces/assistants_web/src/assets/icons/Stop.tsx create mode 100644 src/interfaces/assistants_web/src/assets/icons/Volume.tsx create mode 100644 src/interfaces/assistants_web/src/hooks/use-experimentalFeatures.ts create mode 100644 src/interfaces/assistants_web/src/hooks/use-synthesizer.ts diff --git a/.env-template b/.env-template index 08dfe7f648..96fc44d57a 100644 --- a/.env-template +++ b/.env-template @@ -91,3 +91,7 @@ GOOGLE_DRIVE_CLIENT_ID= GOOGLE_DRIVE_CLIENT_SECRET= NEXT_PUBLIC_GOOGLE_DRIVE_CLIENT_ID=${GOOGLE_DRIVE_CLIENT_ID} NEXT_PUBLIC_GOOGLE_DRIVE_DEVELOPER_KEY= + +# Google Cloud + +GOOGLE_CLOUD_API_KEY= diff --git a/poetry.lock b/poetry.lock index b3d4dc8fd1..362a985c29 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -1339,6 +1339,8 @@ files = [ [package.dependencies] google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" +grpcio = {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""} +grpcio-status = {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""} proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" requests = ">=2.18.0,<3.0.0.dev0" @@ -1422,6 +1424,23 @@ requests-oauthlib = ">=0.7.0" [package.extras] tool = ["click (>=6.0.0)"] +[[package]] +name = "google-cloud-texttospeech" +version = "2.18.0" +description = "Google Cloud Texttospeech API client library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google_cloud_texttospeech-2.18.0-py2.py3-none-any.whl", hash = "sha256:178eb686bd439c46cc1bfee2fab3960914e08655630a52e224469873152d1418"}, + {file = "google_cloud_texttospeech-2.18.0.tar.gz", hash = "sha256:8d1f7577a6f86ed48335e10814680261e2976da54888c8a1d586044ee19196f4"}, +] + +[package.dependencies] +google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} +google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev" +proto-plus = ">=1.22.3,<2.0.0dev" +protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev" + [[package]] name = "googleapis-common-protos" version = "1.65.0" @@ -1576,6 +1595,22 @@ files = [ [package.extras] protobuf = ["grpcio-tools (>=1.66.1)"] +[[package]] +name = "grpcio-status" +version = "1.66.1" +description = "Status proto mapping for gRPC" +optional = false +python-versions = ">=3.8" +files = [ + {file = "grpcio_status-1.66.1-py3-none-any.whl", hash = "sha256:cf9ed0b4a83adbe9297211c95cb5488b0cd065707e812145b842c85c4782ff02"}, + {file = "grpcio_status-1.66.1.tar.gz", hash = "sha256:b3f7d34ccc46d83fea5261eea3786174459f763c31f6e34f1d24eba6d515d024"}, +] + +[package.dependencies] +googleapis-common-protos = ">=1.5.5" +grpcio = ">=1.66.1" +protobuf = ">=5.26.1,<6.0dev" + [[package]] name = "h11" version = "0.14.0" @@ -5932,11 +5967,6 @@ files = [ {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, - {file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"}, - {file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"}, - {file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"}, - {file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"}, - {file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"}, ] [package.dependencies] @@ -6671,4 +6701,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "~3.11" -content-hash = "02ea066b85656db579426761977ac24d66a1a0001d5e14a21024ab4908f00eb6" +content-hash = "077321dcaebc0346ae1669450ad6415aaa9b8c117e7f7154ed412eb127d75ecc" diff --git a/pyproject.toml b/pyproject.toml index 65549e6db9..10bea3ea2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ hyperframe = "^6.0.1" llama-index = "^0.11.10" llama-index-llms-cohere = "^0.3.0" llama-index-embeddings-cohere = "^0.2.1" +google-cloud-texttospeech = "^2.18.0" [tool.poetry.group.dev] optional = true diff --git a/src/backend/config/secrets.template.yaml b/src/backend/config/secrets.template.yaml index c89acb52b8..dbd9248b9c 100644 --- a/src/backend/config/secrets.template.yaml +++ b/src/backend/config/secrets.template.yaml @@ -42,4 +42,6 @@ auth: oidc: client_id: client_secret: - well_known_endpoint: \ No newline at end of file + well_known_endpoint: +google_cloud: + api_key: \ No newline at end of file diff --git a/src/backend/config/settings.py b/src/backend/config/settings.py index f32ed8ac77..4e32b6bb91 100644 --- a/src/backend/config/settings.py +++ b/src/backend/config/settings.py @@ -218,6 +218,13 @@ class RedisSettings(BaseSettings, BaseModel): ) +class GoogleCloudSettings(BaseSettings, BaseModel): + model_config = SETTINGS_CONFIG + api_key: Optional[str] = Field( + default=None, validation_alias=AliasChoices("GOOGLE_CLOUD_API_KEY", "api_key") + ) + + class SageMakerSettings(BaseSettings, BaseModel): model_config = SETTINGS_CONFIG endpoint_name: Optional[str] = Field( @@ -331,6 +338,7 @@ class Settings(BaseSettings): tools: Optional[ToolSettings] = Field(default=ToolSettings()) database: Optional[DatabaseSettings] = Field(default=DatabaseSettings()) redis: Optional[RedisSettings] = Field(default=RedisSettings()) + google_cloud: Optional[GoogleCloudSettings] = Field(default=GoogleCloudSettings()) deployments: Optional[DeploymentSettings] = Field(default=DeploymentSettings()) logger: Optional[LoggerSettings] = Field(default=LoggerSettings()) diff --git a/src/backend/crud/message.py b/src/backend/crud/message.py index abacf15fd5..62ed28396f 100644 --- a/src/backend/crud/message.py +++ b/src/backend/crud/message.py @@ -68,6 +68,27 @@ def get_messages( ) +@validate_transaction +def get_conversation_message(db: Session, conversation_id: str, message_id: str, user_id: str) -> Message | None: + """ + Get a message based on the conversation ID, message ID, and user ID. + + Args: + db (Session): Database session. + conversation_id (str): Conversation ID. + message_id (str): Message ID. + user_id (str): User ID. + + Returns: + Message | None: Message with the given conversation ID, message ID, and user ID or None if not found. + """ + return ( + db.query(Message) + .filter(Message.conversation_id == conversation_id, Message.id == message_id, Message.user_id == user_id) + .first() + ) + + @validate_transaction def get_messages_by_conversation_id( db: Session, conversation_id: str, user_id: str diff --git a/src/backend/routers/conversation.py b/src/backend/routers/conversation.py index 8da4e25622..a59baf3b2f 100644 --- a/src/backend/routers/conversation.py +++ b/src/backend/routers/conversation.py @@ -3,11 +3,13 @@ from fastapi import APIRouter, Depends, Form, HTTPException, Request from fastapi import File as RequestFile from fastapi import UploadFile as FastAPIUploadFile +from starlette.responses import Response from backend.chat.custom.utils import get_deployment from backend.config.routers import RouterName from backend.crud import agent as agent_crud from backend.crud import conversation as conversation_crud +from backend.crud import message as message_crud from backend.database_models import Conversation as ConversationModel from backend.database_models.database import DBSessionDep from backend.schemas.agent import Agent @@ -39,6 +41,7 @@ get_file_service, validate_file, ) +from backend.services.synthesizer import synthesize router = APIRouter( prefix="/v1/conversations", @@ -543,3 +546,45 @@ async def generate_title( title=title, error=error, ) + + +# SYNTHESIZE +@router.get("/{conversation_id}/synthesize/{message_id}") +async def synthesize_message( + conversation_id: str, + message_id: str, + session: DBSessionDep, + ctx: Context = Depends(get_context), +) -> Response: + """ + Generate a synthesized audio for a specific message in a conversation. + + Args: + conversation_id (str): Conversation ID. + message_id (str): Message ID. + session (DBSessionDep): Database session. + ctx (Context): Context object. + + Returns: + Response: Synthesized audio file. + + Raises: + HTTPException: If the message with the given ID is not found or synthesis fails. + """ + user_id = ctx.get_user_id() + message = message_crud.get_conversation_message(session, conversation_id, message_id, user_id) + + if not message: + raise HTTPException( + status_code=404, + detail=f"Message with ID: {message_id} not found.", + ) + + try: + synthesized_audio = synthesize(message.text) + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Error while message synthesis: {e}" + ) + + return Response(synthesized_audio, media_type="audio/mp3") diff --git a/src/backend/routers/experimental_features.py b/src/backend/routers/experimental_features.py index 24eb21cd11..55f6450efe 100644 --- a/src/backend/routers/experimental_features.py +++ b/src/backend/routers/experimental_features.py @@ -13,7 +13,7 @@ @router.get("/") -def list_experimental_features(ctx: Context = Depends(get_context)): +def list_experimental_features(ctx: Context = Depends(get_context)) -> dict[str, bool]: """ List all experimental features and if they are enabled @@ -22,8 +22,8 @@ def list_experimental_features(ctx: Context = Depends(get_context)): Returns: Dict[str, bool]: Experimental feature and their isEnabled state """ - experimental_features = { "USE_AGENTS_VIEW": Settings().feature_flags.use_agents_view, + "USE_TEXT_TO_SPEECH_SYNTHESIS": bool(Settings().google_cloud.api_key), } return experimental_features diff --git a/src/backend/schemas/chat.py b/src/backend/schemas/chat.py index 28cbc2229e..185d1ddf88 100644 --- a/src/backend/schemas/chat.py +++ b/src/backend/schemas/chat.py @@ -173,6 +173,7 @@ class StreamToolCallsGeneration(ChatResponse): class StreamEnd(ChatResponse): + message_id: str | None = Field(default=None) response_id: str | None = Field(default=None) event_type: ClassVar[StreamEvent] = StreamEvent.STREAM_END generation_id: str | None = Field(default=None) diff --git a/src/backend/services/chat.py b/src/backend/services/chat.py index 3b004eb1e1..276f190c16 100644 --- a/src/backend/services/chat.py +++ b/src/backend/services/chat.py @@ -659,6 +659,7 @@ async def generate_chat_stream( user_id = ctx.get_user_id() stream_end_data = { + "message_id": response_message.id, "conversation_id": conversation_id, "response_id": ctx.get_trace_id(), "text": "", diff --git a/src/backend/services/synthesizer.py b/src/backend/services/synthesizer.py new file mode 100644 index 0000000000..04c376b778 --- /dev/null +++ b/src/backend/services/synthesizer.py @@ -0,0 +1,79 @@ +from google.cloud.texttospeech import ( + AudioConfig, + AudioEncoding, + SynthesisInput, + TextToSpeechClient, + VoiceSelectionParams, +) +from googleapiclient.discovery import build + +from backend.config import Settings + + +def synthesize(text: str) -> bytes: + """ + Synthesizes speech from the input text. + + Args: + text (str): The input text to be synthesized into speech. + + Returns: + bytes: The audio content generated from the input text in MP3 format. + + Raises: + ValueError: If the Google Cloud API key from the settings is not valid. + """ + client = TextToSpeechClient(client_options={ + "api_key": _validate_google_cloud_api_key() + }) + + language = detect_language(text) + + response = client.synthesize_speech( + input=SynthesisInput(text=text), + voice=VoiceSelectionParams(language_code=language), + audio_config=AudioConfig(audio_encoding=AudioEncoding.MP3) + ) + + return response.audio_content + + +def detect_language(text: str) -> str: + """ + Detect the language of the given text. + + Args: + text (str): The text for which the language needs to be detected. + + Returns: + str: The language code of the detected language (e.g., 'en', 'es'). + + Raises: + ValueError: If the Google Cloud API key from the settings is not valid. + """ + client = build("translate", "v2", developerKey=_validate_google_cloud_api_key()) + + response = client.detections().list(q=text).execute() + + return response["detections"][0][0]["language"] + + +def _validate_google_cloud_api_key() -> str: + """ + Validates the Google Cloud API key from the settings. + + Returns: + str: The validated API key. + + Raises: + ValueError: If the API key is not found in the settings or is empty. + """ + google_cloud = Settings().google_cloud + + if not google_cloud: + raise ValueError("google_cloud in secrets.yaml is missing.") + + if not google_cloud.api_key: + raise ValueError("google_cloud.api_key in secrets.yaml is missing.") + + return google_cloud.api_key diff --git a/src/backend/tests/integration/routers/test_conversation.py b/src/backend/tests/integration/routers/test_conversation.py index d7cce31caf..7b330a819c 100644 --- a/src/backend/tests/integration/routers/test_conversation.py +++ b/src/backend/tests/integration/routers/test_conversation.py @@ -4,6 +4,7 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session +from backend.config import Settings from backend.config.deployments import ModelDeploymentName from backend.database_models import Conversation from backend.schemas.user import User @@ -157,3 +158,58 @@ def test_generate_title_error_invalid_model( == "status_code: 404, body: {'message': \"model 'invalid' not found, make sure the correct model ID was used and that you have access to the model.\"}" ) assert response["title"] == "" + + +# SYNTHESIZE + + +is_google_cloud_api_key_set = bool(Settings().google_cloud.api_key) + + +@pytest.mark.skipif(not is_google_cloud_api_key_set, reason="Google Cloud API key not set, skipping test") +def test_synthesize_english_message( + session_client: TestClient, + session: Session, + user: User, +) -> None: + conversation = get_factory("Conversation", session).create(user_id=user.id) + message = get_factory("Message", session).create( + id="1", text="Hello world!", conversation_id=conversation.id, user_id=user.id + ) + response = session_client.get( + f"/v1/conversations/{conversation.id}/synthesize/{message.id}", + headers={"User-Id": conversation.user_id}, + ) + assert response.status_code == 200 + assert response.headers["Content-Type"] == "audio/mp3" + + +@pytest.mark.skipif(not is_google_cloud_api_key_set, reason="Google Cloud API key not set, skipping test") +def test_synthesize_non_english_message( + session_client: TestClient, + session: Session, + user: User, +) -> None: + conversation = get_factory("Conversation", session).create(user_id=user.id) + message = get_factory("Message", session).create( + id="1", text="Bonjour le monde!", conversation_id=conversation.id, user_id=user.id + ) + response = session_client.get( + f"/v1/conversations/{conversation.id}/synthesize/{message.id}", + headers={"User-Id": conversation.user_id}, + ) + assert response.status_code == 200 + assert response.headers["Content-Type"] == "audio/mp3" + + +def test_fail_synthesize_message_nonexistent_message( + session_client: TestClient, + session: Session, + user: User, +) -> None: + response = session_client.get( + "/v1/conversations/123/synthesize/456", + headers={"User-Id": user.id}, + ) + assert response.status_code == 404 + assert response.json() == {"detail": "Message with ID: 456 not found."} diff --git a/src/backend/tests/unit/crud/test_message.py b/src/backend/tests/unit/crud/test_message.py index 44b9649446..9266d85032 100644 --- a/src/backend/tests/unit/crud/test_message.py +++ b/src/backend/tests/unit/crud/test_message.py @@ -48,6 +48,22 @@ def test_fail_get_nonexistent_message(session, user): assert message is None +def test_get_conversation_message(session, conversation, user): + _ = get_factory("Message", session).create( + id="1", text="Hello, World!", conversation_id=conversation.id, user_id=user.id + ) + + message = message_crud.get_conversation_message(session, conversation.id, "1", user.id) + assert message.conversation_id == conversation.id + assert message.id == "1" + assert message.text == "Hello, World!" + + +def test_fail_get_nonexistent_conversation_message(session, user): + message = message_crud.get_conversation_message(session, "123", "456", user.id) + assert message is None + + def test_list_messages(session, conversation, user): _ = get_factory("Message", session).create( text="Hello, World!", conversation_id=conversation.id, user_id=user.id diff --git a/src/interfaces/assistants_web/src/assets/icons/Stop.tsx b/src/interfaces/assistants_web/src/assets/icons/Stop.tsx new file mode 100644 index 0000000000..78bc3ed5f0 --- /dev/null +++ b/src/interfaces/assistants_web/src/assets/icons/Stop.tsx @@ -0,0 +1,18 @@ +import * as React from 'react'; +import { SVGProps } from 'react'; + +import { cn } from '@/utils'; + +export const Stop: React.FC> = ({ className, ...props }) => ( + + + +); diff --git a/src/interfaces/assistants_web/src/assets/icons/Volume.tsx b/src/interfaces/assistants_web/src/assets/icons/Volume.tsx new file mode 100644 index 0000000000..793dc2b8cc --- /dev/null +++ b/src/interfaces/assistants_web/src/assets/icons/Volume.tsx @@ -0,0 +1,22 @@ +import * as React from 'react'; +import { SVGProps } from 'react'; + +import { cn } from '@/utils'; + +export const Volume: React.FC> = ({ className, ...props }) => ( + + + + +); diff --git a/src/interfaces/assistants_web/src/assets/icons/index.ts b/src/interfaces/assistants_web/src/assets/icons/index.ts index d69f2fe418..e285d82e8f 100644 --- a/src/interfaces/assistants_web/src/assets/icons/index.ts +++ b/src/interfaces/assistants_web/src/assets/icons/index.ts @@ -47,6 +47,7 @@ export * from './Share'; export * from './Show'; export * from './SignOut'; export * from './Sparkle'; +export * from './Stop'; export * from './Subtract'; export * from './Sun'; export * from './ThumbsDown'; @@ -54,5 +55,6 @@ export * from './ThumbsUp'; export * from './Trash'; export * from './Upload'; export * from './UsersThree'; +export * from './Volume'; export * from './Warning'; export * from './Web'; diff --git a/src/interfaces/assistants_web/src/cohere-client/client.ts b/src/interfaces/assistants_web/src/cohere-client/client.ts index afb0393e2d..d362daee42 100644 --- a/src/interfaces/assistants_web/src/cohere-client/client.ts +++ b/src/interfaces/assistants_web/src/cohere-client/client.ts @@ -3,6 +3,7 @@ import { FetchEventSourceInit, fetchEventSource } from '@microsoft/fetch-event-s import { Body_batch_upload_file_v1_agents_batch_upload_file_post, Body_batch_upload_file_v1_conversations_batch_upload_file_post, + CancelablePromise, CohereChatRequest, CohereClientGenerated, CohereNetworkError, @@ -156,6 +157,19 @@ export class CohereClient { ); } + public async synthesizeMessage(conversationId: string, messageId: string) { + return this.cohereService.default.synthesizeMessageV1ConversationsConversationIdSynthesizeMessageIdGet( + { + conversationId, + messageId, + } + ) as CancelablePromise; + } + + public async getExperimentalFeatures() { + return this.cohereService.default.listExperimentalFeaturesV1ExperimentalFeaturesGet(); + } + public listTools({ agentId }: { agentId?: string | null }) { return this.cohereService.default.listToolsV1ToolsGet({ agentId }); } diff --git a/src/interfaces/assistants_web/src/cohere-client/generated/schemas.gen.ts b/src/interfaces/assistants_web/src/cohere-client/generated/schemas.gen.ts index a8761fbcf6..a9ab607cea 100644 --- a/src/interfaces/assistants_web/src/cohere-client/generated/schemas.gen.ts +++ b/src/interfaces/assistants_web/src/cohere-client/generated/schemas.gen.ts @@ -556,7 +556,7 @@ export const $CohereChatRequest = { }, ], title: 'The model to use for generating the response.', - default: 'command-r', + default: 'command-r-plus', }, temperature: { anyOf: [ @@ -1751,103 +1751,6 @@ export const $JWTResponse = { title: 'JWTResponse', } as const; -export const $LangchainChatRequest = { - properties: { - message: { - type: 'string', - title: 'The message to send to the chatbot.', - }, - chat_history: { - anyOf: [ - { - items: { - $ref: '#/components/schemas/ChatMessage', - }, - type: 'array', - }, - { - type: 'null', - }, - ], - title: - 'A list of entries used to construct the conversation. If provided, these messages will be used to build the prompt and the conversation_id will be ignored so no data will be stored to maintain state.', - }, - conversation_id: { - type: 'string', - title: - 'To store a conversation then create a conversation id and use it for every related request', - }, - tools: { - anyOf: [ - { - items: { - $ref: '#/components/schemas/Tool', - }, - type: 'array', - }, - { - type: 'null', - }, - ], - title: ` - List of custom or managed tools to use for the response. - If passing in managed tools, you only need to provide the name of the tool. - If passing in custom tools, you need to provide the name, description, and optionally parameter defintions of the tool. - Passing a mix of custom and managed tools is not supported. - - Managed Tools Examples: - tools=[ - { - "name": "Wiki Retriever - LangChain", - }, - { - "name": "Calculator", - } - ] - - Custom Tools Examples: - tools=[ - { - "name": "movie_title_generator", - "description": "tool to generate a cool movie title", - "parameter_definitions": { - "synopsis": { - "description": "short synopsis of the movie", - "type": "str", - "required": true - } - } - }, - { - "name": "random_number_generator", - "description": "tool to generate a random number between min and max", - "parameter_definitions": { - "min": { - "description": "minimum number", - "type": "int", - "required": true - }, - "max": { - "description": "maximum number", - "type": "int", - "required": true - } - } - }, - { - "name": "joke_generator", - "description": "tool to generate a random joke", - } - ] - `, - }, - }, - type: 'object', - required: ['message'], - title: 'LangchainChatRequest', - description: 'Request shape for Langchain Streamed Chat.', -} as const; - export const $ListAuthStrategy = { properties: { strategy: { @@ -2830,6 +2733,17 @@ export const $StreamCitationGeneration = { export const $StreamEnd = { properties: { + message_id: { + anyOf: [ + { + type: 'string', + }, + { + type: 'null', + }, + ], + title: 'Message Id', + }, response_id: { anyOf: [ { diff --git a/src/interfaces/assistants_web/src/cohere-client/generated/services.gen.ts b/src/interfaces/assistants_web/src/cohere-client/generated/services.gen.ts index f2d89636df..d957c4fc2b 100644 --- a/src/interfaces/assistants_web/src/cohere-client/generated/services.gen.ts +++ b/src/interfaces/assistants_web/src/cohere-client/generated/services.gen.ts @@ -87,8 +87,6 @@ import type { GetUsersScimV2UsersGetData, GetUsersScimV2UsersGetResponse, HealthHealthGetResponse, - LangchainChatStreamV1LangchainChatPostData, - LangchainChatStreamV1LangchainChatPostResponse, ListAgentToolMetadataV1AgentsAgentIdToolMetadataGetData, ListAgentToolMetadataV1AgentsAgentIdToolMetadataGetResponse, ListAgentsV1AgentsGetData, @@ -121,6 +119,8 @@ import type { SearchConversationsV1ConversationsSearchGetResponse, SetEnvVarsV1DeploymentsNameSetEnvVarsPostData, SetEnvVarsV1DeploymentsNameSetEnvVarsPostResponse, + SynthesizeMessageV1ConversationsConversationIdSynthesizeMessageIdGetData, + SynthesizeMessageV1ConversationsConversationIdSynthesizeMessageIdGetResponse, ToggleConversationPinV1ConversationsConversationIdTogglePinPutData, ToggleConversationPinV1ConversationsConversationIdTogglePinPutResponse, ToolAuthV1ToolAuthGetResponse, @@ -414,37 +414,6 @@ export class DefaultService { }); } - /** - * Langchain Chat Stream - * Stream chat endpoint to handle user messages and return chatbot responses using langchain. - * - * Args: - * session (DBSessionDep): Database session. - * chat_request (LangchainChatRequest): Chat request data. - * request (Request): Request object. - * ctx (Context): Context object. - * - * Returns: - * EventSourceResponse: Server-sent event response with chatbot responses. - * @param data The data for the request. - * @param data.requestBody - * @returns unknown Successful Response - * @throws ApiError - */ - public langchainChatStreamV1LangchainChatPost( - data: LangchainChatStreamV1LangchainChatPostData - ): CancelablePromise { - return this.httpRequest.request({ - method: 'POST', - url: '/v1/langchain-chat', - body: data.requestBody, - mediaType: 'application/json', - errors: { - 422: 'Validation Error', - }, - }); - } - /** * Create User * Create a new user. @@ -974,6 +943,43 @@ export class DefaultService { }); } + /** + * Synthesize Message + * Generate a synthesized audio for a specific message in a conversation. + * + * Args: + * conversation_id (str): Conversation ID. + * message_id (str): Message ID. + * session (DBSessionDep): Database session. + * ctx (Context): Context object. + * + * Returns: + * Response: Synthesized audio file. + * + * Raises: + * HTTPException: If the message with the given ID is not found. + * @param data The data for the request. + * @param data.conversationId + * @param data.messageId + * @returns unknown Successful Response + * @throws ApiError + */ + public synthesizeMessageV1ConversationsConversationIdSynthesizeMessageIdGet( + data: SynthesizeMessageV1ConversationsConversationIdSynthesizeMessageIdGetData + ): CancelablePromise { + return this.httpRequest.request({ + method: 'GET', + url: '/v1/conversations/{conversation_id}/synthesize/{message_id}', + path: { + conversation_id: data.conversationId, + message_id: data.messageId, + }, + errors: { + 422: 'Validation Error', + }, + }); + } + /** * List Tools * List all available tools. @@ -1203,7 +1209,7 @@ export class DefaultService { * ctx (Context): Context object. * Returns: * Dict[str, bool]: Experimental feature and their isEnabled state - * @returns unknown Successful Response + * @returns boolean Successful Response * @throws ApiError */ public listExperimentalFeaturesV1ExperimentalFeaturesGet(): CancelablePromise { diff --git a/src/interfaces/assistants_web/src/cohere-client/generated/types.gen.ts b/src/interfaces/assistants_web/src/cohere-client/generated/types.gen.ts index 40ca068004..7263d09a1a 100644 --- a/src/interfaces/assistants_web/src/cohere-client/generated/types.gen.ts +++ b/src/interfaces/assistants_web/src/cohere-client/generated/types.gen.ts @@ -352,16 +352,6 @@ export type JWTResponse = { token: string; }; -/** - * Request shape for Langchain Streamed Chat. - */ -export type LangchainChatRequest = { - message: string; - chat_history?: Array | null; - conversation_id?: string; - tools?: Array | null; -}; - export type ListAuthStrategy = { strategy: string; client_id: string | null; @@ -562,6 +552,7 @@ export type StreamCitationGeneration = { }; export type StreamEnd = { + message_id?: string | null; response_id?: string | null; generation_id?: string | null; conversation_id?: string | null; @@ -848,12 +839,6 @@ export type ChatV1ChatPostData = { export type ChatV1ChatPostResponse = NonStreamedChatResponse; -export type LangchainChatStreamV1LangchainChatPostData = { - requestBody: LangchainChatRequest; -}; - -export type LangchainChatStreamV1LangchainChatPostResponse = unknown; - export type CreateUserV1UsersPostData = { requestBody: backend__schemas__user__CreateUser; }; @@ -962,6 +947,13 @@ export type GenerateTitleV1ConversationsConversationIdGenerateTitlePostData = { export type GenerateTitleV1ConversationsConversationIdGenerateTitlePostResponse = GenerateTitleResponse; +export type SynthesizeMessageV1ConversationsConversationIdSynthesizeMessageIdGetData = { + conversationId: string; + messageId: string; +}; + +export type SynthesizeMessageV1ConversationsConversationIdSynthesizeMessageIdGetResponse = unknown; + export type ListToolsV1ToolsGetData = { agentId?: string | null; }; @@ -1006,7 +998,9 @@ export type SetEnvVarsV1DeploymentsNameSetEnvVarsPostData = { export type SetEnvVarsV1DeploymentsNameSetEnvVarsPostResponse = unknown; -export type ListExperimentalFeaturesV1ExperimentalFeaturesGetResponse = unknown; +export type ListExperimentalFeaturesV1ExperimentalFeaturesGetResponse = { + [key: string]: boolean; +}; export type CreateAgentV1AgentsPostData = { requestBody: CreateAgentRequest; @@ -1377,21 +1371,6 @@ export type $OpenApiTs = { }; }; }; - '/v1/langchain-chat': { - post: { - req: LangchainChatStreamV1LangchainChatPostData; - res: { - /** - * Successful Response - */ - 200: unknown; - /** - * Validation Error - */ - 422: HTTPValidationError; - }; - }; - }; '/v1/users': { post: { req: CreateUserV1UsersPostData; @@ -1607,6 +1586,21 @@ export type $OpenApiTs = { }; }; }; + '/v1/conversations/{conversation_id}/synthesize/{message_id}': { + get: { + req: SynthesizeMessageV1ConversationsConversationIdSynthesizeMessageIdGetData; + res: { + /** + * Successful Response + */ + 200: unknown; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + }; '/v1/tools': { get: { req: ListToolsV1ToolsGetData; @@ -1712,7 +1706,9 @@ export type $OpenApiTs = { /** * Successful Response */ - 200: unknown; + 200: { + [key: string]: boolean; + }; }; }; }; diff --git a/src/interfaces/assistants_web/src/components/MessageRow/MessageRow.tsx b/src/interfaces/assistants_web/src/components/MessageRow/MessageRow.tsx index aac7f9f72b..d8d5fd3959 100644 --- a/src/interfaces/assistants_web/src/components/MessageRow/MessageRow.tsx +++ b/src/interfaces/assistants_web/src/components/MessageRow/MessageRow.tsx @@ -12,6 +12,8 @@ import { LongPressMenu, } from '@/components/UI'; import { Breakpoint, useBreakpoint } from '@/hooks'; +import { useExperimentalFeatures } from '@/hooks/use-experimentalFeatures'; +import { SynthesisStatus } from '@/hooks/use-synthesizer'; import { type ChatMessage, isAbortedMessage, @@ -28,11 +30,13 @@ type Props = { message: ChatMessage; isStreamingToolEvents: boolean; isReadOnly?: boolean; + synthesisStatus?: SynthesisStatus; delay?: boolean; className?: string; onCopy?: VoidFunction; onRetry?: VoidFunction; onRegenerate?: VoidFunction; + onToggleSynthesis?: VoidFunction; }; /** @@ -45,10 +49,12 @@ export const MessageRow = forwardRef(function MessageRowI isLast, isStreamingToolEvents, isReadOnly = false, + synthesisStatus, className = '', onCopy, onRetry, onRegenerate, + onToggleSynthesis, }, ref ) { @@ -56,29 +62,36 @@ export const MessageRow = forwardRef(function MessageRowI const [isShowing, setIsShowing] = useState(false); const [isLongPressMenuOpen, setIsLongPressMenuOpen] = useState(false); - const [isStepsExpanded, setIsStepsExpanded] = useState(true); + const [isStepsExpanded, setIsStepsExpanded] = useState(true); + + const { data: experimentalFeatures } = useExperimentalFeatures(); + const { longPressProps } = useLongPress({ + onLongPress: () => setIsLongPressMenuOpen(true), + }); + + const getMessageText = () => { + return isFulfilledMessage(message) ? message.originalText : message.text; + }; + const hasSteps = (isFulfilledOrTypingMessage(message) || isErroredMessage(message) || isAbortedMessage(message)) && !!message.toolEvents && message.toolEvents.length > 0; - const isRegenerationEnabled = - isLast && !isReadOnly && isBotMessage(message) && !isErroredMessage(message); - - const getMessageText = () => { - if (isFulfilledMessage(message)) { - return message.originalText; - } - - return message.text; - }; const enableLongPress = (isFulfilledMessage(message) || isUserMessage(message)) && breakpoint === Breakpoint.sm; - const { longPressProps } = useLongPress({ - onLongPress: () => setIsLongPressMenuOpen(true), - }); + + const isSynthesisEnabled = + !!onToggleSynthesis && + !!experimentalFeatures?.USE_TEXT_TO_SPEECH_SYNTHESIS && + !!message.id && + isBotMessage(message) && + !isErroredMessage(message); + + const isRegenerationEnabled = + isLast && !isReadOnly && isBotMessage(message) && !isErroredMessage(message); // Delay the appearance of the message to make it feel more natural. useEffect(() => { @@ -148,6 +161,19 @@ export const MessageRow = forwardRef(function MessageRowI 'hidden md:invisible md:flex md:group-hover:visible': !isLast, })} > + {isSynthesisEnabled && ( + + )} {hasSteps && ( = ({ isStreamingToolEvents, }) => { const isChatEmpty = messages.length === 0; + const { synthesisStatus, toggleSynthesis } = useSynthesizer(); if (isChatEmpty) { return ( @@ -153,6 +155,7 @@ const Messages: React.FC = ({ message={m} isLast={isLastInList && !streamingMessage} isStreamingToolEvents={isStreamingToolEvents} + synthesisStatus={synthesisStatus(m.id!)} className={cn({ // Hide the last message if it is the same as the separate streamed message // to avoid a flash of duplicate messages. @@ -165,6 +168,7 @@ const Messages: React.FC = ({ })} onRetry={onRetry} onRegenerate={onRegenerate} + onToggleSynthesis={() => toggleSynthesis(m.id!)} /> ); })} diff --git a/src/interfaces/assistants_web/src/components/UI/Button.tsx b/src/interfaces/assistants_web/src/components/UI/Button.tsx index 00697f1a12..2a51d7985f 100644 --- a/src/interfaces/assistants_web/src/components/UI/Button.tsx +++ b/src/interfaces/assistants_web/src/components/UI/Button.tsx @@ -140,6 +140,9 @@ export type ButtonProps = { kind?: 'default' | 'outline'; customIcon?: React.ReactNode; }; + spinnerOptions?: { + className?: string; + }; buttonType?: 'submit' | 'reset' | 'button'; onClick?: React.MouseEventHandler; href?: string; @@ -160,6 +163,7 @@ export const Button: React.FC = ({ isLoading = false, className, iconOptions, + spinnerOptions, buttonType, onClick, href, @@ -177,7 +181,7 @@ export const Button: React.FC = ({ }); const iconElement = isLoading ? ( - + ) : icon || kind === 'cell' ? ( { ), + ['stop']: ( + + + + ), ['subtract']: ( @@ -444,6 +453,11 @@ const getIcon = (name: IconName, kind: IconKind): React.ReactNode => { ), + ['volume']: ( + + + + ), ['warning']: ( diff --git a/src/interfaces/assistants_web/src/components/UI/IconButton.tsx b/src/interfaces/assistants_web/src/components/UI/IconButton.tsx index e1bb5c0bd3..b301277022 100644 --- a/src/interfaces/assistants_web/src/components/UI/IconButton.tsx +++ b/src/interfaces/assistants_web/src/components/UI/IconButton.tsx @@ -20,6 +20,7 @@ type Props = { iconKind?: 'default' | 'outline'; iconClassName?: string; disabled?: boolean; + isLoading?: boolean; className?: string; outline?: boolean; onClick?: (e: React.MouseEvent) => void; @@ -37,6 +38,7 @@ export const IconButton: React.FC = ({ iconClassName, className, disabled, + isLoading, href, target, outline = false, @@ -46,11 +48,13 @@ export const IconButton: React.FC = ({