diff --git a/Export.lean b/Export.lean index 061326d..9b81163 100644 --- a/Export.lean +++ b/Export.lean @@ -82,7 +82,6 @@ run_meta flip List.forM (genPython h) [ `NKL.Const , `NKL.Expr - , `NKL.Index , `NKL.Stmt , `NKL.Fun ] diff --git a/NKL/Encode.lean b/NKL/Encode.lean index b57e787..49c5a6e 100644 --- a/NKL/Encode.lean +++ b/NKL/Encode.lean @@ -163,6 +163,7 @@ def encConst : Const -> ByteArray | .int i => tag 0x03 [encInt i] | .float f => tag 0x04 [encFloat f] | .string s => tag 0x05 [encString s] + | .dots => tag 0x06 [] def decConst : DecodeM Const := do let val <- next @@ -173,6 +174,7 @@ def decConst : DecodeM Const := do | 0x03 => return .int (<- decInt) | 0x04 => return .float (<- decFloat) | 0x05 => return .string (<- decString) + | 0x06 => return .dots | _ => throw s!"Unknown Const tag value {val}" private def chkConst (c: Const) : Bool := @@ -184,74 +186,55 @@ private def chkConst (c: Const) : Bool := #guard chkConst (.int 1) #guard chkConst (.float 1.0) #guard chkConst (.string "str") +#guard chkConst .dots ------------------------------------------------------------------------------ -- Expressions -mutual partial def encExpr : Expr -> ByteArray | .value c => tag 0x10 [encConst c] | .bvar s => tag 0x11 [encString s] | .var s _ => tag 0x12 [encString s] - | .subscript e ix => tag 0x13 [encExpr e, encList encIndex ix] - | .binop op l r => tag 0x14 [encString op, encExpr l, encExpr r] - | .cond c t e => tag 0x15 [encExpr c, encExpr t, encExpr e] - | .tuple es => tag 0x16 [encList encExpr es] - | .list es => tag 0x17 [encList encExpr es] - | .call f ax => tag 0x18 [encExpr f, encList encExpr ax] - | .gridcall f ix ax => tag 0x19 [encExpr f, encList encIndex ix, encList encExpr ax] - -partial def encIndex : Index -> ByteArray - | .coord i => tag 0x20 [encExpr i] - | .slice l u step => tag 0x21 [encExpr l, encExpr u, encExpr step] - | .dots => tag 0x22 [] -end - -mutual + | .subscript e ix => tag 0x13 [encExpr e, encList encExpr ix] + | .slice l u step => tag 0x14 [encExpr l, encExpr u, encExpr step] + | .binop op l r => tag 0x15 [encString op, encExpr l, encExpr r] + | .cond c t e => tag 0x16 [encExpr c, encExpr t, encExpr e] + | .tuple es => tag 0x17 [encList encExpr es] + | .list es => tag 0x18 [encList encExpr es] + | .call f ax => tag 0x19 [encExpr f, encList encExpr ax] + | .gridcall f ix ax => tag 0x1a [encExpr f, encList encExpr ix, encList encExpr ax] + partial def decExpr : DecodeM Expr := do match (<- next) with | 0x10 => return .value (<- decConst) | 0x11 => return .bvar (<- decString) | 0x12 => return .var (<- decString) "" - | 0x13 => return .subscript (<-decExpr) (<- decList decIndex) - | 0x14 => return .binop (<- decString) (<- decExpr) (<- decExpr) - | 0x15 => return .cond (<- decExpr) (<- decExpr) (<- decExpr) - | 0x16 => return .tuple (<- decList decExpr) - | 0x17 => return .list (<- decList decExpr) - | 0x18 => return .call (<- decExpr) (<- decList decExpr) - | 0x19 => return .gridcall (<- decExpr) (<- decList decIndex) (<- decList decExpr) + | 0x13 => return .subscript (<- decExpr) (<- decList decExpr) + | 0x14 => return .slice (<- decExpr) (<- decExpr) (<- decExpr) + | 0x15 => return .binop (<- decString) (<- decExpr) (<- decExpr) + | 0x16 => return .cond (<- decExpr) (<- decExpr) (<- decExpr) + | 0x17 => return .tuple (<- decList decExpr) + | 0x18 => return .list (<- decList decExpr) + | 0x19 => return .call (<- decExpr) (<- decList decExpr) + | 0x1a => return .gridcall (<- decExpr) (<- decList decExpr) (<- decList decExpr) | t => throw s!"Unknown tag in Expr {t}" -partial def decIndex : DecodeM Index := do - match (<- next) with - | 0x20 => return .coord (<- decExpr) - | 0x21 => return .slice (<- decExpr) (<- decExpr) (<- decExpr) - | 0x22 => return .dots - | t => throw s!"Unknown tag in Index {t}" -end - private def chkExpr (e : Expr) : Bool := (decode' decExpr $ encExpr e) == some e -private def chkIndex (i : Index) : Bool := - (decode' decIndex $ encIndex i) == some i private def nil := Expr.value .nil -private def ndx := Index.coord nil #guard chkExpr nil #guard chkExpr (.bvar "var") #guard chkExpr (.var "var" "") -#guard chkExpr (.subscript nil [ndx, ndx, ndx]) +#guard chkExpr (.subscript nil [nil, nil, nil]) +#guard chkExpr (.slice nil nil nil) #guard chkExpr (.binop "op" nil nil) #guard chkExpr (.cond nil nil nil) #guard chkExpr (.tuple [nil, nil, nil]) #guard chkExpr (.list [nil, nil, nil]) #guard chkExpr (.call nil [nil, nil, nil]) -#guard chkExpr (.gridcall nil [ndx, ndx, ndx] [nil, nil, nil]) - -#guard chkIndex ndx -#guard chkIndex (.slice nil nil nil) -#guard chkIndex .dots +#guard chkExpr (.gridcall nil [nil, nil, nil] [nil, nil, nil]) ------------------------------------------------------------------------------ -- Statements diff --git a/NKL/NKI.lean b/NKL/NKI.lean index e357528..8f124b8 100644 --- a/NKL/NKI.lean +++ b/NKL/NKI.lean @@ -20,28 +20,22 @@ inductive Const where | int (value: Int) | float (value: Float) | string (value: String) + | dots deriving Repr, BEq, Lean.ToJson, Lean.FromJson -mutual inductive Expr where | value (c: Const) | bvar (name: String) | var (name value: String) - | subscript (tensor: Expr) (ix: List Index) + | subscript (tensor: Expr) (ix: List Expr) + | slice (l u step: Expr) | binop (op: String) (left right: Expr) | cond (e thn els: Expr) | tuple (xs: List Expr) | list (xs: List Expr) | call (f: Expr) (args: List Expr) - | gridcall (f: Expr) (ix: List Index) (args: List Expr) - deriving Repr, BEq, Lean.ToJson, Lean.FromJson - -inductive Index where - | coord (i : Expr) - | slice (l u step: Expr) - | dots + | gridcall (f: Expr) (ix: List Expr) (args: List Expr) deriving Repr, BEq, Lean.ToJson, Lean.FromJson -end inductive Stmt where | ret (e: Expr) diff --git a/NKL/PrettyPrint.lean b/NKL/PrettyPrint.lean index ec7e87b..92ccf7d 100644 --- a/NKL/PrettyPrint.lean +++ b/NKL/PrettyPrint.lean @@ -15,36 +15,28 @@ instance : ToString Const where | .int i => toString i | .float f => toString f | .string s => s - + | .dots => "..." mutual private partial def exps_ s l := String.intercalate s (List.map expr l) private partial def exps := exps_ "," -private partial def ndxs l := String.intercalate "," (List.map ndx l) private partial def expr : Expr -> String | .value c => toString c | .bvar s | .var s _ => s - | .subscript e ix => expr e ++ "[" ++ ndxs ix ++ "]" + | .subscript e ix => expr e ++ "[" ++ exps ix ++ "]" + | .slice l u s => exps_ ":" [l,u,s] | .binop op l r => op ++ "(" ++ expr l ++ "," ++ expr r ++ ")" | .cond e thn els => expr thn ++ " if " ++ expr e ++ " else " ++ expr els | .tuple es => "(" ++ exps es ++ ")" | .list es => "[" ++ exps es ++ "]" | .call f es => expr f ++ "(" ++ exps es ++ ")" - | .gridcall f ix es => expr f ++ "[" ++ ndxs ix ++ "](" ++ exps es ++ ")" - -private partial def ndx : Index -> String - | .coord e => expr e - | .slice l u s => exps_ ":" [l,u,s] - | .dots => "..." + | .gridcall f ix es => expr f ++ "[" ++ exps ix ++ "](" ++ exps es ++ ")" end instance : ToString Expr where toString := expr -instance : ToString Index where - toString := ndx - mutual private partial def stmts sp l := String.intercalate "\n" $ List.map (stmt sp) l diff --git a/interop/nkl/loader.py b/interop/nkl/loader.py index b51f2be..5a4557f 100644 --- a/interop/nkl/loader.py +++ b/interop/nkl/loader.py @@ -29,6 +29,7 @@ def opr(e: ast.AST): def const(c): if c is None: return Nil() + elif c is Ellipsis: return Dots() elif isinstance(c, bool): return Bool(c) elif isinstance(c, int): return Int(c) elif isinstance(c, float): return Float(c) @@ -93,16 +94,6 @@ def translate(self, tree: ast.mod) -> Fun: case _: assert 0, "expecting function definition" - # expressions appearing under a subscript - def index(self, e: ast.expr): - match e: - case ast.Ellipsis(): - return Dots() - case ast.Slice(l,u,s): - return Slice(self.expr(l), self.expr(u), self.expr(s)) - case _: - return Coord(self.expr(e)) - def expr(self, e: ast.expr): if e is None: return Value(Nil()) @@ -155,9 +146,12 @@ def compare(ops, l, rs): # subscript case ast.Subscript(l, ast.Tuple(ix)): - return Subscript(self.expr(l), list(map(self.index, ix))) + return Subscript(self.expr(l), list(map(self.expr, ix))) case ast.Subscript(l, ix): - return Subscript(self.expr(l), [self.index(ix)]) + return Subscript(self.expr(l), [self.expr(ix)]) + # only appears under subscript + case ast.Slice(l,u,s): + return Slice(self.expr(l), self.expr(u), self.expr(s)) # literals case ast.Tuple(es): diff --git a/interop/test/test_json.py b/interop/test/test_json.py index 28a3096..034264c 100644 --- a/interop/test/test_json.py +++ b/interop/test/test_json.py @@ -40,11 +40,11 @@ def test_vars(term): @pytest.mark.parametrize("term", [ [], - [Coord(var)], - [Coord(var), Coord(nil)], + [var], + [var, nil], [Slice(var, var, var)], - [Slice(nil, nil, nil), Coord(nil)], - [Dots()], + [Slice(nil, nil, nil), nil], + [Value(Dots())], ]) def test_subscript(term): load(Fun("f", [], [Assign(Var("a", 0), Subscript(var, term))]))