-
Notifications
You must be signed in to change notification settings - Fork 2.6k
refactor: Migrate RemoteWhisperTranscriber to OpenAI SDK.
#6149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8eb09d3
14aefd2
0deb512
10f6ba2
d04fcb1
c57b7ff
bc13703
8e4eebf
25ecec5
d069035
212c94f
3714122
5fbbbb5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,20 +1,17 @@ | ||
| from typing import List, Optional, Dict, Any, Union, BinaryIO, Literal, get_args, Sequence | ||
|
|
||
| import os | ||
| import json | ||
| import io | ||
| import logging | ||
| from pathlib import Path | ||
|
|
||
| from haystack.preview.utils import request_with_retry | ||
| from haystack.preview import component, Document, default_to_dict | ||
| import os | ||
| from typing import Any, Dict, List, Optional | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
| import openai | ||
|
|
||
| from haystack.preview import Document, component, default_from_dict, default_to_dict | ||
| from haystack.preview.dataclasses import ByteStream | ||
|
|
||
| OPENAI_TIMEOUT = float(os.environ.get("HAYSTACK_OPENAI_TIMEOUT_SEC", 600)) | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| WhisperRemoteModel = Literal["whisper-1"] | ||
| API_BASE_URL = "https://api.openai.com/v1" | ||
|
|
||
|
|
||
| @component | ||
|
|
@@ -30,108 +27,112 @@ class RemoteWhisperTranscriber: | |
|
|
||
| def __init__( | ||
| self, | ||
| api_key: str, | ||
| model_name: WhisperRemoteModel = "whisper-1", | ||
| api_base: str = "https://api.openai.com/v1", | ||
| whisper_params: Optional[Dict[str, Any]] = None, | ||
| api_key: Optional[str] = None, | ||
| model_name: str = "whisper-1", | ||
| organization: Optional[str] = None, | ||
| api_base_url: str = API_BASE_URL, | ||
| **kwargs, | ||
| ): | ||
| """ | ||
| Transcribes a list of audio files into a list of Documents. | ||
|
|
||
| :param api_key: OpenAI API key. | ||
| :param model_name: Name of the model to use. It now accepts only `whisper-1`. | ||
| :param organization: The OpenAI-Organization ID, defaults to `None`. For more details, see OpenAI | ||
| [documentation](https://platform.openai.com/docs/api-reference/requesting-organization). | ||
| :param api_base: OpenAI base URL, defaults to `"https://api.openai.com/v1"`. | ||
| :param kwargs: Other parameters to use for the model. These parameters are all sent directly to the OpenAI | ||
| endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/audio) for more details. | ||
| Some of the supported parameters: | ||
| - `language`: The language of the input audio. | ||
| Supplying the input language in ISO-639-1 format | ||
| will improve accuracy and latency. | ||
| - `prompt`: An optional text to guide the model's | ||
| style or continue a previous audio segment. | ||
| The prompt should match the audio language. | ||
| - `response_format`: The format of the transcript | ||
| output, in one of these options: json, text, srt, | ||
| verbose_json, or vtt. Defaults to "json". Currently only "json" is supported. | ||
| - `temperature`: The sampling temperature, between 0 | ||
| and 1. Higher values like 0.8 will make the output more | ||
| random, while lower values like 0.2 will make it more | ||
| focused and deterministic. If set to 0, the model will | ||
| use log probability to automatically increase the | ||
| temperature until certain thresholds are hit. | ||
| """ | ||
| if model_name not in get_args(WhisperRemoteModel): | ||
| raise ValueError( | ||
| f"Model name not recognized. Choose one among: " f"{', '.join(get_args(WhisperRemoteModel))}." | ||
| ) | ||
| if not api_key: | ||
| raise ValueError("API key is None.") | ||
|
|
||
| # if the user does not provide the API key, check if it is set in the module client | ||
| api_key = api_key or openai.api_key | ||
| if api_key is None: | ||
| try: | ||
| api_key = os.environ["OPENAI_API_KEY"] | ||
| except KeyError as e: | ||
| raise ValueError( | ||
| "RemoteWhisperTranscriber expects an OpenAI API key. " | ||
| "Set the OPENAI_API_KEY environment variable (recommended) or pass it explicitly." | ||
| ) from e | ||
| openai.api_key = api_key | ||
|
|
||
| self.organization = organization | ||
| self.model_name = model_name | ||
| self.api_key = api_key | ||
| self.api_base = api_base | ||
| self.whisper_params = whisper_params or {} | ||
| self.api_base_url = api_base_url | ||
|
|
||
| @component.output_types(documents=List[Document]) | ||
| def run(self, audio_files: List[Path], whisper_params: Optional[Dict[str, Any]] = None): | ||
| """ | ||
| Transcribe the audio files into a list of Documents, one for each input file. | ||
| # Only response_format = "json" is supported | ||
| whisper_params = kwargs | ||
| if whisper_params.get("response_format") != "json": | ||
| logger.warning( | ||
| "RemoteWhisperTranscriber only supports 'response_format: json'. This parameter will be overwritten." | ||
| ) | ||
| whisper_params["response_format"] = "json" | ||
awinml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.whisper_params = whisper_params | ||
|
|
||
| For the supported audio formats, languages, and other parameters, see the | ||
| [Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper | ||
| [github repo](https://github.com/openai/whisper). | ||
| if organization is not None: | ||
| openai.organization = organization | ||
|
|
||
| :param audio_files: a list of paths or binary streams to transcribe | ||
| :returns: a list of Documents, one for each file. The content of the document is the transcription text, | ||
| while the document's metadata contains all the other values returned by the Whisper model, such as the | ||
| alignment data. Another key called `audio_file` contains the path to the audio file used for the | ||
| transcription. | ||
| def to_dict(self) -> Dict[str, Any]: | ||
| """ | ||
| if whisper_params is None: | ||
| whisper_params = self.whisper_params | ||
| Serialize this component to a dictionary. | ||
| This method overrides the default serializer in order to | ||
| avoid leaking the `api_key` value passed to the constructor. | ||
| """ | ||
| return default_to_dict( | ||
| self, | ||
| model_name=self.model_name, | ||
| organization=self.organization, | ||
| api_base_url=self.api_base_url, | ||
| **self.whisper_params, | ||
| ) | ||
|
|
||
| documents = self.transcribe(audio_files, **whisper_params) | ||
| return {"documents": documents} | ||
| @classmethod | ||
| def from_dict(cls, data: Dict[str, Any]) -> "RemoteWhisperTranscriber": | ||
| """ | ||
| Deserialize this component from a dictionary. | ||
| """ | ||
| return default_from_dict(cls, data) | ||
|
|
||
| def transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Document]: | ||
| @component.output_types(documents=List[Document]) | ||
| def run(self, streams: List[ByteStream]): | ||
| """ | ||
| Transcribe the audio files into a list of Documents, one for each input file. | ||
|
|
||
| For the supported audio formats, languages, and other parameters, see the | ||
| [Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper | ||
| [github repo](https://github.com/openai/whisper). | ||
|
|
||
| :param audio_files: a list of paths or binary streams to transcribe | ||
| :returns: a list of transcriptions. | ||
| :param audio_files: a list of ByteStream objects to transcribe. | ||
| :returns: a list of Documents, one for each file. The content of the document is the transcription text. | ||
| """ | ||
| transcriptions = self._raw_transcribe(audio_files=audio_files, **kwargs) | ||
| documents = [] | ||
| for audio, transcript in zip(audio_files, transcriptions): | ||
| content = transcript.pop("text") | ||
| if not isinstance(audio, (str, Path)): | ||
| audio = "<<binary stream>>" | ||
| doc = Document(text=content, metadata={"audio_file": audio, **transcript}) | ||
| documents.append(doc) | ||
| return documents | ||
|
|
||
| def _raw_transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Dict[str, Any]]: | ||
| """ | ||
| Transcribe the given audio files. Returns a list of strings. | ||
|
|
||
| For the supported audio formats, languages, and other parameters, see the | ||
| [Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper | ||
| [github repo](https://github.com/openai/whisper). | ||
| for stream in streams: | ||
| file = io.BytesIO(stream.data) | ||
| try: | ||
| file.name = stream.metadata["file_path"] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @awinml yes, let's do a check here if stream.metadata["file_path"] is present. If it is, use it. If not, just use a random name, e.g.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried the API without using an extension, it does not allow that. I updated the example notebook with a test. I think, we can use Something like this: for stream in streams:
file = io.BytesIO(stream.data)
try:
file.name = stream.metadata["file_path"]
except KeyError as e:
file.name = "audio_input.wav"
warning_msg = """Did not find 'file_path', setting 'file_name' to 'audio_input.wav'."""
logger.warning(warning_msg, e)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, I will push the changes without the warning then. |
||
| except KeyError: | ||
| file.name = "audio_input.wav" | ||
|
|
||
| :param audio_files: a list of paths or binary streams to transcribe. | ||
| :param kwargs: any other parameters that Whisper API can understand. | ||
| :returns: a list of transcriptions as they are produced by the Whisper API (JSON). | ||
| """ | ||
| translate = kwargs.pop("translate", False) | ||
| url = f"{self.api_base}/audio/{'translations' if translate else 'transcriptions'}" | ||
| data = {"model": self.model_name, **kwargs} | ||
| headers = {"Authorization": f"Bearer {self.api_key}"} | ||
|
|
||
| transcriptions = [] | ||
| for audio_file in audio_files: | ||
| if isinstance(audio_file, (str, Path)): | ||
| audio_file = open(audio_file, "rb") | ||
|
|
||
| request_files = ("file", (audio_file.name, audio_file, "application/octet-stream")) | ||
| response = request_with_retry( | ||
| method="post", url=url, data=data, headers=headers, files=[request_files], timeout=OPENAI_TIMEOUT | ||
| ) | ||
| transcription = json.loads(response.content) | ||
|
|
||
| transcriptions.append(transcription) | ||
| return transcriptions | ||
| content = openai.Audio.transcribe(file=file, model=self.model_name, **self.whisper_params) | ||
| doc = Document(text=content["text"], metadata=stream.metadata) | ||
| documents.append(doc) | ||
|
|
||
| def to_dict(self) -> Dict[str, Any]: | ||
| """ | ||
| This method overrides the default serializer in order to avoid leaking the `api_key` value passed | ||
| to the constructor. | ||
| """ | ||
| return default_to_dict( | ||
| self, model_name=self.model_name, api_base=self.api_base, whisper_params=self.whisper_params | ||
| ) | ||
| return {"documents": documents} | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| --- | ||
| preview: | ||
| - | | ||
| Migrate RemoteWhisperTranscriber to OpenAI SDK. |
Uh oh!
There was an error while loading. Please reload this page.