Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ luac.out
*.x86_64
*.hex

deps/
12 changes: 12 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Run all test files
test: deps/mini.nvim
nvim --headless --noplugin -u ./scripts/minimal_init.lua -c "lua MiniTest.run()"

# Run test from file at `$FILE` environment variable
test_file: deps/mini.nvim
nvim --headless --noplugin -u ./scripts/minimal_init.lua -c "lua MiniTest.run_file('$(FILE)')"

# Download 'mini.nvim' to use its 'mini.test' testing module
deps/mini.nvim:
@mkdir -p deps
git clone --filter=blob:none https://github.com/echasnovski/mini.nvim $@
5 changes: 0 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ Features may be incomplete, bugs are likely to occur and breaking changes may oc

- `curl` in your `PATH`
- `github/copilot.vim` or `zbirenbaum/copilot.lua`: we don't depend on them directly, but you should sign in first
- `nvim-lua/plenary.nvim`

# Installation

Expand All @@ -22,10 +21,6 @@ Features may be incomplete, bugs are likely to occur and breaking changes may oc
```lua
{
'Xuyuanp/nes.nvim',
event = 'VeryLazy',
dependencies = {
'nvim-lua/plenary.nvim',
},
opts = {},
}

Expand Down
11 changes: 11 additions & 0 deletions lsp/nes.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---@type vim.lsp.Config
return {
cmd = function(dispatchers)
local server = require("nes.lsp.server").new(dispatchers)
return server:new_public_client()
end,
root_dir = vim.uv.cwd(),
capabilities = {
workspace = { workspaceFolders = true },
},
}
108 changes: 71 additions & 37 deletions lua/nes/api.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
local curl = require("plenary.curl")

local nvim_version = vim.version()
local Curl = require("nes.util").Curl

local M = {}

Expand Down Expand Up @@ -37,78 +36,113 @@ local function get_oauth_token()
end
end

local function get_api_token()
if _api_token and _api_token.expires_at > os.time() + 60000 then
return _api_token
local function with_token(cb)
if _api_token and _api_token.expires_at > os.time() + 60 then
cb(nil, _api_token)
return
end

local oauth_token = get_oauth_token()
if not oauth_token then
error("OAuth token not found")
cb("OAuth token not found")
return
end

local request = curl.get("https://api.github.com/copilot_internal/v2/token", {
return Curl.get("https://api.github.com/copilot_internal/v2/token", {
headers = {
Authorization = "Bearer " .. oauth_token,
["Accept"] = "application/json",
["User-Agent"] = "vscode-chat/dev",
},
on_error = function(err)
error("token request error: " .. err)
on_exit = function(out)
if out.code ~= 0 then
cb(out.stderr or out.stdout or ("code: " .. out.code))
return
end
_api_token = vim.json.decode(out.stdout)
cb(nil, _api_token)
end,
})
_api_token = vim.json.decode(request.body)
return _api_token
end

function M.call(payload, callback)
local api_token = get_api_token()
local base_url = api_token.endpoints.proxy or api_token.endpoints.api

local output = ""

local _request = curl.post(base_url .. "/chat/completions", {
function M._call(base_url, api_key, payload, callback)
return Curl.post(base_url .. "/chat/completions", {
headers = {
Authorization = "Bearer " .. api_token.token,
Authorization = "Bearer " .. api_key,
["User-Agent"] = "vscode-chat/dev",
["Content-Type"] = "application/json",
["Copilot-Integration-Id"] = "vscode-chat",
["editor-version"] = ("Neovim/%d.%d.%d"):format(nvim_version.major, nvim_version.minor, nvim_version.patch),
["editor-plugin-version"] = "nes/0.1.0",
},
on_error = function(err)
error("api request error: " .. err)
end,
body = vim.json.encode(payload),
stream = function(_, chunk)
if not chunk then
on_exit = function(out)
if out.code ~= 0 then
callback("")
require("nes.util").notify(out.stderr or ("code: " .. out.code), { level = vim.log.levels.ERROR })
return
end
if vim.startswith(chunk, "data: ") then
chunk = chunk:sub(6)
end
if chunk == "[DONE]" then
return
local stdout = out.stdout
local lines = vim.split(stdout, "\n", { plain = true })
local chunks = {}
for _, line in ipairs(lines) do
line = vim.trim(line)
if line ~= "" then
if vim.startswith(line, "data: ") then
line = line:sub(7)
end
if line ~= "[DONE]" then
table.insert(chunks, line)
end
end
end
local ok, event = pcall(vim.json.decode, chunk)

local json_chunks = string.format("[%s]", table.concat(chunks, ","))
local ok, events = pcall(vim.json.decode, json_chunks)
if not ok then
callback("")
return
end
if event and event.choices and event.choices[1] then
local choice = event.choices[1]
if choice.delta and choice.delta.content then
output = output .. choice.delta.content

local output = ""
for _, event in ipairs(events) do
if event.choices and event.choices[1] then
local choice = event.choices[1]
if choice.delta and choice.delta.content then
output = output .. choice.delta.content
end
end
end
end,
callback = function()
callback(output)
end,
})
end

---@return fun() cancel
function M.call(payload, callback)
local job
job = with_token(vim.schedule_wrap(function(err, api_token)
if err then
require("nes.util").notify(
"Failed to get API token: " .. vim.inspect(err),
{ level = vim.log.levels.ERROR }
)
callback("")
return
end
local base_url = api_token.endpoints.proxy or api_token.endpoints.api
job = M._call(base_url, api_token.token, payload, callback)
end))

return function()
if job then
job:kill(-1)
end
end
end

function M.debug()
vim.print(get_api_token())
vim.print(_api_token)
end

return M
73 changes: 43 additions & 30 deletions lua/nes/context.lua
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ what I will do next. Do not skip any lines. Do not be lazy.
]]

---@class nes.Context
---@field bufnr number
---@field cursor [integer, integer]
---@field cursor [integer, integer] (1,0)-indexed
---@field original_code string
---@field edits string
---@field current_version table
Expand All @@ -61,30 +60,37 @@ local Context = {}
Context.__index = Context

---@return nes.Context
function Context.new(bufnr)
function Context.new_from_buffer(bufnr)
local filename = vim.fn.fnamemodify(vim.api.nvim_buf_get_name(bufnr), ":")
local original_code = vim.fn.readfile(filename)
local current_version = Context.get_current_version(bufnr)
local original_code = table.concat(vim.fn.readfile(filename, ""), "\n")
local current_code = table.concat(vim.api.nvim_buf_get_lines(bufnr, 0, -1, false), "\n")
local cursor = vim.api.nvim_win_get_cursor(0)
local filetype = vim.bo[bufnr].filetype
return Context.new(filename, original_code, current_code, cursor, filetype)
end

---@param filename string
---@param original_code string
---@param current_code string
---@param cursor [integer, integer] (row, col), (1,0)-indexed
---@param lang string
---@return nes.Context
function Context.new(filename, original_code, current_code, cursor, lang)
local self = {
bufnr = bufnr,
cursor = current_version.cursor,
cursor = cursor,
original_code = table.concat(
vim.iter(original_code)
vim.iter(vim.split(original_code, "\n", { plain = true }))
:enumerate()
:map(function(i, line)
return string.format("%d│%s", i, line)
end)
:totable(),
"\n"
),
edits = vim.diff(
table.concat(original_code, "\n"),
table.concat(vim.api.nvim_buf_get_lines(bufnr, 0, -1, false), "\n"),
{ algorithm = "minimal" }
),
edits = vim.diff(original_code, current_code, { algorithm = "minimal" }),
filename = filename,
current_version = current_version,
filetype = vim.bo[bufnr].filetype,
current_version = Context._get_current_version(current_code, cursor),
filetype = lang,
}
setmetatable(self, Context)
return self
Expand Down Expand Up @@ -131,27 +137,34 @@ function Context:payload()
}
end

function Context.get_current_version(bufnr)
local cursor = vim.api.nvim_win_get_cursor(0)
function Context._get_current_version(text, cursor)
local row, col = cursor[1] - 1, cursor[2]
local start_row = row - 20
if start_row < 0 then
start_row = 0
end
local end_row = row + 20
if end_row >= vim.api.nvim_buf_line_count(bufnr) then
end_row = vim.api.nvim_buf_line_count(bufnr) - 1
end
local end_col = vim.api.nvim_buf_get_lines(bufnr, end_row, end_row + 1, false)[1]:len()
local lines = vim.split(text, "\n", { plain = true })
local start_row = math.max(row - 20, 0)
local end_row = math.min(row + 20, #lines)
local start_col = 0
local end_col = lines[end_row]:len()

local before_cursor = vim.api.nvim_buf_get_text(bufnr, start_row, 0, row, col, {})
local after_cursor = vim.api.nvim_buf_get_text(bufnr, row, col, end_row, end_col, {})
return {
local before_cursor_lines = vim.list_slice(lines, start_row + 1, row)
local after_cursor_lines = vim.list_slice(lines, row + 2, end_row + 1)
local before_cursor_text = lines[row + 1]:sub(1, col)
local after_cursor_text = lines[row + 1]:sub(col + 1)

local res = {
cursor = cursor,
start_row = start_row,
end_row = end_row,
text = string.format("%s<|cursor|>%s", table.concat(before_cursor, "\n"), table.concat(after_cursor, "\n")),
start_col = start_col,
end_col = end_col,
text = string.format(
"%s\n%s<|cursor|>%s\n%s",
table.concat(before_cursor_lines, "\n"),
before_cursor_text,
after_cursor_text,
table.concat(after_cursor_lines, "\n")
),
}
return res
end

return Context
Loading