Skip to content

Commit 7b5f72b

Browse files
committed
Add nil coalescing operator
1 parent 4c29199 commit 7b5f72b

13 files changed

+170
-14
lines changed

Diff for: checker/checker.go

+15-3
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@ func Check(tree *parser.Tree, config *conf.Config) (t reflect.Type, err error) {
3838
}
3939
default:
4040
if t != nil {
41-
if t.Kind() == reflect.Interface {
42-
t = t.Elem()
43-
}
4441
if t.Kind() == v.config.Expect {
4542
return t, nil
4643
}
@@ -358,6 +355,21 @@ func (v *visitor) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) {
358355
return ret, info{}
359356
}
360357

358+
case "??":
359+
if l == nil && r != nil {
360+
return r, info{}
361+
}
362+
if l != nil && r == nil {
363+
return l, info{}
364+
}
365+
if l == nil && r == nil {
366+
return nilType, info{}
367+
}
368+
if r.AssignableTo(l) {
369+
return l, info{}
370+
}
371+
return anyType, info{}
372+
361373
default:
362374
return v.error(node, "unknown operator (%v)", node.Operator)
363375

Diff for: checker/checker_test.go

+2
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ var successTests = []string{
121121
"Duration + Any == Time",
122122
"Any + Duration == Time",
123123
"Any.A?.B == nil",
124+
"(Any.Bool ?? Bool) > 0",
125+
"Bool ?? Bool",
124126
}
125127

126128
func TestCheck(t *testing.T) {

Diff for: compiler/compiler.go

+7
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,13 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) {
424424
c.compile(node.Right)
425425
c.emit(OpRange)
426426

427+
case "??":
428+
c.compile(node.Left)
429+
end := c.emit(OpJumpIfNotNil, placeholder)
430+
c.emit(OpPop)
431+
c.compile(node.Right)
432+
c.patchJump(end)
433+
427434
default:
428435
panic(fmt.Sprintf("unknown operator (%v)", node.Operator))
429436

Diff for: compiler/compiler_test.go

+19
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,25 @@ func TestCompile(t *testing.T) {
230230
Arguments: []int{0, 1, 0, 2},
231231
},
232232
},
233+
{
234+
`A ?? 1`,
235+
vm.Program{
236+
Constants: []interface{}{
237+
&runtime.Field{
238+
Index: []int{0},
239+
Path: []string{"A"},
240+
},
241+
1,
242+
},
243+
Bytecode: []vm.Opcode{
244+
vm.OpLoadField,
245+
vm.OpJumpIfNotNil,
246+
vm.OpPop,
247+
vm.OpPush,
248+
},
249+
Arguments: []int{0, 2, 0, 1},
250+
},
251+
},
233252
}
234253

