diff --git a/pylspclient/json_rpc_endpoint.py b/pylspclient/json_rpc_endpoint.py index 73ba20b..ccc234e 100644 --- a/pylspclient/json_rpc_endpoint.py +++ b/pylspclient/json_rpc_endpoint.py @@ -2,6 +2,7 @@ import json from pylspclient.lsp_errors import ErrorCodes, ResponseError import threading +from typing import Any, IO JSON_RPC_REQ_FORMAT = "Content-Length: {json_string_len}\r\n\r\n{json_string}" LEN_HEADER = "Content-Length: " @@ -15,7 +16,7 @@ class MyEncoder(json.JSONEncoder): """ Encodes an object in JSON """ - def default(self, o): # pylint: disable=E0202 + def default(self, o: Any): # pylint: disable=E0202 return o.__dict__ @@ -24,14 +25,14 @@ class JsonRpcEndpoint(object): Thread safe JSON RPC endpoint implementation. Responsible to recieve and send JSON RPC messages, as described in the protocol. More information can be found: https://www.jsonrpc.org/ ''' - def __init__(self, stdin, stdout): + def __init__(self, stdin: IO, stdout: IO): self.stdin = stdin self.stdout = stdout self.read_lock = threading.Lock() self.write_lock = threading.Lock() @staticmethod - def __add_header(json_string): + def __add_header(json_string: str) -> str: ''' Adds a header for the given json string @@ -41,7 +42,7 @@ def __add_header(json_string): return JSON_RPC_REQ_FORMAT.format(json_string_len=len(json_string), json_string=json_string) - def send_request(self, message): + def send_request(self, message: Any) -> None: ''' Sends the given message. @@ -54,7 +55,7 @@ def send_request(self, message): self.stdin.flush() - def recv_response(self): + def recv_response(self) -> Any: ''' Recives a message. diff --git a/pylspclient/lsp_client.py b/pylspclient/lsp_client.py index 3602a23..863cca3 100644 --- a/pylspclient/lsp_client.py +++ b/pylspclient/lsp_client.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional from pydantic import ValidationError from pylspclient.lsp_endpoint import LspEndpoint @@ -15,7 +15,16 @@ def __init__(self, lsp_endpoint: LspEndpoint): self.lsp_endpoint = lsp_endpoint - def initialize(self, processId, rootPath, rootUri, initializationOptions, capabilities, trace, workspaceFolders): + def initialize( + self, + processId: Optional[int] = None, + rootPath: Optional[str] = None, + rootUri: Optional[str] = None, + initializationOptions: Optional[Any] = None, + capabilities: Optional[dict] = None, + trace: Optional[str] = None, + workspaceFolders: Optional[list] = None, + ): """ The initialize request is sent as the first request from the client to the server. If the server receives a request or notification before the initialize request it should act as follows: @@ -33,7 +42,7 @@ def initialize(self, processId, rootPath, rootUri, initializationOptions, capabi :param int processId: The process Id of the parent process that started the server. Is null if the process has not been started by another process. If the parent process is not alive then the server should exit (see exit notification) its process. :param str rootPath: The rootPath of the workspace. Is null if no folder is open. Deprecated in favour of rootUri. - :param DocumentUri rootUri: The rootUri of the workspace. Is null if no folder is open. If both `rootPath` and `rootUri` are set + :param str rootUri: The rootUri of the workspace. Is null if no folder is open. If both `rootPath` and `rootUri` are set `rootUri` wins. :param any initializationOptions: User provided initialization options. :param ClientCapabilities capabilities: The capabilities provided by the client (editor or tool). @@ -41,6 +50,8 @@ def initialize(self, processId, rootPath, rootUri, initializationOptions, capabi :param list workspaceFolders: The workspace folders configured in the client when the server starts. This property is only available if the client supports workspace folders. It can be `null` if the client supports workspace folders but none are configured. """ + if capabilities is None: + raise ValueError("capabilities is required") self.lsp_endpoint.start() return self.lsp_endpoint.call_method("initialize", processId=processId, rootPath=rootPath, rootUri=rootUri, initializationOptions=initializationOptions, capabilities=capabilities, trace=trace, workspaceFolders=workspaceFolders) diff --git a/pylspclient/lsp_endpoint.py b/pylspclient/lsp_endpoint.py index 184ceaa..09d56e4 100644 --- a/pylspclient/lsp_endpoint.py +++ b/pylspclient/lsp_endpoint.py @@ -1,22 +1,37 @@ from __future__ import print_function import threading from pylspclient.lsp_errors import ErrorCodes, ResponseError +from pylspclient import JsonRpcEndpoint +from typing import Any, Dict, Callable, Union, Optional, Tuple, TypeAlias, TypedDict + +ResultType: TypeAlias = Optional[Dict[str, Any]] + +class ErrorType(TypedDict): + code: ErrorCodes + message: str + data: Optional[Any] class LspEndpoint(threading.Thread): - def __init__(self, json_rpc_endpoint, method_callbacks={}, notify_callbacks={}, timeout=2): + def __init__( + self, + json_rpc_endpoint: JsonRpcEndpoint, + method_callbacks: Dict[str, Callable[[Any], Any]] = {}, + notify_callbacks: Dict[str, Callable[[Any], Any]] = {}, + timeout: int = 2 + ): threading.Thread.__init__(self) self.json_rpc_endpoint = json_rpc_endpoint self.notify_callbacks = notify_callbacks self.method_callbacks = method_callbacks - self.event_dict = {} - self.response_dict = {} - self.next_id = 0 - self._timeout = timeout - self.shutdown_flag = False + self.event_dict: dict = {} + self.response_dict: Dict[Union[str, int], Tuple[ResultType, ErrorType]] = {} + self.next_id: int = 0 + self._timeout: int = timeout + self.shutdown_flag: bool = False - def handle_result(self, rpc_id, result, error): + def handle_result(self, rpc_id: Union[str, int], result: ResultType, error: ErrorType): self.response_dict[rpc_id] = (result, error) cond = self.event_dict[rpc_id] cond.acquire() @@ -24,11 +39,11 @@ def handle_result(self, rpc_id, result, error): cond.release() - def stop(self): + def stop(self) -> None: self.shutdown_flag = True - def run(self): + def run(self) -> None: while not self.shutdown_flag: try: jsonrpc_message = self.json_rpc_endpoint.recv_response() @@ -61,8 +76,8 @@ def run(self): self.send_response(rpc_id, None, e) - def send_response(self, id, result, error): - message_dict = {} + def send_response(self, id: Union[str, int, None], result: Any = None, error: Optional[Exception] = None) -> None: + message_dict: dict = {} message_dict["jsonrpc"] = "2.0" message_dict["id"] = id if result: @@ -72,8 +87,8 @@ def send_response(self, id, result, error): self.json_rpc_endpoint.send_request(message_dict) - def send_message(self, method_name, params, id = None): - message_dict = {} + def send_message(self, method_name: str, params: dict, id = None) -> None: + message_dict: dict = {} message_dict["jsonrpc"] = "2.0" if id is not None: message_dict["id"] = id @@ -82,7 +97,7 @@ def send_message(self, method_name, params, id = None): self.json_rpc_endpoint.send_request(message_dict) - def call_method(self, method_name, **kwargs): + def call_method(self, method_name: str, **kwargs) -> Any: current_id = self.next_id self.next_id += 1 cond = threading.Condition() @@ -101,9 +116,9 @@ def call_method(self, method_name, **kwargs): self.event_dict.pop(current_id) result, error = self.response_dict.pop(current_id) if error: - raise ResponseError(error.get("code"), error.get("message"), error.get("data")) + raise ResponseError(error["code"], error["message"], error.get("data")) return result - def send_notification(self, method_name, **kwargs): + def send_notification(self, method_name: str, **kwargs): self.send_message(method_name, kwargs) diff --git a/pylspclient/lsp_errors.py b/pylspclient/lsp_errors.py index 3c66ecd..43faa80 100644 --- a/pylspclient/lsp_errors.py +++ b/pylspclient/lsp_errors.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional from enum import IntEnum @@ -20,7 +20,7 @@ class ErrorCodes(IntEnum): class ResponseError(Exception): - def __init__(self, code: ErrorCodes, message: str, data: Any = None): + def __init__(self, code: ErrorCodes, message: str, data: Optional[Any] = None): self.code = code self.message = message if data: