@@ -82,27 +82,126 @@ class AudioSpeechStreamResponse(BaseModel):
8282
8383 model_config = ConfigDict (arbitrary_types_allowed = True )
8484
85- def stream_to_file (self , file_path : str ) -> None :
85+ def stream_to_file (
86+ self , file_path : str , response_format : AudioResponseFormat | str | None = None
87+ ) -> None :
88+ """
89+ Save the audio response to a file.
90+
91+ For non-streaming responses, writes the complete file as received.
92+ For streaming responses, collects binary chunks and constructs a valid
93+ file format based on the response_format parameter.
94+
95+ Args:
96+ file_path: Path where the audio file should be saved.
97+ response_format: Format of the audio (wav, mp3, or raw). If not provided,
98+ will attempt to infer from file extension or default to wav.
99+ """
100+ # Determine response format
101+ if response_format is None :
102+ # Infer from file extension
103+ ext = file_path .lower ().split ("." )[- 1 ] if "." in file_path else ""
104+ if ext in ["wav" ]:
105+ response_format = AudioResponseFormat .WAV
106+ elif ext in ["mp3" , "mpeg" ]:
107+ response_format = AudioResponseFormat .MP3
108+ elif ext in ["raw" , "pcm" ]:
109+ response_format = AudioResponseFormat .RAW
110+ else :
111+ # Default to WAV if unknown
112+ response_format = AudioResponseFormat .WAV
113+
114+ if isinstance (response_format , str ):
115+ response_format = AudioResponseFormat (response_format )
116+
86117 if isinstance (self .response , TogetherResponse ):
87- # save response to file
118+ # Non-streaming: save complete file
88119 with open (file_path , "wb" ) as f :
89120 f .write (self .response .data )
90121
91122 elif isinstance (self .response , Iterator ):
123+ # Streaming: collect binary chunks
124+ audio_chunks = []
125+ for chunk in self .response :
126+ if isinstance (chunk .data , bytes ):
127+ audio_chunks .append (chunk .data )
128+ elif isinstance (chunk .data , dict ):
129+ # SSE format with JSON/base64
130+ try :
131+ stream_event = AudioSpeechStreamEventResponse (
132+ response = {"data" : chunk .data }
133+ )
134+ if isinstance (stream_event .response , StreamSentinel ):
135+ break
136+ audio_chunks .append (
137+ base64 .b64decode (stream_event .response .data .b64 )
138+ )
139+ except Exception :
140+ continue # Skip malformed chunks
141+
142+ if not audio_chunks :
143+ raise ValueError ("No audio data received in streaming response" )
144+
145+ # Concatenate all chunks
146+ audio_data = b"" .join (audio_chunks )
147+
92148 with open (file_path , "wb" ) as f :
93- for chunk in self .response :
94- # Try to parse as stream chunk
95- stream_event_response = AudioSpeechStreamEventResponse (
96- response = {"data" : chunk .data }
149+ if response_format == AudioResponseFormat .WAV :
150+ if audio_data .startswith (b"RIFF" ):
151+ # Already a valid WAV file
152+ f .write (audio_data )
153+ else :
154+ # Raw PCM - add WAV header
155+ self ._write_wav_header (f , audio_data )
156+ elif response_format == AudioResponseFormat .MP3 :
157+ # MP3 format: Check if data is actually MP3 or raw PCM
158+ # MP3 files start with ID3 tag or sync word (0xFF 0xFB/0xFA/0xF3/0xF2)
159+ is_mp3 = audio_data .startswith (b"ID3" ) or (
160+ len (audio_data ) > 0
161+ and audio_data [0 :1 ] == b"\xff "
162+ and len (audio_data ) > 1
163+ and audio_data [1 ] & 0xE0 == 0xE0
97164 )
98165
99- if isinstance (stream_event_response .response , StreamSentinel ):
100- break
101-
102- # decode base64
103- audio = base64 .b64decode (stream_event_response .response .data .b64 )
104-
105- f .write (audio )
166+ if is_mp3 :
167+ f .write (audio_data )
168+ else :
169+ raise ValueError ("Invalid MP3 data received." )
170+ else :
171+ # RAW format: write PCM data as-is
172+ f .write (audio_data )
173+
174+ @staticmethod
175+ def _write_wav_header (file_handle : BinaryIO , audio_data : bytes ) -> None :
176+ """
177+ Write WAV file header for raw PCM audio data.
178+
179+ Uses default TTS parameters: 16-bit PCM, mono, 24000 Hz sample rate.
180+ """
181+ import struct
182+
183+ sample_rate = 24000
184+ num_channels = 1
185+ bits_per_sample = 16
186+ byte_rate = sample_rate * num_channels * bits_per_sample // 8
187+ block_align = num_channels * bits_per_sample // 8
188+ data_size = len (audio_data )
189+
190+ # Write WAV header
191+ file_handle .write (b"RIFF" )
192+ file_handle .write (struct .pack ("<I" , 36 + data_size )) # File size - 8
193+ file_handle .write (b"WAVE" )
194+ file_handle .write (b"fmt " )
195+ file_handle .write (struct .pack ("<I" , 16 )) # fmt chunk size
196+ file_handle .write (struct .pack ("<H" , 1 )) # Audio format (1 = PCM)
197+ file_handle .write (struct .pack ("<H" , num_channels ))
198+ file_handle .write (struct .pack ("<I" , sample_rate ))
199+ file_handle .write (struct .pack ("<I" , byte_rate ))
200+ file_handle .write (struct .pack ("<H" , block_align ))
201+ file_handle .write (struct .pack ("<H" , bits_per_sample ))
202+ file_handle .write (b"data" )
203+ file_handle .write (struct .pack ("<I" , data_size ))
204+ file_handle .write (audio_data )
106205
107206
108207class AudioTranscriptionResponseFormat (str , Enum ):
0 commit comments