Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ AutoHashEquals = "2.1.0"
DataStructures = "0.18"
DocStringExtensions = "0.8, 0.9"
Reexport = "0.2, 1"
TermInterface = "0.3.3"
TermInterface = "0.4"
TimerOutputs = "0.5"
julia = "1.8"

Expand Down
54 changes: 0 additions & 54 deletions scratch/eggify.jl

This file was deleted.

3 changes: 1 addition & 2 deletions src/EGraphs/EGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ include("../docstrings.jl")

using DataStructures
using TermInterface
using TermInterface: head
using TimerOutputs
using Metatheory: alwaystrue, cleanast, binarize
using Metatheory.Patterns
Expand Down Expand Up @@ -31,8 +32,6 @@ export merge!
export in_same_class
export addexpr!
export rebuild!
export settermtype!
export gettermtype

include("analysis.jl")
export analyze!
Expand Down
6 changes: 4 additions & 2 deletions src/EGraphs/analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,10 @@ function rec_extract(g::EGraph, costfun, id::EClassId; cse_env = nothing)
elseif n isa ENodeTerm
children = map(arg -> rec_extract(g, costfun, arg; cse_env = cse_env), n.args)
meta = getdata(eclass, :metadata_analysis, nothing)
T = symtype(n)
egraph_reconstruct_expression(T, operation(n), collect(children); metadata = meta, exprhead = exprhead(n))

h = head(n)
args = head_symbol(h) == :call ? [operation(n); children...] : children
maketerm(h, args; metadata = meta)
else
error("Unknown ENode Type $(typeof(n))")
end
Expand Down
70 changes: 19 additions & 51 deletions src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ end
Base.:(==)(a::ENodeLiteral, b::ENodeLiteral) = hash(a) == hash(b)

TermInterface.istree(n::ENodeLiteral) = false
TermInterface.exprhead(n::ENodeLiteral) = nothing
TermInterface.operation(n::ENodeLiteral) = n.value
TermInterface.arity(n::ENodeLiteral) = 0

Expand All @@ -37,25 +36,23 @@ end


mutable struct ENodeTerm <: AbstractENode
exprhead::Union{Symbol,Nothing}
head::Any
operation::Any
symtype::Type
args::Vector{EClassId}
hash::Ref{UInt} # hash cache
ENodeTerm(exprhead, operation, symtype, c_ids) = new(exprhead, operation, symtype, c_ids, Ref{UInt}(0))
ENodeTerm(head, operation, c_ids) = new(head, operation, c_ids, Ref{UInt}(0))
end


function Base.:(==)(a::ENodeTerm, b::ENodeTerm)
hash(a) == hash(b) && a.operation == b.operation
end


TermInterface.istree(n::ENodeTerm) = true
TermInterface.symtype(n::ENodeTerm) = n.symtype
TermInterface.exprhead(n::ENodeTerm) = n.exprhead
TermInterface.head(n::ENodeTerm) = n.head
TermInterface.operation(n::ENodeTerm) = n.operation
TermInterface.arguments(n::ENodeTerm) = n.args
TermInterface.children(n::ENodeTerm) = [n.head; n.args...]
TermInterface.arity(n::ENodeTerm) = length(n.args)

# This optimization comes from SymbolicUtils
Expand All @@ -65,7 +62,7 @@ function Base.hash(t::ENodeTerm, salt::UInt)
!iszero(salt) && return hash(hash(t, zero(UInt)), salt)
h = t.hash[]
!iszero(h) && return h
h′ = hash(t.args, hash(t.exprhead, hash(t.operation, salt)))
h′ = hash(t.args, hash(t.head, hash(t.operation, salt)))
t.hash[] = h′
return h′
end
Expand All @@ -80,9 +77,7 @@ mutable struct EClass
data::AnalysisData
end

function toexpr(n::ENodeTerm)
Expr(:call, :ENode, exprhead(n), operation(n), symtype(n), arguments(n))
end
toexpr(n::ENodeTerm) = Expr(:call, :ENode, head(n), operation(n), arguments(n))

