Skip to content

Commit 0e30346

Browse files
feat: experimental opt-in for mediator live mode (#150)
Co-authored-by: Ahmed Moussa <[email protected]> Signed-off-by: Cristian G <[email protected]>
1 parent e1753ec commit 0e30346

File tree

6 files changed

+168
-35
lines changed

6 files changed

+168
-35
lines changed

Diff for: atala-prism-sdk/src/commonMain/kotlin/io/iohk/atala/prism/walletsdk/prismagent/ConnectionManager.kt

+17-14
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import kotlin.jvm.Throws
3535
* @property castor The instance of the Castor interface used for working with DIDs.
3636
* @property pluto The instance of the Pluto interface used for storing messages and connection information.
3737
* @property mediationHandler The instance of the MediationHandler interface used for handling mediation.
38+
* @property experimentLiveModeOptIn Flag to opt in or out of the experimental feature mediator live mode, using websockets.
3839
* @property pairings The mutable list of DIDPair representing the connections managed by the ConnectionManager.
3940
*/
4041
class ConnectionManager(
@@ -44,6 +45,7 @@ class ConnectionManager(
4445
internal val mediationHandler: MediationHandler,
4546
private var pairings: MutableList<DIDPair>,
4647
private val pollux: Pollux,
48+
private val experimentLiveModeOptIn: Boolean = false,
4749
private val scope: CoroutineScope = CoroutineScope(Dispatchers.IO)
4850
) : ConnectionsManager, DIDCommConnection {
4951

@@ -66,22 +68,23 @@ class ConnectionManager(
6668
// Resolve the DID document for the mediator
6769
val mediatorDidDoc = castor.resolveDID(currentMediatorDID.toString())
6870
var serviceEndpoint: String? = null
69-
70-
// Loop through the services in the DID document to find a WebSocket endpoint
71-
mediatorDidDoc.services.forEach {
72-
if (it.serviceEndpoint.uri.contains("wss://") || it.serviceEndpoint.uri.contains("ws://")) {
73-
serviceEndpoint = it.serviceEndpoint.uri
74-
return@forEach // Exit loop once the WebSocket endpoint is found
71+
if (experimentLiveModeOptIn) {
72+
// Loop through the services in the DID document to find a WebSocket endpoint
73+
mediatorDidDoc.services.forEach {
74+
if (it.serviceEndpoint.uri.contains("wss://") || it.serviceEndpoint.uri.contains("ws://")) {
75+
serviceEndpoint = it.serviceEndpoint.uri
76+
return@forEach // Exit loop once the WebSocket endpoint is found
77+
}
7578
}
76-
}
7779

78-
// If a WebSocket service endpoint is found
79-
serviceEndpoint?.let { serviceEndpointUrl ->
80-
// Listen for unread messages on the WebSocket endpoint
81-
mediationHandler.listenUnreadMessages(
82-
serviceEndpointUrl
83-
) { arrayMessages ->
84-
processMessages(arrayMessages)
80+
// If a WebSocket service endpoint is found
81+
serviceEndpoint?.let { serviceEndpointUrl ->
82+
// Listen for unread messages on the WebSocket endpoint
83+
mediationHandler.listenUnreadMessages(
84+
serviceEndpointUrl
85+
) { arrayMessages ->
86+
processMessages(arrayMessages)
87+
}
8588
}
8689
}
8790

Diff for: atala-prism-sdk/src/commonMain/kotlin/io/iohk/atala/prism/walletsdk/prismagent/PrismAgent.kt

+12-4
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ import io.iohk.atala.prism.walletsdk.logger.PrismLoggerImpl
4747
import io.iohk.atala.prism.walletsdk.pollux.models.AnonCredential
4848
import io.iohk.atala.prism.walletsdk.pollux.models.CredentialRequestMeta
4949
import io.iohk.atala.prism.walletsdk.pollux.models.JWTCredential
50+
import io.iohk.atala.prism.walletsdk.prismagent.helpers.AgentOptions
5051
import io.iohk.atala.prism.walletsdk.prismagent.mediation.BasicMediatorHandler
5152
import io.iohk.atala.prism.walletsdk.prismagent.mediation.MediationHandler
5253
import io.iohk.atala.prism.walletsdk.prismagent.protocols.ProtocolType
@@ -119,6 +120,7 @@ class PrismAgent {
119120
private val api: Api
120121
private var connectionManager: ConnectionManager
121122
private var logger: PrismLogger
123+
private val agentOptions: AgentOptions
122124

123125
/**
124126
* Initializes the PrismAgent with the given dependencies.
@@ -133,6 +135,7 @@ class PrismAgent {
133135
* @param api An optional Api instance used by the PrismAgent if provided, otherwise a default ApiImpl will be used.
134136
* @param logger An optional PrismLogger instance used by the PrismAgent if provided, otherwise a PrismLoggerImpl with
135137
* LogComponent.PRISM_AGENT will be used.
138+
* @param agentOptions Options to configure certain features with in the prism agent.
136139
*/
137140
@JvmOverloads
138141
constructor(
@@ -144,7 +147,8 @@ class PrismAgent {
144147
connectionManager: ConnectionManager,
145148
seed: Seed?,
146149
api: Api?,
147-
logger: PrismLogger = PrismLoggerImpl(LogComponent.PRISM_AGENT)
150+
logger: PrismLogger = PrismLoggerImpl(LogComponent.PRISM_AGENT),
151+
agentOptions: AgentOptions
148152
) {
149153
prismAgentScope.launch {
150154
flowState.emit(State.STOPPED)
@@ -170,6 +174,7 @@ class PrismAgent {
170174
}
171175
)
172176
this.logger = logger
177+
this.agentOptions = agentOptions
173178
}
174179

175180
/**
@@ -184,6 +189,7 @@ class PrismAgent {
184189
* @param api The instance of the API. Default is null.
185190
* @param mediatorHandler The mediator handler.
186191
* @param logger The logger for PrismAgent. Default is PrismLoggerImpl with LogComponent.PRISM_AGENT.
192+
* @param agentOptions Options to configure certain features with in the prism agent.
187193
*/
188194
@JvmOverloads
189195
constructor(
@@ -195,7 +201,8 @@ class PrismAgent {
195201
seed: Seed? = null,
196202
api: Api? = null,
197203
mediatorHandler: MediationHandler,
198-
logger: PrismLogger = PrismLoggerImpl(LogComponent.PRISM_AGENT)
204+
logger: PrismLogger = PrismLoggerImpl(LogComponent.PRISM_AGENT),
205+
agentOptions: AgentOptions
199206
) {
200207
prismAgentScope.launch {
201208
flowState.emit(State.STOPPED)
@@ -220,9 +227,10 @@ class PrismAgent {
220227
}
221228
)
222229
this.logger = logger
230+
this.agentOptions = agentOptions
223231
// Pairing will be removed in the future
224232
this.connectionManager =
225-
ConnectionManager(mercury, castor, pluto, mediatorHandler, mutableListOf(), pollux)
233+
ConnectionManager(mercury, castor, pluto, mediatorHandler, mutableListOf(), pollux, agentOptions.experiments.liveMode)
226234
}
227235

228236
init {
@@ -462,7 +470,7 @@ class PrismAgent {
462470
fun setupMediatorHandler(mediatorHandler: MediationHandler) {
463471
stop()
464472
this.connectionManager =
465-
ConnectionManager(mercury, castor, pluto, mediatorHandler, mutableListOf(), pollux)
473+
ConnectionManager(mercury, castor, pluto, mediatorHandler, mutableListOf(), pollux, agentOptions.experiments.liveMode)
466474
}
467475

468476
/**
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package io.iohk.atala.prism.walletsdk.prismagent.helpers
2+
3+
data class AgentOptions(val experiments: Experiments = Experiments())
4+
5+
data class Experiments(val liveMode: Boolean = false)

Diff for: atala-prism-sdk/src/commonTest/kotlin/io/iohk/atala/prism/walletsdk/prismagent/ConnectionManagerTest.kt

+98
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class ConnectionManagerTest {
6868
mediationHandler = basicMediatorHandlerMock,
6969
pairings = mutableListOf(),
7070
pollux = polluxMock,
71+
experimentLiveModeOptIn = true,
7172
scope = CoroutineScope(testDispatcher)
7273
)
7374
}
@@ -123,6 +124,103 @@ class ConnectionManagerTest {
123124
verify(basicMediatorHandlerMock).listenUnreadMessages(any(), any())
124125
}
125126

127+
@Test
128+
fun testStartFetchingMessages_whenServiceEndpointContainsWSSButOptInLiveModeFalse_thenRegunarlApi() = runTest {
129+
connectionManager = ConnectionManager(
130+
mercury = mercuryMock,
131+
castor = castorMock,
132+
pluto = plutoMock,
133+
mediationHandler = basicMediatorHandlerMock,
134+
pairings = mutableListOf(),
135+
pollux = polluxMock,
136+
experimentLiveModeOptIn = false,
137+
scope = CoroutineScope(testDispatcher)
138+
)
139+
140+
`when`(basicMediatorHandlerMock.mediatorDID)
141+
.thenReturn(DID("did:prism:b6c0c33d701ac1b9a262a14454d1bbde3d127d697a76950963c5fd930605:Cj8KPRI7CgdtYXN0ZXIwEAFKLgoJc2VmsxEiECSTjyV7sUfCr_ArpN9rvCwR9fRMAhcsr_S7ZRiJk4p5k"))
142+
143+
val vmAuthentication = DIDDocument.VerificationMethod(
144+
id = DIDUrl(DID("2", "1", "0")),
145+
controller = DID("2", "2", "0"),
146+
type = Curve.ED25519.value,
147+
publicKeyJwk = mapOf("crv" to Curve.ED25519.value, "x" to "")
148+
)
149+
150+
val vmKeyAgreement = DIDDocument.VerificationMethod(
151+
id = DIDUrl(DID("3", "1", "0")),
152+
controller = DID("3", "2", "0"),
153+
type = Curve.X25519.value,
154+
publicKeyJwk = mapOf("crv" to Curve.X25519.value, "x" to "")
155+
)
156+
157+
val vmService = DIDDocument.Service(
158+
id = UUID.randomUUID().toString(),
159+
type = emptyArray(),
160+
serviceEndpoint = DIDDocument.ServiceEndpoint(
161+
uri = "wss://serviceEndpoint"
162+
)
163+
)
164+
165+
val didDoc = DIDDocument(
166+
id = DID("did:prism:asdfasdf"),
167+
coreProperties = arrayOf(
168+
DIDDocument.Authentication(
169+
urls = emptyArray(),
170+
verificationMethods = arrayOf(vmAuthentication)
171+
),
172+
DIDDocument.KeyAgreement(
173+
urls = emptyArray(),
174+
verificationMethods = arrayOf(vmKeyAgreement)
175+
),
176+
DIDDocument.Services(
177+
values = arrayOf(vmService)
178+
)
179+
)
180+
)
181+
182+
`when`(castorMock.resolveDID(any())).thenReturn(didDoc)
183+
val messages = arrayOf(Pair("1234", Message(piuri = "", body = "")))
184+
`when`(basicMediatorHandlerMock.pickupUnreadMessages(any())).thenReturn(
185+
flow {
186+
emit(
187+
messages
188+
)
189+
}
190+
)
191+
val attachments: Array<AttachmentDescriptor> =
192+
arrayOf(
193+
AttachmentDescriptor(
194+
mediaType = "application/json",
195+
format = CredentialType.JWT.type,
196+
data = AttachmentBase64(base64 = "asdfasdfasdfasdfasdfasdfasdfasdfasdf".base64UrlEncoded)
197+
)
198+
)
199+
val listMessages = listOf(
200+
Message(
201+
piuri = ProtocolType.DidcommconnectionRequest.value,
202+
body = ""
203+
),
204+
Message(
205+
piuri = ProtocolType.DidcommIssueCredential.value,
206+
thid = UUID.randomUUID().toString(),
207+
from = DID("did:peer:asdf897a6sdf"),
208+
to = DID("did:peer:f706sg678ha"),
209+
attachments = attachments,
210+
body = """{}"""
211+
)
212+
)
213+
val messageList: Flow<List<Message>> = flow {
214+
emit(listMessages)
215+
}
216+
`when`(plutoMock.getAllMessages()).thenReturn(messageList)
217+
218+
connectionManager.startFetchingMessages()
219+
assertNotNull(connectionManager.fetchingMessagesJob)
220+
verify(basicMediatorHandlerMock).pickupUnreadMessages(10)
221+
verify(basicMediatorHandlerMock).registerMessagesAsRead(arrayOf("1234"))
222+
}
223+
126224
@Test
127225
fun testStartFetchingMessages_whenServiceEndpointNotContainsWSS_thenUseAPIRequest() =
128226
runBlockingTest {

0 commit comments

Comments
 (0)