Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 159 additions & 23 deletions src/cdomains/affineEquality/sparseImplementation/listMatrix.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ open RatOps

open Batteries

module M = Messages

let timing_wrap = Vector.timing_wrap

module type SparseMatrix =
Expand All @@ -16,6 +18,10 @@ sig
val rref_vec: t -> vec -> t Option.t

val rref_matrix: t -> t -> t Option.t

val linear_disjunct: t -> t -> t
(** [linear_disjunct m1 m2] returns a matrix that contains the linear disjunct of [m1] and [m2].
The result is in rref. If [m1] and [m2] are not linearly disjunct, an exception is raised. *)
end

module type SparseMatrixFunctor =
Expand Down Expand Up @@ -296,29 +302,34 @@ module ListMatrix: SparseMatrixFunctor =
@param v A vector with number of entries equal to the number of columns of [v].
*)
let rref_vec m v =
if is_empty m then (* In this case, v is normalized and returned *)
BatOption.map (fun (_, value) -> init_with_vec @@ div_row v value) (V.find_first_non_zero v)
else (* We try to normalize v and check if a contradiction arises. If not, we insert v at the appropriate place in m (depending on the pivot) *)
let pivot_positions = get_pivot_positions m in
(* filtered_pivots are only the pivots which have a non-zero entry in the corresponding column of v. Only those are relevant to subtract from v *)
let filtered_pivots = List.rev @@ fst @@ List.fold_left (fun (res, pivs_tail) (col_idx, value) ->
let pivs_tail = List.drop_while (fun (_, piv_col, _) -> piv_col < col_idx) pivs_tail in (* Skipping until possible match of both cols *)
match pivs_tail with
| [] -> (res, [])
| (row_idx, piv_col, row) :: ps when piv_col = col_idx -> ((row_idx, piv_col, row, value) :: res, ps)
| _ -> (res, pivs_tail)
) ([], pivot_positions) (V.to_sparse_list v) in
let v_after_elim = List.fold_left (fun acc (row_idx, pivot_position, piv_row, v_at_piv) ->
sub_scaled_row acc piv_row v_at_piv
) v filtered_pivots in
match V.find_first_non_zero v_after_elim with (* now we check for contradictions and finally insert v *)
| None -> Some m (* v is zero vector and was therefore already covered by m *)
| Some (idx, value) ->
if idx = (num_cols m - 1) then
None
else
let normalized_v = V.map_f_preserves_zero (fun x -> x /: value) v_after_elim in
Some (insert_v_according_to_piv m normalized_v idx pivot_positions)
let res =
if is_empty m then (* In this case, v is normalized and returned *)
BatOption.map (fun (_, value) -> init_with_vec @@ div_row v value) (V.find_first_non_zero v)
else (* We try to normalize v and check if a contradiction arises. If not, we insert v at the appropriate place in m (depending on the pivot) *)
let pivot_positions = get_pivot_positions m in
(* filtered_pivots are only the pivots which have a non-zero entry in the corresponding column of v. Only those are relevant to subtract from v *)
let filtered_pivots = List.rev @@ fst @@ List.fold_left (fun (res, pivs_tail) (col_idx, value) ->
let pivs_tail = List.drop_while (fun (_, piv_col, _) -> piv_col < col_idx) pivs_tail in (* Skipping until possible match of both cols *)
match pivs_tail with
| [] -> (res, [])
| (row_idx, piv_col, row) :: ps when piv_col = col_idx -> ((row_idx, piv_col, row, value) :: res, ps)
| _ -> (res, pivs_tail)
) ([], pivot_positions) (V.to_sparse_list v)
in
let v_after_elim = List.fold_left (fun acc (row_idx, pivot_position, piv_row, v_at_piv) ->
sub_scaled_row acc piv_row v_at_piv
) v filtered_pivots in
match V.find_first_non_zero v_after_elim with (* now we check for contradictions and finally insert v *)
| None -> Some m (* v is zero vector and was therefore already covered by m *)
| Some (idx, value) ->
if idx = (num_cols m - 1) then
None
else
let normalized_v = V.map_f_preserves_zero (fun x -> x /: value) v_after_elim in
Some (insert_v_according_to_piv m normalized_v idx pivot_positions)
in
if M.tracing then M.trace "rref_vec" "rref_vec: m:\n%s, v: %s => res:\n%s" (show m) (V.show v) (match res with None -> "None" | Some r -> show r);
res

