From ec8a4f99daab907f3677fe499c2f47257eafdb1a Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Sat, 9 Nov 2024 22:53:54 +0800 Subject: [PATCH] refactor(ast): handle pointer types and simplify PrimaryExprContext usage - Enhance the GoFullIdentListener to recognize and process pointer types. - Replace occurrences of `GoParser.PrimaryExprContext` with `PrimaryExprContext` for consistency. - Add a new test case to cover function calls with pointer types and formatted SQL strings. --- .../chapi/ast/goast/GoFullIdentListener.kt | 38 ++++++---- .../ast/goast/GoFullIdentListenerTest.kt | 73 ++++++++++++++----- 2 files changed, 76 insertions(+), 35 deletions(-) diff --git a/chapi-ast-go/src/main/kotlin/chapi/ast/goast/GoFullIdentListener.kt b/chapi-ast-go/src/main/kotlin/chapi/ast/goast/GoFullIdentListener.kt index ad1750cc..383af1f0 100644 --- a/chapi-ast-go/src/main/kotlin/chapi/ast/goast/GoFullIdentListener.kt +++ b/chapi-ast-go/src/main/kotlin/chapi/ast/goast/GoFullIdentListener.kt @@ -147,6 +147,12 @@ class GoFullIdentListener(var fileName: String) : GoAstListener() { tyType = "FunctionCall" } + // if value starts with & or * should be a pointer + if (value.startsWith("&") || value.startsWith("*")) { + tyType = "Pointer" + value = value.substring(1) + } + return Pair(value, tyType) } @@ -217,7 +223,7 @@ class GoFullIdentListener(var fileName: String) : GoAstListener() { */ override fun enterExpression(ctx: GoParser.ExpressionContext?) { when (val firstChild = ctx?.getChild(0)) { - is GoParser.PrimaryExprContext -> { + is PrimaryExprContext -> { firstChild.getChild(1)?.let { val codeCall = this.handlePrimaryExprCall(firstChild) @@ -231,7 +237,7 @@ class GoFullIdentListener(var fileName: String) : GoAstListener() { } } - private fun handlePrimaryExprCall(primaryExprCtx: GoParser.PrimaryExprContext): List { + private fun handlePrimaryExprCall(primaryExprCtx: PrimaryExprContext): List { return when (val arguments = primaryExprCtx.getChild(1)) { is GoParser.ArgumentsContext -> { codeCallFromExprList(primaryExprCtx.getChild(0), arguments) @@ -270,7 +276,7 @@ class GoFullIdentListener(var fileName: String) : GoAstListener() { private fun parseArguments(child: GoParser.ArgumentsContext): List { return child.expressionList()?.expression()?.map { val (value, typetype) = processingStringType(it.text, "") - if (localVars.containsKey(value)) { + if (localVars.containsKey(value) && localVars[value] != "") { return@map CodeProperty(TypeValue = localVars[value]!!, TypeType = typetype) } @@ -310,12 +316,11 @@ class GoFullIdentListener(var fileName: String) : GoAstListener() { } when (child) { - is GoParser.PrimaryExprContext -> { + is PrimaryExprContext -> { when (child.getChild(1)) { is TerminalNodeImpl -> { - - if (child.getChild(0) is GoParser.PrimaryExprContext && child.childCount > 2) { - val primaryCalls = handlePrimaryExprCall(child.getChild(0) as GoParser.PrimaryExprContext) + if (child.getChild(0) is PrimaryExprContext && child.childCount > 2) { + val primaryCalls = handlePrimaryExprCall(child.getChild(0) as PrimaryExprContext) calls.addAll(primaryCalls) } @@ -359,19 +364,19 @@ class GoFullIdentListener(var fileName: String) : GoAstListener() { ) - private fun handleForPrimary(child: GoParser.PrimaryExprContext, isForLocalVar: Boolean = false): String? { + private fun handleForPrimary(child: PrimaryExprContext, isForLocalVar: Boolean = false): String? { val nodeName = when (val first = child.getChild(0)) { is GoParser.OperandContext -> { first.text } - is GoParser.PrimaryExprContext -> { + is PrimaryExprContext -> { if (first.primaryExpr() != null) { handleForPrimary(first, isForLocalVar).orEmpty() } else { val parent = child.parent if (isForLocalVar && first.text == "fmt" && parent.text.startsWith("fmt.")) { - if (parent is GoParser.PrimaryExprContext) { + if (parent is PrimaryExprContext) { val content = getValueFromPrintf(parent) localVars.getOrDefault(first.text, content) } else { @@ -426,21 +431,24 @@ class GoFullIdentListener(var fileName: String) : GoAstListener() { override fun enterVarDecl(ctx: GoParser.VarDeclContext?) { ctx?.varSpec()?.forEach { it.identifierList().IDENTIFIER().forEach { terminalNode -> - localVars[terminalNode.text] = it.type_()?.text ?: "" + val nodeNameFromExpr = nodeNameFromExpr(it.expressionList(), isForLocalVar = true) + localVars[terminalNode.text] = nodeNameFromExpr.ifEmpty { + it.type_()?.text ?: "" + } } } } override fun enterShortVarDecl(ctx: GoParser.ShortVarDeclContext?) { ctx?.identifierList()?.IDENTIFIER()?.forEach { - localVars[it.text] = nodeNameFromExpr(ctx, isForLocalVar = true) + localVars[it.text] = nodeNameFromExpr(ctx.expressionList(), isForLocalVar = true) } } - private fun nodeNameFromExpr(ctx: GoParser.ShortVarDeclContext, isForLocalVar: Boolean): String { - ctx.expressionList().expression()?.forEach { + private fun nodeNameFromExpr(expressionListContext: GoParser.ExpressionListContext?, isForLocalVar: Boolean): String { + expressionListContext?.expression()?.forEach { when (val firstChild = it.getChild(0)) { - is GoParser.PrimaryExprContext -> { + is PrimaryExprContext -> { return handleForPrimary(firstChild, isForLocalVar).orEmpty() } } diff --git a/chapi-ast-go/src/test/kotlin/chapi/ast/goast/GoFullIdentListenerTest.kt b/chapi-ast-go/src/test/kotlin/chapi/ast/goast/GoFullIdentListenerTest.kt index 77f886ce..540bcc6d 100644 --- a/chapi-ast-go/src/test/kotlin/chapi/ast/goast/GoFullIdentListenerTest.kt +++ b/chapi-ast-go/src/test/kotlin/chapi/ast/goast/GoFullIdentListenerTest.kt @@ -8,7 +8,7 @@ import kotlin.test.assertEquals internal class GoFullIdentListenerTest { @Test internal fun shouldIdentifyPackageName() { - val code= """ + val code = """ package main """ @@ -19,7 +19,7 @@ package main @Test internal fun shouldIdentifySingleImport() { - val code= """ + val code = """ package main import "fmt" @@ -32,7 +32,7 @@ import "fmt" @Test internal fun shouldIdentifyMultipleLineImport() { - val code= """ + val code = """ package main import "fmt" @@ -48,7 +48,7 @@ import . "time" @Test internal fun shouldIdentifyMultipleTogetherImport() { - val code= """ + val code = """ package main import ( @@ -67,7 +67,7 @@ import ( @Test internal fun shouldIdentifyBasicStruct() { - val code= """ + val code = """ package main type School struct { @@ -87,7 +87,7 @@ type School struct { @Test internal fun shouldIdentifyBasicStructFunction() { - val code= """ + val code = """ package main import "fmt" @@ -112,7 +112,7 @@ func (a *Animal) Move() { @Test internal fun shouldIdentifyStructFunctionReturnType() { - val code= """ + val code = """ package main import "fmt" @@ -133,7 +133,7 @@ func (a *Animal) Move() string { @Test internal fun shouldIdentifyFunctionAsDefault() { - val code= """ + val code = """ package main func add(x int, y int) int { @@ -148,7 +148,7 @@ func add(x int, y int) int { @Test internal fun shouldIdentifyFunctionMultipleReturnType() { - val code= """ + val code = """ package main func get(x int, y int) (int, int) { @@ -163,7 +163,7 @@ func get(x int, y int) (int, int) { @Test internal fun shouldIdentifyFunctionParameters() { - val code= """ + val code = """ package main func get(x int, y int) (int, int) { @@ -180,7 +180,7 @@ func get(x int, y int) (int, int) { @Test internal fun shouldIdentifyStructFuncCall() { - val code= """ + val code = """ package main import "fmt" @@ -203,7 +203,7 @@ func (a *Animal) Move() { @Test internal fun shouldIdentifyFuncCall() { - val code= """ + val code = """ package main import "fmt" @@ -222,7 +222,7 @@ func main() { @Test internal fun shouldIdentifyFunctionLocalVars() { - val code= """ + val code = """ package main func VarDecls() { @@ -239,7 +239,7 @@ func VarDecls() { @Test internal fun shouldIdentifyFunctionShortVars() { - val code= """ + val code = """ package main func ShortDecls() { @@ -253,7 +253,7 @@ func ShortDecls() { @Test internal fun shouldIdentifyFunctionConstVars() { - val code= """ + val code = """ package main func ConstDecls() { @@ -273,7 +273,7 @@ func ConstDecls() { @Test internal fun shouldIdentifyStructFunctionLocalVars() { - val code= """ + val code = """ package main import "fmt" @@ -300,7 +300,7 @@ func (a *Animal) Move() { @Test internal fun shouldSuccessGetSqlOfNode() { @Language("Go") - val code= """ + val code = """ package dao import ( @@ -334,7 +334,7 @@ func (d *Dao) QueryBuglyProjectList() (projectList []string, err error) { @Test internal fun shouldIdentifyConstLocalVars() { @Language("Go") - val code= """ + val code = """ package dao @@ -382,7 +382,7 @@ func (d *Dao) CountPersonal(c context.Context, opt *common.BaseOptions) (count i @Test internal fun shouldIdentifyLocalVarWithText() { @Language("Go") - val code= """ + val code = """ package dao import ( @@ -431,7 +431,7 @@ func (d *Dao) MobileMachineLendCount() (mobileMachinesUsageCount []*model.Mobile @Test fun shouldIdentCallInSideCall() { @Language("Go") - val code= """ + val code = """ package dao const _chArcAddSQL = "INSERT INTO member_channel_video%d (mid,cid,aid,order_num,modify_time) VALUES %s" @@ -468,4 +468,37 @@ func (d *Dao) AddChannelArc(c context.Context, mid, cid int64, ts time.Time, chs val codeProperty = secondParameter.Parameters assertEquals(codeProperty[0].TypeValue, "\"INSERT INTO member_channel_video%d (mid,cid,aid,order_num,modify_time) VALUES %s\"") } + + @Test + fun shouldLoadFromAddress() { + @Language("Go") + val code = """ +package dao + +const _updateUserInfoMysql = "update capsule_info_%d set score = score + ? where uid = ? and type = ?"; + +func (d *Dao) UpdateCapsule() (affect int64, err error) { + var ( + sqlStr, uKey, iKey string + ) + sqlStr = fmt.Sprintf(_updateUserInfoMysql, getCapsuleTable(uid)) + affect, err = d.execSqlWithBindParams(ctx, &sqlStr, score, uid, CoinIdIntMap[coinId]) + return +} +""" + + val codeFile = GoAnalyser().analysis(code, "") + val functionCalls = codeFile.DataStructures[0].Functions[0].FunctionCalls + + assertEquals(functionCalls.size, 3) + + val getExecFunc = functionCalls[2] + assertEquals(getExecFunc.NodeName, "Dao") + assertEquals(getExecFunc.FunctionName, "execSqlWithBindParams") + assertEquals(getExecFunc.Parameters.size, 5) + + val firstParameter = getExecFunc.Parameters[1] + println(firstParameter) + assertEquals(firstParameter.TypeValue, "string") + } }