Skip to content

Commit 1399d3c

Browse files
authored
feature: support IN query clauses for nested JSON columns (#272)
* feat: support IN for JSONArrayQuery * chore: restrict json in for other dialects * chore: use builder.WriteQuoted for JSON_EXTRACT * fix: remove superfluous quoted * fix: allow in query for nested values only * fix: resolve JSON type and IN operator incompatibility with older MySQL versions * feat: support in queries for non-nested columns as well
1 parent 610acc2 commit 1399d3c

File tree

2 files changed

+64
-14
lines changed

2 files changed

+64
-14
lines changed

json.go

+47-13
Original file line numberDiff line numberDiff line change
@@ -450,39 +450,73 @@ func JSONArrayQuery(column string) *JSONArrayExpression {
450450
}
451451

452452
type JSONArrayExpression struct {
453+
contains bool
454+
in bool
453455
column string
454456
keys []string
455457
equalsValue interface{}
456458
}
457459

458-
// Contains checks if the column[keys] has contains the value given. The keys parameter is only supported for MySQL.
460+
// Contains checks if the column[keys] contains the value given. The keys parameter is only supported for MySQL.
459461
func (json *JSONArrayExpression) Contains(value interface{}, keys ...string) *JSONArrayExpression {
462+
json.contains = true
460463
json.equalsValue = value
461464
json.keys = keys
462465
return json
463466
}
464467

468+
// In checks if columns[keys] is in the array value given. This method is only supported for MySQL.
469+
func (json *JSONArrayExpression) In(value interface{}, keys ...string) *JSONArrayExpression {
470+
json.in = true
471+
json.keys = keys
472+
json.equalsValue = value
473+
return json
474+
}
475+
465476
// Build implements clause.Expression
466477
func (json *JSONArrayExpression) Build(builder clause.Builder) {
467478
if stmt, ok := builder.(*gorm.Statement); ok {
468479
switch stmt.Dialector.Name() {
469480
case "mysql":
470-
builder.WriteString("JSON_CONTAINS(" + stmt.Quote(json.column) + ",JSON_ARRAY(")
471-
builder.AddVar(stmt, json.equalsValue)
472-
builder.WriteByte(')')
473-
if len(json.keys) > 0 {
481+
switch {
482+
case json.contains:
483+
builder.WriteString("JSON_CONTAINS(" + stmt.Quote(json.column) + ",JSON_ARRAY(")
484+
builder.AddVar(stmt, json.equalsValue)
485+
builder.WriteByte(')')
486+
if len(json.keys) > 0 {
487+
builder.WriteByte(',')
488+
builder.AddVar(stmt, jsonQueryJoin(json.keys))
489+
}
490+
builder.WriteByte(')')
491+
case json.in:
492+
builder.WriteString("JSON_CONTAINS(JSON_ARRAY")
493+
builder.AddVar(stmt, json.equalsValue)
474494
builder.WriteByte(',')
475-
builder.AddVar(stmt, jsonQueryJoin(json.keys))
495+
if len(json.keys) > 0 {
496+
builder.WriteString("JSON_EXTRACT(")
497+
}
498+
builder.WriteQuoted(json.column)
499+
if len(json.keys) > 0 {
500+
builder.WriteByte(',')
501+
builder.AddVar(stmt, jsonQueryJoin(json.keys))
502+
builder.WriteByte(')')
503+
}
504+
builder.WriteByte(')')
476505
}
477-
builder.WriteByte(')')
478506
case "sqlite":
479-
builder.WriteString("exists(SELECT 1 FROM json_each(" + stmt.Quote(json.column) + ") WHERE value = ")
480-
builder.AddVar(stmt, json.equalsValue)
481-
builder.WriteString(")")
507+
switch {
508+
case json.contains:
509+
builder.WriteString("exists(SELECT 1 FROM json_each(" + stmt.Quote(json.column) + ") WHERE value = ")
510+
builder.AddVar(stmt, json.equalsValue)
511+
builder.WriteString(")")
512+
}
482513
case "postgres":
483-
builder.WriteString(stmt.Quote(json.column))
484-
builder.WriteString(" ? ")
485-
builder.AddVar(stmt, json.equalsValue)
514+
switch {
515+
case json.contains:
516+
builder.WriteString(stmt.Quote(json.column))
517+
builder.WriteString(" ? ")
518+
builder.AddVar(stmt, json.equalsValue)
519+
}
486520
}
487521
}
488522
}

json_test.go

+17-1
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,6 @@ func TestJSONArrayQuery(t *testing.T) {
463463
DisplayName: "JSONArray-1",
464464
Config: datatypes.JSON("[\"a\", \"b\"]"),
465465
}
466-
467466
cmp2 := Param{
468467
DisplayName: "JSONArray-2",
469468
Config: datatypes.JSON("[\"c\", \"a\"]"),
@@ -472,6 +471,10 @@ func TestJSONArrayQuery(t *testing.T) {
472471
DisplayName: "JSONArray-3",
473472
Config: datatypes.JSON("{\"test\": [\"a\", \"b\"]}"),
474473
}
474+
cmp4 := Param{
475+
DisplayName: "JSONArray-4",
476+
Config: datatypes.JSON("{\"test\": \"c\"}"),
477+
}
475478

476479
if err := DB.Create(&cmp1).Error; err != nil {
477480
t.Errorf("Failed to create param %v", err)
@@ -482,6 +485,9 @@ func TestJSONArrayQuery(t *testing.T) {
482485
if err := DB.Create(&cmp3).Error; err != nil {
483486
t.Errorf("Failed to create param %v", err)
484487
}
488+
if err := DB.Create(&cmp4).Error; err != nil {
489+
t.Errorf("Failed to create param %v", err)
490+
}
485491

486492
var retSingle1 Param
487493
if err := DB.Where("id = ?", cmp2.ID).First(&retSingle1).Error; err != nil {
@@ -507,5 +513,15 @@ func TestJSONArrayQuery(t *testing.T) {
507513
t.Fatalf("failed to find params with json value and keys, got error %v", err)
508514
}
509515
AssertEqual(t, len(retMultiple), 1)
516+
517+
if err := DB.Where(datatypes.JSONArrayQuery("config").In([]string{"c", "a"})).Find(&retMultiple).Error; err != nil {
518+
t.Fatalf("failed to find params with json value, got error %v", err)
519+
}
520+
AssertEqual(t, len(retMultiple), 1)
521+
522+
if err := DB.Where(datatypes.JSONArrayQuery("config").In([]string{"c", "d"}, "test")).Find(&retMultiple).Error; err != nil {
523+
t.Fatalf("failed to find params with json value and keys, got error %v", err)
524+
}
525+
AssertEqual(t, len(retMultiple), 1)
510526
}
511527
}

0 commit comments

Comments
 (0)