-
Notifications
You must be signed in to change notification settings - Fork 9.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for local ai speech to text (#3921)
Co-authored-by: Yeuoly <[email protected]>
- Loading branch information
1 parent
d51f52a
commit bb7c627
Showing
4 changed files
with
159 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
101 changes: 101 additions & 0 deletions
101
api/core/model_runtime/model_providers/localai/speech2text/speech2text.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from typing import IO, Optional | ||
|
||
from requests import Request, Session | ||
from yarl import URL | ||
|
||
from core.model_runtime.entities.common_entities import I18nObject | ||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType | ||
from core.model_runtime.errors.invoke import ( | ||
InvokeAuthorizationError, | ||
InvokeBadRequestError, | ||
InvokeConnectionError, | ||
InvokeError, | ||
InvokeRateLimitError, | ||
InvokeServerUnavailableError, | ||
) | ||
from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel | ||
|
||
|
||
class LocalAISpeech2text(Speech2TextModel): | ||
""" | ||
Model class for Local AI Text to speech model. | ||
""" | ||
|
||
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: | ||
""" | ||
Invoke large language model | ||
:param model: model name | ||
:param credentials: model credentials | ||
:param file: audio file | ||
:param user: unique user id | ||
:return: text for given audio file | ||
""" | ||
|
||
url = str(URL(credentials['server_url']) / "v1/audio/transcriptions") | ||
data = {"model": model} | ||
files = {"file": file} | ||
|
||
session = Session() | ||
request = Request("POST", url, data=data, files=files) | ||
prepared_request = session.prepare_request(request) | ||
response = session.send(prepared_request) | ||
|
||
if 'error' in response.json(): | ||
raise InvokeServerUnavailableError("Empty response") | ||
|
||
return response.json()["text"] | ||
|
||
def validate_credentials(self, model: str, credentials: dict) -> None: | ||
""" | ||
Validate model credentials | ||
:param model: model name | ||
:param credentials: model credentials | ||
:return: | ||
""" | ||
try: | ||
audio_file_path = self._get_demo_file_path() | ||
|
||
with open(audio_file_path, 'rb') as audio_file: | ||
self._invoke(model, credentials, audio_file) | ||
except Exception as ex: | ||
raise CredentialsValidateFailedError(str(ex)) | ||
|
||
@property | ||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | ||
return { | ||
InvokeConnectionError: [ | ||
InvokeConnectionError | ||
], | ||
InvokeServerUnavailableError: [ | ||
InvokeServerUnavailableError | ||
], | ||
InvokeRateLimitError: [ | ||
InvokeRateLimitError | ||
], | ||
InvokeAuthorizationError: [ | ||
InvokeAuthorizationError | ||
], | ||
InvokeBadRequestError: [ | ||
InvokeBadRequestError | ||
], | ||
} | ||
|
||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: | ||
""" | ||
used to define customizable model schema | ||
""" | ||
entity = AIModelEntity( | ||
model=model, | ||
label=I18nObject( | ||
en_US=model | ||
), | ||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | ||
model_type=ModelType.SPEECH2TEXT, | ||
model_properties={}, | ||
parameter_rules=[] | ||
) | ||
|
||
return entity |
54 changes: 54 additions & 0 deletions
54
api/tests/integration_tests/model_runtime/localai/test_speech2text.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import os | ||
|
||
import pytest | ||
|
||
from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||
from core.model_runtime.model_providers.localai.speech2text.speech2text import LocalAISpeech2text | ||
|
||
|
||
def test_validate_credentials(): | ||
model = LocalAISpeech2text() | ||
|
||
with pytest.raises(CredentialsValidateFailedError): | ||
model.validate_credentials( | ||
model='whisper-1', | ||
credentials={ | ||
'server_url': 'invalid_url' | ||
} | ||
) | ||
|
||
model.validate_credentials( | ||
model='whisper-1', | ||
credentials={ | ||
'server_url': os.environ.get('LOCALAI_SERVER_URL') | ||
} | ||
) | ||
|
||
|
||
def test_invoke_model(): | ||
model = LocalAISpeech2text() | ||
|
||
# Get the directory of the current file | ||
current_dir = os.path.dirname(os.path.abspath(__file__)) | ||
|
||
# Get assets directory | ||
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets') | ||
|
||
# Construct the path to the audio file | ||
audio_file_path = os.path.join(assets_dir, 'audio.mp3') | ||
|
||
# Open the file and get the file object | ||
with open(audio_file_path, 'rb') as audio_file: | ||
file = audio_file | ||
|
||
result = model.invoke( | ||
model='whisper-1', | ||
credentials={ | ||
'server_url': os.environ.get('LOCALAI_SERVER_URL') | ||
}, | ||
file=file, | ||
user="abc-123" | ||
) | ||
|
||
assert isinstance(result, str) | ||
assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10' |