Skip to content

Commit 709c5dd

Browse files
authored
Propagate uint32 arguments types in ast (#438)
1 parent f527a33 commit 709c5dd

File tree

3 files changed

+69
-7
lines changed

3 files changed

+69
-7
lines changed

checker/checker.go

+29-6
Original file line numberDiff line numberDiff line change
@@ -924,8 +924,13 @@ func (v *checker) checkArguments(name string, fn reflect.Type, method bool, argu
924924
}
925925

926926
if isFloat(in) {
927-
t = floatType
928-
traverseAndReplaceIntegerNodesWithFloatNodes(&arg)
927+
traverseAndReplaceIntegerNodesWithFloatNodes(&arguments[i], in)
928+
continue
929+
}
930+
931+
if isInteger(in) && isInteger(t) && kind(t) != kind(in) {
932+
traverseAndReplaceIntegerNodesWithIntegerNodes(&arguments[i], in)
933+
continue
929934
}
930935

931936
if t == nil {
@@ -943,19 +948,37 @@ func (v *checker) checkArguments(name string, fn reflect.Type, method bool, argu
943948
return fn.Out(0), nil
944949
}
945950

946-
func traverseAndReplaceIntegerNodesWithFloatNodes(node *ast.Node) {
951+
func traverseAndReplaceIntegerNodesWithFloatNodes(node *ast.Node, newType reflect.Type) {
947952
switch (*node).(type) {
948953
case *ast.IntegerNode:
949954
*node = &ast.FloatNode{Value: float64((*node).(*ast.IntegerNode).Value)}
955+
(*node).SetType(newType)
956+
case *ast.UnaryNode:
957+
unaryNode := (*node).(*ast.UnaryNode)
958+
traverseAndReplaceIntegerNodesWithFloatNodes(&unaryNode.Node, newType)
959+
case *ast.BinaryNode:
960+
binaryNode := (*node).(*ast.BinaryNode)
961+
switch binaryNode.Operator {
962+
case "+", "-", "*":
963+
traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Left, newType)
964+
traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Right, newType)
965+
}
966+
}
967+
}
968+
969+
func traverseAndReplaceIntegerNodesWithIntegerNodes(node *ast.Node, newType reflect.Type) {
970+
switch (*node).(type) {
971+
case *ast.IntegerNode:
972+
(*node).SetType(newType)
950973
case *ast.UnaryNode:
951974
unaryNode := (*node).(*ast.UnaryNode)
952-
traverseAndReplaceIntegerNodesWithFloatNodes(&unaryNode.Node)
975+
traverseAndReplaceIntegerNodesWithIntegerNodes(&unaryNode.Node, newType)
953976
case *ast.BinaryNode:
954977
binaryNode := (*node).(*ast.BinaryNode)
955978
switch binaryNode.Operator {
956979
case "+", "-", "*":
957-
traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Left)
958-
traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Right)
980+
traverseAndReplaceIntegerNodesWithIntegerNodes(&binaryNode.Left, newType)
981+
traverseAndReplaceIntegerNodesWithIntegerNodes(&binaryNode.Right, newType)
959982
}
960983
}
961984
}

compiler/compiler.go

+11-1
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,17 @@ func (c *compiler) IntegerNode(node *ast.IntegerNode) {
307307
}
308308

309309
func (c *compiler) FloatNode(node *ast.FloatNode) {
310-
c.emitPush(node.Value)
310+
t := node.Type()
311+
if t == nil {
312+
c.emitPush(node.Value)
313+
return
314+
}
315+
switch t.Kind() {
316+
case reflect.Float32:
317+
c.emitPush(float32(node.Value))
318+
case reflect.Float64:
319+
c.emitPush(node.Value)
320+
}
311321
}
312322

313323
func (c *compiler) BoolNode(node *ast.BoolNode) {

expr_test.go

+29
Original file line numberDiff line numberDiff line change
@@ -1969,3 +1969,32 @@ func TestMemoryBudget(t *testing.T) {
19691969
})
19701970
}
19711971
}
1972+
1973+
func TestIssue432(t *testing.T) {
1974+
env := map[string]any{
1975+
"func": func(
1976+
paramUint32 uint32,
1977+
paramUint16 uint16,
1978+
paramUint8 uint8,
1979+
paramUint uint,
1980+
paramInt32 int32,
1981+
paramInt16 int16,
1982+
paramInt8 int8,
1983+
paramInt int,
1984+
paramFloat64 float64,
1985+
paramFloat32 float32,
1986+
) float64 {
1987+
return float64(paramUint32) + float64(paramUint16) + float64(paramUint8) + float64(paramUint) +
1988+
float64(paramInt32) + float64(paramInt16) + float64(paramInt8) + float64(paramInt) +
1989+
float64(paramFloat64) + float64(paramFloat32)
1990+
},
1991+
}
1992+
code := `func(1,1,1,1,1,1,1,1,1,1)`
1993+
1994+
program, err := expr.Compile(code, expr.Env(env))
1995+
assert.NoError(t, err)
1996+
1997+
out, err := expr.Run(program, env)
1998+
assert.NoError(t, err)
1999+
assert.Equal(t, float64(10), out)
2000+
}

0 commit comments

Comments
 (0)