let rref_vec m v = timing_wrap "rref_vec" (rref_vec m) v

Expand Down Expand Up @@ -372,4 +383,129 @@ module ListMatrix: SparseMatrixFunctor =

let is_covered_by m1 m2 = timing_wrap "is_covered_by" (is_covered_by m1) m2

(** Direct implementation of https://doi.org/10.1007/BF00268497 , chapter 5.2 Calculation of Linear Disjunction
also available at https://www-apr.lip6.fr/~mine/enseignement/mpri/attic/2014-2015/exos/karr.pdf
only difference is the implementation being optimized for sparse matrices in row representation
*)
let linear_disjunct m1 m2 =
let maxcols = num_cols m1 in
let inverse_termorder = fun x y -> y - x in
let rev_matrix = List.map (fun x -> V.of_sparse_list (V.length x) (List.rev @@ V.to_sparse_list x) ) in
let del_col m i = List.map (fun v -> V.tail_afterindex v i) m in
let safe_get_row m i =
try List.nth m i with
| Invalid_argument _ -> V.zero_vec (num_cols m) (* if row is empty, we return zero *)
in
let safe_remove_row m i =
try remove_row m i with (* remove_row can fail for sparse representations *)
| Invalid_argument _ -> m (* if row is empty, we return the original matrix *)
in

let col_and_rc m colidx rowidx =
let col = get_col_upper_triangular m colidx in
let rc = try V.nth col rowidx with (* V.nth could be integrated into get_col for the last few bits of performance... *)
| Invalid_argument _ -> A.zero (* if col is empty, we return zero *) in
col, rc
in

let push_col m colidx col =
List.mapi (fun idx row ->
match V.nth col idx with
| valu when A.equal A.zero valu -> row (* if the value is zero, we do not change the row *)
| valu -> V.push_first row colidx valu
| exception _ -> row
) m
in

let case_two a r col_b =
let a_r = get_row a r in
let res = map2i (fun i x y -> if i < r then
V.map2_f_preserves_zero (fun u j -> u +: y *: j) x a_r
else x) a col_b in
if M.tracing then M.trace "linear_disjunct_cases" "case_two: \na:\n%s, r:%d,\n col_b: %s, a_r: %s, => res:\n%s" (show a) r (V.show col_b) (V.show a_r) (show res);
res
in

