From e4e6dc0457e94ee5d8718f780ee59ae724651301 Mon Sep 17 00:00:00 2001 From: Max Lang <17551908+just-max@users.noreply.github.com> Date: Sun, 14 Jul 2024 06:52:28 +0200 Subject: [PATCH] thread counter: rework the implementation to avoid stray threads --- .../thread-counter/thread_counter.ml | 249 +++++++++--------- 1 file changed, 131 insertions(+), 118 deletions(-) diff --git a/src/stdlib-variants/thread-counter/thread_counter.ml b/src/stdlib-variants/thread-counter/thread_counter.ml index 8d7d4d4..c97511b 100644 --- a/src/stdlib-variants/thread-counter/thread_counter.ml +++ b/src/stdlib-variants/thread-counter/thread_counter.ml @@ -1,7 +1,4 @@ -exception TooManyThreads of int -exception ThreadsLeftRunning of int * int -exception StopThread - +open Common let add_err (e : 'e) = function Error es -> Error (es @ [e]) | Ok _ -> Error [e] @@ -22,16 +19,6 @@ end open Debug - -let _ = - Printexc.register_printer (function - | TooManyThreads n -> Some (Printf.sprintf "Used too many threads (> %d)" n) - | ThreadsLeftRunning (n, limit) -> - Some - (Printf.sprintf "Too many threads were left running (%d > %d)" n limit) - | _ -> None) - - module ThreadH = Hashtbl.Make (struct type t = Thread.t @@ -39,154 +26,180 @@ module ThreadH = Hashtbl.Make (struct let hash t = Hashtbl.hash (Thread.id t) end) -type threadset = unit ThreadH.t +module Counter = struct + open Ctx_util -module Counter = struct + let lock_if b m = if b then lock_mutex m else empty_context' () - open Common.Ctx_util + (* Note: we enforce that spawned threads don't raise uncaught exceptions, + which in theory changes the semantics of threads. The value of being + able to report stray exceptions outweighs the slim chance anyone + would rely on being able to ignore exceptions in threads. *) - type state = Running | Stopped | Overflowed + type 'a finished = Return of 'a | Uncaught of exn_info | Overflow of int + type 'a state = Running | Finished of 'a finished - type threadgroup = { - running : threadset; - overflow : Condition.t; (* predicate P(group) := group.state = Overflowed *) - mutable state : state; - max_threads : int; + type 'a group = { + mutable state : 'a state; + finished : Condition.t; (* predicate: state <> Running, mutex: owner.mut *) + mutable thread_count : int; + thread_limit : int option; owner : t; } - and t = { mut : Mutex.t; groups : threadgroup ThreadH.t } + and g = G : 'a group -> g + and t = { mut : Mutex.t; groups : g ThreadH.t } - let create_thread ?group cnt f x = - d_ "create_thread"; + let finish ?(lock = true) group fin = + d_ "finish"; - let< _ = lock_mutex cnt.mut in + let< _ = lock_if lock group.owner.mut in + if group.state <> Running then failwith "finish: already finished"; + group.state <- Finished fin; + Condition.broadcast group.finished - let group = - match group with - | Some g -> g - | None -> ThreadH.find cnt.groups (Thread.self ()) - in + let try_finish group fin = + let< _ = lock_mutex group.owner.mut in + if group.state = Running then + finish ~lock:false group fin - let make_thread () = + let try_return group x = try_finish group (Return x) + + let spawn_thread ?(lock = true) ?group cnt (f : _ -> unit) x = + d_ "create_thread"; + + let make_thread group f x = + let cnt = group.owner in let tid = Thread.create (fun () -> d_ "thread#%d started" Thread.(self () |> id); Fun.protect - (fun () -> f x) + (fun () -> + Util.try_to_result f x + |> Result.iter_error (fun e -> try_finish group (Uncaught e))) ~finally:(fun () -> let< _ = lock_mutex cnt.mut in let tid = Thread.self () in - ThreadH.remove group.running tid; + group.thread_count <- group.thread_count - 1; ThreadH.remove cnt.groups tid; d_ "thread#%d finished" Thread.(self () |> id))) () in - ThreadH.replace cnt.groups tid group; - ThreadH.replace group.running tid (); + ThreadH.replace cnt.groups tid (G group); + group.thread_count <- group.thread_count + 1; tid in - match group.state with - | Running -> - if ThreadH.length group.running >= group.max_threads then ( - group.state <- Overflowed; - Condition.broadcast group.overflow; - (* Mutex.unlock cnt.mut; *) - Thread.self () - (* raise? *) - (* raise StopThread *)) - else make_thread () - | Overflowed | Stopped -> - (* Mutex.unlock cnt.mut (* raise? *); *) - Thread.self () - (* raise StopThread *) + let spawn = fun group -> + match group.state, group.thread_limit with + | Running, Some limit when group.thread_count >= limit -> + finish ~lock:false group (Overflow limit); Thread.self () + | Running, _ -> make_thread group f x + | Finished _, _ -> Thread.self () + in + + let< _ = lock_if lock cnt.mut in + match group with + | Some g -> spawn g + | None -> let G g = ThreadH.find cnt.groups (Thread.self ()) in spawn g let create_counter () = { mut = Mutex.create (); groups = ThreadH.create 32 } - let create_group ~max_threads cnt = + let create_group ?thread_limit cnt = d_ "create_group"; (* no lock needed, since we don't write to the owner; operations on the group will lock instead *) { - running = ThreadH.create 8; state = Running; - overflow = Condition.create (); - max_threads; + finished = Condition.create (); + thread_count = 0; + thread_limit; owner = cnt; } - let wait_for_overflow group = - d_ "wait_for_overflow"; - let< _ = lock_mutex group.owner.mut in - while not (group.state = Overflowed) do - Condition.wait group.overflow group.owner.mut - done + let join_group ~timeout ?(leftover_thread_limit = 0) group = + (* group must be stopped first; busy waits to implement timeout *) + d_ "join_group"; - let set_stopped group = - d_ "set_stopped"; - let< _ = lock_mutex group.owner.mut in - if group.state = Running then group.state <- Stopped + let _ = + let< _ = lock_mutex group.owner.mut in + if group.state = Running then failwith "join_group: still running" + in - let join ~timeout ?(allowed = 0) group = - d_ "join"; - set_stopped group; let time0 = Mtime_clock.counter () in - let rec task () = - let remaining = ThreadH.length group.running in - if remaining <= allowed then remaining + let rec loop () = + let remaining = (let< _ = lock_mutex group.owner.mut in group.thread_count) in + if remaining <= leftover_thread_limit then remaining else if Mtime.Span.compare (Mtime_clock.count time0) timeout > 0 then remaining - else ( - Thread.yield (); - task ()) + else (Thread.yield (); loop ()) in - task () - - let run_counted ~max_threads ~max_leftover ~join_timeout counter f x = - - let result = ref None in - let mut = Mutex.create () in - let task_done = Condition.create () in - - let group = create_group ~max_threads counter in - - create_thread ~group counter - (fun () -> - let r = Some (try Ok (f x) with e -> Error e) in - let< _ = lock_mutex mut in - result := r; - Condition.signal task_done) - () - |> ignore; - - Thread.create - (fun () -> - wait_for_overflow group; - let< _ = lock_mutex mut in - Condition.signal task_done) - () - |> ignore; - - let _ = - let< _ = lock_mutex mut in - while not (Option.is_some !result || group.state = Overflowed) do - Condition.wait task_done mut - done + loop () + + type thread_group_err = + | ThreadLimitReached of int + | ThreadsLeftOver of { left_over : int; limit : int } + | ExceptionRaised of { main : bool; exn_info : exn_info } + + let spawn_thread_group ?thread_limit ~join_timeout ?leftover_thread_limit cnt f x = + let group = create_group ?thread_limit cnt in + + spawn_thread ~group cnt Util.(try_return group % try_to_result f) x |> ignore; + + let fin = + let rec loop () = match group.state with + | Running -> d_ "still running"; Condition.wait group.finished cnt.mut; loop () + | Finished fin -> fin + in + let< _ = lock_mutex cnt.mut in loop () in - - let remaining = join ~allowed:max_leftover ~timeout:join_timeout group in - - let r = - (match !result with - | Some (Ok x) -> Ok x - | Some (Error e) -> Error [e] - | None -> Error [TooManyThreads group.max_threads]) + + let leftover_count = join_group ~timeout:join_timeout ?leftover_thread_limit group in + + let r = match[@warning "-4"] fin with + | Return (Ok x) -> Ok x + | Return (Error e) -> Error [ExceptionRaised { main = true; exn_info = e }] + | Uncaught e -> Error [ExceptionRaised { main = false; exn_info = e }] + | Overflow l -> Error [ThreadLimitReached l] in - if remaining > max_leftover then - add_err (ThreadsLeftRunning (remaining, max_leftover)) r - else r + + match leftover_thread_limit with + | Some lim when leftover_count > lim -> + r |> add_err (ThreadsLeftOver { left_over = leftover_count ; limit = lim }) + | _ -> r + + (* the rest is for error reporting, could do with less code/abstraction... *) + + let string_of_thread_group_err = function + | ThreadLimitReached n -> + Printf.sprintf "Too many threads were used (> %d)" n + | ThreadsLeftOver { left_over = n; limit } -> + Printf.sprintf "Too many threads were left running (%d > %d)" n limit + | ExceptionRaised { main; exn_info = Util.{exn; backtrace} } -> + Printf.sprintf "%s thread raised an exception: %s\n%s" + (if main then "The main" else "A created") + (Printexc.to_string exn) (Printexc.raw_backtrace_to_string backtrace) + + let string_of_thread_group_errs = function + | [] -> "Unknown error in a thread group" + | [err] -> "Error in a thread group: " ^ string_of_thread_group_err err + | errs -> + "Multiple errors in a thread group:\n" ^ + (errs |> List.mapi (fun i err -> [ + Printf.sprintf "+----------- %d -----------+" (i + 1); + string_of_thread_group_err err]) + |> List.concat |> String.concat "\n") + + exception ThreadGroupErrs of thread_group_err list + + let _ = + Printexc.register_printer + (function ThreadGroupErrs errs -> Some (string_of_thread_group_errs errs) | _ -> None) + + let thread_group_result_to_exn = function + | Ok x -> x + | Error errs -> raise (ThreadGroupErrs errs) end @@ -196,6 +209,6 @@ module CounterInstance () = struct (** Like {!Stdlib.Thread}, but with counted threads *) module Thread = struct include Thread - let create f x = Counter.create_thread instance f x + let create f x = Counter.spawn_thread instance Util.(ignore % f) x end end