diff --git a/FirebaseAI/Tests/TestApp/Tests/Integration/LiveSessionTests.swift b/FirebaseAI/Tests/TestApp/Tests/Integration/LiveSessionTests.swift index 7e1ceeb5751..8e6e61f35c4 100644 --- a/FirebaseAI/Tests/TestApp/Tests/Integration/LiveSessionTests.swift +++ b/FirebaseAI/Tests/TestApp/Tests/Integration/LiveSessionTests.swift @@ -48,10 +48,7 @@ struct LiveSessionTests { ), ]), ] - private let textConfig = LiveGenerationConfig( - responseModalities: [.text] - ) - private let audioConfig = LiveGenerationConfig( + private let generationConfig = LiveGenerationConfig( responseModalities: [.audio], outputAudioTranscription: AudioTranscriptionConfig() ) @@ -76,8 +73,8 @@ struct LiveSessionTests { role: "system", parts: """ When you receive a message, if the message is a single word, assume it's the first name of a \ - person, and call the getLastName tool to get the last name of said person. Only respond with \ - the last name. + person, and call the getLastName tool to get the last name of said person. Once you get the \ + response, say the response. """.trimmingCharacters(in: .whitespacesAndNewlines) ) @@ -95,7 +92,7 @@ struct LiveSessionTests { modelName: String) async throws { let model = FirebaseAI.componentInstance(config).liveModel( modelName: modelName, - generationConfig: audioConfig, + generationConfig: generationConfig, systemInstruction: SystemInstructions.helloGoodbye ) @@ -119,15 +116,12 @@ struct LiveSessionTests { #expect(modelResponse == "goodbye") } - @Test( - .disabled("Temporarily disabled"), - .bug("https://github.com/firebase/firebase-ios-sdk/issues/15640"), - arguments: arguments - ) - func sendVideoRealtime_receiveText(_ config: InstanceConfig, modelName: String) async throws { + @Test(arguments: arguments) + func sendVideoRealtime_receiveAudioOutputTranscripts(_ config: InstanceConfig, + modelName: String) async throws { let model = FirebaseAI.componentInstance(config).liveModel( modelName: modelName, - generationConfig: textConfig, + generationConfig: generationConfig, systemInstruction: SystemInstructions.animalInVideo ) @@ -152,7 +146,7 @@ struct LiveSessionTests { await session.sendAudioRealtime(audioFile.data) await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count)) - let text = try await session.collectNextTextResponse() + let text = try await session.collectNextAudioOutputTranscript() await session.close() let modelResponse = text @@ -164,15 +158,11 @@ struct LiveSessionTests { #expect(["kitten", "cat", "kitty"].contains(modelResponse)) } - @Test( - .disabled("Temporarily disabled"), - .bug("https://github.com/firebase/firebase-ios-sdk/issues/15640"), - arguments: arguments - ) + @Test(arguments: arguments) func realtime_functionCalling(_ config: InstanceConfig, modelName: String) async throws { let model = FirebaseAI.componentInstance(config).liveModel( modelName: modelName, - generationConfig: textConfig, + generationConfig: generationConfig, tools: tools, systemInstruction: SystemInstructions.lastNames ) @@ -200,11 +190,9 @@ struct LiveSessionTests { functionId: functionCall.functionId ), ]) - - var text = try await session.collectNextTextResponse() + var text = try await session.collectNextAudioOutputTranscript() if text.isEmpty { - // The model sometimes sends an empty text response first - text = try await session.collectNextTextResponse() + text = try await session.collectNextAudioOutputTranscript() } await session.close() @@ -217,8 +205,6 @@ struct LiveSessionTests { } @Test( - .disabled("Temporarily disabled"), - .bug("https://github.com/firebase/firebase-ios-sdk/issues/15640"), arguments: arguments.filter { // TODO: (b/450982184) Remove when Vertex AI adds support for Function IDs and Cancellation switch $0.0.apiConfig.service { @@ -233,7 +219,7 @@ struct LiveSessionTests { modelName: String) async throws { let model = FirebaseAI.componentInstance(config).liveModel( modelName: modelName, - generationConfig: textConfig, + generationConfig: generationConfig, tools: tools, systemInstruction: SystemInstructions.lastNames ) @@ -266,7 +252,7 @@ struct LiveSessionTests { func realtime_interruption(_ config: InstanceConfig, modelName: String) async throws { let model = FirebaseAI.componentInstance(config).liveModel( modelName: modelName, - generationConfig: audioConfig + generationConfig: generationConfig ) let audioFile = try #require( @@ -295,15 +281,11 @@ struct LiveSessionTests { } } - @Test( - .disabled("Temporarily disabled"), - .bug("https://github.com/firebase/firebase-ios-sdk/issues/15640"), - arguments: arguments - ) + @Test(arguments: arguments) func incremental_works(_ config: InstanceConfig, modelName: String) async throws { let model = FirebaseAI.componentInstance(config).liveModel( modelName: modelName, - generationConfig: textConfig, + generationConfig: generationConfig, systemInstruction: SystemInstructions.yesOrNo ) @@ -311,7 +293,11 @@ struct LiveSessionTests { await session.sendContent("Does five plus") await session.sendContent(" five equal ten?", turnComplete: true) - let text = try await session.collectNextTextResponse() + var text = try await session.collectNextAudioOutputTranscript() + if text.isEmpty { + // The model sometimes sends an empty text response first + text = try await session.collectNextAudioOutputTranscript() + } await session.close() let modelResponse = text @@ -339,26 +325,6 @@ struct LiveSessionTests { } private extension LiveSession { - /// Collects the text that the model sends for the next turn. - /// - /// Will listen for `LiveServerContent` messages from the model, - /// incrementally keeping track of any `TextPart`s it sends. Once - /// the model signals that its turn is complete, the function will return - /// a string concatenated of all the `TextPart`s. - func collectNextTextResponse() async throws -> String { - var text = "" - - for try await content in responsesOf(LiveServerContent.self) { - text += content.modelTurn?.allText() ?? "" - - if content.isTurnComplete { - break - } - } - - return text - } - /// Collects the audio output transcripts that the model sends for the next turn. /// /// Will listen for `LiveServerContent` messages from the model, @@ -395,11 +361,7 @@ private extension LiveSession { case let .toolCall(toolCall): return toolCall case let .content(content): - if let text = content.modelTurn?.allText() { - error += text - } else { - error += content.outputAudioText() - } + error += content.outputAudioText() if content.isTurnComplete { Issue.record("The model didn't send a tool call. Text received: \(error)") @@ -464,16 +426,6 @@ private struct NoInterruptionError: Error, var description: String { "The model never sent an interrupted message." } } -private extension ModelContent { - /// A collection of text from all parts. - /// - /// If this doesn't contain any `TextPart`, then an empty - /// string will be returned instead. - func allText() -> String { - parts.compactMap { ($0 as? TextPart)?.text }.joined() - } -} - extension LiveServerContent { /// Text of the output `LiveAudioTranscript`, or an empty string if it's missing. func outputAudioText() -> String {