|
| 1 | +package ai.privado.languageEngine.c.passes |
| 2 | + |
| 3 | +import io.joern.x2cpg.Defines |
| 4 | +import io.joern.x2cpg.passes.frontend.{ |
| 5 | + CallAlias, |
| 6 | + LocalVar, |
| 7 | + RecoverForXCompilationUnit, |
| 8 | + XTypeRecovery, |
| 9 | + XTypeRecoveryConfig, |
| 10 | + XTypeRecoveryPassGenerator, |
| 11 | + XTypeRecoveryState |
| 12 | +} |
| 13 | +import io.shiftleft.codepropertygraph.Cpg |
| 14 | +import io.shiftleft.codepropertygraph.generated.nodes.* |
| 15 | +import overflowdb.traversal.Traversal |
| 16 | +import io.shiftleft.semanticcpg.language.* |
| 17 | +import overflowdb.BatchedUpdate.DiffGraphBuilder |
| 18 | +import io.shiftleft.semanticcpg.language.importresolver.* |
| 19 | +import io.shiftleft.semanticcpg.language.operatorextension.OpNodes |
| 20 | +import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.{Assignment, FieldAccess} |
| 21 | +import io.joern.x2cpg.passes.frontend.XTypeRecovery.AllNodeTypesFromIteratorExt |
| 22 | +import io.joern.x2cpg.passes.frontend.XTypeRecovery.AllNodeTypesFromNodeExt |
| 23 | +import io.shiftleft.codepropertygraph.generated.{Operators, PropertyNames} |
| 24 | + |
| 25 | +class CTypeRecoveryPassGenerator(cpg: Cpg, config: XTypeRecoveryConfig = XTypeRecoveryConfig()) |
| 26 | + extends XTypeRecoveryPassGenerator[File](cpg, config) { |
| 27 | + override protected def generateRecoveryPass(state: XTypeRecoveryState, iteration: Int): XTypeRecovery[File] = |
| 28 | + new CTypeRecovery(cpg, state, iteration) |
| 29 | +} |
| 30 | + |
| 31 | +private class CTypeRecovery(cpg: Cpg, state: XTypeRecoveryState, iteration: Int) |
| 32 | + extends XTypeRecovery[File](cpg, state, iteration) { |
| 33 | + override def compilationUnits: Traversal[File] = cpg.file.iterator |
| 34 | + |
| 35 | + override def generateRecoveryForCompilationUnitTask( |
| 36 | + unit: File, |
| 37 | + builder: DiffGraphBuilder |
| 38 | + ): RecoverForXCompilationUnit[File] = { |
| 39 | + new RecoverForCFile(cpg, unit, builder, state) |
| 40 | + } |
| 41 | +} |
| 42 | + |
| 43 | +private class RecoverForCFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, state: XTypeRecoveryState) |
| 44 | + extends RecoverForXCompilationUnit[File](cpg, cu, builder, state) { |
| 45 | + |
| 46 | + /** A heuristic method to determine if a call is a constructor or not. |
| 47 | + */ |
| 48 | + override protected def isConstructor(c: Call): Boolean = { |
| 49 | + isConstructor(c.name) |
| 50 | + } |
| 51 | + |
| 52 | + /** A heuristic method to determine if a call name is a constructor or not. |
| 53 | + */ |
| 54 | + override protected def isConstructor(name: String): Boolean = |
| 55 | + !name.isBlank && (name.equals("new") || name.equals("<init>")) |
| 56 | + |
| 57 | + override protected def importNodes: Iterator[Import] = cu match { |
| 58 | + case x: File => cpg.imports.where(_.file.name(x.name)) |
| 59 | + case _ => super.importNodes |
| 60 | + } |
| 61 | + |
| 62 | + override protected def visitImport(i: Import): Unit = for { |
| 63 | + resolvedImport <- i.tag |
| 64 | + alias <- i.importedAs |
| 65 | + } { |
| 66 | + import scala.util.Try |
| 67 | + Try(EvaluatedImport.tagToEvaluatedImport(resolvedImport)).toOption |
| 68 | + .getOrElse(Option(UnknownMethod("random", "random", Option("random")))) |
| 69 | + .foreach { |
| 70 | + case ResolvedMethod(fullName, alias, receiver, _) => |
| 71 | + symbolTable.append(CallAlias(alias, receiver), fullName) |
| 72 | + case ResolvedTypeDecl(fullName, _) => |
| 73 | + symbolTable.append(LocalVar(alias), fullName) |
| 74 | + case ResolvedMember(basePath, memberName, _) => |
| 75 | + val matchingIdentifiers = cpg.method.fullNameExact(basePath).local |
| 76 | + val matchingMembers = cpg.typeDecl.fullNameExact(basePath).member |
| 77 | + val memberTypes = (matchingMembers ++ matchingIdentifiers) |
| 78 | + .nameExact(memberName) |
| 79 | + .getKnownTypes |
| 80 | + symbolTable.append(LocalVar(alias), memberTypes) |
| 81 | + case UnknownMethod(fullName, alias, receiver, _) => |
| 82 | + symbolTable.append(CallAlias(alias, receiver), fullName) |
| 83 | + case UnknownTypeDecl(fullName, _) => |
| 84 | + symbolTable.append(LocalVar(alias), fullName) |
| 85 | + case UnknownImport(path, _) => |
| 86 | + symbolTable.append(CallAlias(alias), path) |
| 87 | + symbolTable.append(LocalVar(alias), path) |
| 88 | + } |
| 89 | + } |
| 90 | + |
| 91 | + override protected def hasTypes(node: AstNode): Boolean = node match { |
| 92 | + case x: Call if !x.methodFullName.startsWith("<operator>") => |
| 93 | + !x.methodFullName.toLowerCase().matches("(<unknownfullname>|any)") && !x.methodFullName.equals(x.name) |
| 94 | + case x => x.getKnownTypes.nonEmpty |
| 95 | + } |
| 96 | + |
| 97 | + override protected def setCallMethodFullNameFromBase(c: Call): Set[String] = { |
| 98 | + val recTypes = c.argument.headOption |
| 99 | + .map { |
| 100 | + case ifa: Call |
| 101 | + if (ifa.name.equals("<operator>.indirectFieldAccess") || ifa.name.equals( |
| 102 | + "<operator>.fieldAccess" |
| 103 | + )) && ifa.argument.headOption.exists(symbolTable.contains) => |
| 104 | + getTypeFromArgument(ifa.argument.headOption, c) |
| 105 | + case x => getTypeFromArgument(Some(x), c) |
| 106 | + } |
| 107 | + .getOrElse(Set.empty[String]) |
| 108 | + val callTypes = recTypes.map(_.stripSuffix("*").concat(s"$pathSep${c.name}")) |
| 109 | + symbolTable.append(c, callTypes) |
| 110 | + } |
| 111 | + |
| 112 | + private def getTypeFromArgument(headArgument: Option[Expression], c: Call): Set[String] = { |
| 113 | + headArgument |
| 114 | + .map { |
| 115 | + case x: Call if x.typeFullName != "ANY" && x.typeFullName != "<empty>" => |
| 116 | + Set(x.typeFullName) |
| 117 | + case x: Call => |
| 118 | + val returns = cpg.method.fullNameExact(c.methodFullName).methodReturn.typeFullNameNot("ANY") |
| 119 | + val returnWithPossibleTypes = cpg.method.fullNameExact(c.methodFullName).methodReturn.where(_.possibleTypes) |
| 120 | + val fullNames = returns.typeFullName ++ returnWithPossibleTypes.possibleTypes |
| 121 | + fullNames.toSet match { |
| 122 | + case xs if xs.nonEmpty => xs |
| 123 | + case _ => |
| 124 | + val returns = cpg.method.fullNameExact(x.methodFullName).methodReturn.typeFullNameNot("ANY") |
| 125 | + val returnWithPossibleTypes = |
| 126 | + cpg.method.fullNameExact(x.methodFullName).methodReturn.where(_.possibleTypes) |
| 127 | + val fullNames = returns.typeFullName ++ returnWithPossibleTypes.possibleTypes |
| 128 | + fullNames.toSet match { |
| 129 | + case xs if xs.nonEmpty => xs |
| 130 | + case _ => symbolTable.get(x).map(t => Seq(t, XTypeRecovery.DummyReturnType).mkString(pathSep)) |
| 131 | + } |
| 132 | + } |
| 133 | + case x => |
| 134 | + symbolTable.get(x) |
| 135 | + } |
| 136 | + .getOrElse(Set.empty[String]) |
| 137 | + } |
| 138 | + |
| 139 | + override protected def setTypeInformation(): Unit = { |
| 140 | + cu.ast |
| 141 | + .collect { |
| 142 | + case n: Local => n |
| 143 | + case n: Call => n |
| 144 | + case n: Expression => n |
| 145 | + case n: MethodParameterIn if state.isFinalIteration => n |
| 146 | + case n: MethodReturn if state.isFinalIteration => n |
| 147 | + } |
| 148 | + .foreach { |
| 149 | + case x: Local if symbolTable.contains(x) => storeNodeTypeInfo(x, symbolTable.get(x).toSeq) |
| 150 | + case x: MethodParameterIn => setTypeFromTypeHints(x) |
| 151 | + case x: MethodReturn => |
| 152 | + setTypeFromTypeHints(x) |
| 153 | + case x: Identifier if symbolTable.contains(x) => |
| 154 | + setTypeInformationForRecCall(x, x.inCall.headOption, x.inCall.argument.l) |
| 155 | + case x: Call if symbolTable.contains(x) => |
| 156 | + val typs = |
| 157 | + if (state.enableDummyTypesForThisIteration) symbolTable.get(x).toSeq |
| 158 | + else symbolTable.get(x).filterNot(XTypeRecovery.isDummyType).toSeq |
| 159 | + storeCallTypeInfo(x, typs) |
| 160 | + case x: Identifier if symbolTable.contains(CallAlias(x.name)) && x.inCall.nonEmpty => |
| 161 | + setTypeInformationForRecCall(x, x.inCall.headOption, x.inCall.argument.l) |
| 162 | + case x: Call |
| 163 | + if x.argument.headOption.isCall.exists(c => |
| 164 | + c.name.equals("<operator>.indirectFieldAccess") || c.name.equals("<operator>.fieldAccess") |
| 165 | + ) && x.argument.headOption.isCall.argument.headOption.exists(c => symbolTable.contains(c)) => |
| 166 | + setCallMethodFullNameFromBase(x) |
| 167 | + val typs = |
| 168 | + if (state.enableDummyTypesForThisIteration) symbolTable.get(x).toSeq |
| 169 | + else symbolTable.get(x).filterNot(XTypeRecovery.isDummyType).toSeq |
| 170 | + storeCallTypeInfo(x, typs) |
| 171 | + case x: Call if x.argument.headOption.exists(symbolTable.contains) => |
| 172 | + setTypeInformationForRecCall(x, Option(x), x.argument.l) |
| 173 | + case _ => |
| 174 | + } |
| 175 | + // Set types in an atomic way |
| 176 | + newTypesForMembers.foreach { case (m, ts) => storeDefaultTypeInfo(m, ts.toSeq) } |
| 177 | + } |
| 178 | + |
| 179 | + private def storeNodeTypeInfo(storedNode: StoredNode, types: Seq[String]): Unit = { |
| 180 | + lazy val existingTypes = storedNode.getKnownTypes |
| 181 | + |
| 182 | + val hasUnknownTypeFullName = storedNode |
| 183 | + .property(PropertyNames.TYPE_FULL_NAME, Defines.Any) |
| 184 | + .matches(XTypeRecovery.unknownTypePattern.pattern.pattern()) |
| 185 | + |
| 186 | + if (types.nonEmpty && (hasUnknownTypeFullName || types.toSet != existingTypes)) { |
| 187 | + storedNode match { |
| 188 | + case m: Member => |
| 189 | + // To avoid overwriting member updates, we store them elsewhere until the end |
| 190 | + newTypesForMembers.updateWith(m) { |
| 191 | + case Some(ts) => Option(ts ++ types) |
| 192 | + case None => Option(types.toSet) |
| 193 | + } |
| 194 | + case i: Identifier => storeIdentifierTypeInfo(i, types) |
| 195 | + case l: Local => storeLocalTypeInfo(l, types) |
| 196 | + case c: Call if !c.name.startsWith("<operator>") => storeCallTypeInfo(c, types) |
| 197 | + case _: Call => |
| 198 | + case n => |
| 199 | + setTypes(n, types) |
| 200 | + } |
| 201 | + } |
| 202 | + } |
| 203 | + |
| 204 | + private def setTypeInformationForRecCall(x: AstNode, n: Option[Call], ms: List[AstNode]): Unit = { |
| 205 | + (n, ms) match { |
| 206 | + // Case 1: 'call' is an assignment from some dynamic dispatch call |
| 207 | + case (Some(call: Call), ::(i: Identifier, ::(c: Call, _))) if call.name == Operators.assignment => |
| 208 | + setTypeForIdentifierAssignedToCall(call, i, c) |
| 209 | + // Case 1: 'call' is an assignment from some other data structure |
| 210 | + case (Some(call: Call), ::(i: Identifier, _)) if call.name == Operators.assignment => |
| 211 | + setTypeForIdentifierAssignedToDefault(call, i) |
| 212 | + // Case 2: 'i' is the receiver of 'call' |
| 213 | + case (Some(call: Call), ::(i: Identifier, _)) if call.name != Operators.fieldAccess => |
| 214 | + setTypeForDynamicDispatchCall(call, i) |
| 215 | + // Case 3: 'i' is the receiver for a field access on member 'f' |
| 216 | + case (Some(fieldAccess: Call), ::(i: Identifier, ::(f: FieldIdentifier, _))) |
| 217 | + if fieldAccess.name == Operators.fieldAccess => |
| 218 | + setTypeForFieldAccess(fieldAccess.asInstanceOf[FieldAccess], i, f) |
| 219 | + case _ => |
| 220 | + } |
| 221 | + // Handle the node itself |
| 222 | + x match { |
| 223 | + case c: Call if c.name.startsWith("<operator") => |
| 224 | + case _ => persistType(x, symbolTable.get(x)) |
| 225 | + } |
| 226 | + } |
| 227 | +} |
0 commit comments