Skip to content

Commit 870f28f

Browse files
committed
feat(rust): add support for better type in impl
1 parent 33e0499 commit 870f28f

File tree

2 files changed

+29
-9
lines changed

2 files changed

+29
-9
lines changed

chapi-ast-rust/src/main/kotlin/chapi/ast/rustast/RustAstBaseListener.kt

+27-6
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
package chapi.ast.rustast
22

33
import chapi.ast.antlr.RustParser
4+
import chapi.ast.antlr.RustParser.TypePathSegmentContext
5+
import chapi.ast.antlr.RustParser.Type_Context
46
import chapi.ast.antlr.RustParserBaseListener
57
import chapi.domain.core.*
68
import org.antlr.v4.runtime.ParserRuleContext
79
import java.io.File
8-
import java.util.concurrent.atomic.AtomicInteger
910

1011

1112
open class RustAstBaseListener(private val fileName: String) : RustParserBaseListener() {
@@ -18,15 +19,13 @@ open class RustAstBaseListener(private val fileName: String) : RustParserBaseLis
1819
protected open var currentFunction: CodeFunction = CodeFunction()
1920
protected var isEnteredImplementation: Boolean = false
2021
protected var isEnteredIndividualFunction: Boolean = false
21-
private var isEnteredImplementationFunction: Boolean = false
2222

2323
protected lateinit var currentIndividualFunction: CodeFunction
2424

2525
private val individualFunctions = mutableListOf<CodeFunction>()
2626
private val individualFields = mutableListOf<CodeField>()
2727

28-
var structMap = mutableMapOf<String, CodeDataStruct>()
29-
28+
private var structMap = mutableMapOf<String, CodeDataStruct>()
3029

3130
/**
3231
* packageName will parse from fileName, like:
@@ -98,7 +97,7 @@ open class RustAstBaseListener(private val fileName: String) : RustParserBaseLis
9897
}
9998

10099
override fun enterFunction_(ctx: RustParser.Function_Context?) {
101-
if (isEnteredImplementation == false) {
100+
if (!isEnteredImplementation) {
102101
val functionName = ctx!!.identifier().text
103102
val function = CodeFunction(
104103
Name = functionName,
@@ -141,7 +140,7 @@ open class RustAstBaseListener(private val fileName: String) : RustParserBaseLis
141140
}
142141

143142
override fun enterImplementation(ctx: RustParser.ImplementationContext?) {
144-
val nodeName = ctx?.inherentImpl()?.type_()?.text ?: return
143+
val nodeName = buildNodeName(ctx)
145144
if (structMap.containsKey(nodeName)) {
146145
currentNode = structMap[nodeName]!!
147146
} else {
@@ -155,6 +154,28 @@ open class RustAstBaseListener(private val fileName: String) : RustParserBaseLis
155154
isEnteredImplementation = true
156155
}
157156

157+
private fun buildNodeName(ctx: RustParser.ImplementationContext?): String {
158+
// keep this for better to debug
159+
val types = when {
160+
ctx?.inherentImpl()?.type_() != null -> listOf(ctx.inherentImpl().type_())
161+
ctx?.traitImpl()?.type_() != null -> listOf(ctx.traitImpl().type_())
162+
else -> emptyList()
163+
}
164+
165+
val typePathSegment: List<TypePathSegmentContext>? = types.firstOrNull()
166+
?.typeNoBounds()
167+
?.traitObjectTypeOneBound()
168+
?.traitBound()
169+
?.typePath()
170+
?.typePathSegment()
171+
172+
val pathIdentSegmentContext = typePathSegment?.map {
173+
it.pathIdentSegment()
174+
}?.firstOrNull()
175+
176+
return pathIdentSegmentContext?.identifier()?.text ?: return ""
177+
}
178+
158179
override fun exitImplementation(ctx: RustParser.ImplementationContext?) {
159180
isEnteredImplementation = false
160181
structMap[currentNode.NodeName] = currentNode

chapi-ast-rust/src/test/kotlin/chapi/ast/rustast/RustFullIdentListenerTest.kt

+2-3
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ class RustFullIdentListenerTest {
103103
}
104104

105105
@Test
106-
@Disabled
107106
fun should_pass_for_multiple_impl() {
108107
val str = """
109108
use std::cmp::Ordering;
@@ -137,7 +136,7 @@ class RustFullIdentListenerTest {
137136

138137
val codeContainer = RustAnalyser().analysis(str, "test.rs")
139138
assertEquals(1, codeContainer.DataStructures.size)
140-
val functions = codeContainer.DataStructures[1].Functions
141-
assertEquals(3, functions.size)
139+
val functions = codeContainer.DataStructures[0].Functions
140+
assertEquals(2, functions.size)
142141
}
143142
}

0 commit comments

Comments
 (0)