diff --git a/internal/lsp/lsproto/jsonrpc.go b/internal/lsp/lsproto/jsonrpc.go index 952a3635b9..76fd4edc1e 100644 --- a/internal/lsp/lsproto/jsonrpc.go +++ b/internal/lsp/lsproto/jsonrpc.go @@ -43,19 +43,32 @@ func (id *ID) UnmarshalJSON(data []byte) error { return json.Unmarshal(data, &id.int) } -// TODO(jakebailey): NotificationMessage? Use RequestMessage without ID? +type Message struct { + JSONRPC JSONRPCVersion `json:"jsonrpc"` +} + +type NotificationMessage struct { + Message + Method Method `json:"method"` + Params any `json:"params"` +} type RequestMessage struct { - JSONRPC JSONRPCVersion `json:"jsonrpc"` - ID *ID `json:"id"` - Method Method `json:"method"` - Params any `json:"params"` + Message + ID *ID `json:"id"` + Method Method `json:"method"` + Params any `json:"params"` } -func (r *RequestMessage) UnmarshalJSON(data []byte) error { +type RequestOrNotificationMessage struct { + NotificationMessage *NotificationMessage + RequestMessage *RequestMessage +} + +func (r *RequestOrNotificationMessage) UnmarshalJSON(data []byte) error { var raw struct { JSONRPC JSONRPCVersion `json:"jsonrpc"` - ID *ID `json:"id"` + ID *ID `json:"id,omitzero"` Method Method `json:"method"` Params json.RawMessage `json:"params"` } @@ -63,40 +76,59 @@ func (r *RequestMessage) UnmarshalJSON(data []byte) error { return fmt.Errorf("%w: %w", ErrInvalidRequest, err) } - r.ID = raw.ID - r.Method = raw.Method - if r.Method == MethodShutdown || r.Method == MethodExit { + params, err := unmarshalParams(raw.Method, raw.Params) + if err != nil { + return err + } + + if raw.ID != nil { + r.RequestMessage = &RequestMessage{ + ID: raw.ID, + Method: raw.Method, + Params: params, + } + } else { + r.NotificationMessage = &NotificationMessage{ + Method: raw.Method, + Params: params, + } + } + + return nil +} + +func unmarshalParams(rawMethod Method, rawParams []byte) (any, error) { + if rawMethod == MethodShutdown || rawMethod == MethodExit { // These methods have no params. - return nil + return nil, nil } var params any var err error - if unmarshalParams, ok := unmarshallers[raw.Method]; ok { - params, err = unmarshalParams(raw.Params) + if unmarshaller, ok := unmarshallers[rawMethod]; ok { + params, err = unmarshaller(rawParams) } else { // Fall back to default; it's probably an unknown message and we will probably not handle it. - err = json.Unmarshal(raw.Params, ¶ms) + err = json.Unmarshal(rawParams, ¶ms) } - r.Params = params if err != nil { - return fmt.Errorf("%w: %w", ErrInvalidRequest, err) + return nil, fmt.Errorf("%w: %w", ErrInvalidRequest, err) } - return nil + return params, nil } type ResponseMessage struct { - JSONRPC JSONRPCVersion `json:"jsonrpc"` - ID *ID `json:"id,omitempty"` - Result any `json:"result"` - Error *ResponseError `json:"error,omitempty"` + Message + ID *ID `json:"id,omitzero"` + Result any `json:"result"` + Error *ResponseError `json:"error,omitzero"` } type ResponseError struct { Code int32 `json:"code"` Message string `json:"message"` - Data any `json:"data,omitempty"` + Data any `json:"data,omitzero"` } diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 7118c8b208..35b07fe0f8 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -105,15 +105,20 @@ func (s *Server) Run() error { } if s.initializeParams == nil { - if req.Method == lsproto.MethodInitialize { - if err := s.handleInitialize(req); err != nil { - return err - } - } else { - if err := s.sendError(req.ID, lsproto.ErrServerNotInitialized); err != nil { - return err + if req.RequestMessage != nil { + message := req.RequestMessage + + if message.Method == lsproto.MethodInitialize { + if err := s.handleInitialize(message); err != nil { + return err + } + } else { + if err := s.sendError(message.ID, lsproto.ErrServerNotInitialized); err != nil { + return err + } } } + continue } @@ -123,13 +128,13 @@ func (s *Server) Run() error { } } -func (s *Server) read() (*lsproto.RequestMessage, error) { +func (s *Server) read() (*lsproto.RequestOrNotificationMessage, error) { data, err := s.r.Read() if err != nil { return nil, err } - req := &lsproto.RequestMessage{} + req := &lsproto.RequestOrNotificationMessage{} if err := json.Unmarshal(data, req); err != nil { return nil, fmt.Errorf("%w: %w", lsproto.ErrInvalidRequest, err) } @@ -170,45 +175,45 @@ func (s *Server) sendResponse(resp *lsproto.ResponseMessage) error { return s.w.Write(data) } -func (s *Server) handleMessage(req *lsproto.RequestMessage) error { - s.requestTime = time.Now() - s.requestMethod = string(req.Method) - - params := req.Params - switch params.(type) { - case *lsproto.InitializeParams: - return s.sendError(req.ID, lsproto.ErrInvalidRequest) - case *lsproto.InitializedParams: - return s.handleInitialized(req) - case *lsproto.DidOpenTextDocumentParams: - return s.handleDidOpen(req) - case *lsproto.DidChangeTextDocumentParams: - return s.handleDidChange(req) - case *lsproto.DidSaveTextDocumentParams: - return s.handleDidSave(req) - case *lsproto.DidCloseTextDocumentParams: - return s.handleDidClose(req) - case *lsproto.DocumentDiagnosticParams: - return s.handleDocumentDiagnostic(req) - case *lsproto.HoverParams: - return s.handleHover(req) - case *lsproto.DefinitionParams: - return s.handleDefinition(req) - default: +func (s *Server) handleMessage(msg *lsproto.RequestOrNotificationMessage) error { + if req := msg.RequestMessage; req != nil { switch req.Method { + case lsproto.MethodInitialize: + return s.sendError(req.ID, lsproto.ErrInvalidRequest) + case lsproto.MethodTextDocumentDiagnostic: + return s.handleDocumentDiagnostic(req) + case lsproto.MethodTextDocumentHover: + return s.handleHover(req) + case lsproto.MethodTextDocumentDefinition: + return s.handleDefinition(req) case lsproto.MethodShutdown: s.projectService.Close() return s.sendResult(req.ID, nil) - case lsproto.MethodExit: - return nil default: s.Log("unknown method", req.Method) - if req.ID != nil { - return s.sendError(req.ID, lsproto.ErrInvalidRequest) - } + } + } else if notif := msg.NotificationMessage; notif != nil { + switch notif.Method { + case lsproto.MethodInitialized: + return s.handleInitialized() + case lsproto.MethodTextDocumentDidOpen: + return s.handleDidOpen(notif) + case lsproto.MethodTextDocumentDidChange: + return s.handleDidChange(notif) + case lsproto.MethodTextDocumentDidSave: + return s.handleDidSave(notif) + case lsproto.MethodTextDocumentDidClose: + return s.handleDidClose(notif) + case lsproto.MethodExit: return nil + default: + s.Log("unknown method", notif.Method) } + } else { + s.Log("Failed to parse unknown message") } + + return nil } func (s *Server) handleInitialize(req *lsproto.RequestMessage) error { @@ -254,7 +259,7 @@ func (s *Server) handleInitialize(req *lsproto.RequestMessage) error { }) } -func (s *Server) handleInitialized(req *lsproto.RequestMessage) error { +func (s *Server) handleInitialized() error { s.logger = project.NewLogger([]io.Writer{s.stderr}, project.LogLevelVerbose) s.projectService = project.NewService(s, project.ServiceOptions{ DefaultLibraryPath: s.defaultLibraryPath, @@ -269,24 +274,26 @@ func (s *Server) handleInitialized(req *lsproto.RequestMessage) error { return nil } -func (s *Server) handleDidOpen(req *lsproto.RequestMessage) error { +func (s *Server) handleDidOpen(req *lsproto.NotificationMessage) error { params := req.Params.(*lsproto.DidOpenTextDocumentParams) s.projectService.OpenFile(ls.DocumentURIToFileName(params.TextDocument.Uri), params.TextDocument.Text, ls.LanguageKindToScriptKind(params.TextDocument.LanguageId), "") return nil } -func (s *Server) handleDidChange(req *lsproto.RequestMessage) error { +func (s *Server) handleDidChange(req *lsproto.NotificationMessage) error { params := req.Params.(*lsproto.DidChangeTextDocumentParams) scriptInfo := s.projectService.GetScriptInfo(ls.DocumentURIToFileName(params.TextDocument.Uri)) if scriptInfo == nil { - return s.sendError(req.ID, lsproto.ErrRequestFailed) + s.logger.Error("Failed to get script info") + return nil } changes := make([]ls.TextChange, len(params.ContentChanges)) for i, change := range params.ContentChanges { if partialChange := change.TextDocumentContentChangePartial; partialChange != nil { if textChange, err := s.converters.FromLSPTextChange(partialChange, scriptInfo.FileName()); err != nil { - return s.sendError(req.ID, err) + s.logger.Error(fmt.Sprintf("Error converting %v:", err)) + return nil } else { changes[i] = textChange } @@ -296,7 +303,8 @@ func (s *Server) handleDidChange(req *lsproto.RequestMessage) error { NewText: wholeChange.Text, } } else { - return s.sendError(req.ID, lsproto.ErrInvalidRequest) + s.logger.Error(fmt.Sprintf("Invalid request")) + return nil } } @@ -304,13 +312,13 @@ func (s *Server) handleDidChange(req *lsproto.RequestMessage) error { return nil } -func (s *Server) handleDidSave(req *lsproto.RequestMessage) error { +func (s *Server) handleDidSave(req *lsproto.NotificationMessage) error { params := req.Params.(*lsproto.DidSaveTextDocumentParams) s.projectService.MarkFileSaved(ls.DocumentURIToFileName(params.TextDocument.Uri), *params.Text) return nil } -func (s *Server) handleDidClose(req *lsproto.RequestMessage) error { +func (s *Server) handleDidClose(req *lsproto.NotificationMessage) error { params := req.Params.(*lsproto.DidCloseTextDocumentParams) s.projectService.CloseFile(ls.DocumentURIToFileName(params.TextDocument.Uri)) return nil