From fbd266fad058e5ca8e4faf62bb0c93521cc30dfa Mon Sep 17 00:00:00 2001 From: oldme <45782393+oldme-git@users.noreply.github.com> Date: Wed, 6 Mar 2024 19:22:58 +0800 Subject: [PATCH] enhance: add `Save` operation support for SQLite #2764 (#3315) --- contrib/drivers/sqlite/sqlite.go | 3 - contrib/drivers/sqlite/sqlite_do_filter.go | 10 --- .../drivers/sqlite/sqlite_format_upsert.go | 68 +++++++++++++++++++ .../drivers/sqlite/sqlite_z_unit_core_test.go | 25 ++++--- .../drivers/sqlite/sqlite_z_unit_init_test.go | 5 -- .../sqlite/sqlite_z_unit_model_test.go | 56 +++++++++++++-- .../sqlitecgo/sqlitecgo_z_unit_core_test.go | 25 ++++--- .../sqlitecgo/sqlitecgo_z_unit_model_test.go | 56 +++++++++++++-- 8 files changed, 196 insertions(+), 52 deletions(-) create mode 100644 contrib/drivers/sqlite/sqlite_format_upsert.go diff --git a/contrib/drivers/sqlite/sqlite.go b/contrib/drivers/sqlite/sqlite.go index 65554930bca..de48aaad988 100644 --- a/contrib/drivers/sqlite/sqlite.go +++ b/contrib/drivers/sqlite/sqlite.go @@ -5,9 +5,6 @@ // You can obtain one at https://github.com/gogf/gf. // Package sqlite implements gdb.Driver, which supports operations for database SQLite. -// -// Note: -// 1. It does not support Save features. package sqlite import ( diff --git a/contrib/drivers/sqlite/sqlite_do_filter.go b/contrib/drivers/sqlite/sqlite_do_filter.go index 1fa344962f8..fa605767435 100644 --- a/contrib/drivers/sqlite/sqlite_do_filter.go +++ b/contrib/drivers/sqlite/sqlite_do_filter.go @@ -10,8 +10,6 @@ import ( "context" "github.com/gogf/gf/v2/database/gdb" - "github.com/gogf/gf/v2/errors/gcode" - "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/text/gstr" ) @@ -24,14 +22,6 @@ func (d *Driver) DoFilter(ctx context.Context, link gdb.Link, sql string, args [ case gstr.HasPrefix(sql, gdb.InsertOperationReplace): sql = "INSERT OR REPLACE" + sql[len(gdb.InsertOperationReplace):] - - default: - if gstr.Contains(sql, gdb.InsertOnDuplicateKeyUpdate) { - return sql, args, gerror.NewCode( - gcode.CodeNotSupported, - `Save operation is not supported by sqlite driver`, - ) - } } return d.Core.DoFilter(ctx, link, sql, args) } diff --git a/contrib/drivers/sqlite/sqlite_format_upsert.go b/contrib/drivers/sqlite/sqlite_format_upsert.go new file mode 100644 index 00000000000..34fc3ccce64 --- /dev/null +++ b/contrib/drivers/sqlite/sqlite_format_upsert.go @@ -0,0 +1,68 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package sqlite + +import ( + "fmt" + + "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" +) + +// FormatUpsert returns SQL clause of type upsert for SQLite. +// For example: ON CONFLICT (id) DO UPDATE SET ... +func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInsertOption) (string, error) { + if len(option.OnConflict) == 0 { + return "", gerror.New("Please specify conflict columns") + } + + var onDuplicateStr string + if option.OnDuplicateStr != "" { + onDuplicateStr = option.OnDuplicateStr + } else if len(option.OnDuplicateMap) > 0 { + for k, v := range option.OnDuplicateMap { + if len(onDuplicateStr) > 0 { + onDuplicateStr += "," + } + switch v.(type) { + case gdb.Raw, *gdb.Raw: + onDuplicateStr += fmt.Sprintf( + "%s=%s", + d.Core.QuoteWord(k), + v, + ) + default: + onDuplicateStr += fmt.Sprintf( + "%s=EXCLUDED.%s", + d.Core.QuoteWord(k), + d.Core.QuoteWord(gconv.String(v)), + ) + } + } + } else { + for _, column := range columns { + // If it's SAVE operation, do not automatically update the creating time. + if d.Core.IsSoftCreatedFieldName(column) { + continue + } + if len(onDuplicateStr) > 0 { + onDuplicateStr += "," + } + onDuplicateStr += fmt.Sprintf( + "%s=EXCLUDED.%s", + d.Core.QuoteWord(column), + d.Core.QuoteWord(column), + ) + } + } + + conflictKeys := gstr.Join(option.OnConflict, ",") + + return fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET ", conflictKeys) + onDuplicateStr, nil +} diff --git a/contrib/drivers/sqlite/sqlite_z_unit_core_test.go b/contrib/drivers/sqlite/sqlite_z_unit_core_test.go index b45fd3ddc4e..a8a4b232ff8 100644 --- a/contrib/drivers/sqlite/sqlite_z_unit_core_test.go +++ b/contrib/drivers/sqlite/sqlite_z_unit_core_test.go @@ -425,19 +425,22 @@ func Test_DB_BatchInsert_Struct(t *testing.T) { } func Test_DB_Save(t *testing.T) { - table := createInitTable() + table := createTable() defer dropTable(table) - gtest.C(t, func(t *gtest.T) { - timeStr := gtime.Now().String() - _, err := db.Save(ctx, table, g.Map{ - "id": 1, - "passport": "t1", - "password": "25d55ad283aa400af464c76d713c07ad", - "nickname": "T11", - "create_time": timeStr, - }) - t.Assert(err, ErrorSave) + createTable("t_user") + defer dropTable("t_user") + + i := 10 + data := g.Map{ + "id": i, + "passport": fmt.Sprintf(`t%d`, i), + "password": fmt.Sprintf(`p%d`, i), + "nickname": fmt.Sprintf(`T%d`, i), + "create_time": gtime.Now().String(), + } + _, err := db.Save(ctx, "t_user", data, 10) + gtest.AssertNE(err, nil) }) } diff --git a/contrib/drivers/sqlite/sqlite_z_unit_init_test.go b/contrib/drivers/sqlite/sqlite_z_unit_init_test.go index 4a5f1a67fd6..6d82d9412e3 100644 --- a/contrib/drivers/sqlite/sqlite_z_unit_init_test.go +++ b/contrib/drivers/sqlite/sqlite_z_unit_init_test.go @@ -11,8 +11,6 @@ import ( "github.com/gogf/gf/v2/container/garray" "github.com/gogf/gf/v2/database/gdb" - "github.com/gogf/gf/v2/errors/gcode" - "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/gctx" "github.com/gogf/gf/v2/os/gfile" @@ -27,9 +25,6 @@ var ( configNode gdb.ConfigNode dbDir = gfile.Temp("sqlite") ctx = gctx.New() - - // Error - ErrorSave = gerror.NewCode(gcode.CodeNotSupported, `Save operation is not supported by sqlite driver`) ) const ( diff --git a/contrib/drivers/sqlite/sqlite_z_unit_model_test.go b/contrib/drivers/sqlite/sqlite_z_unit_model_test.go index a6d6f61a57e..95473ff4063 100644 --- a/contrib/drivers/sqlite/sqlite_z_unit_model_test.go +++ b/contrib/drivers/sqlite/sqlite_z_unit_model_test.go @@ -364,14 +364,58 @@ func Test_Model_Save(t *testing.T) { table := createTable() defer dropTable(table) gtest.C(t, func(t *gtest.T) { - _, err := db.Model(table).Data(g.Map{ + type User struct { + Id int + Passport string + Password string + NickName string + CreateTime *gtime.Time + } + var ( + user User + count int + result sql.Result + err error + ) + + result, err = db.Model(table).Data(g.Map{ "id": 1, - "passport": "t111", - "password": "25d55ad283aa400af464c76d713c07ad", - "nickname": "T111", + "passport": "CN", + "password": "12345678", + "nickname": "oldme", "create_time": CreateTime, - }).Save() - t.Assert(err, ErrorSave) + }).OnConflict("id").Save() + t.AssertNil(nil) + n, _ := result.RowsAffected() + t.Assert(n, 1) + + err = db.Model(table).Scan(&user) + t.Assert(err, nil) + t.Assert(user.Id, 1) + t.Assert(user.Passport, "CN") + t.Assert(user.Password, "12345678") + t.Assert(user.NickName, "oldme") + t.Assert(user.CreateTime.String(), CreateTime) + + _, err = db.Model(table).Data(g.Map{ + "id": 1, + "passport": "CN", + "password": "abc123456", + "nickname": "to be not to be", + "create_time": CreateTime, + }).OnConflict("id").Save() + t.AssertNil(err) + + err = db.Model(table).Scan(&user) + t.Assert(err, nil) + t.Assert(user.Passport, "CN") + t.Assert(user.Password, "abc123456") + t.Assert(user.NickName, "to be not to be") + t.Assert(user.CreateTime.String(), CreateTime) + + count, err = db.Model(table).Count() + t.Assert(err, nil) + t.Assert(count, 1) }) } diff --git a/contrib/drivers/sqlitecgo/sqlitecgo_z_unit_core_test.go b/contrib/drivers/sqlitecgo/sqlitecgo_z_unit_core_test.go index 7dbffd7ee26..dbfdb23def1 100644 --- a/contrib/drivers/sqlitecgo/sqlitecgo_z_unit_core_test.go +++ b/contrib/drivers/sqlitecgo/sqlitecgo_z_unit_core_test.go @@ -425,19 +425,22 @@ func Test_DB_BatchInsert_Struct(t *testing.T) { } func Test_DB_Save(t *testing.T) { - table := createInitTable() + table := createTable() defer dropTable(table) - gtest.C(t, func(t *gtest.T) { - timeStr := gtime.Now().String() - _, err := db.Save(ctx, table, g.Map{ - "id": 1, - "passport": "t1", - "password": "25d55ad283aa400af464c76d713c07ad", - "nickname": "T11", - "create_time": timeStr, - }) - t.Assert(err, ErrorSave) + createTable("t_user") + defer dropTable("t_user") + + i := 10 + data := g.Map{ + "id": i, + "passport": fmt.Sprintf(`t%d`, i), + "password": fmt.Sprintf(`p%d`, i), + "nickname": fmt.Sprintf(`T%d`, i), + "create_time": gtime.Now().String(), + } + _, err := db.Save(ctx, "t_user", data, 10) + gtest.AssertNE(err, nil) }) } diff --git a/contrib/drivers/sqlitecgo/sqlitecgo_z_unit_model_test.go b/contrib/drivers/sqlitecgo/sqlitecgo_z_unit_model_test.go index b9f1760d184..610e46a942e 100644 --- a/contrib/drivers/sqlitecgo/sqlitecgo_z_unit_model_test.go +++ b/contrib/drivers/sqlitecgo/sqlitecgo_z_unit_model_test.go @@ -364,14 +364,58 @@ func Test_Model_Save(t *testing.T) { table := createTable() defer dropTable(table) gtest.C(t, func(t *gtest.T) { - _, err := db.Model(table).Data(g.Map{ + type User struct { + Id int + Passport string + Password string + NickName string + CreateTime *gtime.Time + } + var ( + user User + count int + result sql.Result + err error + ) + + result, err = db.Model(table).Data(g.Map{ "id": 1, - "passport": "t111", - "password": "25d55ad283aa400af464c76d713c07ad", - "nickname": "T111", + "passport": "CN", + "password": "12345678", + "nickname": "oldme", "create_time": CreateTime, - }).Save() - t.Assert(err, ErrorSave) + }).OnConflict("id").Save() + t.AssertNil(nil) + n, _ := result.RowsAffected() + t.Assert(n, 1) + + err = db.Model(table).Scan(&user) + t.Assert(err, nil) + t.Assert(user.Id, 1) + t.Assert(user.Passport, "CN") + t.Assert(user.Password, "12345678") + t.Assert(user.NickName, "oldme") + t.Assert(user.CreateTime.String(), CreateTime) + + _, err = db.Model(table).Data(g.Map{ + "id": 1, + "passport": "CN", + "password": "abc123456", + "nickname": "to be not to be", + "create_time": CreateTime, + }).OnConflict("id").Save() + t.AssertNil(err) + + err = db.Model(table).Scan(&user) + t.Assert(err, nil) + t.Assert(user.Passport, "CN") + t.Assert(user.Password, "abc123456") + t.Assert(user.NickName, "to be not to be") + t.Assert(user.CreateTime.String(), CreateTime) + + count, err = db.Model(table).Count() + t.Assert(err, nil) + t.Assert(count, 1) }) }