Skip to content

Commit 653a058

Browse files
committed
Add more type checks
1 parent aae3b0f commit 653a058

File tree

4 files changed

+204
-57
lines changed

4 files changed

+204
-57
lines changed

Diff for: eval_test.go

+16
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,22 @@ var evalTests = []evalTest{
189189
}{1, 1},
190190
true,
191191
},
192+
{
193+
`A == B`,
194+
struct {
195+
A float64
196+
B interface{}
197+
}{1, new(interface{})},
198+
false,
199+
},
200+
{
201+
`A == B`,
202+
struct {
203+
A interface{}
204+
B float64
205+
}{new(interface{}), 1},
206+
false,
207+
},
192208
{
193209
`[true][A]`,
194210
&struct{ A int }{0},

Diff for: type.go

+91-38
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func (n unaryNode) Type(table typesTable) (Type, error) {
5454

5555
switch n.operator {
5656
case "!", "not":
57-
if isBoolType(ntype) {
57+
if isBoolType(ntype) || isInterfaceType(ntype) {
5858
return boolType, nil
5959
}
6060
return nil, fmt.Errorf(`invalid operation: %v (mismatched type %v)`, n, ntype)
@@ -80,8 +80,15 @@ func (n binaryNode) Type(table typesTable) (Type, error) {
8080
return boolType, nil
8181
}
8282
return nil, fmt.Errorf(`invalid operation: %v (mismatched types %v and %v)`, n, ltype, rtype)
83+
8384
case "or", "||", "and", "&&":
84-
if isBoolType(ltype) && isBoolType(rtype) {
85+
if (isBoolType(ltype) || isInterfaceType(ltype)) && (isBoolType(rtype) || isInterfaceType(rtype)) {
86+
return boolType, nil
87+
}
88+
return nil, fmt.Errorf(`invalid operation: %v (mismatched types %v and %v)`, n, ltype, rtype)
89+
90+
case "|", "^", "&", "<", ">", ">=", "<=", "+", "-", "*", "/", "%", "**", "..":
91+
if (isNumberType(ltype) || isInterfaceType(ltype)) && (isNumberType(rtype) || isInterfaceType(rtype)) {
8592
return boolType, nil
8693
}
8794
return nil, fmt.Errorf(`invalid operation: %v (mismatched types %v and %v)`, n, ltype, rtype)
@@ -91,7 +98,19 @@ func (n binaryNode) Type(table typesTable) (Type, error) {
9198
}
9299

93100
func (n matchesNode) Type(table typesTable) (Type, error) {
94-
return boolType, nil
101+
var err error
102+
ltype, err := n.left.Type(table)
103+
if err != nil {
104+
return nil, err
105+
}
106+
rtype, err := n.right.Type(table)
107+
if err != nil {
108+
return nil, err
109+
}
110+
if (isStringType(ltype) || isInterfaceType(ltype)) && (isStringType(rtype) || isInterfaceType(rtype)) {
111+
return boolType, nil
112+
}
113+
return nil, fmt.Errorf(`invalid operation: %v (mismatched types %v and %v)`, n, ltype, rtype)
95114
}
96115

97116
func (n propertyNode) Type(table typesTable) (Type, error) {
@@ -141,8 +160,14 @@ func (n methodNode) Type(table typesTable) (Type, error) {
141160
}
142161

143162
func (n builtinNode) Type(table typesTable) (Type, error) {
163+
for _, node := range n.arguments {
164+
_, err := node.Type(table)
165+
if err != nil {
166+
return nil, err
167+
}
168+
}
144169
if _, ok := builtins[n.name]; ok {
145-
return nil, nil
170+
return interfaceType, nil
146171
}
147172
return nil, fmt.Errorf("%v undefined", n)
148173
}
@@ -167,7 +192,7 @@ func (n conditionalNode) Type(table typesTable) (Type, error) {
167192
if err != nil {
168193
return nil, err
169194
}
170-
if !isBoolType(ctype) {
195+
if !isBoolType(ctype) && !isInterfaceType(ctype) {
171196
return nil, fmt.Errorf("non-bool %v (type %v) used as condition", n.cond, ctype)
172197
}
173198
_, err = n.exp1.Type(table)
@@ -216,60 +241,88 @@ func (n pairNode) Type(table typesTable) (Type, error) {
216241

217242
// helper funcs for reflect
218243

219-
func isComparable(ltype Type, rtype Type) bool {
220-
ltype = dereference(ltype)
221-
if ltype == nil {
222-
return true
223-
}
224-
rtype = dereference(rtype)
225-
if rtype == nil {
226-
return true
244+
func isComparable(l Type, r Type) bool {
245+
l = dereference(l)
246+
r = dereference(r)
247+
248+
if l == nil || r == nil {
249+
return true // It is possible to compare with nil.
227250
}
228251

229-
if canBeNumberType(ltype) && canBeNumberType(rtype) {
252+
if isNumberType(l) && isNumberType(r) {
230253
return true
231-
} else if ltype.Kind() == reflect.Interface {
254+
} else if l.Kind() == reflect.Interface {
232255
return true
233-
} else if rtype.Kind() == reflect.Interface {
256+
} else if r.Kind() == reflect.Interface {
234257
return true
235-
} else if ltype == rtype {
258+
} else if l == r {
236259
return true
237260
}
238261
return false
239262
}
240263

241-
func isBoolType(ntype Type) bool {
242-
ntype = dereference(ntype)
243-
if ntype == nil {
244-
return false
264+
func isInterfaceType(t Type) bool {
265+
t = dereference(t)
266+
if t != nil {
267+
switch t.Kind() {
268+
case reflect.Interface:
269+
return true
270+
}
245271
}
272+
return false
273+
}
246274

247-
switch ntype.Kind() {
248-
case reflect.Interface:
249-
return true
250-
case reflect.Bool:
251-
return true
275+
func isNumberType(t Type) bool {
276+
t = dereference(t)
277+
if t != nil {
278+
switch t.Kind() {
279+
case reflect.Float32, reflect.Float64:
280+
fallthrough
281+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
282+
fallthrough
283+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
284+
return true
285+
}
252286
}
253287
return false
254288
}
255289

256-
func fieldType(ntype Type, name string) (Type, bool) {
257-
ntype = dereference(ntype)
258-
if ntype == nil {
259-
return nil, false
290+
func isBoolType(t Type) bool {
291+
t = dereference(t)
292+
if t != nil {
293+
switch t.Kind() {
294+
case reflect.Bool:
295+
return true
296+
}
260297
}
298+
return false
299+
}
261300

262-
switch ntype.Kind() {
263-
case reflect.Interface:
264-
return interfaceType, true
265-
case reflect.Struct:
266-
if t, ok := ntype.FieldByName(name); ok {
267-
return t.Type, true
301+
func isStringType(t Type) bool {
302+
t = dereference(t)
303+
if t != nil {
304+
switch t.Kind() {
305+
case reflect.String:
306+
return true
268307
}
269-
case reflect.Map:
270-
return ntype.Elem(), true
271308
}
309+
return false
310+
}
272311

312+
func fieldType(ntype Type, name string) (Type, bool) {
313+
ntype = dereference(ntype)
314+
if ntype != nil {
315+
switch ntype.Kind() {
316+
case reflect.Interface:
317+
return interfaceType, true
318+
case reflect.Struct:
319+
if t, ok := ntype.FieldByName(name); ok {
320+
return t.Type, true
321+
}
322+
case reflect.Map:
323+
return ntype.Elem(), true
324+
}
325+
}
273326
return nil, false
274327
}
275328

Diff for: type_test.go

+96-4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ var typeTests = []typeTest{
2525
"Fn(Any)",
2626
"Foo.Fn()",
2727
"true ? Any : Any",
28+
"Ok && Any",
29+
"Str matches 'ok'",
30+
"Str matches Any",
31+
"Any matches Any",
2832
"len([])",
2933
"true == false",
3034
"nil",
@@ -39,6 +43,7 @@ var typeTests = []typeTest{
3943
"Num == Abc",
4044
"Abc == Num",
4145
"1 == 2 and true or Ok",
46+
"Int == Any",
4247
"IntPtr == Int",
4348
"!OkPtr == Ok",
4449
"1 == NumPtr",
@@ -47,6 +52,20 @@ var typeTests = []typeTest{
4752
"nil == nil",
4853
"nil == IntPtr",
4954
"Foo2p.Bar.Baz",
55+
"Int | Num",
56+
"Int ^ Num",
57+
"Int & Num",
58+
"Int < Num",
59+
"Int > Num",
60+
"Int >= Num",
61+
"Int <= Num",
62+
"Int + Num",
63+
"Int - Num",
64+
"Int * Num",
65+
"Int / Num",
66+
"Int % Num",
67+
"Int ** Num",
68+
"Int .. Num",
5069
}
5170

5271
var typeErrorTests = []typeErrorTest{
@@ -110,6 +129,10 @@ var typeErrorTests = []typeErrorTest{
110129
"Map['str'].Not",
111130
`Map["str"].Not undefined (type *expr_test.foo has no field Not)`,
112131
},
132+
{
133+
"Ok && IntPtr",
134+
"invalid operation: (Ok && IntPtr) (mismatched types bool and *int)",
135+
},
113136
{
114137
"No ? Any.Ok : Any.Not",
115138
"unknown name No",
@@ -123,8 +146,16 @@ var typeErrorTests = []typeErrorTest{
123146
"unknown name No",
124147
},
125148
{
126-
"Any ? Any : Any",
127-
"non-bool Any (type map[string]interface {}) used as condition",
149+
"Many ? Any : Any",
150+
"non-bool Many (type map[string]interface {}) used as condition",
151+
},
152+
{
153+
"Str matches Int",
154+
"invalid operation: (Str matches Int) (mismatched types string and int)",
155+
},
156+
{
157+
"Int matches Str",
158+
"invalid operation: (Int matches Str) (mismatched types int and string)",
128159
},
129160
{
130161
"!Not",
@@ -166,6 +197,66 @@ var typeErrorTests = []typeErrorTest{
166197
"not IntPtr",
167198
"invalid operation: not IntPtr (mismatched type *int)",
168199
},
200+
{
201+
"len(Not)",
202+
"unknown name Not",
203+
},
204+
{
205+
"Int | Ok",
206+
"invalid operation: (Int | Ok) (mismatched types int and bool)",
207+
},
208+
{
209+
"Int ^ Ok",
210+
"invalid operation: (Int ^ Ok) (mismatched types int and bool)",
211+
},
212+
{
213+
"Int & Ok",
214+
"invalid operation: (Int & Ok) (mismatched types int and bool)",
215+
},
216+
{
217+
"Int < Ok",
218+
"invalid operation: (Int < Ok) (mismatched types int and bool)",
219+
},
220+
{
221+
"Int > Ok",
222+
"invalid operation: (Int > Ok) (mismatched types int and bool)",
223+
},
224+
{
225+
"Int >= Ok",
226+
"invalid operation: (Int >= Ok) (mismatched types int and bool)",
227+
},
228+
{
229+
"Int <= Ok",
230+
"invalid operation: (Int <= Ok) (mismatched types int and bool)",
231+
},
232+
{
233+
"Int + Ok",
234+
"invalid operation: (Int + Ok) (mismatched types int and bool)",
235+
},
236+
{
237+
"Int - Ok",
238+
"invalid operation: (Int - Ok) (mismatched types int and bool)",
239+
},
240+
{
241+
"Int * Ok",
242+
"invalid operation: (Int * Ok) (mismatched types int and bool)",
243+
},
244+
{
245+
"Int / Ok",
246+
"invalid operation: (Int / Ok) (mismatched types int and bool)",
247+
},
248+
{
249+
"Int % Ok",
250+
"invalid operation: (Int % Ok) (mismatched types int and bool)",
251+
},
252+
{
253+
"Int ** Ok",
254+
"invalid operation: (Int ** Ok) (mismatched types int and bool)",
255+
},
256+
{
257+
"Int .. Ok",
258+
"invalid operation: (Int .. Ok) (mismatched types int and bool)",
259+
},
169260
}
170261

171262
type abc interface {
@@ -183,9 +274,10 @@ type payload struct {
183274
Abc abc
184275
Foo *foo
185276
Arr []*foo
186-
Irr []interface{}
187277
Map map[string]*foo
188-
Any map[string]interface{}
278+
Any interface{}
279+
Irr []interface{}
280+
Many map[string]interface{}
189281
Fn func()
190282
Ok bool
191283
Num float64

Diff for: utils.go

+1-15
Original file line numberDiff line numberDiff line change
@@ -55,21 +55,7 @@ func cast(v interface{}) (float64, error) {
5555

5656
func canBeNumber(v interface{}) bool {
5757
if v != nil {
58-
return canBeNumberType(reflect.TypeOf(v))
59-
}
60-
return false
61-
}
62-
63-
func canBeNumberType(t Type) bool {
64-
if t != nil {
65-
switch t.Kind() {
66-
case reflect.Float32, reflect.Float64:
67-
fallthrough
68-
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
69-
fallthrough
70-
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
71-
return true
72-
}
58+
return isNumberType(reflect.TypeOf(v))
7359
}
7460
return false
7561
}

0 commit comments

Comments
 (0)