Skip to content

Commit 31dd62f

Browse files
zane-degZane DeGraffenried
andauthored
Fix boolean support for required_if, required_unless and eqfield (#754)
* Fix boolean support in requireCheckFieldValue, isEqField and isNeField * Added tests Co-authored-by: Zane DeGraffenried <[email protected]>
1 parent b926bf0 commit 31dd62f

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

baked_in.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,9 @@ func isNeCrossStructField(fl FieldLevel) bool {
10381038
case reflect.Slice, reflect.Map, reflect.Array:
10391039
return int64(topField.Len()) != int64(field.Len())
10401040

1041+
case reflect.Bool:
1042+
return topField.Bool() != field.Bool()
1043+
10411044
case reflect.Struct:
10421045

10431046
fieldType := field.Type()
@@ -1085,6 +1088,9 @@ func isEqCrossStructField(fl FieldLevel) bool {
10851088
case reflect.Slice, reflect.Map, reflect.Array:
10861089
return int64(topField.Len()) == int64(field.Len())
10871090

1091+
case reflect.Bool:
1092+
return topField.Bool() == field.Bool()
1093+
10881094
case reflect.Struct:
10891095

10901096
fieldType := field.Type()
@@ -1132,6 +1138,9 @@ func isEqField(fl FieldLevel) bool {
11321138
case reflect.Slice, reflect.Map, reflect.Array:
11331139
return int64(field.Len()) == int64(currentField.Len())
11341140

1141+
case reflect.Bool:
1142+
return field.Bool() == currentField.Bool()
1143+
11351144
case reflect.Struct:
11361145

11371146
fieldType := field.Type()
@@ -1446,6 +1455,9 @@ func requireCheckFieldValue(fl FieldLevel, param string, value string, defaultNo
14461455

14471456
case reflect.Slice, reflect.Map, reflect.Array:
14481457
return int64(field.Len()) == asInt(value)
1458+
1459+
case reflect.Bool:
1460+
return field.Bool() == asBool(value)
14491461
}
14501462

14511463
// default reflect.String:

validator_test.go

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1670,12 +1670,14 @@ func TestCrossStructNeFieldValidation(t *testing.T) {
16701670
i := 1
16711671
j = 1
16721672
k = 1.543
1673+
b := true
16731674
arr := []string{"test"}
16741675

16751676
s2 := "abcd"
16761677
i2 := 1
16771678
j2 = 1
16781679
k2 = 1.543
1680+
b2 := true
16791681
arr2 := []string{"test"}
16801682
arr3 := []string{"test", "test2"}
16811683
now2 := now
@@ -1696,6 +1698,10 @@ func TestCrossStructNeFieldValidation(t *testing.T) {
16961698
NotEqual(t, errs, nil)
16971699
AssertError(t, errs, "", "", "", "", "necsfield")
16981700

1701+
errs = validate.VarWithValue(b2, b, "necsfield")
1702+
NotEqual(t, errs, nil)
1703+
AssertError(t, errs, "", "", "", "", "necsfield")
1704+
16991705
errs = validate.VarWithValue(arr2, arr, "necsfield")
17001706
NotEqual(t, errs, nil)
17011707
AssertError(t, errs, "", "", "", "", "necsfield")
@@ -1834,6 +1840,7 @@ func TestCrossStructEqFieldValidation(t *testing.T) {
18341840
i := 1
18351841
j = 1
18361842
k = 1.543
1843+
b := true
18371844
arr := []string{"test"}
18381845

18391846
var j2 uint64
@@ -1842,6 +1849,7 @@ func TestCrossStructEqFieldValidation(t *testing.T) {
18421849
i2 := 1
18431850
j2 = 1
18441851
k2 = 1.543
1852+
b2 := true
18451853
arr2 := []string{"test"}
18461854
arr3 := []string{"test", "test2"}
18471855
now2 := now
@@ -1858,6 +1866,9 @@ func TestCrossStructEqFieldValidation(t *testing.T) {
18581866
errs = validate.VarWithValue(k2, k, "eqcsfield")
18591867
Equal(t, errs, nil)
18601868

1869+
errs = validate.VarWithValue(b2, b, "eqcsfield")
1870+
Equal(t, errs, nil)
1871+
18611872
errs = validate.VarWithValue(arr2, arr, "eqcsfield")
18621873
Equal(t, errs, nil)
18631874

@@ -4829,6 +4840,7 @@ func TestIsEqFieldValidation(t *testing.T) {
48294840
i := 1
48304841
j = 1
48314842
k = 1.543
4843+
b := true
48324844
arr := []string{"test"}
48334845
now := time.Now().UTC()
48344846

@@ -4838,6 +4850,7 @@ func TestIsEqFieldValidation(t *testing.T) {
48384850
i2 := 1
48394851
j2 = 1
48404852
k2 = 1.543
4853+
b2 := true
48414854
arr2 := []string{"test"}
48424855
arr3 := []string{"test", "test2"}
48434856
now2 := now
@@ -4854,6 +4867,9 @@ func TestIsEqFieldValidation(t *testing.T) {
48544867
errs = validate.VarWithValue(k2, k, "eqfield")
48554868
Equal(t, errs, nil)
48564869

4870+
errs = validate.VarWithValue(b2, b, "eqfield")
4871+
Equal(t, errs, nil)
4872+
48574873
errs = validate.VarWithValue(arr2, arr, "eqfield")
48584874
Equal(t, errs, nil)
48594875

@@ -10065,12 +10081,15 @@ func TestRequiredUnless(t *testing.T) {
1006510081
Field6 uint `validate:"required_unless=Field5 2" json:"field_6"`
1006610082
Field7 float32 `validate:"required_unless=Field6 0" json:"field_7"`
1006710083
Field8 float64 `validate:"required_unless=Field7 0.0" json:"field_8"`
10084+
Field9 bool `validate:"omitempty" json:"field_9"`
10085+
Field10 string `validate:"required_unless=Field9 true" json:"field_10"`
1006810086
}{
1006910087
FieldE: "test",
1007010088
Field2: &fieldVal,
1007110089
Field3: map[string]string{"key": "val"},
1007210090
Field4: "test",
1007310091
Field5: 2,
10092+
Field9: true,
1007410093
}
1007510094

1007610095
validate := New()
@@ -10090,6 +10109,8 @@ func TestRequiredUnless(t *testing.T) {
1009010109
Field5 string `validate:"required_unless=Field3 0" json:"field_5"`
1009110110
Field6 string `validate:"required_unless=Inner.Field test" json:"field_6"`
1009210111
Field7 string `validate:"required_unless=Inner2.Field test" json:"field_7"`
10112+
Field8 bool `validate:"omitempty" json:"field_8"`
10113+
Field9 string `validate:"required_unless=Field8 true" json:"field_9"`
1009310114
}{
1009410115
Inner: &Inner{Field: &fieldVal},
1009510116
FieldE: "test",
@@ -10100,10 +10121,11 @@ func TestRequiredUnless(t *testing.T) {
1010010121
NotEqual(t, errs, nil)
1010110122

1010210123
ve := errs.(ValidationErrors)
10103-
Equal(t, len(ve), 3)
10124+
Equal(t, len(ve), 4)
1010410125
AssertError(t, errs, "Field3", "Field3", "Field3", "Field3", "required_unless")
1010510126
AssertError(t, errs, "Field4", "Field4", "Field4", "Field4", "required_unless")
1010610127
AssertError(t, errs, "Field7", "Field7", "Field7", "Field7", "required_unless")
10128+
AssertError(t, errs, "Field9", "Field9", "Field9", "Field9", "required_unless")
1010710129

1010810130
defer func() {
1010910131
if r := recover(); r == nil {

0 commit comments

Comments
 (0)