let case_three col1 col2 m1 m2 result ridx cidx = (* no new pivots at ridx/cidx *)
let sub_and_lastterm c1 c2 = (* return last element/idx pair that differs*)
let len = V.length c1 in
let c1 = V.to_sparse_list c1 in
let c2 = V.to_sparse_list c2 in
let rec sub_and_last_aux (acclist,acc) c1 c2 =
match c1, c2 with
| (i,_)::_,_ when i >= ridx -> (acclist,acc) (* we are done, no more entries in c1 that are relevant *)
| (i1, v1) :: xs1, (i2, v2) :: xs2 when i1 = i2 ->
let res = A.sub v1 v2 in
let acc = if A.equal res A.zero then acc else Some (i1, v1, v2) in
sub_and_last_aux ((i1,res)::acclist,acc) xs1 xs2
| (i1, v1) :: xs1, (i2, v2) :: xs2 when i1 < i2 -> sub_and_last_aux ((i1,v1)::acclist,Some (i1,v1,A.zero)) xs1 ((i2, v2)::xs2)
| (i1, v1) :: xs1, (i2, v2) :: xs2 (* when i1 > i2 *)-> sub_and_last_aux ((i2,A.neg v2)::acclist,Some (i2,A.zero,v2)) ((i1, v1)::xs1) xs2
| (i,v)::xs ,[] -> sub_and_last_aux ((i,v)::acclist,Some (i,v,A.zero)) xs []
| [], (i,v)::xs -> sub_and_last_aux ((i,v)::acclist,Some (i,A.zero,v)) [] xs
| [], [] -> (acclist,acc)
in
let resl,rest = sub_and_last_aux ([],None) c1 c2 in
if M.tracing then M.trace "linear_disjunct_cases" "sub_and_last: ridx: %d c1: %s, c2: %s, resultlist: %s, result_pivot: %s" ridx (V.show col1) (V.show col2) (String.concat "," (List.map (fun (i,v) -> Printf.sprintf "(%d,%s)" i (A.to_string v)) resl)) (match rest with None -> "None" | Some (i,v1,v2) -> Printf.sprintf "(%d,%s,%s)" i (A.to_string v1) (A.to_string v2));
V.of_sparse_list len (List.rev resl), rest
in
let coldiff,lastdiff = sub_and_lastterm col1 col2 in
match lastdiff with
| None ->
let sameinboth=get_col_upper_triangular m1 cidx in
if M.tracing then M.trace "linear_disjunct_cases" "case_three: no difference found, cidx: %d, ridx: %d, coldiff: %s, sameinboth: %s" cidx ridx (V.show coldiff) (V.show sameinboth);
(del_col m1 cidx, del_col m2 cidx, push_col result cidx sameinboth, ridx) (* No difference found -> (del_col m1 cidx, del_col m2 cidx, push hd to result, ridx)*)
| Some (idx,x,y) ->
let r1 = safe_get_row m1 idx in
let r2 = safe_get_row m2 idx in
let resrow = safe_get_row result idx in
let diff = x -: y in
let multiply_by_t termorder m t =
map2i (fun i x c -> if i <= ridx then
let beta = c /: diff in
V.map2_f_preserves_zero_helper (termorder) (fun u j -> u -: (beta *: j)) x t
else x) m coldiff
in
let transformed_res = multiply_by_t (inverse_termorder) result resrow in
let transformed_a = multiply_by_t (-) m1 r1 in
let alpha = get_col_upper_triangular transformed_a cidx in
let res = push_col transformed_res cidx alpha in
if M.tracing then M.trace "linear_disjunct_cases" "case_three: found difference at ridx: %d idx: %d, x: %s, y: %s, diff: %s, m1: \n%s, m2:\n%s, res:\n%s"
ridx idx (A.to_string x) (A.to_string y) (A.to_string diff) (show m1) (show m2) (show @@ rev_matrix res);
safe_remove_row (transformed_a) idx, safe_remove_row (multiply_by_t (-) m2 r2) idx, safe_remove_row (res) idx, ridx - 1
in

