Skip to content

Commit

Permalink
refactor: remove Index type from AST
Browse files Browse the repository at this point in the history
This change brings the NKI AST closer to the original Python
AST, at the expense of losing some syntactic constraints.
This will make it easier to do more of the processing on the
Lean side.
  • Loading branch information
govereau committed Nov 20, 2024
1 parent ec3e131 commit 162796b
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 79 deletions.
1 change: 0 additions & 1 deletion Export.lean
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ run_meta
flip List.forM (genPython h)
[ `NKL.Const
, `NKL.Expr
, `NKL.Index
, `NKL.Stmt
, `NKL.Fun
]
63 changes: 23 additions & 40 deletions NKL/Encode.lean
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def encConst : Const -> ByteArray
| .int i => tag 0x03 [encInt i]
| .float f => tag 0x04 [encFloat f]
| .string s => tag 0x05 [encString s]
| .dots => tag 0x06 []

def decConst : DecodeM Const := do
let val <- next
Expand All @@ -173,6 +174,7 @@ def decConst : DecodeM Const := do
| 0x03 => return .int (<- decInt)
| 0x04 => return .float (<- decFloat)
| 0x05 => return .string (<- decString)
| 0x06 => return .dots
| _ => throw s!"Unknown Const tag value {val}"

private def chkConst (c: Const) : Bool :=
Expand All @@ -184,74 +186,55 @@ private def chkConst (c: Const) : Bool :=
#guard chkConst (.int 1)
#guard chkConst (.float 1.0)
#guard chkConst (.string "str")
#guard chkConst .dots

------------------------------------------------------------------------------
-- Expressions

mutual
partial def encExpr : Expr -> ByteArray
| .value c => tag 0x10 [encConst c]
| .bvar s => tag 0x11 [encString s]
| .var s _ => tag 0x12 [encString s]
| .subscript e ix => tag 0x13 [encExpr e, encList encIndex ix]
| .binop op l r => tag 0x14 [encString op, encExpr l, encExpr r]
| .cond c t e => tag 0x15 [encExpr c, encExpr t, encExpr e]
| .tuple es => tag 0x16 [encList encExpr es]
| .list es => tag 0x17 [encList encExpr es]
| .call f ax => tag 0x18 [encExpr f, encList encExpr ax]
| .gridcall f ix ax => tag 0x19 [encExpr f, encList encIndex ix, encList encExpr ax]

partial def encIndex : Index -> ByteArray
| .coord i => tag 0x20 [encExpr i]
| .slice l u step => tag 0x21 [encExpr l, encExpr u, encExpr step]
| .dots => tag 0x22 []
end

mutual
| .subscript e ix => tag 0x13 [encExpr e, encList encExpr ix]
| .slice l u step => tag 0x14 [encExpr l, encExpr u, encExpr step]
| .binop op l r => tag 0x15 [encString op, encExpr l, encExpr r]
| .cond c t e => tag 0x16 [encExpr c, encExpr t, encExpr e]
| .tuple es => tag 0x17 [encList encExpr es]
| .list es => tag 0x18 [encList encExpr es]
| .call f ax => tag 0x19 [encExpr f, encList encExpr ax]
| .gridcall f ix ax => tag 0x1a [encExpr f, encList encExpr ix, encList encExpr ax]

partial def decExpr : DecodeM Expr := do
match (<- next) with
| 0x10 => return .value (<- decConst)
| 0x11 => return .bvar (<- decString)
| 0x12 => return .var (<- decString) ""
| 0x13 => return .subscript (<-decExpr) (<- decList decIndex)
| 0x14 => return .binop (<- decString) (<- decExpr) (<- decExpr)
| 0x15 => return .cond (<- decExpr) (<- decExpr) (<- decExpr)
| 0x16 => return .tuple (<- decList decExpr)
| 0x17 => return .list (<- decList decExpr)
| 0x18 => return .call (<- decExpr) (<- decList decExpr)
| 0x19 => return .gridcall (<- decExpr) (<- decList decIndex) (<- decList decExpr)
| 0x13 => return .subscript (<- decExpr) (<- decList decExpr)
| 0x14 => return .slice (<- decExpr) (<- decExpr) (<- decExpr)
| 0x15 => return .binop (<- decString) (<- decExpr) (<- decExpr)
| 0x16 => return .cond (<- decExpr) (<- decExpr) (<- decExpr)
| 0x17 => return .tuple (<- decList decExpr)
| 0x18 => return .list (<- decList decExpr)
| 0x19 => return .call (<- decExpr) (<- decList decExpr)
| 0x1a => return .gridcall (<- decExpr) (<- decList decExpr) (<- decList decExpr)
| t => throw s!"Unknown tag in Expr {t}"

partial def decIndex : DecodeM Index := do
match (<- next) with
| 0x20 => return .coord (<- decExpr)
| 0x21 => return .slice (<- decExpr) (<- decExpr) (<- decExpr)
| 0x22 => return .dots
| t => throw s!"Unknown tag in Index {t}"
end

private def chkExpr (e : Expr) : Bool :=
(decode' decExpr $ encExpr e) == some e
private def chkIndex (i : Index) : Bool :=
(decode' decIndex $ encIndex i) == some i

private def nil := Expr.value .nil
private def ndx := Index.coord nil

#guard chkExpr nil
#guard chkExpr (.bvar "var")
#guard chkExpr (.var "var" "")
#guard chkExpr (.subscript nil [ndx, ndx, ndx])
#guard chkExpr (.subscript nil [nil, nil, nil])
#guard chkExpr (.slice nil nil nil)
#guard chkExpr (.binop "op" nil nil)
#guard chkExpr (.cond nil nil nil)
#guard chkExpr (.tuple [nil, nil, nil])
#guard chkExpr (.list [nil, nil, nil])
#guard chkExpr (.call nil [nil, nil, nil])
#guard chkExpr (.gridcall nil [ndx, ndx, ndx] [nil, nil, nil])

#guard chkIndex ndx
#guard chkIndex (.slice nil nil nil)
#guard chkIndex .dots
#guard chkExpr (.gridcall nil [nil, nil, nil] [nil, nil, nil])

------------------------------------------------------------------------------
-- Statements
Expand Down
14 changes: 4 additions & 10 deletions NKL/NKI.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,22 @@ inductive Const where
| int (value: Int)
| float (value: Float)
| string (value: String)
| dots
deriving Repr, BEq, Lean.ToJson, Lean.FromJson

mutual
inductive Expr where
| value (c: Const)
| bvar (name: String)
| var (name value: String)
| subscript (tensor: Expr) (ix: List Index)
| subscript (tensor: Expr) (ix: List Expr)
| slice (l u step: Expr)
| binop (op: String) (left right: Expr)
| cond (e thn els: Expr)
| tuple (xs: List Expr)
| list (xs: List Expr)
| call (f: Expr) (args: List Expr)
| gridcall (f: Expr) (ix: List Index) (args: List Expr)
deriving Repr, BEq, Lean.ToJson, Lean.FromJson

inductive Index where
| coord (i : Expr)
| slice (l u step: Expr)
| dots
| gridcall (f: Expr) (ix: List Expr) (args: List Expr)
deriving Repr, BEq, Lean.ToJson, Lean.FromJson
end

inductive Stmt where
| ret (e: Expr)
Expand Down
16 changes: 4 additions & 12 deletions NKL/PrettyPrint.lean
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,28 @@ instance : ToString Const where
| .int i => toString i
| .float f => toString f
| .string s => s

| .dots => "..."

mutual
private partial def exps_ s l := String.intercalate s (List.map expr l)
private partial def exps := exps_ ","
private partial def ndxs l := String.intercalate "," (List.map ndx l)

private partial def expr : Expr -> String
| .value c => toString c
| .bvar s | .var s _ => s
| .subscript e ix => expr e ++ "[" ++ ndxs ix ++ "]"
| .subscript e ix => expr e ++ "[" ++ exps ix ++ "]"
| .slice l u s => exps_ ":" [l,u,s]
| .binop op l r => op ++ "(" ++ expr l ++ "," ++ expr r ++ ")"
| .cond e thn els => expr thn ++ " if " ++ expr e ++ " else " ++ expr els
| .tuple es => "(" ++ exps es ++ ")"
| .list es => "[" ++ exps es ++ "]"
| .call f es => expr f ++ "(" ++ exps es ++ ")"
| .gridcall f ix es => expr f ++ "[" ++ ndxs ix ++ "](" ++ exps es ++ ")"

private partial def ndx : Index -> String
| .coord e => expr e
| .slice l u s => exps_ ":" [l,u,s]
| .dots => "..."
| .gridcall f ix es => expr f ++ "[" ++ exps ix ++ "](" ++ exps es ++ ")"
end

instance : ToString Expr where
toString := expr

instance : ToString Index where
toString := ndx

mutual
private partial def stmts sp l :=
String.intercalate "\n" $ List.map (stmt sp) l
Expand Down
18 changes: 6 additions & 12 deletions interop/nkl/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def opr(e: ast.AST):

def const(c):
if c is None: return Nil()
elif c is Ellipsis: return Dots()
elif isinstance(c, bool): return Bool(c)
elif isinstance(c, int): return Int(c)
elif isinstance(c, float): return Float(c)
Expand Down Expand Up @@ -93,16 +94,6 @@ def translate(self, tree: ast.mod) -> Fun:
case _:
assert 0, "expecting function definition"

# expressions appearing under a subscript
def index(self, e: ast.expr):
match e:
case ast.Ellipsis():
return Dots()
case ast.Slice(l,u,s):
return Slice(self.expr(l), self.expr(u), self.expr(s))
case _:
return Coord(self.expr(e))

def expr(self, e: ast.expr):
if e is None:
return Value(Nil())
Expand Down Expand Up @@ -155,9 +146,12 @@ def compare(ops, l, rs):

# subscript
case ast.Subscript(l, ast.Tuple(ix)):
return Subscript(self.expr(l), list(map(self.index, ix)))
return Subscript(self.expr(l), list(map(self.expr, ix)))
case ast.Subscript(l, ix):
return Subscript(self.expr(l), [self.index(ix)])
return Subscript(self.expr(l), [self.expr(ix)])
# only appears under subscript
case ast.Slice(l,u,s):
return Slice(self.expr(l), self.expr(u), self.expr(s))

# literals
case ast.Tuple(es):
Expand Down
8 changes: 4 additions & 4 deletions interop/test/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ def test_vars(term):

@pytest.mark.parametrize("term",
[ [],
[Coord(var)],
[Coord(var), Coord(nil)],
[var],
[var, nil],
[Slice(var, var, var)],
[Slice(nil, nil, nil), Coord(nil)],
[Dots()],
[Slice(nil, nil, nil), nil],
[Value(Dots())],
])
def test_subscript(term):
load(Fun("f", [], [Assign(Var("a", 0), Subscript(var, term))]))

0 comments on commit 162796b

Please sign in to comment.