Skip to content

Commit

Permalink
Add first and last received times to TranscriptionSegment (#485)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidliu authored Aug 21, 2024
1 parent 66a231a commit 11a8a67
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 17 deletions.
5 changes: 5 additions & 0 deletions .changeset/sour-needles-drop.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"client-sdk-android": minor
---

Add first and last received times to TranscriptionSegment
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ import livekit.LivekitRtc
import livekit.org.webrtc.*
import livekit.org.webrtc.audio.AudioDeviceModule
import java.net.URI
import java.util.Date
import javax.inject.Named

class Room
Expand Down Expand Up @@ -271,6 +272,8 @@ constructor(
private var regionUrlProvider: RegionUrlProvider? = null
private var regionUrl: String? = null

private var transcriptionReceivedTimes = mutableMapOf<String, Long>()

private fun getCurrentRoomOptions(): RoomOptions =
RoomOptions(
adaptiveStream = adaptiveStream,
Expand Down Expand Up @@ -1131,10 +1134,24 @@ constructor(
* @suppress
*/
override fun onTranscriptionReceived(transcription: LivekitModels.Transcription) {
if (transcription.segmentsList.isEmpty()) {
LKLog.d { "Received transcription segments are empty." }
return
}

val participant = getParticipantByIdentity(transcription.transcribedParticipantIdentity)
val publication = participant?.trackPublications?.get(transcription.trackId)
val segments = transcription.segmentsList
.map { it.toSDKType() }
.map { it.toSDKType(firstReceivedTime = transcriptionReceivedTimes[it.id] ?: Date().time) }

// Update receive times
for (segment in segments) {
if (segment.final) {
transcriptionReceivedTimes.remove(segment.id)
} else {
transcriptionReceivedTimes[segment.id] = segment.firstReceivedTime
}
}

val event = RoomEvent.TranscriptionReceived(
room = this,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,69 @@ package io.livekit.android.room.types

import io.livekit.android.util.LKLog
import livekit.LivekitModels
import java.util.Date

data class TranscriptionSegment(
/**
* The id of the transcription segment.
*/
val id: String,
/**
* The text of the transcription.
*/
val text: String,
/**
* Language
*/
val language: String,
val startTime: Long,
val endTime: Long,
/**
* If false, the user can expect this transcription to update in the future.
*/
val final: Boolean,
/**
* When this client first locally received this segment.
*
* Defined as milliseconds from epoch date (using [Date.getTime])
*/
val firstReceivedTime: Long = Date().time,
/**
* When this client last locally received this segment.
*
* Defined as milliseconds from epoch date (using [Date.getTime])
*/
val lastReceivedTime: Long = Date().time,
) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
override fun hashCode(): Int {
return id.hashCode()
}
}

other as TranscriptionSegment
/**
* Merges [newSegment] info into this segment if the ids are equal.
*
* Returns `this` if a different segment is passed.
*/
fun TranscriptionSegment?.merge(newSegment: TranscriptionSegment): TranscriptionSegment {
if (this == null) {
return newSegment
}

return id == other.id
if (this.id != newSegment.id) {
return this
}

override fun hashCode(): Int {
return id.hashCode()
if (this.final) {
LKLog.d { "new segment for $id overwriting final segment?" }
}

return copy(
id = this.id,
text = newSegment.text,
language = newSegment.language,
final = newSegment.final,
firstReceivedTime = this.firstReceivedTime,
lastReceivedTime = newSegment.lastReceivedTime,
)
}

/**
Expand All @@ -47,22 +89,18 @@ data class TranscriptionSegment(
fun MutableMap<String, TranscriptionSegment>.mergeNewSegments(newSegments: Collection<TranscriptionSegment>) {
for (segment in newSegments) {
val existingSegment = get(segment.id)
if (existingSegment?.final == true) {
LKLog.d { "new segment for ${segment.id} overwriting final segment?" }
}
put(segment.id, segment)
put(segment.id, existingSegment.merge(segment))
}
}

/**
* @suppress
*/
fun LivekitModels.TranscriptionSegment.toSDKType() =
fun LivekitModels.TranscriptionSegment.toSDKType(firstReceivedTime: Long = Date().time) =
TranscriptionSegment(
id = id,
text = text,
language = language,
startTime = startTime,
endTime = endTime,
final = final,
firstReceivedTime = firstReceivedTime,
)
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ import io.livekit.android.test.mock.MockPeerConnection
import io.livekit.android.test.mock.TestData
import io.livekit.android.test.util.toDataChannelBuffer
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.delay
import kotlinx.coroutines.runBlocking
import org.junit.Assert.assertEquals
import org.junit.Assert.assertTrue
import org.junit.Test

@OptIn(ExperimentalCoroutinesApi::class)
Expand Down Expand Up @@ -97,4 +100,60 @@ class RoomTranscriptionMockE2ETest : MockE2ETest() {
assertIsClass(TrackPublicationEvent.TranscriptionReceived::class.java, publicationEvents[0])
}
}

@Test
fun transcriptionFirstReceivedStaysSame() = runTest {
connect()
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
name = "",
mediaTrack = MockAudioStreamTrack(id = TestData.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
options = LocalAudioTrackOptions(),
audioProcessingController = MockAudioProcessingController(),
dispatcher = coroutineRule.dispatcher,
),
options = AudioTrackPublishOptions(
source = Track.Source.MICROPHONE,
),
)
val subPeerConnection = component.rtcEngine().getSubscriberPeerConnection() as MockPeerConnection
val subDataChannel = MockDataChannel(RTCEngine.RELIABLE_DATA_CHANNEL_LABEL)
subPeerConnection.observer?.onDataChannel(subDataChannel)

val roomCollector = EventCollector(room.events, coroutineRule.scope)

val firstDataBuffer = with(TestData.DATA_PACKET_TRANSCRIPTION.toBuilder()) {
transcription = with(transcription.toBuilder()) {
val firstSegment = with(getSegments(0).toBuilder()) {
text = "first_text"
language = "first_enUS"
text = "This is a not a final transcription."
final = false
build()
}
clearSegments()
addSegments(firstSegment)
build()
}
build()
}.toDataChannelBuffer()
subDataChannel.observer?.onMessage(firstDataBuffer)

runBlocking {
delay(2) // to ensure start and end received times are different.
}
val dataBuffer = TestData.DATA_PACKET_TRANSCRIPTION.toDataChannelBuffer()
subDataChannel.observer?.onMessage(dataBuffer)

val roomEvents = roomCollector.stopCollecting()

assertEquals(2, roomEvents.size)

val first = (roomEvents[0] as RoomEvent.TranscriptionReceived).transcriptionSegments[0]
val final = (roomEvents[1] as RoomEvent.TranscriptionReceived).transcriptionSegments[0]
val expectedSegment = TestData.DATA_PACKET_TRANSCRIPTION.transcription.getSegments(0)
assertEquals(expectedSegment.id, final.id)
assertEquals(final.firstReceivedTime, first.firstReceivedTime)
assertTrue(final.lastReceivedTime > final.firstReceivedTime)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright 2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.livekit.android.room.types

import org.junit.Assert.assertEquals
import org.junit.Test

class TranscriptionSegmentTest {

@Test
fun mergeSegments() {
val first = TranscriptionSegment(
id = "1",
text = "text",
language = "language",
final = false,
firstReceivedTime = 0,
lastReceivedTime = 0,
)

val last = TranscriptionSegment(
id = "1",
text = "newtext",
language = "newlanguage",
final = true,
firstReceivedTime = 100,
lastReceivedTime = 100,
)

val merged = first.merge(last)

val expected = TranscriptionSegment(
id = "1",
text = "newtext",
language = "newlanguage",
final = true,
firstReceivedTime = 0,
lastReceivedTime = 100,
)
assertEquals(expected, merged)
}
}

0 comments on commit 11a8a67

Please sign in to comment.