Skip to content

Commit 7d7a408

Browse files
committed
Support http forward
1 parent 0a3e454 commit 7d7a408

File tree

2 files changed

+281
-0
lines changed
  • mirai-api-http/src

2 files changed

+281
-0
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* Copyright 2023 Mamoe Technologies and contributors.
3+
*
4+
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
5+
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
6+
*
7+
* https://github.com/mamoe/mirai/blob/master/LICENSE
8+
*/
9+
10+
package net.mamoe.mirai.api.http.adapter.http.plugin
11+
12+
import io.ktor.http.*
13+
import io.ktor.http.content.*
14+
import io.ktor.server.application.*
15+
import io.ktor.server.plugins.*
16+
import io.ktor.server.request.*
17+
import io.ktor.util.*
18+
import io.ktor.util.pipeline.*
19+
import io.ktor.util.reflect.*
20+
import kotlinx.serialization.InternalSerializationApi
21+
import kotlinx.serialization.encodeToString
22+
import kotlinx.serialization.json.Json
23+
import kotlinx.serialization.json.JsonElement
24+
import kotlinx.serialization.json.JsonNull
25+
import kotlinx.serialization.serializer
26+
27+
28+
internal val HttpForwardAttributeKey = AttributeKey<HttpForwardContext>("HttpForward")
29+
internal val HttpForwardPhase = PipelinePhase("Forward")
30+
val HttpForward = createApplicationPlugin("HttpForward", ::HttpForwardConfig) {
31+
application.insertPhaseAfter(ApplicationCallPipeline.Call, HttpForwardPhase)
32+
33+
application.intercept(HttpForwardPhase) {
34+
val forwardContext = call.attributes.getOrNull(HttpForwardAttributeKey)
35+
if (forwardContext != null && !forwardContext.forwarded) {
36+
forwardContext.forwarded = true
37+
forwardContext.convertors = this@createApplicationPlugin.pluginConfig.getConvertors()
38+
finish()
39+
application.execute(ApplicationForwardCall(call, forwardContext))
40+
}
41+
}
42+
}
43+
44+
typealias BodyConvertor = (Any, TypeInfo) -> Any?
45+
46+
class HttpForwardConfig {
47+
private val convertors: MutableList<BodyConvertor> = mutableListOf(DefaultBodyConvertor)
48+
49+
fun addConvertor(convertor: BodyConvertor) {
50+
convertors.add(convertor)
51+
}
52+
53+
fun getConvertors(): List<BodyConvertor> = convertors
54+
55+
@OptIn(InternalSerializationApi::class)
56+
fun jsonElementBodyConvertor(json: Json) {
57+
addConvertor { body, typeInfo ->
58+
val b = if (body == NullBody) JsonNull else body
59+
when {
60+
b !is JsonElement -> null
61+
typeInfo.type == String::class -> json.encodeToString(b)
62+
else -> json.decodeFromJsonElement(typeInfo.type.serializer(), b)
63+
}
64+
}
65+
}
66+
}
67+
68+
val DefaultBodyConvertor: (Any, TypeInfo) -> Any? = { body, typeInfo ->
69+
if (typeInfo.type.isInstance(body)) body else null
70+
}
71+
72+
internal data class HttpForwardContext(val router: String, val body: Any?) {
73+
var forwarded = false
74+
var convertors = emptyList<BodyConvertor>()
75+
}
76+
77+
fun ApplicationCall.forward(forward: String) {
78+
attributes.put(HttpForwardAttributeKey, HttpForwardContext(forward, null))
79+
}
80+
81+
fun ApplicationCall.forward(forward: String, body: Any?) {
82+
attributes.put(HttpForwardAttributeKey, HttpForwardContext(forward, body ?: NullBody))
83+
}
84+
85+
internal fun forwardReceivePipeline(convertors: List<BodyConvertor>, body: Any): ApplicationReceivePipeline =
86+
ApplicationReceivePipeline().apply {
87+
intercept(ApplicationReceivePipeline.Transform) {
88+
proceedWith(convertors.firstNotNullOfOrNull { it.invoke(body, context.receiveType) }
89+
?: throw CannotTransformContentToTypeException(context.receiveType.kotlinType!!))
90+
}
91+
}
92+
93+
internal class ApplicationForwardCall(
94+
val delegate: ApplicationCall, val context: HttpForwardContext
95+
) : ApplicationCall by delegate {
96+
override val request: ApplicationRequest = DelegateApplicationRequest(this, context.router, context.body)
97+
}
98+
99+
internal class DelegateApplicationRequest(
100+
override val call: ApplicationForwardCall, forward: String, body: Any?
101+
) : ApplicationRequest by call.delegate.request {
102+
private val _pipeline by lazy {
103+
body?.let { forwardReceivePipeline(call.context.convertors, it) } ?: call.delegate.request.pipeline
104+
}
105+
override val local = DelegateRequestConnectionPoint(call.delegate.request.local, forward)
106+
override val pipeline: ApplicationReceivePipeline = _pipeline
107+
}
108+
109+
internal class DelegateRequestConnectionPoint(
110+
private val delegate: RequestConnectionPoint, override val uri: String
111+
) : RequestConnectionPoint by delegate
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
/*
2+
* Copyright 2023 Mamoe Technologies and contributors.
3+
*
4+
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
5+
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
6+
*
7+
* https://github.com/mamoe/mirai/blob/master/LICENSE
8+
*/
9+
package net.mamoe.mirai.api.http.adapter.http.plugin
10+
11+
import io.ktor.client.request.*
12+
import io.ktor.client.statement.*
13+
import io.ktor.http.*
14+
import io.ktor.serialization.kotlinx.json.*
15+
import io.ktor.server.application.*
16+
import io.ktor.server.plugins.contentnegotiation.*
17+
import io.ktor.server.plugins.doublereceive.*
18+
import io.ktor.server.request.*
19+
import io.ktor.server.response.*
20+
import io.ktor.server.routing.*
21+
import io.ktor.server.testing.*
22+
import kotlinx.serialization.Serializable
23+
import kotlinx.serialization.json.JsonElement
24+
import net.mamoe.mirai.api.http.adapter.internal.dto.parameter.LongTargetDTO
25+
import net.mamoe.mirai.api.http.adapter.internal.dto.parameter.NudgeDTO
26+
import net.mamoe.mirai.api.http.adapter.internal.serializer.BuiltinJsonSerializer
27+
import kotlin.test.Test
28+
import kotlin.test.assertEquals
29+
30+
class HttpForwardTest {
31+
32+
@Test
33+
fun testGetRequestForward() = testApplication {
34+
routing {
35+
get("/test") {
36+
call.forward("/forward")
37+
}
38+
39+
get("/forward") {
40+
call.respondText(call.parameters["key"] ?: "null")
41+
}
42+
}
43+
44+
client.get("/test") {
45+
parameter("key", "value")
46+
}.also {
47+
assertEquals(HttpStatusCode.OK, it.status)
48+
assertEquals("value", it.bodyAsText())
49+
}
50+
}
51+
52+
@Test
53+
fun testPostRequestForwardReceiveBody() = testApplication {
54+
install(ContentNegotiation) { json(json=BuiltinJsonSerializer.buildJson()) }
55+
56+
routing {
57+
post("/test") {
58+
call.forward("/forward")
59+
}
60+
61+
post("/forward") {
62+
val receive = call.receive<LongTargetDTO>()
63+
call.respondText(receive.target.toString())
64+
}
65+
}
66+
67+
client.post("/test") {
68+
contentType(ContentType.Application.Json)
69+
setBody("""{"target":123}""")
70+
}.also {
71+
assertEquals(HttpStatusCode.OK, it.status)
72+
assertEquals("123", it.bodyAsText())
73+
}
74+
}
75+
76+
@Test
77+
fun testPostRequestForwardDoubleReceiveBody() = testApplication {
78+
install(DoubleReceive)
79+
install(ContentNegotiation) { json(json=BuiltinJsonSerializer.buildJson()) }
80+
81+
routing {
82+
post("/test") {
83+
val receive = call.receive<LongTargetDTO>()
84+
assertEquals(123, receive.target)
85+
call.forward("/forward")
86+
}
87+
88+
post("/forward") {
89+
val receive = call.receive<LongTargetDTO>()
90+
call.respondText(receive.target.toString())
91+
}
92+
}
93+
94+
client.post("/test") {
95+
contentType(ContentType.Application.Json)
96+
setBody("""{"target":123}""")
97+
}.also {
98+
assertEquals(HttpStatusCode.OK, it.status)
99+
assertEquals("123", it.bodyAsText())
100+
}
101+
}
102+
103+
@Test
104+
fun testPostRequestForwardResetBody() = testApplication {
105+
install(DoubleReceive)
106+
install(HttpRouterMonitor)
107+
install(ContentNegotiation) { json(json=BuiltinJsonSerializer.buildJson()) }
108+
109+
routing {
110+
post("/test") {
111+
val receive = call.receive<LongTargetDTO>()
112+
assertEquals(123, receive.target)
113+
call.forward("/forward", NudgeDTO(321, 321, "kind"))
114+
}
115+
116+
post("/forward") {
117+
val receive = call.receive<NudgeDTO>()
118+
call.respondText(receive.target.toString())
119+
}
120+
}
121+
122+
client.post("/test") {
123+
contentType(ContentType.Application.Json)
124+
setBody("""{"target":123}""")
125+
}.also {
126+
assertEquals(HttpStatusCode.OK, it.status)
127+
assertEquals("321", it.bodyAsText())
128+
}
129+
}
130+
131+
132+
@Serializable
133+
private data class NestedDto(
134+
val router: String,
135+
val body: JsonElement,
136+
)
137+
138+
@Test
139+
fun testPostRequestForwardNestedBody() = testApplication {
140+
val json = BuiltinJsonSerializer.buildJson()
141+
142+
install(DoubleReceive)
143+
install(HttpRouterMonitor)
144+
install(ContentNegotiation) { json(json) }
145+
install(HttpForward) { jsonElementBodyConvertor(json) }
146+
147+
routing {
148+
post("/test") {
149+
val receive = call.receive<NestedDto>()
150+
assertEquals("/forward", receive.router)
151+
call.forward("/forward", receive.body)
152+
153+
call.respond(HttpStatusCode.OK)
154+
}
155+
156+
post("/forward") {
157+
val receive = call.receive<LongTargetDTO>()
158+
call.respondText(receive.target.toString())
159+
}
160+
}
161+
162+
client.post("/test") {
163+
contentType(ContentType.Application.Json)
164+
setBody("""{"router":"/forward","body":{"target":321}}""")
165+
}.also {
166+
assertEquals(HttpStatusCode.OK, it.status)
167+
assertEquals("321", it.bodyAsText())
168+
}
169+
}
170+
}

0 commit comments

Comments
 (0)