Skip to content

Commit

Permalink
Merge pull request #12 from goccy/support-join
Browse files Browse the repository at this point in the history
Support all JOIN types
  • Loading branch information
goccy authored Aug 1, 2022
2 parents f689f9b + 9ce08b0 commit b49d99a
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 37 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ A list of ZetaSQL specifications and features supported by go-zetasqlite.
- [X] HAVING
- [x] ORDER BY
- [X] GROUP BY - ROLLUP
- [X] INNER/LEFT JOIN
- [X] INNER/LEFT/RIGHT/FULL/CROSS JOIN
- [x] QUALIFY

### Aggregate functions
Expand Down Expand Up @@ -380,8 +380,8 @@ A list of ZetaSQL specifications and features supported by go-zetasqlite.
- [x] DATE
- [x] DATE_ADD
- [x] DATE_SUB
- [ ] DATE_DIFF
- [ ] DATE_TRUNC
- [x] DATE_DIFF
- [x] DATE_TRUNC
- [ ] DATE_FROM_UNIX_DATE
- [ ] FORMAT_DATE
- [ ] LAST_DAY
Expand Down
13 changes: 13 additions & 0 deletions internal/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type (
needsTableNameForColumnKey struct{}
tableNameToColumnListMapKey struct{}
useColumnIDKey struct{}
rowIDColumnKey struct{}
)

