Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concurrent fiber solver #11362

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions otherlibs/stdune/src/hashtbl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ struct
|> List.sort ~compare:(fun (k, _) (k', _) -> Dyn.compare k k'))
;;

let to_list t = foldi t ~init:[] ~f:(fun key v acc -> (key, v) :: acc)

let filteri_inplace t ~f =
filter_map_inplace t ~f:(fun ~key ~data ->
match f ~key ~data with
Expand Down
1 change: 1 addition & 0 deletions otherlibs/stdune/src/hashtbl_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ module type S = sig
val to_dyn : ('v -> Dyn.t) -> 'v t -> Dyn.t
val filteri_inplace : 'a t -> f:(key:key -> data:'a -> bool) -> unit
val length : _ t -> int
val to_list : 'a t -> (key * 'a) list
end
1 change: 1 addition & 0 deletions otherlibs/stdune/src/table.ml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ let filteri_inplace (type input output) ((module T) : (input, output) t) ~f =
;;

let length (type input output) ((module T) : (input, output) t) = T.H.length T.value
let to_list (type input output) ((module T) : (input, output) t) = T.H.to_list T.value

module Multi = struct
let cons t x v =
Expand Down
1 change: 1 addition & 0 deletions otherlibs/stdune/src/table.mli
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ val iter : (_, 'v) t -> f:('v -> unit) -> unit
val filteri_inplace : ('a, 'b) t -> f:(key:'a -> data:'b -> bool) -> unit
val length : (_, _) t -> int
val values : (_, 'a) t -> 'a list
val to_list : ('a, 'b) t -> ('a * 'b) list

module Multi : sig
type ('k, 'v) t
Expand Down
74 changes: 41 additions & 33 deletions src/dune_pkg/opam_solver.ml
Original file line number Diff line number Diff line change
Expand Up @@ -669,36 +669,38 @@ module Solver = struct
might need, adding all of them to [sat_problem]. *)
let build_problem context root_req sat ~dummy_impl =
(* For each (iface, source) we have a list of implementations. *)
let impl_cache = ref Input.Role.Map.empty in
let impl_cache = Fiber_cache.create (module Input.Role) in
let conflict_classes = Conflict_classes.create sat in
let+ () =
let rec lookup_impl expand_deps role =
match Input.Role.Map.find !impl_cache role with
| Some s -> Fiber.return s
| None ->
let* clause, impls =
Candidates.make_impl_clause sat context ~dummy_impl role
in
impl_cache := Input.Role.Map.set !impl_cache role clause;
let+ () =
Fiber.sequential_iter impls ~f:(fun { var = impl_var; impl } ->
Conflict_classes.process conflict_classes impl_var impl;
match expand_deps with
| `No_expand -> Fiber.return ()
| `Expand_and_collect_conflicts deferred ->
Input.Impl.requires role impl
|> Fiber.sequential_iter ~f:(fun (dep : Input.dependency) ->
match dep.importance with
| Ensure -> process_dep expand_deps impl_var dep
| Prevent ->
(* Defer processing restricting deps until all essential
deps have been processed for the entire problem.
Restricting deps will be processed later without
recurring into their dependencies. *)
deferred := (impl_var, dep) :: !deferred;
Fiber.return ()))
in
clause
let impls = ref [] in
let* clause =
Fiber_cache.find_or_add impl_cache role ~f:(fun () ->
let+ clause, impls' =
Candidates.make_impl_clause sat context ~dummy_impl role
in
impls := impls';
clause)
in
let+ () =
Fiber.parallel_iter !impls ~f:(fun { var = impl_var; impl } ->
Conflict_classes.process conflict_classes impl_var impl;
match expand_deps with
| `No_expand -> Fiber.return ()
| `Expand_and_collect_conflicts deferred ->
Input.Impl.requires role impl
|> Fiber.parallel_iter ~f:(fun (dep : Input.dependency) ->
match dep.importance with
| Ensure -> process_dep expand_deps impl_var dep
| Prevent ->
(* Defer processing restricting deps until all essential
deps have been processed for the entire problem.
Restricting deps will be processed later without
recurring into their dependencies. *)
deferred := (impl_var, dep) :: !deferred;
Fiber.return ()))
in
clause
and process_dep expand_deps user_var (dep : Input.dependency) : unit Fiber.t =
(* Process a dependency of [user_var]:
- find the candidate implementations to satisfy it
Expand Down Expand Up @@ -749,13 +751,12 @@ module Solver = struct
restricting dependencies are irrelevant to solving the dependency
problem. *)
List.rev !conflicts
|> Fiber.sequential_iter ~f:(fun (impl_var, dep) ->
|> Fiber.parallel_iter ~f:(fun (impl_var, dep) ->
process_dep `No_expand impl_var dep)
(* All impl_candidates have now been added, so snapshot the cache. *)
in
let impl_clauses = !impl_cache in
Conflict_classes.seal conflict_classes;
impl_clauses
impl_cache
;;

(** [do_solve model req] finds an implementation matching the given
Expand All @@ -780,7 +781,8 @@ module Solver = struct
*)
let sat = S.create () in
let dummy_impl = if closest_match then Some Input.Dummy else None in
let+ impl_clauses = build_problem context root_req sat ~dummy_impl in
let* impl_clauses = build_problem context root_req sat ~dummy_impl in
let+ impl_clauses = Fiber_cache.to_table impl_clauses in
(* Run the solve *)
let decider () =
(* Walk the current solution, depth-first, looking for the first
Expand All @@ -792,7 +794,7 @@ module Solver = struct
then None (* Break cycles *)
else (
Table.set seen req true;
match Input.Role.Map.find_exn impl_clauses req |> Candidates.state with
match Table.find_exn impl_clauses req |> Candidates.state with
| Unselected -> None
| Undecided lit -> Some lit
| Selected deps ->
Expand All @@ -814,7 +816,13 @@ module Solver = struct
| None -> None
| Some _solution ->
(* Build the results object *)
Some (Input.Role.Map.filter_map impl_clauses ~f:Candidates.selected)
Some
(Table.to_list impl_clauses
|> List.filter_map ~f:(fun (key, v) ->
match Candidates.selected v with
| None -> None
| Some v -> Some (key, v))
|> Input.Role.Map.of_list_exn)
;;
end

Expand Down
Loading