diff --git a/contrib/drivers/mssql/mssql_do_exec.go b/contrib/drivers/mssql/mssql_do_exec.go new file mode 100644 index 00000000000..c0eb915623a --- /dev/null +++ b/contrib/drivers/mssql/mssql_do_exec.go @@ -0,0 +1,191 @@ +package mssql + +import ( + "context" + "database/sql" + "fmt" + "regexp" + "strings" + + "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" +) + +const ( + backIdInsertHeadDefault = "INSERT INTO" + backIdInsertHeadInsertIgnore = "INSERT IGNORE INTO" + + autoIncrementName = "auto_increment" + mssqlOutPutKey = "OUTPUT" + mssqlInsertedObjName = "INSERTED" + mssqlAffectFd = " 1 as AffectCount" + affectCountFieldName = "AffectCount" + mssqlPrimaryKeyName = "PRI" + fdId = "ID" + positionInsertValues = ") VALUES" // find the position of the string "VALUES" in the INSERT SQL statement to embed output code for retrieving the last inserted ID +) + +// DoExec commits the sql string and its arguments to underlying driver +// through given link object and returns the execution result. +func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sqlStr string, args ...interface{}) (result sql.Result, err error) { + // Transaction checks. + if link == nil { + if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil { + // Firstly, check and retrieve transaction link from context. + link = &txLinkMssql{tx.GetSqlTX()} + } else if link, err = d.Core.MasterLink(); err != nil { + // Or else it creates one from master node. + return nil, err + } + } else if !link.IsTransaction() { + // If current link is not transaction link, it checks and retrieves transaction from context. + if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil { + link = &txLinkMssql{tx.GetSqlTX()} + } + } + + // SQL filtering. + sqlStr, args = d.Core.FormatSqlBeforeExecuting(sqlStr, args) + sqlStr, args, err = d.DoFilter(ctx, link, sqlStr, args) + if err != nil { + return nil, err + } + + if !(strings.HasPrefix(sqlStr, backIdInsertHeadDefault) || strings.HasPrefix(sqlStr, backIdInsertHeadInsertIgnore)) { + return d.Core.DoExec(ctx, link, sqlStr, args) + } + // find the first pos + pos := strings.Index(sqlStr, positionInsertValues) + + table := d.GetTableNameFromSql(sqlStr) + outPutSql := d.GetInsertOutputSql(ctx, table) + // rebuild sql add output + var ( + sqlValueBefore = sqlStr[:pos+1] + sqlValueAfter = sqlStr[pos+1:] + ) + + sqlStr = fmt.Sprintf("%s%s%s", sqlValueBefore, outPutSql, sqlValueAfter) + + // fmt.Println("sql str:", sqlStr) + // Link execution. + var out gdb.DoCommitOutput + out, err = d.DoCommit(ctx, gdb.DoCommitInput{ + Link: link, + Sql: sqlStr, + Args: args, + Stmt: nil, + Type: gdb.SqlTypeQueryContext, + IsTransaction: link.IsTransaction(), + }) + if err != nil { + return &InsertResult{lastInsertId: 0, rowsAffected: 0, err: err}, err + } + var ( + aCount int64 // affect count + lId int64 // last insert id + ) + stdSqlResult := out.Records + if len(stdSqlResult) == 0 { + err = gerror.WrapCode(gcode.CodeDbOperationError, gerror.New("affectcount is zero"), `sql.Result.RowsAffected failed`) + return &InsertResult{lastInsertId: 0, rowsAffected: 0, err: err}, err + } + // get affect count + aCount = stdSqlResult[0].GMap().GetVar(affectCountFieldName).Int64() + // get last_insert_id + lId = stdSqlResult[0].GMap().GetVar(fdId).Int64() + + return &InsertResult{lastInsertId: lId, rowsAffected: aCount}, err +} + +// GetTableNameFromSql get table name from sql statement +// It handles table string like: +// "user" +// "user u" +// "DbLog.dbo.user", +// "user as u". +func (d *Driver) GetTableNameFromSql(sqlStr string) (table string) { + // INSERT INTO "ip_to_id"("ip") OUTPUT 1 as AffectCount,INSERTED.id as ID VALUES(?) + leftChars, rightChars := d.GetChars() + trimStr := leftChars + rightChars + "[] " + pattern := "INTO(.+?)\\(" + regCompile := regexp.MustCompile(pattern) + tableInfo := regCompile.FindStringSubmatch(sqlStr) + //get the first one. after the first it may be content of the value, it's not table name. + table = tableInfo[1] + table = strings.Trim(table, " ") + if strings.Contains(table, ".") { + tmpAry := strings.Split(table, ".") + // the last one is tablename + table = tmpAry[len(tmpAry)-1] + } else if strings.Contains(table, "as") || strings.Contains(table, " ") { + tmpAry := strings.Split(table, "as") + if len(tmpAry) < 2 { + tmpAry = strings.Split(table, " ") + } + // get the first one + table = tmpAry[0] + } + table = strings.Trim(table, trimStr) + return table +} + +// txLink is used to implement interface Link for TX. +type txLinkMssql struct { + *sql.Tx +} + +// IsTransaction returns if current Link is a transaction. +func (l *txLinkMssql) IsTransaction() bool { + return true +} + +// IsOnMaster checks and returns whether current link is operated on master node. +// Note that, transaction operation is always operated on master node. +func (l *txLinkMssql) IsOnMaster() bool { + return true +} + +// InsertResult instance of sql.Result +type InsertResult struct { + lastInsertId int64 + rowsAffected int64 + err error +} + +func (r *InsertResult) LastInsertId() (int64, error) { + return r.lastInsertId, r.err +} + +func (r *InsertResult) RowsAffected() (int64, error) { + return r.rowsAffected, r.err +} + +// GetInsertOutputSql gen get last_insert_id code +func (m *Driver) GetInsertOutputSql(ctx context.Context, table string) string { + fds, errFd := m.GetDB().TableFields(ctx, table) + if errFd != nil { + return "" + } + extraSqlAry := make([]string, 0) + extraSqlAry = append(extraSqlAry, fmt.Sprintf(" %s %s", mssqlOutPutKey, mssqlAffectFd)) + incrNo := 0 + if len(fds) > 0 { + for _, fd := range fds { + // has primary key and is auto-incement + if fd.Extra == autoIncrementName && fd.Key == mssqlPrimaryKeyName && !fd.Null { + incrNoStr := "" + if incrNo == 0 { // fixed first field named id, convenient to get + incrNoStr = fmt.Sprintf(" as %s", fdId) + } + + extraSqlAry = append(extraSqlAry, fmt.Sprintf("%s.%s%s", mssqlInsertedObjName, fd.Name, incrNoStr)) + incrNo++ + } + // fmt.Printf("null:%t name:%s key:%s k:%s \n", fd.Null, fd.Name, fd.Key, k) + } + } + return strings.Join(extraSqlAry, ",") + // sql example:INSERT INTO "ip_to_id"("ip") OUTPUT 1 as AffectCount,INSERTED.id as ID VALUES(?) +} diff --git a/contrib/drivers/mssql/mssql_z_unit_basic_test.go b/contrib/drivers/mssql/mssql_z_unit_basic_test.go index 7235f0b1653..0e92c709f9e 100644 --- a/contrib/drivers/mssql/mssql_z_unit_basic_test.go +++ b/contrib/drivers/mssql/mssql_z_unit_basic_test.go @@ -13,11 +13,15 @@ import ( "testing" "time" + "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/encoding/gxml" "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/os/gctx" "github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/test/gtest" + + "github.com/gogf/gf/contrib/drivers/mssql/v2" ) func TestTables(t *testing.T) { @@ -148,6 +152,54 @@ func TestDoInsert(t *testing.T) { }) } +func TestDoInsertGetId(t *testing.T) { + // create test table + createInsertAndGetIdTableForTest() + gtest.C(t, func(t *gtest.T) { + table := "ip_to_id" + data := map[string]interface{}{ + "ip": "192.168.179.1", + } + id, err := db.InsertAndGetId(gctx.New(), table, data) + t.AssertNil(err) + t.AssertGT(id, 0) + // fmt.Println("id:", id) + + // multiple insert test + dataAry := []map[string]interface{}{{"ip": "192.168.5.9"}, {"ip": "192.168.5.10"}} + id1, err1 := db.InsertAndGetId(gctx.New(), table, dataAry) + t.AssertNil(err1) + t.AssertGT(id1, 0) + }) +} + +func TestGetTableFromSql(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + okTable := "ip_to_id" + sqlStr := "INSERT INTO \"ip_to_id\"(\"ip\") VALUES(?)" + dbMssql, _ := db.GetCore().GetDB().(*gdb.DriverWrapperDB).DB.(*mssql.Driver) + //fmt.Println("db:", fmt.Sprintf("%T", dbMssql), " ok:", ok) + table := dbMssql.GetTableNameFromSql(sqlStr) + // fmt.Println("default table:", table) + t.Assert(table, okTable) + + sqlStr = "INSERT INTO \"MyLogDb\".\"dbo\".\"ip_to_id\"(\"ip\") VALUES(?)" + table = dbMssql.GetTableNameFromSql(sqlStr) + // fmt.Println("MyLogDb.dbo.ip_to_id table:", table) + t.Assert(table, okTable) + + sqlStr = "INSERT INTO \"ip_to_id\" as \"tt\" (\"ip\") VALUES(?)" + table = dbMssql.GetTableNameFromSql(sqlStr) + // fmt.Println("ip_to_id as tt table:", table) + t.Assert(table, okTable) + + sqlStr = "INSERT INTO \"ip_to_id\" \"tt\" (\"ip\") VALUES(?)" + table = dbMssql.GetTableNameFromSql(sqlStr) + // fmt.Println("ip_to_id tt table:", table) + t.Assert(table, okTable) + }) +} + func Test_DB_Ping(t *testing.T) { gtest.C(t, func(t *gtest.T) { err1 := db.PingMaster() diff --git a/contrib/drivers/mssql/mssql_z_unit_init_test.go b/contrib/drivers/mssql/mssql_z_unit_init_test.go index 08e925c5a40..e1284ab3e3d 100644 --- a/contrib/drivers/mssql/mssql_z_unit_init_test.go +++ b/contrib/drivers/mssql/mssql_z_unit_init_test.go @@ -25,18 +25,23 @@ var ( ) const ( - TableSize = 10 - TestDbUser = "sa" - TestDbPass = "LoremIpsum86" + TableSize = 10 + TableName = "t_user" + TestSchema1 = "test1" + TestSchema2 = "test2" + TableNamePrefix1 = "gf_" + TestDbUser = "sa" + TestDbPass = "LoremIpsum86" // "theone@123" + CreateTime = "2018-10-24 10:00:00" ) func init() { node := gdb.ConfigNode{ - Host: "127.0.0.1", + Host: "127.0.0.1", // 192.168.5.72 127.0.0.1 Port: "1433", User: TestDbUser, Pass: TestDbPass, - Name: "master", + Name: "test", // "QPLogDB", Type: "mssql", Role: "master", Charset: "utf8", @@ -142,3 +147,26 @@ func dropTable(table string) { gtest.Fatal(err) } } + +// createInsertAndGetIdTableForTest test for InsertAndGetId +func createInsertAndGetIdTableForTest() (name string) { + + if _, err := db.Exec(context.Background(), ` +IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='ip_to_id' and xtype='U') +begin + CREATE TABLE [ip_to_id]( + [id] [int] IDENTITY(1,1) NOT NULL, + [ip] [varchar](128) NULL, + CONSTRAINT [PK_ip_to_id] PRIMARY KEY CLUSTERED + ( + [id] ASC + )WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] + ) ON [PRIMARY] +end + `); err != nil { + gtest.Fatal(err) + } + + db.Schema(db.GetConfig().Name) + return +}