Skip to content

Commit 2b1338f

Browse files
Add changes for streaming mode in TTS models (#389)
1 parent 342b970 commit 2b1338f

File tree

2 files changed

+156
-16
lines changed

2 files changed

+156
-16
lines changed

src/together/abstract/api_requestor.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -619,14 +619,29 @@ def _interpret_response(
619619
) -> Tuple[TogetherResponse | Iterator[TogetherResponse], bool]:
620620
"""Returns the response(s) and a bool indicating whether it is a stream."""
621621
content_type = result.headers.get("Content-Type", "")
622+
622623
if stream and "text/event-stream" in content_type:
624+
# SSE format streaming
623625
return (
624626
self._interpret_response_line(
625627
line, result.status_code, result.headers, stream=True
626628
)
627629
for line in parse_stream(result.iter_lines())
628630
), True
631+
elif stream and content_type in [
632+
"audio/wav",
633+
"audio/mpeg",
634+
"application/octet-stream",
635+
]:
636+
# Binary audio streaming - return chunks as binary data
637+
def binary_stream_generator() -> Iterator[TogetherResponse]:
638+
for chunk in result.iter_content(chunk_size=8192):
639+
if chunk: # Skip empty chunks
640+
yield TogetherResponse(chunk, dict(result.headers))
641+
642+
return binary_stream_generator(), True
629643
else:
644+
# Non-streaming response
630645
if content_type in ["application/octet-stream", "audio/wav", "audio/mpeg"]:
631646
content = result.content
632647
else:
@@ -648,23 +663,49 @@ async def _interpret_async_response(
648663
| tuple[TogetherResponse, bool]
649664
):
650665
"""Returns the response(s) and a bool indicating whether it is a stream."""
651-
if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
666+
content_type = result.headers.get("Content-Type", "")
667+
668+
if stream and "text/event-stream" in content_type:
669+
# SSE format streaming
652670
return (
653671
self._interpret_response_line(
654672
line, result.status, result.headers, stream=True
655673
)
656674
async for line in parse_stream_async(result.content)
657675
), True
676+
elif stream and content_type in [
677+
"audio/wav",
678+
"audio/mpeg",
679+
"application/octet-stream",
680+
]:
681+
# Binary audio streaming - return chunks as binary data
682+
async def binary_stream_generator() -> (
683+
AsyncGenerator[TogetherResponse, None]
684+
):
685+
async for chunk in result.content.iter_chunked(8192):
686+
if chunk: # Skip empty chunks
687+
yield TogetherResponse(chunk, dict(result.headers))
688+
689+
return binary_stream_generator(), True
658690
else:
691+
# Non-streaming response
659692
try:
660-
await result.read()
693+
content = await result.read()
661694
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
662695
raise error.Timeout("Request timed out") from e
663696
except aiohttp.ClientError as e:
664697
utils.log_warn(e, body=result.content)
698+
699+
if content_type in ["application/octet-stream", "audio/wav", "audio/mpeg"]:
700+
# Binary content - keep as bytes
701+
response_content: str | bytes = content
702+
else:
703+
# Text content - decode to string
704+
response_content = content.decode("utf-8")
705+
665706
return (
666707
self._interpret_response_line(
667-
(await result.read()).decode("utf-8"),
708+
response_content,
668709
result.status,
669710
result.headers,
670711
stream=False,

src/together/types/audio_speech.py

Lines changed: 112 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,27 +82,126 @@ class AudioSpeechStreamResponse(BaseModel):
8282

8383
model_config = ConfigDict(arbitrary_types_allowed=True)
8484

85-
def stream_to_file(self, file_path: str) -> None:
85+
def stream_to_file(
86+
self, file_path: str, response_format: AudioResponseFormat | str | None = None
87+
) -> None:
88+
"""
89+
Save the audio response to a file.
90+
91+
For non-streaming responses, writes the complete file as received.
92+
For streaming responses, collects binary chunks and constructs a valid
93+
file format based on the response_format parameter.
94+
95+
Args:
96+
file_path: Path where the audio file should be saved.
97+
response_format: Format of the audio (wav, mp3, or raw). If not provided,
98+
will attempt to infer from file extension or default to wav.
99+
"""
100+
# Determine response format
101+
if response_format is None:
102+
# Infer from file extension
103+
ext = file_path.lower().split(".")[-1] if "." in file_path else ""
104+
if ext in ["wav"]:
105+
response_format = AudioResponseFormat.WAV
106+
elif ext in ["mp3", "mpeg"]:
107+
response_format = AudioResponseFormat.MP3
108+
elif ext in ["raw", "pcm"]:
109+
response_format = AudioResponseFormat.RAW
110+
else:
111+
# Default to WAV if unknown
112+
response_format = AudioResponseFormat.WAV
113+
114+
if isinstance(response_format, str):
115+
response_format = AudioResponseFormat(response_format)
116+
86117
if isinstance(self.response, TogetherResponse):
87-
# save response to file
118+
# Non-streaming: save complete file
88119
with open(file_path, "wb") as f:
89120
f.write(self.response.data)
90121

91122
elif isinstance(self.response, Iterator):
123+
# Streaming: collect binary chunks
124+
audio_chunks = []
125+
for chunk in self.response:
126+
if isinstance(chunk.data, bytes):
127+
audio_chunks.append(chunk.data)
128+
elif isinstance(chunk.data, dict):
129+
# SSE format with JSON/base64
130+
try:
131+
stream_event = AudioSpeechStreamEventResponse(
132+
response={"data": chunk.data}
133+
)
134+
if isinstance(stream_event.response, StreamSentinel):
135+
break
136+
audio_chunks.append(
137+
base64.b64decode(stream_event.response.data.b64)
138+
)
139+
except Exception:
140+
continue # Skip malformed chunks
141+
142+
if not audio_chunks:
143+
raise ValueError("No audio data received in streaming response")
144+
145+
# Concatenate all chunks
146+
audio_data = b"".join(audio_chunks)
147+
92148
with open(file_path, "wb") as f:
93-
for chunk in self.response:
94-
# Try to parse as stream chunk
95-
stream_event_response = AudioSpeechStreamEventResponse(
96-
response={"data": chunk.data}
149+
if response_format == AudioResponseFormat.WAV:
150+
if audio_data.startswith(b"RIFF"):
151+
# Already a valid WAV file
152+
f.write(audio_data)
153+
else:
154+
# Raw PCM - add WAV header
155+
self._write_wav_header(f, audio_data)
156+
elif response_format == AudioResponseFormat.MP3:
157+
# MP3 format: Check if data is actually MP3 or raw PCM
158+
# MP3 files start with ID3 tag or sync word (0xFF 0xFB/0xFA/0xF3/0xF2)
159+
is_mp3 = audio_data.startswith(b"ID3") or (
160+
len(audio_data) > 0
161+
and audio_data[0:1] == b"\xff"
162+
and len(audio_data) > 1
163+
and audio_data[1] & 0xE0 == 0xE0
97164
)
98165

99-
if isinstance(stream_event_response.response, StreamSentinel):
100-
break
101-
102-
# decode base64
103-
audio = base64.b64decode(stream_event_response.response.data.b64)
104-
105-
f.write(audio)
166+
if is_mp3:
167+
f.write(audio_data)
168+
else:
169+
raise ValueError("Invalid MP3 data received.")
170+
else:
171+
# RAW format: write PCM data as-is
172+
f.write(audio_data)
173+
174+
@staticmethod
175+
def _write_wav_header(file_handle: BinaryIO, audio_data: bytes) -> None:
176+
"""
177+
Write WAV file header for raw PCM audio data.
178+
179+
Uses default TTS parameters: 16-bit PCM, mono, 24000 Hz sample rate.
180+
"""
181+
import struct
182+
183+
sample_rate = 24000
184+
num_channels = 1
185+
bits_per_sample = 16
186+
byte_rate = sample_rate * num_channels * bits_per_sample // 8
187+
block_align = num_channels * bits_per_sample // 8
188+
data_size = len(audio_data)
189+
190+
# Write WAV header
191+
file_handle.write(b"RIFF")
192+
file_handle.write(struct.pack("<I", 36 + data_size)) # File size - 8
193+
file_handle.write(b"WAVE")
194+
file_handle.write(b"fmt ")
195+
file_handle.write(struct.pack("<I", 16)) # fmt chunk size
196+
file_handle.write(struct.pack("<H", 1)) # Audio format (1 = PCM)
197+
file_handle.write(struct.pack("<H", num_channels))
198+
file_handle.write(struct.pack("<I", sample_rate))
199+
file_handle.write(struct.pack("<I", byte_rate))
200+
file_handle.write(struct.pack("<H", block_align))
201+
file_handle.write(struct.pack("<H", bits_per_sample))
202+
file_handle.write(b"data")
203+
file_handle.write(struct.pack("<I", data_size))
204+
file_handle.write(audio_data)
106205

107206

108207
class AudioTranscriptionResponseFormat(str, Enum):

0 commit comments

Comments
 (0)