Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions source/slang/slang-ast-iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ struct ASTIterator
iterator->maybeDispatchCallback(expr);

dispatchIfNotNull(expr->functionExpr);
dispatchIfNotNull(expr->originalFunctionExpr);
for (auto arg : expr->arguments)
dispatchIfNotNull(arg);
}
Expand Down
6 changes: 4 additions & 2 deletions source/slang/slang-ast-print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,7 @@ void ASTPrinter::addVal(Val* val)

/* static */ void ASTPrinter::appendDeclName(Decl* decl, StringBuilder& out)
{
decl = maybeGetInner(decl);
if (as<ConstructorDecl>(decl))
{
out << "init";
Expand Down Expand Up @@ -1231,8 +1232,7 @@ void ASTPrinter::_addDeclPathRec(const DeclRef<Decl>& declRef, Index depth)
}
else if (auto extensionDeclRef = parentDeclRef.as<ExtensionDecl>())
{
ExtensionDecl* extensionDecl = as<ExtensionDecl>(parentDeclRef.getDecl());
Type* type = extensionDecl->targetType.type;
Type* type = getTargetType(m_astBuilder, extensionDeclRef);
if (m_optionFlags & OptionFlag::NoSpecializedExtensionTypeName)
{
if (auto unspecializedDeclRef = isDeclRefTypeOf<Decl>(type))
Expand Down Expand Up @@ -1522,6 +1522,8 @@ void ASTPrinter::addDeclKindPrefix(Decl* decl)
continue;
if (as<HLSLLayoutSemantic>(modifier))
continue;
if (as<ImplicitConversionModifier>(modifier))
continue;
}
// Don't print out attributes.
if (as<AttributeBase>(modifier))
Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang-check-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1252,8 +1252,8 @@ struct SemanticsVisitor : public SemanticsContext
TypeExp TranslateTypeNode(TypeExp const& typeExp);
Type* getRemovedModifierType(ModifiedType* type, ModifierVal* modifier);
Type* getConstantBufferType(Type* elementType, Type* layoutType);

DeclRefType* getExprDeclRefType(Expr* expr);
LookupResult lookupConstructorsInType(Type* type, Scope* sourceScope);

/// Is `decl` usable as a static member?
bool isDeclUsableAsStaticMember(Decl* decl);
Expand Down
35 changes: 25 additions & 10 deletions source/slang/slang-check-overload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2249,6 +2249,24 @@ DeclRef<Decl> SemanticsVisitor::inferGenericArguments(
return trySolveConstraintSystem(&constraints, genericDeclRef, knownGenericArgs, outBaseCost);
}

LookupResult SemanticsVisitor::lookupConstructorsInType(Type* type, Scope* sourceScope)
{
// Look up all the initializers on `type` by looking up
// its members named `$init`. All `__init` declarations are stored
// with the name `$init` internally to avoid potential conflicts
// if a user decided to name a field/method `__init`.
LookupOptions options =
LookupOptions(uint8_t(LookupOptions::IgnoreInheritance) | uint8_t(LookupOptions::NoDeref));
return lookUpMember(
m_astBuilder,
this,
getName("$init"),
type,
sourceScope,
LookupMask::Default,
options);
}

void SemanticsVisitor::AddTypeOverloadCandidates(Type* type, OverloadResolveContext& context)
{
// The code being checked is trying to apply `type` like a function.
Expand All @@ -2272,16 +2290,7 @@ void SemanticsVisitor::AddTypeOverloadCandidates(Type* type, OverloadResolveCont
// from a value of the same type. There is no need in Slang for
// "copy constructors" but the core module currently has to define
// some just to make code that does, e.g., `float(1.0f)` work.)
LookupOptions options =
LookupOptions(uint8_t(LookupOptions::IgnoreInheritance) | uint8_t(LookupOptions::NoDeref));
LookupResult initializers = lookUpMember(
m_astBuilder,
this,
getName("$init"),
type,
context.sourceScope,
LookupMask::Default,
options);
LookupResult initializers = lookupConstructorsInType(type, context.sourceScope);
AddOverloadCandidates(initializers, context);
}

Expand Down Expand Up @@ -2702,6 +2711,12 @@ Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr)
expr->arguments[0],
&tempSink,
&conversionCost);
if (auto resultInvokeExpr = as<InvokeExpr>(resultExpr))
{
resultInvokeExpr->originalFunctionExpr = expr->functionExpr;
resultInvokeExpr->argumentDelimeterLocs = expr->argumentDelimeterLocs;
resultInvokeExpr->loc = expr->loc;
}
if (coerceResult)
return resultExpr;
typeOverloadChecked = true;
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -3764,6 +3764,8 @@ class Session : public RefObject, public slang::IGlobalSession
ComPtr<ISlangBlob> getAutodiffLibraryCode();
ComPtr<ISlangBlob> getGLSLLibraryCode();

void getBuiltinModuleSource(StringBuilder& sb, slang::BuiltinModuleName moduleName);

RefPtr<SharedASTBuilder> m_sharedASTBuilder;

SPIRVCoreGrammarInfo& getSPIRVCoreGrammarInfo()
Expand Down
10 changes: 0 additions & 10 deletions source/slang/slang-language-server-ast-lookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,6 @@ struct ASTLookupExprVisitor : public ExprVisitor<ASTLookupExprVisitor, bool>
return dispatchIfNotNull(expr->originalExpr);
}

