diff --git a/src/stdlib-variants/thread-counter/thread_counter.ml b/src/stdlib-variants/thread-counter/thread_counter.ml index c97511b..d8e9ccc 100644 --- a/src/stdlib-variants/thread-counter/thread_counter.ml +++ b/src/stdlib-variants/thread-counter/thread_counter.ml @@ -29,6 +29,7 @@ end) module Counter = struct open Ctx_util + open Ctx_util.Syntax let lock_if b m = if b then lock_mutex m else empty_context' () @@ -37,7 +38,7 @@ module Counter = struct able to report stray exceptions outweighs the slim chance anyone would rely on being able to ignore exceptions in threads. *) - type 'a finished = Return of 'a | Uncaught of exn_info | Overflow of int + type 'a finished = Return of 'a | Uncaught of Util.exn_info | Overflow of int type 'a state = Running | Finished of 'a finished type 'a group = { @@ -119,7 +120,10 @@ module Counter = struct owner = cnt; } - let join_group ~timeout ?(leftover_thread_limit = 0) group = + let get_thread_count group = + let< _ = lock_mutex group.owner.mut in group.thread_count + + let join_group ~leftover_thread_limit ~timeout group = (* group must be stopped first; busy waits to implement timeout *) d_ "join_group"; @@ -130,9 +134,9 @@ module Counter = struct let time0 = Mtime_clock.counter () in let rec loop () = - let remaining = (let< _ = lock_mutex group.owner.mut in group.thread_count) in + let remaining = get_thread_count group in if remaining <= leftover_thread_limit then remaining - else if Mtime.Span.compare (Mtime_clock.count time0) timeout > 0 then remaining + else if Mtime.Span.compare (Mtime_clock.count time0) timeout >= 0 then remaining else (Thread.yield (); loop ()) in loop () @@ -140,9 +144,9 @@ module Counter = struct type thread_group_err = | ThreadLimitReached of int | ThreadsLeftOver of { left_over : int; limit : int } - | ExceptionRaised of { main : bool; exn_info : exn_info } + | ExceptionRaised of { main : bool; exn_info : Util.exn_info } - let spawn_thread_group ?thread_limit ~join_timeout ?leftover_thread_limit cnt f x = + let spawn_thread_group_no_check ?thread_limit ?leftover_limit cnt f x = let group = create_group ?thread_limit cnt in spawn_thread ~group cnt Util.(try_return group % try_to_result f) x |> ignore; @@ -155,8 +159,15 @@ module Counter = struct let< _ = lock_mutex cnt.mut in loop () in - let leftover_count = join_group ~timeout:join_timeout ?leftover_thread_limit group in + let leftover_count = + match leftover_limit with + | Some (n, t) -> join_group ~leftover_thread_limit:n ~timeout:t group + | None -> get_thread_count group + in + + fin, leftover_count + let check_spawn_thread_group ?leftover_limit fin leftover_count = let r = match[@warning "-4"] fin with | Return (Ok x) -> Ok x | Return (Error e) -> Error [ExceptionRaised { main = true; exn_info = e }] @@ -164,11 +175,15 @@ module Counter = struct | Overflow l -> Error [ThreadLimitReached l] in - match leftover_thread_limit with - | Some lim when leftover_count > lim -> + match leftover_limit with + | Some (lim, (_ : Mtime.span)) when leftover_count > lim -> r |> add_err (ThreadsLeftOver { left_over = leftover_count ; limit = lim }) | _ -> r + let spawn_thread_group ?thread_limit ?leftover_limit cnt f x = + let fin, leftover_count = spawn_thread_group_no_check cnt f x ?thread_limit ?leftover_limit in + check_spawn_thread_group ?leftover_limit fin leftover_count + (* the rest is for error reporting, could do with less code/abstraction... *) let string_of_thread_group_err = function @@ -180,9 +195,10 @@ module Counter = struct Printf.sprintf "%s thread raised an exception: %s\n%s" (if main then "The main" else "A created") (Printexc.to_string exn) (Printexc.raw_backtrace_to_string backtrace) + |> String.trim let string_of_thread_group_errs = function - | [] -> "Unknown error in a thread group" + | [] -> "Unknown error in a thread group" (* this shouldn't happen *) | [err] -> "Error in a thread group: " ^ string_of_thread_group_err err | errs -> "Multiple errors in a thread group:\n" ^