Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/curly-kangaroos-accept.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@livekit/agents": patch
"@livekit/agents-plugin-deepgram": patch
---

fix resource cleanup
4 changes: 4 additions & 0 deletions agents/src/stt/stt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ export abstract class STT extends (EventEmitter as new () => TypedEmitter<STTCal
* transcriptions
*/
abstract stream(): SpeechStream;

async close(): Promise<void> {
return;
}
}

/**
Expand Down
4 changes: 4 additions & 0 deletions agents/src/tts/tts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ export abstract class TTS extends (EventEmitter as new () => TypedEmitter<TTSCal
* Returns a {@link SynthesizeStream} that can be used to push text and receive audio data
*/
abstract stream(): SynthesizeStream;

async close(): Promise<void> {
return;
}
}

/**
Expand Down
4 changes: 4 additions & 0 deletions agents/src/vad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ export abstract class VAD extends (EventEmitter as new () => TypedEmitter<VADCal
* Returns a {@link VADStream} that can be used to push audio frames and receive VAD events.
*/
abstract stream(): VADStream;

async close(): Promise<void> {
return;
}
}

export abstract class VADStream implements AsyncIterableIterator<VADEvent> {
Expand Down
3 changes: 3 additions & 0 deletions agents/src/voice/agent_activity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2141,12 +2141,15 @@ export class AgentActivity implements RecognitionHooks {
}
if (this.stt instanceof STT) {
this.stt.off('metrics_collected', this.onMetricsCollected);
await this.stt.close();
}
if (this.tts instanceof TTS) {
this.tts.off('metrics_collected', this.onMetricsCollected);
await this.tts.close();
}
if (this.vad instanceof VAD) {
this.vad.off('metrics_collected', this.onMetricsCollected);
await this.vad.close();
}

this.detachAudioInput();
Expand Down
163 changes: 105 additions & 58 deletions plugins/deepgram/src/stt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ import {
AudioByteStream,
AudioEnergyFilter,
Future,
Task,
log,
stt,
waitForAbort,
} from '@livekit/agents';
import type { AudioFrame } from '@livekit/rtc-node';
import { type RawData, WebSocket } from 'ws';
import { WebSocket } from 'ws';
import { PeriodicCollector } from './_utils.js';
import type { STTLanguages, STTModels } from './models.js';

Expand Down Expand Up @@ -62,6 +64,7 @@ export class STT extends stt.STT {
#opts: STTOptions;
#logger = log();
label = 'deepgram.STT';
private abortController = new AbortController();

constructor(opts: Partial<STTOptions> = defaultSTTOptions) {
super({
Expand Down Expand Up @@ -111,7 +114,11 @@ export class STT extends stt.STT {
}

stream(): SpeechStream {
return new SpeechStream(this, this.#opts);
return new SpeechStream(this, this.#opts, this.abortController);
}

async close() {
this.abortController.abort();
}
}

Expand All @@ -125,7 +132,11 @@ export class SpeechStream extends stt.SpeechStream {
#audioDurationCollector: PeriodicCollector<number>;
label = 'deepgram.SpeechStream';

constructor(stt: STT, opts: STTOptions) {
constructor(
stt: STT,
opts: STTOptions,
private abortController: AbortController,
) {
super(stt, opts.sampleRate);
this.#opts = opts;
this.closed = false;
Expand All @@ -140,7 +151,8 @@ export class SpeechStream extends stt.SpeechStream {
const maxRetry = 32;
let retries = 0;
let ws: WebSocket;
while (!this.input.closed) {

while (!this.input.closed && !this.closed) {
const streamURL = new URL(API_BASE_URL_V1);
const params = {
model: this.#opts.model,
Expand Down Expand Up @@ -185,17 +197,23 @@ export class SpeechStream extends stt.SpeechStream {

await this.#runWS(ws);
} catch (e) {
if (retries >= maxRetry) {
throw new Error(`failed to connect to Deepgram after ${retries} attempts: ${e}`);
}
if (!this.closed && !this.input.closed) {
if (retries >= maxRetry) {
throw new Error(`failed to connect to Deepgram after ${retries} attempts: ${e}`);
}

const delay = Math.min(retries * 5, 10);
retries++;
const delay = Math.min(retries * 5, 10);
retries++;

this.#logger.warn(
`failed to connect to Deepgram, retrying in ${delay} seconds: ${e} (${retries}/${maxRetry})`,
);
await new Promise((resolve) => setTimeout(resolve, delay * 1000));
this.#logger.warn(
`failed to connect to Deepgram, retrying in ${delay} seconds: ${e} (${retries}/${maxRetry})`,
);
await new Promise((resolve) => setTimeout(resolve, delay * 1000));
} else {
this.#logger.warn(
`Deepgram disconnected, connection is closed: ${e} (inputClosed: ${this.input.closed}, isClosed: ${this.closed})`,
);
}
}
}

Expand All @@ -220,6 +238,20 @@ export class SpeechStream extends stt.SpeechStream {
}
}, 5000);

// gets cancelled also when sendTask is complete
const wsMonitor = Task.from(async (controller) => {
const closed = new Promise<void>(async (_, reject) => {
ws.once('close', (code, reason) => {
if (!closing) {
this.#logger.error(`WebSocket closed with code ${code}: ${reason}`);
reject(new Error('WebSocket closed'));
}
});
});

await Promise.race([closed, waitForAbort(controller.signal)]);
});

const sendTask = async () => {
const samples100Ms = Math.floor(this.#opts.sampleRate / 10);
const stream = new AudioByteStream(
Expand All @@ -228,48 +260,52 @@ export class SpeechStream extends stt.SpeechStream {
samples100Ms,
);

for await (const data of this.input) {
let frames: AudioFrame[];
if (data === SpeechStream.FLUSH_SENTINEL) {
frames = stream.flush();
this.#audioDurationCollector.flush();
} else if (
data.sampleRate === this.#opts.sampleRate ||
data.channels === this.#opts.numChannels
) {
frames = stream.write(data.data.buffer);
} else {
throw new Error(`sample rate or channel count of frame does not match`);
}
try {
while (!this.closed) {
const result = await Promise.race([
this.input.next(),
waitForAbort(this.abortController.signal),
]);

if (result === undefined) return; // aborted
if (result.done) {
break;
}

const data = result.value;

let frames: AudioFrame[];
if (data === SpeechStream.FLUSH_SENTINEL) {
frames = stream.flush();
this.#audioDurationCollector.flush();
} else if (
data.sampleRate === this.#opts.sampleRate ||
data.channels === this.#opts.numChannels
) {
frames = stream.write(data.data.buffer as ArrayBuffer);
} else {
throw new Error(`sample rate or channel count of frame does not match`);
}

for await (const frame of frames) {
if (this.#audioEnergyFilter.pushFrame(frame)) {
const frameDuration = frame.samplesPerChannel / frame.sampleRate;
this.#audioDurationCollector.push(frameDuration);
ws.send(frame.data.buffer);
for await (const frame of frames) {
if (this.#audioEnergyFilter.pushFrame(frame)) {
const frameDuration = frame.samplesPerChannel / frame.sampleRate;
this.#audioDurationCollector.push(frameDuration);
ws.send(frame.data.buffer);
}
}
}
} finally {
closing = true;
ws.send(JSON.stringify({ type: 'CloseStream' }));
wsMonitor.cancel();
}

closing = true;
ws.send(JSON.stringify({ type: 'CloseStream' }));
};

const wsMonitor = new Promise<void>((_, reject) =>
ws.once('close', (code, reason) => {
if (!closing) {
this.#logger.error(`WebSocket closed with code ${code}: ${reason}`);
reject(new Error('WebSocket closed'));
}
}),
);

const listenTask = async () => {
while (!this.closed && !closing) {
try {
await new Promise<RawData>((resolve) => {
ws.once('message', (data) => resolve(data));
}).then((msg) => {
const listenTask = Task.from(async (controller) => {
const listenMessage = new Promise<void>((resolve, reject) => {
ws.on('message', (msg) => {
try {
const json = JSON.parse(msg.toString());
switch (json['type']) {
case 'SpeechStarted': {
Expand Down Expand Up @@ -300,7 +336,9 @@ export class SpeechStream extends stt.SpeechStream {
if (alternatives[0] && alternatives[0].text) {
if (!this.#speaking) {
this.#speaking = true;
this.queue.put({ type: stt.SpeechEventType.START_OF_SPEECH });
this.queue.put({
type: stt.SpeechEventType.START_OF_SPEECH,
});
}

if (isFinal) {
Expand Down Expand Up @@ -334,15 +372,24 @@ export class SpeechStream extends stt.SpeechStream {
break;
}
}
});
} catch (error) {
this.#logger.child({ error }).warn('unrecoverable error, exiting');
break;
}
}
};

await Promise.race([this.#resetWS.await, Promise.all([sendTask(), listenTask(), wsMonitor])]);
if (this.closed || closing) {
resolve();
}
} catch (err) {
this.#logger.error(`STT: Error processing message: ${msg}`);
reject(err);
}
});
});

await Promise.race([listenMessage, waitForAbort(controller.signal)]);
}, this.abortController);

await Promise.race([
this.#resetWS.await,
Promise.all([sendTask(), listenTask.result, wsMonitor]),
]);
closing = true;
ws.close();
clearInterval(keepalive);
Expand Down