bool visitTypeCastExpr(TypeCastExpr* expr)
{
if (dispatchIfNotNull(expr->functionExpr))
return true;
for (auto arg : expr->arguments)
if (dispatchIfNotNull(arg))
return true;
return false;
}

bool visitDerefExpr(DerefExpr* expr) { return dispatchIfNotNull(expr->base); }
bool visitMatrixSwizzleExpr(MatrixSwizzleExpr* expr)
{
Expand Down
177 changes: 168 additions & 9 deletions source/slang/slang-language-server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,37 @@ HumaneSourceLoc getModuleLoc(SourceManager* manager, ContainerDecl* moduleDecl)
return location;
}

// When user code has `Foo(123)` where `Foo` is a `struct`, goto-definition on
// `Foo` should redirect to the constructor of `Foo` instead of the type declaration of `Foo`.
// This function will check if the `declRefExpr` is a reference to a type declaration,
// but the declRefExpr is referenced from an `InvokeExpr::originalFunctionExpr` that is now
// resolved to a constructor. If so we will return the declRef of the constructor.
//
DeclRef<Decl> maybeRedirectToConstructor(DeclRefExpr* declRefExpr, const List<SyntaxNode*>& path)
{
if (path.getCount() < 2)
return declRefExpr->declRef;
if (!as<AggTypeDecl>(declRefExpr->declRef))
return declRefExpr->declRef;
auto invokeExpr = as<InvokeExpr>(path[path.getCount() - 2]);
if (!invokeExpr)
return declRefExpr->declRef;
if (!invokeExpr->originalFunctionExpr)
return declRefExpr->declRef;
auto originalFuncExpr = invokeExpr->originalFunctionExpr;
if (originalFuncExpr != declRefExpr)
return declRefExpr->declRef;
// If the invoke expression is the same as the decl ref expression,
// it means we are looking at a constructor call.
auto resolvedFuncExpr = as<DeclRefExpr>(invokeExpr->functionExpr);
if (!resolvedFuncExpr)
return declRefExpr->declRef;
auto ctorDecl = as<ConstructorDecl>(resolvedFuncExpr->declRef);
if (ctorDecl)
return ctorDecl;
return declRefExpr->declRef;
}

