|  | 
|  | 1 | +package dotty.tools.pc | 
|  | 2 | + | 
|  | 3 | +import scala.util.Try | 
|  | 4 | + | 
|  | 5 | +import dotty.tools.dotc.ast.Trees.ValDef | 
|  | 6 | +import dotty.tools.dotc.ast.tpd.* | 
|  | 7 | +import dotty.tools.dotc.core.Contexts.Context | 
|  | 8 | +import dotty.tools.dotc.core.Flags | 
|  | 9 | +import dotty.tools.dotc.core.Flags.Method | 
|  | 10 | +import dotty.tools.dotc.core.Names.Name | 
|  | 11 | +import dotty.tools.dotc.core.StdNames.* | 
|  | 12 | +import dotty.tools.dotc.core.SymDenotations.NoDenotation | 
|  | 13 | +import dotty.tools.dotc.core.Symbols.defn | 
|  | 14 | +import dotty.tools.dotc.core.Symbols.NoSymbol | 
|  | 15 | +import dotty.tools.dotc.core.Symbols.Symbol | 
|  | 16 | +import dotty.tools.dotc.core.Types.* | 
|  | 17 | +import dotty.tools.pc.IndexedContext | 
|  | 18 | +import dotty.tools.pc.utils.InteractiveEnrichments.* | 
|  | 19 | +import scala.annotation.tailrec | 
|  | 20 | +import dotty.tools.dotc.core.Denotations.SingleDenotation | 
|  | 21 | +import dotty.tools.dotc.core.Denotations.MultiDenotation | 
|  | 22 | +import dotty.tools.dotc.util.Spans.Span | 
|  | 23 | + | 
|  | 24 | +object ApplyExtractor: | 
|  | 25 | +  def unapply(path: List[Tree])(using Context): Option[Apply] = | 
|  | 26 | +    path match | 
|  | 27 | +      case ValDef(_, _, _) :: Block(_, app: Apply) :: _ | 
|  | 28 | +          if !app.fun.isInfix => Some(app) | 
|  | 29 | +      case rest => | 
|  | 30 | +        def getApplyForContextFunctionParam(path: List[Tree]): Option[Apply] = | 
|  | 31 | +          path match | 
|  | 32 | +            // fun(arg@@) | 
|  | 33 | +            case (app: Apply) :: _ => Some(app) | 
|  | 34 | +            // fun(arg@@), where fun(argn: Context ?=> SomeType) | 
|  | 35 | +            // recursively matched for multiple context arguments, e.g. Context1 ?=> Context2 ?=> SomeType | 
|  | 36 | +            case (_: DefDef) :: Block(List(_), _: Closure) :: rest => | 
|  | 37 | +              getApplyForContextFunctionParam(rest) | 
|  | 38 | +            case _ => None | 
|  | 39 | +        for | 
|  | 40 | +          app <- getApplyForContextFunctionParam(rest) | 
|  | 41 | +          if !app.fun.isInfix | 
|  | 42 | +        yield app | 
|  | 43 | +    end match | 
|  | 44 | + | 
|  | 45 | + | 
|  | 46 | +object ApplyArgsExtractor: | 
|  | 47 | +  def getArgsAndParams( | 
|  | 48 | +    optIndexedContext: Option[IndexedContext], | 
|  | 49 | +    apply: Apply, | 
|  | 50 | +    span: Span | 
|  | 51 | +  )(using Context): List[(List[Tree], List[ParamSymbol])] = | 
|  | 52 | +    def collectArgss(a: Apply): List[List[Tree]] = | 
|  | 53 | +      def stripContextFuntionArgument(argument: Tree): List[Tree] = | 
|  | 54 | +        argument match | 
|  | 55 | +          case Block(List(d: DefDef), _: Closure) => | 
|  | 56 | +            d.rhs match | 
|  | 57 | +              case app: Apply => | 
|  | 58 | +                app.args | 
|  | 59 | +              case b @ Block(List(_: DefDef), _: Closure) => | 
|  | 60 | +                stripContextFuntionArgument(b) | 
|  | 61 | +              case _ => Nil | 
|  | 62 | +          case v => List(v) | 
|  | 63 | + | 
|  | 64 | +      val args = a.args.flatMap(stripContextFuntionArgument) | 
|  | 65 | +      a.fun match | 
|  | 66 | +        case app: Apply => collectArgss(app) :+ args | 
|  | 67 | +        case _ => List(args) | 
|  | 68 | +    end collectArgss | 
|  | 69 | + | 
|  | 70 | +    val method = apply.fun | 
|  | 71 | + | 
|  | 72 | +    val argss = collectArgss(apply) | 
|  | 73 | + | 
|  | 74 | +    def fallbackFindApply(sym: Symbol) = | 
|  | 75 | +      sym.info.member(nme.apply) match | 
|  | 76 | +        case NoDenotation => Nil | 
|  | 77 | +        case den => List(den.symbol) | 
|  | 78 | + | 
|  | 79 | +      // fallback for when multiple overloaded methods match the supplied args | 
|  | 80 | +    def fallbackFindMatchingMethods() = | 
|  | 81 | +      def matchingMethodsSymbols( | 
|  | 82 | +        indexedContext: IndexedContext, | 
|  | 83 | +        method: Tree | 
|  | 84 | +      ): List[Symbol] = | 
|  | 85 | +        method match | 
|  | 86 | +          case Ident(name) => indexedContext.findSymbol(name).getOrElse(Nil) | 
|  | 87 | +          case Select(This(_), name) => indexedContext.findSymbol(name).getOrElse(Nil) | 
|  | 88 | +          case sel @ Select(from, name) => | 
|  | 89 | +            val symbol = from.symbol | 
|  | 90 | +            val ownerSymbol = | 
|  | 91 | +              if symbol.is(Method) && symbol.owner.isClass then | 
|  | 92 | +                Some(symbol.owner) | 
|  | 93 | +              else Try(symbol.info.classSymbol).toOption | 
|  | 94 | +            ownerSymbol.map(sym =>  sym.info.member(name)).collect{ | 
|  | 95 | +              case single: SingleDenotation => List(single.symbol) | 
|  | 96 | +              case multi: MultiDenotation => multi.allSymbols | 
|  | 97 | +            }.getOrElse(Nil) | 
|  | 98 | +          case Apply(fun, _) => matchingMethodsSymbols(indexedContext, fun) | 
|  | 99 | +          case _ => Nil | 
|  | 100 | +      val matchingMethods = | 
|  | 101 | +        for | 
|  | 102 | +          indexedContext <- optIndexedContext.toList | 
|  | 103 | +          potentialMatch <- matchingMethodsSymbols(indexedContext, method) | 
|  | 104 | +          if potentialMatch.is(Flags.Method) && | 
|  | 105 | +                potentialMatch.vparamss.length >= argss.length && | 
|  | 106 | +                Try(potentialMatch.isAccessibleFrom(apply.symbol.info)).toOption | 
|  | 107 | +                  .getOrElse(false) && | 
|  | 108 | +                potentialMatch.vparamss | 
|  | 109 | +                  .zip(argss) | 
|  | 110 | +                  .reverse | 
|  | 111 | +                  .zipWithIndex | 
|  | 112 | +                  .forall { case (pair, index) => | 
|  | 113 | +                    FuzzyArgMatcher(potentialMatch.tparams) | 
|  | 114 | +                      .doMatch(allArgsProvided = index != 0, span) | 
|  | 115 | +                      .tupled(pair) | 
|  | 116 | +                    } | 
|  | 117 | +        yield potentialMatch | 
|  | 118 | +      matchingMethods | 
|  | 119 | +    end fallbackFindMatchingMethods | 
|  | 120 | + | 
|  | 121 | +    val matchingMethods: List[Symbol] = | 
|  | 122 | +      if method.symbol.paramSymss.nonEmpty then | 
|  | 123 | +        val allArgsAreSupplied = | 
|  | 124 | +          val vparamss = method.symbol.vparamss | 
|  | 125 | +          vparamss.length == argss.length && vparamss | 
|  | 126 | +            .zip(argss) | 
|  | 127 | +            .lastOption | 
|  | 128 | +            .exists { case (baseParams, baseArgs) => | 
|  | 129 | +              baseArgs.length == baseParams.length | 
|  | 130 | +            } | 
|  | 131 | +        // ``` | 
|  | 132 | +        //  m(arg : Int) | 
|  | 133 | +        //  m(arg : Int, anotherArg : Int) | 
|  | 134 | +        //  m(a@@) | 
|  | 135 | +        // ``` | 
|  | 136 | +        // complier will choose the first `m`, so we need to manually look for the other one | 
|  | 137 | +        if allArgsAreSupplied then | 
|  | 138 | +          val foundPotential = fallbackFindMatchingMethods() | 
|  | 139 | +          if foundPotential.contains(method.symbol) then foundPotential | 
|  | 140 | +          else method.symbol :: foundPotential | 
|  | 141 | +        else List(method.symbol) | 
|  | 142 | +      else if method.symbol.is(Method) || method.symbol == NoSymbol then | 
|  | 143 | +        fallbackFindMatchingMethods() | 
|  | 144 | +      else fallbackFindApply(method.symbol) | 
|  | 145 | +      end if | 
|  | 146 | +    end matchingMethods | 
|  | 147 | + | 
|  | 148 | +    matchingMethods.map { methodSym => | 
|  | 149 | +      val vparamss = methodSym.vparamss | 
|  | 150 | + | 
|  | 151 | +      // get params and args we are interested in | 
|  | 152 | +      // e.g. | 
|  | 153 | +      // in the following case, the interesting args and params are | 
|  | 154 | +      // - params: [apple, banana] | 
|  | 155 | +      // - args: [apple, b] | 
|  | 156 | +      // ``` | 
|  | 157 | +      // def curry(x: Int)(apple: String, banana: String) = ??? | 
|  | 158 | +      // curry(1)(apple = "test", b@@) | 
|  | 159 | +      // ``` | 
|  | 160 | +      val (baseParams0, baseArgs) = | 
|  | 161 | +        vparamss.zip(argss).lastOption.getOrElse((Nil, Nil)) | 
|  | 162 | + | 
|  | 163 | +      val baseParams: List[ParamSymbol] = | 
|  | 164 | +        def defaultBaseParams = baseParams0.map(JustSymbol(_)) | 
|  | 165 | +        @tailrec | 
|  | 166 | +        def getRefinedParams(refinedType: Type, level: Int): List[ParamSymbol] = | 
|  | 167 | +          if level > 0 then | 
|  | 168 | +            val resultTypeOpt = | 
|  | 169 | +              refinedType match | 
|  | 170 | +                case RefinedType(AppliedType(_, args), _, _) => args.lastOption | 
|  | 171 | +                case AppliedType(_, args) => args.lastOption | 
|  | 172 | +                case _ => None | 
|  | 173 | +            resultTypeOpt match | 
|  | 174 | +              case Some(resultType) => getRefinedParams(resultType, level - 1) | 
|  | 175 | +              case _ => defaultBaseParams | 
|  | 176 | +          else | 
|  | 177 | +            refinedType match | 
|  | 178 | +              case RefinedType(AppliedType(_, args), _, MethodType(ri)) => | 
|  | 179 | +                baseParams0.zip(ri).zip(args).map { case ((sym, name), arg) => | 
|  | 180 | +                  RefinedSymbol(sym, name, arg) | 
|  | 181 | +                } | 
|  | 182 | +              case _ => defaultBaseParams | 
|  | 183 | +        // finds param refinements for lambda expressions | 
|  | 184 | +        // val hello: (x: Int, y: Int) => Unit = (x, _) => println(x) | 
|  | 185 | +        @tailrec | 
|  | 186 | +        def refineParams(method: Tree, level: Int): List[ParamSymbol] = | 
|  | 187 | +          method match | 
|  | 188 | +            case Select(Apply(f, _), _) => refineParams(f, level + 1) | 
|  | 189 | +            case Select(h, name) => | 
|  | 190 | +              // for Select(foo, name = apply) we want `foo.symbol` | 
|  | 191 | +              if name == nme.apply then getRefinedParams(h.symbol.info, level) | 
|  | 192 | +              else getRefinedParams(method.symbol.info, level) | 
|  | 193 | +            case Apply(f, _) => | 
|  | 194 | +              refineParams(f, level + 1) | 
|  | 195 | +            case _ => getRefinedParams(method.symbol.info, level) | 
|  | 196 | +        refineParams(method, 0) | 
|  | 197 | +      end baseParams | 
|  | 198 | +      (baseArgs, baseParams) | 
|  | 199 | +    } | 
|  | 200 | + | 
|  | 201 | +  extension (method: Symbol) | 
|  | 202 | +    def vparamss(using Context) = method.filteredParamss(_.isTerm) | 
|  | 203 | +    def tparams(using Context) = method.filteredParamss(_.isType).flatten | 
|  | 204 | +    def filteredParamss(f: Symbol => Boolean)(using Context) = | 
|  | 205 | +      method.paramSymss.filter(params => params.forall(f)) | 
|  | 206 | +sealed trait ParamSymbol: | 
|  | 207 | +  def name: Name | 
|  | 208 | +  def info: Type | 
|  | 209 | +  def symbol: Symbol | 
|  | 210 | +  def nameBackticked(using Context) = name.decoded.backticked | 
|  | 211 | + | 
|  | 212 | +case class JustSymbol(symbol: Symbol)(using Context) extends ParamSymbol: | 
|  | 213 | +  def name: Name = symbol.name | 
|  | 214 | +  def info: Type = symbol.info | 
|  | 215 | + | 
|  | 216 | +case class RefinedSymbol(symbol: Symbol, name: Name, info: Type) | 
|  | 217 | +    extends ParamSymbol | 
|  | 218 | + | 
|  | 219 | + | 
|  | 220 | +class FuzzyArgMatcher(tparams: List[Symbol])(using Context): | 
|  | 221 | + | 
|  | 222 | +  /** | 
|  | 223 | +   * A heuristic for checking if the passed arguments match the method's arguments' types. | 
|  | 224 | +   * For non-polymorphic methods we use the subtype relation (`<:<`) | 
|  | 225 | +   * and for polymorphic methods we use a heuristic. | 
|  | 226 | +   * We check the args types not the result type. | 
|  | 227 | +   */ | 
|  | 228 | +  def doMatch( | 
|  | 229 | +      allArgsProvided: Boolean, | 
|  | 230 | +      span: Span | 
|  | 231 | +  )(expectedArgs: List[Symbol], actualArgs: List[Tree]) = | 
|  | 232 | +    (expectedArgs.length == actualArgs.length || | 
|  | 233 | +      (!allArgsProvided && expectedArgs.length >= actualArgs.length)) && | 
|  | 234 | +      actualArgs.zipWithIndex.forall { | 
|  | 235 | +        case (arg: Ident, _) if arg.span.contains(span) => true | 
|  | 236 | +        case (NamedArg(name, arg), _) => | 
|  | 237 | +          expectedArgs.exists { expected => | 
|  | 238 | +            expected.name == name && (!arg.hasType || arg.typeOpt.unfold | 
|  | 239 | +              .fuzzyArg_<:<(expected.info)) | 
|  | 240 | +          } | 
|  | 241 | +        case (arg, i) => | 
|  | 242 | +          !arg.hasType || arg.typeOpt.unfold.fuzzyArg_<:<(expectedArgs(i).info) | 
|  | 243 | +      } | 
|  | 244 | + | 
|  | 245 | +  extension (arg: Type) | 
|  | 246 | +    def fuzzyArg_<:<(expected: Type) = | 
|  | 247 | +      if tparams.isEmpty then arg <:< expected | 
|  | 248 | +      else arg <:< substituteTypeParams(expected) | 
|  | 249 | +    def unfold = | 
|  | 250 | +      arg match | 
|  | 251 | +        case arg: TermRef => arg.underlying | 
|  | 252 | +        case e => e | 
|  | 253 | + | 
|  | 254 | +  private def substituteTypeParams(t: Type): Type = | 
|  | 255 | +    t match | 
|  | 256 | +      case e if tparams.exists(_ == e.typeSymbol) => | 
|  | 257 | +        val matchingParam = tparams.find(_ == e.typeSymbol).get | 
|  | 258 | +        matchingParam.info match | 
|  | 259 | +          case b @ TypeBounds(_, _) => WildcardType(b) | 
|  | 260 | +          case _ => WildcardType | 
|  | 261 | +      case o @ OrType(e1, e2) => | 
|  | 262 | +        OrType(substituteTypeParams(e1), substituteTypeParams(e2), o.isSoft) | 
|  | 263 | +      case AndType(e1, e2) => | 
|  | 264 | +        AndType(substituteTypeParams(e1), substituteTypeParams(e2)) | 
|  | 265 | +      case AppliedType(et, eparams) => | 
|  | 266 | +        AppliedType(et, eparams.map(substituteTypeParams)) | 
|  | 267 | +      case _ => t | 
|  | 268 | + | 
|  | 269 | +end FuzzyArgMatcher | 
0 commit comments