235254
for _, test := range tests {

Diff for: docs/Language-Definition.md

+10-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ d>
7979
<tr>
8080
<td>Conditional</td>
8181
<td>
82-
<code>?:</code> (ternary)
82+
<code>?:</code> (ternary), <code>??</code> (nil coalescing)
8383
</td>
8484
</tr>
8585
<tr>
@@ -147,6 +147,15 @@ without checking if the struct or the map is `nil`. If the struct or the map is
147147
author?.User?.Name
148148
```
149149

150+
#### Nil coalescing
151+
152+
The `??` operator can be used to return the left-hand side if it is not `nil`,
153+
otherwise the right-hand side is returned.
154+
155+
```c++
156+
author?.User?.Name ?? "Anonymous"
157+
```
158+
150159
### Slice Operator
151160

152161
The slice operator `[:]` can be used to access a slice of an array.

Diff for: expr_test.go

+36
Original file line numberDiff line numberDiff line change
@@ -1710,6 +1710,42 @@ func TestFunction(t *testing.T) {
17101710
assert.Equal(t, 20, out)
17111711
}
17121712

1713+
// Nil coalescing operator
1714+
func TestRun_NilCoalescingOperator(t *testing.T) {
1715+
env := map[string]interface{}{
1716+
"foo": map[string]interface{}{
1717+
"bar": "value",
1718+
},
1719+
}
1720+
1721+
t.Run("value", func(t *testing.T) {
1722+
p, err := expr.Compile(`foo.bar ?? "default"`, expr.Env(env))
1723+
assert.NoError(t, err)
1724+
1725+
out, err := expr.Run(p, env)
1726+
assert.NoError(t, err)
1727+
assert.Equal(t, "value", out)
1728+
})
1729+
1730+
t.Run("default", func(t *testing.T) {
1731+
p, err := expr.Compile(`foo.baz ?? "default"`, expr.Env(env))
1732+
assert.NoError(t, err)
1733+
1734+
out, err := expr.Run(p, env)
1735+
assert.NoError(t, err)
1736+
assert.Equal(t, "default", out)
1737+
})
1738+
1739+
t.Run("default with chain", func(t *testing.T) {
1740+
p, err := expr.Compile(`foo?.bar ?? "default"`, expr.Env(env))
1741+
assert.NoError(t, err)
1742+
1743+
out, err := expr.Run(p, map[string]interface{}{})
1744+
assert.NoError(t, err)
1745+
assert.Equal(t, "default", out)
1746+
})
1747+
}
1748+
17131749
// Mock types
17141750

17151751
type mockEnv struct {

Diff for: parser/lexer/lexer_test.go

+9
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,15 @@ var lexTests = []lexTest{
180180
{Kind: EOF},
181181
},
182182
},
183+
{
184+
`foo ?? bar`,
185+
[]Token{
186+
{Kind: Identifier, Value: "foo"},
187+
{Kind: Operator, Value: "??"},
188+
{Kind: Identifier, Value: "bar"},
189+
{Kind: EOF},
190+
},
191+
},
183192
}
184193

185194
func compareTokens(i1, i2 []Token) bool {

Diff for: parser/lexer/state.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ func not(l *lexer) stateFn {
156156
}
157157

158158
func questionMark(l *lexer) stateFn {
159-
l.accept(".")
159+
l.accept(".?")
160160
l.emit(Operator)
161161
return root
162162
}

Diff for: parser/parser.go

+21-9
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ var binaryOperators = map[string]operator{
5858
"%": {60, left},
5959
"**": {100, right},
6060
"^": {100, right},
61+
"??": {500, left},
6162
}
6263

6364
var builtins = map[string]builtin{
@@ -113,9 +114,13 @@ func Parse(input string) (*Tree, error) {
113114
}
114115

115116
func (p *parser) error(format string, args ...interface{}) {
117+
p.errorAt(p.current, format, args...)
118+
}
119+
120+
func (p *parser) errorAt(token Token, format string, args ...interface{}) {
116121
if p.err == nil { // show first error
117122
p.err = &file.Error{
118-
Location: p.current.Location,
123+
Location: token.Location,
119124
Message: fmt.Sprintf(format, args...),
120125
}
121126
}
@@ -143,22 +148,28 @@ func (p *parser) expect(kind Kind, values ...string) {
143148
func (p *parser) parseExpression(precedence int) Node {
144149
nodeLeft := p.parsePrimary()
145150

146-
token := p.current
147-
for token.Is(Operator) && p.err == nil {
151+
lastOperator := ""
152+
opToken := p.current
153+
for opToken.Is(Operator) && p.err == nil {
148154
negate := false
149155
var notToken Token
150156

151-
if token.Is(Operator, "not") {
157+
if opToken.Is(Operator, "not") {
152158
p.next()
153159
notToken = p.current
154160
negate = true
155-
token = p.current
161+
opToken = p.current
156162
}
157163

158-
if op, ok := binaryOperators[token.Value]; ok {
164+
if op, ok := binaryOperators[opToken.Value]; ok {
159165
if op.precedence >= precedence {
160166
p.next()
161167

168+
if lastOperator == "??" && opToken.Value != "??" && !opToken.Is(Bracket, "(") {
169+
p.errorAt(opToken, "Operator (%v) and coalesce expressions (??) cannot be mixed. Wrap either by parentheses.", opToken.Value)
170+
break
171+
}
172+
162173
var nodeRight Node
163174
if op.associativity == left {
164175
nodeRight = p.parseExpression(op.precedence + 1)
@@ -167,11 +178,11 @@ func (p *parser) parseExpression(precedence int) Node {
167178
}
168179

169180
nodeLeft = &BinaryNode{
170-
Operator: token.Value,
181+
Operator: opToken.Value,
171182
Left: nodeLeft,
172183
Right: nodeRight,
173184
}
174-
nodeLeft.SetLocation(token.Location)
185+
nodeLeft.SetLocation(opToken.Location)
175186

176187
if negate {
177188
nodeLeft = &UnaryNode{
@@ -181,7 +192,8 @@ func (p *parser) parseExpression(precedence int) Node {
181192
nodeLeft.SetLocation(notToken.Location)
182193
}
183194

184-
token = p.current
195+
lastOperator = opToken.Value
196+
opToken = p.current
185197
continue
186198
}
187199
}

Diff for: parser/parser_test.go

+41
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,42 @@ func TestParse(t *testing.T) {
408408
"[]",
409409
&ArrayNode{},
410410
},
411+
{
412+
"foo ?? bar",
413+
&BinaryNode{Operator: "??",
414+
Left: &IdentifierNode{Value: "foo"},
415+
Right: &IdentifierNode{Value: "bar"}},
416+
},
417+
{
418+
"foo ?? bar ?? baz",
419+
&BinaryNode{Operator: "??",
420+
Left: &BinaryNode{Operator: "??",
421+
Left: &IdentifierNode{Value: "foo"},
422+
Right: &IdentifierNode{Value: "bar"}},
423+
Right: &IdentifierNode{Value: "baz"}},
424+
},
425+
{
426+
"foo ?? (bar || baz)",
427+
&BinaryNode{Operator: "??",
428+
Left: &IdentifierNode{Value: "foo"},
429+
Right: &BinaryNode{Operator: "||",
430+
Left: &IdentifierNode{Value: "bar"},
431+
Right: &IdentifierNode{Value: "baz"}}},
432+
},
433+
{
434+
"foo || bar ?? baz",
435+
&BinaryNode{Operator: "||",
436+
Left: &IdentifierNode{Value: "foo"},
437+
Right: &BinaryNode{Operator: "??",
438+
Left: &IdentifierNode{Value: "bar"},
439+
Right: &IdentifierNode{Value: "baz"}}},
440+
},
441+
{
442+
"foo ?? bar()",
443+
&BinaryNode{Operator: "??",
444+
Left: &IdentifierNode{Value: "foo"},
445+
Right: &CallNode{Callee: &IdentifierNode{Value: "bar"}}},
446+
},
411447
}
412448
for _, test := range parseTests {
413449
actual, err := parser.Parse(test.input)
@@ -479,6 +515,11 @@ a map key must be a quoted string, a number, a identifier, or an expression encl
479515
unexpected token Operator(",") (1:16)
480516
| {foo:1, bar:2, ,}
481517
| ...............^
518+
519+
foo ?? bar || baz
520+
Operator (||) and coalesce expressions (??) cannot be mixed. Wrap either by parentheses. (1:12)
521+
| foo ?? bar || baz
522+
| ...........^
482523
`
483524

484525
func TestParse_error(t *testing.T) {

Diff for: vm/opcodes.go

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ const (
2626
OpJumpIfTrue
2727
OpJumpIfFalse
2828
OpJumpIfNil
29+
OpJumpIfNotNil
2930
OpJumpIfEnd
3031
OpJumpBackward
3132
OpIn

Diff for: vm/program.go

+3
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ func (program *Program) Disassemble() string {
138138
case OpJumpIfNil:
139139
jump("OpJumpIfNil")
140140

141+
case OpJumpIfNotNil:
142+
jump("OpJumpIfNotNil")
143+
141144
case OpJumpIfEnd:
142145
jump("OpJumpIfEnd")
143146

Diff for: vm/vm.go

+5
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ func (vm *VM) Run(program *Program, env interface{}) (_ interface{}, err error)
174174
vm.ip += arg
175175
}
176176

177+
case OpJumpIfNotNil:
178+
if !runtime.IsNil(vm.current()) {
179+
vm.ip += arg
180+
}
181+
177182
case OpJumpIfEnd:
178183
scope := vm.Scope()
179184
if scope.It >= scope.Len {

0 commit comments

Comments
 (0)