Skip to content

Commit 823b8b3

Browse files
feat: support for mediator live mode (websocket) (#147)
Co-authored-by: Ahmed Moussa <[email protected]> Signed-off-by: Cristian G <[email protected]>
1 parent 5810477 commit 823b8b3

File tree

9 files changed

+382
-35
lines changed

9 files changed

+382
-35
lines changed

Diff for: atala-prism-sdk/build.gradle.kts

+5
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ kotlin {
106106
implementation("io.ktor:ktor-client-content-negotiation:2.3.4")
107107
implementation("io.ktor:ktor-serialization-kotlinx-json:2.3.4")
108108
implementation("io.ktor:ktor-client-logging:2.3.4")
109+
implementation("io.ktor:ktor-websockets:2.3.4")
109110

110111
implementation("io.iohk.atala.prism.didcomm:didpeer:$didpeerVersion")
111112

@@ -135,12 +136,15 @@ kotlin {
135136
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.8.0")
136137
implementation("io.ktor:ktor-client-mock:2.3.4")
137138
implementation("junit:junit:4.13.2")
139+
implementation("org.mockito:mockito-core:4.4.0")
140+
implementation("org.mockito.kotlin:mockito-kotlin:4.0.0")
138141
}
139142
}
140143
val jvmMain by getting {
141144
dependencies {
142145
implementation("io.ktor:ktor-client-okhttp:2.3.4")
143146
implementation("app.cash.sqldelight:sqlite-driver:2.0.1")
147+
implementation("io.ktor:ktor-client-java:2.3.4")
144148
}
145149
}
146150
val jvmTest by getting
@@ -149,6 +153,7 @@ kotlin {
149153
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.8.0")
150154
implementation("io.ktor:ktor-client-okhttp:2.3.4")
151155
implementation("app.cash.sqldelight:android-driver:2.0.1")
156+
implementation("io.ktor:ktor-client-android:2.3.4")
152157
}
153158
}
154159
val androidInstrumentedTest by getting {

Diff for: atala-prism-sdk/src/commonMain/kotlin/io/iohk/atala/prism/walletsdk/prismagent/ConnectionManager.kt

+99-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
@file:Suppress("ktlint:standard:import-ordering")
2+
13
package io.iohk.atala.prism.walletsdk.prismagent
24

35
import io.iohk.atala.prism.walletsdk.domain.buildingblocks.Castor
@@ -9,8 +11,14 @@ import io.iohk.atala.prism.walletsdk.domain.models.Message
911
import io.iohk.atala.prism.walletsdk.prismagent.connectionsmanager.ConnectionsManager
1012
import io.iohk.atala.prism.walletsdk.prismagent.connectionsmanager.DIDCommConnection
1113
import io.iohk.atala.prism.walletsdk.prismagent.mediation.MediationHandler
14+
import java.time.Duration
15+
import kotlinx.coroutines.CoroutineScope
16+
import kotlinx.coroutines.Dispatchers
17+
import kotlinx.coroutines.Job
18+
import kotlinx.coroutines.delay
1219
import kotlinx.coroutines.flow.Flow
1320
import kotlinx.coroutines.flow.first
21+
import kotlinx.coroutines.launch
1422
import kotlin.jvm.Throws
1523

1624
/**
@@ -27,9 +35,99 @@ class ConnectionManager(
2735
private val castor: Castor,
2836
private val pluto: Pluto,
2937
internal val mediationHandler: MediationHandler,
30-
private var pairings: MutableList<DIDPair>
38+
private var pairings: MutableList<DIDPair>,
39+
private val scope: CoroutineScope = CoroutineScope(Dispatchers.IO)
3140
) : ConnectionsManager, DIDCommConnection {
3241

42+
var fetchingMessagesJob: Job? = null
43+
44+
/**
45+
* Starts the process of fetching messages at a regular interval.
46+
*
47+
* @param requestInterval The time interval (in seconds) between message fetch requests.
48+
* Defaults to 5 seconds if not specified.
49+
*/
50+
@JvmOverloads
51+
fun startFetchingMessages(requestInterval: Int = 5) {
52+
// Check if the job for fetching messages is already running
53+
if (fetchingMessagesJob == null) {
54+
// Launch a coroutine in the provided scope
55+
fetchingMessagesJob = scope.launch {
56+
// Retrieve the current mediator DID
57+
val currentMediatorDID = mediationHandler.mediatorDID
58+
// Resolve the DID document for the mediator
59+
val mediatorDidDoc = castor.resolveDID(currentMediatorDID.toString())
60+
var serviceEndpoint: String? = null
61+
62+
// Loop through the services in the DID document to find a WebSocket endpoint
63+
mediatorDidDoc.services.forEach {
64+
if (it.serviceEndpoint.uri.contains("wss://") || it.serviceEndpoint.uri.contains("ws://")) {
65+
serviceEndpoint = it.serviceEndpoint.uri
66+
return@forEach // Exit loop once the WebSocket endpoint is found
67+
}
68+
}
69+
70+
// If a WebSocket service endpoint is found
71+
serviceEndpoint?.let { serviceEndpointUrl ->
72+
// Listen for unread messages on the WebSocket endpoint
73+
mediationHandler.listenUnreadMessages(
74+
serviceEndpointUrl
75+
) { arrayMessages ->
76+
// Process the received messages
77+
val messagesIds = mutableListOf<String>()
78+
val messages = mutableListOf<Message>()
79+
arrayMessages.map { pair ->
80+
messagesIds.add(pair.first)
81+
messages.add(pair.second)
82+
}
83+
// If there are any messages, mark them as read and store them
84+
scope.launch {
85+
if (messagesIds.isNotEmpty()) {
86+
mediationHandler.registerMessagesAsRead(
87+
messagesIds.toTypedArray()
88+
)
89+
pluto.storeMessages(messages)
90+
}
91+
}
92+
}
93+
}
94+
95+
// Fallback mechanism if no WebSocket service endpoint is available
96+
if (serviceEndpoint == null) {
97+
while (true) {
98+
// Continuously await and process new messages
99+
awaitMessages().collect { array ->
100+
val messagesIds = mutableListOf<String>()
101+
val messages = mutableListOf<Message>()
102+
array.map { pair ->
103+
messagesIds.add(pair.first)
104+
messages.add(pair.second)
105+
}
106+
if (messagesIds.isNotEmpty()) {
107+
mediationHandler.registerMessagesAsRead(
108+
messagesIds.toTypedArray()
109+
)
110+
pluto.storeMessages(messages)
111+
}
112+
}
113+
// Wait for the specified request interval before fetching new messages
114+
delay(Duration.ofSeconds(requestInterval.toLong()).toMillis())
115+
}
116+
}
117+
}
118+
119+
// Start the coroutine if it's not already active
120+
fetchingMessagesJob?.let {
121+
if (it.isActive) return
122+
it.start()
123+
}
124+
}
125+
}
126+
127+
fun stopConnection() {
128+
fetchingMessagesJob?.cancel()
129+
}
130+
33131
/**
34132
* Suspends the current coroutine and boots the registered mediator associated with the mediator handler.
35133
* If no mediator is available, a [PrismAgentError.NoMediatorAvailableError] is thrown.

Diff for: atala-prism-sdk/src/commonMain/kotlin/io/iohk/atala/prism/walletsdk/prismagent/PrismAgent.kt

+2-32
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,8 @@ import io.ktor.http.HttpMethod
6767
import io.ktor.http.Url
6868
import io.ktor.serialization.kotlinx.json.json
6969
import java.net.UnknownHostException
70-
import java.time.Duration
7170
import kotlinx.coroutines.CoroutineScope
7271
import kotlinx.coroutines.Dispatchers
73-
import kotlinx.coroutines.Job
74-
import kotlinx.coroutines.delay
7572
import kotlinx.coroutines.flow.Flow
7673
import kotlinx.coroutines.flow.MutableSharedFlow
7774
import kotlinx.coroutines.flow.first
@@ -116,7 +113,6 @@ class PrismAgent {
116113
val pluto: Pluto
117114
val mercury: Mercury
118115
val pollux: Pollux
119-
var fetchingMessagesJob: Job? = null
120116
val flowState = MutableSharedFlow<State>()
121117

122118
private val prismAgentScope: CoroutineScope = CoroutineScope(Dispatchers.Default)
@@ -298,7 +294,6 @@ class PrismAgent {
298294
}
299295
logger.info(message = "Stoping agent")
300296
state = State.STOPPING
301-
fetchingMessagesJob?.cancel()
302297
state = State.STOPPED
303298
logger.info(message = "Agent not running")
304299
}
@@ -724,40 +719,15 @@ class PrismAgent {
724719
*/
725720
@JvmOverloads
726721
fun startFetchingMessages(requestInterval: Int = 5) {
727-
if (fetchingMessagesJob == null) {
728-
logger.info(message = "Start streaming new unread messages")
729-
fetchingMessagesJob = prismAgentScope.launch {
730-
while (true) {
731-
connectionManager.awaitMessages().collect { array ->
732-
val messagesIds = mutableListOf<String>()
733-
val messages = mutableListOf<Message>()
734-
array.map { pair ->
735-
messagesIds.add(pair.first)
736-
messages.add(pair.second)
737-
}
738-
if (messagesIds.isNotEmpty()) {
739-
connectionManager.mediationHandler.registerMessagesAsRead(
740-
messagesIds.toTypedArray()
741-
)
742-
pluto.storeMessages(messages)
743-
}
744-
}
745-
delay(Duration.ofSeconds(requestInterval.toLong()).toMillis())
746-
}
747-
}
748-
}
749-
fetchingMessagesJob?.let {
750-
if (it.isActive) return
751-
it.start()
752-
}
722+
connectionManager.startFetchingMessages(requestInterval)
753723
}
754724

755725
/**
756726
* Stop fetching messages
757727
*/
758728
fun stopFetchingMessages() {
759729
logger.info(message = "Stop streaming new unread messages")
760-
fetchingMessagesJob?.cancel()
730+
connectionManager.stopConnection()
761731
}
762732

763733
/**

Diff for: atala-prism-sdk/src/commonMain/kotlin/io/iohk/atala/prism/walletsdk/prismagent/mediation/BasicMediatorHandler.kt

+83-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
@file:Suppress("ktlint:standard:import-ordering")
2+
13
package io.iohk.atala.prism.walletsdk.prismagent.mediation
24

35
import io.iohk.atala.prism.walletsdk.domain.buildingblocks.Mercury
@@ -7,16 +9,24 @@ import io.iohk.atala.prism.walletsdk.domain.models.Mediator
79
import io.iohk.atala.prism.walletsdk.domain.models.Message
810
import io.iohk.atala.prism.walletsdk.domain.models.UnknownError
911
import io.iohk.atala.prism.walletsdk.prismagent.PrismAgentError
12+
import io.iohk.atala.prism.walletsdk.prismagent.protocols.ProtocolType
1013
import io.iohk.atala.prism.walletsdk.prismagent.protocols.mediation.MediationGrant
1114
import io.iohk.atala.prism.walletsdk.prismagent.protocols.mediation.MediationKeysUpdateList
1215
import io.iohk.atala.prism.walletsdk.prismagent.protocols.mediation.MediationRequest
1316
import io.iohk.atala.prism.walletsdk.prismagent.protocols.pickup.PickupReceived
1417
import io.iohk.atala.prism.walletsdk.prismagent.protocols.pickup.PickupRequest
1518
import io.iohk.atala.prism.walletsdk.prismagent.protocols.pickup.PickupRunner
19+
import io.ktor.client.HttpClient
20+
import io.ktor.client.plugins.HttpTimeout
21+
import io.ktor.client.plugins.websocket.WebSockets
22+
import io.ktor.client.plugins.websocket.webSocket
23+
import io.ktor.websocket.Frame
24+
import io.ktor.websocket.readText
1625
import kotlinx.coroutines.flow.Flow
1726
import kotlinx.coroutines.flow.first
1827
import kotlinx.coroutines.flow.flow
1928
import java.util.UUID
29+
import kotlinx.coroutines.isActive
2030

2131
/**
2232
* A class that provides an implementation of [MediationHandler] using a Pluto instance and a Mercury instance. It can
@@ -84,7 +94,8 @@ class BasicMediatorHandler(
8494
val registeredMediator = bootRegisteredMediator()
8595
if (registeredMediator == null) {
8696
try {
87-
val requestMessage = MediationRequest(from = host, to = mediatorDID).makeMessage()
97+
val requestMessage =
98+
MediationRequest(from = host, to = mediatorDID).makeMessage()
8899
val message = mercury.sendMessageParseResponse(message = requestMessage)
89100
?: throw UnknownError.SomethingWentWrongError(
90101
message = "BasicMediatorHandler => mercury.sendMessageParseResponse returned null"
@@ -167,4 +178,75 @@ class BasicMediatorHandler(
167178
} ?: throw PrismAgentError.NoMediatorAvailableError()
168179
mercury.sendMessage(requestMessage)
169180
}
181+
182+
/**
183+
* Listens for unread messages from a specified WebSocket service endpoint.
184+
*
185+
* This function creates a WebSocket connection to the provided service endpoint URI
186+
* and listens for incoming messages. Upon receiving messages, it processes and
187+
* dispatches them to the specified callback function.
188+
*
189+
* @param serviceEndpointUri The URI of the service endpoint. It should be a valid WebSocket URI.
190+
* @param onMessageCallback A callback function that is invoked when a message is received.
191+
* This function is responsible for handling the incoming message.
192+
*/
193+
override suspend fun listenUnreadMessages(
194+
serviceEndpointUri: String,
195+
onMessageCallback: OnMessageCallback
196+
) {
197+
val client = HttpClient {
198+
install(WebSockets)
199+
install(HttpTimeout) {
200+
requestTimeoutMillis = WEBSOCKET_TIMEOUT
201+
connectTimeoutMillis = WEBSOCKET_TIMEOUT
202+
socketTimeoutMillis = WEBSOCKET_TIMEOUT
203+
}
204+
}
205+
if (serviceEndpointUri.contains("wss://") || serviceEndpointUri.contains("ws://")) {
206+
client.webSocket(serviceEndpointUri) {
207+
if (isActive) {
208+
val liveDeliveryMessage = Message(
209+
body = "{\"live_delivery\":true}",
210+
piuri = ProtocolType.LiveDeliveryChange.value,
211+
id = UUID.randomUUID().toString(),
212+
from = mediator?.hostDID,
213+
to = mediatorDID
214+
)
215+
val packedMessage = mercury.packMessage(liveDeliveryMessage)
216+
send(Frame.Text(packedMessage))
217+
}
218+
while (isActive) {
219+
try {
220+
for (frame in incoming) {
221+
if (frame is Frame.Text) {
222+
val messages =
223+
handleReceivedMessagesFromSockets(frame.readText())
224+
onMessageCallback.onMessage(messages)
225+
}
226+
}
227+
} catch (e: Exception) {
228+
e.printStackTrace()
229+
continue
230+
}
231+
}
232+
}
233+
}
234+
}
235+
236+
private suspend fun handleReceivedMessagesFromSockets(text: String): Array<Pair<String, Message>> {
237+
val decryptedMessage = mercury.unpackMessage(text)
238+
if (decryptedMessage.piuri == ProtocolType.PickupStatus.value ||
239+
decryptedMessage.piuri == ProtocolType.PickupDelivery.value
240+
) {
241+
return PickupRunner(decryptedMessage, mercury).run()
242+
} else {
243+
return emptyArray()
244+
}
245+
}
170246
}
247+
248+
fun interface OnMessageCallback {
249+
fun onMessage(messages: Array<Pair<String, Message>>)
250+
}
251+
252+
const val WEBSOCKET_TIMEOUT: Long = 15_000

