Skip to content

Commit

Permalink
thread counter: split out api
Browse files Browse the repository at this point in the history
  • Loading branch information
just-max committed Jul 14, 2024
1 parent 46484ff commit d5dccdd
Showing 1 changed file with 26 additions and 10 deletions.
36 changes: 26 additions & 10 deletions src/stdlib-variants/thread-counter/thread_counter.ml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ end)
module Counter = struct

open Ctx_util
open Ctx_util.Syntax

let lock_if b m = if b then lock_mutex m else empty_context' ()

Expand All @@ -37,7 +38,7 @@ module Counter = struct
able to report stray exceptions outweighs the slim chance anyone
would rely on being able to ignore exceptions in threads. *)

type 'a finished = Return of 'a | Uncaught of exn_info | Overflow of int
type 'a finished = Return of 'a | Uncaught of Util.exn_info | Overflow of int
type 'a state = Running | Finished of 'a finished

type 'a group = {
Expand Down Expand Up @@ -119,7 +120,10 @@ module Counter = struct
owner = cnt;
}

let join_group ~timeout ?(leftover_thread_limit = 0) group =
let get_thread_count group =
let< _ = lock_mutex group.owner.mut in group.thread_count

let join_group ~leftover_thread_limit ~timeout group =
(* group must be stopped first; busy waits to implement timeout *)
d_ "join_group";

Expand All @@ -130,19 +134,19 @@ module Counter = struct

let time0 = Mtime_clock.counter () in
let rec loop () =
let remaining = (let< _ = lock_mutex group.owner.mut in group.thread_count) in
let remaining = get_thread_count group in
if remaining <= leftover_thread_limit then remaining
else if Mtime.Span.compare (Mtime_clock.count time0) timeout > 0 then remaining
else if Mtime.Span.compare (Mtime_clock.count time0) timeout >= 0 then remaining
else (Thread.yield (); loop ())
in
loop ()

type thread_group_err =
| ThreadLimitReached of int
| ThreadsLeftOver of { left_over : int; limit : int }
| ExceptionRaised of { main : bool; exn_info : exn_info }
| ExceptionRaised of { main : bool; exn_info : Util.exn_info }

let spawn_thread_group ?thread_limit ~join_timeout ?leftover_thread_limit cnt f x =
let spawn_thread_group_no_check ?thread_limit ?leftover_limit cnt f x =
let group = create_group ?thread_limit cnt in

spawn_thread ~group cnt Util.(try_return group % try_to_result f) x |> ignore;
Expand All @@ -155,20 +159,31 @@ module Counter = struct
let< _ = lock_mutex cnt.mut in loop ()
in

let leftover_count = join_group ~timeout:join_timeout ?leftover_thread_limit group in
let leftover_count =
match leftover_limit with
| Some (n, t) -> join_group ~leftover_thread_limit:n ~timeout:t group
| None -> get_thread_count group
in

fin, leftover_count

let check_spawn_thread_group ?leftover_limit fin leftover_count =
let r = match[@warning "-4"] fin with
| Return (Ok x) -> Ok x
| Return (Error e) -> Error [ExceptionRaised { main = true; exn_info = e }]
| Uncaught e -> Error [ExceptionRaised { main = false; exn_info = e }]
| Overflow l -> Error [ThreadLimitReached l]
in

match leftover_thread_limit with
| Some lim when leftover_count > lim ->
match leftover_limit with
| Some (lim, (_ : Mtime.span)) when leftover_count > lim ->
r |> add_err (ThreadsLeftOver { left_over = leftover_count ; limit = lim })
| _ -> r

let spawn_thread_group ?thread_limit ?leftover_limit cnt f x =
let fin, leftover_count = spawn_thread_group_no_check cnt f x ?thread_limit ?leftover_limit in
check_spawn_thread_group ?leftover_limit fin leftover_count

(* the rest is for error reporting, could do with less code/abstraction... *)

let string_of_thread_group_err = function
Expand All @@ -180,9 +195,10 @@ module Counter = struct
Printf.sprintf "%s thread raised an exception: %s\n%s"
(if main then "The main" else "A created")
(Printexc.to_string exn) (Printexc.raw_backtrace_to_string backtrace)
|> String.trim

let string_of_thread_group_errs = function
| [] -> "Unknown error in a thread group"
| [] -> "Unknown error in a thread group" (* this shouldn't happen *)
| [err] -> "Error in a thread group: " ^ string_of_thread_group_err err
| errs ->
"Multiple errors in a thread group:\n" ^
Expand Down

0 comments on commit d5dccdd

Please sign in to comment.