Skip to content

Commit eed7c07

Browse files
committed
commit in progress
1 parent 101720d commit eed7c07

21 files changed

+2080
-1919
lines changed

eopt/dune

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
(library
22
(name castor_eopt)
33
(libraries core castor)
4-
(preprocess (pps ppx_compare ppx_sexp_conv ppx_let ppx_expect ppx_sexp_message))
4+
(preprocess (pps ppx_compare ppx_sexp_conv ppx_let ppx_expect ppx_sexp_message ppx_hash))
55
(inline_tests))

eopt/extract.ml

+166
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
open Core
2+
open Castor.Ast
3+
module Name = Castor.Name
4+
module Egraph = Castor.Egraph
5+
module Visitors = Castor.Visitors
6+
module AE = Egraph.AstEGraph
7+
8+
type stage = Compile | Run [@@deriving compare, equal, hash, sexp]
9+
type stage_env = (Name.t * stage) list [@@deriving compare, hash, sexp]
10+
11+
module Memo_key = struct
12+
type t = stage * stage_env * Egraph.Id.t [@@deriving compare, hash, sexp_of]
13+
end
14+
15+
let incr = List.map ~f:(fun (n, s) -> (Name.incr n, s))
16+
let comptime = List.map ~f:(fun n -> (Name.zero n, Compile))
17+
let runtime = List.map ~f:(fun n -> (Name.zero n, Run))
18+
19+
exception Stage_error
20+
21+
let extract_well_staged g params qs =
22+
let schema = AE.schema g in
23+
let g' = AE.create () in
24+
let memo = Hashtbl.create (module Memo_key) in
25+
let rec extract_pred_exn stage stage_env p : _ pred =
26+
match p with
27+
| `Name n ->
28+
if Set.mem params n && [%equal: stage] stage Run then p
29+
else if Name.is_bound n then
30+
match List.Assoc.find ~equal:[%equal: Name.t] stage_env n with
31+
| Some stage' when [%equal: stage] stage stage' -> p
32+
| _ -> raise_notrace Stage_error
33+
else p
34+
| _ ->
35+
Visitors.Map.pred
36+
(extract_query_exn stage stage_env)
37+
(extract_pred_exn stage stage_env)
38+
p
39+
and extract_query_exn stage stage_env q : Egraph.Id.t =
40+
match extract_query stage stage_env q with
41+
| Some q -> q
42+
| None -> raise_notrace Stage_error
43+
and extract_runtime_enode_exn stage_env enode =
44+
print_s
45+
[%message
46+
"extracting runtime enode"
47+
(enode : (Egraph.Id.t pred, Egraph.Id.t) query)];
48+
49+
let ret =
50+
match enode with
51+
| AHashIdx x ->
52+
let hi_keys = extract_query_exn Compile stage_env x.hi_keys in
53+
let hi_values =
54+
extract_query_exn Run
55+
(incr stage_env @ comptime (schema x.hi_keys))
56+
x.hi_values
57+
in
58+
let hi_key_layout =
59+
Option.map ~f:(extract_query_exn Run stage_env) x.hi_key_layout
60+
in
61+
let hi_lookup =
62+
List.map x.hi_lookup ~f:(extract_pred_exn Run stage_env)
63+
in
64+
AHashIdx { hi_keys; hi_values; hi_key_layout; hi_lookup }
65+
| AOrderedIdx x ->
66+
let oi_keys = extract_query_exn Compile stage_env x.oi_keys in
67+
let oi_values =
68+
extract_query_exn Run
69+
(incr stage_env @ comptime (schema x.oi_keys))
70+
x.oi_values
71+
in
72+
let oi_key_layout =
73+
Option.map x.oi_key_layout ~f:(extract_query_exn Compile stage_env)
74+
in
75+
let oi_lookup =
76+
let extract_bound_exn (p, b) =
77+
(extract_pred_exn Run stage_env p, b)
78+
in
79+
List.map x.oi_lookup
80+
~f:(Tuple2.map ~f:(Option.map ~f:extract_bound_exn))
81+
in
82+
AOrderedIdx { oi_keys; oi_values; oi_key_layout; oi_lookup }
83+
| AList x ->
84+
let l_keys = extract_query_exn Compile stage_env x.l_keys in
85+
let l_values =
86+
extract_query_exn Run
87+
(incr stage_env @ comptime (schema x.l_keys))
88+
x.l_values
89+
in
90+
AList { l_keys; l_values }
91+
| AScalar x ->
92+
let s_pred = extract_pred_exn Compile stage_env x.s_pred in
93+
AScalar { x with s_pred }
94+
| DepJoin x ->
95+
let d_lhs = extract_query_exn Run stage_env x.d_lhs in
96+
let d_rhs =
97+
extract_query_exn Run
98+
(incr stage_env @ runtime (schema x.d_lhs))
99+
x.d_rhs
100+
in
101+
DepJoin { d_lhs; d_rhs }
102+
| q ->
103+
Visitors.Map.query
104+
(extract_query_exn Run stage_env)
105+
(extract_pred_exn Run stage_env)
106+
q
107+
in
108+
print_s
109+
[%message
110+
"extract runtime enode"
111+
(enode : (Egraph.Id.t pred, Egraph.Id.t) query)
112+
(ret : (Egraph.Id.t pred, Egraph.Id.t) query)];
113+
ret
114+
and extract_compile_enode_exn stage_env = function
115+
| AHashIdx _ | AOrderedIdx _ | AList _ | AScalar _ ->
116+
raise_notrace Stage_error
117+
| q ->
118+
Visitors.Map.query
119+
(extract_query_exn Compile stage_env)
120+
(extract_pred_exn Compile stage_env)
121+
q
122+
and extract_query' stage stage_env (q : Egraph.Id.t) : Egraph.Id.t option =
123+
let extract_enode_exn =
124+
match stage with
125+
| Run -> extract_runtime_enode_exn stage_env
126+
| Compile -> extract_compile_enode_exn stage_env
127+
in
128+
let enodes =
129+
AE.enodes g q
130+
|> Iter.filter_map (fun enode ->
131+
try Some (extract_enode_exn enode) with Stage_error -> None)
132+
|> Iter.to_list
133+
in
134+
if List.is_empty enodes then None
135+
else
136+
List.map enodes ~f:(AE.add g')
137+
|> List.reduce_exn ~f:(AE.merge g')
138+
|> Option.return
139+
and extract_query stage stage_env q =
140+
let max_debruijn =
141+
List.filter_map stage_env ~f:(fun (n, _) ->
142+
match n.name with Bound (i, _) -> Some i | _ -> None)
143+
|> List.max_elt ~compare:Int.compare
144+
|> Option.value ~default:(-1)
145+
in
146+
if max_debruijn > AE.max_debruijn_index g q then None
147+
else
148+
let ret =
149+
match Hashtbl.find memo (stage, stage_env, q) with
150+
| Some x -> x
151+
| None ->
152+
let x = extract_query' stage stage_env q in
153+
Hashtbl.add_exn memo ~key:(stage, stage_env, q) ~data:x;
154+
x
155+
in
156+
print_s
157+
[%message
158+
"extract query"
159+
(stage : stage)
160+
(stage_env : stage_env)
161+
(q : Egraph.Id.t)
162+
(ret : Egraph.Id.t option)];
163+
ret
164+
in
165+
let qs' = List.map qs ~f:(extract_query Run []) in
166+
(g', qs')

lib/egraph.ml

+23-9
Original file line numberDiff line numberDiff line change
@@ -395,24 +395,35 @@ module AstLang = struct
395395

396396
and map_args_pred f p = V.Map.pred f (map_args_pred f) p
397397

398-
let match_func q q' =
399-
try
400-
ignore (V.Map2.query () () q q');
401-
true
402-
with V.Map2.Mismatch -> false
398+
let match_func _ _ = assert false
403399
end
404400

405-
module UnitAnalysis = struct
401+
module UnitAnalysis (L : LANG) = struct
402+
type 'a lang = 'a L.t
406403
type t = unit [@@deriving sexp_of, equal]
407404

408405
let of_enode _ _ = ()
409406
let merge _ _ = Ok ()
410407
end
411408

412409
module OptAnalysis = struct
413-
type t = { schema : Schema.t; free : Set.M(Name).t }
410+
type t = { schema : Schema.t; free : Set.M(Name).t; max_debruijn_index : int }
414411
[@@deriving sexp_of, equal]
415412

413+
let rec max_debruijn_index_query data q =
414+
Visitors.Reduce.query (-1) max
415+
(fun id -> (data id).max_debruijn_index)
416+
(max_debruijn_index_pred data)
417+
q
418+
419+
and max_debruijn_index_pred data = function
420+
| `Name { Name.name = Bound (idx, _); _ } -> idx
421+
| (p : _ Ast.ppred) ->
422+
Visitors.Reduce.pred (-1) max
423+
(fun id -> (data id).max_debruijn_index)
424+
(max_debruijn_index_pred data)
425+
p
426+
416427
let of_enode data q =
417428
{
418429
schema = Schema.schema_query_open (fun id -> (data id).schema) q;
@@ -421,6 +432,7 @@ module OptAnalysis = struct
421432
~schema:(fun id -> Set.of_list (module Name) (data id).schema)
422433
(fun id -> (data id).free)
423434
q;
435+
max_debruijn_index = max_debruijn_index_query data q;
424436
}
425437

