Skip to content

Commit 3da8527

Browse files
committed
Fix operator overloading recursive
Fixes #548
1 parent 62cdd42 commit 3da8527

File tree

4 files changed

+134
-5
lines changed

4 files changed

+134
-5
lines changed

Diff for: expr.go

+18-5
Original file line numberDiff line numberDiff line change
@@ -200,11 +200,24 @@ func Compile(input string, ops ...Option) (*vm.Program, error) {
200200
}
201201

202202
if len(config.Visitors) > 0 {
203-
for _, v := range config.Visitors {
204-
// We need to perform types check, because some visitors may rely on
205-
// types information available in the tree.
206-
_, _ = checker.Check(tree, config)
207-
ast.Walk(&tree.Node, v)
203+
for i := 0; i < 1000; i++ {
204+
more := false
205+
for _, v := range config.Visitors {
206+
// We need to perform types check, because some visitors may rely on
207+
// types information available in the tree.
208+
_, _ = checker.Check(tree, config)
209+
210+
ast.Walk(&tree.Node, v)
211+
212+
if v, ok := v.(interface {
213+
ShouldRepeat() bool
214+
}); ok {
215+
more = more || v.ShouldRepeat()
216+
}
217+
}
218+
if !more {
219+
break
220+
}
208221
}
209222
}
210223
_, err = checker.Check(tree, config)

Diff for: patcher/operator_override.go

+6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ type OperatorOverride struct {
1414
Overrides []string // List of function names to override operator with.
1515
Types conf.TypesTable // Env types.
1616
Functions conf.FunctionsTable // Env functions.
17+
applied bool // Flag to indicate if any override was applied.
1718
}
1819

1920
func (p *OperatorOverride) Visit(node *ast.Node) {
@@ -37,9 +38,14 @@ func (p *OperatorOverride) Visit(node *ast.Node) {
3738
}
3839
newNode.SetType(ret)
3940
ast.Patch(node, newNode)
41+
p.applied = true
4042
}
4143
}
4244

45+
func (p *OperatorOverride) ShouldRepeat() bool {
46+
return p.applied
47+
}
48+
4349
func (p *OperatorOverride) FindSuitableOperatorOverload(l, r reflect.Type) (reflect.Type, string, bool) {
4450
t, fn, ok := p.findSuitableOperatorOverloadInFunctions(l, r)
4551
if !ok {

Diff for: test/operator/issues584/issues584_test.go

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package issues584_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
8+
"github.com/expr-lang/expr"
9+
)
10+
11+
type Env struct{}
12+
13+
type Program struct {
14+
}
15+
16+
func (p *Program) Foo() Value {
17+
return func(e *Env) float64 {
18+
return 5
19+
}
20+
}
21+
22+
func (p *Program) Bar() Value {
23+
return func(e *Env) float64 {
24+
return 100
25+
}
26+
}
27+
28+
func (p *Program) AndCondition(a, b Condition) Conditions {
29+
return Conditions{a, b}
30+
}
31+
32+
func (p *Program) AndConditions(a Conditions, b Condition) Conditions {
33+
return append(a, b)
34+
}
35+
36+
func (p *Program) ValueGreaterThan_float(v Value, i float64) Condition {
37+
return func(e *Env) bool {
38+
realized := v(e)
39+
return realized > i
40+
}
41+
}
42+
43+
func (p *Program) ValueLessThan_float(v Value, i float64) Condition {
44+
return func(e *Env) bool {
45+
realized := v(e)
46+
return realized < i
47+
}
48+
}
49+
50+
type Condition func(e *Env) bool
51+
type Conditions []Condition
52+
53+
type Value func(e *Env) float64
54+
55+
func TestIssue584(t *testing.T) {
56+
code := `Foo() > 1.5 and Bar() < 200.0`
57+
58+
p := &Program{}
59+
60+
opt := []expr.Option{
61+
expr.Env(p),
62+
expr.Operator("and", "AndCondition", "AndConditions"),
63+
expr.Operator(">", "ValueGreaterThan_float"),
64+
expr.Operator("<", "ValueLessThan_float"),
65+
}
66+
67+
program, err := expr.Compile(code, opt...)
68+
assert.Nil(t, err)
69+
70+
state, err := expr.Run(program, p)
71+
assert.Nil(t, err)
72+
assert.NotNil(t, state)
73+
}

Diff for: test/operator/operator_test.go

+37
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,40 @@ func TestOperator_Polymorphic(t *testing.T) {
216216
require.NoError(t, err)
217217
require.Equal(t, 6, output)
218218
}
219+
220+
func TestOperator_recursive_apply(t *testing.T) {
221+
type Decimal struct {
222+
Int int
223+
}
224+
225+
env := map[string]any{
226+
"add": func(a, b Decimal) Decimal {
227+
return Decimal{
228+
Int: a.Int + b.Int,
229+
}
230+
},
231+
"addInt": func(a Decimal, b int) Decimal {
232+
return Decimal{
233+
Int: a.Int + b,
234+
}
235+
},
236+
"a": Decimal{1},
237+
"b": Decimal{2},
238+
"c": Decimal{3},
239+
"d": Decimal{4},
240+
"e": Decimal{5},
241+
}
242+
243+
program, err := expr.Compile(
244+
`a + b + 100 + c + d + e`,
245+
expr.Env(env),
246+
expr.Operator("+", "add"),
247+
expr.Operator("+", "addInt"),
248+
)
249+
require.NoError(t, err)
250+
require.Equal(t, `add(add(add(addInt(add(a, b), 100), c), d), e)`, program.Node().String())
251+
252+
output, err := expr.Run(program, env)
253+
require.NoError(t, err)
254+
require.Equal(t, 115, output.(Decimal).Int)
255+
}

0 commit comments

Comments
 (0)