Skip to content

Commit

Permalink
Add thread counter (#21)
Browse files Browse the repository at this point in the history
* For tracking and restricting threads spawned by student code
  • Loading branch information
just-max authored Jul 15, 2024
1 parent 6870040 commit e47c06d
Show file tree
Hide file tree
Showing 2 changed files with 245 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/stdlib-variants/thread-counter/dune
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
(library
(name thread_counter)
(public_name less-power.thread-counter)
(libraries common mtime mtime.clock.os threads))
241 changes: 241 additions & 0 deletions src/stdlib-variants/thread-counter/thread_counter.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
open Common

let add_err (e : 'e) = function Error es -> Error (es @ [e]) | Ok _ -> Error [e]

module Debug = struct
let disabled_ fmt = Stdlib.(Printf.ifprintf stderr fmt)

let enabled_ fmt =
Stdlib.(
Printf.ksprintf
(fun s ->
prerr_string s;
prerr_newline ())
fmt)

let d_ = disabled_
let i_ = enabled_
end

open Debug

module ThreadH = Hashtbl.Make (struct
type t = Thread.t

let equal t1 t2 = compare (Thread.id t1) (Thread.id t2) = 0
let hash t = Hashtbl.hash (Thread.id t)
end)

module Counter = struct

open Ctx_util
open Ctx_util.Syntax

let lock_if b m = if b then Mutex.protect m else empty_context' ()

(** Note: we enforce that spawned threads don't raise uncaught exceptions,
which in theory changes the semantics of threads. The value of being
able to report stray exceptions outweighs the slim chance anyone
would rely on being able to ignore exceptions in threads. *)

type 'a finished = Return of 'a | Uncaught of Util.exn_info | Overflow of int
type 'a state = Running | Finished of 'a finished

type 'a group = {
mutable state : 'a state;
finished : Condition.t; (* predicate: state <> Running, mutex: owner.mut *)
mutable thread_count : int;
thread_limit : int option;
owner : t;
}

and g = G : 'a group -> g
and t = { mut : Mutex.t; groups : g ThreadH.t }

let finish ?(lock = true) group fin =
d_ "finish";

let< _ = lock_if lock group.owner.mut in
if group.state <> Running then failwith "finish: already finished";
group.state <- Finished fin;
Condition.broadcast group.finished

let try_finish group fin =
let< _ = Mutex.protect group.owner.mut in
if group.state = Running then
finish ~lock:false group fin

let try_return group x = try_finish group (Return x)

let spawn_thread ?(lock = true) ?group cnt (f : _ -> unit) x =
d_ "create_thread";

let make_thread group f x =
let cnt = group.owner in
let tid =
Thread.create
(fun () ->
d_ "thread#%d started" Thread.(self () |> id);
Fun.protect
(fun () ->
Util.try_to_result f x
|> Result.iter_error (fun e -> try_finish group (Uncaught e)))
~finally:(fun () ->
let< _ = Mutex.protect cnt.mut in
let tid = Thread.self () in
group.thread_count <- group.thread_count - 1;
ThreadH.remove cnt.groups tid;
d_ "thread#%d finished" Thread.(self () |> id)))
()
in
ThreadH.replace cnt.groups tid (G group);
group.thread_count <- group.thread_count + 1;
tid
in

let spawn = fun group ->
match group.state, group.thread_limit with
| Running, Some limit when group.thread_count >= limit ->
finish ~lock:false group (Overflow limit); Thread.self ()
| Running, _ -> make_thread group f x
| Finished _, _ -> Thread.self ()
in

let< _ = lock_if lock cnt.mut in
match group with
| Some g -> spawn g
| None -> let G g = ThreadH.find cnt.groups (Thread.self ()) in spawn g

let create_counter () = { mut = Mutex.create (); groups = ThreadH.create 32 }

let create_group ?thread_limit cnt =
d_ "create_group";
(* no lock needed, since we don't write to the owner;
operations on the group will lock instead *)
{
state = Running;
finished = Condition.create ();
thread_count = 0;
thread_limit;
owner = cnt;
}

let get_thread_count group =
let< _ = Mutex.protect group.owner.mut in group.thread_count

(** Wait for threads in a group to complete. Group must be finished first. *)
let join_group ~leftover_thread_limit ~timeout group =
(* busy waits to implement timeout *)
d_ "join_group";

let _ =
let< _ = Mutex.protect group.owner.mut in
if group.state = Running then failwith "join_group: still running"
in

let time0 = Mtime_clock.counter () in
let rec loop () =
let remaining = get_thread_count group in
if remaining <= leftover_thread_limit then remaining
else if Mtime.Span.compare (Mtime_clock.count time0) timeout >= 0 then remaining
else (Thread.yield (); loop ())
in
loop ()

(** {1 High-level group operations} *)

type thread_group_err =
| ThreadLimitReached of int
| ThreadsLeftOver of { left_over : int; limit : int }
| ExceptionRaised of { main : bool; exn_info : Util.exn_info }

(** Create a group that runs the given function, then sets the return value
as the return value of the group. *)
let spawn_thread_group ?thread_limit cnt f x =
let group = create_group ?thread_limit cnt in
spawn_thread ~group cnt Util.(try_return group % try_to_result f) x |> ignore;
group

(** Wait for a group to finish then join its threads ({!join_group}). *)
let collect_thread_group ?leftover_limit group =
let cnt = group.owner in
let fin =
let rec loop () = match group.state with
| Running -> d_ "still running"; Condition.wait group.finished cnt.mut; loop ()
| Finished fin -> fin
in
let< _ = Mutex.protect cnt.mut in loop ()
in

let leftover_count =
match leftover_limit with
| Some (n, t) -> join_group ~leftover_thread_limit:n ~timeout:t group
| None -> get_thread_count group
in

fin, leftover_count

(** Check the return value of {!collect_thread_group} *)
let check_thread_group_result ?leftover_limit fin leftover_count =
let r = match fin with
| Return (Ok x) -> Ok x
| Return (Error e) -> Error [ExceptionRaised { main = true; exn_info = e }]
| Uncaught e -> Error [ExceptionRaised { main = false; exn_info = e }]
| Overflow l -> Error [ThreadLimitReached l]
in

match leftover_limit with
| Some (lim, (_ : Mtime.span)) when leftover_count > lim ->
r |> add_err (ThreadsLeftOver { left_over = leftover_count ; limit = lim })
| _ -> r

(** Combines {!spawn_thread_group}, {!collect_thread_group}, and {!check_thread_group_result}.*)
let run_in_thread_group ?thread_limit ?leftover_limit cnt f x =
let group = spawn_thread_group cnt f x ?thread_limit in
let fin, leftover_count = collect_thread_group group ?leftover_limit in
check_thread_group_result ?leftover_limit fin leftover_count

(* the rest is for error reporting, could do with less code/abstraction... *)

let string_of_thread_group_err = function
| ThreadLimitReached n ->
Printf.sprintf "Too many threads were used (> %d)" n
| ThreadsLeftOver { left_over = n; limit } ->
Printf.sprintf "Too many threads were left running (%d > %d)" n limit
| ExceptionRaised { main; exn_info = Util.{exn; backtrace} } ->
Printf.sprintf "%s thread raised an exception: %s\n%s"
(if main then "The main" else "A created")
(Printexc.to_string exn) (Printexc.raw_backtrace_to_string backtrace)
|> String.trim

let string_of_thread_group_errs = function
| [] -> "Unknown error in a thread group" (* this shouldn't happen *)
| [err] -> "Error in a thread group: " ^ string_of_thread_group_err err
| errs ->
"Multiple errors in a thread group:\n" ^
(errs |> List.mapi (fun i err -> [
Printf.sprintf "+----------- %d -----------+" (i + 1);
string_of_thread_group_err err])
|> List.concat |> String.concat "\n")

exception ThreadGroupErrs of thread_group_err list

let _ =
Printexc.register_printer
(function ThreadGroupErrs errs -> Some (string_of_thread_group_errs errs) | _ -> None)

let thread_group_result_to_exn = function
| Ok x -> x
| Error errs -> raise (ThreadGroupErrs errs)

end

module CounterInstance () = struct
let instance = Counter.create_counter ()

(** Like {!Stdlib.Thread}, but with counted threads *)
module Thread = struct
include Thread
let create f x = Counter.spawn_thread instance Util.(ignore % f) x
end
end

0 comments on commit e47c06d

Please sign in to comment.