426438
let merge x x' =
@@ -466,10 +478,12 @@ module AstEGraph = struct
466478
let choose_exn g = choose_bounded_exn g 10
467479
let choose g id = try Some (choose_exn g id) with Choose_failed -> None
468480
let schema g id = (data g id).schema
481+
let max_debruijn_index g id = (data g id).max_debruijn_index
469482
end
470483

471484
let%expect_test "" =
472-
let module E = Make (SymbolLang (String)) (UnitAnalysis) in
485+
let module L = SymbolLang (String) in
486+
let module E = Make (L) (UnitAnalysis (L)) in
473487
let g = E.create () in
474488
let x = E.add g { func = "x"; args = [] } in
475489
let y = E.add g { func = "y"; args = [] } in
@@ -528,7 +542,7 @@ let%expect_test "" =
528542

529543
let%expect_test "" =
530544
let module L = SymbolLang (String) in
531-
let module E = Make (L) (UnitAnalysis) in
545+
let module E = Make (L) (UnitAnalysis (L)) in
532546
let g = E.create () in
533547
let x = E.add g { func = "x"; args = [] }
534548
and y = E.add g { func = "y"; args = [] }

lib/egraph.mli

+3-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ module AstLang : sig
9090
include LANG with type 'a t := 'a t
9191
end
9292

