diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/room/RTCEngine.kt b/livekit-android-sdk/src/main/java/io/livekit/android/room/RTCEngine.kt index 2ebd7ede..c52a5b9b 100644 --- a/livekit-android-sdk/src/main/java/io/livekit/android/room/RTCEngine.kt +++ b/livekit-android-sdk/src/main/java/io/livekit/android/room/RTCEngine.kt @@ -65,7 +65,7 @@ internal constructor( private val pctFactory: PeerConnectionTransport.Factory, @Named(InjectionNames.DISPATCHER_IO) private val ioDispatcher: CoroutineDispatcher, -) : SignalClient.Listener, DataChannel.Observer { +) : SignalClient.Listener { internal var listener: Listener? = null /** @@ -218,7 +218,7 @@ internal constructor( LOSSY_DATA_CHANNEL_LABEL -> lossyDataChannelSub = dataChannel else -> return@onDataChannel } - dataChannel.registerObserver(this) + dataChannel.registerObserver(DataChannelObserver(dataChannel)) } subscriberObserver.connectionChangeListener = connectionStateListener @@ -239,7 +239,9 @@ internal constructor( createDataChannel( RELIABLE_DATA_CHANNEL_LABEL, reliableInit, - ).apply { registerObserver(this@RTCEngine) } + ).also { dataChannel -> + dataChannel.registerObserver(DataChannelObserver(dataChannel)) + } } val lossyInit = DataChannel.Init() @@ -249,7 +251,9 @@ internal constructor( createDataChannel( LOSSY_DATA_CHANNEL_LABEL, lossyInit, - ).apply { registerObserver(this@RTCEngine) } + ).also { dataChannel -> + dataChannel.registerObserver(DataChannelObserver(dataChannel)) + } } } @@ -684,8 +688,11 @@ internal constructor( } companion object { - private const val RELIABLE_DATA_CHANNEL_LABEL = "_reliable" - private const val LOSSY_DATA_CHANNEL_LABEL = "_lossy" + @VisibleForTesting + internal const val RELIABLE_DATA_CHANNEL_LABEL = "_reliable" + + @VisibleForTesting + internal const val LOSSY_DATA_CHANNEL_LABEL = "_lossy" internal const val MAX_DATA_PACKET_SIZE = 15000 private const val MAX_RECONNECT_RETRIES = 10 private const val MAX_RECONNECT_TIMEOUT = 60 * 1000 @@ -883,13 +890,13 @@ internal constructor( // --------------------------------- DataChannel.Observer ------------------------------------// - override fun onBufferedAmountChange(previousAmount: Long) { + fun onBufferedAmountChange(dataChannel: DataChannel, previousAmount: Long) { } - override fun onStateChange() { + fun onStateChange(dataChannel: DataChannel) { } - override fun onMessage(buffer: DataChannel.Buffer?) { + fun onMessage(dataChannel: DataChannel, buffer: DataChannel.Buffer?) { if (buffer == null) { return } @@ -911,6 +918,20 @@ internal constructor( } } + private inner class DataChannelObserver(val dataChannel: DataChannel) : DataChannel.Observer { + override fun onBufferedAmountChange(p0: Long) { + this@RTCEngine.onBufferedAmountChange(dataChannel, p0) + } + + override fun onStateChange() { + this@RTCEngine.onStateChange(dataChannel) + } + + override fun onMessage(p0: DataChannel.Buffer) { + this@RTCEngine.onMessage(dataChannel, p0) + } + } + fun sendSyncState( subscription: LivekitRtc.UpdateSubscription, publishedTracks: List, diff --git a/livekit-android-sdk/src/test/java/io/livekit/android/MockE2ETest.kt b/livekit-android-sdk/src/test/java/io/livekit/android/MockE2ETest.kt index f8a1188b..c223496a 100644 --- a/livekit-android-sdk/src/test/java/io/livekit/android/MockE2ETest.kt +++ b/livekit-android-sdk/src/test/java/io/livekit/android/MockE2ETest.kt @@ -35,9 +35,12 @@ import okhttp3.Request import okhttp3.Response import okio.ByteString import org.junit.Before +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner import org.webrtc.PeerConnection @ExperimentalCoroutinesApi +@RunWith(RobolectricTestRunner::class) abstract class MockE2ETest : BaseTest() { internal lateinit var component: TestLiveKitComponent diff --git a/livekit-android-sdk/src/test/java/io/livekit/android/room/RoomDataMockE2ETest.kt b/livekit-android-sdk/src/test/java/io/livekit/android/room/RoomDataMockE2ETest.kt new file mode 100644 index 00000000..615f27db --- /dev/null +++ b/livekit-android-sdk/src/test/java/io/livekit/android/room/RoomDataMockE2ETest.kt @@ -0,0 +1,65 @@ +/* + * Copyright 2023 LiveKit, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.livekit.android.room + +import com.google.protobuf.ByteString +import io.livekit.android.MockE2ETest +import io.livekit.android.assert.assertIsClass +import io.livekit.android.events.EventCollector +import io.livekit.android.events.RoomEvent +import io.livekit.android.mock.MockDataChannel +import io.livekit.android.mock.MockPeerConnection +import kotlinx.coroutines.ExperimentalCoroutinesApi +import livekit.LivekitModels.DataPacket +import livekit.LivekitModels.UserPacket +import org.junit.Assert.assertEquals +import org.junit.Test +import org.webrtc.DataChannel +import java.nio.ByteBuffer + +@OptIn(ExperimentalCoroutinesApi::class) +class RoomDataMockE2ETest : MockE2ETest() { + @Test + fun dataReceivedEvent() = runTest { + connect() + val subPeerConnection = component.rtcEngine().getSubscriberPeerConnection() as MockPeerConnection + val subDataChannel = MockDataChannel(RTCEngine.RELIABLE_DATA_CHANNEL_LABEL) + subPeerConnection.observer?.onDataChannel(subDataChannel) + + val collector = EventCollector(room.events, coroutineRule.scope) + val dataPacket = with(DataPacket.newBuilder()) { + user = with(UserPacket.newBuilder()) { + payload = ByteString.copyFrom("hello", Charsets.UTF_8) + build() + } + build() + } + val dataBuffer = DataChannel.Buffer( + ByteBuffer.wrap(dataPacket.toByteArray()), + true + ) + + subDataChannel.observer?.onMessage(dataBuffer) + val events = collector.stopCollecting() + + assertEquals(1, events.size) + assertIsClass(RoomEvent.DataReceived::class.java, events[0]) + + val event = events[0] as RoomEvent.DataReceived + assertEquals("hello", event.data.decodeToString()) + } +}