Skip to content

Commit 2c916f8

Browse files
authored
fix(database/gdb): issue where the Count/Value/Array query logic was incompatible with the old version when users extended the returned result fields using the Select Hook (#3995)
1 parent 42eae41 commit 2c916f8

File tree

7 files changed

+132
-55
lines changed

7 files changed

+132
-55
lines changed

contrib/drivers/mysql/mysql_z_unit_issue_test.go

+33-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"testing"
1414
"time"
1515

16+
"github.com/gogf/gf/v2/container/gvar"
1617
"github.com/gogf/gf/v2/database/gdb"
1718
"github.com/gogf/gf/v2/frame/g"
1819
"github.com/gogf/gf/v2/os/gtime"
@@ -1283,12 +1284,12 @@ func Test_Issue3754(t *testing.T) {
12831284
func Test_Issue3626(t *testing.T) {
12841285
table := "issue3626"
12851286
array := gstr.SplitAndTrim(gtest.DataContent(`issue3626.sql`), ";")
1287+
defer dropTable(table)
12861288
for _, v := range array {
12871289
if _, err := db.Exec(ctx, v); err != nil {
12881290
gtest.Error(err)
12891291
}
12901292
}
1291-
defer dropTable(table)
12921293

12931294
// Insert.
12941295
gtest.C(t, func(t *gtest.T) {
@@ -1377,3 +1378,34 @@ func Test_Issue3932(t *testing.T) {
13771378
t.Assert(one["id"], 10)
13781379
})
13791380
}
1381+
1382+
// https://github.com/gogf/gf/issues/3968
1383+
func Test_Issue3968(t *testing.T) {
1384+
table := createInitTable()
1385+
defer dropTable(table)
1386+
1387+
gtest.C(t, func(t *gtest.T) {
1388+
var hook = gdb.HookHandler{
1389+
Select: func(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) {
1390+
result, err = in.Next(ctx)
1391+
if err != nil {
1392+
return nil, err
1393+
}
1394+
if result != nil {
1395+
for i, _ := range result {
1396+
result[i]["location"] = gvar.New("ny")
1397+
}
1398+
}
1399+
return
1400+
},
1401+
}
1402+
var (
1403+
count int
1404+
result gdb.Result
1405+
)
1406+
err := db.Model(table).Hook(hook).ScanAndCount(&result, &count, false)
1407+
t.AssertNil(err)
1408+
t.Assert(count, 10)
1409+
t.Assert(len(result), 10)
1410+
})
1411+
}

database/gdb/gdb.go

+11-10
Original file line numberDiff line numberDiff line change
@@ -396,12 +396,13 @@ const (
396396
linkPattern = `(\w+):([\w\-\$]*):(.*?)@(\w+?)\((.+?)\)/{0,1}([^\?]*)\?{0,1}(.*)`
397397
)
398398

399-
type queryType int
399+
type SelectType int
400400

401401
const (
402-
queryTypeNormal queryType = iota
403-
queryTypeCount
404-
queryTypeValue
402+
SelectTypeDefault SelectType = iota
403+
SelectTypeCount
404+
SelectTypeValue
405+
SelectTypeArray
405406
)
406407

407408
type joinOperator string
@@ -700,21 +701,21 @@ func getConfigNodeByWeight(cg ConfigGroup) *ConfigNode {
700701
}
701702
// Exclude the right border value.
702703
var (
703-
min = 0
704-
max = 0
705-
random = grand.N(0, total-1)
704+
minWeight = 0
705+
maxWeight = 0
706+
random = grand.N(0, total-1)
706707
)
707708
for i := 0; i < len(cg); i++ {
708-
max = min + cg[i].Weight*100
709-
if random >= min && random < max {
709+
maxWeight = minWeight + cg[i].Weight*100
710+
if random >= minWeight && random < maxWeight {
710711
// ====================================================
711712
// Return a COPY of the ConfigNode.
712713
// ====================================================
713714
node := ConfigNode{}
714715
node = cg[i]
715716
return &node
716717
}
717-
min = max
718+
minWeight = maxWeight
718719
}
719720
return nil
720721
}

database/gdb/gdb_core.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ func (c *Core) doUnion(ctx context.Context, unionType int, unions ...*Model) *Mo
278278
unionTypeStr = "UNION"
279279
}
280280
for _, v := range unions {
281-
sqlWithHolder, holderArgs := v.getFormattedSqlAndArgs(ctx, queryTypeNormal, false)
281+
sqlWithHolder, holderArgs := v.getFormattedSqlAndArgs(ctx, SelectTypeDefault, false)
282282
if composedSqlStr == "" {
283283
composedSqlStr += fmt.Sprintf(`(%s)`, sqlWithHolder)
284284
} else {

database/gdb/gdb_core_ctx.go

-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ type internalCtxData struct {
2323
}
2424

2525
// column stores column data in ctx for internal usage purpose.
26-
// Deprecated.
27-
// TODO remove this usage in future.
2826
type internalColumnData struct {
2927
// The first column in result response from database server.
3028
// This attribute is used for Value/Count selection statement purpose,

database/gdb/gdb_model_cache.go

+5-4
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ func (m *Model) getSelectResultFromCache(ctx context.Context, sql string, args .
9090
}
9191

9292
func (m *Model) saveSelectResultToCache(
93-
ctx context.Context, queryType queryType, result Result, sql string, args ...interface{},
93+
ctx context.Context, selectType SelectType, result Result, sql string, args ...interface{},
9494
) (err error) {
9595
if !m.cacheEnabled || m.tx != nil {
9696
return
@@ -108,18 +108,19 @@ func (m *Model) saveSelectResultToCache(
108108
// Special handler for Value/Count operations result.
109109
if len(result) > 0 {
110110
var core = m.db.GetCore()
111-
switch queryType {
112-
case queryTypeValue, queryTypeCount:
111+
switch selectType {
112+
case SelectTypeValue, SelectTypeArray, SelectTypeCount:
113113
if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil {
114114
if result[0][internalData.FirstResultColumn].IsEmpty() {
115115
result = nil
116116
}
117117
}
118+
default:
118119
}
119120
}
120121

121122
// In case of Cache Penetration.
122-
if result.IsEmpty() {
123+
if result != nil && result.IsEmpty() {
123124
if m.cacheOption.Force {
124125
result = Result{}
125126
} else {

database/gdb/gdb_model_hook.go

+6-5
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,12 @@ type internalParamHookDelete struct {
6666
// which is usually not be interesting for upper business hook handler.
6767
type HookSelectInput struct {
6868
internalParamHookSelect
69-
Model *Model // Current operation Model.
70-
Table string // The table name that to be used. Update this attribute to change target table name.
71-
Schema string // The schema name that to be used. Update this attribute to change target schema name.
72-
Sql string // The sql string that to be committed.
73-
Args []interface{} // The arguments of sql.
69+
Model *Model // Current operation Model.
70+
Table string // The table name that to be used. Update this attribute to change target table name.
71+
Schema string // The schema name that to be used. Update this attribute to change target schema name.
72+
Sql string // The sql string that to be committed.
73+
Args []interface{} // The arguments of sql.
74+
SelectType SelectType // The type of this SELECT operation.
7475
}
7576

7677
// HookInsertInput holds the parameters for insert hook operation.

database/gdb/gdb_model_select.go

+76-32
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import (
2828
// see Model.Where.
2929
func (m *Model) All(where ...interface{}) (Result, error) {
3030
var ctx = m.GetCtx()
31-
return m.doGetAll(ctx, false, where...)
31+
return m.doGetAll(ctx, SelectTypeDefault, false, where...)
3232
}
3333

3434
// AllAndCount retrieves all records and the total count of records from the model.
@@ -69,7 +69,7 @@ func (m *Model) AllAndCount(useFieldForCount bool) (result Result, totalCount in
6969
}
7070

7171
// Retrieve all records
72-
result, err = m.doGetAll(m.GetCtx(), false)
72+
result, err = m.doGetAll(m.GetCtx(), SelectTypeDefault, false)
7373
return
7474
}
7575

@@ -110,7 +110,7 @@ func (m *Model) One(where ...interface{}) (Record, error) {
110110
if len(where) > 0 {
111111
return m.Where(where[0], where[1:]...).One()
112112
}
113-
all, err := m.doGetAll(ctx, true)
113+
all, err := m.doGetAll(ctx, SelectTypeDefault, true)
114114
if err != nil {
115115
return nil, err
116116
}
@@ -136,24 +136,41 @@ func (m *Model) Array(fieldsAndWhere ...interface{}) ([]Value, error) {
136136
return m.Fields(gconv.String(fieldsAndWhere[0])).Array()
137137
}
138138
}
139-
all, err := m.All()
139+
140+
var (
141+
field string
142+
core = m.db.GetCore()
143+
ctx = core.injectInternalColumn(m.GetCtx())
144+
)
145+
all, err := m.doGetAll(ctx, SelectTypeArray, false)
140146
if err != nil {
141147
return nil, err
142148
}
143-
var field string
144149
if len(all) > 0 {
145-
var recordFields = m.getRecordFields(all[0])
146-
if len(recordFields) > 1 {
147-
// it returns error if there are multiple fields in the result record.
148-
return nil, gerror.NewCodef(
149-
gcode.CodeInvalidParameter,
150-
`invalid fields for "Array" operation, result fields number "%d"%s, but expect one`,
151-
len(recordFields),
152-
gjson.MustEncodeString(recordFields),
150+
internalData := core.getInternalColumnFromCtx(ctx)
151+
if internalData == nil {
152+
return nil, gerror.NewCode(
153+
gcode.CodeInternalError,
154+
`query count error: the internal context data is missing. there's internal issue should be fixed`,
153155
)
154156
}
155-
if len(recordFields) == 1 {
156-
field = recordFields[0]
157+
// If FirstResultColumn present, it returns the value of the first record of the first field.
158+
// It means it use no cache mechanism, while cache mechanism makes `internalData` missing.
159+
field = internalData.FirstResultColumn
160+
if field == "" {
161+
// Fields number check.
162+
var recordFields = m.getRecordFields(all[0])
163+
if len(recordFields) == 1 {
164+
field = recordFields[0]
165+
} else {
166+
// it returns error if there are multiple fields in the result record.
167+
return nil, gerror.NewCodef(
168+
gcode.CodeInvalidParameter,
169+
`invalid fields for "Array" operation, result fields number "%d"%s, but expect one`,
170+
len(recordFields),
171+
gjson.MustEncodeString(recordFields),
172+
)
173+
}
157174
}
158175
}
159176
return all.Array(field), nil
@@ -398,13 +415,26 @@ func (m *Model) Value(fieldsAndWhere ...interface{}) (Value, error) {
398415
}
399416
}
400417
var (
401-
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, queryTypeValue, true)
402-
all, err = m.doGetAllBySql(ctx, queryTypeValue, sqlWithHolder, holderArgs...)
418+
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, SelectTypeValue, true)
419+
all, err = m.doGetAllBySql(ctx, SelectTypeValue, sqlWithHolder, holderArgs...)
403420
)
404421
if err != nil {
405422
return nil, err
406423
}
407424
if len(all) > 0 {
425+
internalData := core.getInternalColumnFromCtx(ctx)
426+
if internalData == nil {
427+
return nil, gerror.NewCode(
428+
gcode.CodeInternalError,
429+
`query count error: the internal context data is missing. there's internal issue should be fixed`,
430+
)
431+
}
432+
// If FirstResultColumn present, it returns the value of the first record of the first field.
433+
// It means it use no cache mechanism, while cache mechanism makes `internalData` missing.
434+
if v, ok := all[0][internalData.FirstResultColumn]; ok {
435+
return v, nil
436+
}
437+
// Fields number check.
408438
var recordFields = m.getRecordFields(all[0])
409439
if len(recordFields) == 1 {
410440
for _, v := range all[0] {
@@ -445,13 +475,26 @@ func (m *Model) Count(where ...interface{}) (int, error) {
445475
return m.Where(where[0], where[1:]...).Count()
446476
}
447477
var (
448-
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, queryTypeCount, false)
449-
all, err = m.doGetAllBySql(ctx, queryTypeCount, sqlWithHolder, holderArgs...)
478+
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, SelectTypeCount, false)
479+
all, err = m.doGetAllBySql(ctx, SelectTypeCount, sqlWithHolder, holderArgs...)
450480
)
451481
if err != nil {
452482
return 0, err
453483
}
454484
if len(all) > 0 {
485+
internalData := core.getInternalColumnFromCtx(ctx)
486+
if internalData == nil {
487+
return 0, gerror.NewCode(
488+
gcode.CodeInternalError,
489+
`query count error: the internal context data is missing. there's internal issue should be fixed`,
490+
)
491+
}
492+
// If FirstResultColumn present, it returns the value of the first record of the first field.
493+
// It means it use no cache mechanism, while cache mechanism makes `internalData` missing.
494+
if v, ok := all[0][internalData.FirstResultColumn]; ok {
495+
return v.Int(), nil
496+
}
497+
// Fields number check.
455498
var recordFields = m.getRecordFields(all[0])
456499
if len(recordFields) == 1 {
457500
for _, v := range all[0] {
@@ -616,17 +659,17 @@ func (m *Model) Having(having interface{}, args ...interface{}) *Model {
616659
// The parameter `limit1` specifies whether limits querying only one record if m.limit is not set.
617660
// The optional parameter `where` is the same as the parameter of Model.Where function,
618661
// see Model.Where.
619-
func (m *Model) doGetAll(ctx context.Context, limit1 bool, where ...interface{}) (Result, error) {
662+
func (m *Model) doGetAll(ctx context.Context, selectType SelectType, limit1 bool, where ...interface{}) (Result, error) {
620663
if len(where) > 0 {
621664
return m.Where(where[0], where[1:]...).All()
622665
}
623-
sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(ctx, queryTypeNormal, limit1)
624-
return m.doGetAllBySql(ctx, queryTypeNormal, sqlWithHolder, holderArgs...)
666+
sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(ctx, selectType, limit1)
667+
return m.doGetAllBySql(ctx, selectType, sqlWithHolder, holderArgs...)
625668
}
626669

627670
// doGetAllBySql does the select statement on the database.
628671
func (m *Model) doGetAllBySql(
629-
ctx context.Context, queryType queryType, sql string, args ...interface{},
672+
ctx context.Context, selectType SelectType, sql string, args ...interface{},
630673
) (result Result, err error) {
631674
if result, err = m.getSelectResultFromCache(ctx, sql, args...); err != nil || result != nil {
632675
return
@@ -639,24 +682,25 @@ func (m *Model) doGetAllBySql(
639682
},
640683
handler: m.hookHandler.Select,
641684
},
642-
Model: m,
643-
Table: m.tables,
644-
Sql: sql,
645-
Args: m.mergeArguments(args),
685+
Model: m,
686+
Table: m.tables,
687+
Sql: sql,
688+
Args: m.mergeArguments(args),
689+
SelectType: selectType,
646690
}
647691
if result, err = in.Next(ctx); err != nil {
648692
return
649693
}
650694

651-
err = m.saveSelectResultToCache(ctx, queryType, result, sql, args...)
695+
err = m.saveSelectResultToCache(ctx, selectType, result, sql, args...)
652696
return
653697
}
654698

655699
func (m *Model) getFormattedSqlAndArgs(
656-
ctx context.Context, queryType queryType, limit1 bool,
700+
ctx context.Context, selectType SelectType, limit1 bool,
657701
) (sqlWithHolder string, holderArgs []interface{}) {
658-
switch queryType {
659-
case queryTypeCount:
702+
switch selectType {
703+
case SelectTypeCount:
660704
queryFields := "COUNT(1)"
661705
if len(m.fields) > 0 {
662706
// DO NOT quote the m.fields here, in case of fields like:
@@ -698,7 +742,7 @@ func (m *Model) getFormattedSqlAndArgs(
698742

699743
func (m *Model) getHolderAndArgsAsSubModel(ctx context.Context) (holder string, args []interface{}) {
700744
holder, args = m.getFormattedSqlAndArgs(
701-
ctx, queryTypeNormal, false,
745+
ctx, SelectTypeDefault, false,
702746
)
703747
args = m.mergeArguments(args)
704748
return

0 commit comments

Comments
 (0)