diff --git a/cast.go b/cast.go index ec14aec..8d85539 100644 --- a/cast.go +++ b/cast.go @@ -9,6 +9,7 @@ package cast import "time" const errorMsg = "unable to cast %#v of type %T to %T" +const errorMsgWith = "unable to cast %#v of type %T to %T: %w" // Basic is a type parameter constraint for functions accepting basic types. // diff --git a/number.go b/number.go index 1cca23c..a58dc4d 100644 --- a/number.go +++ b/number.go @@ -9,6 +9,7 @@ import ( "encoding/json" "errors" "fmt" + "regexp" "strconv" "strings" "time" @@ -149,22 +150,22 @@ func toNumberE[T Number](i any, parseFn func(string) (T, error)) (T, error) { } v, err := parseFn(s) - if err == nil { - return v, nil + if err != nil { + return 0, fmt.Errorf(errorMsgWith, i, i, n, err) } - return 0, fmt.Errorf(errorMsg, i, i, n) + return v, nil case json.Number: if s == "" { return 0, nil } v, err := parseFn(string(s)) - if err == nil { - return v, nil + if err != nil { + return 0, fmt.Errorf(errorMsgWith, i, i, n, err) } - return 0, fmt.Errorf(errorMsg, i, i, n) + return v, nil case float64EProvider: if _, ok := any(n).(float64); !ok { return 0, fmt.Errorf(errorMsg, i, i, n) @@ -293,22 +294,22 @@ func toUnsignedNumberE[T Number](i any, parseFn func(string) (T, error)) (T, err } v, err := parseFn(s) - if err == nil { - return v, nil + if err != nil { + return 0, fmt.Errorf(errorMsgWith, i, i, n, err) } - return 0, fmt.Errorf(errorMsg, i, i, n) + return v, nil case json.Number: if s == "" { return 0, nil } v, err := parseFn(string(s)) - if err == nil { - return v, nil + if err != nil { + return 0, fmt.Errorf(errorMsgWith, i, i, n, err) } - return 0, fmt.Errorf(errorMsg, i, i, n) + return v, nil case float64EProvider: if _, ok := any(n).(float64); !ok { return 0, fmt.Errorf(errorMsg, i, i, n) @@ -413,7 +414,7 @@ func parseInt[T integer](s string) (T, error) { } func parseUint[T unsigned](s string) (T, error) { - v, err := strconv.ParseUint(trimDecimal(s), 0, 0) + v, err := strconv.ParseUint(strings.TrimLeft(trimDecimal(s), "+"), 0, 0) if err != nil { return 0, err } @@ -520,13 +521,28 @@ func trimZeroDecimal(s string) string { return s } -// trimming decimals seems significantly faster than parsing to float first -// -// see BenchmarkDecimal +var stringNumberRe = regexp.MustCompile(`^([-+]?\d*)(\.\d*)?$`) + +// see [BenchmarkDecimal] for details about the implementation func trimDecimal(s string) string { - // trim the decimal part (if any) - if i := strings.Index(s, "."); i >= 0 { - s = s[:i] + if !strings.Contains(s, ".") { + return s + } + + matches := stringNumberRe.FindStringSubmatch(s) + if matches != nil { + // matches[1] is the captured integer part with sign + s = matches[1] + + // handle special cases + switch s { + case "-", "+": + s += "0" + case "": + s = "0" + } + + return s } return s diff --git a/number_internal_test.go b/number_internal_test.go index b95c082..6d333d0 100644 --- a/number_internal_test.go +++ b/number_internal_test.go @@ -6,7 +6,9 @@ package cast import ( + "regexp" "strconv" + "strings" "testing" qt "github.com/frankban/quicktest" @@ -45,29 +47,103 @@ func TestTrimZeroDecimal(t *testing.T) { } func TestTrimDecimal(t *testing.T) { - c := qt.New(t) + testCases := []struct { + input string + expected string + }{ + {"10.0", "10"}, + {"10.010", "10"}, + {"00000.00001", "00000"}, + {"-0001.0", "-0001"}, + {".5", "0"}, + {"+12.", "+12"}, + {"+.25", "+0"}, + {"-.25", "-0"}, + {"0.0000000000", "0"}, + {"0.0000000001", "0"}, + {"10.0000000000", "10"}, + {"10.0000000001", "10"}, + {"10000000000000.0000000000", "10000000000000"}, + + {"10...17", "10...17"}, + {"10.foobar", "10.foobar"}, + {"10.0i", "10.0i"}, + {"10.0E9", "10.0E9"}, + } - c.Assert(trimDecimal("10.0"), qt.Equals, "10") - c.Assert(trimDecimal("10.00"), qt.Equals, "10") - c.Assert(trimDecimal("10.010"), qt.Equals, "10") - c.Assert(trimDecimal("0.0000000000"), qt.Equals, "0") - c.Assert(trimDecimal("0.00000000001"), qt.Equals, "0") + for _, testCase := range testCases { + // TODO: remove after minimum Go version is >=1.22 + testCase := testCase + + t.Run(testCase.input, func(t *testing.T) { + c := qt.New(t) + + c.Assert(trimDecimal(testCase.input), qt.Equals, testCase.expected) + }) + } } +// Analysis (in the order of performance): +// +// - Trimming decimals based on decimal point yields a lot of incorrectly parsed values. +// - Parsing to float might be better, but we still need to cast the number, it might overflow, problematic. +// - Regex parsing is an order of magnitude slower, but it yields correct results. func BenchmarkDecimal(b *testing.B) { - testCases := []string{"10.0", "10.00", "10.010", "0.0000000000", "0.0000000001", "10.0000000000", "10.0000000001", "10000000000000.0000000000"} + testCases := []struct { + input string + expectError bool + }{ + {"10.0", false}, + {"10.00", false}, + {"10.010", false}, + {"0.0000000000", false}, + {"0.0000000001", false}, + {"10.0000000000", false}, + {"10.0000000001", false}, + {"10000000000000.0000000000", false}, + + // {"10...17", true}, + // {"10.foobar", true}, + // {"10.0i", true}, + // {"10.0E9", true}, + } + + trimDecimalString := func(s string) string { + // trim the decimal part (if any) + if i := strings.Index(s, "."); i >= 0 { + s = s[:i] + } + + return s + } + + re := regexp.MustCompile(`^([-+]?\d*)(\.\d*)?$`) + + trimDecimalRegex := func(s string) string { + matches := re.FindStringSubmatch(s) + if matches != nil { + // matches[1] is the captured integer part with sign + return matches[1] + } + + return s + } for _, testCase := range testCases { // TODO: remove after minimum Go version is >=1.22 testCase := testCase - b.Run(testCase, func(b *testing.B) { + b.Run(testCase.input, func(b *testing.B) { b.Run("ParseFloat", func(b *testing.B) { // TODO: use b.Loop() once updated to Go 1.24 for i := 0; i < b.N; i++ { - v, err := strconv.ParseFloat(testCase, 64) - if err != nil { - b.Fatal(err) + v, err := strconv.ParseFloat(testCase.input, 64) + if (err != nil) != testCase.expectError { + if err != nil { + b.Fatal(err) + } + + b.Fatal("expected error, but got none") } n := int64(v) @@ -75,12 +151,32 @@ func BenchmarkDecimal(b *testing.B) { } }) - b.Run("TrimDecimal", func(b *testing.B) { + b.Run("TrimDecimalString", func(b *testing.B) { + // TODO: use b.Loop() once updated to Go 1.24 + for i := 0; i < b.N; i++ { + v, err := strconv.ParseInt(trimDecimalString(testCase.input), 0, 0) + if (err != nil) != testCase.expectError { + if err != nil { + b.Fatal(err) + } + + b.Fatal("expected error, but got none") + } + + _ = v + } + }) + + b.Run("TrimDecimalRegex", func(b *testing.B) { // TODO: use b.Loop() once updated to Go 1.24 for i := 0; i < b.N; i++ { - v, err := strconv.ParseInt(trimDecimal(testCase), 0, 0) - if err != nil { - b.Fatal(err) + v, err := strconv.ParseInt(trimDecimalRegex(testCase.input), 0, 0) + if (err != nil) != testCase.expectError { + if err != nil { + b.Fatal(err) + } + + b.Fatal("expected error, but got none") } _ = v diff --git a/number_test.go b/number_test.go index 8935206..9d84d18 100644 --- a/number_test.go +++ b/number_test.go @@ -150,6 +150,8 @@ var numberContexts = map[string]numberContext{ }, } +// TODO: separate test and failure cases? +// Kinda hard to track cases right now. func generateNumberTestCases(samples []any) []testCase { zero := samples[0] one := samples[1] @@ -169,7 +171,9 @@ func generateNumberTestCases(samples []any) []testCase { _ = overflowString kind := reflect.TypeOf(zero).Kind() + isSint := kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 || kind == reflect.Int32 || kind == reflect.Int64 isUint := kind == reflect.Uint || kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 || kind == reflect.Uint64 + isInt := isSint || isUint // Some precision is lost when converting from float64 to float32. eightPoint31_32 := eightPoint31 @@ -231,6 +235,27 @@ func generateNumberTestCases(samples []any) []testCase { // Failure cases {"test", zero, true}, {testing.T{}, zero, true}, + + {"10...17", zero, true}, + {"10.foobar", zero, true}, + {"10.0i", zero, true}, + } + + if isInt { + testCases = append( + testCases, + + testCase{".5", zero, false}, + testCase{"+8.", eight, false}, + testCase{"+.25", zero, false}, + testCase{"-.25", zero, isUint}, + + testCase{"10.0E9", zero, true}, + ) + } else if kind == reflect.Float32 { + testCases = append(testCases, testCase{"10.0E9", float32(10000000000.000000), false}) + } else if kind == reflect.Float64 { + testCases = append(testCases, testCase{"10.0E9", float64(10000000000.000000), false}) } if isUint && underflowString != nil {