Skip to content

Commit

Permalink
refactor(ast): handle pointer types and simplify PrimaryExprContext u…
Browse files Browse the repository at this point in the history
…sage

- Enhance the GoFullIdentListener to recognize and process pointer types.
- Replace occurrences of `GoParser.PrimaryExprContext` with `PrimaryExprContext` for consistency.
- Add a new test case to cover function calls with pointer types and formatted SQL strings.
  • Loading branch information
phodal committed Nov 9, 2024
1 parent fef025f commit ec8a4f9
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 35 deletions.
38 changes: 23 additions & 15 deletions chapi-ast-go/src/main/kotlin/chapi/ast/goast/GoFullIdentListener.kt
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ class GoFullIdentListener(var fileName: String) : GoAstListener() {
tyType = "FunctionCall"
}

// if value starts with & or * should be a pointer
if (value.startsWith("&") || value.startsWith("*")) {
tyType = "Pointer"
value = value.substring(1)
}

return Pair(value, tyType)
}

Expand Down Expand Up @@ -217,7 +223,7 @@ class GoFullIdentListener(var fileName: String) : GoAstListener() {
*/
override fun enterExpression(ctx: GoParser.ExpressionContext?) {
when (val firstChild = ctx?.getChild(0)) {
is GoParser.PrimaryExprContext -> {
is PrimaryExprContext -> {
firstChild.getChild(1)?.let {
val codeCall = this.handlePrimaryExprCall(firstChild)

Expand All @@ -231,7 +237,7 @@ class GoFullIdentListener(var fileName: String) : GoAstListener() {
}
}

private fun handlePrimaryExprCall(primaryExprCtx: GoParser.PrimaryExprContext): List<CodeCall> {
private fun handlePrimaryExprCall(primaryExprCtx: PrimaryExprContext): List<CodeCall> {
return when (val arguments = primaryExprCtx.getChild(1)) {
is GoParser.ArgumentsContext -> {
codeCallFromExprList(primaryExprCtx.getChild(0), arguments)
Expand Down Expand Up @@ -270,7 +276,7 @@ class GoFullIdentListener(var fileName: String) : GoAstListener() {
private fun parseArguments(child: GoParser.ArgumentsContext): List<CodeProperty> {
return child.expressionList()?.expression()?.map {
val (value, typetype) = processingStringType(it.text, "")
if (localVars.containsKey(value)) {
if (localVars.containsKey(value) && localVars[value] != "") {
return@map CodeProperty(TypeValue = localVars[value]!!, TypeType = typetype)
}

Expand Down Expand Up @@ -310,12 +316,11 @@ class GoFullIdentListener(var fileName: String) : GoAstListener() {
}

when (child) {
is GoParser.PrimaryExprContext -> {
is PrimaryExprContext -> {
when (child.getChild(1)) {
is TerminalNodeImpl -> {

if (child.getChild(0) is GoParser.PrimaryExprContext && child.childCount > 2) {
val primaryCalls = handlePrimaryExprCall(child.getChild(0) as GoParser.PrimaryExprContext)
if (child.getChild(0) is PrimaryExprContext && child.childCount > 2) {
val primaryCalls = handlePrimaryExprCall(child.getChild(0) as PrimaryExprContext)
calls.addAll(primaryCalls)
}

Expand Down Expand Up @@ -359,19 +364,19 @@ class GoFullIdentListener(var fileName: String) : GoAstListener() {
)


private fun handleForPrimary(child: GoParser.PrimaryExprContext, isForLocalVar: Boolean = false): String? {
private fun handleForPrimary(child: PrimaryExprContext, isForLocalVar: Boolean = false): String? {
val nodeName = when (val first = child.getChild(0)) {
is GoParser.OperandContext -> {
first.text
}

is GoParser.PrimaryExprContext -> {
is PrimaryExprContext -> {
if (first.primaryExpr() != null) {
handleForPrimary(first, isForLocalVar).orEmpty()
} else {
val parent = child.parent
if (isForLocalVar && first.text == "fmt" && parent.text.startsWith("fmt.")) {
if (parent is GoParser.PrimaryExprContext) {
if (parent is PrimaryExprContext) {
val content = getValueFromPrintf(parent)
localVars.getOrDefault(first.text, content)
} else {
Expand Down Expand Up @@ -426,21 +431,24 @@ class GoFullIdentListener(var fileName: String) : GoAstListener() {
override fun enterVarDecl(ctx: GoParser.VarDeclContext?) {
ctx?.varSpec()?.forEach {
it.identifierList().IDENTIFIER().forEach { terminalNode ->
localVars[terminalNode.text] = it.type_()?.text ?: ""
val nodeNameFromExpr = nodeNameFromExpr(it.expressionList(), isForLocalVar = true)
localVars[terminalNode.text] = nodeNameFromExpr.ifEmpty {
it.type_()?.text ?: ""
}
}
}
}

override fun enterShortVarDecl(ctx: GoParser.ShortVarDeclContext?) {
ctx?.identifierList()?.IDENTIFIER()?.forEach {
localVars[it.text] = nodeNameFromExpr(ctx, isForLocalVar = true)
localVars[it.text] = nodeNameFromExpr(ctx.expressionList(), isForLocalVar = true)
}
}

private fun nodeNameFromExpr(ctx: GoParser.ShortVarDeclContext, isForLocalVar: Boolean): String {
ctx.expressionList().expression()?.forEach {
private fun nodeNameFromExpr(expressionListContext: GoParser.ExpressionListContext?, isForLocalVar: Boolean): String {
expressionListContext?.expression()?.forEach {
when (val firstChild = it.getChild(0)) {
is GoParser.PrimaryExprContext -> {
is PrimaryExprContext -> {
return handleForPrimary(firstChild, isForLocalVar).orEmpty()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import kotlin.test.assertEquals
internal class GoFullIdentListenerTest {
@Test
internal fun shouldIdentifyPackageName() {
val code= """
val code = """
package main
"""

Expand All @@ -19,7 +19,7 @@ package main

@Test
internal fun shouldIdentifySingleImport() {
val code= """
val code = """
package main
import "fmt"
Expand All @@ -32,7 +32,7 @@ import "fmt"

@Test
internal fun shouldIdentifyMultipleLineImport() {
val code= """
val code = """
package main
import "fmt"
Expand All @@ -48,7 +48,7 @@ import . "time"

@Test
internal fun shouldIdentifyMultipleTogetherImport() {
val code= """
val code = """
package main
import (
Expand All @@ -67,7 +67,7 @@ import (

@Test
internal fun shouldIdentifyBasicStruct() {
val code= """
val code = """
package main
type School struct {
Expand All @@ -87,7 +87,7 @@ type School struct {

@Test
internal fun shouldIdentifyBasicStructFunction() {
val code= """
val code = """
package main
import "fmt"
Expand All @@ -112,7 +112,7 @@ func (a *Animal) Move() {

@Test
internal fun shouldIdentifyStructFunctionReturnType() {
val code= """
val code = """
package main
import "fmt"
Expand All @@ -133,7 +133,7 @@ func (a *Animal) Move() string {

@Test
internal fun shouldIdentifyFunctionAsDefault() {
val code= """
val code = """
package main
func add(x int, y int) int {
Expand All @@ -148,7 +148,7 @@ func add(x int, y int) int {

@Test
internal fun shouldIdentifyFunctionMultipleReturnType() {
val code= """
val code = """
package main
func get(x int, y int) (int, int) {
Expand All @@ -163,7 +163,7 @@ func get(x int, y int) (int, int) {

@Test
internal fun shouldIdentifyFunctionParameters() {
val code= """
val code = """
package main
func get(x int, y int) (int, int) {
Expand All @@ -180,7 +180,7 @@ func get(x int, y int) (int, int) {

@Test
internal fun shouldIdentifyStructFuncCall() {
val code= """
val code = """
package main
import "fmt"
Expand All @@ -203,7 +203,7 @@ func (a *Animal) Move() {

@Test
internal fun shouldIdentifyFuncCall() {
val code= """
val code = """
package main
import "fmt"
Expand All @@ -222,7 +222,7 @@ func main() {

@Test
internal fun shouldIdentifyFunctionLocalVars() {
val code= """
val code = """
package main
func VarDecls() {
Expand All @@ -239,7 +239,7 @@ func VarDecls() {

@Test
internal fun shouldIdentifyFunctionShortVars() {
val code= """
val code = """
package main
func ShortDecls() {
Expand All @@ -253,7 +253,7 @@ func ShortDecls() {

@Test
internal fun shouldIdentifyFunctionConstVars() {
val code= """
val code = """
package main
func ConstDecls() {
Expand All @@ -273,7 +273,7 @@ func ConstDecls() {

@Test
internal fun shouldIdentifyStructFunctionLocalVars() {
val code= """
val code = """
package main
import "fmt"
Expand All @@ -300,7 +300,7 @@ func (a *Animal) Move() {
@Test
internal fun shouldSuccessGetSqlOfNode() {
@Language("Go")
val code= """
val code = """
package dao
import (
Expand Down Expand Up @@ -334,7 +334,7 @@ func (d *Dao) QueryBuglyProjectList() (projectList []string, err error) {
@Test
internal fun shouldIdentifyConstLocalVars() {
@Language("Go")
val code= """
val code = """
package dao
Expand Down Expand Up @@ -382,7 +382,7 @@ func (d *Dao) CountPersonal(c context.Context, opt *common.BaseOptions) (count i
@Test
internal fun shouldIdentifyLocalVarWithText() {
@Language("Go")
val code= """
val code = """
package dao
import (
Expand Down Expand Up @@ -431,7 +431,7 @@ func (d *Dao) MobileMachineLendCount() (mobileMachinesUsageCount []*model.Mobile
@Test
fun shouldIdentCallInSideCall() {
@Language("Go")
val code= """
val code = """
package dao
const _chArcAddSQL = "INSERT INTO member_channel_video%d (mid,cid,aid,order_num,modify_time) VALUES %s"
Expand Down Expand Up @@ -468,4 +468,37 @@ func (d *Dao) AddChannelArc(c context.Context, mid, cid int64, ts time.Time, chs
val codeProperty = secondParameter.Parameters
assertEquals(codeProperty[0].TypeValue, "\"INSERT INTO member_channel_video%d (mid,cid,aid,order_num,modify_time) VALUES %s\"")
}

@Test
fun shouldLoadFromAddress() {
@Language("Go")
val code = """
package dao
const _updateUserInfoMysql = "update capsule_info_%d set score = score + ? where uid = ? and type = ?";
func (d *Dao) UpdateCapsule() (affect int64, err error) {
var (
sqlStr, uKey, iKey string
)
sqlStr = fmt.Sprintf(_updateUserInfoMysql, getCapsuleTable(uid))
affect, err = d.execSqlWithBindParams(ctx, &sqlStr, score, uid, CoinIdIntMap[coinId])
return
}
"""

val codeFile = GoAnalyser().analysis(code, "")
val functionCalls = codeFile.DataStructures[0].Functions[0].FunctionCalls

assertEquals(functionCalls.size, 3)

val getExecFunc = functionCalls[2]
assertEquals(getExecFunc.NodeName, "Dao")
assertEquals(getExecFunc.FunctionName, "execSqlWithBindParams")
assertEquals(getExecFunc.Parameters.size, 5)

val firstParameter = getExecFunc.Parameters[1]
println(firstParameter)
assertEquals(firstParameter.TypeValue, "string")
}
}

0 comments on commit ec8a4f9

Please sign in to comment.