From a46eac72ce32c8263b19a8ae585fb49cd567ccc2 Mon Sep 17 00:00:00 2001 From: Yonghwan SO Date: Sun, 25 Sep 2022 13:05:40 +0900 Subject: [PATCH] prevent empty query to be executed --- executors.go | 10 ++++++++++ query.go | 4 ++++ query_test.go | 19 +++++++++++++++++++ sql_builder.go | 18 +++++++++++------- 4 files changed, 44 insertions(+), 7 deletions(-) diff --git a/executors.go b/executors.go index dfda8ebc..4a618183 100644 --- a/executors.go +++ b/executors.go @@ -20,10 +20,16 @@ func (c *Connection) Reload(model interface{}) error { }) } +// TODO: consider merging the following two methods. + // Exec runs the given query. func (q *Query) Exec() error { return q.Connection.timeFunc("Exec", func() error { sql, args := q.ToSQL(nil) + if sql == "" { + return fmt.Errorf("empty query") + } + log(logging.SQL, sql, args...) _, err := q.Connection.Store.Exec(sql, args...) return err @@ -36,6 +42,10 @@ func (q *Query) ExecWithCount() (int, error) { count := int64(0) return int(count), q.Connection.timeFunc("Exec", func() error { sql, args := q.ToSQL(nil) + if sql == "" { + return fmt.Errorf("empty query") + } + log(logging.SQL, sql, args...) result, err := q.Connection.Store.Exec(sql, args...) if err != nil { diff --git a/query.go b/query.go index afb1a3a5..6df38d54 100644 --- a/query.go +++ b/query.go @@ -213,6 +213,10 @@ func Q(c *Connection) *Query { // from the `Model` passed in. func (q Query) ToSQL(model *Model, addColumns ...string) (string, []interface{}) { sb := q.toSQLBuilder(model, addColumns...) + // nil model is allowed only when if RawSQL is provided. + if model == nil && (q.RawSQL == nil || q.RawSQL.Fragment == "") { + return "", nil + } return sb.String(), sb.Args() } diff --git a/query_test.go b/query_test.go index 9adf2723..094033aa 100644 --- a/query_test.go +++ b/query_test.go @@ -386,3 +386,22 @@ func Test_ToSQL_RawQuery(t *testing.T) { a.Equal(args, []interface{}{"random", "query"}) }) } + +func Test_RawQuery_Empty(t *testing.T) { + Debug = true + defer func() { Debug = false }() + + t.Run("EmptyQuery", func(t *testing.T) { + r := require.New(t) + transaction(func(tx *Connection) { + r.Error(tx.Q().Exec()) + }) + }) + + t.Run("EmptyRawQuery", func(t *testing.T) { + r := require.New(t) + transaction(func(tx *Connection) { + r.Error(tx.RawQuery("").Exec()) + }) + }) +} diff --git a/sql_builder.go b/sql_builder.go index 18ffaa81..966964ba 100644 --- a/sql_builder.go +++ b/sql_builder.go @@ -17,6 +17,8 @@ type sqlBuilder struct { AddColumns []string sql string args []interface{} + isCompiled bool + err error } func newSQLBuilder(q Query, m *Model, addColumns ...string) *sqlBuilder { @@ -25,6 +27,7 @@ func newSQLBuilder(q Query, m *Model, addColumns ...string) *sqlBuilder { Model: m, AddColumns: addColumns, args: []interface{}{}, + isCompiled: false, } } @@ -53,19 +56,15 @@ func hasLimitOrOffset(sqlString string) bool { } func (sq *sqlBuilder) String() string { - if sq.sql == "" { + if !sq.isCompiled { sq.compile() } return sq.sql } func (sq *sqlBuilder) Args() []interface{} { - if len(sq.args) == 0 { - if len(sq.Query.RawSQL.Arguments) > 0 { - sq.args = sq.Query.RawSQL.Arguments - } else { - sq.compile() - } + if !sq.isCompiled { + sq.compile() } return sq.args } @@ -83,7 +82,12 @@ func (sq *sqlBuilder) compile() { } sq.sql = sq.Query.RawSQL.Fragment } + sq.args = sq.Query.RawSQL.Arguments } else { + if sq.Model == nil { + sq.err = fmt.Errorf("sqlBuilder.compile() called but no RawSQL and Model specified") + return + } switch sq.Query.Operation { case Select: sq.sql = sq.buildSelectSQL()