Skip to content

Commit

Permalink
refactor(goast): enhance local variable and printf handling in GoFull…
Browse files Browse the repository at this point in the history
…IdentListener

The commit updates the GoFullIdentListener to improve the handling of local variables and function calls, particularly for the `fmt.Printf` family. It includes:
- Adding a list of `fmt` functions to a dedicated variable.
- Modifying the `handleForPrimary` function to include a new parameter for local variable detection and enhance the logic for `fmt` function calls.
- Introducing a new private function `getValueFromPrintf` to extract the content from `fmt.Printf` calls.
- Adjusting the `nodeNameFromExpr` function to account for the new local variable flag.

This change is accompanied by additional test cases to ensure the correct identification of local variables and function calls.
  • Loading branch information
phodal committed Nov 9, 2024
1 parent c2f7335 commit 2c17637
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package chapi.ast.goast
import chapi.ast.antlr.GoParser
import chapi.domain.core.*
import chapi.infra.Stack
import org.antlr.v4.runtime.RuleContext
import org.antlr.v4.runtime.tree.ParseTree
import org.antlr.v4.runtime.tree.TerminalNodeImpl

Expand Down Expand Up @@ -321,17 +322,44 @@ class GoFullIdentListener(var fileName: String) : GoAstListener() {
return calls
}

private fun handleForPrimary(child: GoParser.PrimaryExprContext): String? {
var goPrintFuncs: List<String> = listOf(
"fmt.Print",
"fmt.Printf",
"fmt.Println",

"fmt.Sprint",
"fmt.Sprintf",
"fmt.Sprintln",

"fmt.Fprint",
"fmt.Fprintf",
"fmt.Fprintln",

"fmt.Errorf"
)


private fun handleForPrimary(child: GoParser.PrimaryExprContext, isForLocalVar: Boolean = false): String? {
val nodeName = when (val first = child.getChild(0)) {
is GoParser.OperandContext -> {
first.text
}

is GoParser.PrimaryExprContext -> {
if (first.primaryExpr() != null) {
handleForPrimary(first).orEmpty()
handleForPrimary(first, isForLocalVar).orEmpty()
} else {
localVars.getOrDefault(first.text, first.text)
val parent = child.parent
if (isForLocalVar && first.text == "fmt" && parent.text.startsWith("fmt.")) {
if (parent is GoParser.PrimaryExprContext) {
val content = getValueFromPrintf(parent)
localVars.getOrDefault(first.text, content)
} else {
localVars.getOrDefault(first.text, first.text)
}
} else {
localVars.getOrDefault(first.text, first.text)
}
}
}

Expand Down Expand Up @@ -361,6 +389,20 @@ class GoFullIdentListener(var fileName: String) : GoAstListener() {
return nodeName
}

private fun getValueFromPrintf(parent: RuleContext): String {
val child = parent.getChild(1)
if (child !is GoParser.ArgumentsContext) {
return child.text.removePrefix("(").removeSuffix(")")
}

val first = child.getChild(1)
if (first is GoParser.ExpressionListContext) {
return first.getChild(0).text.removePrefix("(").removeSuffix(")")
}

return child.text.removePrefix("(").removeSuffix(")")
}

override fun enterVarDecl(ctx: GoParser.VarDeclContext?) {
ctx?.varSpec()?.forEach {
it.identifierList().IDENTIFIER().forEach { terminalNode ->
Expand All @@ -371,15 +413,15 @@ class GoFullIdentListener(var fileName: String) : GoAstListener() {

override fun enterShortVarDecl(ctx: GoParser.ShortVarDeclContext?) {
ctx?.identifierList()?.IDENTIFIER()?.forEach {
localVars[it.text] = nodeNameFromExpr(ctx)
localVars[it.text] = nodeNameFromExpr(ctx, isForLocalVar = true)
}
}

private fun nodeNameFromExpr(ctx: GoParser.ShortVarDeclContext): String {
private fun nodeNameFromExpr(ctx: GoParser.ShortVarDeclContext, isForLocalVar: Boolean): String {
ctx.expressionList().expression()?.forEach {
when (val firstChild = it.getChild(0)) {
is GoParser.PrimaryExprContext -> {
return handleForPrimary(firstChild).orEmpty()
return handleForPrimary(firstChild, isForLocalVar).orEmpty()
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,4 +330,53 @@ func (d *Dao) QueryBuglyProjectList() (projectList []string, err error) {
assertEquals(functionCalls[1].NodeName, "Dao.db")
assertEquals(functionCalls[1].FunctionName, "Rows")
}

@Test
internal fun shouldIdentifyLocalVarWithText() {
@Language("Go")
val code= """
package dao
import (
"database/sql"
"fmt"
"go-common/app/admin/ep/merlin/model"
)
func (d *Dao) MobileMachineLendCount() (mobileMachinesUsageCount []*model.MobileMachineUsageCount, err error) {
var rows *sql.Rows
SQL := fmt.Sprintf("select b.id,b.name,count(b.name) as count from mobile_machine_logs as a "+
"left join mobile_machines as b on a.machine_id = b.id "+
"where a.operation_type='%s' and a.operation_result = '%s' "+
"group by b.id,b.`name` order by count desc", model.MBLendOutLog, model.OperationSuccessForMachineLog)
if rows, err = d.db.Raw(SQL).Rows(); err != nil {
return
}
defer rows.Close()
for rows.Next() {
mc := &model.MobileMachineUsageCount{}
if err = rows.Scan(&mc.MobileMachineID, &mc.MobileMachineName, &mc.Count); err != nil {
return
}
mobileMachinesUsageCount = append(mobileMachinesUsageCount, mc)
}
return
}
"""
val codeFile = GoAnalyser().analysis(code, "")
val functionCalls = codeFile.DataStructures[0].Functions[0].FunctionCalls
println(functionCalls)

assertEquals(functionCalls.size, 7)
assertEquals(functionCalls[0].NodeName, "fmt")

assertEquals(functionCalls[1].NodeName, "Dao.db")
assertEquals(functionCalls[1].FunctionName, "Raw")
assertEquals(functionCalls[1].Parameters.size, 1)
assertEquals(functionCalls[1].Parameters[0].TypeValue, "\"select b.id,b.name,count(b.name) as count from mobile_machine_logs as a \"+\"left join mobile_machines as b on a.machine_id = b.id \"+\"where a.operation_type='%s' and a.operation_result = '%s' \"+\"group by b.id,b.`name` order by count desc\"")
}
}

0 comments on commit 2c17637

Please sign in to comment.