func namePathFromContext(ctx context.Context) []string {
Expand Down Expand Up @@ -197,6 +198,18 @@ func tableNameToColumnListMap(ctx context.Context) map[string][]*ast.Column {
return value.(map[string][]*ast.Column)
}

func withRowIDColumn(ctx context.Context) context.Context {
return context.WithValue(ctx, rowIDColumnKey{}, true)
}

func needsRowIDColumn(ctx context.Context) bool {
value := ctx.Value(rowIDColumnKey{})
if value == nil {
return false
}
return value.(bool)
}

func WithCurrentTime(ctx context.Context, now time.Time) context.Context {
return context.WithValue(ctx, currentTimeKey{}, &now)
}
Expand Down
52 changes: 45 additions & 7 deletions internal/formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ func (n *JoinScanNode) FormatSQL(ctx context.Context) (string, error) {
if n.node == nil {
return "", nil
}
left, err := newNode(n.node.LeftScan()).FormatSQL(ctx)
left, err := newNode(n.node.LeftScan()).FormatSQL(withRowIDColumn(ctx))
if err != nil {
return "", err
}
Expand All @@ -561,18 +561,38 @@ func (n *JoinScanNode) FormatSQL(ctx context.Context) (string, error) {
if err != nil {
return "", err
}
var joinType string
switch n.node.JoinType() {
case ast.JoinTypeInner:
joinType = "JOIN"
return fmt.Sprintf("%s JOIN %s ON %s", left, right, joinExpr), nil
case ast.JoinTypeLeft:
joinType = "LEFT JOIN"
return fmt.Sprintf("%s LEFT JOIN %s ON %s", left, right, joinExpr), nil
case ast.JoinTypeRight:
joinType = "RIGHT JOIN"
// SQLite doesn't support RIGHT JOIN at v3.38.0, so emulate by using LEFT JOIN.
// ROW_NUMBER() OVER() AS `row_id`
return fmt.Sprintf("%s LEFT JOIN %s ON %s ORDER BY `row_id` NULLS LAST", right, left, joinExpr), nil
case ast.JoinTypeFull:
joinType = "FULL JOIN"
// SQLite doesn't support FULL OUTER JOIN at v3.38.0,
// so emulate by combination of LEFT JOIN and UNION ALL and DISTINCT.
var (
columns []string
columnMap = columnRefMap(ctx)
)
for _, col := range n.node.ColumnList() {
colName := string(uniqueColumnName(ctx, col))
if ref, exists := columnMap[colName]; exists {
columns = append(columns, ref)
delete(columnMap, colName)
} else {
columns = append(columns, fmt.Sprintf("`%s`", colName))
}
}
return fmt.Sprintf(
"SELECT DISTINCT %[1]s FROM (SELECT %[1]s FROM %[2]s LEFT JOIN %[3]s ON %[4]s UNION ALL SELECT %[1]s FROM %[3]s LEFT JOIN %[2]s ON %[4]s)",
strings.Join(columns, ","),
left, right, joinExpr,
), nil
}
return fmt.Sprintf("%s %s %s ON %s", left, joinType, right, joinExpr), nil
return "", fmt.Errorf("unexpected join type %s", n.node.JoinType())
}

func (n *ArrayScanNode) FormatSQL(ctx context.Context) (string, error) {
Expand Down Expand Up @@ -682,6 +702,12 @@ func (n *AggregateScanNode) FormatSQL(ctx context.Context) (string, error) {
columns = append(columns, fmt.Sprintf("`%s`", colName))
}
}
if needsRowIDColumn(ctx) {
columns = append(
columns,
"ROW_NUMBER() OVER() AS `row_id`",
)
}
if len(n.node.GroupingSetList()) != 0 {
columnPatterns := [][]string{}
groupByColumnPatterns := [][]string{}
Expand Down Expand Up @@ -888,6 +914,12 @@ func (n *WithRefScanNode) FormatSQL(ctx context.Context) (string, error) {
fmt.Sprintf("`%s` AS `%s`", uniqueColumnName(ctx, columnDefs[i]), uniqueColumnName(ctx, columns[i])),
)
}
if needsRowIDColumn(ctx) {
formattedColumns = append(
formattedColumns,
"ROW_NUMBER() OVER() AS `row_id`",
)
}
return fmt.Sprintf("(SELECT %s FROM %s)", strings.Join(formattedColumns, ","), tableName), nil
}

Expand Down Expand Up @@ -1075,6 +1107,12 @@ func (n *ProjectScanNode) FormatSQL(ctx context.Context) (string, error) {
)
}
}
if needsRowIDColumn(ctx) {
columns = append(
columns,
"ROW_NUMBER() OVER() AS `row_id`",
)
}
formattedInput, err := formatInput(input)
if err != nil {
return "", err
Expand Down
2 changes: 1 addition & 1 deletion internal/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"database/sql"
"database/sql/driver"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"reflect"
Expand All @@ -14,6 +13,7 @@ import (
"strings"
"time"

"github.com/goccy/go-json"
ast "github.com/goccy/go-zetasql/resolved_ast"
"github.com/goccy/go-zetasql/types"
)
Expand Down
75 changes: 49 additions & 26 deletions query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1609,32 +1609,55 @@ SELECT Roster.LastName, TeamMascot.Mascot FROM Roster LEFT JOIN TeamMascot ON Ro
{"Eisenhower", nil},
},
},
/*
{
name: "right join",
query: `
WITH Roster AS
(SELECT 'Adams' as LastName, 50 as SchoolID UNION ALL
SELECT 'Buchanan', 52 UNION ALL
SELECT 'Coolidge', 52 UNION ALL
SELECT 'Davis', 51 UNION ALL
SELECT 'Eisenhower', 77),
TeamMascot AS
(SELECT 50 as SchoolID, 'Jaguars' as Mascot UNION ALL
SELECT 51, 'Knights' UNION ALL
SELECT 52, 'Lakers' UNION ALL
SELECT 53, 'Mustangs')
SELECT Roster.LastName, TeamMascot.Mascot FROM Roster RIGHT JOIN TeamMascot ON Roster.SchoolID = TeamMascot.SchoolID
`,
expectedRows: [][]interface{}{
{"Adams", "Jaguars"},
{"Buchanan", "Lakers"},
{"Coolidge", "Lakers"},
{"Davis", "Knights"},
{nil, "Mustangs"},
},
},
*/
{
name: "right join",
query: `
WITH Roster AS
(SELECT 'Adams' as LastName, 50 as SchoolID UNION ALL
SELECT 'Buchanan', 52 UNION ALL
SELECT 'Coolidge', 52 UNION ALL
SELECT 'Davis', 51 UNION ALL
SELECT 'Eisenhower', 77),
TeamMascot AS
(SELECT 50 as SchoolID, 'Jaguars' as Mascot UNION ALL
SELECT 51, 'Knights' UNION ALL
SELECT 52, 'Lakers' UNION ALL
SELECT 53, 'Mustangs')
SELECT Roster.LastName, TeamMascot.Mascot FROM Roster RIGHT JOIN TeamMascot ON Roster.SchoolID = TeamMascot.SchoolID
`,
expectedRows: [][]interface{}{
{"Adams", "Jaguars"},
{"Buchanan", "Lakers"},
{"Coolidge", "Lakers"},
{"Davis", "Knights"},
{nil, "Mustangs"},
},
},
{
name: "full join",
query: `
WITH Roster AS
(SELECT 'Adams' as LastName, 50 as SchoolID UNION ALL
SELECT 'Buchanan', 52 UNION ALL
SELECT 'Coolidge', 52 UNION ALL
SELECT 'Davis', 51 UNION ALL
SELECT 'Eisenhower', 77),
TeamMascot AS
(SELECT 50 as SchoolID, 'Jaguars' as Mascot UNION ALL
SELECT 51, 'Knights' UNION ALL
SELECT 52, 'Lakers' UNION ALL
SELECT 53, 'Mustangs')
SELECT Roster.LastName, TeamMascot.Mascot FROM Roster FULL JOIN TeamMascot ON Roster.SchoolID = TeamMascot.SchoolID
`,
expectedRows: [][]interface{}{
{"Adams", "Jaguars"},
{"Buchanan", "Lakers"},
{"Coolidge", "Lakers"},
{"Davis", "Knights"},
{"Eisenhower", nil},
{nil, "Mustangs"},
},
},
{
name: "qualify",
query: `
Expand Down

0 comments on commit b49d99a

Please sign in to comment.