Skip to content

Commit

Permalink
fix(watcher): race condition between positions update and watch run on (
Browse files Browse the repository at this point in the history
#479)

BufWritePost
  • Loading branch information
YaroSpace authored Dec 27, 2024
1 parent 6d3d22c commit 5cd797d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
14 changes: 13 additions & 1 deletion lua/neotest/consumers/watch/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ local neotest = {}
---@class neotest.consumers.watch
neotest.watch = {}

local init = function(client)
client.listeners.discover_positions = function(_, tree)
for _, watcher in pairs(watchers) do
if watcher.tree:data().path == tree:data().path
and not watcher.discover_positions_event.is_set() then
watcher.discover_positions_event.set()
end
end
end
end

local function get_valid_client(bufnr)
local clients = nio.lsp.get_clients({ bufnr = bufnr })
for _, client in ipairs(clients) do
Expand Down Expand Up @@ -200,7 +211,8 @@ function neotest.watch.is_watching(position_id)
end

neotest.watch = setmetatable(neotest.watch, {
__call = function()
__call = function(_, client)
init(client)
return neotest.watch
end,
})
Expand Down
8 changes: 8 additions & 0 deletions lua/neotest/consumers/watch/watcher.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ local config = require("neotest.config")
---@class neotest.consumers.watch.Watcher
---@field lsp_client nio.lsp.Client
---@field autocmd_id? string
---@field tree neotest.Tree
---@field discover_positions_event nio.control.Future
local Watcher = {}

function Watcher:new(lsp_client)
Expand Down Expand Up @@ -159,6 +161,9 @@ function Watcher:watch(tree, args)
logger.debug("Built dependencies in", elapsed, "ms for", tree:data().id, ":", dependencies)
local dependants = self:_build_dependants(dependencies)

self.tree = tree
self.discover_positions_event = nio.control.future()

self.autocmd_id = nio.api.nvim_create_autocmd("BufWritePost", {
callback = function(autocmd_args)
if type(args.run_predicate) == "function" and not args.run_predicate(autocmd_args.buf) then
Expand All @@ -172,6 +177,9 @@ function Watcher:watch(tree, args)
return
end

self.discover_positions_event.wait()
self.discover_positions_event = nio.control.future()

if tree:data().type ~= "dir" then
run.run(vim.tbl_extend("keep", { tree:data().id }, args))
else
Expand Down

0 comments on commit 5cd797d

Please sign in to comment.