Skip to content

Commit

Permalink
feat: call hierarchy provider (#735)
Browse files Browse the repository at this point in the history
Closes #680

### Summary of Changes

Implement a call hierarchy provider to get incoming & outgoing calls.
  • Loading branch information
lars-reimann authored Nov 7, 2023
1 parent c40347c commit 168d098
Show file tree
Hide file tree
Showing 12 changed files with 634 additions and 41 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import { AstNode, type AstNodeLocator, getDocument, streamAllContents, WorkspaceCache } from 'langium';
import { isSdsCall, type SdsCall } from '../generated/ast.js';
import type { SafeDsNodeMapper } from '../helpers/safe-ds-node-mapper.js';
import type { SafeDsServices } from '../safe-ds-module.js';

export class SafeDsCallGraphComputer {
private readonly astNodeLocator: AstNodeLocator;
private readonly nodeMapper: SafeDsNodeMapper;

/**
* Stores the calls inside the node with the given ID.
*/
private readonly callCache: WorkspaceCache<string, SdsCall[]>;

constructor(services: SafeDsServices) {
this.astNodeLocator = services.workspace.AstNodeLocator;
this.nodeMapper = services.helpers.NodeMapper;

this.callCache = new WorkspaceCache(services.shared);
}

getCalls(node: AstNode): SdsCall[] {
const key = this.getNodeId(node);
return this.callCache.get(key, () => streamAllContents(node).filter(isSdsCall).toArray());
}

private getNodeId(node: AstNode) {
const documentUri = getDocument(node).uri.toString();
const nodePath = this.astNodeLocator.getAstNodePath(node);
return `${documentUri}~${nodePath}`;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import {
AbstractCallHierarchyProvider,
type AstNode,
type CstNode,
findLeafNodeAtOffset,
getContainerOfType,
getDocument,
type NodeKindProvider,
type ReferenceDescription,
type Stream,
} from 'langium';
import type {
CallHierarchyIncomingCall,
CallHierarchyOutgoingCall,
Range,
SymbolKind,
SymbolTag,
} from 'vscode-languageserver';
import type { SafeDsCallGraphComputer } from '../flow/safe-ds-call-graph-computer.js';
import {
isSdsDeclaration,
isSdsParameter,
type SdsCall,
type SdsCallable,
type SdsDeclaration,
} from '../generated/ast.js';
import type { SafeDsNodeMapper } from '../helpers/safe-ds-node-mapper.js';
import type { SafeDsServices } from '../safe-ds-module.js';
import type { SafeDsNodeInfoProvider } from './safe-ds-node-info-provider.js';

export class SafeDsCallHierarchyProvider extends AbstractCallHierarchyProvider {
private readonly callGraphComputer: SafeDsCallGraphComputer;
private readonly nodeInfoProvider: SafeDsNodeInfoProvider;
private readonly nodeKindProvider: NodeKindProvider;
private readonly nodeMapper: SafeDsNodeMapper;

constructor(services: SafeDsServices) {
super(services);

this.callGraphComputer = services.flow.CallGraphComputer;
this.nodeInfoProvider = services.lsp.NodeInfoProvider;
this.nodeKindProvider = services.shared.lsp.NodeKindProvider;
this.nodeMapper = services.helpers.NodeMapper;
}

protected override getCallHierarchyItem(targetNode: AstNode): {
kind: SymbolKind;
tags?: SymbolTag[];
detail?: string;
} {
return {
kind: this.nodeKindProvider.getSymbolKind(targetNode),
tags: this.nodeInfoProvider.getTags(targetNode),
detail: this.nodeInfoProvider.getDetails(targetNode),
};
}

protected getIncomingCalls(
node: AstNode,
references: Stream<ReferenceDescription>,
): CallHierarchyIncomingCall[] | undefined {
const result: CallHierarchyIncomingCall[] = [];

this.getUniquePotentialCallers(references).forEach((caller) => {
if (!caller.$cstNode) {
/* c8 ignore next 2 */
return;
}

const callerNameCstNode = this.nameProvider.getNameNode(caller);
if (!callerNameCstNode) {
/* c8 ignore next 2 */
return;
}

// Find all calls inside the caller that refer to the given node. This can also handle aliases.
const callsOfNode = this.getCallsOf(caller, node);
if (callsOfNode.length === 0 || callsOfNode.some((it) => !it.$cstNode)) {
return;
}

const callerDocumentUri = getDocument(caller).uri.toString();

result.push({
from: {
name: callerNameCstNode.text,
range: caller.$cstNode.range,
selectionRange: callerNameCstNode.range,
uri: callerDocumentUri,
...this.getCallHierarchyItem(caller),
},
fromRanges: callsOfNode.map((it) => it.$cstNode!.range),
});
});

if (result.length === 0) {
return undefined;
}

return result;
}

/**
* Returns all declarations that contain at least one of the given references. Some of them might not be actual
* callers, since the references might not occur in a call. This has to be checked later.
*/
private getUniquePotentialCallers(references: Stream<ReferenceDescription>): Stream<SdsDeclaration> {
return references
.map((it) => {
const document = this.documents.getOrCreateDocument(it.sourceUri);
const rootNode = document.parseResult.value;
if (!rootNode.$cstNode) {
/* c8 ignore next 2 */
return undefined;
}

const targetNode = findLeafNodeAtOffset(rootNode.$cstNode, it.segment.offset);
if (!targetNode) {
/* c8 ignore next 2 */
return undefined;
}

const containingDeclaration = getContainerOfType(targetNode.astNode, isSdsDeclaration);
if (isSdsParameter(containingDeclaration)) {
// For parameters, we return their containing callable instead
return getContainerOfType(containingDeclaration.$container, isSdsDeclaration);
} else {
return containingDeclaration;
}
})
.distinct()
.filter(isSdsDeclaration);
}

private getCallsOf(caller: AstNode, callee: AstNode): SdsCall[] {
return this.callGraphComputer
.getCalls(caller)
.filter((call) => this.nodeMapper.callToCallable(call) === callee);
}

protected getOutgoingCalls(node: AstNode): CallHierarchyOutgoingCall[] | undefined {
const calls = this.callGraphComputer.getCalls(node);
const callsGroupedByCallable = new Map<
string,
{ callable: SdsCallable; callableNameCstNode: CstNode; callableDocumentUri: string; fromRanges: Range[] }
>();

// Group calls by the callable they refer to
calls.forEach((call) => {
const callCstNode = call.$cstNode;
if (!callCstNode) {
/* c8 ignore next 2 */
return;
}

const callable = this.nodeMapper.callToCallable(call);
if (!callable?.$cstNode) {
/* c8 ignore next 2 */
return;
}

const callableNameCstNode = this.nameProvider.getNameNode(callable);
if (!callableNameCstNode) {
/* c8 ignore next 2 */
return;
}

const callableDocumentUri = getDocument(callable).uri.toString();
const callableId = callableDocumentUri + '~' + callableNameCstNode.text;

const previousFromRanges = callsGroupedByCallable.get(callableId)?.fromRanges ?? [];
callsGroupedByCallable.set(callableId, {
callable,
callableNameCstNode,
fromRanges: [...previousFromRanges, callCstNode.range],
callableDocumentUri,
});
});

if (callsGroupedByCallable.size === 0) {
return undefined;
}

return Array.from(callsGroupedByCallable.values()).map((call) => ({
to: {
name: call.callableNameCstNode.text,
range: call.callable.$cstNode!.range,
selectionRange: call.callableNameCstNode.range,
uri: call.callableDocumentUri,
...this.getCallHierarchyItem(call.callable),
},
fromRanges: call.fromRanges,
}));
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import { AstNode, DefaultDocumentSymbolProvider, LangiumDocument } from 'langium';
import { DocumentSymbol, SymbolTag } from 'vscode-languageserver';
import { SafeDsServices } from '../safe-ds-module.js';
import { SafeDsAnnotations } from '../builtins/safe-ds-annotations.js';
import { type AstNode, DefaultDocumentSymbolProvider, type LangiumDocument } from 'langium';
import type { DocumentSymbol } from 'vscode-languageserver';
import type { SafeDsAnnotations } from '../builtins/safe-ds-annotations.js';
import {
isSdsAnnotatedObject,
isSdsAnnotation,
isSdsAttribute,
isSdsClass,
Expand All @@ -12,16 +10,20 @@ import {
isSdsPipeline,
isSdsSegment,
} from '../generated/ast.js';
import { SafeDsTypeComputer } from '../typing/safe-ds-type-computer.js';
import type { SafeDsServices } from '../safe-ds-module.js';
import type { SafeDsTypeComputer } from '../typing/safe-ds-type-computer.js';
import type { SafeDsNodeInfoProvider } from './safe-ds-node-info-provider.js';

export class SafeDsDocumentSymbolProvider extends DefaultDocumentSymbolProvider {
private readonly builtinAnnotations: SafeDsAnnotations;
private readonly nodeInfoProvider: SafeDsNodeInfoProvider;
private readonly typeComputer: SafeDsTypeComputer;

constructor(services: SafeDsServices) {
super(services);

this.builtinAnnotations = services.builtins.Annotations;
this.nodeInfoProvider = services.lsp.NodeInfoProvider;
this.typeComputer = services.types.TypeComputer;
}

Expand All @@ -34,8 +36,8 @@ export class SafeDsDocumentSymbolProvider extends DefaultDocumentSymbolProvider
{
name: name ?? nameNode.text,
kind: this.nodeKindProvider.getSymbolKind(node),
tags: this.getTags(node),
detail: this.getDetails(node),
tags: this.nodeInfoProvider.getTags(node),
detail: this.nodeInfoProvider.getDetails(node),
range: cstNode.range,
selectionRange: nameNode.range,
children: this.getChildSymbols(document, node),
Expand All @@ -60,22 +62,6 @@ export class SafeDsDocumentSymbolProvider extends DefaultDocumentSymbolProvider
}
}

private getDetails(node: AstNode): string | undefined {
if (isSdsFunction(node) || isSdsSegment(node)) {
const type = this.typeComputer.computeType(node);
return type?.toString();
}
return undefined;
}

private getTags(node: AstNode): SymbolTag[] | undefined {
if (isSdsAnnotatedObject(node) && this.builtinAnnotations.isDeprecated(node)) {
return [SymbolTag.Deprecated];
} else {
return undefined;
}
}

private isLeaf(node: AstNode): boolean {
return (
isSdsAnnotation(node) ||
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import { AstNode } from 'langium';
import { SymbolTag } from 'vscode-languageserver';
import type { SafeDsAnnotations } from '../builtins/safe-ds-annotations.js';
import { isSdsAnnotatedObject, isSdsFunction, isSdsSegment } from '../generated/ast.js';
import type { SafeDsServices } from '../safe-ds-module.js';
import { SafeDsTypeComputer } from '../typing/safe-ds-type-computer.js';

export class SafeDsNodeInfoProvider {
private readonly builtinAnnotations: SafeDsAnnotations;
private readonly typeComputer: SafeDsTypeComputer;

constructor(services: SafeDsServices) {
this.builtinAnnotations = services.builtins.Annotations;
this.typeComputer = services.types.TypeComputer;
}

/**
* Returns the detail string for the given node. This can be used, for example, to provide document symbols or call
* hierarchies.
*/
getDetails(node: AstNode): string | undefined {
if (isSdsFunction(node) || isSdsSegment(node)) {
const type = this.typeComputer.computeType(node);
return type?.toString();
}
return undefined;
}

/**
* Returns the tags for the given node. This can be used, for example, to provide document symbols or call
* hierarchies.
*/
getTags(node: AstNode): SymbolTag[] | undefined {
if (isSdsAnnotatedObject(node) && this.builtinAnnotations.isDeprecated(node)) {
return [SymbolTag.Deprecated];
} else {
return undefined;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,7 @@ export class SafeDsPartialEvaluator {
}

// Try to evaluate the node without parameter substitutions and cache the result
const documentUri = getDocument(node).uri.toString();
const nodePath = this.astNodeLocator.getAstNodePath(node);
const key = `${documentUri}~${nodePath}`;
const resultWithoutSubstitutions = this.cache.get(key, () =>
const resultWithoutSubstitutions = this.cache.get(this.getNodeId(node), () =>
this.doEvaluateWithSubstitutions(node, NO_SUBSTITUTIONS),
);
if (resultWithoutSubstitutions.isFullyEvaluated || isEmpty(substitutions)) {
Expand All @@ -96,6 +93,12 @@ export class SafeDsPartialEvaluator {
} /* c8 ignore stop */
}

private getNodeId(node: AstNode) {
const documentUri = getDocument(node).uri.toString();
const nodePath = this.astNodeLocator.getAstNodePath(node);
return `${documentUri}~${nodePath}`;
}

private doEvaluateWithSubstitutions(
node: AstNode | undefined,
substitutions: ParameterSubstitutions,
Expand Down
Loading

0 comments on commit 168d098

Please sign in to comment.