diff --git a/ast/ast.go b/ast/ast.go index cfc762333..98042df02 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -72,22 +72,16 @@ func (n *NodeBase) SetFreeVariables(idents Identifiers) { // --------------------------------------------------------------------------- -// +gen stringer -type CompKind int - -const ( - CompFor CompKind = iota - CompIf -) +type IfSpec struct { + Expr Node +} -// TODO(sbarzowski) separate types for two kinds -// TODO(sbarzowski) bonus points for attaching ifs to the previous for -type CompSpec struct { - Kind CompKind - VarName *Identifier // nil when kind != compSpecFor - Expr Node +type ForSpec struct { + VarName Identifier + Expr Node + Conditions []IfSpec + Outer *ForSpec } -type CompSpecs []CompSpec // --------------------------------------------------------------------------- @@ -136,7 +130,7 @@ type ArrayComp struct { NodeBase Body Node TrailingComma bool - Specs CompSpecs + Spec ForSpec } // --------------------------------------------------------------------------- @@ -486,7 +480,7 @@ type ObjectComp struct { NodeBase Fields ObjectFields TrailingComma bool - Specs CompSpecs + Spec ForSpec } // --------------------------------------------------------------------------- diff --git a/ast/compkind_stringer.go b/ast/compkind_stringer.go deleted file mode 100644 index dfd76e0cb..000000000 --- a/ast/compkind_stringer.go +++ /dev/null @@ -1,20 +0,0 @@ -// Generated by: main -// TypeWriter: stringer -// Directive: +gen on astCompKind - -package ast - -import ( - "fmt" -) - -const _astCompKind_name = "astCompForastCompIf" - -var _astCompKind_index = [...]uint8{0, 10, 19} - -func (i CompKind) String() string { - if i < 0 || i+1 >= CompKind(len(_astCompKind_index)) { - return fmt.Sprintf("astCompKind(%d)", i) - } - return _astCompKind_name[_astCompKind_index[i]:_astCompKind_index[i+1]] -} diff --git a/builtins.go b/builtins.go index 2b063c7bd..0f0e7d0ef 100644 --- a/builtins.go +++ b/builtins.go @@ -184,6 +184,32 @@ func builtinMakeArray(e *evaluator, szp potentialValue, funcp potentialValue) (v return makeValueArray(elems), nil } +func builtinFlatMap(e *evaluator, funcp potentialValue, arrp potentialValue) (value, error) { + arr, err := e.evaluateArray(arrp) + if err != nil { + return nil, err + } + fun, err := e.evaluateFunction(funcp) + if err != nil { + return nil, err + } + num := int(arr.length()) + // Start with capacity of the original array. + // This may spare us a few reallocations. + // TODO(sbarzowski) verify that it actually helps + elems := make([]potentialValue, 0, num) + for i := 0; i < num; i++ { + returned, err := e.evaluateArray(fun.call(args(arr.elements[i]))) + if err != nil { + return nil, err + } + for _, elem := range returned.elements { + elems = append(elems, elem) + } + } + return makeValueArray(elems), nil +} + func builtinNegation(e *evaluator, xp potentialValue) (value, error) { x, err := e.evaluateBoolean(xp) if err != nil { @@ -448,6 +474,7 @@ var funcBuiltins = map[string]evalCallable{ "length": &UnaryBuiltin{name: "length", function: builtinLength, parameters: ast.Identifiers{"x"}}, "toString": &UnaryBuiltin{name: "toString", function: builtinToString, parameters: ast.Identifiers{"x"}}, "makeArray": &BinaryBuiltin{name: "makeArray", function: builtinMakeArray, parameters: ast.Identifiers{"sz", "func"}}, + "flatMap": &BinaryBuiltin{name: "flatMap", function: builtinFlatMap, parameters: ast.Identifiers{"func", "arr"}}, "primitiveEquals": &BinaryBuiltin{name: "primitiveEquals", function: primitiveEquals, parameters: ast.Identifiers{"sz", "func"}}, "objectFieldsEx": &BinaryBuiltin{name: "objectFields", function: builtinObjectFieldsEx, parameters: ast.Identifiers{"obj", "hidden"}}, "objectHasEx": &TernaryBuiltin{name: "objectHasEx", function: builtinObjectHasEx, parameters: ast.Identifiers{"obj", "fname", "hidden"}}, diff --git a/desugarer.go b/desugarer.go index c01552b9a..9c3af400f 100644 --- a/desugarer.go +++ b/desugarer.go @@ -210,19 +210,31 @@ func desugarFields(location ast.LocationRange, fields *ast.ObjectFields, objLeve return nil } -func desugarArrayComp(astComp *ast.ArrayComp, objLevel int) (ast.Node, error) { - return &ast.LiteralNull{}, nil - // TODO(sbarzowski) this - switch astComp.Specs[0].Kind { - case ast.CompFor: - panic("TODO") - case ast.CompIf: - panic("TODO") - default: - panic("TODO") +func simpleLambda(body ast.Node, paramName ast.Identifier) ast.Node { + return &ast.Function{ + Body: body, + Parameters: ast.Identifiers{paramName}, } } +func desugarForSpec(inside ast.Node, forSpec *ast.ForSpec) (ast.Node, error) { + // TODO(sbarzowski) support ifs + function := simpleLambda(inside, forSpec.VarName) + current := buildStdCall("flatMap", function, forSpec.Expr) + if forSpec.Outer == nil { + return current, nil + } + return desugarForSpec(current, forSpec.Outer) +} + +func wrapInArray(inside ast.Node) ast.Node { + return &ast.Array{Elements: ast.Nodes{inside}} +} + +func desugarArrayComp(comp *ast.ArrayComp, objLevel int) (ast.Node, error) { + return desugarForSpec(wrapInArray(comp.Body), &comp.Spec) +} + func desugarObjectComp(astComp *ast.ObjectComp, objLevel int) (ast.Node, error) { return &ast.LiteralNull{}, nil // TODO(sbarzowski) this @@ -306,6 +318,10 @@ func desugar(astPtr *ast.Node, objLevel int) (err error) { return err } *astPtr = comp + err = desugar(astPtr, objLevel) + if err != nil { + return err + } case *ast.Assert: if node.Message == nil { diff --git a/evaluator.go b/evaluator.go index 6e4438eb4..3f29b88ad 100644 --- a/evaluator.go +++ b/evaluator.go @@ -110,6 +110,23 @@ func (e *evaluator) evaluateBoolean(pv potentialValue) (*valueBoolean, error) { return e.getBoolean(v) } +func (e *evaluator) getArray(val value) (*valueArray, error) { + switch v := val.(type) { + case *valueArray: + return v, nil + default: + return nil, e.typeErrorSpecific(val, &valueArray{}) + } +} + +func (e *evaluator) evaluateArray(pv potentialValue) (*valueArray, error) { + v, err := e.evaluate(pv) + if err != nil { + return nil, err + } + return e.getArray(v) +} + func (e *evaluator) getFunction(val value) (*valueFunction, error) { switch v := val.(type) { case *valueFunction: diff --git a/parser/parser.go b/parser/parser.go index 6eb7c1e75..d248d5d1d 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -353,7 +353,7 @@ func (p *parser) parseObjectRemainder(tok *token) (ast.Node, *token, error) { if field.Kind != ast.ObjectFieldExpr { return nil, nil, MakeStaticError("Object comprehensions can only have [e] fields.", next.loc) } - specs, last, err := p.parseComprehensionSpecs(tokenBraceR) + spec, last, err := p.parseComprehensionSpecs(tokenBraceR) if err != nil { return nil, nil, err } @@ -361,7 +361,7 @@ func (p *parser) parseObjectRemainder(tok *token) (ast.Node, *token, error) { NodeBase: ast.NewNodeBaseLoc(locFromTokens(tok, last)), Fields: fields, TrailingComma: gotComma, - Specs: *specs, + Spec: *spec, }, last, nil } @@ -537,50 +537,53 @@ func (p *parser) parseObjectRemainder(tok *token) (ast.Node, *token, error) { } /* parses for x in expr for y in expr if expr for z in expr ... */ -func (p *parser) parseComprehensionSpecs(end tokenKind) (*ast.CompSpecs, *token, error) { - var specs ast.CompSpecs - for { - varID, err := p.popExpect(tokenIdentifier) - if err != nil { - return nil, nil, err - } - id := ast.Identifier(varID.data) - _, err = p.popExpect(tokenIn) - if err != nil { - return nil, nil, err - } - arr, err := p.parse(maxPrecedence) +func (p *parser) parseComprehensionSpecs(end tokenKind) (*ast.ForSpec, *token, error) { + var ifSpecs []ast.IfSpec + + varID, err := p.popExpect(tokenIdentifier) + if err != nil { + return nil, nil, err + } + id := ast.Identifier(varID.data) + _, err = p.popExpect(tokenIn) + if err != nil { + return nil, nil, err + } + arr, err := p.parse(maxPrecedence) + if err != nil { + return nil, nil, err + } + forSpec := &ast.ForSpec{ + VarName: id, + Expr: arr, + } + + maybeIf := p.pop() + for ; maybeIf.kind == tokenIf; maybeIf = p.pop() { + cond, err := p.parse(maxPrecedence) if err != nil { return nil, nil, err } - specs = append(specs, ast.CompSpec{ - Kind: ast.CompFor, - VarName: &id, - Expr: arr, + ifSpecs = append(ifSpecs, ast.IfSpec{ + Expr: cond, }) + } + forSpec.Conditions = ifSpecs + if maybeIf.kind == end { + return forSpec, maybeIf, nil + } - maybeIf := p.pop() - for ; maybeIf.kind == tokenIf; maybeIf = p.pop() { - cond, err := p.parse(maxPrecedence) - if err != nil { - return nil, nil, err - } - specs = append(specs, ast.CompSpec{ - Kind: ast.CompIf, - VarName: nil, - Expr: cond, - }) - } - if maybeIf.kind == end { - return &specs, maybeIf, nil - } - - if maybeIf.kind != tokenFor { - return nil, nil, MakeStaticError( - fmt.Sprintf("Expected for, if or %v after for clause, got: %v", end, maybeIf), maybeIf.loc) - } + if maybeIf.kind != tokenFor { + return nil, nil, MakeStaticError( + fmt.Sprintf("Expected for, if or %v after for clause, got: %v", end, maybeIf), maybeIf.loc) + } + nextSpec, last, err := p.parseComprehensionSpecs(end) + if err != nil { + return nil, nil, err } + nextSpec.Outer = forSpec + return nextSpec, last, nil } // Assumes that the leading '[' has already been consumed and passed as tok. @@ -609,7 +612,7 @@ func (p *parser) parseArray(tok *token) (ast.Node, error) { if next.kind == tokenFor { // It's a comprehension p.pop() - specs, last, err := p.parseComprehensionSpecs(tokenBracketR) + spec, last, err := p.parseComprehensionSpecs(tokenBracketR) if err != nil { return nil, err } @@ -617,7 +620,7 @@ func (p *parser) parseArray(tok *token) (ast.Node, error) { NodeBase: ast.NewNodeBaseLoc(locFromTokens(tok, last)), Body: first, TrailingComma: gotComma, - Specs: *specs, + Spec: *spec, }, nil } // Not a comprehension: It can have more elements. diff --git a/testdata/arrcomp.golden b/testdata/arrcomp.golden new file mode 100644 index 000000000..1e3ec7217 --- /dev/null +++ b/testdata/arrcomp.golden @@ -0,0 +1 @@ +[ ] diff --git a/testdata/arrcomp.input b/testdata/arrcomp.input new file mode 100644 index 000000000..16da891df --- /dev/null +++ b/testdata/arrcomp.input @@ -0,0 +1 @@ +[x for x in []] diff --git a/testdata/arrcomp2.golden b/testdata/arrcomp2.golden new file mode 100644 index 000000000..a238abc06 --- /dev/null +++ b/testdata/arrcomp2.golden @@ -0,0 +1,5 @@ +[ + 1, + 2, + 3 +] diff --git a/testdata/arrcomp2.input b/testdata/arrcomp2.input new file mode 100644 index 000000000..ff04f502f --- /dev/null +++ b/testdata/arrcomp2.input @@ -0,0 +1 @@ +[x for x in [1, 2, 3]] diff --git a/testdata/arrcomp3.golden b/testdata/arrcomp3.golden new file mode 100644 index 000000000..8586b7829 --- /dev/null +++ b/testdata/arrcomp3.golden @@ -0,0 +1,38 @@ +[ + [ + 1, + "a" + ], + [ + 1, + "b" + ], + [ + 1, + "c" + ], + [ + 2, + "a" + ], + [ + 2, + "b" + ], + [ + 2, + "c" + ], + [ + 3, + "a" + ], + [ + 3, + "b" + ], + [ + 3, + "c" + ] +] diff --git a/testdata/arrcomp3.input b/testdata/arrcomp3.input new file mode 100644 index 000000000..0d92baab3 --- /dev/null +++ b/testdata/arrcomp3.input @@ -0,0 +1,2 @@ +[[var, var2] for var in [1, 2, 3] for var2 in ["a", "b", "c"]] + diff --git a/testdata/arrcomp4.golden b/testdata/arrcomp4.golden new file mode 100644 index 000000000..2d3b8efc2 --- /dev/null +++ b/testdata/arrcomp4.golden @@ -0,0 +1,5 @@ +[ + 2, + 3, + 4 +] diff --git a/testdata/arrcomp4.input b/testdata/arrcomp4.input new file mode 100644 index 000000000..26858b630 --- /dev/null +++ b/testdata/arrcomp4.input @@ -0,0 +1 @@ +[y for x in [1, 2, 3] for y in [x + 1]] diff --git a/testdata/arrcomp5.golden b/testdata/arrcomp5.golden new file mode 100644 index 000000000..5bbe24c08 --- /dev/null +++ b/testdata/arrcomp5.golden @@ -0,0 +1 @@ +testdata/arrcomp5:1:14-15 Unknown variable: x diff --git a/testdata/arrcomp5.input b/testdata/arrcomp5.input new file mode 100644 index 000000000..c0b8e809f --- /dev/null +++ b/testdata/arrcomp5.input @@ -0,0 +1,2 @@ +[y for y in [x + 1] for x in [1, 2, 3]] + diff --git a/testdata/std.flatmap.golden b/testdata/std.flatmap.golden new file mode 100644 index 000000000..1e3ec7217 --- /dev/null +++ b/testdata/std.flatmap.golden @@ -0,0 +1 @@ +[ ] diff --git a/testdata/std.flatmap.input b/testdata/std.flatmap.input new file mode 100644 index 000000000..b35cfb365 --- /dev/null +++ b/testdata/std.flatmap.input @@ -0,0 +1 @@ +std.flatMap(function(x) [], [1, 2, 3]) diff --git a/testdata/std.flatmap2.golden b/testdata/std.flatmap2.golden new file mode 100644 index 000000000..74a3da67a --- /dev/null +++ b/testdata/std.flatmap2.golden @@ -0,0 +1,17 @@ +[ + 1, + 2, + 3, + 1, + 2, + 3, + 1, + 2, + 3, + 1, + 2, + 3, + 1, + 2, + 3 +] diff --git a/testdata/std.flatmap2.input b/testdata/std.flatmap2.input new file mode 100644 index 000000000..246ea5b76 --- /dev/null +++ b/testdata/std.flatmap2.input @@ -0,0 +1 @@ +std.flatMap(function(x) [1, 2, 3], ["a", 2, 3, 4, 5]) diff --git a/testdata/std.flatmap3.golden b/testdata/std.flatmap3.golden new file mode 100644 index 000000000..1e3ec7217 --- /dev/null +++ b/testdata/std.flatmap3.golden @@ -0,0 +1 @@ +[ ] diff --git a/testdata/std.flatmap3.input b/testdata/std.flatmap3.input new file mode 100644 index 000000000..dce95ef1d --- /dev/null +++ b/testdata/std.flatmap3.input @@ -0,0 +1 @@ +std.flatMap(function(x) error "never happens", []) diff --git a/testdata/std.flatmap4.golden b/testdata/std.flatmap4.golden new file mode 100644 index 000000000..56613b44b --- /dev/null +++ b/testdata/std.flatmap4.golden @@ -0,0 +1,12 @@ +[ + 1, + 1, + 2, + 2, + 3, + 3, + 4, + 4, + 5, + 5 +] diff --git a/testdata/std.flatmap4.input b/testdata/std.flatmap4.input new file mode 100644 index 000000000..34a40fb84 --- /dev/null +++ b/testdata/std.flatmap4.input @@ -0,0 +1 @@ +std.flatMap(function(x) [x, x], [1, 2, 3, 4, 5]) diff --git a/testdata/std.flatmap5.golden b/testdata/std.flatmap5.golden new file mode 100644 index 000000000..bc150003f --- /dev/null +++ b/testdata/std.flatmap5.golden @@ -0,0 +1 @@ +RUNTIME ERROR: a diff --git a/testdata/std.flatmap5.input b/testdata/std.flatmap5.input new file mode 100644 index 000000000..1a1de7674 --- /dev/null +++ b/testdata/std.flatmap5.input @@ -0,0 +1,2 @@ +local failWith(x) = error x; +std.type(std.flatMap(failWith, ["a", "b", "c"])) diff --git a/value.go b/value.go index 29b85802a..5f3569070 100644 --- a/value.go +++ b/value.go @@ -131,9 +131,24 @@ type valueArray struct { elements []potentialValue } +func (arr *valueArray) length() int { + return len(arr.elements) +} + func makeValueArray(elements []potentialValue) *valueArray { + // We don't want to keep a bigger array than necessary + // so we create a new one with minimal capacity + var arrayElems []potentialValue + if len(elements) == cap(elements) { + arrayElems = elements + } else { + arrayElems = make([]potentialValue, len(elements)) + for i := range elements { + arrayElems[i] = elements[i] + } + } return &valueArray{ - elements: elements, + elements: arrayElems, } }