Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(database/gdb): issue where the Count/Value/Array query logic was incompatible with the old version when users extended the returned result fields using the Select Hook #3995

Merged
merged 2 commits into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion contrib/drivers/mysql/mysql_z_unit_issue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"testing"
"time"

"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
Expand Down Expand Up @@ -1283,12 +1284,12 @@ func Test_Issue3754(t *testing.T) {
func Test_Issue3626(t *testing.T) {
table := "issue3626"
array := gstr.SplitAndTrim(gtest.DataContent(`issue3626.sql`), ";")
defer dropTable(table)
for _, v := range array {
if _, err := db.Exec(ctx, v); err != nil {
gtest.Error(err)
}
}
defer dropTable(table)

// Insert.
gtest.C(t, func(t *gtest.T) {
Expand Down Expand Up @@ -1377,3 +1378,34 @@ func Test_Issue3932(t *testing.T) {
t.Assert(one["id"], 10)
})
}

// https://github.com/gogf/gf/issues/3968
func Test_Issue3968(t *testing.T) {
table := createInitTable()
defer dropTable(table)

gtest.C(t, func(t *gtest.T) {
var hook = gdb.HookHandler{
Select: func(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) {
result, err = in.Next(ctx)
if err != nil {
return nil, err
}
if result != nil {
for i, _ := range result {
result[i]["location"] = gvar.New("ny")
}
}
return
},
}
var (
count int
result gdb.Result
)
err := db.Model(table).Hook(hook).ScanAndCount(&result, &count, false)
t.AssertNil(err)
t.Assert(count, 10)
t.Assert(len(result), 10)
})
}
21 changes: 11 additions & 10 deletions database/gdb/gdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,12 +396,13 @@ const (
linkPattern = `(\w+):([\w\-\$]*):(.*?)@(\w+?)\((.+?)\)/{0,1}([^\?]*)\?{0,1}(.*)`
)

type queryType int
type SelectType int

const (
queryTypeNormal queryType = iota
queryTypeCount
queryTypeValue
SelectTypeDefault SelectType = iota
SelectTypeCount
SelectTypeValue
SelectTypeArray
)

type joinOperator string
Expand Down Expand Up @@ -700,21 +701,21 @@ func getConfigNodeByWeight(cg ConfigGroup) *ConfigNode {
}
// Exclude the right border value.
var (
min = 0
max = 0
random = grand.N(0, total-1)
minWeight = 0
maxWeight = 0
random = grand.N(0, total-1)
)
for i := 0; i < len(cg); i++ {
max = min + cg[i].Weight*100
if random >= min && random < max {
maxWeight = minWeight + cg[i].Weight*100
if random >= minWeight && random < maxWeight {
// ====================================================
// Return a COPY of the ConfigNode.
// ====================================================
node := ConfigNode{}
node = cg[i]
return &node
}
min = max
minWeight = maxWeight
}
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion database/gdb/gdb_core.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ func (c *Core) doUnion(ctx context.Context, unionType int, unions ...*Model) *Mo
unionTypeStr = "UNION"
}
for _, v := range unions {
sqlWithHolder, holderArgs := v.getFormattedSqlAndArgs(ctx, queryTypeNormal, false)
sqlWithHolder, holderArgs := v.getFormattedSqlAndArgs(ctx, SelectTypeDefault, false)
if composedSqlStr == "" {
composedSqlStr += fmt.Sprintf(`(%s)`, sqlWithHolder)
} else {
Expand Down
2 changes: 0 additions & 2 deletions database/gdb/gdb_core_ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ type internalCtxData struct {
}

// column stores column data in ctx for internal usage purpose.
// Deprecated.
// TODO remove this usage in future.
type internalColumnData struct {
// The first column in result response from database server.
// This attribute is used for Value/Count selection statement purpose,
Expand Down
9 changes: 5 additions & 4 deletions database/gdb/gdb_model_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (m *Model) getSelectResultFromCache(ctx context.Context, sql string, args .
}

func (m *Model) saveSelectResultToCache(
ctx context.Context, queryType queryType, result Result, sql string, args ...interface{},
ctx context.Context, selectType SelectType, result Result, sql string, args ...interface{},
) (err error) {
if !m.cacheEnabled || m.tx != nil {
return
Expand All @@ -108,18 +108,19 @@ func (m *Model) saveSelectResultToCache(
// Special handler for Value/Count operations result.
if len(result) > 0 {
var core = m.db.GetCore()
switch queryType {
case queryTypeValue, queryTypeCount:
switch selectType {
case SelectTypeValue, SelectTypeArray, SelectTypeCount:
if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil {
if result[0][internalData.FirstResultColumn].IsEmpty() {
result = nil
}
}
default:
}
}

// In case of Cache Penetration.
if result.IsEmpty() {
if result != nil && result.IsEmpty() {
if m.cacheOption.Force {
result = Result{}
} else {
Expand Down
11 changes: 6 additions & 5 deletions database/gdb/gdb_model_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,12 @@ type internalParamHookDelete struct {
// which is usually not be interesting for upper business hook handler.
type HookSelectInput struct {
internalParamHookSelect
Model *Model // Current operation Model.
Table string // The table name that to be used. Update this attribute to change target table name.
Schema string // The schema name that to be used. Update this attribute to change target schema name.
Sql string // The sql string that to be committed.
Args []interface{} // The arguments of sql.
Model *Model // Current operation Model.
Table string // The table name that to be used. Update this attribute to change target table name.
Schema string // The schema name that to be used. Update this attribute to change target schema name.
Sql string // The sql string that to be committed.
Args []interface{} // The arguments of sql.
SelectType SelectType // The type of this SELECT operation.
}

// HookInsertInput holds the parameters for insert hook operation.
Expand Down
108 changes: 76 additions & 32 deletions database/gdb/gdb_model_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
// see Model.Where.
func (m *Model) All(where ...interface{}) (Result, error) {
var ctx = m.GetCtx()
return m.doGetAll(ctx, false, where...)
return m.doGetAll(ctx, SelectTypeDefault, false, where...)
}

// AllAndCount retrieves all records and the total count of records from the model.
Expand Down Expand Up @@ -69,7 +69,7 @@ func (m *Model) AllAndCount(useFieldForCount bool) (result Result, totalCount in
}

// Retrieve all records
result, err = m.doGetAll(m.GetCtx(), false)
result, err = m.doGetAll(m.GetCtx(), SelectTypeDefault, false)
return
}

Expand Down Expand Up @@ -110,7 +110,7 @@ func (m *Model) One(where ...interface{}) (Record, error) {
if len(where) > 0 {
return m.Where(where[0], where[1:]...).One()
}
all, err := m.doGetAll(ctx, true)
all, err := m.doGetAll(ctx, SelectTypeDefault, true)
if err != nil {
return nil, err
}
Expand All @@ -136,24 +136,41 @@ func (m *Model) Array(fieldsAndWhere ...interface{}) ([]Value, error) {
return m.Fields(gconv.String(fieldsAndWhere[0])).Array()
}
}
all, err := m.All()

var (
field string
core = m.db.GetCore()
ctx = core.injectInternalColumn(m.GetCtx())
)
all, err := m.doGetAll(ctx, SelectTypeArray, false)
if err != nil {
return nil, err
}
var field string
if len(all) > 0 {
var recordFields = m.getRecordFields(all[0])
if len(recordFields) > 1 {
// it returns error if there are multiple fields in the result record.
return nil, gerror.NewCodef(
gcode.CodeInvalidParameter,
`invalid fields for "Array" operation, result fields number "%d"%s, but expect one`,
len(recordFields),
gjson.MustEncodeString(recordFields),
internalData := core.getInternalColumnFromCtx(ctx)
if internalData == nil {
return nil, gerror.NewCode(
gcode.CodeInternalError,
`query count error: the internal context data is missing. there's internal issue should be fixed`,
)
}
if len(recordFields) == 1 {
field = recordFields[0]
// If FirstResultColumn present, it returns the value of the first record of the first field.
// It means it use no cache mechanism, while cache mechanism makes `internalData` missing.
field = internalData.FirstResultColumn
if field == "" {
// Fields number check.
var recordFields = m.getRecordFields(all[0])
if len(recordFields) == 1 {
field = recordFields[0]
} else {
// it returns error if there are multiple fields in the result record.
return nil, gerror.NewCodef(
gcode.CodeInvalidParameter,
`invalid fields for "Array" operation, result fields number "%d"%s, but expect one`,
len(recordFields),
gjson.MustEncodeString(recordFields),
)
}
}
}
return all.Array(field), nil
Expand Down Expand Up @@ -398,13 +415,26 @@ func (m *Model) Value(fieldsAndWhere ...interface{}) (Value, error) {
}
}
var (
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, queryTypeValue, true)
all, err = m.doGetAllBySql(ctx, queryTypeValue, sqlWithHolder, holderArgs...)
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, SelectTypeValue, true)
all, err = m.doGetAllBySql(ctx, SelectTypeValue, sqlWithHolder, holderArgs...)
)
if err != nil {
return nil, err
}
if len(all) > 0 {
internalData := core.getInternalColumnFromCtx(ctx)
if internalData == nil {
return nil, gerror.NewCode(
gcode.CodeInternalError,
`query count error: the internal context data is missing. there's internal issue should be fixed`,
)
}
// If FirstResultColumn present, it returns the value of the first record of the first field.
// It means it use no cache mechanism, while cache mechanism makes `internalData` missing.
if v, ok := all[0][internalData.FirstResultColumn]; ok {
return v, nil
}
// Fields number check.
var recordFields = m.getRecordFields(all[0])
if len(recordFields) == 1 {
for _, v := range all[0] {
Expand Down Expand Up @@ -445,13 +475,26 @@ func (m *Model) Count(where ...interface{}) (int, error) {
return m.Where(where[0], where[1:]...).Count()
}
var (
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, queryTypeCount, false)
all, err = m.doGetAllBySql(ctx, queryTypeCount, sqlWithHolder, holderArgs...)
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, SelectTypeCount, false)
all, err = m.doGetAllBySql(ctx, SelectTypeCount, sqlWithHolder, holderArgs...)
)
if err != nil {
return 0, err
}
if len(all) > 0 {
internalData := core.getInternalColumnFromCtx(ctx)
if internalData == nil {
return 0, gerror.NewCode(
gcode.CodeInternalError,
`query count error: the internal context data is missing. there's internal issue should be fixed`,
)
}
// If FirstResultColumn present, it returns the value of the first record of the first field.
// It means it use no cache mechanism, while cache mechanism makes `internalData` missing.
if v, ok := all[0][internalData.FirstResultColumn]; ok {
return v.Int(), nil
}
// Fields number check.
var recordFields = m.getRecordFields(all[0])
if len(recordFields) == 1 {
for _, v := range all[0] {
Expand Down Expand Up @@ -616,17 +659,17 @@ func (m *Model) Having(having interface{}, args ...interface{}) *Model {
// The parameter `limit1` specifies whether limits querying only one record if m.limit is not set.
// The optional parameter `where` is the same as the parameter of Model.Where function,
// see Model.Where.
func (m *Model) doGetAll(ctx context.Context, limit1 bool, where ...interface{}) (Result, error) {
func (m *Model) doGetAll(ctx context.Context, selectType SelectType, limit1 bool, where ...interface{}) (Result, error) {
if len(where) > 0 {
return m.Where(where[0], where[1:]...).All()
}
sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(ctx, queryTypeNormal, limit1)
return m.doGetAllBySql(ctx, queryTypeNormal, sqlWithHolder, holderArgs...)
sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(ctx, selectType, limit1)
return m.doGetAllBySql(ctx, selectType, sqlWithHolder, holderArgs...)
}

// doGetAllBySql does the select statement on the database.
func (m *Model) doGetAllBySql(
ctx context.Context, queryType queryType, sql string, args ...interface{},
ctx context.Context, selectType SelectType, sql string, args ...interface{},
) (result Result, err error) {
if result, err = m.getSelectResultFromCache(ctx, sql, args...); err != nil || result != nil {
return
Expand All @@ -639,24 +682,25 @@ func (m *Model) doGetAllBySql(
},
handler: m.hookHandler.Select,
},
Model: m,
Table: m.tables,
Sql: sql,
Args: m.mergeArguments(args),
Model: m,
Table: m.tables,
Sql: sql,
Args: m.mergeArguments(args),
SelectType: selectType,
}
if result, err = in.Next(ctx); err != nil {
return
}

err = m.saveSelectResultToCache(ctx, queryType, result, sql, args...)
err = m.saveSelectResultToCache(ctx, selectType, result, sql, args...)
return
}

func (m *Model) getFormattedSqlAndArgs(
ctx context.Context, queryType queryType, limit1 bool,
ctx context.Context, selectType SelectType, limit1 bool,
) (sqlWithHolder string, holderArgs []interface{}) {
switch queryType {
case queryTypeCount:
switch selectType {
case SelectTypeCount:
queryFields := "COUNT(1)"
if len(m.fields) > 0 {
// DO NOT quote the m.fields here, in case of fields like:
Expand Down Expand Up @@ -698,7 +742,7 @@ func (m *Model) getFormattedSqlAndArgs(

func (m *Model) getHolderAndArgsAsSubModel(ctx context.Context) (holder string, args []interface{}) {
holder, args = m.getFormattedSqlAndArgs(
ctx, queryTypeNormal, false,
ctx, SelectTypeDefault, false,
)
args = m.mergeArguments(args)
return
Expand Down
Loading