From 023633cacacb80767c891b8a0cc6871db8f4e062 Mon Sep 17 00:00:00 2001 From: Paul Govereau Date: Wed, 20 Nov 2024 08:33:12 -0500 Subject: [PATCH] refactor: remove Index type from AST This change brings the NKI AST closer to the original Python AST, at the expense of losing some syntactic constraints. This will make it easier to do more of the processing on the Lean side. --- Export.lean | 1 - NKL/Encode.lean | 63 ++++++++++++++------------------------- NKL/NKI.lean | 14 +++------ NKL/PrettyPrint.lean | 16 +++------- interop/nkl/loader.py | 18 ++++------- interop/test/test_json.py | 8 ++--- 6 files changed, 41 insertions(+), 79 deletions(-) diff --git a/Export.lean b/Export.lean index 061326d..9b81163 100644 --- a/Export.lean +++ b/Export.lean @@ -82,7 +82,6 @@ run_meta flip List.forM (genPython h) [ `NKL.Const , `NKL.Expr - , `NKL.Index , `NKL.Stmt , `NKL.Fun ] diff --git a/NKL/Encode.lean b/NKL/Encode.lean index b57e787..49c5a6e 100644 --- a/NKL/Encode.lean +++ b/NKL/Encode.lean @@ -163,6 +163,7 @@ def encConst : Const -> ByteArray | .int i => tag 0x03 [encInt i] | .float f => tag 0x04 [encFloat f] | .string s => tag 0x05 [encString s] + | .dots => tag 0x06 [] def decConst : DecodeM Const := do let val <- next @@ -173,6 +174,7 @@ def decConst : DecodeM Const := do | 0x03 => return .int (<- decInt) | 0x04 => return .float (<- decFloat) | 0x05 => return .string (<- decString) + | 0x06 => return .dots | _ => throw s!"Unknown Const tag value {val}" private def chkConst (c: Const) : Bool := @@ -184,74 +186,55 @@ private def chkConst (c: Const) : Bool := #guard chkConst (.int 1) #guard chkConst (.float 1.0) #guard chkConst (.string "str") +#guard chkConst .dots ------------------------------------------------------------------------------ -- Expressions -mutual partial def encExpr : Expr -> ByteArray | .value c => tag 0x10 [encConst c] | .bvar s => tag 0x11 [encString s] | .var s _ => tag 0x12 [encString s] - | .subscript e ix => tag 0x13 [encExpr e, encList encIndex ix] - | .binop op l r => tag 0x14 [encString op, encExpr l, encExpr r] - | .cond c t e => tag 0x15 [encExpr c, encExpr t, encExpr e] - | .tuple es => tag 0x16 [encList encExpr es] - | .list es => tag 0x17 [encList encExpr es] - | .call f ax => tag 0x18 [encExpr f, encList encExpr ax] - | .gridcall f ix ax => tag 0x19 [encExpr f, encList encIndex ix, encList encExpr ax] - -partial def encIndex : Index -> ByteArray - | .coord i => tag 0x20 [encExpr i] - | .slice l u step => tag 0x21 [encExpr l, encExpr u, encExpr step] - | .dots => tag 0x22 [] -end - -mutual + | .subscript e ix => tag 0x13 [encExpr e, encList encExpr ix] + | .slice l u step => tag 0x14 [encExpr l, encExpr u, encExpr step] + | .binop op l r => tag 0x15 [encString op, encExpr l, encExpr r] + | .cond c t e => tag 0x16 [encExpr c, encExpr t, encExpr e] + | .tuple es => tag 0x17 [encList encExpr es] + | .list es => tag 0x18 [encList encExpr es] + | .call f ax => tag 0x19 [encExpr f, encList encExpr ax] + | .gridcall f ix ax => tag 0x1a [encExpr f, encList encExpr ix, encList encExpr ax] + partial def decExpr : DecodeM Expr := do match (<- next) with | 0x10 => return .value (<- decConst) | 0x11 => return .bvar (<- decString) | 0x12 => return .var (<- decString) "" - | 0x13 => return .subscript (<-decExpr) (<- decList decIndex) - | 0x14 => return .binop (<- decString) (<- decExpr) (<- decExpr) - | 0x15 => return .cond (<- decExpr) (<- decExpr) (<- decExpr) - | 0x16 => return .tuple (<- decList decExpr) - | 0x17 => return .list (<- decList decExpr) - | 0x18 => return .call (<- decExpr) (<- decList decExpr) - | 0x19 => return .gridcall (<- decExpr) (<- decList decIndex) (<- decList decExpr) + | 0x13 => return .subscript (<- decExpr) (<- decList decExpr) + | 0x14 => return .slice (<- decExpr) (<- decExpr) (<- decExpr) + | 0x15 => return .binop (<- decString) (<- decExpr) (<- decExpr) + | 0x16 => return .cond (<- decExpr) (<- decExpr) (<- decExpr) + | 0x17 => return .tuple (<- decList decExpr) + | 0x18 => return .list (<- decList decExpr) + | 0x19 => return .call (<- decExpr) (<- decList decExpr) + | 0x1a => return .gridcall (<- decExpr) (<- decList decExpr) (<- decList decExpr) | t => throw s!"Unknown tag in Expr {t}" -partial def decIndex : DecodeM Index := do - match (<- next) with - | 0x20 => return .coord (<- decExpr) - | 0x21 => return .slice (<- decExpr) (<- decExpr) (<- decExpr) - | 0x22 => return .dots - | t => throw s!"Unknown tag in Index {t}" -end - private def chkExpr (e : Expr) : Bool := (decode' decExpr $ encExpr e) == some e -private def chkIndex (i : Index) : Bool := - (decode' decIndex $ encIndex i) == some i private def nil := Expr.value .nil -private def ndx := Index.coord nil #guard chkExpr nil #guard chkExpr (.bvar "var") #guard chkExpr (.var "var" "") -#guard chkExpr (.subscript nil [ndx, ndx, ndx]) +#guard chkExpr (.subscript nil [nil, nil, nil]) +#guard chkExpr (.slice nil nil nil) #guard chkExpr (.binop "op" nil nil) #guard chkExpr (.cond nil nil nil) #guard chkExpr (.tuple [nil, nil, nil]) #guard chkExpr (.list [nil, nil, nil]) #guard chkExpr (.call nil [nil, nil, nil]) -#guard chkExpr (.gridcall nil [ndx, ndx, ndx] [nil, nil, nil]) - -#guard chkIndex ndx -#guard chkIndex (.slice nil nil nil) -#guard chkIndex .dots +#guard chkExpr (.gridcall nil [nil, nil, nil] [nil, nil, nil]) ------------------------------------------------------------------------------ -- Statements diff --git a/NKL/NKI.lean b/NKL/NKI.lean index e357528..8f124b8 100644 --- a/NKL/NKI.lean +++ b/NKL/NKI.lean @@ -20,28 +20,22 @@ inductive Const where | int (value: Int) | float (value: Float) | string (value: String) + | dots deriving Repr, BEq, Lean.ToJson, Lean.FromJson -mutual inductive Expr where | value (c: Const) | bvar (name: String) | var (name value: String) - | subscript (tensor: Expr) (ix: List Index) + | subscript (tensor: Expr) (ix: List Expr) + | slice (l u step: Expr) | binop (op: String) (left right: Expr) | cond (e thn els: Expr) | tuple (xs: List Expr) | list (xs: List Expr) | call (f: Expr) (args: List Expr) - | gridcall (f: Expr) (ix: List Index) (args: List Expr) - deriving Repr, BEq, Lean.ToJson, Lean.FromJson - -inductive Index where - | coord (i : Expr) - | slice (l u step: Expr) - | dots + | gridcall (f: Expr) (ix: List Expr) (args: List Expr) deriving Repr, BEq, Lean.ToJson, Lean.FromJson -end inductive Stmt where | ret (e: Expr) diff --git a/NKL/PrettyPrint.lean b/NKL/PrettyPrint.lean index ec7e87b..92ccf7d 100644 --- a/NKL/PrettyPrint.lean +++ b/NKL/PrettyPrint.lean @@ -15,36 +15,28 @@ instance : ToString Const where | .int i => toString i | .float f => toString f | .string s => s - + | .dots => "..." mutual private partial def exps_ s l := String.intercalate s (List.map expr l) private partial def exps := exps_ "," -private partial def ndxs l := String.intercalate "," (List.map ndx l) private partial def expr : Expr -> String | .value c => toString c | .bvar s | .var s _ => s - | .subscript e ix => expr e ++ "[" ++ ndxs ix ++ "]" + | .subscript e ix => expr e ++ "[" ++ exps ix ++ "]" + | .slice l u s => exps_ ":" [l,u,s] | .binop op l r => op ++ "(" ++ expr l ++ "," ++ expr r ++ ")" | .cond e thn els => expr thn ++ " if " ++ expr e ++ " else " ++ expr els | .tuple es => "(" ++ exps es ++ ")" | .list es => "[" ++ exps es ++ "]" | .call f es => expr f ++ "(" ++ exps es ++ ")" - | .gridcall f ix es => expr f ++ "[" ++ ndxs ix ++ "](" ++ exps es ++ ")" - -private partial def ndx : Index -> String - | .coord e => expr e - | .slice l u s => exps_ ":" [l,u,s] - | .dots => "..." + | .gridcall f ix es => expr f ++ "[" ++ exps ix ++ "](" ++ exps es ++ ")" end instance : ToString Expr where toString := expr -instance : ToString Index where - toString := ndx - mutual private partial def stmts sp l := String.intercalate "\n" $ List.map (stmt sp) l diff --git a/interop/nkl/loader.py b/interop/nkl/loader.py index b51f2be..5a4557f 100644 --- a/interop/nkl/loader.py +++ b/interop/nkl/loader.py @@ -29,6 +29,7 @@ def opr(e: ast.AST): def const(c): if c is None: return Nil() + elif c is Ellipsis: return Dots() elif isinstance(c, bool): return Bool(c) elif isinstance(c, int): return Int(c) elif isinstance(c, float): return Float(c) @@ -93,16 +94,6 @@ def translate(self, tree: ast.mod) -> Fun: case _: assert 0, "expecting function definition" - # expressions appearing under a subscript - def index(self, e: ast.expr): - match e: - case ast.Ellipsis(): - return Dots() - case ast.Slice(l,u,s): - return Slice(self.expr(l), self.expr(u), self.expr(s)) - case _: - return Coord(self.expr(e)) - def expr(self, e: ast.expr): if e is None: return Value(Nil()) @@ -155,9 +146,12 @@ def compare(ops, l, rs): # subscript case ast.Subscript(l, ast.Tuple(ix)): - return Subscript(self.expr(l), list(map(self.index, ix))) + return Subscript(self.expr(l), list(map(self.expr, ix))) case ast.Subscript(l, ix): - return Subscript(self.expr(l), [self.index(ix)]) + return Subscript(self.expr(l), [self.expr(ix)]) + # only appears under subscript + case ast.Slice(l,u,s): + return Slice(self.expr(l), self.expr(u), self.expr(s)) # literals case ast.Tuple(es): diff --git a/interop/test/test_json.py b/interop/test/test_json.py index 28a3096..034264c 100644 --- a/interop/test/test_json.py +++ b/interop/test/test_json.py @@ -40,11 +40,11 @@ def test_vars(term): @pytest.mark.parametrize("term", [ [], - [Coord(var)], - [Coord(var), Coord(nil)], + [var], + [var, nil], [Slice(var, var, var)], - [Slice(nil, nil, nil), Coord(nil)], - [Dots()], + [Slice(nil, nil, nil), nil], + [Value(Dots())], ]) def test_subscript(term): load(Fun("f", [], [Assign(Var("a", 0), Subscript(var, term))]))