Skip to content

Commit

Permalink
Merge pull request #22 from VarusHsu/master
Browse files Browse the repository at this point in the history
A lot of new features
  • Loading branch information
lqs authored Mar 19, 2024
2 parents aa857a1 + 5b9c119 commit 59302de
Show file tree
Hide file tree
Showing 13 changed files with 589 additions and 51 deletions.
23 changes: 23 additions & 0 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,29 @@ func commaOrderBys(scope scope, orderBys []OrderBy) (string, error) {
}

func getCallerInfo(db database, retry bool) string {
if !db.enableCallerInfo {
return ""
}
extraInfo := ""
if retry {
extraInfo += " (retry)"
}
for i := 0; true; i++ {
_, file, line, ok := runtime.Caller(i)
if !ok {
break
}
if file == "" || strings.Contains(file, "/sqlingo@v") {
continue
}
segs := strings.Split(file, "/")
name := segs[len(segs)-1]
return fmt.Sprintf("/* %s:%d%s */ ", name, line, extraInfo)
}
return ""
}

func getTxCallerInfo(db transaction, retry bool) string {
if !db.enableCallerInfo {
return ""
}
Expand Down
137 changes: 109 additions & 28 deletions database.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,23 @@ package sqlingo
import (
"context"
"database/sql"
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"time"
)

const (
// for colorful terminal print
green = "\033[32m"
red = "\033[31m"
blue = "\033[34m"
reset = "\033[0m"
)

// Database is the interface of a database with underlying sql.DB object.
type Database interface {
// Get the underlying sql.DB object of the database
Expand All @@ -20,11 +34,12 @@ type Database interface {
// Executes a statement with context
ExecuteContext(ctx context.Context, sql string) (sql.Result, error)
// Set the logger function
SetLogger(logger func(sql string, durationNano int64))
SetLogger(logger LoggerFunc)
// Set the retry policy function.
// The retry policy function returns true if needs retry.
SetRetryPolicy(retryPolicy func(err error) bool)
// enable or disable caller info
// EnableCallerInfo enable or disable the caller info in the log.
// Deprecated: use SetLogger instead
EnableCallerInfo(enableCallerInfo bool)
// Set a interceptor function
SetInterceptor(interceptor InterceptorFunc)
Expand All @@ -43,25 +58,95 @@ type Database interface {
Update(table Table) updateWithSet
// Initiate a DELETE FROM statement
DeleteFrom(table Table) deleteWithTable
}

type txOrDB interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
// Begin Start a new transaction and returning a Transaction object.
// the DDL operations using the returned Transaction object will
// regard as one time transaction.
// User must manually call Commit() or Rollback() to end the transaction,
// after that, more DDL operations or TCL will return error.
Begin() (Transaction, error)
}

var (
once sync.Once
srcPrefix string
)

type database struct {
db *sql.DB
tx *sql.Tx
logger func(sql string, durationNano int64)
logger LoggerFunc
dialect dialect
retryPolicy func(error) bool
enableCallerInfo bool
interceptor InterceptorFunc
}

func (d *database) SetLogger(logger func(sql string, durationNano int64)) {
d.logger = logger
type LoggerFunc func(sql string, durationNano int64, isTx bool, retry bool)

func (d *database) SetLogger(loggerFunc LoggerFunc) {
d.logger = loggerFunc
}

// defaultLogger is sqlingo default logger,
// which print log to stderr and regard executing time gt 100ms as slow sql.
func defaultLogger(sql string, durationNano int64, isTx bool, retry bool) {
// for finding code position, try once is enough
once.Do(func() {
// $GOPATH/pkg/mod/github.com/lqs/[email protected]/database.go
_, file, _, _ := runtime.Caller(0)
// $GOPATH/pkg/mod/github.com/lqs/[email protected]
srcPrefix = filepath.Dir(file)
})

var file string
var line int
var ok bool
for i := 0; i < 16; i++ {
_, file, line, ok = runtime.Caller(i)
// `!strings.HasPrefix(file, srcPrefix)` jump out when using sqlingo as dependent package
// `strings.HasSuffix(file, "_test.go")` jump out when executing unit test cases
// `!ok` this is so terrible for something unexpected happened
if !ok || !strings.HasPrefix(file, srcPrefix) || strings.HasSuffix(file, "_test.go") {
break
}
}

// convert durationNano (int64) to time.Duration
du := time.Duration(durationNano)
// todo shouldn't append ';' here
if !strings.HasSuffix(sql, ";") {
sql += ";"
}

sb := strings.Builder{}
sb.Grow(32)
sb.WriteString("|")
sb.WriteString(du.String())
if isTx {
sb.WriteString("|transaction") // todo using something traceable
}
if retry {
sb.WriteString("|retry")
}
sb.WriteString("|")

line1 := strings.Join(
[]string{
"[sqlingo]",
time.Now().Format("2006-01-02 15:04:05"),
sb.String(),
file + ":" + fmt.Sprint(line),
},
" ")

// print to stderr
fmt.Fprintln(os.Stderr, blue+line1+reset)
if du < 100*time.Millisecond {
fmt.Fprintf(os.Stderr, "%s%s%s\n", green, sql, reset)
} else {
fmt.Fprintf(os.Stderr, "%s%s%s\n", red, sql, reset)
}
fmt.Fprintln(os.Stderr)
}

func (d *database) SetRetryPolicy(retryPolicy func(err error) bool) {
Expand All @@ -76,7 +161,9 @@ func (d *database) SetInterceptor(interceptor InterceptorFunc) {
d.interceptor = interceptor
}

// Open a database, similar to sql.Open
// Open a database, similar to sql.Open.
// `db` using a default logger, which print log to stderr and regard executing time gt 100ms as slow sql.
// To disable the default logger, use `db.SetLogger(nil)`.
func Open(driverName string, dataSourceName string) (db Database, err error) {
var sqlDB *sql.DB
if dataSourceName != "" {
Expand All @@ -86,6 +173,7 @@ func Open(driverName string, dataSourceName string) (db Database, err error) {
}
}
db = Use(driverName, sqlDB)
db.SetLogger(defaultLogger)
return
}

Expand All @@ -101,13 +189,6 @@ func (d database) GetDB() *sql.DB {
return d.db
}

func (d database) getTxOrDB() txOrDB {
if d.tx != nil {
return d.tx
}
return d.db
}

func (d database) Query(sqlString string) (Cursor, error) {
return d.QueryContext(context.Background(), sqlString)
}
Expand All @@ -116,10 +197,9 @@ func (d database) QueryContext(ctx context.Context, sqlString string) (Cursor, e
isRetry := false
for {
sqlStringWithCallerInfo := getCallerInfo(d, isRetry) + sqlString

rows, err := d.queryContextOnce(ctx, sqlStringWithCallerInfo)
rows, err := d.queryContextOnce(ctx, sqlStringWithCallerInfo, isRetry)
if err != nil {
isRetry = d.tx == nil && d.retryPolicy != nil && d.retryPolicy(err)
isRetry = d.retryPolicy != nil && d.retryPolicy(err)
if isRetry {
continue
}
Expand All @@ -129,30 +209,30 @@ func (d database) QueryContext(ctx context.Context, sqlString string) (Cursor, e
}
}

func (d database) queryContextOnce(ctx context.Context, sqlStringWithCallerInfo string) (*sql.Rows, error) {
func (d database) queryContextOnce(ctx context.Context, sqlString string, retry bool) (*sql.Rows, error) {
if ctx == nil {
ctx = context.Background()
}
startTime := time.Now().UnixNano()
defer func() {
endTime := time.Now().UnixNano()
if d.logger != nil {
d.logger(sqlStringWithCallerInfo, endTime-startTime)
d.logger(sqlString, endTime-startTime, false, retry)
}
}()

interceptor := d.interceptor
var rows *sql.Rows
invoker := func(ctx context.Context, sql string) (err error) {
rows, err = d.getTxOrDB().QueryContext(ctx, sql)
rows, err = d.GetDB().QueryContext(ctx, sql)
return
}

var err error
if interceptor == nil {
err = invoker(ctx, sqlStringWithCallerInfo)
err = invoker(ctx, sqlString)
} else {
err = interceptor(ctx, sqlStringWithCallerInfo, invoker)
err = interceptor(ctx, sqlString, invoker)
}
if err != nil {
return nil, err
Expand All @@ -165,6 +245,7 @@ func (d database) Execute(sqlString string) (sql.Result, error) {
return d.ExecuteContext(context.Background(), sqlString)
}

// ExecuteContext todo Is there need retry?
func (d database) ExecuteContext(ctx context.Context, sqlString string) (sql.Result, error) {
if ctx == nil {
ctx = context.Background()
Expand All @@ -174,13 +255,13 @@ func (d database) ExecuteContext(ctx context.Context, sqlString string) (sql.Res
defer func() {
endTime := time.Now().UnixNano()
if d.logger != nil {
d.logger(sqlStringWithCallerInfo, endTime-startTime)
d.logger(sqlStringWithCallerInfo, endTime-startTime, false, false)
}
}()

var result sql.Result
invoker := func(ctx context.Context, sql string) (err error) {
result, err = d.getTxOrDB().ExecContext(ctx, sql)
result, err = d.GetDB().ExecContext(ctx, sql)
return
}
var err error
Expand Down
2 changes: 1 addition & 1 deletion database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func TestDatabase(t *testing.T) {
interceptorExecutedCount++
return invoker(ctx, sql)
})
db.SetLogger(func(sql string, durationNano int64) {
db.SetLogger(func(sql string, durationNano int64, _, _ bool) {
if sql != "SELECT 1" {
t.Error(sql)
}
Expand Down
7 changes: 7 additions & 0 deletions delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,12 @@ func (s deleteStatus) Execute() (sql.Result, error) {
if err != nil {
return nil, err
}

// use transaction if it exists, otherwise use database.
// this is because when s.scope.Transaction is not nil,
// it must be built by transaction.
if s.scope.Transaction != nil {
return s.scope.Transaction.Execute(sqlString)
}
return s.scope.Database.Execute(sqlString)
}
8 changes: 5 additions & 3 deletions expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,11 @@ func (e expression) GetTable() Table {
}

type scope struct {
Database *database
Tables []Table
lastJoin *join
// Transaction should be nil if without transaction begin
Transaction *transaction
Database *database
Tables []Table
lastJoin *join
}

func staticExpression(sql string, priority priority, isBool bool) expression {
Expand Down
4 changes: 3 additions & 1 deletion field.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ func newField(table Table, fieldName string) actualField {
expression: expression{
builder: func(scope scope) (string, error) {
dialect := dialectUnknown
if scope.Database != nil {
if scope.Transaction != nil {
dialect = scope.Transaction.dialect
} else if scope.Database != nil {
dialect = scope.Database.dialect
}
if len(scope.Tables) != 1 || scope.lastJoin != nil || scope.Tables[0].GetName() != tableName {
Expand Down
3 changes: 3 additions & 0 deletions insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,5 +187,8 @@ func (s insertStatus) Execute() (result sql.Result, err error) {
if err != nil {
return nil, err
}
if s.scope.Transaction != nil {
return s.scope.Transaction.Execute(sqlString)
}
return s.scope.Database.Execute(sqlString)
}
Loading

0 comments on commit 59302de

Please sign in to comment.