Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 27 additions & 52 deletions src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,22 +224,6 @@ Returns the canonical e-class id for a given e-class.

@inline Base.getindex(g::EGraph, i::Id) = g.classes[IdKey(find(g, i))]

# function canonicalize(g::EGraph, n::VecExpr)::VecExpr
# if !v_isexpr(n)
# v_hash!(n)
# return n
# end
# l = v_arity(n)
# new_n = v_new(l)
# v_set_flag!(new_n, v_flags(n))
# v_set_head!(new_n, v_head(n))
# for i in v_children_range(n)
# @inbounds new_n[i] = find(g, n[i])
# end
# v_hash!(new_n)
# new_n
# end

function canonicalize!(g::EGraph, n::VecExpr)
v_isexpr(n) || @goto ret
for i in (VECEXPR_META_LENGTH + 1):length(n)
Expand All @@ -253,19 +237,16 @@ end

function lookup(g::EGraph, n::VecExpr)::Id
canonicalize!(g, n)
h = IdKey(v_hash(n))

haskey(g.memo, n) ? find(g, g.memo[n]) : 0
id = get(g.memo, n, zero(Id))
iszero(id) ? id : find(g, id)
end


function add_class_by_op(g::EGraph, n, eclass_id)
key = IdKey(v_signature(n))
if haskey(g.classes_by_op, key)
push!(g.classes_by_op[key], eclass_id)
else
g.classes_by_op[key] = [eclass_id]
end
vec = get!(g.classes_by_op, key, Vector{Id}())
push!(vec, eclass_id)
end

"""
Expand All @@ -274,7 +255,8 @@ Inserts an e-node in an [`EGraph`](@ref)
function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)::Id where {ExpressionType,Analysis}
canonicalize!(g, n)

haskey(g.memo, n) && return g.memo[n]
id = get(g.memo, n, zero(Id))
iszero(id) || return id

if should_copy
n = copy(n)
Expand All @@ -291,7 +273,7 @@ function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)
g.memo[n] = id

add_class_by_op(g, n, id)
eclass = EClass{Analysis}(id, VecExpr[n], Pair{VecExpr,Id}[], make(g, n))
eclass = EClass{Analysis}(id, VecExpr[copy(n)], Pair{VecExpr,Id}[], make(g, n))
g.classes[IdKey(id)] = eclass
modify!(g, eclass)
push!(g.pending, n => id)
Expand Down Expand Up @@ -320,28 +302,22 @@ function addexpr!(g::EGraph, se)::Id
se isa EClass && return se.id
e = preprocess(se)

n = if isexpr(e)
args = iscall(e) ? arguments(e) : children(e)
ar = length(args)
n = v_new(ar)
v_set_flag!(n, VECEXPR_FLAG_ISTREE)
iscall(e) && v_set_flag!(n, VECEXPR_FLAG_ISCALL)

h = iscall(e) ? operation(e) : head(e)
v_set_head!(n, add_constant!(g, h))

# get the signature from op and arity
v_set_signature!(n, hash(maybe_quote_operation(h), hash(ar)))

for i in v_children_range(n)
@inbounds n[i] = addexpr!(g, args[i - VECEXPR_META_LENGTH])
end
n
else # constant enode
VecExpr(Id[Id(0), Id(0), Id(0), add_constant!(g, e)])
isexpr(e) || return add!(g, VecExpr(Id[Id(0), Id(0), Id(0), add_constant!(g, e)]), false)

args = iscall(e) ? arguments(e) : children(e)
ar = length(args)
n = v_new(ar)
v_set_flag!(n, VECEXPR_FLAG_ISTREE)
iscall(e) && v_set_flag!(n, VECEXPR_FLAG_ISCALL)
h = iscall(e) ? operation(e) : head(e)
v_set_head!(n, add_constant!(g, h))
# get the signature from op and arity
v_set_signature!(n, hash(maybe_quote_operation(h), hash(ar)))
for i in v_children_range(n)
@inbounds n[i] = addexpr!(g, args[i - VECEXPR_META_LENGTH])
end
id = add!(g, n, false)
return id

add!(g, n, false)
end

"""
Expand Down Expand Up @@ -431,10 +407,10 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp
while !isempty(g.pending) || !isempty(g.analysis_pending)
while !isempty(g.pending)
(node::VecExpr, eclass_id::Id) = pop!(g.pending)
node = copy(node)
canonicalize!(g, node)
if haskey(g.memo, node)
old_class_id = g.memo[node]
g.memo[node] = eclass_id
old_class_id = get!(g.memo, node, eclass_id)
if old_class_id != eclass_id
did_something = union!(g, old_class_id, eclass_id)
# TODO unique! can node dedup be moved here? compare performance
# did_something && unique!(g[eclass_id].nodes)
Expand Down Expand Up @@ -474,9 +450,8 @@ function check_memo(g::EGraph)::Bool
for (id, class) in g.classes
@assert id.val == class.id
for node in class.nodes
if haskey(test_memo, node)
old_id = test_memo[node]
test_memo[node] = id.val
old_id = get!(test_memo, node, id.val)
if old_id != id.val
@assert find(g, old_id) == find(g, id.val) "Unexpected equivalence $node $(g[find(g, id.val)].nodes) $(g[find(g, old_id)].nodes)"
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/EGraphs/uniquequeue.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ function Base.pop!(uq::UniqueQueue{T}) where {T}
v
end

Base.isempty(uq::UniqueQueue) = isempty(uq.vec)
Base.isempty(uq::UniqueQueue) = isempty(uq.vec)
2 changes: 1 addition & 1 deletion src/vecexpr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ end

"""The hash of the e-node."""
@inline v_hash(n::VecExpr)::Id = @inbounds n.data[1]
Base.hash(n::VecExpr) = v_hash(n) # IdKey not necessary here
Base.hash(n::VecExpr, h::UInt) = hash(v_hash(n), h) # IdKey not necessary here
Base.:(==)(a::VecExpr, b::VecExpr) = (@view a.data[2:end]) == (@view b.data[2:end])

"""Set e-node hash to zero."""
Expand Down