From d3050c1ed3fc0bee9f8ce394d35a3132f8a629e3 Mon Sep 17 00:00:00 2001 From: CyJaySong Date: Tue, 24 Dec 2024 11:53:53 +0800 Subject: [PATCH] fix(database/gdb): gdb.Counter not work in OnDuplicate --- .../drivers/mysql/mysql_z_unit_model_test.go | 22 ++++++++ contrib/drivers/pgsql/pgsql_format_upsert.go | 19 +++++++ .../drivers/pgsql/pgsql_z_unit_model_test.go | 22 ++++++++ .../drivers/sqlite/sqlite_format_upsert.go | 19 +++++++ .../sqlite/sqlite_z_unit_model_test.go | 22 ++++++++ database/gdb/gdb_core.go | 51 ++++++++++--------- database/gdb/gdb_core_underlying.go | 16 ++++++ 7 files changed, 146 insertions(+), 25 deletions(-) diff --git a/contrib/drivers/mysql/mysql_z_unit_model_test.go b/contrib/drivers/mysql/mysql_z_unit_model_test.go index 9be10ff1313..d28f461c241 100644 --- a/contrib/drivers/mysql/mysql_z_unit_model_test.go +++ b/contrib/drivers/mysql/mysql_z_unit_model_test.go @@ -2812,6 +2812,28 @@ func Test_Model_OnDuplicate(t *testing.T) { }) } +func Test_Model_OnDuplicateWithCounter(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicate(g.Map{ + "id": gdb.Counter{Field: "id", Value: 999999}, + }).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.AssertNil(one) + }) +} + func Test_Model_OnDuplicateEx(t *testing.T) { table := createInitTable() defer dropTable(table) diff --git a/contrib/drivers/pgsql/pgsql_format_upsert.go b/contrib/drivers/pgsql/pgsql_format_upsert.go index c4c8af91122..fc003cb4ce5 100644 --- a/contrib/drivers/pgsql/pgsql_format_upsert.go +++ b/contrib/drivers/pgsql/pgsql_format_upsert.go @@ -40,6 +40,25 @@ func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInse d.Core.QuoteWord(k), v, ) + case gdb.Counter, *gdb.Counter: + var counter gdb.Counter + switch value := v.(type) { + case gdb.Counter: + counter = value + case *gdb.Counter: + counter = *value + } + operator, columnVal := "+", counter.Value + if columnVal < 0 { + operator, columnVal = "-", -columnVal + } + onDuplicateStr += fmt.Sprintf( + "%s=EXCLUDED.%s%s%s", + d.QuoteWord(k), + d.QuoteWord(counter.Field), + operator, + gconv.String(columnVal), + ) default: onDuplicateStr += fmt.Sprintf( "%s=EXCLUDED.%s", diff --git a/contrib/drivers/pgsql/pgsql_z_unit_model_test.go b/contrib/drivers/pgsql/pgsql_z_unit_model_test.go index d7748f07f17..3a51ba264dc 100644 --- a/contrib/drivers/pgsql/pgsql_z_unit_model_test.go +++ b/contrib/drivers/pgsql/pgsql_z_unit_model_test.go @@ -521,6 +521,28 @@ func Test_Model_OnDuplicate(t *testing.T) { }) } +func Test_Model_OnDuplicateWithCounter(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicate(g.Map{ + "id": gdb.Counter{Field: "id", Value: 999999}, + }).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.AssertNil(one) + }) +} + func Test_Model_OnDuplicateEx(t *testing.T) { table := createInitTable() defer dropTable(table) diff --git a/contrib/drivers/sqlite/sqlite_format_upsert.go b/contrib/drivers/sqlite/sqlite_format_upsert.go index 5821144a13e..c80d5dfd74e 100644 --- a/contrib/drivers/sqlite/sqlite_format_upsert.go +++ b/contrib/drivers/sqlite/sqlite_format_upsert.go @@ -40,6 +40,25 @@ func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInse d.Core.QuoteWord(k), v, ) + case gdb.Counter, *gdb.Counter: + var counter gdb.Counter + switch value := v.(type) { + case gdb.Counter: + counter = value + case *gdb.Counter: + counter = *value + } + operator, columnVal := "+", counter.Value + if columnVal < 0 { + operator, columnVal = "-", -columnVal + } + onDuplicateStr += fmt.Sprintf( + "%s=EXCLUDED.%s%s%s", + d.QuoteWord(k), + d.QuoteWord(counter.Field), + operator, + gconv.String(columnVal), + ) default: onDuplicateStr += fmt.Sprintf( "%s=EXCLUDED.%s", diff --git a/contrib/drivers/sqlite/sqlite_z_unit_model_test.go b/contrib/drivers/sqlite/sqlite_z_unit_model_test.go index 19e97bfa51a..03e8465c7fc 100644 --- a/contrib/drivers/sqlite/sqlite_z_unit_model_test.go +++ b/contrib/drivers/sqlite/sqlite_z_unit_model_test.go @@ -4324,3 +4324,25 @@ func Test_OrderRandom(t *testing.T) { t.Assert(len(result), TableSize) }) } + +func Test_Model_OnDuplicateWithCounter(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicate(g.Map{ + "id": gdb.Counter{Field: "id", Value: 999999}, + }).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.AssertNil(one) + }) +} diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index 50971611258..06dcbd7ca16 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -583,24 +583,8 @@ func (c *Core) DoUpdate(ctx context.Context, link Link, table string, data inter switch kind { case reflect.Map, reflect.Struct: var ( - fields []string - dataMap map[string]interface{} - counterHandler = func(column string, counter Counter) { - if counter.Value != 0 { - column = c.QuoteWord(column) - var ( - columnRef = c.QuoteWord(counter.Field) - columnVal = counter.Value - operator = "+" - ) - if columnVal < 0 { - operator = "-" - columnVal = -columnVal - } - fields = append(fields, fmt.Sprintf("%s=%s%s?", column, columnRef, operator)) - params = append(params, columnVal) - } - } + fields []string + dataMap map[string]interface{} ) dataMap, err = c.ConvertDataForRecord(ctx, data, table) if err != nil { @@ -620,13 +604,21 @@ func (c *Core) DoUpdate(ctx context.Context, link Link, table string, data inter } for _, k := range keysInSequence { v := dataMap[k] - switch value := v.(type) { - case *Counter: - counterHandler(k, *value) - - case Counter: - counterHandler(k, value) - + switch v.(type) { + case Counter, *Counter: + var counter Counter + switch value := v.(type) { + case Counter: + counter = value + case *Counter: + counter = *value + } + if counter.Value == 0 { + continue + } + operator, columnVal := c.getCounterAlter(counter) + fields = append(fields, fmt.Sprintf("%s=%s%s?", c.QuoteWord(k), c.QuoteWord(counter.Field), operator)) + params = append(params, columnVal) default: if s, ok := v.(Raw); ok { fields = append(fields, c.QuoteWord(k)+"="+gconv.String(s)) @@ -796,3 +788,12 @@ func (c *Core) IsSoftCreatedFieldName(fieldName string) bool { func (c *Core) FormatSqlBeforeExecuting(sql string, args []interface{}) (newSql string, newArgs []interface{}) { return handleSliceAndStructArgsForSql(sql, args) } + +// getCounterAlter The polished counter alter +func (c *Core) getCounterAlter(counter Counter) (operator string, columnVal float64) { + operator, columnVal = "+", counter.Value + if columnVal < 0 { + operator, columnVal = "-", -columnVal + } + return +} diff --git a/database/gdb/gdb_core_underlying.go b/database/gdb/gdb_core_underlying.go index 045d11c65af..25c60a4baf7 100644 --- a/database/gdb/gdb_core_underlying.go +++ b/database/gdb/gdb_core_underlying.go @@ -388,6 +388,22 @@ func (c *Core) FormatUpsert(columns []string, list List, option DoInsertOption) c.QuoteWord(k), v, ) + case Counter, *Counter: + var counter Counter + switch value := v.(type) { + case Counter: + counter = value + case *Counter: + counter = *value + } + operator, columnVal := c.getCounterAlter(counter) + onDuplicateStr += fmt.Sprintf( + "%s=%s%s%s", + c.QuoteWord(k), + c.QuoteWord(counter.Field), + operator, + gconv.String(columnVal), + ) default: onDuplicateStr += fmt.Sprintf( "%s=VALUES(%s)",