@@ -450,39 +450,73 @@ func JSONArrayQuery(column string) *JSONArrayExpression {
450
450
}
451
451
452
452
type JSONArrayExpression struct {
453
+ contains bool
454
+ in bool
453
455
column string
454
456
keys []string
455
457
equalsValue interface {}
456
458
}
457
459
458
- // Contains checks if the column[keys] has contains the value given. The keys parameter is only supported for MySQL.
460
+ // Contains checks if the column[keys] contains the value given. The keys parameter is only supported for MySQL.
459
461
func (json * JSONArrayExpression ) Contains (value interface {}, keys ... string ) * JSONArrayExpression {
462
+ json .contains = true
460
463
json .equalsValue = value
461
464
json .keys = keys
462
465
return json
463
466
}
464
467
468
+ // In checks if columns[keys] is in the array value given. This method is only supported for MySQL.
469
+ func (json * JSONArrayExpression ) In (value interface {}, keys ... string ) * JSONArrayExpression {
470
+ json .in = true
471
+ json .keys = keys
472
+ json .equalsValue = value
473
+ return json
474
+ }
475
+
465
476
// Build implements clause.Expression
466
477
func (json * JSONArrayExpression ) Build (builder clause.Builder ) {
467
478
if stmt , ok := builder .(* gorm.Statement ); ok {
468
479
switch stmt .Dialector .Name () {
469
480
case "mysql" :
470
- builder .WriteString ("JSON_CONTAINS(" + stmt .Quote (json .column ) + ",JSON_ARRAY(" )
471
- builder .AddVar (stmt , json .equalsValue )
472
- builder .WriteByte (')' )
473
- if len (json .keys ) > 0 {
481
+ switch {
482
+ case json .contains :
483
+ builder .WriteString ("JSON_CONTAINS(" + stmt .Quote (json .column ) + ",JSON_ARRAY(" )
484
+ builder .AddVar (stmt , json .equalsValue )
485
+ builder .WriteByte (')' )
486
+ if len (json .keys ) > 0 {
487
+ builder .WriteByte (',' )
488
+ builder .AddVar (stmt , jsonQueryJoin (json .keys ))
489
+ }
490
+ builder .WriteByte (')' )
491
+ case json .in :
492
+ builder .WriteString ("JSON_CONTAINS(JSON_ARRAY" )
493
+ builder .AddVar (stmt , json .equalsValue )
474
494
builder .WriteByte (',' )
475
- builder .AddVar (stmt , jsonQueryJoin (json .keys ))
495
+ if len (json .keys ) > 0 {
496
+ builder .WriteString ("JSON_EXTRACT(" )
497
+ }
498
+ builder .WriteQuoted (json .column )
499
+ if len (json .keys ) > 0 {
500
+ builder .WriteByte (',' )
501
+ builder .AddVar (stmt , jsonQueryJoin (json .keys ))
502
+ builder .WriteByte (')' )
503
+ }
504
+ builder .WriteByte (')' )
476
505
}
477
- builder .WriteByte (')' )
478
506
case "sqlite" :
479
- builder .WriteString ("exists(SELECT 1 FROM json_each(" + stmt .Quote (json .column ) + ") WHERE value = " )
480
- builder .AddVar (stmt , json .equalsValue )
481
- builder .WriteString (")" )
507
+ switch {
508
+ case json .contains :
509
+ builder .WriteString ("exists(SELECT 1 FROM json_each(" + stmt .Quote (json .column ) + ") WHERE value = " )
510
+ builder .AddVar (stmt , json .equalsValue )
511
+ builder .WriteString (")" )
512
+ }
482
513
case "postgres" :
483
- builder .WriteString (stmt .Quote (json .column ))
484
- builder .WriteString (" ? " )
485
- builder .AddVar (stmt , json .equalsValue )
514
+ switch {
515
+ case json .contains :
516
+ builder .WriteString (stmt .Quote (json .column ))
517
+ builder .WriteString (" ? " )
518
+ builder .AddVar (stmt , json .equalsValue )
519
+ }
486
520
}
487
521
}
488
522
}
0 commit comments