From 01acaab28875fd6601fff397d596ec737534eeaa Mon Sep 17 00:00:00 2001 From: Huan Du Date: Mon, 9 Sep 2024 01:52:01 +0800 Subject: [PATCH] fix #163: CTE API refactory; see issue for details --- cte.go | 82 +++++++++++++++++++++------ cte_test.go | 8 ++- ctequery.go | 135 ++++++++++++++++++++++++++++++++++++++++++++ ctetable.go | 106 ---------------------------------- delete.go | 14 +++++ delete_test.go | 15 +++++ flavor.go | 4 +- select.go | 38 ++++++++++--- select_test.go | 18 ++++++ update.go | 14 +++++ update_test.go | 17 ++++++ whereclause.go | 12 ++-- whereclause_test.go | 28 ++++++--- 13 files changed, 343 insertions(+), 148 deletions(-) create mode 100644 ctequery.go delete mode 100644 ctetable.go diff --git a/cte.go b/cte.go index 4ef19d7..4611f56 100644 --- a/cte.go +++ b/cte.go @@ -9,12 +9,12 @@ const ( ) // With creates a new CTE builder with default flavor. -func With(tables ...*CTETableBuilder) *CTEBuilder { +func With(tables ...*CTEQueryBuilder) *CTEBuilder { return DefaultFlavor.NewCTEBuilder().With(tables...) } // WithRecursive creates a new recursive CTE builder with default flavor. -func WithRecursive(tables ...*CTETableBuilder) *CTEBuilder { +func WithRecursive(tables ...*CTEQueryBuilder) *CTEBuilder { return DefaultFlavor.NewCTEBuilder().WithRecursive(tables...) } @@ -28,8 +28,8 @@ func newCTEBuilder() *CTEBuilder { // CTEBuilder is a CTE (Common Table Expression) builder. type CTEBuilder struct { recursive bool - tableNames []string - tableBuilderVars []string + queries []*CTEQueryBuilder + queryBuilderVars []string args *Args @@ -40,24 +40,22 @@ type CTEBuilder struct { var _ Builder = new(CTEBuilder) // With sets the CTE name and columns. -func (cteb *CTEBuilder) With(tables ...*CTETableBuilder) *CTEBuilder { - tableNames := make([]string, 0, len(tables)) - tableBuilderVars := make([]string, 0, len(tables)) +func (cteb *CTEBuilder) With(queries ...*CTEQueryBuilder) *CTEBuilder { + queryBuilderVars := make([]string, 0, len(queries)) - for _, table := range tables { - tableNames = append(tableNames, table.TableName()) - tableBuilderVars = append(tableBuilderVars, cteb.args.Add(table)) + for _, query := range queries { + queryBuilderVars = append(queryBuilderVars, cteb.args.Add(query)) } - cteb.tableNames = tableNames - cteb.tableBuilderVars = tableBuilderVars + cteb.queries = queries + cteb.queryBuilderVars = queryBuilderVars cteb.marker = cteMarkerAfterWith return cteb } // WithRecursive sets the CTE name and columns and turns on the RECURSIVE keyword. -func (cteb *CTEBuilder) WithRecursive(tables ...*CTETableBuilder) *CTEBuilder { - cteb.With(tables...).recursive = true +func (cteb *CTEBuilder) WithRecursive(queries ...*CTEQueryBuilder) *CTEBuilder { + cteb.With(queries...).recursive = true return cteb } @@ -67,6 +65,18 @@ func (cteb *CTEBuilder) Select(col ...string) *SelectBuilder { return sb.With(cteb).Select(col...) } +// DeleteFrom creates a new DeleteBuilder to build a DELETE statement using this CTE. +func (cteb *CTEBuilder) DeleteFrom(table string) *DeleteBuilder { + db := cteb.args.Flavor.NewDeleteBuilder() + return db.With(cteb).DeleteFrom(table) +} + +// Update creates a new UpdateBuilder to build an UPDATE statement using this CTE. +func (cteb *CTEBuilder) Update(table string) *UpdateBuilder { + ub := cteb.args.Flavor.NewUpdateBuilder() + return ub.With(cteb).Update(table) +} + // String returns the compiled CTE string. func (cteb *CTEBuilder) String() string { sql, _ := cteb.Build() @@ -83,12 +93,12 @@ func (cteb *CTEBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{} buf := newStringBuilder() cteb.injection.WriteTo(buf, cteMarkerInit) - if len(cteb.tableBuilderVars) > 0 { + if len(cteb.queryBuilderVars) > 0 { buf.WriteLeadingString("WITH ") if cteb.recursive { buf.WriteString("RECURSIVE ") } - buf.WriteStrings(cteb.tableBuilderVars, ", ") + buf.WriteStrings(cteb.queryBuilderVars, ", ") } cteb.injection.WriteTo(buf, cteMarkerAfterWith) @@ -110,5 +120,43 @@ func (cteb *CTEBuilder) SQL(sql string) *CTEBuilder { // TableNames returns all table names in a CTE. func (cteb *CTEBuilder) TableNames() []string { - return cteb.tableNames + if len(cteb.queryBuilderVars) == 0 { + return nil + } + + tableNames := make([]string, 0, len(cteb.queries)) + + for _, query := range cteb.queries { + tableNames = append(tableNames, query.TableName()) + } + + return tableNames +} + +// tableNamesForSelect returns a list of table names which should be automatically added to FROM clause. +// It's not public, as this feature is designed only for SelectBuilder right now. +func (cteb *CTEBuilder) tableNamesForSelect() []string { + cnt := 0 + + // It's rare that the ShouldAddToTableList() returns true. + // Count it before allocating any memory for better performance. + for _, query := range cteb.queries { + if query.ShouldAddToTableList() { + cnt++ + } + } + + if cnt == 0 { + return nil + } + + tableNames := make([]string, 0, cnt) + + for _, query := range cteb.queries { + if query.ShouldAddToTableList() { + tableNames = append(tableNames, query.TableName()) + } + } + + return tableNames } diff --git a/cte_test.go b/cte_test.go index ae3bbc6..c22d23f 100644 --- a/cte_test.go +++ b/cte_test.go @@ -32,7 +32,7 @@ func ExampleWith() { func ExampleWithRecursive() { sb := WithRecursive( - CTETable("source_accounts", "id", "parent_id").As( + CTEQuery("source_accounts", "id", "parent_id").As( UnionAll( Select("p.id", "p.parent_id"). From("accounts AS p"). @@ -85,7 +85,7 @@ func ExampleCTEBuilder() { func TestCTEBuilder(t *testing.T) { a := assert.New(t) cteb := newCTEBuilder() - ctetb := newCTETableBuilder() + ctetb := newCTEQueryBuilder() cteb.SQL("/* init */") cteb.With(ctetb) cteb.SQL("/* after with */") @@ -97,6 +97,8 @@ func TestCTEBuilder(t *testing.T) { ctetb.As(Select("a", "b").From("t")) ctetb.SQL("/* after table as */") + a.Equal(cteb.TableNames(), []string{ctetb.TableName()}) + sql, args := cteb.Build() a.Equal(sql, "/* init */ WITH /* table init */ t (a, b) /* after table */ AS (SELECT a, b FROM t) /* after table as */ /* after with */") a.Assert(args == nil) @@ -109,7 +111,7 @@ func TestRecursiveCTEBuilder(t *testing.T) { a := assert.New(t) cteb := newCTEBuilder() cteb.recursive = true - ctetb := newCTETableBuilder() + ctetb := newCTEQueryBuilder() cteb.SQL("/* init */") cteb.With(ctetb) cteb.SQL("/* after with */") diff --git a/ctequery.go b/ctequery.go new file mode 100644 index 0000000..581a31a --- /dev/null +++ b/ctequery.go @@ -0,0 +1,135 @@ +// Copyright 2024 Huan Du. All rights reserved. +// Licensed under the MIT license that can be found in the LICENSE file. + +package sqlbuilder + +const ( + cteQueryMarkerInit injectionMarker = iota + cteQueryMarkerAfterTable + cteQueryMarkerAfterAs +) + +// CTETable creates a new CTE query builder with default flavor, marking it as a table. +// +// The resulting CTE query can be used in a `SelectBuilder“, where its table name will be +// automatically included in the FROM clause. +func CTETable(name string, cols ...string) *CTEQueryBuilder { + return DefaultFlavor.NewCTEQueryBuilder().AddToTableList().Table(name, cols...) +} + +// CTEQuery creates a new CTE query builder with default flavor. +func CTEQuery(name string, cols ...string) *CTEQueryBuilder { + return DefaultFlavor.NewCTEQueryBuilder().Table(name, cols...) +} + +func newCTEQueryBuilder() *CTEQueryBuilder { + return &CTEQueryBuilder{ + args: &Args{}, + injection: newInjection(), + } +} + +// CTEQueryBuilder is a builder to build one table in CTE (Common Table Expression). +type CTEQueryBuilder struct { + name string + cols []string + builderVar string + + // if true, this query's table name will be automatically added to the table list + // in FROM clause of SELECT statement. + autoAddToTableList bool + + args *Args + + injection *injection + marker injectionMarker +} + +var _ Builder = new(CTEQueryBuilder) + +// CTETableBuilder is an alias of CTEQueryBuilder for backward compatibility. +// Deprecated: use CTEQueryBuilder instead. +type CTETableBuilder = CTEQueryBuilder + +// Table sets the table name and columns in a CTE table. +func (ctetb *CTEQueryBuilder) Table(name string, cols ...string) *CTEQueryBuilder { + ctetb.name = name + ctetb.cols = cols + ctetb.marker = cteQueryMarkerAfterTable + return ctetb +} + +// As sets the builder to select data. +func (ctetb *CTEQueryBuilder) As(builder Builder) *CTEQueryBuilder { + ctetb.builderVar = ctetb.args.Add(builder) + ctetb.marker = cteQueryMarkerAfterAs + return ctetb +} + +// AddToTableList sets flag to add table name to table list in FROM clause of SELECT statement. +func (ctetb *CTEQueryBuilder) AddToTableList() *CTEQueryBuilder { + ctetb.autoAddToTableList = true + return ctetb +} + +// ShouldAddToTableList returns flag to add table name to table list in FROM clause of SELECT statement. +func (ctetb *CTEQueryBuilder) ShouldAddToTableList() bool { + return ctetb.autoAddToTableList +} + +// String returns the compiled CTE string. +func (ctetb *CTEQueryBuilder) String() string { + sql, _ := ctetb.Build() + return sql +} + +// Build returns compiled CTE string and args. +func (ctetb *CTEQueryBuilder) Build() (sql string, args []interface{}) { + return ctetb.BuildWithFlavor(ctetb.args.Flavor) +} + +// BuildWithFlavor builds a CTE with the specified flavor and initial arguments. +func (ctetb *CTEQueryBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) { + buf := newStringBuilder() + ctetb.injection.WriteTo(buf, cteQueryMarkerInit) + + if ctetb.name != "" { + buf.WriteLeadingString(ctetb.name) + + if len(ctetb.cols) > 0 { + buf.WriteLeadingString("(") + buf.WriteStrings(ctetb.cols, ", ") + buf.WriteString(")") + } + + ctetb.injection.WriteTo(buf, cteQueryMarkerAfterTable) + } + + if ctetb.builderVar != "" { + buf.WriteLeadingString("AS (") + buf.WriteString(ctetb.builderVar) + buf.WriteRune(')') + + ctetb.injection.WriteTo(buf, cteQueryMarkerAfterAs) + } + + return ctetb.args.CompileWithFlavor(buf.String(), flavor, initialArg...) +} + +// SetFlavor sets the flavor of compiled sql. +func (ctetb *CTEQueryBuilder) SetFlavor(flavor Flavor) (old Flavor) { + old = ctetb.args.Flavor + ctetb.args.Flavor = flavor + return +} + +// SQL adds an arbitrary sql to current position. +func (ctetb *CTEQueryBuilder) SQL(sql string) *CTEQueryBuilder { + ctetb.injection.SQL(ctetb.marker, sql) + return ctetb +} + +// TableName returns the CTE table name. +func (ctetb *CTEQueryBuilder) TableName() string { + return ctetb.name +} diff --git a/ctetable.go b/ctetable.go deleted file mode 100644 index 8fbd70a..0000000 --- a/ctetable.go +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright 2024 Huan Du. All rights reserved. -// Licensed under the MIT license that can be found in the LICENSE file. - -package sqlbuilder - -const ( - cteTableMarkerInit injectionMarker = iota - cteTableMarkerAfterTable - cteTableMarkerAfterAs -) - -// CTETable creates a new CTE table builder with default flavor. -func CTETable(name string, cols ...string) *CTETableBuilder { - return DefaultFlavor.NewCTETableBuilder().Table(name, cols...) -} - -func newCTETableBuilder() *CTETableBuilder { - return &CTETableBuilder{ - args: &Args{}, - injection: newInjection(), - } -} - -// CTETableBuilder is a builder to build one table in CTE (Common Table Expression). -type CTETableBuilder struct { - name string - cols []string - builderVar string - - args *Args - - injection *injection - marker injectionMarker -} - -// Table sets the table name and columns in a CTE table. -func (ctetb *CTETableBuilder) Table(name string, cols ...string) *CTETableBuilder { - ctetb.name = name - ctetb.cols = cols - ctetb.marker = cteTableMarkerAfterTable - return ctetb -} - -// As sets the builder to select data. -func (ctetb *CTETableBuilder) As(builder Builder) *CTETableBuilder { - ctetb.builderVar = ctetb.args.Add(builder) - ctetb.marker = cteTableMarkerAfterAs - return ctetb -} - -// String returns the compiled CTE string. -func (ctetb *CTETableBuilder) String() string { - sql, _ := ctetb.Build() - return sql -} - -// Build returns compiled CTE string and args. -func (ctetb *CTETableBuilder) Build() (sql string, args []interface{}) { - return ctetb.BuildWithFlavor(ctetb.args.Flavor) -} - -// BuildWithFlavor builds a CTE with the specified flavor and initial arguments. -func (ctetb *CTETableBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) { - buf := newStringBuilder() - ctetb.injection.WriteTo(buf, cteTableMarkerInit) - - if ctetb.name != "" { - buf.WriteLeadingString(ctetb.name) - - if len(ctetb.cols) > 0 { - buf.WriteLeadingString("(") - buf.WriteStrings(ctetb.cols, ", ") - buf.WriteString(")") - } - - ctetb.injection.WriteTo(buf, cteTableMarkerAfterTable) - } - - if ctetb.builderVar != "" { - buf.WriteLeadingString("AS (") - buf.WriteString(ctetb.builderVar) - buf.WriteRune(')') - - ctetb.injection.WriteTo(buf, cteTableMarkerAfterAs) - } - - return ctetb.args.CompileWithFlavor(buf.String(), flavor, initialArg...) -} - -// SetFlavor sets the flavor of compiled sql. -func (ctetb *CTETableBuilder) SetFlavor(flavor Flavor) (old Flavor) { - old = ctetb.args.Flavor - ctetb.args.Flavor = flavor - return -} - -// SQL adds an arbitrary sql to current position. -func (ctetb *CTETableBuilder) SQL(sql string) *CTETableBuilder { - ctetb.injection.SQL(ctetb.marker, sql) - return ctetb -} - -// TableName returns the CTE table name. -func (ctetb *CTETableBuilder) TableName() string { - return ctetb.name -} diff --git a/delete.go b/delete.go index 3541847..b49c4d2 100644 --- a/delete.go +++ b/delete.go @@ -9,6 +9,7 @@ import ( const ( deleteMarkerInit injectionMarker = iota + deleteMarkerAfterWith deleteMarkerAfterDeleteFrom deleteMarkerAfterWhere deleteMarkerAfterOrderBy @@ -44,6 +45,7 @@ type DeleteBuilder struct { whereClauseProxy *whereClauseProxy whereClauseExpr string + cteBuilder string table string orderByCols []string order string @@ -62,6 +64,13 @@ func DeleteFrom(table string) *DeleteBuilder { return DefaultFlavor.NewDeleteBuilder().DeleteFrom(table) } +// With sets WITH clause (the Common Table Expression) before DELETE. +func (db *DeleteBuilder) With(builder *CTEBuilder) *DeleteBuilder { + db.marker = deleteMarkerAfterWith + db.cteBuilder = db.Var(builder) + return db +} + // DeleteFrom sets table name in DELETE. func (db *DeleteBuilder) DeleteFrom(table string) *DeleteBuilder { db.table = Escape(table) @@ -140,6 +149,11 @@ func (db *DeleteBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{ buf := newStringBuilder() db.injection.WriteTo(buf, deleteMarkerInit) + if db.cteBuilder != "" { + buf.WriteLeadingString(db.cteBuilder) + db.injection.WriteTo(buf, deleteMarkerAfterWith) + } + if len(db.table) > 0 { buf.WriteLeadingString("DELETE FROM ") buf.WriteString(db.table) diff --git a/delete_test.go b/delete_test.go index f644481..3fd6c4f 100644 --- a/delete_test.go +++ b/delete_test.go @@ -65,3 +65,18 @@ func ExampleDeleteBuilder_SQL() { // /* before */ DELETE FROM demo.user PARTITION (p0) WHERE id > ? /* after where */ ORDER BY id /* after order by */ LIMIT 10 /* after limit */ // [1234] } + +func ExampleDeleteBuilder_With() { + sql := With( + CTEQuery("users").As( + Select("id", "name").From("users").Where("name IS NULL"), + ), + ).DeleteFrom("orders").Where( + "users.id = orders.user_id", + ).String() + + fmt.Println(sql) + + // Output: + // WITH users AS (SELECT id, name FROM users WHERE name IS NULL) DELETE FROM orders WHERE users.id = orders.user_id +} diff --git a/flavor.go b/flavor.go index c5dc63a..0bf7bdd 100644 --- a/flavor.go +++ b/flavor.go @@ -149,8 +149,8 @@ func (f Flavor) NewCTEBuilder() *CTEBuilder { } // NewCTETableBuilder creates a new CTE table builder with flavor. -func (f Flavor) NewCTETableBuilder() *CTETableBuilder { - b := newCTETableBuilder() +func (f Flavor) NewCTEQueryBuilder() *CTEQueryBuilder { + b := newCTEQueryBuilder() b.SetFlavor(f) return b } diff --git a/select.go b/select.go index 6a6bd47..0edf3e7 100644 --- a/select.go +++ b/select.go @@ -66,7 +66,9 @@ type SelectBuilder struct { whereClauseProxy *whereClauseProxy whereClauseExpr string - cteBuilder string + cteBuilderVar string + cteBuilder *CTEBuilder + distinct bool tables []string selectCols []string @@ -96,14 +98,32 @@ func Select(col ...string) *SelectBuilder { // TableNames returns all table names in a SELECT. func (sb *SelectBuilder) TableNames() []string { - return sb.tables + var additionalTableNames []string + + if sb.cteBuilder != nil { + additionalTableNames = sb.cteBuilder.tableNamesForSelect() + } + + var tableNames []string + + if len(sb.tables) > 0 && len(additionalTableNames) > 0 { + tableNames = make([]string, len(sb.tables)+len(additionalTableNames)) + copy(tableNames, sb.tables) + copy(tableNames[len(sb.tables):], additionalTableNames) + } else if len(sb.tables) > 0 { + tableNames = sb.tables + } else if len(additionalTableNames) > 0 { + tableNames = additionalTableNames + } + + return tableNames } // With sets WITH clause (the Common Table Expression) before SELECT. func (sb *SelectBuilder) With(builder *CTEBuilder) *SelectBuilder { sb.marker = selectMarkerAfterWith - sb.cteBuilder = sb.Var(builder) - sb.tables = append(sb.tables, builder.TableNames()...) + sb.cteBuilderVar = sb.Var(builder) + sb.cteBuilder = builder return sb } @@ -284,8 +304,8 @@ func (sb *SelectBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{ oraclePage := flavor == Oracle && (sb.limit >= 0 || sb.offset >= 0) - if sb.cteBuilder != "" { - buf.WriteLeadingString(sb.cteBuilder) + if sb.cteBuilderVar != "" { + buf.WriteLeadingString(sb.cteBuilderVar) sb.injection.WriteTo(buf, selectMarkerAfterWith) } @@ -341,9 +361,11 @@ func (sb *SelectBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{ } } - if len(sb.tables) > 0 { + tableNames := sb.TableNames() + + if len(tableNames) > 0 { buf.WriteLeadingString("FROM ") - buf.WriteStrings(sb.tables, ", ") + buf.WriteStrings(tableNames, ", ") } sb.injection.WriteTo(buf, selectMarkerAfterFrom) diff --git a/select_test.go b/select_test.go index d7ee322..b7d0023 100644 --- a/select_test.go +++ b/select_test.go @@ -346,3 +346,21 @@ func ExampleSelectBuilder_NumCol() { // Output: // 3 } + +func ExampleSelectBuilder_With() { + sql := With( + CTEQuery("users").As( + Select("id", "name").From("users").Where("prime IS NOT NULL"), + ), + + // The CTE table orders will be added to table list of FROM clause automatically. + CTETable("orders").As( + Select("id", "user_id").From("orders"), + ), + ).Select("orders.id").Join("users", "orders.user_id = users.id").Limit(10).String() + + fmt.Println(sql) + + // Output: + // WITH users AS (SELECT id, name FROM users WHERE prime IS NOT NULL), orders AS (SELECT id, user_id FROM orders) SELECT orders.id FROM orders JOIN users ON orders.user_id = users.id LIMIT 10 +} diff --git a/update.go b/update.go index 2a04ac1..ec17e8b 100644 --- a/update.go +++ b/update.go @@ -10,6 +10,7 @@ import ( const ( updateMarkerInit injectionMarker = iota + updateMarkerAfterWith updateMarkerAfterUpdate updateMarkerAfterSet updateMarkerAfterWhere @@ -46,6 +47,7 @@ type UpdateBuilder struct { whereClauseProxy *whereClauseProxy whereClauseExpr string + cteBuilder string table string assignments []string orderByCols []string @@ -65,6 +67,13 @@ func Update(table string) *UpdateBuilder { return DefaultFlavor.NewUpdateBuilder().Update(table) } +// With sets WITH clause (the Common Table Expression) before UPDATE. +func (ub *UpdateBuilder) With(builder *CTEBuilder) *UpdateBuilder { + ub.marker = updateMarkerAfterWith + ub.cteBuilder = ub.Var(builder) + return ub +} + // Update sets table name in UPDATE. func (ub *UpdateBuilder) Update(table string) *UpdateBuilder { ub.table = Escape(table) @@ -203,6 +212,11 @@ func (ub *UpdateBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{ buf := newStringBuilder() ub.injection.WriteTo(buf, updateMarkerInit) + if ub.cteBuilder != "" { + buf.WriteLeadingString(ub.cteBuilder) + ub.injection.WriteTo(buf, updateMarkerAfterWith) + } + if len(ub.table) > 0 { buf.WriteLeadingString("UPDATE ") buf.WriteString(ub.table) diff --git a/update_test.go b/update_test.go index a8d0acf..9d2323d 100644 --- a/update_test.go +++ b/update_test.go @@ -132,3 +132,20 @@ func ExampleUpdateBuilder_NumAssignment() { // Output: // 3 } + +func ExampleUpdateBuilder_With() { + sql := With( + CTEQuery("users").As( + Select("id", "name").From("users").Where("prime IS NOT NULL"), + ), + ).Update("orders").Set( + "orders.transport_fee = 0", + ).Where( + "users.id = orders.user_id", + ).String() + + fmt.Println(sql) + + // Output: + // WITH users AS (SELECT id, name FROM users WHERE prime IS NOT NULL) UPDATE orders SET orders.transport_fee = 0 WHERE users.id = orders.user_id +} diff --git a/whereclause.go b/whereclause.go index 3e80b93..70451a9 100644 --- a/whereclause.go +++ b/whereclause.go @@ -87,9 +87,9 @@ func (wc *WhereClause) SetFlavor(flavor Flavor) (old Flavor) { } // AddWhereExpr adds an AND expression to WHERE clause with the specified arguments. -func (wc *WhereClause) AddWhereExpr(args *Args, andExpr ...string) { +func (wc *WhereClause) AddWhereExpr(args *Args, andExpr ...string) *WhereClause { if len(andExpr) == 0 { - return + return wc } // Merge with last clause if possible. @@ -98,7 +98,7 @@ func (wc *WhereClause) AddWhereExpr(args *Args, andExpr ...string) { if lastClause.args == args { lastClause.andExprs = append(lastClause.andExprs, andExpr...) - return + return wc } } @@ -106,13 +106,15 @@ func (wc *WhereClause) AddWhereExpr(args *Args, andExpr ...string) { args: args, andExprs: andExpr, }) + return wc } // AddWhereClause adds all clauses in the whereClause to the wc. -func (wc *WhereClause) AddWhereClause(whereClause *WhereClause) { +func (wc *WhereClause) AddWhereClause(whereClause *WhereClause) *WhereClause { if whereClause == nil { - return + return wc } wc.clauses = append(wc.clauses, whereClause.clauses...) + return wc } diff --git a/whereclause_test.go b/whereclause_test.go index 53f9967..1a64054 100644 --- a/whereclause_test.go +++ b/whereclause_test.go @@ -207,26 +207,40 @@ func TestWhereClauseSharedInstances(t *testing.T) { a.Equal(ub.String(), "UPDATE t SET foo = 1 WHERE id = ?") a.Equal(db.String(), "DELETE FROM t WHERE id = ?") + // Add more WhereClause. + cond := NewCond() + moreWhereClause := NewWhereClause().AddWhereExpr( + cond.Args, + cond.GreaterEqualThan("credit", 100), + ) + + // The moreWhereClause is added to whereClause. + // All builders sharing the same WhereClause will have the same new cluase. + sb.AddWhereClause(moreWhereClause) + a.Equal(sb.String(), "SELECT * FROM t WHERE id = ? AND credit >= ?") + a.Equal(ub.String(), "UPDATE t SET foo = 1 WHERE id = ? AND credit >= ?") + a.Equal(db.String(), "DELETE FROM t WHERE id = ? AND credit >= ?") + // Copied WhereClause is independent from the original. ub.WhereClause = CopyWhereClause(whereClause) ub.Where(ub.GreaterEqualThan("level", 10)) db.Where(db.In("status", 1, 2)) - a.Equal(sb.String(), "SELECT * FROM t WHERE id = ? AND status IN (?, ?)") - a.Equal(ub.String(), "UPDATE t SET foo = 1 WHERE id = ? AND level >= ?") - a.Equal(db.String(), "DELETE FROM t WHERE id = ? AND status IN (?, ?)") + a.Equal(sb.String(), "SELECT * FROM t WHERE id = ? AND credit >= ? AND status IN (?, ?)") + a.Equal(ub.String(), "UPDATE t SET foo = 1 WHERE id = ? AND credit >= ? AND level >= ?") + a.Equal(db.String(), "DELETE FROM t WHERE id = ? AND credit >= ? AND status IN (?, ?)") // Clear the WhereClause and add new where clause and expressions. db.WhereClause = nil db.AddWhereClause(ub.WhereClause) db.AddWhereExpr(db.Args, db.Equal("deleted", 0)) - a.Equal(sb.String(), "SELECT * FROM t WHERE id = ? AND status IN (?, ?)") - a.Equal(ub.String(), "UPDATE t SET foo = 1 WHERE id = ? AND level >= ?") - a.Equal(db.String(), "DELETE FROM t WHERE id = ? AND level >= ? AND deleted = ?") + a.Equal(sb.String(), "SELECT * FROM t WHERE id = ? AND credit >= ? AND status IN (?, ?)") + a.Equal(ub.String(), "UPDATE t SET foo = 1 WHERE id = ? AND credit >= ? AND level >= ?") + a.Equal(db.String(), "DELETE FROM t WHERE id = ? AND credit >= ? AND level >= ? AND deleted = ?") // Nested WhereClause. ub.Where(ub.NotIn("id", sb)) sb.Where(sb.NotEqual("flag", "normal")) - a.Equal(ub.String(), "UPDATE t SET foo = 1 WHERE id = ? AND level >= ? AND id NOT IN (SELECT * FROM t WHERE id = ? AND status IN (?, ?) AND flag <> ?)") + a.Equal(ub.String(), "UPDATE t SET foo = 1 WHERE id = ? AND credit >= ? AND level >= ? AND id NOT IN (SELECT * FROM t WHERE id = ? AND credit >= ? AND status IN (?, ?) AND flag <> ?)") } func TestEmptyWhereExpr(t *testing.T) {