let rec lindisjunc_aux currentrowindex currentcolindex m1 m2 result =
if M.tracing then M.trace "linear_disjunct" "result so far: \n%s, currentrowindex: %d, currentcolindex: %d, m1: \n%s, m2:\n%s "
(show @@ rev_matrix result) currentrowindex currentcolindex (show m1) (show m2);
if currentcolindex >= maxcols then result
else
let col1, rc1 = col_and_rc m1 currentcolindex currentrowindex in
let col2, rc2 = col_and_rc m2 currentcolindex currentrowindex in
match Z.to_int @@ A.get_num rc1, Z.to_int @@ A.get_num rc2 with
| 1, 1 -> lindisjunc_aux
(currentrowindex + 1) (currentcolindex+1)
(del_col m1 currentrowindex) (del_col m2 currentrowindex)
(List.mapi (fun idx row -> if idx = currentrowindex then V.push_first row currentcolindex A.one else row) result)
| 1, 0 -> let beta = get_col_upper_triangular m2 currentcolindex in
if M.tracing then M.trace "linear_disjunct_cases" "case 1,0: currentrowindex: %d, currentcolindex: %d, m1: \n%s, m2:\n%s , beta %s" currentrowindex currentcolindex (show m1) (show m2) (V.show beta);
lindisjunc_aux
(currentrowindex) (currentcolindex+1)
(safe_remove_row (case_two m1 currentrowindex col2) currentrowindex) (safe_remove_row m2 currentrowindex)
(safe_remove_row (push_col result currentcolindex beta) currentrowindex)
| 0, 1 -> let beta = get_col_upper_triangular m1 currentcolindex in
if M.tracing then M.trace "linear_disjunct_cases" "case 0,1: currentrowindex: %d, currentcolindex: %d, m1: \n%s, m2:\n%s , beta %s" currentrowindex currentcolindex (show m1) (show m2) (V.show beta);
lindisjunc_aux
(currentrowindex) (currentcolindex+1)
(safe_remove_row m1 currentrowindex) (safe_remove_row (case_two m2 currentrowindex col1) currentrowindex)
(safe_remove_row (push_col result currentcolindex beta) currentrowindex)
| 0, 0 -> let m1 , m2, result, currentrowindex = case_three col1 col2 m1 m2 result currentrowindex currentcolindex in
lindisjunc_aux currentrowindex (currentcolindex+1) m1 m2 result (* we need to process m1, m2 and result *)
| a,b -> failwith ("matrix not in rref m1: " ^ (string_of_int a) ^ (string_of_int b)^(show m1) ^ " m2: " ^ (show m2))
in
(* create a totally empty intial result, with dimensions rows x cols *)
let pseudoempty = BatList.make (max (num_rows m1) (num_rows m1)) (V.zero_vec (num_cols m1)) in
let res = rev_matrix @@ lindisjunc_aux 0 0 m1 m2 pseudoempty in
if M.tracing then M.tracel "linear_disjunct" "linear_disjunct between \n%s and \n%s =>\n%s" (show m1) (show m2) (show res);
res

end
56 changes: 50 additions & 6 deletions src/cdomains/affineEquality/sparseImplementation/sparseVector.ml
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
open Vector
open RatOps

module M = Messages

open Batteries

module type SparseVector =
sig
include Vector
val push_first: t -> int -> num -> t

val is_zero_vec: t -> bool

val tail_afterindex: t -> int -> t

val insert_zero_at_indices: t -> (int * int) list -> int -> t

val remove_at_indices: t -> int list -> t
Expand All @@ -23,6 +29,8 @@ sig

val map2_f_preserves_zero: (num -> num -> num) -> t -> t -> t

val map2_f_preserves_zero_helper: (int -> int -> int) -> (num -> num -> num) -> t -> t -> t

val find2i_f_false_at_zero: (num -> num -> bool) -> t -> t -> int

val apply_with_c_f_preserves_zero: (num -> num -> num) -> num -> t -> t
Expand Down Expand Up @@ -86,7 +94,8 @@ module SparseVector: SparseVectorFunctor =
let show v =
let rec sparse_list_str i l =
if i >= v.len then "]"
else match l with
else
match l with
| [] -> (A.to_string A.zero) ^" "^ (sparse_list_str (i + 1) l)
| (idx, value) :: xs ->
if i = idx then (A.to_string value) ^" "^ sparse_list_str (i + 1) xs
Expand Down Expand Up @@ -126,6 +135,13 @@ module SparseVector: SparseVectorFunctor =
| Some (idx, value) when idx = n -> value
| _ -> A.zero

let push_first v n num =
if n >= v.len then raise (Invalid_argument "Index out of bounds")
else let res =
{v with entries = (n,num)::v.entries} in
if M.tracing then M.trace "push_first" "pushed %s at index %d, new length: %d, resulting in %s" (A.to_string num) n res.len (res.entries |> List.map (fun (i, x) -> Printf.sprintf "(%d, %s)" i (A.to_string x)) |> String.concat ", ");
res