SlangResult LanguageServer::hover(
const LanguageServerProtocol::HoverParams& args,
const JSONValue& responseId)
Expand Down Expand Up @@ -828,7 +859,8 @@ LanguageServerResult<LanguageServerProtocol::Hover> LanguageServerCore::hover(
};
if (auto declRefExpr = as<DeclRefExpr>(leafNode))
{
fillDeclRefHoverInfo(declRefExpr->declRef, declRefExpr->name);
auto resolvedDeclRef = maybeRedirectToConstructor(declRefExpr, findResult[0].path);
fillDeclRefHoverInfo(resolvedDeclRef, declRefExpr->name);
}
else if (auto overloadedExpr = as<OverloadedExpr>(leafNode))
{
Expand Down Expand Up @@ -1004,11 +1036,12 @@ LanguageServerResult<List<LanguageServerProtocol::Location>> LanguageServerCore:
{
if (declRefExpr->declRef.getDecl())
{
auto declRef = declRefExpr->declRef;
declRef = maybeRedirectToConstructor(declRefExpr, findResult[0].path);
auto location = version->linkage->getSourceManager()->getHumaneLoc(
declRefExpr->declRef.getNameLoc().isValid() ? declRefExpr->declRef.getNameLoc()
: declRefExpr->declRef.getLoc(),
declRef.getNameLoc().isValid() ? declRef.getNameLoc() : declRef.getLoc(),
SourceLocType::Actual);
auto name = declRefExpr->declRef.getName();
auto name = declRef.getName();
locations.add(LocationResult{
location,
name ? (int)UTF8Util::calcUTF16CharCount(name->text.getUnownedSlice()) : 0});
Expand Down Expand Up @@ -1076,6 +1109,14 @@ LanguageServerResult<List<LanguageServerProtocol::Location>> LanguageServerCore:
{
result.uri =
URI::fromLocalFilePath(loc.loc.pathInfo.foundPath.getUnownedSlice()).uri;
}
else if (loc.loc.pathInfo.getName() == "core" || loc.loc.pathInfo.getName() == "glsl")
{
result.uri = StringBuilder() << "slang-synth://" << loc.loc.pathInfo.getName()
<< "/" << loc.loc.pathInfo.getName() << ".builtin";
}
if (result.uri.getLength() != 0)
{
doc->oneBasedUTF8LocToZeroBasedUTF16Loc(
loc.loc.line,
loc.loc.column,
Expand Down Expand Up @@ -1504,6 +1545,75 @@ SlangResult LanguageServer::signatureHelp(
return SLANG_OK;
}

// Heuristical cost for determining the best candidate to use as the active signature.
// We will always use the candidate that has the most matched parameters to the current argument
// list. If there are multiple candidates with the same number of matched parameters, we will
// use the one with the least number of unmatched parameters. If there are still multiple
// candidates with the same number of unmatched parameters, we will use the one with the least
// maximum argument conversion cost.
//
struct CallCandidateMatchCost
{
Index matchedArgCount = 0;
Index excessArgCount = 0;
Index unmatchedParamCount = 0;
ConversionCost maxArgConversionCost = 0;

bool isBetterThan(const CallCandidateMatchCost& other) const
{
if (excessArgCount < other.excessArgCount)
return true;
else if (excessArgCount > other.excessArgCount)
return false;
if (matchedArgCount > other.matchedArgCount)
return true;
else if (matchedArgCount < other.matchedArgCount)
return false;

if (unmatchedParamCount < other.unmatchedParamCount)
return true;
else if (unmatchedParamCount > other.unmatchedParamCount)
return false;
return maxArgConversionCost < other.maxArgConversionCost;
}
};

// Given a callable decl and an AppExprBase containing the arguments used to call it,
// return the match cost for the candidate.
static CallCandidateMatchCost getCallCandidateMatchCost(
DeclRef<CallableDecl> callableDeclRef,
AppExprBase* appExpr,
SemanticsVisitor& semanticsVisitor,
WorkspaceVersion* version)
{
CallCandidateMatchCost result;
auto astBuilder = version->linkage->getASTBuilder();
auto paramList = getMembersOfType<ParamDecl>(astBuilder, callableDeclRef).toArray();

for (Index argId = 0; argId < appExpr->arguments.getCount(); argId++)
{
auto arg = appExpr->arguments[argId];
if (!arg)
continue;
if (!arg->type.type)
continue;
if (argId < paramList.getCount())
{
auto paramType = getType(version->linkage->getASTBuilder(), paramList[argId]);
ConversionCost argCost = 0;
if (paramType && semanticsVisitor.canCoerce(paramType, arg->type.type, arg, &argCost))
{
result.matchedArgCount++;
result.maxArgConversionCost = Math::Max(result.maxArgConversionCost, argCost);
}
}
}
result.excessArgCount =
Math::Max((Index)0, (appExpr->argumentDelimeterLocs.getCount() - 1) - paramList.getCount());
result.unmatchedParamCount = paramList.getCount() - result.matchedArgCount;
return result;
}

LanguageServerResult<LanguageServerProtocol::SignatureHelp> LanguageServerCore::signatureHelp(
const LanguageServerProtocol::SignatureHelpParams& args)
{
Expand Down Expand Up @@ -1594,11 +1704,41 @@ LanguageServerResult<LanguageServerProtocol::SignatureHelp> LanguageServerCore::
}

SignatureHelp response;
response.activeSignature = 0;

CallCandidateMatchCost bestCandidateMatchCost;

// We will use an ad-hoc semantics visitor to check for argument-to-parameter conversions
// and to determine the best candidate signature.
// In the ideal design, this info should be gathered during the normal type checking
// process, but that require a lot of refactoring in the current code base, and may
// risk slowing down type checking for non-language-server use cases since we won't be
// able to do as many early returns.
// So instead we will do a separate ad-hoc checking here to do a best-effort guess
// on the best candidate.
//
DiagnosticSink sink;
SharedSemanticsContext semanticsContext(version->linkage, nullptr, &sink);
SemanticsVisitor semanticsVisitor(&semanticsContext);

auto addDeclRef = [&](DeclRef<Decl> declRef)
{
if (!declRef.getDecl())
return;

// If we have a better match than the current best, we will update response.activeSignature
// to this signature.
if (auto callableDeclRef = declRef.as<CallableDecl>())
{
auto matchCost =
getCallCandidateMatchCost(callableDeclRef, appExpr, semanticsVisitor, version);
if (matchCost.isBetterThan(bestCandidateMatchCost))
{
bestCandidateMatchCost = matchCost;
response.activeSignature = (uint32_t)response.signatures.getCount();
}
}

SignatureInformation sigInfo;

List<Slang::Range<Index>> paramRanges;
Expand Down Expand Up @@ -1675,13 +1815,14 @@ LanguageServerResult<LanguageServerProtocol::SignatureHelp> LanguageServerCore::

if (auto declRefExpr = as<DeclRefExpr>(funcExpr))
{
if (auto aggDeclRef = as<AggTypeDecl>(declRefExpr->declRef))
if (auto typeType = as<TypeType>(declRefExpr->type.type))
{
// Look for initializers
for (auto member :
getMembersOfType<ConstructorDecl>(version->linkage->getASTBuilder(), aggDeclRef))
auto ctors =
semanticsVisitor.lookupConstructorsInType(typeType->getType(), declRefExpr->scope);
for (auto ctor : ctors)
{
addDeclRef(member);
addDeclRef(ctor.declRef);
}
}
else
Expand Down Expand Up @@ -1711,7 +1852,6 @@ LanguageServerResult<LanguageServerProtocol::SignatureHelp> LanguageServerCore::
{
addFuncType(funcType);
}
response.activeSignature = 0;
response.activeParameter = 0;
for (int i = 1; i < appExpr->argumentDelimeterLocs.getCount(); i++)
{
Expand Down Expand Up @@ -2828,4 +2968,23 @@ SLANG_API SlangResult runLanguageServer(Slang::LanguageServerStartupOptions opti
return SLANG_OK;
}

SLANG_API SlangResult
getBuiltinModuleSource(const UnownedStringSlice& moduleName, slang::IBlob** blob)
{
ComPtr<slang::IGlobalSession> globalSession;
slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef());
Slang::Session* session = static_cast<Slang::Session*>(globalSession.get());
StringBuilder sb;
if (moduleName.startsWith("core"))
{
session->getBuiltinModuleSource(sb, slang::BuiltinModuleName::Core);
}
else if (moduleName.startsWith("glsl"))
{
session->getBuiltinModuleSource(sb, slang::BuiltinModuleName::GLSL);
}
*blob = StringBlob::moveCreate(sb.produceString()).detach();
return SLANG_OK;
}

} // namespace Slang
3 changes: 3 additions & 0 deletions source/slang/slang-language-server.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,4 +275,7 @@ inline bool _isIdentifierChar(char ch)
}

SLANG_API SlangResult runLanguageServer(LanguageServerStartupOptions options);
SLANG_API SlangResult
getBuiltinModuleSource(const UnownedStringSlice& moduleName, slang::IBlob** blob);

} // namespace Slang
5 changes: 4 additions & 1 deletion source/slang/slang-workspace-version.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,10 @@ void DocumentVersion::oneBasedUTF8LocToZeroBasedUTF16Loc(
Index rsLine = inLine - 1;
auto bounds = getUTF16Boundaries(inLine);
outLine = rsLine;
outCol = std::lower_bound(bounds.begin(), bounds.end(), inCol - 1) - bounds.begin();
if (bounds.getCount() != 0)
outCol = std::lower_bound(bounds.begin(), bounds.end(), inCol - 1) - bounds.begin();
else
outCol = inCol - 1;
}

void DocumentVersion::oneBasedUTF8LocToZeroBasedUTF16Loc(
Expand Down
Loading