diff --git a/internal/compiler/expand.go b/internal/compiler/expand.go index 8ea2fd2d20..29adfa316a 100644 --- a/internal/compiler/expand.go +++ b/internal/compiler/expand.go @@ -38,12 +38,7 @@ func (c *Compiler) expand(qc *QueryCatalog, raw *ast.RawStmt) ([]source.Edit, er func (c *Compiler) quoteIdent(ident string) string { if c.parser.IsReservedKeyword(ident) { - switch c.conf.Engine { - case config.EngineMySQL: - return "`" + ident + "`" - default: - return "\"" + ident + "\"" - } + return c.quote(ident) } if c.conf.Engine == config.EnginePostgreSQL { // camelCase means the column is also camelCase @@ -54,6 +49,15 @@ func (c *Compiler) quoteIdent(ident string) string { return ident } +func (c *Compiler) quote(x string) string { + switch c.conf.Engine { + case config.EngineMySQL: + return "`" + x + "`" + default: + return "\"" + x + "\"" + } +} + func (c *Compiler) expandStmt(qc *QueryCatalog, raw *ast.RawStmt, node ast.Node) ([]source.Edit, error) { tables, err := c.sourceTables(qc, node) if err != nil { @@ -132,16 +136,36 @@ func (c *Compiler) expandStmt(qc *QueryCatalog, raw *ast.RawStmt, node ast.Node) for _, p := range parts { old = append(old, c.quoteIdent(p)) } - oldString := strings.Join(old, ".") + + var oldString string + var oldFunc func(string) int // use the sqlc.embed string instead if embed, ok := qc.embeds.Find(ref); ok { oldString = embed.Orig() + } else { + oldFunc = func(s string) int { + length := 0 + for i, o := range old { + if hasSeparator := i > 0; hasSeparator { + length++ + } + if strings.HasPrefix(s[length:], o) { + length += len(o) + } else if quoted := c.quote(o); strings.HasPrefix(s[length:], quoted) { + length += len(quoted) + } else { + length += len(o) + } + } + return length + } } edits = append(edits, source.Edit{ Location: res.Location - raw.StmtLocation, Old: oldString, + OldFunc: oldFunc, New: strings.Join(cols, ", "), }) } diff --git a/internal/endtoend/testdata/star_expansion/mysql/go/query.sql.go b/internal/endtoend/testdata/star_expansion/mysql/go/query.sql.go index 8f8d72b09a..6738b4ff63 100644 --- a/internal/endtoend/testdata/star_expansion/mysql/go/query.sql.go +++ b/internal/endtoend/testdata/star_expansion/mysql/go/query.sql.go @@ -52,3 +52,30 @@ func (q *Queries) StarExpansion(ctx context.Context) ([]StarExpansionRow, error) } return items, nil } + +const starQuotedExpansion = `-- name: StarQuotedExpansion :many +SELECT t.a, t.b FROM foo ` + "`" + `t` + "`" + ` +` + +func (q *Queries) StarQuotedExpansion(ctx context.Context) ([]Foo, error) { + rows, err := q.db.QueryContext(ctx, starQuotedExpansion) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Foo + for rows.Next() { + var i Foo + if err := rows.Scan(&i.A, &i.B); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/star_expansion/mysql/query.sql b/internal/endtoend/testdata/star_expansion/mysql/query.sql index d330c646f3..4515ce46e6 100644 --- a/internal/endtoend/testdata/star_expansion/mysql/query.sql +++ b/internal/endtoend/testdata/star_expansion/mysql/query.sql @@ -2,3 +2,6 @@ CREATE TABLE foo (a text, b text); /* name: StarExpansion :many */ SELECT *, *, foo.* FROM foo; + +/* name: StarQuotedExpansion :many */ +SELECT `t`.* FROM foo `t`; diff --git a/internal/endtoend/testdata/star_expansion/postgresql/pgx/v4/go/query.sql.go b/internal/endtoend/testdata/star_expansion/postgresql/pgx/v4/go/query.sql.go index 904e5d6157..1fc5ae5f54 100644 --- a/internal/endtoend/testdata/star_expansion/postgresql/pgx/v4/go/query.sql.go +++ b/internal/endtoend/testdata/star_expansion/postgresql/pgx/v4/go/query.sql.go @@ -49,3 +49,27 @@ func (q *Queries) StarExpansion(ctx context.Context) ([]StarExpansionRow, error) } return items, nil } + +const starQuotedExpansion = `-- name: StarQuotedExpansion :many +SELECT t.a, t.b FROM foo "t" +` + +func (q *Queries) StarQuotedExpansion(ctx context.Context) ([]Foo, error) { + rows, err := q.db.Query(ctx, starQuotedExpansion) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Foo + for rows.Next() { + var i Foo + if err := rows.Scan(&i.A, &i.B); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/star_expansion/postgresql/pgx/v4/query.sql b/internal/endtoend/testdata/star_expansion/postgresql/pgx/v4/query.sql index 0305062fbd..77ef29f8e7 100644 --- a/internal/endtoend/testdata/star_expansion/postgresql/pgx/v4/query.sql +++ b/internal/endtoend/testdata/star_expansion/postgresql/pgx/v4/query.sql @@ -2,3 +2,6 @@ CREATE TABLE foo (a text, b text); -- name: StarExpansion :many SELECT *, *, foo.* FROM foo; + +-- name: StarQuotedExpansion :many +SELECT "t".* FROM foo "t"; \ No newline at end of file diff --git a/internal/endtoend/testdata/star_expansion/postgresql/pgx/v5/go/query.sql.go b/internal/endtoend/testdata/star_expansion/postgresql/pgx/v5/go/query.sql.go index 56a635f890..206b188a0a 100644 --- a/internal/endtoend/testdata/star_expansion/postgresql/pgx/v5/go/query.sql.go +++ b/internal/endtoend/testdata/star_expansion/postgresql/pgx/v5/go/query.sql.go @@ -50,3 +50,27 @@ func (q *Queries) StarExpansion(ctx context.Context) ([]StarExpansionRow, error) } return items, nil } + +const starQuotedExpansion = `-- name: StarQuotedExpansion :many +SELECT t.a, t.b FROM foo "t" +` + +func (q *Queries) StarQuotedExpansion(ctx context.Context) ([]Foo, error) { + rows, err := q.db.Query(ctx, starQuotedExpansion) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Foo + for rows.Next() { + var i Foo + if err := rows.Scan(&i.A, &i.B); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/star_expansion/postgresql/pgx/v5/query.sql b/internal/endtoend/testdata/star_expansion/postgresql/pgx/v5/query.sql index 0305062fbd..77ef29f8e7 100644 --- a/internal/endtoend/testdata/star_expansion/postgresql/pgx/v5/query.sql +++ b/internal/endtoend/testdata/star_expansion/postgresql/pgx/v5/query.sql @@ -2,3 +2,6 @@ CREATE TABLE foo (a text, b text); -- name: StarExpansion :many SELECT *, *, foo.* FROM foo; + +-- name: StarQuotedExpansion :many +SELECT "t".* FROM foo "t"; \ No newline at end of file diff --git a/internal/endtoend/testdata/star_expansion/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/star_expansion/postgresql/stdlib/go/query.sql.go index 8f8d72b09a..b7e6b55a6e 100644 --- a/internal/endtoend/testdata/star_expansion/postgresql/stdlib/go/query.sql.go +++ b/internal/endtoend/testdata/star_expansion/postgresql/stdlib/go/query.sql.go @@ -52,3 +52,30 @@ func (q *Queries) StarExpansion(ctx context.Context) ([]StarExpansionRow, error) } return items, nil } + +const starQuotedExpansion = `-- name: StarQuotedExpansion :many +SELECT t.a, t.b FROM foo "t" +` + +func (q *Queries) StarQuotedExpansion(ctx context.Context) ([]Foo, error) { + rows, err := q.db.QueryContext(ctx, starQuotedExpansion) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Foo + for rows.Next() { + var i Foo + if err := rows.Scan(&i.A, &i.B); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/star_expansion/postgresql/stdlib/query.sql b/internal/endtoend/testdata/star_expansion/postgresql/stdlib/query.sql index 0305062fbd..77ef29f8e7 100644 --- a/internal/endtoend/testdata/star_expansion/postgresql/stdlib/query.sql +++ b/internal/endtoend/testdata/star_expansion/postgresql/stdlib/query.sql @@ -2,3 +2,6 @@ CREATE TABLE foo (a text, b text); -- name: StarExpansion :many SELECT *, *, foo.* FROM foo; + +-- name: StarQuotedExpansion :many +SELECT "t".* FROM foo "t"; \ No newline at end of file diff --git a/internal/endtoend/testdata/star_expansion/sqlite/go/db.go b/internal/endtoend/testdata/star_expansion/sqlite/go/db.go new file mode 100644 index 0000000000..57406b68e8 --- /dev/null +++ b/internal/endtoend/testdata/star_expansion/sqlite/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/star_expansion/sqlite/go/models.go b/internal/endtoend/testdata/star_expansion/sqlite/go/models.go new file mode 100644 index 0000000000..c0cab4c642 --- /dev/null +++ b/internal/endtoend/testdata/star_expansion/sqlite/go/models.go @@ -0,0 +1,14 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 + +package querytest + +import ( + "database/sql" +) + +type Foo struct { + A sql.NullString + B sql.NullString +} diff --git a/internal/endtoend/testdata/star_expansion/sqlite/go/query.sql.go b/internal/endtoend/testdata/star_expansion/sqlite/go/query.sql.go new file mode 100644 index 0000000000..b7e6b55a6e --- /dev/null +++ b/internal/endtoend/testdata/star_expansion/sqlite/go/query.sql.go @@ -0,0 +1,81 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const starExpansion = `-- name: StarExpansion :many +SELECT a, b, a, b, foo.a, foo.b FROM foo +` + +type StarExpansionRow struct { + A sql.NullString + B sql.NullString + A_2 sql.NullString + B_2 sql.NullString + A_3 sql.NullString + B_3 sql.NullString +} + +func (q *Queries) StarExpansion(ctx context.Context) ([]StarExpansionRow, error) { + rows, err := q.db.QueryContext(ctx, starExpansion) + if err != nil { + return nil, err + } + defer rows.Close() + var items []StarExpansionRow + for rows.Next() { + var i StarExpansionRow + if err := rows.Scan( + &i.A, + &i.B, + &i.A_2, + &i.B_2, + &i.A_3, + &i.B_3, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const starQuotedExpansion = `-- name: StarQuotedExpansion :many +SELECT t.a, t.b FROM foo "t" +` + +func (q *Queries) StarQuotedExpansion(ctx context.Context) ([]Foo, error) { + rows, err := q.db.QueryContext(ctx, starQuotedExpansion) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Foo + for rows.Next() { + var i Foo + if err := rows.Scan(&i.A, &i.B); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/star_expansion/sqlite/query.sql b/internal/endtoend/testdata/star_expansion/sqlite/query.sql new file mode 100644 index 0000000000..6cc8506280 --- /dev/null +++ b/internal/endtoend/testdata/star_expansion/sqlite/query.sql @@ -0,0 +1,7 @@ +CREATE TABLE foo (a text, b text); + +-- name: StarExpansion :many +SELECT *, *, foo.* FROM foo; + +-- name: StarQuotedExpansion :many +SELECT "t".* FROM foo "t"; diff --git a/internal/endtoend/testdata/star_expansion/sqlite/sqlc.json b/internal/endtoend/testdata/star_expansion/sqlite/sqlc.json new file mode 100644 index 0000000000..fc58be5b0d --- /dev/null +++ b/internal/endtoend/testdata/star_expansion/sqlite/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "engine": "sqlite", + "path": "go", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/engine/sqlite/convert.go b/internal/engine/sqlite/convert.go index 7a9bfa6dbd..ad8ddc7c0e 100644 --- a/internal/engine/sqlite/convert.go +++ b/internal/engine/sqlite/convert.go @@ -767,7 +767,7 @@ func (c *cc) convertInsert_stmtContext(n *parser.Insert_stmtContext) ast.Node { rel.Schemaname = &schemaName } if n.Table_alias() != nil { - tableAlias := n.Table_alias().GetText() + tableAlias := identifier(n.Table_alias().GetText()) rel.Alias = &ast.Alias{ Aliasname: &tableAlias, } @@ -837,7 +837,7 @@ func (c *cc) convertTablesOrSubquery(n []parser.ITable_or_subqueryContext) []ast rv.Schemaname = &schema } if from.Table_alias() != nil { - alias := from.Table_alias().GetText() + alias := identifier(from.Table_alias().GetText()) rv.Alias = &ast.Alias{Aliasname: &alias} } if from.Table_alias_fallback() != nil { @@ -870,7 +870,7 @@ func (c *cc) convertTablesOrSubquery(n []parser.ITable_or_subqueryContext) []ast } if from.Table_alias() != nil { - alias := from.Table_alias().GetText() + alias := identifier(from.Table_alias().GetText()) rf.Alias = &ast.Alias{Aliasname: &alias} } @@ -881,7 +881,7 @@ func (c *cc) convertTablesOrSubquery(n []parser.ITable_or_subqueryContext) []ast } if from.Table_alias() != nil { - alias := from.Table_alias().GetText() + alias := identifier(from.Table_alias().GetText()) rs.Alias = &ast.Alias{Aliasname: &alias} } diff --git a/internal/source/code.go b/internal/source/code.go index 9a6ed077d3..f34e3e3684 100644 --- a/internal/source/code.go +++ b/internal/source/code.go @@ -12,6 +12,7 @@ type Edit struct { Location int Old string New string + OldFunc func(string) int } func LineNumber(source string, head int) (int, int) { @@ -63,8 +64,14 @@ func Mutate(raw string, a []Edit) (string, error) { if start > len(s) || start < 0 { return "", fmt.Errorf("edit start location is out of bounds") } + var oldLen int + if edit.OldFunc != nil { + oldLen = edit.OldFunc(s[start:]) + } else { + oldLen = len(edit.Old) + } - stop := edit.Location + len(edit.Old) + stop := edit.Location + oldLen if stop > len(s) { return "", fmt.Errorf("edit stop location is out of bounds") } @@ -73,7 +80,7 @@ func Mutate(raw string, a []Edit) (string, error) { // this edit overlaps the previous one (and is therefore a developer error) if idx != 0 { prevEdit := a[idx-1] - if prevEdit.Location < edit.Location+len(edit.Old) { + if prevEdit.Location < edit.Location+oldLen { return "", fmt.Errorf("2 edits overlap") } }