diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml new file mode 100644 index 0000000..25032f2 --- /dev/null +++ b/.github/workflows/format.yml @@ -0,0 +1,32 @@ +name: Format + +on: + push: + branches: + - "main" + pull_request: + branches: + - "main" + +jobs: + format: + name: Stylua + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - run: date +%W > weekly + + - name: Restore cache + id: cache + uses: actions/cache@v4 + with: + path: | + ~/.cargo/bin + key: ${{ runner.os }}-cargo-${{ hashFiles('weekly') }} + + - name: Install + if: steps.cache.outputs.cache-hit != 'true' + run: cargo install stylua + + - name: Format + run: stylua --check lua/ --config-path=.stylua.toml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..5081cf8 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,24 @@ +name: Lint + +on: + push: + branches: + - "main" + pull_request: + branches: + - "main" + +jobs: + lint: + name: Luacheck + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Setup + run: | + sudo apt-get update + sudo apt-get install luarocks -y + sudo luarocks install luacheck + + - name: Lint + run: luacheck lua/ --globals vim --codes diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..5fe411b --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,30 @@ +name: Test + +on: + push: + branches: + - "main" + pull_request: + branches: + - "main" + +jobs: + test: + runs-on: ${{ matrix.os }} + strategy: + matrix: + nvim-versions: ['stable', 'nightly'] + os: [ubuntu-latest, windows-latest, macos-latest] + fail-fast: false + name: Plenary Tests + steps: + - name: checkout + uses: actions/checkout@v4 + + - uses: rhysd/action-setup-vim@v1 + with: + neovim: true + version: ${{ matrix.nvim-versions }} + + - name: run tests + run: make test diff --git a/.github/workflows/typecheck.yml b/.github/workflows/typecheck.yml new file mode 100644 index 0000000..b40e7bb --- /dev/null +++ b/.github/workflows/typecheck.yml @@ -0,0 +1,21 @@ +name: lua_ls-typecheck + +on: + push: + branches: + - "main" + pull_request: + branches: + - "main" + +jobs: + build: + name: Type Check Code Base + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - uses: stevearc/nvim-typecheck-action@v2 + with: + level: Warning + configpath: ".luarc.json" diff --git a/.gitignore b/.gitignore index 6fd0a37..70ee672 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,4 @@ luac.out *.x86_64 *.hex +deps/ diff --git a/.luacheckrc b/.luacheckrc new file mode 100644 index 0000000..016ffaa --- /dev/null +++ b/.luacheckrc @@ -0,0 +1,24 @@ +-- vim: ft=lua tw=80 + +stds.nvim = { + read_globals = { "jit" }, +} +std = "lua51+nvim" + +-- Don't report unused self arguments of methods. +self = false + +-- Rerun tests only if their modification time changed. +cache = true + +ignore = { + "631", -- max_line_length + "212/_.*", -- unused argument, for vars with "_" prefix + "214", -- used variable with unused hint ("_" prefix) + "581", -- negation of a relational operator- operator can be flipped (not for tables) +} + +-- Global objects defined by the C code +read_globals = { + "vim", +} diff --git a/.luarc.json b/.luarc.json new file mode 100644 index 0000000..7b95ad3 --- /dev/null +++ b/.luarc.json @@ -0,0 +1,16 @@ +{ + "runtime": { + "version": "LuaJIT", + "pathStrict": true + }, + "type": { + "checkTableShape": true + }, + "diagnostics.globals": [ + "describe", + "it", + "before_each", + "MiniTest", + "after_each" + ] +} diff --git a/.stylua.toml b/.stylua.toml new file mode 100644 index 0000000..c537618 --- /dev/null +++ b/.stylua.toml @@ -0,0 +1,5 @@ +column_width = 120 +line_endings = "Unix" +indent_type = "Spaces" +indent_width = 4 +quote_style = "AutoPreferDouble" diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..6e1d681 --- /dev/null +++ b/Makefile @@ -0,0 +1,12 @@ +# Run all test files +test: deps/mini.nvim + nvim --headless --noplugin -u ./scripts/minimal_init.lua -c "lua MiniTest.run()" + +# Run test from file at `$FILE` environment variable +test_file: deps/mini.nvim + nvim --headless --noplugin -u ./scripts/minimal_init.lua -c "lua MiniTest.run_file('$(FILE)')" + +# Download 'mini.nvim' to use its 'mini.test' testing module +deps/mini.nvim: + @mkdir -p deps + git clone --filter=blob:none https://github.com/echasnovski/mini.test.git $@ diff --git a/README.md b/README.md index afea232..9fd5716 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,6 @@ Features may be incomplete, bugs are likely to occur and breaking changes may oc - `curl` in your `PATH` - `github/copilot.vim` or `zbirenbaum/copilot.lua`: we don't depend on them directly, but you should sign in first -- `nvim-lua/plenary.nvim` # Installation @@ -22,10 +21,6 @@ Features may be incomplete, bugs are likely to occur and breaking changes may oc ```lua { 'Xuyuanp/nes.nvim', - event = 'VeryLazy', - dependencies = { - 'nvim-lua/plenary.nvim', - }, opts = {}, } diff --git a/lsp/nes.lua b/lsp/nes.lua new file mode 100644 index 0000000..184e724 --- /dev/null +++ b/lsp/nes.lua @@ -0,0 +1,11 @@ +---@type vim.lsp.Config +return { + cmd = function(dispatchers) + local server = require("nes.lsp.server").new(dispatchers) + return server:new_public_client() + end, + root_dir = vim.uv.cwd(), + capabilities = { + workspace = { workspaceFolders = true }, + }, +} diff --git a/lua/nes.lua b/lua/nes.lua new file mode 100644 index 0000000..b6516d7 --- /dev/null +++ b/lua/nes.lua @@ -0,0 +1,29 @@ +local M = { + configs = { + provider = { + name = "copilot", + }, + }, +} + +function M.setup(opts) + local _ = opts or {} + M.configs = vim.tbl_deep_extend("force", M.configs, opts or {}) + + vim.lsp.enable("nes", true) +end + +setmetatable(M, { + __index = function(_, key) + if vim.startswith(key, "_") then + -- hide private function + return + end + local core = require("nes.core") + if core[key] then + return core[key] + end + end, +}) + +return M diff --git a/lua/nes/api.lua b/lua/nes/api.lua deleted file mode 100644 index f12ec9d..0000000 --- a/lua/nes/api.lua +++ /dev/null @@ -1,114 +0,0 @@ -local curl = require("plenary.curl") - -local nvim_version = vim.version() - -local M = {} - -local _oauth_token -local _api_token - -local function get_oauth_token() - if _oauth_token then - return _oauth_token - end - - local config_dir = vim.env.XDG_CONFIG_HOME or vim.fs.joinpath(vim.env.HOME, "/.config") - - local config_paths = { - "github-copilot/apps.json", - "github-copilot/hosts.json", - } - - for _, path in pairs(config_paths) do - local config_path = vim.fs.joinpath(config_dir, path) - if vim.uv.fs_stat(config_path) then - local data = vim.fn.readfile(config_path, "") - if vim.islist(data) then - data = table.concat(data, "\n") - end - local apps = vim.json.decode(data) - for key, value in pairs(apps) do - if vim.startswith(key, "github.com") then - _oauth_token = value.oauth_token - return _oauth_token - end - end - end - end -end - -local function get_api_token() - if _api_token and _api_token.expires_at > os.time() + 60000 then - return _api_token - end - - local oauth_token = get_oauth_token() - if not oauth_token then - error("OAuth token not found") - end - - local request = curl.get("https://api.github.com/copilot_internal/v2/token", { - headers = { - Authorization = "Bearer " .. oauth_token, - ["Accept"] = "application/json", - ["User-Agent"] = "vscode-chat/dev", - }, - on_error = function(err) - error("token request error: " .. err) - end, - }) - _api_token = vim.json.decode(request.body) - return _api_token -end - -function M.call(payload, callback) - local api_token = get_api_token() - local base_url = api_token.endpoints.proxy or api_token.endpoints.api - - local output = "" - - local _request = curl.post(base_url .. "/chat/completions", { - headers = { - Authorization = "Bearer " .. api_token.token, - ["User-Agent"] = "vscode-chat/dev", - ["Content-Type"] = "application/json", - ["Copilot-Integration-Id"] = "vscode-chat", - ["editor-version"] = ("Neovim/%d.%d.%d"):format(nvim_version.major, nvim_version.minor, nvim_version.patch), - ["editor-plugin-version"] = "nes/0.1.0", - }, - on_error = function(err) - error("api request error: " .. err) - end, - body = vim.json.encode(payload), - stream = function(_, chunk) - if not chunk then - return - end - if vim.startswith(chunk, "data: ") then - chunk = chunk:sub(6) - end - if chunk == "[DONE]" then - return - end - local ok, event = pcall(vim.json.decode, chunk) - if not ok then - return - end - if event and event.choices and event.choices[1] then - local choice = event.choices[1] - if choice.delta and choice.delta.content then - output = output .. choice.delta.content - end - end - end, - callback = function() - callback(output) - end, - }) -end - -function M.debug() - vim.print(get_api_token()) -end - -return M diff --git a/lua/nes/api/codecompanion.lua b/lua/nes/api/codecompanion.lua new file mode 100644 index 0000000..701783b --- /dev/null +++ b/lua/nes/api/codecompanion.lua @@ -0,0 +1,56 @@ +---@class nes.api.provider.CodeCompanion +---@field private _adapter table +---@field private _client table +local CodeCompanion = {} +CodeCompanion.__index = CodeCompanion + +---@return nes.api.provider.CodeCompanion +function CodeCompanion.new(opts) + opts = opts or {} + local adapters = require("codecompanion.adapters") + + local name = opts.adapter or "openai" + local adapter = adapters.resolve(name) + if opts.extend then + adapter = adapters.extend(adapter, opts.extend) + end + adapter.features.tokens = false + + local settings = adapter:map_schema_to_params(adapter:make_from_schema()) + local client = require("codecompanion.http").new({ adapter = settings }) + + local self = { + _adapter = adapter, + _client = client, + } + return setmetatable(self, CodeCompanion) +end + +---@param messages nes.api.chat_completions.Message[] +---@param callback nes.api.Callback +---@return fun() cancel +function CodeCompanion:call(messages, callback) + local output = {} + local job = self._client:request({ messages = messages }, { + callback = function(err, data) + if err or not data then + return + end + local result = self._adapter.handlers.chat_output(self._adapter, data) + if result and result.status == "success" then + table.insert(output, result.output.content) + end + end, + done = function() + callback(nil, table.concat(output, "")) + end, + }, { silent = true }) + return function() + if job then + job:shutdown(-1, 114) + job = nil + end + end +end + +return CodeCompanion diff --git a/lua/nes/api/copilot.lua b/lua/nes/api/copilot.lua new file mode 100644 index 0000000..b810a91 --- /dev/null +++ b/lua/nes/api/copilot.lua @@ -0,0 +1,205 @@ +local nvim_version = vim.version() +local Curl = require("nes.util").Curl + +---@private +---@class ApiToken +---@field token string +---@field endpoints {proxy: string?, api: string} +---@field expires_at integer + +---@class nes.api.provider.Copilot +---@field private _opts? table +---@field private _oauth_token? string +---@field private _api_token? ApiToken +local Copilot = {} +Copilot.__index = Copilot + +local default_opts = { + token_endpoint = "https://api.github.com/copilot_internal/v2/token", + params = { + model = "copilot-nes-v", + temperature = 0, + top_p = 1, + n = 1, + stream = true, + snippy = { + enabled = false, + }, + }, +} + +---@return nes.api.provider.Copilot +function Copilot.new(opts) + opts = vim.tbl_deep_extend("force", default_opts, opts or {}) + local self = { + _opts = opts, + } + setmetatable(self, Copilot) + return self +end + +---@private +---@param messages nes.api.chat_completions.Message[] +---@return table +function Copilot:_payload(messages) + local payload = vim.deepcopy(self._opts.params) + payload.messages = messages + return payload +end + +---@private +---@return string? +function Copilot:_get_oauth_token() + if self._oauth_token then + return self._oauth_token + end + + local config_dir = vim.env.XDG_CONFIG_HOME or vim.fs.joinpath(vim.env.HOME, "/.config") + + local config_paths = { + "github-copilot/apps.json", + "github-copilot/hosts.json", + } + + for _, path in pairs(config_paths) do + local config_path = vim.fs.joinpath(config_dir, path) + if vim.uv.fs_stat(config_path) then + local data = vim.fn.readfile(config_path, "") + if vim.islist(data) then + data = table.concat(data, "\n") + end + local apps = vim.json.decode(data) + for key, value in pairs(apps) do + if vim.startswith(key, "github.com") then + self._oauth_token = value.oauth_token + return self._oauth_token + end + end + end + end +end + +---@private +---@param cb fun(err: string?, api_token?: ApiToken) +function Copilot:_with_token(cb) + if self._api_token and self._api_token.expires_at > os.time() + 5 then + cb(nil, self._api_token) + return + end + + local oauth_token = self:_get_oauth_token() + if not oauth_token then + cb("OAuth token not found") + return + end + + return Curl.get(self._opts.token_endpoint, { + headers = { + Authorization = "Bearer " .. oauth_token, + ["Accept"] = "application/json", + ["User-Agent"] = "vscode-chat/dev", + }, + on_exit = function(out) + if out.code ~= 0 then + cb(out.stderr or out.stdout or ("code: " .. out.code)) + return + end + self._api_token = vim.json.decode(out.stdout) + cb(nil, self._api_token) + end, + }) +end + +function Copilot:_call(base_url, api_key, messages, callback) + return Curl.post(base_url .. "/chat/completions", { + headers = { + Authorization = "Bearer " .. api_key, + ["User-Agent"] = "vscode-chat/dev", + ["Content-Type"] = "application/json", + ["Copilot-Integration-Id"] = "vscode-chat", + ["editor-version"] = ("Neovim/%d.%d.%d"):format(nvim_version.major, nvim_version.minor, nvim_version.patch), + ["editor-plugin-version"] = "nes/0.1.0", + }, + body = vim.json.encode(self:_payload(messages)), + on_exit = function(out) + if out.code ~= 0 then + callback({ message = out.stderr or ("code: " .. out.code) }) + return + end + + local stdout = out.stdout + + if not self._opts.params.stream then + local rsp = vim.json.decode(stdout) + if rsp.choices and rsp.choices[1] then + local choice = rsp.choices[1] + if choice.message and choice.message.content then + callback(nil, choice.message.content) + else + callback({ message = "No content in response" }) + end + else + callback({ message = "Invalid response format" }) + end + return + end + + local lines = vim.split(stdout, "\n", { plain = true }) + local chunks = {} + for _, line in ipairs(lines) do + line = vim.trim(line) + if line ~= "" then + if vim.startswith(line, "data: ") then + line = line:sub(7) + end + if line ~= "[DONE]" then + table.insert(chunks, line) + end + end + end + local json_chunks = string.format("[%s]", table.concat(chunks, ",")) + local ok, events = pcall(vim.json.decode, json_chunks) + if not ok then + callback({ message = "Failed to decode json: " .. (events or "unknown error") }) + return + end + + local output = "" + for _, event in ipairs(events) do + if event.choices and event.choices[1] then + local choice = event.choices[1] + if choice.delta and choice.delta.content then + output = output .. choice.delta.content + end + end + end + callback(nil, output) + end, + }) +end + +---@param messages nes.api.chat_completions.Message[] +---@param callback nes.api.Callback +---@return fun() cancel +function Copilot:call(messages, callback) + local job + job = self:_with_token(vim.schedule_wrap(function(err, api_token) + job = nil + if err then + callback(err) + return + end + --TODO: deal with nil api_token + + local base_url = api_token.endpoints.proxy or api_token.endpoints.api + job = self:_call(base_url, api_token.token, messages, callback) + end)) + return function() + if job then + job:kill(-1) + job = nil + end + end +end + +return Copilot diff --git a/lua/nes/api/init.lua b/lua/nes/api/init.lua new file mode 100644 index 0000000..97e9e27 --- /dev/null +++ b/lua/nes/api/init.lua @@ -0,0 +1,31 @@ +local M = {} + +---@alias nes.api.Callback fun(err?: any, output?: string) + +---@class nes.api.chat_completions.Message +---@field role 'system' | 'assistant' | 'user' +---@field content string + +---@class nes.api.Client +---@field call fun(messages: nes.api.chat_completions.Message[], callback: nes.api.Callback): fun() + +---@return nes.api.Client +function M.new_client(opts) + opts = vim.tbl_deep_extend("force", require("nes").configs.provider, opts or {}) + local provider_name = opts.name or "copilot" + + local lib = "nes.api." .. provider_name + local ok, cls = pcall(require, lib) + if not ok then + error("Invalid provider name: " .. provider_name) + end + local provider = cls.new(opts[provider_name] or {}) + + return { + call = function(messages, callback) + return provider:call(messages, callback) + end, + } +end + +return M diff --git a/lua/nes/context.lua b/lua/nes/context.lua index 2a40781..433de69 100644 --- a/lua/nes/context.lua +++ b/lua/nes/context.lua @@ -18,6 +18,7 @@ When responding to the programmer, you must follow these rules: - Do not alter method signatures, add or remove return values, or modify existing logic unless explicitly instructed. - The current cursor position is indicated by <|cursor|>. You MUST keep the cursor position the same in your response. - DO NOT REMOVE <|cursor|>. +- Avoid adding unnecessary text, such as comments. - You must ONLY reply using the tag: . ]] @@ -50,8 +51,7 @@ what I will do next. Do not skip any lines. Do not be lazy. ]] ---@class nes.Context ----@field bufnr number ----@field cursor [integer, integer] +---@field cursor [integer, integer] (1,0)-indexed ---@field original_code string ---@field edits string ---@field current_version table @@ -60,98 +60,128 @@ what I will do next. Do not skip any lines. Do not be lazy. local Context = {} Context.__index = Context +---@param filename string +---@param original_code string +---@param current_code string +---@param cursor [integer, integer] (row, col), (1,0)-indexed +---@param lang string ---@return nes.Context -function Context.new(bufnr) - local filename = vim.fn.fnamemodify(vim.api.nvim_buf_get_name(bufnr), ":") - local original_code = vim.fn.readfile(filename) - local current_version = Context.get_current_version(bufnr) - local self = { - bufnr = bufnr, - cursor = current_version.cursor, - original_code = table.concat( - vim.iter(original_code) - :enumerate() - :map(function(i, line) - return string.format("%d│%s", i, line) - end) - :totable(), - "\n" - ), - edits = vim.diff( - table.concat(original_code, "\n"), - table.concat(vim.api.nvim_buf_get_lines(bufnr, 0, -1, false), "\n"), - { algorithm = "minimal" } - ), - filename = filename, - current_version = current_version, - filetype = vim.bo[bufnr].filetype, - } - setmetatable(self, Context) - return self +function Context.new(filename, original_code, current_code, cursor, lang) + local self = { + cursor = cursor, + original_code = table.concat( + vim.iter(vim.split(original_code, "\n", { plain = true })) + :enumerate() + :map(function(i, line) + return string.format("%d│%s", i, line) + end) + :totable(), + "\n" + ), + edits = vim.diff(original_code, current_code, { algorithm = "minimal" }), + filename = filename, + current_version = Context._get_current_version(current_code, cursor), + filetype = lang, + } + setmetatable(self, Context) + ---@diagnostic disable-next-line: return-type-mismatch + return self end function Context:payload() - -- copy from vscode - return { - messages = { - { - role = "system", - content = SystemPrompt, - }, - { - role = "user", - content = UserPromptTemplate:format( - self.filename, - self.original_code, - self.filename, - self.filename, - self.edits, - self.filename, - self.filetype, - self.current_version.text - ), - }, - }, - model = "copilot-nes-v", - temperature = 0, - top_p = 1, - prediction = { - type = "content", - content = string.format( - "\n```%s\n%s\n```\n", - self.filetype, - self.current_version.text - ), - }, - n = 1, - stream = true, - snippy = { - enabled = false, - }, - } + return { + messages = { + { + role = "system", + content = SystemPrompt, + }, + { + role = "user", + content = self:user_prompt(), + }, + }, + prediction = { + type = "content", + content = string.format( + "\n```%s\n%s\n```\n", + self.filetype, + self.current_version.text + ), + }, + } end -function Context.get_current_version(bufnr) - local cursor = vim.api.nvim_win_get_cursor(0) - local row, col = cursor[1] - 1, cursor[2] - local start_row = row - 20 - if start_row < 0 then - start_row = 0 - end - local end_row = row + 20 - if end_row >= vim.api.nvim_buf_line_count(bufnr) then - end_row = vim.api.nvim_buf_line_count(bufnr) - 1 - end - local end_col = vim.api.nvim_buf_get_lines(bufnr, end_row, end_row + 1, false)[1]:len() - - local before_cursor = vim.api.nvim_buf_get_text(bufnr, start_row, 0, row, col, {}) - local after_cursor = vim.api.nvim_buf_get_text(bufnr, row, col, end_row, end_col, {}) - return { - cursor = cursor, - start_row = start_row, - end_row = end_row, - text = string.format("%s<|cursor|>%s", table.concat(before_cursor, "\n"), table.concat(after_cursor, "\n")), - } +---@return string +function Context:user_prompt() + return UserPromptTemplate:format( + self.filename, + self.original_code, + self.filename, + self.filename, + self.edits, + self.filename, + self.filetype, + self.current_version.text + ) +end + +function Context._get_current_version(text, cursor) + local row, col = cursor[1] - 1, cursor[2] + local lines = vim.split(text, "\n", { plain = true }) + local start_row = math.max(row - 20, 0) + local end_row = math.min(row + 20, #lines) + local start_col = 0 + local end_col = lines[end_row]:len() + + local before_cursor_lines = vim.list_slice(lines, start_row + 1, row) + local after_cursor_lines = vim.list_slice(lines, row + 2, end_row + 1) + local before_cursor_text = lines[row + 1]:sub(1, col) + local after_cursor_text = lines[row + 1]:sub(col + 1) + + local res = { + cursor = cursor, + start_row = start_row, + end_row = end_row, + start_col = start_col, + end_col = end_col, + text = string.format( + "%s%s<|cursor|>%s%s", + #before_cursor_lines > 0 and (table.concat(before_cursor_lines, "\n") .. "\n") or "", + before_cursor_text, + after_cursor_text, + #after_cursor_lines > 0 and ("\n" .. table.concat(after_cursor_lines, "\n")) or "" + ), + } + return res +end + +---@return lsp.TextEdit[]? +function Context:generate_edits(next_version) + if not vim.startswith(next_version, "") then + return + end + local old_version = self.current_version.text:gsub("<|cursor|>", "") + + -- have to ignore the cursor tag, because the response doesn't have it most of the time, even if I force it in system prompt + next_version = next_version:gsub("<|cursor|>", "") + local new_lines = vim.split(next_version, "\n") + if vim.startswith(new_lines[1], "") then + table.remove(new_lines, 1) + end + if vim.startswith(new_lines[1], "```") then + table.remove(new_lines, 1) + end + if #new_lines > 0 and vim.startswith(new_lines[#new_lines], "") then + table.remove(new_lines, #new_lines) + end + if #new_lines > 0 and vim.startswith(new_lines[#new_lines], "```") then + table.remove(new_lines, #new_lines) + end + next_version = table.concat(new_lines, "\n") + + return require("nes.util").text_edits_from_diff(old_version, next_version, { + line_offset = self.current_version.start_row, + }) end return Context diff --git a/lua/nes/core.lua b/lua/nes/core.lua index 3a9a07e..40dbd6f 100644 --- a/lua/nes/core.lua +++ b/lua/nes/core.lua @@ -2,333 +2,41 @@ local Context = require("nes.context") local M = {} -local ns_id = vim.api.nvim_create_namespace("nes") -local hl_ns_id = vim.api.nvim_create_namespace("nes_highlight") - ----@class nes.EditSuggestionUI ----@field preview_winnr? integer ----@field added_extmark_id? integer ----@field deleted_extmark_id? integer - ----@class nes.EditSuggestion ----@field text_edit lsp.TextEdit ----@field ui? nes.EditSuggestionUI - ----@class nes.BufState ----@field line_offset integer ----@field suggestions nes.EditSuggestion[] ----@field accepted_cursor? [integer, integer] - ----@class nes.Apply.Opts ----@field jump? boolean | { hl_timeout: integer? } auto jump to the end of the new edit ----@field trigger? boolean auto trigger the next edit suggestion - ----@private ----@param bufnr integer ----@param suggestion_ui nes.EditSuggestionUI -function M._dismiss_suggestion_ui(bufnr, suggestion_ui) - pcall(vim.api.nvim_win_close, suggestion_ui.preview_winnr, true) - pcall(vim.api.nvim_buf_del_extmark, bufnr, ns_id, suggestion_ui.added_extmark_id) - pcall(vim.api.nvim_buf_del_extmark, bufnr, ns_id, suggestion_ui.deleted_extmark_id) -end - ----@private ----@param bufnr integer ----@param line_offset integer ----@param suggestion nes.EditSuggestion ----@param opts? nes.Apply.Opts ----@return integer offset ----@return [integer, integer]? new_cursor if jump is true -function M._apply_suggestion(bufnr, line_offset, suggestion, opts) - opts = opts or {} - local text_edit = vim.deepcopy(suggestion.text_edit) - text_edit.range.start.line = text_edit.range.start.line + line_offset - text_edit.range["end"].line = text_edit.range["end"].line + line_offset - - -- apply the text edit - vim.lsp.util.apply_text_edits({ text_edit }, bufnr, "utf-8") - - if suggestion.ui then - M._dismiss_suggestion_ui(bufnr, suggestion.ui) - end - - local deleted_lines_count = text_edit.range["end"].line - text_edit.range.start.line - local added_lines = vim.split(text_edit.newText, "\n") - local added_lines_count = text_edit.newText == "" and 0 or #added_lines - 1 - - local new_cursor - if opts.jump and added_lines_count > 0 then - local start_line = text_edit.range.start.line - new_cursor = { start_line + added_lines_count, #added_lines[#added_lines - 1] } - vim.api.nvim_win_set_cursor(0, new_cursor) - - local hl_timeout = type(opts.jump) == "table" and opts.jump.hl_timeout or 800 - - if hl_timeout > 0 then - vim.defer_fn(function() - vim.hl.range(bufnr, hl_ns_id, "NesApply", { - start_line, - 0, - }, { - start_line + added_lines_count, - #added_lines[#added_lines - 1], - }, { timeout = hl_timeout }) - end, 10) - end - end - - return added_lines_count - deleted_lines_count, new_cursor -end - ----@private ----@param bufnr integer ----@param state nes.BufState ----@return nes.BufState -function M._apply_next_suggestion(bufnr, state, opts) - if not state.suggestions or #state.suggestions == 0 then - return state - end - local suggestion = state.suggestions[1] - local offset, new_cursor = M._apply_suggestion(bufnr, state.line_offset, suggestion, opts) - - state.accepted_cursor = new_cursor - state.line_offset = state.line_offset + offset - table.remove(state.suggestions, 1) - return state -end - ----@private ----@param bufnr integer ----@param state nes.BufState ----@return nes.BufState -function M._display_next_suggestion(bufnr, state) - local win_id = vim.fn.win_findbuf(bufnr)[1] - if not state.suggestions or #state.suggestions == 0 then - return state - end - local suggestion = state.suggestions[1] - if suggestion.ui then - return state - end - - local ui = {} - local deleted_lines_count = suggestion.text_edit.range["end"].line - suggestion.text_edit.range.start.line - if deleted_lines_count > 0 then - ui.deleted_extmark_id = - vim.api.nvim_buf_set_extmark(bufnr, ns_id, state.line_offset + suggestion.text_edit.range.start.line, 0, { - hl_group = "NesDelete", - end_line = state.line_offset + suggestion.text_edit.range["end"].line, - }) - end - local added_lines = vim.split(suggestion.text_edit.newText, "\n") - local added_lines_count = suggestion.text_edit.newText == "" and 0 or #added_lines - 1 - if added_lines_count > 0 then - local virt_lines = {} - for _i = 1, added_lines_count do - table.insert(virt_lines, { - { "", "Normal" }, - }) - end - local line = state.line_offset + suggestion.text_edit.range.start.line + deleted_lines_count - 1 - - -- tricky part: - -- 1. set empty virtual lines to offset the content of the rest - -- 2. open a borderless floating window to show the added lines - -- 3. use treesitter to highlight the added lines - ui.added_extmark_id = vim.api.nvim_buf_set_extmark(bufnr, ns_id, line, 0, { - virt_lines = virt_lines, - }) - - local preview_bufnr = vim.api.nvim_create_buf(false, true) - vim.api.nvim_buf_set_lines(preview_bufnr, 0, -1, false, added_lines) - vim.bo[preview_bufnr].modifiable = false - vim.bo[preview_bufnr].buflisted = false - vim.bo[preview_bufnr].bufhidden = "wipe" - vim.bo[preview_bufnr].filetype = vim.bo[bufnr].filetype - - local cursor = vim.api.nvim_win_get_cursor(win_id) - local win_width = vim.api.nvim_win_get_width(win_id) - local offset = vim.fn.getwininfo(win_id)[1].textoff - local preview_winnr = vim.api.nvim_open_win(preview_bufnr, false, { - relative = "cursor", - width = win_width - offset, - height = #added_lines - 1, - row = state.line_offset + suggestion.text_edit.range["end"].line - cursor[1] + 1, - col = 0, - style = "minimal", - border = "none", - }) - vim.wo[preview_winnr].number = false - vim.wo[preview_winnr].winhighlight = "Normal:NesAdd" - vim.wo[preview_winnr].winblend = 0 - - ui.preview_winnr = preview_winnr - end - - suggestion.ui = ui - state.suggestions[1] = suggestion - - vim.api.nvim_create_autocmd({ "CursorMoved", "CursorMovedI" }, { - buffer = bufnr, - callback = function() - if not vim.b.nes_state then - return true - end - - local accepted_cursor = vim.b.nes_state.accepted_cursor - if accepted_cursor then - local cursor = vim.api.nvim_win_get_cursor(win_id) - if cursor[1] == accepted_cursor[1] and cursor[2] == accepted_cursor[2] then - return - end - end - - M.clear_suggestion(bufnr) - return true - end, - }) - - return state -end - ----@param ctx nes.Context ----@param next_version string -local function parse_suggestion(ctx, next_version) - -- force clear the suggestion first, in case of duplicated request - M.clear_suggestion(ctx.bufnr) - - if not vim.startswith(next_version, "") then - vim.print("not found") - return - end - local old_version = ctx.current_version.text:gsub("<|cursor|>", "") - - -- have to ignore the cursor tag, because the response doesn't have it most of the time, even if I force it in system prompt - next_version = next_version:gsub("<|cursor|>", "") - local new_lines = vim.split(next_version, "\n") - if vim.startswith(new_lines[1], "") then - table.remove(new_lines, 1) - end - if vim.startswith(new_lines[1], "```") then - table.remove(new_lines, 1) - end - if vim.startswith(new_lines[#new_lines], "") then - table.remove(new_lines, #new_lines) - end - if vim.startswith(new_lines[#new_lines], "```") then - table.remove(new_lines, #new_lines) - end - next_version = table.concat(new_lines, "\n") - - local chunks = vim.diff(old_version, next_version, { - algorithm = "minimal", - ignore_cr_at_eol = true, - ignore_whitespace_change_at_eol = true, - ignore_blank_lines = true, - ignore_whitespace = true, - result_type = "indices", - }) - assert(type(chunks) == "table", "nes: invalid diff result") - if not chunks or #chunks == 0 then - return - end - - local state = { line_offset = ctx.current_version.start_row, suggestions = {} } - for _, next_edit in ipairs(chunks) do - local start_a, count_a = next_edit[1], next_edit[2] - local start_b, count_b = next_edit[3], next_edit[4] - - ---@type lsp.TextEdit - local text_edit = { - range = { - start = { - line = start_a, - character = 0, - }, - ["end"] = { - line = 0, -- leave it empty for now - character = 0, - }, - }, - newText = "", - } - - if count_a > 0 then - text_edit.range["start"].line = start_a - 1 - text_edit.range["end"].line = start_a + count_a - 1 - else - text_edit.range["end"].line = start_a - end - if count_b > 0 then - local added_lines = {} - for i = start_b, start_b + count_b - 1 do - table.insert(added_lines, new_lines[i]) - end - text_edit.newText = table.concat(added_lines, "\n") .. "\n" - end - table.insert(state.suggestions, { text_edit = text_edit }) - end - - state = M._display_next_suggestion(ctx.bufnr, state) - vim.b[ctx.bufnr].nes_state = state -end - ----@param bufnr? integer -function M.get_suggestion(bufnr) - bufnr = bufnr and bufnr > 0 and bufnr or vim.api.nvim_get_current_buf() - local ctx = Context.new(bufnr) - local payload = ctx:payload() - require("nes.api").call(payload, function(stdout) - local next_version = vim.trim(stdout) - assert(next_version) - if not vim.startswith(next_version, "") then - return - end - vim.schedule(function() - parse_suggestion(ctx, next_version) - end) - end) -end - ----@param bufnr? integer ----@param opts? nes.Apply.Opts -function M.apply_suggestion(bufnr, opts) - opts = opts or {} - - bufnr = bufnr and bufnr > 0 and bufnr or vim.api.nvim_get_current_buf() - - local state = vim.b[bufnr].nes_state - if not state then - return - end - local new_state = M._apply_next_suggestion(bufnr, state, opts) - vim.b[bufnr].nes_state = new_state - if #new_state.suggestions > 0 then - -- vim.schedule(function() - vim.b[bufnr].nes_state = M._display_next_suggestion(bufnr, new_state) - -- end) - elseif opts.trigger then - vim.schedule(function() - M.get_suggestion(bufnr) - end) - end -end - ----@param bufnr? integer -function M.clear_suggestion(bufnr) - bufnr = bufnr and bufnr > 0 and bufnr or vim.api.nvim_get_current_buf() - vim.api.nvim_buf_clear_namespace(bufnr, ns_id, 0, -1) - local state = vim.b[bufnr].nes_state - if not state then - return - end - - for _, suggestion in ipairs(state.suggestions) do - if suggestion.ui then - M._dismiss_suggestion_ui(bufnr, suggestion.ui) - end - end - vim.b[bufnr].nes_state = nil +---@type nes.api.Client? +local api_client + +---@param filename string +---@param original_code string +---@param current_code string +---@param cursor [integer, integer] (1,0)-indexed (row, col) +---@param lang string +---@return fun() cancel +function M.fetch_suggestions(filename, original_code, current_code, cursor, lang, callback) + if current_code == original_code then + callback({}) + return function() end + end + local ctx = Context.new(filename, original_code, current_code, cursor, lang) + local payload = ctx:payload() + + if not api_client then + api_client = require("nes.api").new_client() + end + + return api_client.call( + payload.messages, + vim.schedule_wrap(function(err, stdout) + if err then + require("nes.util").notify("call api failed: " .. vim.inspect(err), { level = vim.log.levels.ERROR }) + callback({}) + return + end + local next_version = vim.trim(stdout) + assert(next_version) + local edits = ctx:generate_edits(next_version) or {} + callback(edits) + end) + ) end return M diff --git a/lua/nes/health.lua b/lua/nes/health.lua new file mode 100644 index 0000000..de4eef1 --- /dev/null +++ b/lua/nes/health.lua @@ -0,0 +1,15 @@ +local M = {} + +function M.check() + vim.health.start("System") + local required_binaries = { "curl" } + for _, name in ipairs(required_binaries) do + if vim.fn.executable(name) == 0 then + vim.health.error(name .. " is not installed") + else + vim.health.ok(name .. " is installed") + end + end +end + +return M diff --git a/lua/nes/init.lua b/lua/nes/init.lua deleted file mode 100644 index 15610b9..0000000 --- a/lua/nes/init.lua +++ /dev/null @@ -1,24 +0,0 @@ -local M = {} - -function M.setup(opts) - opts = opts or {} - - vim.api.nvim_set_hl(0, "NesAdd", { link = "DiffAdd", default = true }) - vim.api.nvim_set_hl(0, "NesDelete", { link = "DiffDelete", default = true }) - vim.api.nvim_set_hl(0, "NesApply", { link = "DiffText", default = true }) -end - -setmetatable(M, { - __index = function(_, key) - if vim.startswith(key, "_") then - -- hide private function - return - end - local core = require("nes.core") - if core[key] then - return core[key] - end - end, -}) - -return M diff --git a/lua/nes/lsp/server.lua b/lua/nes/lsp/server.lua new file mode 100644 index 0000000..d5d32cf --- /dev/null +++ b/lua/nes/lsp/server.lua @@ -0,0 +1,316 @@ +local Methods = vim.lsp.protocol.Methods + +---@class nes.DocumentState +---@field original lsp.TextDocumentItem +---@field current lsp.TextDocumentItem +---@field pending_edits? nes.InlineEdit[] +---@field last_applied? integer + +---@alias nes.Workspace table + +---@alias nes.InlineEditFilter fun(edit: lsp.TextEdit): boolean + +---@class nes.Server +---@field dispatchers vim.lsp.rpc.Dispatchers +---@field private _workspace nes.Workspace +---@field private _initialized boolean +---@field private _client_initialized boolean +---@field private _closing boolean +---@field private _filters nes.InlineEditFilter[] +---@field private _inflights table -- message_id -> cancel function +---@field private _next_message_id integer +local Server = {} +Server.__index = Server + +---@type lsp.ServerCapabilities +local capabilities = { + textDocumentSync = 1, + workspace = { + workspaceFolders = { + supported = true, + changeNotifications = true, + }, + }, +} + +---@param dispatchers vim.lsp.rpc.Dispatchers +---@return nes.Server +function Server.new(dispatchers) + local self = setmetatable({ + dispatchers = dispatchers, + _workspace = {}, + _next_message_id = 1, + _initialized = false, + _closing = false, + _inflights = {}, + _filters = { + -- no more than 3 lines edit + function(edit) + return edit.range["end"].line - edit.range.start.line < 3 + end, + function(edit) + return #vim.split(edit.newText, "\n") < 3 + end, + }, + }, Server) + return self +end + +---@return vim.lsp.rpc.PublicClient +function Server:new_public_client() + return { + request = function(...) + return self:on_request(...) + end, + notify = function(...) + return self:on_notify(...) + end, + is_closing = function() + return self:is_closing() + end, + terminate = function() + self:terminate() + end, + } +end + +--- Receives a request from the LSP client +--- +---@param method vim.lsp.protocol.Method.ClientToServer The invoked LSP method +---@param params table? Parameters for the invoked LSP method +---@param callback fun(err: lsp.ResponseError?, result: any) Callback to invoke +---@param notify_reply_callback? fun(message_id: integer) Callback to invoke as soon as a request is no longer pending +---@return boolean success +---@return integer? message_id +function Server:on_request(method, params, callback, notify_reply_callback) + if self._closing then + return false + end + + if method ~= Methods.initialize and not self._client_initialized then + vim.notify("client not iniitialized: " .. method, vim.log.levels.WARN) + return false + end + + notify_reply_callback = notify_reply_callback or function() end + + local handler = self[method] + if not handler then + vim.notify("method not support: " .. method, vim.log.levels.WARN) + return false + end + local message_id = self:new_message_id() + + handler = vim.schedule_wrap(handler) + local wrapped_cb = vim.schedule_wrap(function(err, result) + callback(err, result) + if not err then + notify_reply_callback(message_id) + end + end) + handler(self, params, wrapped_cb, message_id) + + return true, message_id +end + +--- Receives a notification from the LSP client. +---@param method string The invoked LSP method +---@param params table? Parameters for the invoked LSP method +---@return boolean +function Server:on_notify(method, params) + if self._closing then + return false + end + + method = method + params = params + local handler = self[method] + if not handler then + vim.notify("No handler for method: " .. method, vim.log.levels.WARN) + return false + end + vim.schedule(function() + handler(self, params, function() end) + end) + return true +end + +---@return boolean +function Server:is_closing() + return self._closing +end + +function Server:terminate() + self._closing = true + self._workspace = nil +end + +function Server:new_message_id() + local id = self._next_message_id + self._next_message_id = self._next_message_id + 1 + return id +end + +---@param params lsp.InitializeParams +Server[Methods.initialize] = function(self, params, callback) + local _ = params + ---@type lsp.InitializeResult + local result = { + capabilities = capabilities, + serverInfo = { + name = "nes", + version = "0.1.0", + }, + } + self._initialized = true + vim.schedule(function() + self.dispatchers.server_request(Methods.window_logMessage, { type = 3, message = "NES initialized" }) + end) + callback(nil, result) +end + +---@param params lsp.InitializedParams +Server[Methods.initialized] = function(self, params, callback) + local _ = params + self._client_initialized = true + callback() +end + +---@param params lsp.DidOpenTextDocumentParams +Server[Methods.textDocument_didOpen] = function(self, params, callback) + self._workspace[params.textDocument.uri] = { + original = params.textDocument, + current = vim.deepcopy(params.textDocument), + } + callback() +end + +---@param params lsp.DidSaveTextDocumentParams +Server[Methods.textDocument_didSave] = function(self, params, callback) + local state = self._workspace[params.textDocument.uri] + if not state then + callback({ code = 1, message = "no state" }) + return + end + state.original = vim.deepcopy(state.current) + self._workspace[params.textDocument.uri] = state + callback() +end + +---@param params lsp.DidChangeTextDocumentParams +Server[Methods.textDocument_didChange] = function(self, params, callback) + local state = self._workspace[params.textDocument.uri] + if not state then + callback({ code = 1, message = "no state" }) + return + end + state.current.version = params.textDocument.version + state.current.text = params.contentChanges[1].text + + state.pending_edits = nil + state.last_applied = nil + + self._workspace[params.textDocument.uri] = state + callback() +end + +---@param params lsp.DidCloseTextDocumentParams +Server[Methods.textDocument_didClose] = function(self, params, callback) + self._workspace[params.textDocument.uri] = nil + callback() +end + +---@param params lsp.CancelParams +Server[Methods.dollar_cancelRequest] = function(self, params, callback) + local cancel = self._inflights[params.id] or function() end + cancel() + + callback() +end + +---@class nes.InlineEditParams : lsp.TextDocumentPositionParams +---@field version integer + +---@class nes.InlineEdit: lsp.TextEdit +---@field command? lsp.Command +---@field text string +---@field textDocument lsp.VersionedTextDocumentIdentifier + +---@param params nes.InlineEditParams +Server["textDocument/copilotInlineEdit"] = function(self, params, callback, message_id) + for msg_id, cancel in pairs(self._inflights) do + cancel() + self._inflights[msg_id] = nil + end + + local state = self._workspace[params.textDocument.uri] + if not state then + callback({ code = 1, message = "no state" }) + return + end + + local version = state.current.version + + local pending = state.pending_edits or {} + local last_applied = state.last_applied or 0 + local next_edit = pending[last_applied + 1] + + if next_edit then + callback(nil, { edits = { next_edit } }) + state.last_applied = last_applied + 1 + return + end + + state.pending_edits = nil + state.last_applied = nil + + local cursor = { params.position.line + 1, params.position.character } + local filename = vim.fn.fnamemodify(vim.uri_to_fname(state.original.uri), ":") + local cancel = require("nes.core").fetch_suggestions( + filename, + state.original.text, + state.current.text, + cursor, + state.original.languageId, + ---@param edits lsp.TextEdit[] + function(edits) + self._inflights[message_id] = nil + + if version ~= state.current.version then + -- drop outdated suggestions + callback(nil, { edits = {} }) + return + end + ---@type nes.InlineEdit[] + local inline_edits = {} + for _, edit in ipairs(edits) do + local ok = true + for _, filter in ipairs(self._filters) do + if not filter(edit) then + ok = false + break + end + end + + if ok then + table.insert(inline_edits, { + range = edit.range, + text = edit.newText, + newText = edit.newText, + textDocument = { + uri = state.current.uri, + version = state.current.version, + }, + } --[[@as nes.InlineEdit]]) + end + end + state.pending_edits = inline_edits + state.last_applied = 0 + callback(nil, { edits = { inline_edits[1] } }) + end + ) + + self._inflights[message_id] = cancel +end + +return Server diff --git a/lua/nes/util.lua b/lua/nes/util.lua new file mode 100644 index 0000000..8ae358a --- /dev/null +++ b/lua/nes/util.lua @@ -0,0 +1,245 @@ +local M = {} + +function M.notify(text, opts) + opts = opts or {} + opts.title = "[NES] " .. (opts.title or "") + opts.level = opts.level or vim.log.levels.INFO + vim.notify(text, opts.level, { title = opts.title }) +end + +---@class nes.util.Curl +local Curl = {} + +function Curl.request(method, url, opts) + opts = opts or {} + local bin = opts.binary or "curl" + local args = { bin, "-sSL", url, "-X", method } + for key, value in pairs(opts.headers or {}) do + table.insert(args, "-H") + table.insert(args, key .. ": " .. value) + end + if opts.body then + table.insert(args, "-d") + table.insert(args, "@-") + end + return vim.system(args, { + stdin = opts.body, + text = true, + timeout = opts.timeout, + env = opts.env, + stdout = opts.stdout, + stderr = opts.stderr, + }, opts.on_exit) +end + +function Curl.get(url, opts) + return Curl.request("GET", url, opts) +end + +function Curl.post(url, opts) + return Curl.request("POST", url, opts) +end + +M.Curl = Curl + +---@param a string +---@param b string +---@param opts? {line_offset?: integer, diff?: vim.diff.Opts} +---@return lsp.TextEdit[] +function M.text_edits_from_diff(a, b, opts) + local res = {} + + local old_lines = vim.split(a, "\n", { plain = true }) + local new_lines = vim.split(b, "\n", { plain = true }) + + opts = opts or {} + opts.line_offset = opts.line_offset or 0 + opts.diff = opts.diff + or { + ignore_cr_at_eol = true, + ignore_whitespace_change_at_eol = true, + ignore_blank_lines = true, + ignore_whitespace = true, + } + opts.diff.algorithm = "minimal" + opts.diff.on_hunk = function(start_a, count_a, start_b, count_b) + -- no change + if count_a == 0 and count_b == 0 then + return + end + + if count_a > 0 then + if count_b == 0 then + -- delete lines + local edit = { + range = { + start = { line = opts.line_offset + start_a - 1, character = 0 }, + ["end"] = { + line = opts.line_offset + start_a - 1 + count_a, + character = 0, + }, + }, + newText = "", + } + table.insert(res, edit) + return + end + if count_a == 1 and count_b == 1 then + -- try inline edit + local inline_edit = M._calculate_inline_edit(old_lines[start_a], new_lines[start_b]) + if inline_edit then + local edit = { + range = { + start = { line = opts.line_offset + start_a - 1, character = inline_edit.start_col }, + ["end"] = { line = opts.line_offset + start_a - 1, character = inline_edit.end_col }, + }, + newText = inline_edit.text, + } + table.insert(res, edit) + return + end + end + + -- replace lines + local edit = { + range = { + start = { line = opts.line_offset + start_a - 1, character = 0 }, + ["end"] = { + line = opts.line_offset + start_a - 1 + count_a - 1, + character = #old_lines[start_a + count_a - 1], + }, + }, + newText = table.concat(vim.list_slice(new_lines, start_b, start_b + count_b - 1), "\n"), + } + table.insert(res, edit) + return + end + if count_b > 0 then + if start_a == 0 then + local edit = { + range = { + start = { line = opts.line_offset, character = 0 }, + ["end"] = { + line = opts.line_offset, + character = 0, + }, + }, + newText = "\n" .. table.concat(vim.list_slice(new_lines, start_b, start_b + count_b - 1), "\n"), + } + table.insert(res, edit) + return + end + -- add lines + local edit = { + range = { + start = { line = opts.line_offset + start_a - 1, character = #old_lines[start_a] }, + ["end"] = { + line = opts.line_offset + start_a - 1, + character = #old_lines[start_a], + }, + }, + newText = "\n" .. table.concat(vim.list_slice(new_lines, start_b, start_b + count_b - 1), "\n"), + } + table.insert(res, edit) + return + end + + assert(false, "unreachable") + end + + vim.diff(a, b, opts.diff) + + return res +end + +---@private +---@class InlineEdit +---@field start_col integer 0-indexed +---@field end_col integer 0-indexed +---@field text string + +---generated by gemini-2.5-pro +---@param a string a single line string +---@param b string a single line string +---@return InlineEdit? inline_edit only returns if the edit is a single add/delete of a contiguous block +function M._calculate_inline_edit(a, b) + -- If the strings are identical, there's no edit. + if a == b then + return nil + end + + local len_a = #a + local len_b = #b + + -- Find the length of the common prefix (0-indexed length) + local prefix_len = 0 + local min_len = math.min(len_a, len_b) + while prefix_len < min_len and a:sub(prefix_len + 1, prefix_len + 1) == b:sub(prefix_len + 1, prefix_len + 1) do + prefix_len = prefix_len + 1 + end + + -- Find the length of the common suffix after the prefix (0-indexed length) + local suffix_len = 0 + -- Loop backwards from the end, ensuring we don't overlap with the prefix already found + -- We compare characters from the end of string `a` and string `b`. + -- The index calculation `len - suffix_len` gives the 1-based index from the start. + -- We need to ensure this index is greater than the prefix_len (0-indexed). + -- `len - suffix_len > prefix_len` is equivalent to `len - prefix_len > suffix_len`. + while + suffix_len < len_a - prefix_len + and suffix_len < len_b - prefix_len + and a:sub(len_a - suffix_len, len_a - suffix_len) == b:sub(len_b - suffix_len, len_b - suffix_len) + do + suffix_len = suffix_len + 1 + end + + -- Check if the differing parts are contiguous and cover the entire difference. + -- If the total length of the matched prefix and suffix is greater than the length + -- of either string, it means the prefix and suffix overlap or meet exactly + -- within one or both strings. + -- If they meet exactly (prefix_len + suffix_len == len_a and prefix_len + suffix_len == len_b) + -- and the strings are different, it implies a change exactly at the junction point(s) + -- or a simple replacement, which is not considered a single "add/delete a substring" + -- in the sense of adding/removing a block of text between common parts. + if prefix_len + suffix_len > len_a or prefix_len + suffix_len > len_b then + return nil -- Not a simple single add/delete of a contiguous block + end + + -- Extract the differing middle parts (using 1-based indexing for string.sub) + -- The middle part starts *after* the prefix and ends *before* the suffix. + local middle_start_idx = prefix_len + 1 -- 1-based index + local middle_end_idx_a = len_a - suffix_len -- 1-based index + local middle_end_idx_b = len_b - suffix_len -- 1-based index + + local middle_a = a:sub(middle_start_idx, middle_end_idx_a) + local middle_b = b:sub(middle_start_idx, middle_end_idx_b) + + -- Analyze the middle parts to determine the type of edit + if middle_a == "" and middle_b ~= "" then + -- Case: Addition (a -> b by adding middle_b) + -- The addition happens at the position after the prefix. + -- start_col is the 0-indexed position *before* the added text. + -- end_col is the 0-indexed position *after* the added text (same as start for insertion). + return { + start_col = prefix_len, -- 0-indexed column where addition starts + end_col = prefix_len, -- 0-indexed column where addition ends (exclusive) + text = middle_b, -- The text that was added + } + elseif middle_a ~= "" and middle_b == "" then + -- Case: Deletion (a -> b by deleting middle_a) + -- The deletion spans from the end of the prefix to the start of the suffix in 'a'. + -- start_col is the 0-indexed position of the first character deleted. + -- end_col is the 0-indexed position *after* the last character deleted in the original string 'a'. + return { + start_col = prefix_len, -- 0-indexed column where deletion starts + end_col = len_a - suffix_len, -- 0-indexed column where deletion ends (exclusive, relative to 'a') + text = "", + } + else + -- Case: Both middle parts are non-empty (replacement or multiple changes) + -- Case: Both middle parts are empty (covered by initial a == b check, or implies prefix/suffix covers everything but strings are different, handled by validity check above) + return nil -- Not a simple single add/delete + end +end + +return M diff --git a/plugin/nes.lua b/plugin/nes.lua new file mode 100644 index 0000000..2d56d7c --- /dev/null +++ b/plugin/nes.lua @@ -0,0 +1,3 @@ +vim.api.nvim_set_hl(0, "NesAdd", { link = "DiffAdd", default = true }) +vim.api.nvim_set_hl(0, "NesDelete", { link = "DiffDelete", default = true }) +vim.api.nvim_set_hl(0, "NesApply", { link = "DiffText", default = true }) diff --git a/scripts/minimal_init.lua b/scripts/minimal_init.lua new file mode 100644 index 0000000..1ee5df3 --- /dev/null +++ b/scripts/minimal_init.lua @@ -0,0 +1,12 @@ +-- Add current directory to 'runtimepath' to be able to use 'lua' files +vim.cmd([[let &rtp.=','.getcwd()]]) + +-- Set up 'mini.test' only when calling headless Neovim (like with `make test`) +if #vim.api.nvim_list_uis() == 0 then + -- Add 'mini.nvim' to 'runtimepath' to be able to use 'mini.test' + -- Assumed that 'mini.nvim' is stored in 'deps/mini.nvim' + vim.cmd("set rtp+=deps/mini.nvim") + + -- Set up 'mini.test' + require("mini.test").setup() +end diff --git a/tests/test_util.lua b/tests/test_util.lua new file mode 100644 index 0000000..dfad8c0 --- /dev/null +++ b/tests/test_util.lua @@ -0,0 +1,174 @@ +local new_set = MiniTest.new_set +local eq = MiniTest.expect.equality + +local T = new_set() + +---@private +---@alias Case.Args {a: string, b: string, opts?: table} + +---@private +---@alias Case.Expected lsp.TextEdit[] + +---@private +---@alias Case [Case.Args, Case.Expected] + +T["text_edits_from_diff"] = new_set() + +---@type table +local cases = { + ["no change"] = { + { + a = "aaa", + b = "aaa", + }, + {}, + }, + ["replace single line"] = { + { + a = "prefix\naaa\nsuffix", + b = "prefix\nbbb\nsuffix", + }, + { + { + range = { + start = { line = 1, character = 0 }, + ["end"] = { line = 1, character = 3 }, + }, + newText = "bbb", + }, + }, + }, + ["less to more"] = { + { + a = "prefix\naaa\nsuffix", + b = "prefix\nbbb\nccc\nsuffix", + }, + { + { + range = { + start = { line = 1, character = 0 }, + ["end"] = { line = 1, character = 3 }, + }, + newText = "bbb\nccc", + }, + }, + }, + ["more to less"] = { + { + a = "prefix\naaa\nbbb\nsuffix", + b = "prefix\nccc\nsuffix", + }, + { + { + range = { + start = { line = 1, character = 0 }, + ["end"] = { line = 2, character = 3 }, + }, + newText = "ccc", + }, + }, + }, + ["delete lines"] = { + { + a = "prefix\naaa\nbbb\nsuffix", + b = "prefix\nsuffix", + }, + { + { + range = { + start = { line = 1, character = 0 }, + ["end"] = { line = 3, character = 0 }, + }, + newText = "", + }, + }, + }, + ["add lines"] = { + { + a = "prefix\nsuffix", + b = "prefix\naaa\nbbb\nsuffix", + }, + { + { + range = { + start = { line = 0, character = 6 }, + ["end"] = { line = 0, character = 6 }, + }, + newText = "\naaa\nbbb", + }, + }, + }, + ["no suffix"] = { + { + a = "prefix\n", + b = "prefix\naaa\nbbb", + }, + { + { + range = { + start = { line = 0, character = 6 }, + ["end"] = { line = 0, character = 6 }, + }, + newText = "\naaa\nbbb", + }, + }, + }, + ["line offset"] = { + { + a = "prefix\naaa\nsuffix", + b = "prefix\nbbb\nsuffix", + opts = { line_offset = 10 }, + }, + { + { + range = { + start = { line = 11, character = 0 }, + ["end"] = { line = 11, character = 3 }, + }, + newText = "bbb", + }, + }, + }, + ["inline add"] = { + { + a = "prefix\naaaccc\nsuffix", + b = "prefix\naaabbbccc\nsuffix", + }, + { + { + range = { + start = { line = 1, character = 3 }, + ["end"] = { line = 1, character = 3 }, + }, + newText = "bbb", + }, + }, + }, + ["inline delete"] = { + { + a = "prefix\naaabbbccc\nsuffix", + b = "prefix\naaaccc\nsuffix", + }, + { + { + range = { + start = { line = 1, character = 3 }, + ["end"] = { line = 1, character = 6 }, + }, + newText = "", + }, + }, + }, +} + +do + for name, case in pairs(cases) do + T["text_edits_from_diff"][name] = function() + local args, expected = unpack(case) + local actual = require("nes.util").text_edits_from_diff(args.a, args.b, args.opts) + eq(actual, expected) + end + end +end + +return T