diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index 094a9c1a2bf..3cce8df5941 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -133,6 +133,7 @@ struct ASTIterator iterator->maybeDispatchCallback(expr); dispatchIfNotNull(expr->functionExpr); + dispatchIfNotNull(expr->originalFunctionExpr); for (auto arg : expr->arguments) dispatchIfNotNull(arg); } diff --git a/source/slang/slang-ast-print.cpp b/source/slang/slang-ast-print.cpp index 4b5a69f15d8..ee747a4c291 100644 --- a/source/slang/slang-ast-print.cpp +++ b/source/slang/slang-ast-print.cpp @@ -1129,6 +1129,7 @@ void ASTPrinter::addVal(Val* val) /* static */ void ASTPrinter::appendDeclName(Decl* decl, StringBuilder& out) { + decl = maybeGetInner(decl); if (as(decl)) { out << "init"; @@ -1231,8 +1232,7 @@ void ASTPrinter::_addDeclPathRec(const DeclRef& declRef, Index depth) } else if (auto extensionDeclRef = parentDeclRef.as()) { - ExtensionDecl* extensionDecl = as(parentDeclRef.getDecl()); - Type* type = extensionDecl->targetType.type; + Type* type = getTargetType(m_astBuilder, extensionDeclRef); if (m_optionFlags & OptionFlag::NoSpecializedExtensionTypeName) { if (auto unspecializedDeclRef = isDeclRefTypeOf(type)) @@ -1522,6 +1522,8 @@ void ASTPrinter::addDeclKindPrefix(Decl* decl) continue; if (as(modifier)) continue; + if (as(modifier)) + continue; } // Don't print out attributes. if (as(modifier)) diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 7ddec20fbd6..30e31740110 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1252,8 +1252,8 @@ struct SemanticsVisitor : public SemanticsContext TypeExp TranslateTypeNode(TypeExp const& typeExp); Type* getRemovedModifierType(ModifiedType* type, ModifierVal* modifier); Type* getConstantBufferType(Type* elementType, Type* layoutType); - DeclRefType* getExprDeclRefType(Expr* expr); + LookupResult lookupConstructorsInType(Type* type, Scope* sourceScope); /// Is `decl` usable as a static member? bool isDeclUsableAsStaticMember(Decl* decl); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 6c0a7f18430..41aba267465 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -2249,6 +2249,24 @@ DeclRef SemanticsVisitor::inferGenericArguments( return trySolveConstraintSystem(&constraints, genericDeclRef, knownGenericArgs, outBaseCost); } +LookupResult SemanticsVisitor::lookupConstructorsInType(Type* type, Scope* sourceScope) +{ + // Look up all the initializers on `type` by looking up + // its members named `$init`. All `__init` declarations are stored + // with the name `$init` internally to avoid potential conflicts + // if a user decided to name a field/method `__init`. + LookupOptions options = + LookupOptions(uint8_t(LookupOptions::IgnoreInheritance) | uint8_t(LookupOptions::NoDeref)); + return lookUpMember( + m_astBuilder, + this, + getName("$init"), + type, + sourceScope, + LookupMask::Default, + options); +} + void SemanticsVisitor::AddTypeOverloadCandidates(Type* type, OverloadResolveContext& context) { // The code being checked is trying to apply `type` like a function. @@ -2272,16 +2290,7 @@ void SemanticsVisitor::AddTypeOverloadCandidates(Type* type, OverloadResolveCont // from a value of the same type. There is no need in Slang for // "copy constructors" but the core module currently has to define // some just to make code that does, e.g., `float(1.0f)` work.) - LookupOptions options = - LookupOptions(uint8_t(LookupOptions::IgnoreInheritance) | uint8_t(LookupOptions::NoDeref)); - LookupResult initializers = lookUpMember( - m_astBuilder, - this, - getName("$init"), - type, - context.sourceScope, - LookupMask::Default, - options); + LookupResult initializers = lookupConstructorsInType(type, context.sourceScope); AddOverloadCandidates(initializers, context); } @@ -2702,6 +2711,12 @@ Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr) expr->arguments[0], &tempSink, &conversionCost); + if (auto resultInvokeExpr = as(resultExpr)) + { + resultInvokeExpr->originalFunctionExpr = expr->functionExpr; + resultInvokeExpr->argumentDelimeterLocs = expr->argumentDelimeterLocs; + resultInvokeExpr->loc = expr->loc; + } if (coerceResult) return resultExpr; typeOverloadChecked = true; diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 60f6cc92f9d..7cdd1614ccd 100644 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -3764,6 +3764,8 @@ class Session : public RefObject, public slang::IGlobalSession ComPtr getAutodiffLibraryCode(); ComPtr getGLSLLibraryCode(); + void getBuiltinModuleSource(StringBuilder& sb, slang::BuiltinModuleName moduleName); + RefPtr m_sharedASTBuilder; SPIRVCoreGrammarInfo& getSPIRVCoreGrammarInfo() diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index f394e7d7a40..1ae9e4a0bb0 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -235,16 +235,6 @@ struct ASTLookupExprVisitor : public ExprVisitor return dispatchIfNotNull(expr->originalExpr); } - bool visitTypeCastExpr(TypeCastExpr* expr) - { - if (dispatchIfNotNull(expr->functionExpr)) - return true; - for (auto arg : expr->arguments) - if (dispatchIfNotNull(arg)) - return true; - return false; - } - bool visitDerefExpr(DerefExpr* expr) { return dispatchIfNotNull(expr->base); } bool visitMatrixSwizzleExpr(MatrixSwizzleExpr* expr) { diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp index d52c5d8551f..ba2722faef8 100644 --- a/source/slang/slang-language-server.cpp +++ b/source/slang/slang-language-server.cpp @@ -565,6 +565,37 @@ HumaneSourceLoc getModuleLoc(SourceManager* manager, ContainerDecl* moduleDecl) return location; } +// When user code has `Foo(123)` where `Foo` is a `struct`, goto-definition on +// `Foo` should redirect to the constructor of `Foo` instead of the type declaration of `Foo`. +// This function will check if the `declRefExpr` is a reference to a type declaration, +// but the declRefExpr is referenced from an `InvokeExpr::originalFunctionExpr` that is now +// resolved to a constructor. If so we will return the declRef of the constructor. +// +DeclRef maybeRedirectToConstructor(DeclRefExpr* declRefExpr, const List& path) +{ + if (path.getCount() < 2) + return declRefExpr->declRef; + if (!as(declRefExpr->declRef)) + return declRefExpr->declRef; + auto invokeExpr = as(path[path.getCount() - 2]); + if (!invokeExpr) + return declRefExpr->declRef; + if (!invokeExpr->originalFunctionExpr) + return declRefExpr->declRef; + auto originalFuncExpr = invokeExpr->originalFunctionExpr; + if (originalFuncExpr != declRefExpr) + return declRefExpr->declRef; + // If the invoke expression is the same as the decl ref expression, + // it means we are looking at a constructor call. + auto resolvedFuncExpr = as(invokeExpr->functionExpr); + if (!resolvedFuncExpr) + return declRefExpr->declRef; + auto ctorDecl = as(resolvedFuncExpr->declRef); + if (ctorDecl) + return ctorDecl; + return declRefExpr->declRef; +} + SlangResult LanguageServer::hover( const LanguageServerProtocol::HoverParams& args, const JSONValue& responseId) @@ -828,7 +859,8 @@ LanguageServerResult LanguageServerCore::hover( }; if (auto declRefExpr = as(leafNode)) { - fillDeclRefHoverInfo(declRefExpr->declRef, declRefExpr->name); + auto resolvedDeclRef = maybeRedirectToConstructor(declRefExpr, findResult[0].path); + fillDeclRefHoverInfo(resolvedDeclRef, declRefExpr->name); } else if (auto overloadedExpr = as(leafNode)) { @@ -1004,11 +1036,12 @@ LanguageServerResult> LanguageServerCore: { if (declRefExpr->declRef.getDecl()) { + auto declRef = declRefExpr->declRef; + declRef = maybeRedirectToConstructor(declRefExpr, findResult[0].path); auto location = version->linkage->getSourceManager()->getHumaneLoc( - declRefExpr->declRef.getNameLoc().isValid() ? declRefExpr->declRef.getNameLoc() - : declRefExpr->declRef.getLoc(), + declRef.getNameLoc().isValid() ? declRef.getNameLoc() : declRef.getLoc(), SourceLocType::Actual); - auto name = declRefExpr->declRef.getName(); + auto name = declRef.getName(); locations.add(LocationResult{ location, name ? (int)UTF8Util::calcUTF16CharCount(name->text.getUnownedSlice()) : 0}); @@ -1076,6 +1109,14 @@ LanguageServerResult> LanguageServerCore: { result.uri = URI::fromLocalFilePath(loc.loc.pathInfo.foundPath.getUnownedSlice()).uri; + } + else if (loc.loc.pathInfo.getName() == "core" || loc.loc.pathInfo.getName() == "glsl") + { + result.uri = StringBuilder() << "slang-synth://" << loc.loc.pathInfo.getName() + << "/" << loc.loc.pathInfo.getName() << ".builtin"; + } + if (result.uri.getLength() != 0) + { doc->oneBasedUTF8LocToZeroBasedUTF16Loc( loc.loc.line, loc.loc.column, @@ -1504,6 +1545,75 @@ SlangResult LanguageServer::signatureHelp( return SLANG_OK; } +// Heuristical cost for determining the best candidate to use as the active signature. +// We will always use the candidate that has the most matched parameters to the current argument +// list. If there are multiple candidates with the same number of matched parameters, we will +// use the one with the least number of unmatched parameters. If there are still multiple +// candidates with the same number of unmatched parameters, we will use the one with the least +// maximum argument conversion cost. +// +struct CallCandidateMatchCost +{ + Index matchedArgCount = 0; + Index excessArgCount = 0; + Index unmatchedParamCount = 0; + ConversionCost maxArgConversionCost = 0; + + bool isBetterThan(const CallCandidateMatchCost& other) const + { + if (excessArgCount < other.excessArgCount) + return true; + else if (excessArgCount > other.excessArgCount) + return false; + if (matchedArgCount > other.matchedArgCount) + return true; + else if (matchedArgCount < other.matchedArgCount) + return false; + + if (unmatchedParamCount < other.unmatchedParamCount) + return true; + else if (unmatchedParamCount > other.unmatchedParamCount) + return false; + return maxArgConversionCost < other.maxArgConversionCost; + } +}; + +// Given a callable decl and an AppExprBase containing the arguments used to call it, +// return the match cost for the candidate. +static CallCandidateMatchCost getCallCandidateMatchCost( + DeclRef callableDeclRef, + AppExprBase* appExpr, + SemanticsVisitor& semanticsVisitor, + WorkspaceVersion* version) +{ + CallCandidateMatchCost result; + auto astBuilder = version->linkage->getASTBuilder(); + auto paramList = getMembersOfType(astBuilder, callableDeclRef).toArray(); + + for (Index argId = 0; argId < appExpr->arguments.getCount(); argId++) + { + auto arg = appExpr->arguments[argId]; + if (!arg) + continue; + if (!arg->type.type) + continue; + if (argId < paramList.getCount()) + { + auto paramType = getType(version->linkage->getASTBuilder(), paramList[argId]); + ConversionCost argCost = 0; + if (paramType && semanticsVisitor.canCoerce(paramType, arg->type.type, arg, &argCost)) + { + result.matchedArgCount++; + result.maxArgConversionCost = Math::Max(result.maxArgConversionCost, argCost); + } + } + } + result.excessArgCount = + Math::Max((Index)0, (appExpr->argumentDelimeterLocs.getCount() - 1) - paramList.getCount()); + result.unmatchedParamCount = paramList.getCount() - result.matchedArgCount; + return result; +} + LanguageServerResult LanguageServerCore::signatureHelp( const LanguageServerProtocol::SignatureHelpParams& args) { @@ -1594,11 +1704,41 @@ LanguageServerResult LanguageServerCore:: } SignatureHelp response; + response.activeSignature = 0; + + CallCandidateMatchCost bestCandidateMatchCost; + + // We will use an ad-hoc semantics visitor to check for argument-to-parameter conversions + // and to determine the best candidate signature. + // In the ideal design, this info should be gathered during the normal type checking + // process, but that require a lot of refactoring in the current code base, and may + // risk slowing down type checking for non-language-server use cases since we won't be + // able to do as many early returns. + // So instead we will do a separate ad-hoc checking here to do a best-effort guess + // on the best candidate. + // + DiagnosticSink sink; + SharedSemanticsContext semanticsContext(version->linkage, nullptr, &sink); + SemanticsVisitor semanticsVisitor(&semanticsContext); + auto addDeclRef = [&](DeclRef declRef) { if (!declRef.getDecl()) return; + // If we have a better match than the current best, we will update response.activeSignature + // to this signature. + if (auto callableDeclRef = declRef.as()) + { + auto matchCost = + getCallCandidateMatchCost(callableDeclRef, appExpr, semanticsVisitor, version); + if (matchCost.isBetterThan(bestCandidateMatchCost)) + { + bestCandidateMatchCost = matchCost; + response.activeSignature = (uint32_t)response.signatures.getCount(); + } + } + SignatureInformation sigInfo; List> paramRanges; @@ -1675,13 +1815,14 @@ LanguageServerResult LanguageServerCore:: if (auto declRefExpr = as(funcExpr)) { - if (auto aggDeclRef = as(declRefExpr->declRef)) + if (auto typeType = as(declRefExpr->type.type)) { // Look for initializers - for (auto member : - getMembersOfType(version->linkage->getASTBuilder(), aggDeclRef)) + auto ctors = + semanticsVisitor.lookupConstructorsInType(typeType->getType(), declRefExpr->scope); + for (auto ctor : ctors) { - addDeclRef(member); + addDeclRef(ctor.declRef); } } else @@ -1711,7 +1852,6 @@ LanguageServerResult LanguageServerCore:: { addFuncType(funcType); } - response.activeSignature = 0; response.activeParameter = 0; for (int i = 1; i < appExpr->argumentDelimeterLocs.getCount(); i++) { @@ -2828,4 +2968,23 @@ SLANG_API SlangResult runLanguageServer(Slang::LanguageServerStartupOptions opti return SLANG_OK; } +SLANG_API SlangResult +getBuiltinModuleSource(const UnownedStringSlice& moduleName, slang::IBlob** blob) +{ + ComPtr globalSession; + slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()); + Slang::Session* session = static_cast(globalSession.get()); + StringBuilder sb; + if (moduleName.startsWith("core")) + { + session->getBuiltinModuleSource(sb, slang::BuiltinModuleName::Core); + } + else if (moduleName.startsWith("glsl")) + { + session->getBuiltinModuleSource(sb, slang::BuiltinModuleName::GLSL); + } + *blob = StringBlob::moveCreate(sb.produceString()).detach(); + return SLANG_OK; +} + } // namespace Slang diff --git a/source/slang/slang-language-server.h b/source/slang/slang-language-server.h index 43c3521eb57..31b7114c0c8 100644 --- a/source/slang/slang-language-server.h +++ b/source/slang/slang-language-server.h @@ -275,4 +275,7 @@ inline bool _isIdentifierChar(char ch) } SLANG_API SlangResult runLanguageServer(LanguageServerStartupOptions options); +SLANG_API SlangResult +getBuiltinModuleSource(const UnownedStringSlice& moduleName, slang::IBlob** blob); + } // namespace Slang diff --git a/source/slang/slang-workspace-version.cpp b/source/slang/slang-workspace-version.cpp index b1a3dec3451..63bf7ed5300 100644 --- a/source/slang/slang-workspace-version.cpp +++ b/source/slang/slang-workspace-version.cpp @@ -487,7 +487,10 @@ void DocumentVersion::oneBasedUTF8LocToZeroBasedUTF16Loc( Index rsLine = inLine - 1; auto bounds = getUTF16Boundaries(inLine); outLine = rsLine; - outCol = std::lower_bound(bounds.begin(), bounds.end(), inCol - 1) - bounds.begin(); + if (bounds.getCount() != 0) + outCol = std::lower_bound(bounds.begin(), bounds.end(), inCol - 1) - bounds.begin(); + else + outCol = inCol - 1; } void DocumentVersion::oneBasedUTF8LocToZeroBasedUTF16Loc( diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index e98187c850a..ebc1d113a5a 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -435,6 +435,21 @@ SlangResult Session::compileCoreModule(slang::CompileCoreModuleFlags compileFlag return compileBuiltinModule(slang::BuiltinModuleName::Core, compileFlags); } +void Session::getBuiltinModuleSource(StringBuilder& sb, slang::BuiltinModuleName moduleName) +{ + switch (moduleName) + { + case slang::BuiltinModuleName::Core: + sb << (const char*)getCoreLibraryCode()->getBufferPointer() + << (const char*)getHLSLLibraryCode()->getBufferPointer() + << (const char*)getAutodiffLibraryCode()->getBufferPointer(); + break; + case slang::BuiltinModuleName::GLSL: + sb << (const char*)getGLSLLibraryCode()->getBufferPointer(); + break; + } +} + SlangResult Session::compileBuiltinModule( slang::BuiltinModuleName moduleName, slang::CompileCoreModuleFlags compileFlags) @@ -460,17 +475,7 @@ SlangResult Session::compileBuiltinModule( } StringBuilder moduleSrcBuilder; - switch (moduleName) - { - case slang::BuiltinModuleName::Core: - moduleSrcBuilder << (const char*)getCoreLibraryCode()->getBufferPointer() - << (const char*)getHLSLLibraryCode()->getBufferPointer() - << (const char*)getAutodiffLibraryCode()->getBufferPointer(); - break; - case slang::BuiltinModuleName::GLSL: - moduleSrcBuilder << (const char*)getGLSLLibraryCode()->getBufferPointer(); - break; - } + getBuiltinModuleSource(moduleSrcBuilder, moduleName); // TODO(JS): Could make this return a SlangResult as opposed to exception auto moduleSrcBlob = StringBlob::moveCreate(moduleSrcBuilder.produceString()); diff --git a/tests/language-server/ctor-hover.slang b/tests/language-server/ctor-hover.slang new file mode 100644 index 00000000000..186d373cffd --- /dev/null +++ b/tests/language-server/ctor-hover.slang @@ -0,0 +1,15 @@ +//TEST:LANG_SERVER(filecheck=CHECK): +struct MyType +{ + __init(int x) {} +} + +void test() +{ +//HOVER:10,18 + let obj = MyType(5); + // ^^^^^ + // Hover here should show info for the ctor, not the type. +} + +//CHECK: MyType.init diff --git a/tests/language-server/ctor-signature.slang b/tests/language-server/ctor-signature.slang new file mode 100644 index 00000000000..6fc7b3844d8 --- /dev/null +++ b/tests/language-server/ctor-signature.slang @@ -0,0 +1,9 @@ +//TEST:LANG_SERVER(filecheck=CHECK): + +void test() +{ +//SIGNATURE:6,25 + let v = float3(1.0, ) +} + +// CHECK: (selected) float3.init(float x, float diff --git a/tests/language-server/smoke.slang.expected.txt b/tests/language-server/smoke.slang.expected.txt index 7e0e8def758..451fa0655b4 100644 --- a/tests/language-server/smoke.slang.expected.txt +++ b/tests/language-server/smoke.slang.expected.txt @@ -16,10 +16,8 @@ content: -------- activeParameter: 0 activeSignature: 0 -func T.getSum() -> int: +(selected) func T.getSum() -> int: Returns the sum of the contents. -{REDACTED}.slang(10) - - +{REDACTED}.slang(10) \ No newline at end of file diff --git a/tools/slang-test/slang-test-main.cpp b/tools/slang-test/slang-test-main.cpp index c22b6dbf676..5e17513d501 100644 --- a/tools/slang-test/slang-test-main.cpp +++ b/tools/slang-test/slang-test-main.cpp @@ -2141,8 +2141,13 @@ TestResult runLanguageServerTest(TestContext* context, TestInput& input) { actualOutputSB << "activeParameter: " << sigInfo.activeParameter << "\n"; actualOutputSB << "activeSignature: " << sigInfo.activeSignature << "\n"; - for (auto item : sigInfo.signatures) + for (Index i = 0; i < sigInfo.signatures.getCount(); ++i) { + auto& item = sigInfo.signatures[i]; + if (i == sigInfo.activeSignature) + { + actualOutputSB << "(selected) "; + } actualOutputSB << item.label << ":"; for (auto param : item.parameters) { diff --git a/tools/slangd/main.cpp b/tools/slangd/main.cpp index 2409922730c..f2b9622530a 100644 --- a/tools/slangd/main.cpp +++ b/tools/slangd/main.cpp @@ -16,6 +16,14 @@ int main(int argc, const char* const* argv) { isDebug = true; } + else if (Slang::UnownedStringSlice(argv[i]) == "--print-builtin-module" && i < argc - 1) + { + Slang::UnownedStringSlice moduleName = Slang::UnownedStringSlice(argv[++i]); + Slang::ComPtr code; + Slang::getBuiltinModuleSource(moduleName, code.writeRef()); + printf("%s\n", (const char*)code->getBufferPointer()); + return 0; + } } if (isDebug) {