Diff for: atala-prism-sdk/src/commonMain/kotlin/io/iohk/atala/prism/walletsdk/prismagent/mediation/MediationHandler.kt

+12
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,16 @@ interface MediationHandler {
5757
* @param ids An array of message IDs to register as read.
5858
*/
5959
suspend fun registerMessagesAsRead(ids: Array<String>)
60+
61+
/**
62+
* Listens for unread messages from a specified WebSocket service endpoint.
63+
*
64+
* @param serviceEndpointUri The URI of the service endpoint. It should be a valid WebSocket URI.
65+
* @param onMessageCallback A callback function that is invoked when a message is received.
66+
* This function is responsible for handling the incoming message.
67+
*/
68+
suspend fun listenUnreadMessages(
69+
serviceEndpointUri: String,
70+
onMessageCallback: OnMessageCallback
71+
)
6072
}

Diff for: atala-prism-sdk/src/commonMain/kotlin/io/iohk/atala/prism/walletsdk/prismagent/protocols/ProtocolType.kt

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ enum class ProtocolType(val value: String) {
3434
PickupDelivery("https://didcomm.org/messagepickup/3.0/delivery"),
3535
PickupStatus("https://didcomm.org/messagepickup/3.0/status"),
3636
PickupReceived("https://didcomm.org/messagepickup/3.0/messages-received"),
37+
LiveDeliveryChange("https://didcomm.org/messagepickup/3.0/live-delivery-change"),
3738
None("");
3839

3940
companion object {

0 commit comments

Comments
 (0)