function Base.show(io::IO, x::ENodeTerm)
print(io, toexpr(x))
Expand Down Expand Up @@ -191,8 +186,8 @@ mutable struct EGraph
analyses::Dict{Union{Symbol,Function},Union{Symbol,Function}}
"a cache mapping function symbols to e-classes that contain e-nodes with that function symbol."
symcache::Dict{Any,Vector{EClassId}}
default_termtype::Type
termtypes::TermTypes
head_type::Type
# termtypes::TermTypes
numclasses::Int
numnodes::Int
"If we use global buffers we may need to lock. Defaults to true."
Expand All @@ -209,7 +204,7 @@ end
EGraph(expr)
Construct an EGraph from a starting symbolic expression `expr`.
"""
function EGraph(; needslock::Bool = false, buffer_size = DEFAULT_BUFFER_SIZE)
function EGraph(; needslock::Bool = false, buffer_size = DEFAULT_BUFFER_SIZE, head_type = ExprHead)
EGraph(
IntDisjointSet(),
Dict{EClassId,EClass}(),
Expand All @@ -218,8 +213,8 @@ function EGraph(; needslock::Bool = false, buffer_size = DEFAULT_BUFFER_SIZE)
-1,
Dict{Union{Symbol,Function},Union{Symbol,Function}}(),
Dict{Any,Vector{EClassId}}(),
Expr,
TermTypes(),
head_type,
# TermTypes(),
0,
0,
needslock,
Expand All @@ -234,7 +229,7 @@ function maybelock!(f::Function, g::EGraph)
end

function EGraph(e; keepmeta = false, kwargs...)
g = EGraph(kwargs...)
g = EGraph(; kwargs...)
keepmeta && addanalysis!(g, :metadata_analysis)
g.root = addexpr!(g, e; keepmeta = keepmeta)
g
Expand All @@ -249,22 +244,6 @@ function addanalysis!(g::EGraph, analysis_name::Symbol)
g.analyses[analysis_name] = analysis_name
end

function settermtype!(g::EGraph, f, ar, T)
g.termtypes[(f, ar)] = T
end

function settermtype!(g::EGraph, T)
g.default_termtype = T
end

function gettermtype(g::EGraph, f, ar)
if haskey(g.termtypes, (f, ar))
g.termtypes[(f, ar)]
else
g.default_termtype
end
end


"""
Returns the canonical e-class id for a given e-class.
Expand All @@ -284,7 +263,7 @@ canonicalize(g::EGraph, n::ENodeLiteral) = n
function canonicalize(g::EGraph, n::ENodeTerm)
if arity(n) > 0
new_args = map(x -> find(g, x), n.args)
return ENodeTerm(exprhead(n), operation(n), symtype(n), new_args)
return ENodeTerm(head(n), operation(n), new_args)
end
return n
end
Expand Down Expand Up @@ -367,7 +346,7 @@ function addexpr!(g::EGraph, se; keepmeta = false)::EClassId