(**
[set_nth v n num] returns [v] where the [n]-th entry has been set to [num].
@raise Invalid_argument if [n] is out of bounds.
Expand Down Expand Up @@ -156,6 +172,22 @@ module SparseVector: SparseVectorFunctor =
in
{entries = add_indices_helper v.entries indices 0 []; len = v.len + num_zeros}

(**
[tail_afterindex v n] returns the vector starting after the [n]-th entry, i.e. all entries with index > [n].
@raise Invalid_argument if [n] is out of bounds.
*)
let tail_afterindex v n =
if n >= v.len then raise (Invalid_argument "Index out of bounds")
else
match v.entries with
| [] -> v (* If the vector is empty, return it as is *)
| (headidx,headval) :: _ ->
if M.tracing then M.trace "tail_afterindex" "headidx: %d, n: %d, v.len: %d" headidx n v.len;
if headidx > n then v
else
let entries = List.tl v.entries in
{entries; len = v.len }

(**
[remove_nth v n] returns [v] where the [n]-th entry is removed, decreasing the length of the vector by one.
@raise Invalid_argument if [n] is out of bounds
Expand Down Expand Up @@ -257,13 +289,15 @@ module SparseVector: SparseVectorFunctor =
{v with entries = entries'}

(**
[map2_f_preserves_zero f v v'] returns the mapping of [v] and [v'] specified by [f].
[map2_f_preserves_zero termorder f v v'] returns the mapping of [v] and [v'] specified by [f].

Note that [f] {b must} be such that [f 0 0 = 0]!

[termorder] is a function specifying, if the entries of [v] and [v'] are ordered in increasing or decreasing index order.

@raise Invalid_argument if [v] and [v'] have unequal lengths
*)
let map2_f_preserves_zero f v v' =
let map2_f_preserves_zero_helper termorder f v v' =
let f_rem_zero acc idx e1 e2 =
let r = f e1 e2 in
if r =: A.zero then acc else (idx, r) :: acc
Expand All @@ -274,14 +308,25 @@ module SparseVector: SparseVectorFunctor =
| [], (yidx, yval) :: ys -> aux (f_rem_zero acc yidx A.zero yval) [] ys
| (xidx, xval) :: xs, [] -> aux (f_rem_zero acc xidx xval A.zero) xs []
| (xidx, xval) :: xs, (yidx, yval) :: ys ->
match xidx - yidx with
match termorder xidx yidx with
| d when d < 0 -> aux (f_rem_zero acc xidx xval A.zero) xs v2
| d when d > 0 -> aux (f_rem_zero acc yidx A.zero yval) v1 ys
| _ -> aux (f_rem_zero acc xidx xval yval) xs ys
in
if v.len <> v'.len then raise (Invalid_argument "Unequal lengths") else
{v with entries = List.rev (aux [] v.entries v'.entries)}

(**
[map2_f_preserves_zero f v v'] returns the mapping of [v] and [v'] specified by [f].

Note that [f] {b must} be such that [f 0 0 = 0]!

The entries of [v] and [v'] are assumed to be ordered in increasing index order.

@raise Invalid_argument if [v] and [v'] have unequal lengths
*)
let map2_f_preserves_zero f v v'= map2_f_preserves_zero_helper (-) f v v'

let map2_f_preserves_zero f v1 v2 = timing_wrap "map2_f_preserves_zero" (map2_f_preserves_zero f v1) v2

(**
Expand Down Expand Up @@ -321,6 +366,5 @@ module SparseVector: SparseVectorFunctor =

let rev v =
let entries = List.rev_map (fun (idx, value) -> (v.len - 1 - idx, value)) v.entries in
{entries; len = v.len}

{entries; len = v.len}
end
Loading
Loading