93-
module UnitAnalysis : ANALYSIS with type t = unit and type 'a lang := 'a
93+
module UnitAnalysis (L : LANG) :
94+
ANALYSIS with type t = unit and type 'a lang = 'a L.t
9495

9596
module AstEGraph : sig
9697
open Ast
@@ -101,4 +102,5 @@ module AstEGraph : sig
101102
val choose : t -> Id.t -> < > annot option
102103
val choose_exn : t -> Id.t -> < > annot
103104
val schema : t -> Id.t -> Schema.t
105+
val max_debruijn_index : t -> Id.t -> int
104106
end

lib/is_serializable.ml

-102
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,6 @@
11
open Core
22
open Ast
33
module V = Visitors
4-
(* module A = Abslayout *)
5-
6-
(* class ['a] stage_iter = *)
7-
(* object (self : 'a) *)
8-
(* inherit [_] V.iter *)
9-
10-
(* method! visit_AList (ctx, phase) { l_keys = rk; l_values = rv; _ } = *)
11-
(* self#visit_t (ctx, `Compile) rk; *)
12-
(* self#visit_t (ctx, phase) rv *)
13-
14-
(* method! visit_AHashIdx (ctx, phase) h = *)
15-
(* List.iter h.hi_lookup ~f:(self#visit_pred (ctx, phase)); *)
16-
(* self#visit_t (ctx, `Compile) h.hi_keys; *)
17-
(* self#visit_t (ctx, phase) h.hi_values *)
18-
19-
(* method! visit_AOrderedIdx (ctx, phase) *)
20-
(* { oi_keys = rk; oi_values = rv; oi_lookup; _ } = *)
21-
(* let bound_iter = *)
22-
(* Option.iter ~f:(fun (p, _) -> self#visit_pred (ctx, phase) p) *)
23-
(* in *)
24-
(* List.iter oi_lookup ~f:(fun (b1, b2) -> *)
25-
(* bound_iter b1; *)
26-
(* bound_iter b2); *)
27-
(* self#visit_t (ctx, `Compile) rk; *)
28-
(* self#visit_t (ctx, phase) rv *)
29-
30-
(* method! visit_AScalar (ctx, _) p = self#visit_pred (ctx, `Compile) p.s_pred *)
31-
(* end *)
324

335
let annotate_stage r =
346
let incr = List.map ~f:(fun (n, s) -> (Name.incr n, s)) in
@@ -91,77 +63,3 @@ let is_static r =
9163
with
9264
| Some (_, `Compile) -> true
9365
| Some (_, `Run) | None -> false)
94-
95-
(* exception Un_serial of string *)
96-
97-
(* class ['a] ops_serializable_visitor = *)
98-
(* object *)
99-
(* inherit ['a] stage_iter as super *)
100-
101-
(* method! visit_t ((), s) r = *)
102-
(* super#visit_t ((), s) r; *)
103-
(* match (s, r.node) with *)
104-
(* | `Run, (Relation _ | GroupBy (_, _, _) | Join _ | OrderBy _ | Dedup _) -> *)
105-
(* raise *)
106-
(* @@ Un_serial *)
107-
(* (Format.asprintf *)
108-
(* "Cannot serialize: Bad operator in run-time position %a" A.pp *)
109-
(* r) *)
110-
(* | _ -> () *)
111-
(* end *)
112-
113-
(* class ['a] names_serializable_visitor stage = *)
114-
(* object *)
115-
(* inherit ['a] stage_iter as super *)
116-
117-
(* method! visit_Name (_, s) n = *)
118-
(* match (stage n, s) with *)
119-
(* | `Compile, `Run | `Run, `Compile -> *)
120-
(* let stage = match s with `Compile -> "compile" | `Run -> "run" in *)
121-
(* let msg = *)
122-
(* Fmt.str "Cannot serialize: Found %a in %s time position." Name.pp n *)
123-
(* stage *)
124-
(* in *)
125-
(* raise @@ Un_serial msg *)
126-
(* | _ -> () *)
127-
128-
(* method! visit_t (_, s) = super#visit_t ((), s) *)
129-
(* end *)
130-
131-
(* (\** Return true if `r` is serializable. This function performs two checks: *)
132-
(* - `r` must not contain any compile time only operations in run time position. *)
133-
(* - Run-time names may only appear in run-time position and vice versa. *\) *)
134-
(* let is_serializeable ?(path = Path.root) ?params r = *)
135-
(* try *)
136-
(* (new ops_serializable_visitor)#visit_t ((), `Run) @@ Path.get_exn path r; *)
137-
(* let stage = stage ?params r in *)
138-
(* (new names_serializable_visitor stage)#visit_t ((), `Run) *)
139-
(* @@ Path.get_exn path r; *)
140-
(* Ok () *)
141-
(* with Un_serial msg -> Error msg *)
142-
143-
(* class ['a] ops_spine_serializable_visitor = *)
144-
(* object *)
145-
(* inherit ['a] ops_serializable_visitor *)
146-
(* method! visit_Exists _ _ = () *)
147-
(* method! visit_First _ _ = () *)
148-
(* end *)
149-
150-
(* class ['a] names_spine_serializable_visitor stage = *)
151-
(* object *)
152-
(* inherit ['a] names_serializable_visitor stage *)
153-
(* method! visit_Exists _ _ = () *)
154-
(* method! visit_First _ _ = () *)
155-
(* end *)
156-
157-
(* (\** Return true if the spine of r (the part of the query with no subqueries) is *)
158-
(* serializable. *\) *)
159-
(* let is_spine_serializeable ?(path = Path.root) ?params r = *)
160-
(* try *)
161-
(* (new ops_spine_serializable_visitor)#visit_t ((), `Run) *)
162-
(* @@ Path.get_exn path r; *)
163-
(* let stage = stage ?params r in *)
164-
(* (new names_spine_serializable_visitor stage)#visit_t ((), `Run) *)
165-
(* @@ Path.get_exn path r; *)
166-
(* Ok () *)
167-
(* with Un_serial msg -> Error msg *)

lib/name.ml

+1
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,4 @@ let pp fmt n =
8080
| Bound (i, x) -> Fmt.pf fmt "%d.%s" i x
8181

8282
let fresh fmt = create (Fresh.name Global.fresh fmt)
83+
let is_bound n = match n.name with Bound _ -> true | _ -> false

lib/name.mli

+1
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ val fresh : (int -> string, unit, string) format -> t
2424
val of_string_exn : string -> t
2525
val scope : t -> int option
2626
val unscoped : t -> t
27+
val is_bound : t -> bool

0 commit comments

Comments
 (0)