Skip to content

Commit f01ca29

Browse files
authored
Merge branch 'go-gorm:master' into master
2 parents b0f2f14 + 328f301 commit f01ca29

36 files changed

+656
-197
lines changed

.github/workflows/tests.yml

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
sqlite:
1717
strategy:
1818
matrix:
19-
go: ['1.18', '1.17', '1.16']
19+
go: ['1.19', '1.18', '1.17', '1.16']
2020
platform: [ubuntu-latest] # can not run in windows OS
2121
runs-on: ${{ matrix.platform }}
2222

@@ -42,7 +42,7 @@ jobs:
4242
strategy:
4343
matrix:
4444
dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest']
45-
go: ['1.18', '1.17', '1.16']
45+
go: ['1.19', '1.18', '1.17', '1.16']
4646
platform: [ubuntu-latest]
4747
runs-on: ${{ matrix.platform }}
4848

@@ -86,7 +86,7 @@ jobs:
8686
strategy:
8787
matrix:
8888
dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
89-
go: ['1.18', '1.17', '1.16']
89+
go: ['1.19', '1.18', '1.17', '1.16']
9090
platform: [ubuntu-latest] # can not run in macOS and Windows
9191
runs-on: ${{ matrix.platform }}
9292

@@ -128,7 +128,7 @@ jobs:
128128
sqlserver:
129129
strategy:
130130
matrix:
131-
go: ['1.18', '1.17', '1.16']
131+
go: ['1.19', '1.18', '1.17', '1.16']
132132
platform: [ubuntu-latest] # can not run test in macOS and windows
133133
runs-on: ${{ matrix.platform }}
134134

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ documents
33
coverage.txt
44
_book
55
.idea
6-
vendor
6+
vendor
7+
.vscode

association.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,9 @@ func (association *Association) buildCondition() *DB {
507507
joinStmt.AddClause(queryClause)
508508
}
509509
joinStmt.Build("WHERE")
510-
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
510+
if len(joinStmt.SQL.String()) > 0 {
511+
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
512+
}
511513
}
512514

513515
tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{

callbacks.go

+6-31
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package gorm
22

33
import (
44
"context"
5-
"database/sql"
65
"errors"
76
"fmt"
87
"reflect"
@@ -16,13 +15,12 @@ import (
1615
func initializeCallbacks(db *DB) *callbacks {
1716
return &callbacks{
1817
processors: map[string]*processor{
19-
"create": {db: db},
20-
"query": {db: db},
21-
"update": {db: db},
22-
"delete": {db: db},
23-
"row": {db: db},
24-
"raw": {db: db},
25-
"transaction": {db: db},
18+
"create": {db: db},
19+
"query": {db: db},
20+
"update": {db: db},
21+
"delete": {db: db},
22+
"row": {db: db},
23+
"raw": {db: db},
2624
},
2725
}
2826
}
@@ -74,29 +72,6 @@ func (cs *callbacks) Raw() *processor {
7472
return cs.processors["raw"]
7573
}
7674

77-
func (cs *callbacks) Transaction() *processor {
78-
return cs.processors["transaction"]
79-
}
80-
81-
func (p *processor) Begin(tx *DB, opt *sql.TxOptions) *DB {
82-
var err error
83-
84-
switch beginner := tx.Statement.ConnPool.(type) {
85-
case TxBeginner:
86-
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
87-
case ConnPoolBeginner:
88-
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
89-
default:
90-
err = ErrInvalidTransaction
91-
}
92-
93-
if err != nil {
94-
_ = tx.AddError(err)
95-
}
96-
97-
return tx
98-
}
99-
10075
func (p *processor) Execute(db *DB) *DB {
10176
// call scopes
10277
for len(db.Statement.scopes) > 0 {

callbacks/associations.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
206206
}
207207
}
208208

209-
cacheKey := utils.ToStringKey(relPrimaryValues)
209+
cacheKey := utils.ToStringKey(relPrimaryValues...)
210210
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
211211
identityMap[cacheKey] = true
212212
if isPtr {
@@ -292,7 +292,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
292292
}
293293
}
294294

295-
cacheKey := utils.ToStringKey(relPrimaryValues)
295+
cacheKey := utils.ToStringKey(relPrimaryValues...)
296296
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
297297
identityMap[cacheKey] = true
298298
distinctElems = reflect.Append(distinctElems, elem)

callbacks/update.go

+17-15
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,12 @@ func Update(config *Config) func(db *gorm.DB) {
7070
if db.Statement.SQL.Len() == 0 {
7171
db.Statement.SQL.Grow(180)
7272
db.Statement.AddClauseIfNotExists(clause.Update{})
73-
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
74-
db.Statement.AddClause(set)
75-
} else if _, ok := db.Statement.Clauses["SET"]; !ok {
76-
return
73+
if _, ok := db.Statement.Clauses["SET"]; !ok {
74+
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
75+
db.Statement.AddClause(set)
76+
} else {
77+
return
78+
}
7779
}
7880

7981
db.Statement.Build(db.Statement.BuildClauses...)
@@ -158,21 +160,21 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
158160
switch stmt.ReflectValue.Kind() {
159161
case reflect.Slice, reflect.Array:
160162
if size := stmt.ReflectValue.Len(); size > 0 {
161-
var primaryKeyExprs []clause.Expression
163+
var isZero bool
162164
for i := 0; i < size; i++ {
163-
exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields))
164-
var notZero bool
165-
for idx, field := range stmt.Schema.PrimaryFields {
166-
value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
167-
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
168-
notZero = notZero || !isZero
169-
}
170-
if notZero {
171-
primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...))
165+
for _, field := range stmt.Schema.PrimaryFields {
166+
_, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
167+
if !isZero {
168+
break
169+
}
172170
}
173171
}
174172

175-
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}})
173+
if !isZero {
174+
_, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
175+
column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues)
176+
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
177+
}
176178
}
177179
case reflect.Struct:
178180
for _, field := range stmt.Schema.PrimaryFields {

0 commit comments

Comments
 (0)