From 4b4638b370e898a9d10709464b8cd460b7dcfd0c Mon Sep 17 00:00:00 2001 From: zhengchun Date: Tue, 9 Apr 2024 21:53:51 +0800 Subject: [PATCH] #96, allows node-set numeric operator on `+`, `-`, `*`, `MOD()`, `DIV()` --- build.go | 2 +- operator.go | 30 ++++++++++++++---------------- query.go | 4 ++-- xpath_test.go | 23 +++++++++++++++++++++++ 4 files changed, 40 insertions(+), 19 deletions(-) diff --git a/build.go b/build.go index e079e96..44f87d8 100644 --- a/build.go +++ b/build.go @@ -665,7 +665,7 @@ func (b *builder) processOperator(root *operatorNode, props *builderProp) (query var qyOutput query switch root.Op { case "+", "-", "*", "div", "mod": // Numeric operator - var exprFunc func(interface{}, interface{}) interface{} + var exprFunc func(iterator, interface{}, interface{}) interface{} switch root.Op { case "+": exprFunc = plusFunc diff --git a/operator.go b/operator.go index 12aadc1..2820152 100644 --- a/operator.go +++ b/operator.go @@ -1,7 +1,6 @@ package xpath import ( - "reflect" "strconv" ) @@ -247,44 +246,43 @@ var orFunc = func(t iterator, m, n interface{}) interface{} { return logicalFuncs[t1][t2](t, "or", m, n) } -func numericExpr(m, n interface{}, cb func(float64, float64) float64) float64 { - typ := reflect.TypeOf(float64(0)) - a := reflect.ValueOf(m).Convert(typ) - b := reflect.ValueOf(n).Convert(typ) - return cb(a.Float(), b.Float()) +func numericExpr(t iterator, m, n interface{}, cb func(float64, float64) float64) float64 { + a := asNumber(t, m) + b := asNumber(t, n) + return cb(a, b) } // plusFunc is an `+` operator. -var plusFunc = func(m, n interface{}) interface{} { - return numericExpr(m, n, func(a, b float64) float64 { +var plusFunc = func(t iterator, m, n interface{}) interface{} { + return numericExpr(t, m, n, func(a, b float64) float64 { return a + b }) } // minusFunc is an `-` operator. -var minusFunc = func(m, n interface{}) interface{} { - return numericExpr(m, n, func(a, b float64) float64 { +var minusFunc = func(t iterator, m, n interface{}) interface{} { + return numericExpr(t, m, n, func(a, b float64) float64 { return a - b }) } // mulFunc is an `*` operator. -var mulFunc = func(m, n interface{}) interface{} { - return numericExpr(m, n, func(a, b float64) float64 { +var mulFunc = func(t iterator, m, n interface{}) interface{} { + return numericExpr(t, m, n, func(a, b float64) float64 { return a * b }) } // divFunc is an `DIV` operator. -var divFunc = func(m, n interface{}) interface{} { - return numericExpr(m, n, func(a, b float64) float64 { +var divFunc = func(t iterator, m, n interface{}) interface{} { + return numericExpr(t, m, n, func(a, b float64) float64 { return a / b }) } // modFunc is an 'MOD' operator. -var modFunc = func(m, n interface{}) interface{} { - return numericExpr(m, n, func(a, b float64) float64 { +var modFunc = func(t iterator, m, n interface{}) interface{} { + return numericExpr(t, m, n, func(a, b float64) float64 { return float64(int(a) % int(b)) }) } diff --git a/query.go b/query.go index fe6f488..a4d1dce 100644 --- a/query.go +++ b/query.go @@ -999,7 +999,7 @@ func (l *logicalQuery) Properties() queryProp { type numericQuery struct { Left, Right query - Do func(interface{}, interface{}) interface{} + Do func(iterator, interface{}, interface{}) interface{} } func (n *numericQuery) Select(t iterator) NodeNavigator { @@ -1009,7 +1009,7 @@ func (n *numericQuery) Select(t iterator) NodeNavigator { func (n *numericQuery) Evaluate(t iterator) interface{} { m := n.Left.Evaluate(t) k := n.Right.Evaluate(t) - return n.Do(m, k) + return n.Do(t, m, k) } func (n *numericQuery) Clone() query { diff --git a/xpath_test.go b/xpath_test.go index e845860..cd62243 100644 --- a/xpath_test.go +++ b/xpath_test.go @@ -3,6 +3,7 @@ package xpath import ( "bytes" "fmt" + "math" "sort" "strings" "testing" @@ -235,6 +236,28 @@ func TestMustCompile(t *testing.T) { } } +func Test_plusFunc(t *testing.T) { + // 1+1 + assertEqual(t, float64(2), plusFunc(nil, float64(1), float64(1))) + // string + + assertEqual(t, float64(2), plusFunc(nil, "1", "1")) + // invalid string + v := plusFunc(nil, "a", 1) + assertTrue(t, math.IsNaN(v.(float64))) + // Nodeset + // TODO +} + +func Test_minusFunc(t *testing.T) { + // 1 - 1 + assertEqual(t, float64(0), minusFunc(nil, float64(1), float64(1))) + // string + assertEqual(t, float64(0), minusFunc(nil, "1", "1")) + // invalid string + v := minusFunc(nil, "a", 1) + assertTrue(t, math.IsNaN(v.(float64))) +} + func TestNodeType(t *testing.T) { tests := []struct { expr string