id = add!(g, if istree(se)
class_ids::Vector{EClassId} = [addexpr!(g, arg; keepmeta = keepmeta) for arg in arguments(e)]
ENodeTerm(exprhead(e), operation(e), symtype(e), class_ids)
ENodeTerm(head(e), operation(e), class_ids)
else
# constant enode
ENodeLiteral(e)
Expand Down Expand Up @@ -525,39 +504,28 @@ function reachable(g::EGraph, id::EClassId)
return hist
end


"""
When extracting symbolic expressions from an e-graph, we need
to instruct the e-graph how to rebuild expressions of a certain type.
This function must be extended by the user to add new types of expressions that can be manipulated by e-graphs.
"""
function egraph_reconstruct_expression(T::Type{Expr}, op, args; metadata = nothing, exprhead = :call)
similarterm(Expr(:call, :_), op, args; metadata = metadata, exprhead = exprhead)
end

# Thanks to Max Willsey and Yihong Zhang

import Metatheory: lookup_pat

function lookup_pat(g::EGraph, p::PatTerm)::EClassId
@assert isground(p)

eh = exprhead(p)
op = operation(p)
args = arguments(p)
ar = arity(p)

T = gettermtype(g, op, ar)
eh = g.head_type(head_symbol(head(p)))

ids = map(x -> lookup_pat(g, x), args)
!all((>)(0), ids) && return -1

if T == Expr && op isa Union{Function,DataType}
id = lookup(g, ENodeTerm(eh, op, T, ids))
id < 0 && return lookup(g, ENodeTerm(eh, nameof(op), T, ids))
if g.head_type == ExprHead && op isa Union{Function,DataType}
id = lookup(g, ENodeTerm(eh, op, ids))
id < 0 && return lookup(g, ENodeTerm(eh, nameof(op), ids))
return id
else
return lookup(g, ENodeTerm(eh, op, T, ids))
return lookup(g, ENodeTerm(eh, op, ids))
end
end

Expand Down
7 changes: 3 additions & 4 deletions src/EGraphs/saturation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,13 @@ end
instantiate_enode!(bindings::Bindings, g::EGraph, p::Any)::EClassId = add!(g, ENodeLiteral(p))
instantiate_enode!(bindings::Bindings, g::EGraph, p::PatVar)::EClassId = bindings[p.idx][1]
function instantiate_enode!(bindings::Bindings, g::EGraph, p::PatTerm)::EClassId
eh = exprhead(p)
op = operation(p)
ar = arity(p)
args = arguments(p)
T = gettermtype(g, op, ar)
# TODO add predicate check `quotes_operation`
new_op = T == Expr && op isa Union{Function,DataType} ? nameof(op) : op
add!(g, ENodeTerm(eh, new_op, T, map(arg -> instantiate_enode!(bindings, g, arg), args)))
new_op = g.head_type == ExprHead && op isa Union{Function,DataType} ? nameof(op) : op
eh = g.head_type(head_symbol(head(p)))
add!(g, ENodeTerm(eh, new_op, map(arg -> instantiate_enode!(bindings, g, arg), args)))
end

function apply_rule!(buf, g::EGraph, rule::RewriteRule, id, direction)
Expand Down
18 changes: 9 additions & 9 deletions src/Library.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,36 @@ using Metatheory.Rules


macro commutativity(op)
RewriteRule(PatTerm(:call, op, [PatVar(:a), PatVar(:b)]), PatTerm(:call, op, [PatVar(:b), PatVar(:a)]))
RewriteRule(PatTerm(PatHead(:call), op, PatVar(:a), PatVar(:b)), PatTerm(PatHead(:call), op, PatVar(:b), PatVar(:a)))
end

macro right_associative(op)
RewriteRule(
PatTerm(:call, op, [PatVar(:a), PatTerm(:call, op, [PatVar(:b), PatVar(:c)])]),
PatTerm(:call, op, [PatTerm(:call, op, [PatVar(:a), PatVar(:b)]), PatVar(:c)]),
PatTerm(PatHead(:call), op, PatVar(:a), PatTerm(PatHead(:call), op, PatVar(:b), PatVar(:c))),
PatTerm(PatHead(:call), op, PatTerm(PatHead(:call), op, PatVar(:a), PatVar(:b)), PatVar(:c)),
)
end
macro left_associative(op)
RewriteRule(
PatTerm(:call, op, [PatTerm(:call, op, [PatVar(:a), PatVar(:b)]), PatVar(:c)]),
PatTerm(:call, op, [PatVar(:a), PatTerm(:call, op, [PatVar(:b), PatVar(:c)])]),
PatTerm(PatHead(:call), op, PatTerm(PatHead(:call), op, PatVar(:a), PatVar(:b)), PatVar(:c)),
PatTerm(PatHead(:call), op, PatVar(:a), PatTerm(PatHead(:call), op, PatVar(:b), PatVar(:c))),
)
end


macro identity_left(op, id)
RewriteRule(PatTerm(:call, op, [id, PatVar(:a)]), PatVar(:a))
RewriteRule(PatTerm(PatHead(:call), op, id, PatVar(:a)), PatVar(:a))
end

macro identity_right(op, id)
RewriteRule(PatTerm(:call, op, [PatVar(:a), id]), PatVar(:a))
RewriteRule(PatTerm(PatHead(:call), op, PatVar(:a), id), PatVar(:a))
end

macro inverse_left(op, id, invop)
RewriteRule(PatTerm(:call, op, [PatTerm(:call, invop, [PatVar(:a)]), PatVar(:a)]), id)
RewriteRule(PatTerm(PatHead(:call), op, PatTerm(PatHead(:call), invop, PatVar(:a)), PatVar(:a)), id)
end
macro inverse_right(op, id, invop)
RewriteRule(PatTerm(:call, op, [PatVar(:a), PatTerm(:call, invop, [PatVar(:a)])]), id)
RewriteRule(PatTerm(PatHead(:call), op, PatVar(:a), PatTerm(PatHead(:call), invop, PatVar(:a))), id)
end


Expand Down
1 change: 1 addition & 0 deletions src/Metatheory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using DataStructures
using Base.Meta
using Reexport
using TermInterface
using TermInterface: head

@inline alwaystrue(x) = true

Expand Down
Loading