Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package cast
import "time"

const errorMsg = "unable to cast %#v of type %T to %T"
const errorMsgWith = "unable to cast %#v of type %T to %T: %w"

// Basic is a type parameter constraint for functions accepting basic types.
//
Expand Down
54 changes: 35 additions & 19 deletions number.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"encoding/json"
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -149,22 +150,22 @@ func toNumberE[T Number](i any, parseFn func(string) (T, error)) (T, error) {
}

v, err := parseFn(s)
if err == nil {
return v, nil
if err != nil {
return 0, fmt.Errorf(errorMsgWith, i, i, n, err)
}

return 0, fmt.Errorf(errorMsg, i, i, n)
return v, nil
case json.Number:
if s == "" {
return 0, nil
}

v, err := parseFn(string(s))
if err == nil {
return v, nil
if err != nil {
return 0, fmt.Errorf(errorMsgWith, i, i, n, err)
}

return 0, fmt.Errorf(errorMsg, i, i, n)
return v, nil
case float64EProvider:
if _, ok := any(n).(float64); !ok {
return 0, fmt.Errorf(errorMsg, i, i, n)
Expand Down Expand Up @@ -293,22 +294,22 @@ func toUnsignedNumberE[T Number](i any, parseFn func(string) (T, error)) (T, err
}

v, err := parseFn(s)
if err == nil {
return v, nil
if err != nil {
return 0, fmt.Errorf(errorMsgWith, i, i, n, err)
}

return 0, fmt.Errorf(errorMsg, i, i, n)
return v, nil
case json.Number:
if s == "" {
return 0, nil
}

v, err := parseFn(string(s))
if err == nil {
return v, nil
if err != nil {
return 0, fmt.Errorf(errorMsgWith, i, i, n, err)
}

return 0, fmt.Errorf(errorMsg, i, i, n)
return v, nil
case float64EProvider:
if _, ok := any(n).(float64); !ok {
return 0, fmt.Errorf(errorMsg, i, i, n)
Expand Down Expand Up @@ -413,7 +414,7 @@ func parseInt[T integer](s string) (T, error) {
}

func parseUint[T unsigned](s string) (T, error) {
v, err := strconv.ParseUint(trimDecimal(s), 0, 0)
v, err := strconv.ParseUint(strings.TrimLeft(trimDecimal(s), "+"), 0, 0)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -520,13 +521,28 @@ func trimZeroDecimal(s string) string {
return s
}

// trimming decimals seems significantly faster than parsing to float first
//
// see BenchmarkDecimal
var stringNumberRe = regexp.MustCompile(`^([-+]?\d*)(\.\d*)?$`)

// see [BenchmarkDecimal] for details about the implementation
func trimDecimal(s string) string {
// trim the decimal part (if any)
if i := strings.Index(s, "."); i >= 0 {
s = s[:i]
if !strings.Contains(s, ".") {
return s
}

matches := stringNumberRe.FindStringSubmatch(s)
if matches != nil {
// matches[1] is the captured integer part with sign
s = matches[1]

// handle special cases
switch s {
case "-", "+":
s += "0"
case "":
s = "0"
}
Comment on lines +537 to +543
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should explain these cases here, as it's not obvious

Suggested change
// handle special cases
switch s {
case "-", "+":
s += "0"
case "":
s = "0"
}
// handle special cases
switch s {
case "-", "+": // like -.25 or +.25
s += "0"
case "": // like .25
s = "0"
}


return s
}

return s
Expand Down
126 changes: 111 additions & 15 deletions number_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
package cast

import (
"regexp"
"strconv"
"strings"
"testing"

qt "github.com/frankban/quicktest"
Expand Down Expand Up @@ -45,42 +47,136 @@ func TestTrimZeroDecimal(t *testing.T) {
}

func TestTrimDecimal(t *testing.T) {
c := qt.New(t)
testCases := []struct {
input string
expected string
}{
{"10.0", "10"},
{"10.010", "10"},
{"00000.00001", "00000"},
{"-0001.0", "-0001"},
{".5", "0"},
{"+12.", "+12"},
{"+.25", "+0"},
{"-.25", "-0"},
{"0.0000000000", "0"},
{"0.0000000001", "0"},
{"10.0000000000", "10"},
{"10.0000000001", "10"},
{"10000000000000.0000000000", "10000000000000"},

{"10...17", "10...17"},
{"10.foobar", "10.foobar"},
{"10.0i", "10.0i"},
{"10.0E9", "10.0E9"},
}

c.Assert(trimDecimal("10.0"), qt.Equals, "10")
c.Assert(trimDecimal("10.00"), qt.Equals, "10")
c.Assert(trimDecimal("10.010"), qt.Equals, "10")
c.Assert(trimDecimal("0.0000000000"), qt.Equals, "0")
c.Assert(trimDecimal("0.00000000001"), qt.Equals, "0")
for _, testCase := range testCases {
// TODO: remove after minimum Go version is >=1.22
testCase := testCase

t.Run(testCase.input, func(t *testing.T) {
c := qt.New(t)

c.Assert(trimDecimal(testCase.input), qt.Equals, testCase.expected)
})
}
}

// Analysis (in the order of performance):
//
// - Trimming decimals based on decimal point yields a lot of incorrectly parsed values.
// - Parsing to float might be better, but we still need to cast the number, it might overflow, problematic.
// - Regex parsing is an order of magnitude slower, but it yields correct results.
func BenchmarkDecimal(b *testing.B) {
testCases := []string{"10.0", "10.00", "10.010", "0.0000000000", "0.0000000001", "10.0000000000", "10.0000000001", "10000000000000.0000000000"}
testCases := []struct {
input string
expectError bool
}{
{"10.0", false},
{"10.00", false},
{"10.010", false},
{"0.0000000000", false},
{"0.0000000001", false},
{"10.0000000000", false},
{"10.0000000001", false},
{"10000000000000.0000000000", false},

// {"10...17", true},
// {"10.foobar", true},
// {"10.0i", true},
// {"10.0E9", true},
}

trimDecimalString := func(s string) string {
// trim the decimal part (if any)
if i := strings.Index(s, "."); i >= 0 {
s = s[:i]
}

return s
}

re := regexp.MustCompile(`^([-+]?\d*)(\.\d*)?$`)

trimDecimalRegex := func(s string) string {
matches := re.FindStringSubmatch(s)
if matches != nil {
// matches[1] is the captured integer part with sign
return matches[1]
}

return s
}

for _, testCase := range testCases {
// TODO: remove after minimum Go version is >=1.22
testCase := testCase

b.Run(testCase, func(b *testing.B) {
b.Run(testCase.input, func(b *testing.B) {
b.Run("ParseFloat", func(b *testing.B) {
// TODO: use b.Loop() once updated to Go 1.24
for i := 0; i < b.N; i++ {
v, err := strconv.ParseFloat(testCase, 64)
if err != nil {
b.Fatal(err)
v, err := strconv.ParseFloat(testCase.input, 64)
if (err != nil) != testCase.expectError {
if err != nil {
b.Fatal(err)
}

b.Fatal("expected error, but got none")
}

n := int64(v)
_ = n
}
})

b.Run("TrimDecimal", func(b *testing.B) {
b.Run("TrimDecimalString", func(b *testing.B) {
// TODO: use b.Loop() once updated to Go 1.24
for i := 0; i < b.N; i++ {
v, err := strconv.ParseInt(trimDecimalString(testCase.input), 0, 0)
if (err != nil) != testCase.expectError {
if err != nil {
b.Fatal(err)
}

b.Fatal("expected error, but got none")
}

_ = v
}
})

b.Run("TrimDecimalRegex", func(b *testing.B) {
// TODO: use b.Loop() once updated to Go 1.24
for i := 0; i < b.N; i++ {
v, err := strconv.ParseInt(trimDecimal(testCase), 0, 0)
if err != nil {
b.Fatal(err)
v, err := strconv.ParseInt(trimDecimalRegex(testCase.input), 0, 0)
if (err != nil) != testCase.expectError {
if err != nil {
b.Fatal(err)
}

b.Fatal("expected error, but got none")
}

_ = v
Expand Down
25 changes: 25 additions & 0 deletions number_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ var numberContexts = map[string]numberContext{
},
}

// TODO: separate test and failure cases?
// Kinda hard to track cases right now.
func generateNumberTestCases(samples []any) []testCase {
zero := samples[0]
one := samples[1]
Expand All @@ -169,7 +171,9 @@ func generateNumberTestCases(samples []any) []testCase {
_ = overflowString

kind := reflect.TypeOf(zero).Kind()
isSint := kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 || kind == reflect.Int32 || kind == reflect.Int64
isUint := kind == reflect.Uint || kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 || kind == reflect.Uint64
isInt := isSint || isUint

// Some precision is lost when converting from float64 to float32.
eightPoint31_32 := eightPoint31
Expand Down Expand Up @@ -231,6 +235,27 @@ func generateNumberTestCases(samples []any) []testCase {
// Failure cases
{"test", zero, true},
{testing.T{}, zero, true},

{"10...17", zero, true},
{"10.foobar", zero, true},
{"10.0i", zero, true},
}

if isInt {
testCases = append(
testCases,

testCase{".5", zero, false},
testCase{"+8.", eight, false},
testCase{"+.25", zero, false},
testCase{"-.25", zero, isUint},

testCase{"10.0E9", zero, true},
)
} else if kind == reflect.Float32 {
testCases = append(testCases, testCase{"10.0E9", float32(10000000000.000000), false})
} else if kind == reflect.Float64 {
testCases = append(testCases, testCase{"10.0E9", float64(10000000000.000000), false})
}

if isUint && underflowString != nil {
Expand Down