Skip to content

Commit bb320aa

Browse files
authored
Merge pull request #179 from huandu/feature/cte-table-for-update-delete
Automatically ref names of CTETables in DELETE and UPDATE statements
2 parents 134c901 + 96c9b25 commit bb320aa

File tree

5 files changed

+141
-29
lines changed

5 files changed

+141
-29
lines changed

cte.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,12 @@ func (cteb *CTEBuilder) TableNames() []string {
138138
return tableNames
139139
}
140140

141-
// tableNamesForSelect returns a list of table names which should be automatically added to FROM clause.
142-
// It's not public, as this feature is designed only for SelectBuilder right now.
143-
func (cteb *CTEBuilder) tableNamesForSelect() []string {
141+
// tableNamesForFrom returns a list of table names which should be automatically added to FROM clause.
142+
// It's not public, as this feature is designed only for SelectBuilder/UpdateBuilder/DeleteBuilder right now.
143+
func (cteb *CTEBuilder) tableNamesForFrom() []string {
144144
cnt := 0
145145

146-
// It's rare that the ShouldAddToTableList() returns true.
146+
// ShouldAddToTableList() unlikely returns true.
147147
// Count it before allocating any memory for better performance.
148148
for _, query := range cteb.queries {
149149
if query.ShouldAddToTableList() {

cte_test.go

+37
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,43 @@ func ExampleCTEBuilder() {
8282
// [users valid_users]
8383
}
8484

85+
func ExampleCTEBuilder_update() {
86+
builder := With(
87+
CTETable("users", "user_id").As(
88+
Select("user_id").From("vip_users"),
89+
),
90+
).Update("orders").Set(
91+
"orders.transport_fee = 0",
92+
).Where(
93+
"users.user_id = orders.user_id",
94+
)
95+
96+
sqlForMySQL, _ := builder.BuildWithFlavor(MySQL)
97+
sqlForPostgreSQL, _ := builder.BuildWithFlavor(PostgreSQL)
98+
99+
fmt.Println(sqlForMySQL)
100+
fmt.Println(sqlForPostgreSQL)
101+
102+
// Output:
103+
// WITH users (user_id) AS (SELECT user_id FROM vip_users) UPDATE orders, users SET orders.transport_fee = 0 WHERE users.user_id = orders.user_id
104+
// WITH users (user_id) AS (SELECT user_id FROM vip_users) UPDATE orders FROM users SET orders.transport_fee = 0 WHERE users.user_id = orders.user_id
105+
}
106+
107+
func ExampleCTEBuilder_delete() {
108+
sql := With(
109+
CTETable("users", "user_id").As(
110+
Select("user_id").From("cheaters"),
111+
),
112+
).DeleteFrom("awards").Where(
113+
"users.user_id = awards.user_id",
114+
).String()
115+
116+
fmt.Println(sql)
117+
118+
// Output:
119+
// WITH users (user_id) AS (SELECT user_id FROM cheaters) DELETE FROM awards, users WHERE users.user_id = awards.user_id
120+
}
121+
85122
func TestCTEBuilder(t *testing.T) {
86123
a := assert.New(t)
87124
cteb := newCTEBuilder()

delete.go

+40-11
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,10 @@ type DeleteBuilder struct {
4545
whereClauseProxy *whereClauseProxy
4646
whereClauseExpr string
4747

48-
cteBuilder string
49-
table string
48+
cteBuilderVar string
49+
cteBuilder *CTEBuilder
50+
51+
tables []string
5052
orderByCols []string
5153
order string
5254
limit int
@@ -60,24 +62,48 @@ type DeleteBuilder struct {
6062
var _ Builder = new(DeleteBuilder)
6163

6264
// DeleteFrom sets table name in DELETE.
63-
func DeleteFrom(table string) *DeleteBuilder {
64-
return DefaultFlavor.NewDeleteBuilder().DeleteFrom(table)
65+
func DeleteFrom(table ...string) *DeleteBuilder {
66+
return DefaultFlavor.NewDeleteBuilder().DeleteFrom(table...)
6567
}
6668

6769
// With sets WITH clause (the Common Table Expression) before DELETE.
6870
func (db *DeleteBuilder) With(builder *CTEBuilder) *DeleteBuilder {
6971
db.marker = deleteMarkerAfterWith
70-
db.cteBuilder = db.Var(builder)
72+
db.cteBuilderVar = db.Var(builder)
73+
db.cteBuilder = builder
7174
return db
7275
}
7376

7477
// DeleteFrom sets table name in DELETE.
75-
func (db *DeleteBuilder) DeleteFrom(table string) *DeleteBuilder {
76-
db.table = Escape(table)
78+
func (db *DeleteBuilder) DeleteFrom(table ...string) *DeleteBuilder {
79+
db.tables = table
7780
db.marker = deleteMarkerAfterDeleteFrom
7881
return db
7982
}
8083

84+
// TableNames returns all table names in this DELETE statement.
85+
func (db *DeleteBuilder) TableNames() []string {
86+
var additionalTableNames []string
87+
88+
if db.cteBuilder != nil {
89+
additionalTableNames = db.cteBuilder.tableNamesForFrom()
90+
}
91+
92+
var tableNames []string
93+
94+
if len(db.tables) > 0 && len(additionalTableNames) > 0 {
95+
tableNames = make([]string, len(db.tables)+len(additionalTableNames))
96+
copy(tableNames, db.tables)
97+
copy(tableNames[len(db.tables):], additionalTableNames)
98+
} else if len(db.tables) > 0 {
99+
tableNames = db.tables
100+
} else if len(additionalTableNames) > 0 {
101+
tableNames = additionalTableNames
102+
}
103+
104+
return tableNames
105+
}
106+
81107
// Where sets expressions of WHERE in DELETE.
82108
func (db *DeleteBuilder) Where(andExpr ...string) *DeleteBuilder {
83109
if len(andExpr) == 0 || estimateStringsBytes(andExpr) == 0 {
@@ -146,17 +172,20 @@ func (db *DeleteBuilder) Build() (sql string, args []interface{}) {
146172
// BuildWithFlavor returns compiled DELETE string and args with flavor and initial args.
147173
// They can be used in `DB#Query` of package `database/sql` directly.
148174
func (db *DeleteBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) {
175+
149176
buf := newStringBuilder()
150177
db.injection.WriteTo(buf, deleteMarkerInit)
151178

152-
if db.cteBuilder != "" {
153-
buf.WriteLeadingString(db.cteBuilder)
179+
if db.cteBuilder != nil {
180+
buf.WriteLeadingString(db.cteBuilderVar)
154181
db.injection.WriteTo(buf, deleteMarkerAfterWith)
155182
}
156183

157-
if len(db.table) > 0 {
184+
tableNames := db.TableNames()
185+
186+
if len(tableNames) > 0 {
158187
buf.WriteLeadingString("DELETE FROM ")
159-
buf.WriteString(db.table)
188+
buf.WriteStrings(tableNames, ", ")
160189
}
161190

162191
db.injection.WriteTo(buf, deleteMarkerAfterDeleteFrom)

select.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,12 @@ func Select(col ...string) *SelectBuilder {
9696
return DefaultFlavor.NewSelectBuilder().Select(col...)
9797
}
9898

99-
// TableNames returns all table names in a SELECT.
99+
// TableNames returns all table names in this SELECT statement.
100100
func (sb *SelectBuilder) TableNames() []string {
101101
var additionalTableNames []string
102102

103103
if sb.cteBuilder != nil {
104-
additionalTableNames = sb.cteBuilder.tableNamesForSelect()
104+
additionalTableNames = sb.cteBuilder.tableNamesForFrom()
105105
}
106106

107107
var tableNames []string

update.go

+58-12
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@ type UpdateBuilder struct {
4747
whereClauseProxy *whereClauseProxy
4848
whereClauseExpr string
4949

50-
cteBuilder string
51-
table string
50+
cteBuilderVar string
51+
cteBuilder *CTEBuilder
52+
53+
tables []string
5254
assignments []string
5355
orderByCols []string
5456
order string
@@ -63,24 +65,46 @@ type UpdateBuilder struct {
6365
var _ Builder = new(UpdateBuilder)
6466

6567
// Update sets table name in UPDATE.
66-
func Update(table string) *UpdateBuilder {
67-
return DefaultFlavor.NewUpdateBuilder().Update(table)
68+
func Update(table ...string) *UpdateBuilder {
69+
return DefaultFlavor.NewUpdateBuilder().Update(table...)
6870
}
6971

7072
// With sets WITH clause (the Common Table Expression) before UPDATE.
7173
func (ub *UpdateBuilder) With(builder *CTEBuilder) *UpdateBuilder {
7274
ub.marker = updateMarkerAfterWith
73-
ub.cteBuilder = ub.Var(builder)
75+
ub.cteBuilderVar = ub.Var(builder)
76+
ub.cteBuilder = builder
7477
return ub
7578
}
7679

7780
// Update sets table name in UPDATE.
78-
func (ub *UpdateBuilder) Update(table string) *UpdateBuilder {
79-
ub.table = Escape(table)
81+
func (ub *UpdateBuilder) Update(table ...string) *UpdateBuilder {
82+
ub.tables = table
8083
ub.marker = updateMarkerAfterUpdate
8184
return ub
8285
}
8386

87+
// TableNames returns all table names in this UPDATE statement.
88+
func (ub *UpdateBuilder) TableNames() (tableNames []string) {
89+
var additionalTableNames []string
90+
91+
if ub.cteBuilder != nil {
92+
additionalTableNames = ub.cteBuilder.tableNamesForFrom()
93+
}
94+
95+
if len(ub.tables) > 0 && len(additionalTableNames) > 0 {
96+
tableNames = make([]string, len(ub.tables)+len(additionalTableNames))
97+
copy(tableNames, ub.tables)
98+
copy(tableNames[len(ub.tables):], additionalTableNames)
99+
} else if len(ub.tables) > 0 {
100+
tableNames = ub.tables
101+
} else if len(additionalTableNames) > 0 {
102+
tableNames = additionalTableNames
103+
}
104+
105+
return tableNames
106+
}
107+
84108
// Set sets the assignments in SET.
85109
func (ub *UpdateBuilder) Set(assignment ...string) *UpdateBuilder {
86110
ub.assignments = assignment
@@ -212,14 +236,36 @@ func (ub *UpdateBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{
212236
buf := newStringBuilder()
213237
ub.injection.WriteTo(buf, updateMarkerInit)
214238

215-
if ub.cteBuilder != "" {
216-
buf.WriteLeadingString(ub.cteBuilder)
239+
if ub.cteBuilder != nil {
240+
buf.WriteLeadingString(ub.cteBuilderVar)
217241
ub.injection.WriteTo(buf, updateMarkerAfterWith)
218242
}
219243

220-
if len(ub.table) > 0 {
221-
buf.WriteLeadingString("UPDATE ")
222-
buf.WriteString(ub.table)
244+
switch flavor {
245+
case MySQL:
246+
// CTE table names should be written after UPDATE keyword in MySQL.
247+
tableNames := ub.TableNames()
248+
249+
if len(tableNames) > 0 {
250+
buf.WriteLeadingString("UPDATE ")
251+
buf.WriteStrings(tableNames, ", ")
252+
}
253+
254+
default:
255+
if len(ub.tables) > 0 {
256+
buf.WriteLeadingString("UPDATE ")
257+
buf.WriteStrings(ub.tables, ", ")
258+
259+
// For ISO SQL, CTE table names should be written after FROM keyword.
260+
if ub.cteBuilder != nil {
261+
cteTableNames := ub.cteBuilder.tableNamesForFrom()
262+
263+
if len(cteTableNames) > 0 {
264+
buf.WriteLeadingString("FROM ")
265+
buf.WriteStrings(cteTableNames, ", ")
266+
}
267+
}
268+
}
223269
}
224270

225271
ub.injection.WriteTo(buf, updateMarkerAfterUpdate)

0 commit comments

Comments
 (0)