diff --git a/README.md b/README.md index c35b2df..1bfb383 100644 --- a/README.md +++ b/README.md @@ -147,6 +147,11 @@ A list of ZetaSQL specifications and features supported by go-zetasqlite. - [x] WINDOW - [x] WITH - [x] UNION +- [X] HAVING +- [x] ORDER BY +- [X] GROUP BY - ROLLUP +- [X] INNER/LEFT JOIN +- [x] QUALIFY ### Aggregate functions @@ -195,7 +200,7 @@ A list of ZetaSQL specifications and features supported by go-zetasqlite. ### Numbering functions - [x] RANK -- [ ] DENSE_RANK +- [x] DENSE_RANK - [ ] PERCENT_RANK - [ ] CUME_DIST - [ ] NTILE @@ -276,7 +281,7 @@ A list of ZetaSQL specifications and features supported by go-zetasqlite. - [x] LAST_VALUE - [ ] NTH_VALUE - [ ] LEAD -- [ ] LAG +- [x] LAG - [ ] PERCENTILE_CONT - [ ] PERCENTILE_DISC diff --git a/cmd/zetasqlite-cli/go.mod b/cmd/zetasqlite-cli/go.mod new file mode 100644 index 0000000..f246b1f --- /dev/null +++ b/cmd/zetasqlite-cli/go.mod @@ -0,0 +1,14 @@ +module github.com/goccy/go-zetasqlite/cmd/zetasqlite-cli + +go 1.18 + +require ( + github.com/chzyer/readline v1.5.1 + github.com/goccy/go-zetasqlite v0.3.1 +) + +require ( + github.com/goccy/go-zetasql v0.2.8 // indirect + github.com/mattn/go-sqlite3 v1.14.14 // indirect + golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5 // indirect +) diff --git a/cmd/zetasqlite-cli/go.sum b/cmd/zetasqlite-cli/go.sum new file mode 100644 index 0000000..9abeda5 --- /dev/null +++ b/cmd/zetasqlite-cli/go.sum @@ -0,0 +1,15 @@ +github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= +github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= +github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= +github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= +github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= +github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= +github.com/goccy/go-zetasql v0.2.8 h1:Kh8zQgBxN7Kg+cNJ3Z724G0oyaGoUYngeATSp6IOF/A= +github.com/goccy/go-zetasql v0.2.8/go.mod h1:6W14CJVKh7crrSPyj6NPk4c49L2NWnxvyDLsRkOm4BI= +github.com/goccy/go-zetasqlite v0.3.1 h1:rOWZZrQp59iu0FLeYkrOjdXBQbF/XPKNq3wGdbsTNYE= +github.com/goccy/go-zetasqlite v0.3.1/go.mod h1:PT3M02F7jVGdLmrczfr3ZaLnjxH9bowlsP6qy77OG1g= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/mattn/go-sqlite3 v1.14.14 h1:qZgc/Rwetq+MtyE18WhzjokPD93dNqLGNT3QJuLvBGw= +github.com/mattn/go-sqlite3 v1.14.14/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5 h1:y/woIyUBFbpQGKS0u1aHF/40WUDnek3fPOyD08H5Vng= +golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/cmd/zetasqlite-cli/main.go b/cmd/zetasqlite-cli/main.go new file mode 100644 index 0000000..2e83cc9 --- /dev/null +++ b/cmd/zetasqlite-cli/main.go @@ -0,0 +1,74 @@ +package main + +import ( + "context" + "database/sql" + "fmt" + "io" + "log" + "reflect" + "strings" + + "github.com/chzyer/readline" + _ "github.com/goccy/go-zetasqlite" +) + +func main() { + if err := run(context.Background()); err != nil { + log.Fatalf("%+v", err) + } +} + +func run(ctx context.Context) error { + db, err := sql.Open("zetasqlite_sqlite3", ":memory:") + if err != nil { + return err + } + rl, err := readline.NewEx(&readline.Config{ + Prompt: ">> ", + HistoryFile: "./zetasqlite.history", + }) + if err != nil { + return err + } + defer rl.Close() + for { + line, err := rl.Readline() + if err == io.EOF || err == readline.ErrInterrupt { + break + } + line = strings.TrimSpace(line) + rows, err := db.QueryContext(ctx, line) + if err != nil { + fmt.Printf("ERROR: %v\n", err) + continue + } + columns, err := rows.Columns() + if err != nil { + fmt.Printf("ERROR: %v\n", err) + continue + } + columnNum := len(columns) + args := make([]interface{}, columnNum) + for i := 0; i < columnNum; i++ { + var v interface{} + args[i] = &v + } + header := strings.Join(columns, "|") + fmt.Printf("%s\n", header) + fmt.Printf("%s\n", strings.Repeat("-", len(header))) + for rows.Next() { + if err := rows.Scan(args...); err != nil { + fmt.Printf("ERROR: %v", err) + break + } + values := make([]string, 0, len(args)) + for _, arg := range args { + v := reflect.ValueOf(arg).Elem().Interface() + values = append(values, fmt.Sprint(v)) + } + fmt.Printf("%s\n", strings.Join(values, "|")) + } + } + return nil +} diff --git a/driver.go b/driver.go index 1c9df4f..50166c7 100644 --- a/driver.go +++ b/driver.go @@ -118,7 +118,7 @@ func (c *ZetaSQLiteConn) ExecContext(ctx context.Context, query string, args []d if err != nil { return nil, fmt.Errorf("failed to analyze query: %w", err) } - newNamedValues, err := internal.ConvertNamedValues(args) + newNamedValues, err := internal.EncodeNamedValues(args, out.Params()) if err != nil { return nil, err } @@ -134,7 +134,7 @@ func (c *ZetaSQLiteConn) QueryContext(ctx context.Context, query string, args [] if err != nil { return nil, fmt.Errorf("failed to analyze query: %w", err) } - newNamedValues, err := internal.ConvertNamedValues(args) + newNamedValues, err := internal.EncodeNamedValues(args, out.Params()) if err != nil { return nil, err } diff --git a/go.mod b/go.mod index 5c66b44..6538636 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/goccy/go-zetasqlite go 1.17 require ( - github.com/goccy/go-zetasql v0.2.7 + github.com/goccy/go-zetasql v0.2.8 github.com/mattn/go-sqlite3 v1.14.14 ) diff --git a/go.sum b/go.sum index d0c06c8..f4c0fd8 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/goccy/go-zetasql v0.2.7 h1:zpWBg0hb9LqDRpllrPvs8Ei6LsVVLNWLlVV1RgnQvx4= -github.com/goccy/go-zetasql v0.2.7/go.mod h1:6W14CJVKh7crrSPyj6NPk4c49L2NWnxvyDLsRkOm4BI= +github.com/goccy/go-zetasql v0.2.8 h1:Kh8zQgBxN7Kg+cNJ3Z724G0oyaGoUYngeATSp6IOF/A= +github.com/goccy/go-zetasql v0.2.8/go.mod h1:6W14CJVKh7crrSPyj6NPk4c49L2NWnxvyDLsRkOm4BI= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/mattn/go-sqlite3 v1.14.14 h1:qZgc/Rwetq+MtyE18WhzjokPD93dNqLGNT3QJuLvBGw= diff --git a/internal/analyzer.go b/internal/analyzer.go index d7bb88f..5276ca4 100644 --- a/internal/analyzer.go +++ b/internal/analyzer.go @@ -22,7 +22,7 @@ type AnalyzerOutput struct { node ast.Node query string formattedQuery string - argsNum int + params []*ast.ParameterNode isQuery bool tableSpec *TableSpec outputColumns []*ColumnSpec @@ -31,6 +31,10 @@ type AnalyzerOutput struct { QueryContext func(context.Context, *sql.Conn, ...interface{}) (driver.Rows, error) } +func (o *AnalyzerOutput) Params() []*ast.ParameterNode { + return o.params +} + func NewAnalyzer(catalog *Catalog) *Analyzer { return &Analyzer{ catalog: catalog, @@ -67,6 +71,9 @@ func newAnalyzerOptions() *zetasql.AnalyzerOptions { zetasql.FeatureV13ExtendedDateTimeSignatures, zetasql.FeatureV12CivilTime, zetasql.FeatureIntervalType, + zetasql.FeatureGroupByRollup, + zetasql.FeatureV13NullsFirstLastInOrderBy, + zetasql.FeatureV13Qualify, }) langOpt.SetSupportedStatementKinds([]ast.Kind{ ast.QueryStmt, @@ -136,7 +143,7 @@ func (a *Analyzer) analyzeCreateTableStmt(query string, node *ast.CreateTableStm return &AnalyzerOutput{ node: node, query: query, - argsNum: a.getParamNumFromNode(node), + params: a.getParamsFromNode(node), tableSpec: spec, Prepare: func(ctx context.Context, conn *sql.Conn) (driver.Stmt, error) { if spec.CreateMode == ast.CreateOrReplaceMode { @@ -203,18 +210,18 @@ func (a *Analyzer) analyzeDMLStmt(ctx context.Context, query string, node ast.No if formattedQuery == "" { return nil, fmt.Errorf("failed to format query %s", query) } - argsNum := a.getParamNumFromNode(node) + params := a.getParamsFromNode(node) return &AnalyzerOutput{ node: node, query: query, formattedQuery: formattedQuery, - argsNum: argsNum, + params: params, Prepare: func(ctx context.Context, conn *sql.Conn) (driver.Stmt, error) { s, err := conn.PrepareContext(ctx, formattedQuery) if err != nil { return nil, fmt.Errorf("failed to prepare %s: %w", query, err) } - return newDMLStmt(s, argsNum, formattedQuery), nil + return newDMLStmt(s, params, formattedQuery), nil }, ExecContext: func(ctx context.Context, conn *sql.Conn, args ...interface{}) (driver.Result, error) { if _, err := conn.ExecContext(ctx, formattedQuery, args...); err != nil { @@ -240,19 +247,19 @@ func (a *Analyzer) analyzeQueryStmt(ctx context.Context, query string, node *ast if formattedQuery == "" { return nil, fmt.Errorf("failed to format query %s", query) } - argsNum := a.getParamNumFromNode(node) + params := a.getParamsFromNode(node) return &AnalyzerOutput{ node: node, query: query, formattedQuery: formattedQuery, - argsNum: argsNum, + params: params, isQuery: true, Prepare: func(ctx context.Context, conn *sql.Conn) (driver.Stmt, error) { s, err := conn.PrepareContext(ctx, formattedQuery) if err != nil { return nil, fmt.Errorf("failed to prepare %s: %w", query, err) } - return newQueryStmt(s, argsNum, formattedQuery, outputColumns), nil + return newQueryStmt(s, params, formattedQuery, outputColumns), nil }, QueryContext: func(ctx context.Context, conn *sql.Conn, args ...interface{}) (driver.Rows, error) { rows, err := conn.QueryContext(ctx, formattedQuery, args...) @@ -266,7 +273,7 @@ func (a *Analyzer) analyzeQueryStmt(ctx context.Context, query string, node *ast func (a *Analyzer) getFullNamePath(query string) (*fullNamePath, error) { fullpath := &fullNamePath{} - parsedAST, err := zetasql.ParseStatement(query) + parsedAST, err := zetasql.ParseStatement(query, a.opt.ParserOptions()) if err != nil { return nil, fmt.Errorf("failed to parse statement: %w", err) } @@ -311,14 +318,14 @@ func (a *Analyzer) getFullNamePath(query string) (*fullNamePath, error) { return fullpath, nil } -func (a *Analyzer) getParamNumFromNode(node ast.Node) int { - var numInput int +func (a *Analyzer) getParamsFromNode(node ast.Node) []*ast.ParameterNode { + var params []*ast.ParameterNode ast.Walk(node, func(n ast.Node) error { - _, ok := n.(*ast.ParameterNode) + param, ok := n.(*ast.ParameterNode) if ok { - numInput++ + params = append(params, param) } return nil }) - return numInput + return params } diff --git a/internal/catalog.go b/internal/catalog.go index 1ff1182..e045e7d 100644 --- a/internal/catalog.go +++ b/internal/catalog.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "reflect" + "strings" "sync" "time" @@ -252,6 +253,14 @@ func (c *Catalog) addTableSpecRecursive(cat *types.SimpleCatalog, spec *TableSpe if !c.existsCatalog(cat, subCatalogName) { cat.AddCatalog(subCatalog) } + fullTableName := strings.Join(spec.NamePath, ".") + if !c.existsTable(cat, fullTableName) { + table, err := c.createSimpleTable(fullTableName, spec) + if err != nil { + return err + } + cat.AddTable(table) + } newNamePath := spec.NamePath[1:] // add sub catalog to root catalog if err := c.addTableSpecRecursive(cat, c.copyTableSpec(spec, newNamePath)); err != nil { @@ -271,18 +280,26 @@ func (c *Catalog) addTableSpecRecursive(cat *types.SimpleCatalog, spec *TableSpe if c.existsTable(cat, tableName) { return nil } + table, err := c.createSimpleTable(tableName, spec) + if err != nil { + return err + } + cat.AddTable(table) + return nil +} + +func (c *Catalog) createSimpleTable(tableName string, spec *TableSpec) (*types.SimpleTable, error) { columns := []types.Column{} for _, column := range spec.Columns { typ, err := column.Type.ToZetaSQLType() if err != nil { - return err + return nil, err } columns = append(columns, types.NewSimpleColumn( tableName, column.Name, typ, )) } - cat.AddTable(types.NewSimpleTable(tableName, columns)) - return nil + return types.NewSimpleTable(tableName, columns), nil } func (c *Catalog) addFunctionSpecRecursive(cat *types.SimpleCatalog, spec *FunctionSpec) error { diff --git a/internal/context.go b/internal/context.go index dcbaec0..8ae9b2b 100644 --- a/internal/context.go +++ b/internal/context.go @@ -13,8 +13,10 @@ type ( analyticOrderColumnNamesKey struct{} analyticPartitionColumnNamesKey struct{} analyticTableNameKey struct{} + analyticInputScanKey struct{} arraySubqueryColumnNameKey struct{} currentTimeKey struct{} + existsGroupByKey struct{} ) func namePathFromContext(ctx context.Context) []string { @@ -70,8 +72,13 @@ func funcMapFromContext(ctx context.Context) map[string]*FunctionSpec { return value.(map[string]*FunctionSpec) } +type analyticOrderBy struct { + column string + isAsc bool +} + type analyticOrderColumnNames struct { - values []string + values []*analyticOrderBy } func withAnalyticOrderColumnNames(ctx context.Context, v *analyticOrderColumnNames) context.Context { @@ -110,6 +117,18 @@ func analyticTableNameFromContext(ctx context.Context) string { return value.(string) } +func withAnalyticInputScan(ctx context.Context, input string) context.Context { + return context.WithValue(ctx, analyticInputScanKey{}, input) +} + +func analyticInputScanFromContext(ctx context.Context) string { + value := ctx.Value(analyticInputScanKey{}) + if value == nil { + return "" + } + return value.(string) +} + type arraySubqueryColumnNames struct { names []string } @@ -126,6 +145,22 @@ func arraySubqueryColumnNameFromContext(ctx context.Context) *arraySubqueryColum return value.(*arraySubqueryColumnNames) } +type existsGroupBy struct { + exists bool +} + +func withExistsGroupBy(ctx context.Context, v *existsGroupBy) context.Context { + return context.WithValue(ctx, existsGroupByKey{}, v) +} + +func existsGroupByFromContext(ctx context.Context) *existsGroupBy { + value := ctx.Value(existsGroupByKey{}) + if value == nil { + return nil + } + return value.(*existsGroupBy) +} + func WithCurrentTime(ctx context.Context, now time.Time) context.Context { return context.WithValue(ctx, currentTimeKey{}, &now) } diff --git a/internal/formatter.go b/internal/formatter.go index 1e1e1b2..eeec001 100644 --- a/internal/formatter.go +++ b/internal/formatter.go @@ -18,10 +18,22 @@ func New(node ast.Node) Formatter { } func FormatName(namePath []string) string { + namePath = FormatPath(namePath) return strings.Join(namePath, "_") } +func FormatPath(path []string) []string { + ret := []string{} + for _, p := range path { + splitted := strings.Split(p, ".") + ret = append(ret, splitted...) + } + return ret +} + func MergeNamePath(namePath []string, queryPath []string) []string { + namePath = FormatPath(namePath) + queryPath = FormatPath(queryPath) if len(queryPath) == 0 { return namePath } @@ -62,6 +74,7 @@ func getFuncNameAndArgs(ctx context.Context, node *ast.BaseFunctionCallNode, isW _, existsAggregateFunc := aggregateFuncMap[funcName] _, existsWindowFunc := windowFuncMap[funcName] currentTime := CurrentTime(ctx) + fullpath := fullNamePathFromContext(ctx) if strings.HasPrefix(funcName, "$") { if isWindowFunc { funcName = fmt.Sprintf("zetasqlite_window_%s_%s", funcName[1:], suffixName) @@ -83,7 +96,6 @@ func getFuncNameAndArgs(ctx context.Context, node *ast.BaseFunctionCallNode, isW } else if isWindowFunc && existsWindowFunc { funcName = fmt.Sprintf("zetasqlite_window_%s_%s", funcName, suffixName) } else { - fullpath := fullNamePathFromContext(ctx) path := fullpath.paths[fullpath.idx] funcName = FormatName( MergeNamePath( @@ -91,8 +103,8 @@ func getFuncNameAndArgs(ctx context.Context, node *ast.BaseFunctionCallNode, isW path, ), ) - fullpath.idx++ } + fullpath.idx++ return funcName, args, nil } @@ -115,7 +127,13 @@ func (n *ColumnRefNode) FormatSQL(ctx context.Context) (string, error) { if n.node == nil { return "", nil } - return fmt.Sprintf("`%s`", n.node.Column().Name()), nil + columnMap := columnRefMap(ctx) + colName := n.node.Column().Name() + if ref, exists := columnMap[colName]; exists { + delete(columnMap, colName) + return ref, nil + } + return fmt.Sprintf("`%s`", colName), nil } func (n *ConstantNode) FormatSQL(ctx context.Context) (string, error) { @@ -224,7 +242,9 @@ func (n *AnalyticFunctionCallNode) FormatSQL(ctx context.Context) (string, error for _, a := range n.node.ArgumentList() { switch t := a.(type) { case *ast.ColumnRefNode: - ctx = withAnalyticTableName(ctx, t.Column().TableName()) + ctx = withAnalyticTableName(ctx, FormatName([]string{t.Column().TableName()})) + case *ast.LiteralNode: + continue default: return "", fmt.Errorf("unexpected argument node type %T for analytic function", a) } @@ -232,11 +252,10 @@ func (n *AnalyticFunctionCallNode) FormatSQL(ctx context.Context) (string, error if err != nil { return "", err } - orderColumnNames.values = append(orderColumnNames.values, arg) - } - tableName := analyticTableNameFromContext(ctx) - if tableName == "" { - return "", fmt.Errorf("failed to find table name from analytic query") + orderColumnNames.values = append(orderColumnNames.values, &analyticOrderBy{ + column: arg, + isAsc: true, + }) } funcName, args, err := getFuncNameAndArgs(ctx, n.node.BaseFunctionCallNode, true) if err != nil { @@ -259,8 +278,8 @@ func (n *AnalyticFunctionCallNode) FormatSQL(ctx context.Context) (string, error for _, column := range analyticPartitionColumnNamesFromContext(ctx) { args = append(args, getWindowPartitionOptionFuncSQL(column)) } - for _, column := range orderColumns { - args = append(args, getWindowOrderByOptionFuncSQL(column)) + for _, col := range orderColumns { + args = append(args, getWindowOrderByOptionFuncSQL(col.column, col.isAsc)) } windowFrame := n.node.WindowFrame() if windowFrame != nil { @@ -277,11 +296,12 @@ func (n *AnalyticFunctionCallNode) FormatSQL(ctx context.Context) (string, error args = append(args, endSQL) } args = append(args, getWindowRowIDOptionFuncSQL()) + input := analyticInputScanFromContext(ctx) return fmt.Sprintf( - "( SELECT %s(%s) FROM %s )", + "( SELECT %s(%s) %s )", funcName, strings.Join(args, ","), - tableName, + input, ), nil return "", nil @@ -449,7 +469,56 @@ func (n *TableScanNode) FormatSQL(ctx context.Context) (string, error) { } func (n *JoinScanNode) FormatSQL(ctx context.Context) (string, error) { - return "", nil + if n.node == nil { + return "", nil + } + leftRef, ok := n.node.LeftScan().(*ast.WithRefScanNode) + if !ok { + return "", fmt.Errorf("unexpected leftscan node %T", n.node.LeftScan()) + } + rightRef, ok := n.node.RightScan().(*ast.WithRefScanNode) + if !ok { + return "", fmt.Errorf("unexpected rightscan node %T", n.node.RightScan()) + } + equalFunc, ok := n.node.JoinExpr().(*ast.FunctionCallNode) + if !ok { + return "", fmt.Errorf("unexpected joinexpr node %T", n.node.JoinExpr()) + } + args := equalFunc.ArgumentList() + if len(args) != 2 { + return "", fmt.Errorf("join argument must be two arguments but got %d", len(args)) + } + leftColumn, err := newNode(args[0]).FormatSQL(ctx) + if err != nil { + return "", err + } + rightColumn, err := newNode(args[1]).FormatSQL(ctx) + if err != nil { + return "", err + } + leftTableName := leftRef.WithQueryName() + rightTableName := rightRef.WithQueryName() + var joinType string + switch n.node.JoinType() { + case ast.JoinTypeInner: + joinType = "JOIN" + case ast.JoinTypeLeft: + joinType = "LEFT JOIN" + case ast.JoinTypeRight: + joinType = "RIGHT JOIN" + case ast.JoinTypeFull: + joinType = "FULL JOIN" + } + return fmt.Sprintf( + "FROM `%s` %s `%s` ON `%s`.%s = `%s`.%s", + leftTableName, + joinType, + rightTableName, + leftTableName, + leftColumn, + rightTableName, + rightColumn, + ), nil } func (n *ArrayScanNode) FormatSQL(ctx context.Context) (string, error) { @@ -461,6 +530,18 @@ func (n *ArrayScanNode) FormatSQL(ctx context.Context) (string, error) { return "", err } colName := n.node.ElementColumn().Name() + if n.node.InputScan() != nil { + input, err := newNode(n.node.InputScan()).FormatSQL(ctx) + if err != nil { + return "", err + } + return fmt.Sprintf( + "SELECT *, json_each.value AS `%s` %s, json_each(zetasqlite_decode_array_string(%s))", + colName, + input, + arrayExpr, + ), nil + } return fmt.Sprintf( "FROM ( SELECT json_each.value AS `%s` FROM json_each(zetasqlite_decode_array_string(%s)) )", colName, @@ -476,14 +557,23 @@ func (n *FilterScanNode) FormatSQL(ctx context.Context) (string, error) { if n.node == nil { return "", nil } + ctx = withExistsGroupBy(ctx, &existsGroupBy{}) input, err := newNode(n.node.InputScan()).FormatSQL(ctx) if err != nil { return "", err } + usedGroupBy := existsGroupByFromContext(ctx).exists filter, err := newNode(n.node.FilterExpr()).FormatSQL(ctx) if err != nil { return "", err } + if usedGroupBy { + return fmt.Sprintf("%s HAVING %s", input, filter), nil + } + if strings.Contains(input, "WHERE") && input[len(input)-1] != ')' { + // expected to qualify clause + return fmt.Sprintf("FROM ( %s ) WHERE %s", input, filter), nil + } return fmt.Sprintf("%s WHERE %s", input, filter), nil } @@ -509,17 +599,83 @@ func (n *AggregateScanNode) FormatSQL(ctx context.Context) (string, error) { columns = append(columns, ref) delete(columnMap, colName) } else { - columns = append(columns, fmt.Sprintf("`%s`", colName)) + columns = append( + columns, + fmt.Sprintf("`%s`", colName), + ) } } input, err := newNode(n.node.InputScan()).FormatSQL(ctx) if err != nil { return "", err } + groupByColumns := []string{} + groupByColumnMap := map[string]struct{}{} + for _, col := range n.node.GroupByList() { + colName := fmt.Sprintf("`%s`", col.Column().Name()) + groupByColumns = append(groupByColumns, colName) + groupByColumnMap[colName] = struct{}{} + } + if len(groupByColumns) != 0 { + existsGroupBy := existsGroupByFromContext(ctx) + if existsGroupBy != nil { + existsGroupBy.exists = true + } + } + if len(n.node.GroupingSetList()) != 0 { + columnPatterns := [][]string{} + groupByColumnPatterns := [][]string{} + for _, set := range n.node.GroupingSetList() { + groupBySetColumns := []string{} + groupBySetColumnMap := map[string]struct{}{} + for _, col := range set.GroupByColumnList() { + colName := fmt.Sprintf("`%s`", col.Column().Name()) + groupBySetColumns = append(groupBySetColumns, colName) + groupBySetColumnMap[colName] = struct{}{} + } + nullColumnNameMap := map[string]struct{}{} + for _, col := range groupByColumns { + if _, exists := groupBySetColumnMap[col]; !exists { + nullColumnNameMap[col] = struct{}{} + } + } + groupBySetColumnPattern := []string{} + for _, col := range columns { + if _, exists := nullColumnNameMap[col]; exists { + groupBySetColumnPattern = append(groupBySetColumnPattern, "NULL") + } else { + groupBySetColumnPattern = append(groupBySetColumnPattern, col) + } + } + columnPatterns = append(columnPatterns, groupBySetColumnPattern) + groupByColumnPatterns = append(groupByColumnPatterns, groupBySetColumns) + } + stmts := []string{} + for i := 0; i < len(columnPatterns); i++ { + var groupBy string + if len(groupByColumnPatterns[i]) != 0 { + groupBy = fmt.Sprintf("GROUP BY %s", strings.Join(groupByColumnPatterns[i], ",")) + } + if strings.HasPrefix(input, "SELECT") { + stmts = append(stmts, fmt.Sprintf("SELECT %s FROM ( %s %s )", strings.Join(columnPatterns[i], ","), input, groupBy)) + } else { + stmts = append(stmts, fmt.Sprintf("SELECT %s %s %s", strings.Join(columnPatterns[i], ","), input, groupBy)) + } + } + return fmt.Sprintf( + "%s ORDER BY %s", + strings.Join(stmts, " UNION ALL "), + strings.Join(groupByColumns, ","), + ), nil + } + var groupBy string + if len(groupByColumns) > 0 { + groupBy = fmt.Sprintf("GROUP BY %s", strings.Join(groupByColumns, ",")) + } if strings.HasPrefix(input, "SELECT") { - return fmt.Sprintf("SELECT %s FROM ( %s )", strings.Join(columns, ","), input), nil + return fmt.Sprintf("SELECT %s FROM ( %s %s )", strings.Join(columns, ","), input, groupBy), nil } - return fmt.Sprintf("SELECT %s %s", strings.Join(columns, ","), input), nil + return fmt.Sprintf("SELECT %s %s %s", strings.Join(columns, ","), input, groupBy), nil } func (n *AnonymizedAggregateScanNode) FormatSQL(ctx context.Context) (string, error) { @@ -566,7 +722,62 @@ func (n *SetOperationScanNode) FormatSQL(ctx context.Context) (string, error) { } func (n *OrderByScanNode) FormatSQL(ctx context.Context) (string, error) { - return "", nil + if n.node == nil { + return "", nil + } + input, err := newNode(n.node.InputScan()).FormatSQL(ctx) + if err != nil { + return "", err + } + columns := []string{} + columnMap := columnRefMap(ctx) + for _, col := range n.node.ColumnList() { + colName := col.Name() + if ref, exists := columnMap[colName]; exists { + columns = append(columns, ref) + delete(columnMap, colName) + } else { + columns = append( + columns, + fmt.Sprintf("`%s`", colName), + ) + } + } + orderByColumns := []string{} + for _, item := range n.node.OrderByItemList() { + colName := item.ColumnRef().Column().Name() + switch item.NullOrder() { + case ast.NullOrderModeNullsFirst: + orderByColumns = append( + orderByColumns, + fmt.Sprintf("(`%s` IS NOT NULL)", colName), + ) + case ast.NullOrderModeNullsLast: + orderByColumns = append( + orderByColumns, + fmt.Sprintf("(`%s` IS NULL)", colName), + ) + } + if item.IsDescending() { + orderByColumns = append(orderByColumns, fmt.Sprintf("`%s` DESC", colName)) + } else { + orderByColumns = append(orderByColumns, fmt.Sprintf("`%s`", colName)) + } + } + if strings.HasPrefix(input, "SELECT") { + return fmt.Sprintf( + "SELECT %s FROM ( %s ) ORDER BY %s", + strings.Join(columns, ","), + input, + strings.Join(orderByColumns, ","), + ), nil + } + return fmt.Sprintf( + "SELECT %s %s ORDER BY %s", + strings.Join(columns, ","), + input, + strings.Join(orderByColumns, ","), + ), nil } func (n *LimitOffsetScanNode) FormatSQL(ctx context.Context) (string, error) { @@ -589,38 +800,80 @@ func (n *AnalyticScanNode) FormatSQL(ctx context.Context) (string, error) { if err != nil { return "", err } + ctx = withAnalyticInputScan(ctx, input) orderColumnNames := analyticOrderColumnNamesFromContext(ctx) for _, group := range n.node.FunctionGroupList() { if group.PartitionBy() != nil { var partitionColumns []string for _, columnRef := range group.PartitionBy().PartitionByList() { - ctx = withAnalyticTableName(ctx, columnRef.Column().TableName()) + ctx = withAnalyticTableName(ctx, FormatName([]string{columnRef.Column().TableName()})) + colName := fmt.Sprintf("`%s`", columnRef.Column().Name()) partitionColumns = append( partitionColumns, - fmt.Sprintf("`%s`", columnRef.Column().Name()), + colName, ) + orderColumnNames.values = append(orderColumnNames.values, &analyticOrderBy{ + column: colName, + isAsc: true, + }) } - orderColumnNames.values = append(orderColumnNames.values, partitionColumns...) ctx = withAnalyticPartitionColumnNames(ctx, partitionColumns) } if group.OrderBy() != nil { var orderByColumns []string for _, item := range group.OrderBy().OrderByItemList() { - ctx = withAnalyticTableName(ctx, item.ColumnRef().Column().TableName()) + ctx = withAnalyticTableName(ctx, FormatName([]string{item.ColumnRef().Column().TableName()})) + colName := fmt.Sprintf("`%s`", item.ColumnRef().Column().Name()) orderByColumns = append( orderByColumns, - fmt.Sprintf("`%s`", item.ColumnRef().Column().Name()), + colName, ) + orderColumnNames.values = append(orderColumnNames.values, &analyticOrderBy{ + column: colName, + isAsc: !item.IsDescending(), + }) } - orderColumnNames.values = append(orderColumnNames.values, orderByColumns...) } if _, err := newNode(group).FormatSQL(ctx); err != nil { return "", err } } - orderBy := fmt.Sprintf("ORDER BY %s", strings.Join(orderColumnNames.values, ",")) - orderColumnNames.values = []string{} - return fmt.Sprintf("FROM ( SELECT *, ROW_NUMBER() OVER() AS `rowid` %s ) %s", input, orderBy), nil + columns := []string{} + columnMap := columnRefMap(ctx) + for _, col := range n.node.ColumnList() { + colName := col.Name() + if ref, exists := columnMap[colName]; exists { + columns = append(columns, ref) + delete(columnMap, colName) + } else { + columns = append( + columns, + fmt.Sprintf("`%s`", colName), + ) + } + } + var orderColumnFormattedNames []string + for _, col := range orderColumnNames.values { + if col.isAsc { + orderColumnFormattedNames = append( + orderColumnFormattedNames, + col.column, + ) + } else { + orderColumnFormattedNames = append( + orderColumnFormattedNames, + fmt.Sprintf("%s DESC", col.column), + ) + } + } + orderBy := fmt.Sprintf("ORDER BY %s", strings.Join(orderColumnFormattedNames, ",")) + orderColumnNames.values = []*analyticOrderBy{} + return fmt.Sprintf( + "SELECT %s FROM ( SELECT *, ROW_NUMBER() OVER() AS `row_id` %s ) %s", + strings.Join(columns, ","), + input, + orderBy, + ), nil } func (n *SampleScanNode) FormatSQL(ctx context.Context) (string, error) { @@ -703,15 +956,25 @@ func (n *ProjectScanNode) FormatSQL(ctx context.Context) (string, error) { if err != nil { return "", err } + _, isJoinScan := n.node.InputScan().(*ast.JoinScanNode) columns := []string{} columnMap := columnRefMap(ctx) for _, col := range n.node.ColumnList() { + tableName := FormatName([]string{col.TableName()}) colName := col.Name() if ref, exists := columnMap[colName]; exists { columns = append(columns, ref) delete(columnMap, colName) + } else if isJoinScan { + columns = append( + columns, + fmt.Sprintf("`%s`.`%s`", tableName, colName), + ) } else { - columns = append(columns, fmt.Sprintf("`%s`", colName)) + columns = append( + columns, + fmt.Sprintf("`%s`", colName), + ) } } if strings.HasPrefix(input, "SELECT") { @@ -914,7 +1177,7 @@ func (n *AnalyticFunctionGroupNode) FormatSQL(ctx context.Context) (string, erro if n.node == nil { return "", nil } - orderColumnNames := analyticOrderColumnNamesFromContext(ctx) + var queries []string for _, column := range n.node.AnalyticFunctionList() { sql, err := newNode(column).FormatSQL(ctx) @@ -922,10 +1185,6 @@ func (n *AnalyticFunctionGroupNode) FormatSQL(ctx context.Context) (string, erro return "", err } queries = append(queries, sql) - orderColumnNames.values = append( - orderColumnNames.values, - fmt.Sprintf("`%s`", column.Column().Name()), - ) } return strings.Join(queries, ","), nil } diff --git a/internal/function.go b/internal/function.go index 8f576b3..540613d 100644 --- a/internal/function.go +++ b/internal/function.go @@ -7,18 +7,30 @@ import ( ) func ADD(a, b Value) (Value, error) { + if a == nil || b == nil { + return nil, nil + } return a.Add(b) } func SUB(a, b Value) (Value, error) { + if a == nil || b == nil { + return nil, nil + } return a.Sub(b) } func MUL(a, b Value) (Value, error) { + if a == nil || b == nil { + return nil, nil + } return a.Mul(b) } func OP_DIV(a, b Value) (Value, error) { + if a == nil || b == nil { + return nil, nil + } return a.Div(b) } diff --git a/internal/function_bind.go b/internal/function_bind.go index 5557d37..1c63c3c 100644 --- a/internal/function_bind.go +++ b/internal/function_bind.go @@ -792,10 +792,7 @@ func bindLength(args ...Value) (Value, error) { } func timeFromUnixNano(unixNano int64) time.Time { - return time.Unix( - unixNano/int64(time.Second), - unixNano%int64(time.Nanosecond), - ) + return time.Unix(0, unixNano) } func bindCurrentDate(args ...Value) (Value, error) { @@ -1331,10 +1328,14 @@ func bindWindowRowID(args ...Value) (Value, error) { } func bindWindowOrderBy(args ...Value) (Value, error) { - if len(args) != 1 { + if len(args) != 2 { return nil, fmt.Errorf("WINDOW_ORDER_BY: invalid argument num %d", len(args)) } - return WINDOW_ORDER_BY(args[0]) + isAsc, err := args[1].ToBool() + if err != nil { + return nil, err + } + return WINDOW_ORDER_BY(args[0], isAsc) } func bindArrayAgg(converter ReturnValueConverter) func() *Aggregator { @@ -1675,6 +1676,39 @@ func bindWindowLastValue(converter ReturnValueConverter) func() *WindowAggregato } } +func bindWindowLag(converter ReturnValueConverter) func() *WindowAggregator { + return func() *WindowAggregator { + fn := &WINDOW_LAG{} + return newWindowAggregator( + func(args []Value, opt *AggregatorOption, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { + if len(args) != 1 && len(args) != 2 && len(args) != 3 { + return fmt.Errorf("WINDOW_LAG: invalid argument num %d", len(args)) + } + var offset int64 = 1 + if len(args) >= 2 { + v, err := args[1].ToInt64() + if err != nil { + return err + } + offset = v + } + if offset < 0 { + return fmt.Errorf("WINDOW_LAG: offset is must be positive value %d", offset) + } + var defaultValue Value + if len(args) == 3 { + defaultValue = args[2] + } + return fn.Step(args[0], offset, defaultValue, windowOpt, agg) + }, + func(agg *WindowFuncAggregatedStatus) (Value, error) { + return fn.Done(agg) + }, + converter, + ) + } +} + func bindWindowRank(converter ReturnValueConverter) func() *WindowAggregator { return func() *WindowAggregator { fn := &WINDOW_RANK{} @@ -1692,3 +1726,21 @@ func bindWindowRank(converter ReturnValueConverter) func() *WindowAggregator { ) } } + +func bindWindowDenseRank(converter ReturnValueConverter) func() *WindowAggregator { + return func() *WindowAggregator { + fn := &WINDOW_DENSE_RANK{} + return newWindowAggregator( + func(args []Value, opt *AggregatorOption, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { + if len(args) != 0 { + return fmt.Errorf("WINDOW_DENSE_RANK: invalid argument num %d", len(args)) + } + return fn.Step(windowOpt, agg) + }, + func(agg *WindowFuncAggregatedStatus) (Value, error) { + return fn.Done(agg) + }, + converter, + ) + } +} diff --git a/internal/function_register.go b/internal/function_register.go index abf2bef..2347afd 100644 --- a/internal/function_register.go +++ b/internal/function_register.go @@ -207,7 +207,7 @@ var normalFuncs = []*FuncInfo{ { Name: "if", BindFunc: bindIf, - ReturnTypes: []types.TypeKind{types.INT64, types.DOUBLE, types.STRING}, + ReturnTypes: []types.TypeKind{types.INT64, types.DOUBLE, types.STRING, types.BOOL}, }, { Name: "ifnull", @@ -661,11 +661,24 @@ var windowFuncs = []*WindowFuncInfo{ BindFunc: bindWindowLastValue, ReturnTypes: []types.TypeKind{types.STRING}, }, + { + Name: "lag", + BindFunc: bindWindowLag, + ReturnTypes: []types.TypeKind{ + types.INT64, types.DOUBLE, types.STRING, types.BOOL, + types.DATE, types.DATETIME, types.TIMESTAMP, + }, + }, { Name: "rank", BindFunc: bindWindowRank, ReturnTypes: []types.TypeKind{types.INT64}, }, + { + Name: "dense_rank", + BindFunc: bindWindowDenseRank, + ReturnTypes: []types.TypeKind{types.INT64}, + }, } type NameAndFunc struct { diff --git a/internal/function_window.go b/internal/function_window.go index d481a47..53decd1 100644 --- a/internal/function_window.go +++ b/internal/function_window.go @@ -2,6 +2,7 @@ package internal import ( "fmt" + "math" "sync" ) @@ -186,6 +187,43 @@ func (f *WINDOW_LAST_VALUE) Done(agg *WindowFuncAggregatedStatus) (Value, error) return lastValue, nil } +type WINDOW_LAG struct { + lagOnce sync.Once + offset int64 + defaultValue Value +} + +func (f *WINDOW_LAG) Step(v Value, offset int64, defaultValue Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { + if v == nil { + return nil + } + f.lagOnce.Do(func() { + f.offset = offset + f.defaultValue = defaultValue + }) + return agg.Step(v, opt) +} + +func (f *WINDOW_LAG) Done(agg *WindowFuncAggregatedStatus) (Value, error) { + var lagValue Value + if err := agg.Done(func(values []Value, start, end int) error { + if len(values) == 0 { + return nil + } + if start-int(f.offset) < 0 { + return nil + } + lagValue = values[start-int(f.offset)] + return nil + }); err != nil { + return nil, err + } + if lagValue == nil { + return f.defaultValue, nil + } + return lagValue, nil +} + type WINDOW_RANK struct { } @@ -196,9 +234,16 @@ func (f *WINDOW_RANK) Step(opt *WindowFuncStatus, agg *WindowFuncAggregatedStatu func (f *WINDOW_RANK) Done(agg *WindowFuncAggregatedStatus) (Value, error) { var rankValue Value if err := agg.Done(func(_ []Value, start, end int) error { - var orderByValues []Value + var ( + orderByValues []Value + isAsc bool = true + isAscOnce sync.Once + ) for _, value := range agg.SortedValues { - orderByValues = append(orderByValues, value.OrderBy[len(value.OrderBy)-1]) + orderByValues = append(orderByValues, value.OrderBy[len(value.OrderBy)-1].Value) + isAscOnce.Do(func() { + isAsc = value.OrderBy[len(value.OrderBy)-1].IsAsc + }) } if start >= len(orderByValues) || end < 0 { return nil @@ -215,17 +260,101 @@ func (f *WINDOW_RANK) Done(agg *WindowFuncAggregatedStatus) (Value, error) { sameRankNum = 1 maxValue int64 ) - for idx := 0; idx <= lastIdx; idx++ { - curValue, err := orderByValues[idx].ToInt64() - if err != nil { - return err + if isAsc { + for idx := 0; idx <= lastIdx; idx++ { + curValue, err := orderByValues[idx].ToInt64() + if err != nil { + return err + } + if maxValue < curValue { + maxValue = curValue + rank += sameRankNum + sameRankNum = 1 + } else { + sameRankNum++ + } } - if maxValue < curValue { - maxValue = curValue - rank += sameRankNum - sameRankNum = 1 - } else { - sameRankNum++ + } else { + maxValue = math.MaxInt64 + for idx := 0; idx <= lastIdx; idx++ { + curValue, err := orderByValues[idx].ToInt64() + if err != nil { + return err + } + if maxValue > curValue { + maxValue = curValue + rank += sameRankNum + sameRankNum = 1 + } else { + sameRankNum++ + } + } + } + rankValue = IntValue(rank) + return nil + }); err != nil { + return nil, err + } + return rankValue, nil +} + +type WINDOW_DENSE_RANK struct { +} + +func (f *WINDOW_DENSE_RANK) Step(opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { + return agg.Step(IntValue(1), opt) +} + +func (f *WINDOW_DENSE_RANK) Done(agg *WindowFuncAggregatedStatus) (Value, error) { + var rankValue Value + if err := agg.Done(func(_ []Value, start, end int) error { + var ( + orderByValues []Value + isAscOnce sync.Once + isAsc bool = true + ) + for _, value := range agg.SortedValues { + orderByValues = append(orderByValues, value.OrderBy[len(value.OrderBy)-1].Value) + isAscOnce.Do(func() { + isAsc = value.OrderBy[len(value.OrderBy)-1].IsAsc + }) + } + if start >= len(orderByValues) || end < 0 { + return nil + } + if len(orderByValues) == 0 { + return nil + } + if start != end { + return fmt.Errorf("Rank must be same value of start and end") + } + lastIdx := start + var ( + rank = 0 + maxValue int64 + ) + if isAsc { + for idx := 0; idx <= lastIdx; idx++ { + curValue, err := orderByValues[idx].ToInt64() + if err != nil { + return err + } + if maxValue < curValue { + maxValue = curValue + rank++ + } + } + } else { + maxValue = math.MaxInt64 + for idx := 0; idx <= lastIdx; idx++ { + curValue, err := orderByValues[idx].ToInt64() + if err != nil { + return err + } + if maxValue > curValue { + maxValue = curValue + rank++ + } } } rankValue = IntValue(rank) diff --git a/internal/function_window_option.go b/internal/function_window_option.go index 007276c..98782ec 100644 --- a/internal/function_window_option.go +++ b/internal/function_window_option.go @@ -59,12 +59,20 @@ func (o *WindowFuncOption) UnmarshalJSON(b []byte) error { return err } o.Value = value.Value - case WindowFuncOptionPartition, WindowFuncOptionOrderBy: + case WindowFuncOptionPartition: value, err := ValueOf(v.Value) if err != nil { return fmt.Errorf("failed to convert %v to Value: %w", v.Value, err) } o.Value = value + case WindowFuncOptionOrderBy: + var value struct { + Value *WindowOrderBy `json:"value"` + } + if err := json.Unmarshal(b, &value); err != nil { + return err + } + o.Value = value.Value } return nil } @@ -141,11 +149,11 @@ func getWindowPartitionOptionFuncSQL(column string) string { } func getWindowRowIDOptionFuncSQL() string { - return "zetasqlite_window_rowid_string(`rowid`)" + return "zetasqlite_window_rowid_string(`row_id`)" } -func getWindowOrderByOptionFuncSQL(column string) string { - return fmt.Sprintf("zetasqlite_window_order_by_string(%s)", column) +func getWindowOrderByOptionFuncSQL(column string, isAsc bool) string { + return fmt.Sprintf("zetasqlite_window_order_by_string(%s, %t)", column, isAsc) } func WINDOW_FRAME_UNIT(frameUnit int64) (Value, error) { @@ -179,9 +187,62 @@ func WINDOW_BOUNDARY_END(boundaryType, offset int64) (Value, error) { } func WINDOW_PARTITION(partition Value) (Value, error) { + var v interface{} + switch vv := partition.(type) { + case IntValue: + i64, err := vv.ToInt64() + if err != nil { + return nil, err + } + v = i64 + case FloatValue: + f64, err := vv.ToFloat64() + if err != nil { + return nil, err + } + v = f64 + case StringValue: + s, err := vv.ToString() + if err != nil { + return nil, err + } + v = s + case BoolValue: + b, err := vv.ToBool() + if err != nil { + return nil, err + } + v = b + case DateValue: + i64, err := vv.ToInt64() + if err != nil { + return nil, err + } + v = i64 + case DatetimeValue: + i64, err := vv.ToInt64() + if err != nil { + return nil, err + } + v = i64 + case TimeValue: + i64, err := vv.ToInt64() + if err != nil { + return nil, err + } + v = i64 + case TimestampValue: + i64, err := vv.ToInt64() + if err != nil { + return nil, err + } + v = i64 + default: + return nil, fmt.Errorf("unsupported %T type for order by value", vv) + } b, _ := json.Marshal(&WindowFuncOption{ Type: WindowFuncOptionPartition, - Value: partition, + Value: v, }) return StringValue(string(b)), nil } @@ -194,7 +255,29 @@ func WINDOW_ROWID(id int64) (Value, error) { return StringValue(string(b)), nil } -func WINDOW_ORDER_BY(value Value) (Value, error) { +type WindowOrderBy struct { + Value Value `json:"value"` + IsAsc bool `json:"isAsc"` +} + +func (w *WindowOrderBy) UnmarshalJSON(b []byte) error { + var v struct { + Value interface{} `json:"value"` + IsAsc bool `json:"isAsc"` + } + if err := json.Unmarshal(b, &v); err != nil { + return err + } + value, err := ValueOf(v.Value) + if err != nil { + return err + } + w.Value = value + w.IsAsc = v.IsAsc + return nil +} + +func WINDOW_ORDER_BY(value Value, isAsc bool) (Value, error) { var v interface{} switch vv := value.(type) { case IntValue: @@ -227,10 +310,36 @@ func WINDOW_ORDER_BY(value Value) (Value, error) { return nil, err } v = i64 + case DatetimeValue: + i64, err := vv.ToInt64() + if err != nil { + return nil, err + } + v = i64 + case TimeValue: + i64, err := vv.ToInt64() + if err != nil { + return nil, err + } + v = i64 + case TimestampValue: + i64, err := vv.ToInt64() + if err != nil { + return nil, err + } + v = i64 + default: + return nil, fmt.Errorf("unsupported %T type for order by value", vv) } b, _ := json.Marshal(&WindowFuncOption{ - Type: WindowFuncOptionOrderBy, - Value: v, + Type: WindowFuncOptionOrderBy, + Value: struct { + Value interface{} `json:"value"` + IsAsc bool `json:"isAsc"` + }{ + Value: v, + IsAsc: isAsc, + }, }) return StringValue(string(b)), nil } @@ -241,7 +350,7 @@ type WindowFuncStatus struct { End *WindowBoundary Partition Value RowID int64 - OrderBy []Value + OrderBy []*WindowOrderBy } func parseWindowOptions(args ...interface{}) ([]interface{}, *WindowFuncStatus, error) { @@ -272,7 +381,7 @@ func parseWindowOptions(args ...interface{}) ([]interface{}, *WindowFuncStatus, case WindowFuncOptionRowID: opt.RowID = v.Value.(int64) case WindowFuncOptionOrderBy: - opt.OrderBy = append(opt.OrderBy, v.Value.(Value)) + opt.OrderBy = append(opt.OrderBy, v.Value.(*WindowOrderBy)) default: filteredArgs = append(filteredArgs, arg) continue @@ -282,7 +391,7 @@ func parseWindowOptions(args ...interface{}) ([]interface{}, *WindowFuncStatus, } type WindowOrderedValue struct { - OrderBy []Value + OrderBy []*WindowOrderBy Value Value } @@ -362,10 +471,18 @@ func (s *WindowFuncAggregatedStatus) Done(cb func([]Value, int, int) error) erro } if len(sortedValues) != 0 { for orderBy := 0; orderBy < len(sortedValues[0].OrderBy); orderBy++ { - sort.Slice(sortedValues, func(i, j int) bool { - cond, _ := sortedValues[i].OrderBy[orderBy].LT(sortedValues[j].OrderBy[orderBy]) - return cond - }) + isAsc := sortedValues[0].OrderBy[orderBy].IsAsc + if isAsc { + sort.Slice(sortedValues, func(i, j int) bool { + cond, _ := sortedValues[i].OrderBy[orderBy].Value.LT(sortedValues[j].OrderBy[orderBy].Value) + return cond + }) + } else { + sort.Slice(sortedValues, func(i, j int) bool { + cond, _ := sortedValues[i].OrderBy[orderBy].Value.GT(sortedValues[j].OrderBy[orderBy].Value) + return cond + }) + } } } s.SortedValues = sortedValues @@ -492,7 +609,7 @@ func (s *WindowFuncAggregatedStatus) currentRangeValue() (int64, error) { if len(curValue.OrderBy) == 0 { return 0, fmt.Errorf("required order by column for analytic range scanning") } - return curValue.OrderBy[len(curValue.OrderBy)-1].ToInt64() + return curValue.OrderBy[len(curValue.OrderBy)-1].Value.ToInt64() } func (s *WindowFuncAggregatedStatus) partitionedCurrentRangeValue() (int64, error) { @@ -501,7 +618,7 @@ func (s *WindowFuncAggregatedStatus) partitionedCurrentRangeValue() (int64, erro if len(curValue.Value.OrderBy) == 0 { return 0, fmt.Errorf("required order by column for analytic range scanning") } - return curValue.Value.OrderBy[len(curValue.Value.OrderBy)-1].ToInt64() + return curValue.Value.OrderBy[len(curValue.Value.OrderBy)-1].Value.ToInt64() } func (s *WindowFuncAggregatedStatus) lookupMinIndexFromRangeValue(rangeValue int64) (int, error) { @@ -511,7 +628,7 @@ func (s *WindowFuncAggregatedStatus) lookupMinIndexFromRangeValue(rangeValue int if len(value.OrderBy) != 1 { continue } - target, err := value.OrderBy[len(value.OrderBy)-1].ToInt64() + target, err := value.OrderBy[len(value.OrderBy)-1].Value.ToInt64() if err != nil { return 0, err } @@ -529,7 +646,7 @@ func (s *WindowFuncAggregatedStatus) lookupMaxIndexFromRangeValue(rangeValue int if len(value.OrderBy) != 1 { continue } - target, err := value.OrderBy[len(value.OrderBy)-1].ToInt64() + target, err := value.OrderBy[len(value.OrderBy)-1].Value.ToInt64() if err != nil { return 0, err } diff --git a/internal/json.go b/internal/json.go index 78a651a..f59f259 100644 --- a/internal/json.go +++ b/internal/json.go @@ -4,7 +4,6 @@ import ( "fmt" "strconv" "strings" - "time" "github.com/goccy/go-zetasql/types" ) @@ -43,8 +42,7 @@ func jsonFromZetaSQLValue(v types.Value) string { case types.TIME: return toTimeValueFromInt64(v.ToInt64()) case types.TIMESTAMP: - // TODO: 7 hours added time will be returned - return toTimestampValueFromTime(v.ToTime().Add(-7 * time.Hour)) + return toTimestampValueFromTime(v.ToTime()) case types.ARRAY: elems := []string{} for i := 0; i < v.NumElements(); i++ { diff --git a/internal/rows.go b/internal/rows.go index 204b190..c5227e0 100644 --- a/internal/rows.go +++ b/internal/rows.go @@ -162,7 +162,7 @@ func (r *Rows) convertValue(value interface{}, typ *Type) (driver.Value, error) if err != nil { return nil, err } - v = append(v, t) + v = append(v, t.UTC()) } return v, nil case types.STRUCT: @@ -201,7 +201,11 @@ func (r *Rows) convertValue(value interface{}, typ *Type) (driver.Value, error) if err != nil { return nil, err } - return val.ToJSON() + t, err := val.ToTime() + if err != nil { + return nil, err + } + return t.UTC(), nil } return value, nil } diff --git a/internal/spec.go b/internal/spec.go index 641a31e..a412e32 100644 --- a/internal/spec.go +++ b/internal/spec.go @@ -127,42 +127,42 @@ func (s *ColumnSpec) SQLiteSchema() string { switch types.TypeKind(s.Type.Kind) { case types.INT32, types.INT64, types.UINT32, types.UINT64: typ = "INT" + case types.ENUM: + typ = "INT" case types.BOOL: typ = "BOOLEAN" case types.FLOAT: typ = "FLOAT" + case types.BYTES: + typ = "BLOB" case types.DOUBLE: typ = "DOUBLE" + case types.JSON: + typ = "JSON" case types.STRING: typ = "TEXT" - case types.BYTES: - typ = "BLOB" case types.DATE: - typ = "DATE" + typ = "TEXT" case types.TIMESTAMP: - typ = "DATETIME" - case types.ENUM: - typ = "INT" + typ = "TEXT" case types.ARRAY: - typ = "JSON" + typ = "TEXT" case types.STRUCT: - typ = "JSON" + typ = "TEXT" case types.PROTO: - typ = "JSON" + typ = "TEXT" case types.TIME: - typ = "DATETIME" + typ = "TEXT" case types.DATETIME: - typ = "DATETIME" + typ = "TEXT" case types.GEOGRAPHY: - typ = "JSON" + typ = "TEXT" case types.NUMERIC: - typ = "NUMERIC" + typ = "TEXT" case types.BIG_NUMERIC: typ = "TEXT" case types.EXTENDED: typ = "TEXT" - case types.JSON: - typ = "JSON" case types.INTERVAL: typ = "TEXT" case types.UNKNOWN: diff --git a/internal/stmt.go b/internal/stmt.go index 919cc05..759911d 100644 --- a/internal/stmt.go +++ b/internal/stmt.go @@ -5,6 +5,8 @@ import ( "database/sql" "database/sql/driver" "fmt" + + ast "github.com/goccy/go-zetasql/resolved_ast" ) var ( @@ -83,14 +85,14 @@ func newCreateFunctionStmt(catalog *Catalog, spec *FunctionSpec) *CreateFunction type DMLStmt struct { stmt *sql.Stmt - argsNum int + args []*ast.ParameterNode formattedQuery string } -func newDMLStmt(stmt *sql.Stmt, argsNum int, formattedQuery string) *DMLStmt { +func newDMLStmt(stmt *sql.Stmt, args []*ast.ParameterNode, formattedQuery string) *DMLStmt { return &DMLStmt{ stmt: stmt, - argsNum: argsNum, + args: args, formattedQuery: formattedQuery, } } @@ -104,7 +106,7 @@ func (s *DMLStmt) Close() error { } func (s *DMLStmt) NumInput() int { - return s.argsNum + return len(s.args) } func (s *DMLStmt) Exec(args []driver.Value) (driver.Result, error) { @@ -112,7 +114,7 @@ func (s *DMLStmt) Exec(args []driver.Value) (driver.Result, error) { for _, arg := range args { values = append(values, arg) } - newArgs, err := convertValues(values) + newArgs, err := encodeValues(values, s.args) if err != nil { return nil, err } @@ -142,15 +144,15 @@ func (s *DMLStmt) QueryContext(ctx context.Context, query string, args []driver. type QueryStmt struct { stmt *sql.Stmt - argsNum int + args []*ast.ParameterNode formattedQuery string outputColumns []*ColumnSpec } -func newQueryStmt(stmt *sql.Stmt, argsNum int, formattedQuery string, outputColumns []*ColumnSpec) *QueryStmt { +func newQueryStmt(stmt *sql.Stmt, args []*ast.ParameterNode, formattedQuery string, outputColumns []*ColumnSpec) *QueryStmt { return &QueryStmt{ stmt: stmt, - argsNum: argsNum, + args: args, formattedQuery: formattedQuery, outputColumns: outputColumns, } @@ -165,7 +167,7 @@ func (s *QueryStmt) Close() error { } func (s *QueryStmt) NumInput() int { - return s.argsNum + return len(s.args) } func (s *QueryStmt) OutputColumns() []*ColumnSpec { @@ -185,7 +187,7 @@ func (s *QueryStmt) Query(args []driver.Value) (driver.Rows, error) { for _, arg := range args { values = append(values, arg) } - newArgs, err := convertValues(values) + newArgs, err := encodeValues(values, s.args) if err != nil { return nil, err } diff --git a/internal/value.go b/internal/value.go index 341797a..7011aa3 100644 --- a/internal/value.go +++ b/internal/value.go @@ -13,6 +13,9 @@ import ( "strconv" "strings" "time" + + ast "github.com/goccy/go-zetasql/resolved_ast" + "github.com/goccy/go-zetasql/types" ) type Value interface { @@ -1449,7 +1452,7 @@ func (d TimestampValue) ToStruct() (*StructValue, error) { } func (d TimestampValue) ToJSON() (string, error) { - return time.Time(d).Format(time.RFC3339), nil + return time.Time(d).Format(time.RFC3339Nano), nil } func (d TimestampValue) ToTime() (time.Time, error) { @@ -1710,54 +1713,6 @@ func StructValueOf(v string) (Value, error) { return &StructValue{keys: keys, values: values, m: valMap}, nil } -func SQLiteValue(v interface{}) (interface{}, error) { - rv := reflect.TypeOf(v) - switch rv.Kind() { - case reflect.Int: - return int64(v.(int)), nil - case reflect.Int8: - return int64(v.(int8)), nil - case reflect.Int16: - return int64(v.(int16)), nil - case reflect.Int32: - return int64(v.(int32)), nil - case reflect.Uint: - return int64(v.(uint)), nil - case reflect.Uint8: - return int64(v.(uint8)), nil - case reflect.Uint16: - return int64(v.(uint16)), nil - case reflect.Uint32: - return int64(v.(uint32)), nil - case reflect.Uint64: - return int64(v.(uint64)), nil - case reflect.Float32: - return float64(v.(float32)), nil - case reflect.Slice: - if rv.Elem().Kind() == reflect.Uint8 { - return string(v.([]byte)), nil - } - b, err := json.Marshal(v) - if err != nil { - return nil, fmt.Errorf("failed to encode value %v: %w", v, err) - } - return toArrayValueFromJSONString(string(b)), nil - case reflect.Array: - b, err := json.Marshal(v) - if err != nil { - return nil, fmt.Errorf("failed to encode value %v: %w", v, err) - } - return toArrayValueFromJSONString(string(b)), nil - case reflect.Struct: - b, err := json.Marshal(v) - if err != nil { - return nil, fmt.Errorf("failed to encode value %v: %w", v, err) - } - return toStructValueFromJSONString(string(b)), nil - } - return v, nil -} - func toArrayValueFromJSONString(json string) string { return strconv.Quote( fmt.Sprintf( @@ -2000,10 +1955,16 @@ func toTimeValue(s string) (time.Time, error) { return time.Time{}, fmt.Errorf("unsupported time format %s", s) } -func ConvertNamedValues(v []driver.NamedValue) ([]sql.NamedArg, error) { +func EncodeNamedValues(v []driver.NamedValue, params []*ast.ParameterNode) ([]sql.NamedArg, error) { + if len(v) != len(params) { + return nil, fmt.Errorf( + "failed to match named values num (%d) and params num (%d)", + len(v), len(params), + ) + } ret := make([]sql.NamedArg, 0, len(v)) - for _, vv := range v { - converted, err := convertNamedValue(vv) + for idx, vv := range v { + converted, err := encodeNamedValue(vv, params[idx]) if err != nil { return nil, fmt.Errorf("failed to convert value from %+v: %w", vv, err) } @@ -2012,8 +1973,8 @@ func ConvertNamedValues(v []driver.NamedValue) ([]sql.NamedArg, error) { return ret, nil } -func convertNamedValue(v driver.NamedValue) (sql.NamedArg, error) { - value, err := SQLiteValue(v.Value) +func encodeNamedValue(v driver.NamedValue, param *ast.ParameterNode) (sql.NamedArg, error) { + value, err := encodeValueWithType(v.Value, param.Type()) if err != nil { return sql.NamedArg{}, err } @@ -2023,10 +1984,16 @@ func convertNamedValue(v driver.NamedValue) (sql.NamedArg, error) { }, nil } -func convertValues(v []interface{}) ([]interface{}, error) { +func encodeValues(v []interface{}, params []*ast.ParameterNode) ([]interface{}, error) { + if len(v) != len(params) { + return nil, fmt.Errorf( + "failed to match args values num (%d) and params num (%d)", + len(v), len(params), + ) + } ret := make([]interface{}, 0, len(v)) - for _, vv := range v { - value, err := SQLiteValue(vv) + for idx, vv := range v { + value, err := encodeValueWithType(vv, params[idx].Type()) if err != nil { return nil, err } @@ -2034,3 +2001,84 @@ func convertValues(v []interface{}) ([]interface{}, error) { } return ret, nil } + +func encodeValueWithType(v interface{}, t types.Type) (interface{}, error) { + switch t.Kind() { + case types.INT32, types.INT64, types.UINT32, types.UINT64, types.ENUM: + vv, err := ValueOf(v) + if err != nil { + return nil, err + } + return vv.ToInt64() + case types.BOOL: + vv, err := ValueOf(v) + if err != nil { + return nil, err + } + return vv.ToBool() + case types.FLOAT, types.DOUBLE: + vv, err := ValueOf(v) + if err != nil { + return nil, err + } + return vv.ToFloat64() + case types.STRING: + vv, err := ValueOf(v) + if err != nil { + return nil, err + } + return vv.ToString() + case types.BYTES: + vv, err := ValueOf(v) + if err != nil { + return nil, err + } + s, err := vv.ToString() + if err != nil { + return nil, err + } + return []byte(s), nil + case types.DATE: + text, ok := v.(string) + if !ok { + return nil, fmt.Errorf("failed to convert DATE from %T", v) + } + return toDateValueFromString(text), nil + case types.TIMESTAMP: + text, ok := v.(string) + if !ok { + return nil, fmt.Errorf("failed to convert TIMESTAMP from %T", v) + } + return toTimestampValueFromString(text), nil + case types.ARRAY: + b, err := json.Marshal(v) + if err != nil { + return nil, fmt.Errorf("failed to encode array value %v: %w", v, err) + } + return toArrayValueFromJSONString(string(b)), nil + case types.STRUCT: + b, err := json.Marshal(v) + if err != nil { + return nil, fmt.Errorf("failed to encode struct value %v: %w", v, err) + } + return toStructValueFromJSONString(string(b)), nil + case types.TIME: + case types.DATETIME: + case types.PROTO: + return nil, fmt.Errorf("failed to convert PROTO type from %T", v) + case types.GEOGRAPHY: + return nil, fmt.Errorf("failed to convert GEOGRAPHY type from %T", v) + case types.NUMERIC: + return nil, fmt.Errorf("failed to convert NUMERIC type from %T", v) + case types.BIG_NUMERIC: + return nil, fmt.Errorf("failed to convert BIGNUMERIC type from %T", v) + case types.EXTENDED: + return nil, fmt.Errorf("failed to convert EXTENDED type from %T", v) + case types.JSON: + return nil, fmt.Errorf("failed to convert JSON type from %T", v) + case types.INTERVAL: + return nil, fmt.Errorf("failed to convert INTERVAL type from %T", v) + default: + } + return nil, fmt.Errorf("unexpected type %s to convert from %T", t.Kind(), v) +} diff --git a/query_test.go b/query_test.go index ac22fd7..981cbd3 100644 --- a/query_test.go +++ b/query_test.go @@ -3,6 +3,7 @@ package zetasqlite_test import ( "context" "database/sql" + "math" "reflect" "testing" "time" @@ -20,6 +21,11 @@ func TestQuery(t *testing.T) { t.Fatal(err) } defer db.Close() + floatCmpOpt := cmp.Comparer(func(x, y float64) bool { + delta := math.Abs(x - y) + mean := math.Abs(x+y) / 2.0 + return delta/mean < 0.00001 + }) for _, test := range []struct { name string query string @@ -295,7 +301,7 @@ FROM UNNEST([1, 2, 3, 4]) AS val`, { name: "nullif true", query: `SELECT NULLIF(0, 0)`, - expectedRows: [][]interface{}{}, + expectedRows: [][]interface{}{{nil}}, }, { name: "nullif false", @@ -566,12 +572,12 @@ SELECT ARRAY_CONCAT_AGG(x) AS array_concat_agg FROM ( { name: "sum null", query: `SELECT SUM(x) AS sum FROM UNNEST([]) AS x`, - expectedRows: [][]interface{}{}, + expectedRows: [][]interface{}{{nil}}, }, { name: "null", query: `SELECT NULL`, - expectedRows: [][]interface{}{}, + expectedRows: [][]interface{}{{nil}}, }, // window function @@ -898,6 +904,167 @@ FROM Numbers`, {int64(10), int64(6)}, }, }, + { + name: "window dense_rank", + query: ` +WITH Numbers AS + (SELECT 1 as x + UNION ALL SELECT 2 + UNION ALL SELECT 2 + UNION ALL SELECT 5 + UNION ALL SELECT 8 + UNION ALL SELECT 10 + UNION ALL SELECT 10 +) +SELECT x, + DENSE_RANK() OVER (ORDER BY x ASC) AS dense_rank +FROM Numbers`, + expectedRows: [][]interface{}{ + {int64(1), int64(1)}, + {int64(2), int64(2)}, + {int64(2), int64(2)}, + {int64(5), int64(3)}, + {int64(8), int64(4)}, + {int64(10), int64(5)}, + {int64(10), int64(5)}, + }, + }, + { + name: "window dense_rank with group", + query: ` +WITH finishers AS + (SELECT 'Sophia Liu' as name, + TIMESTAMP '2016-10-18 2:51:45' as finish_time, + 'F30-34' as division + UNION ALL SELECT 'Lisa Stelzner', TIMESTAMP '2016-10-18 2:54:11', 'F35-39' + UNION ALL SELECT 'Nikki Leith', TIMESTAMP '2016-10-18 2:59:01', 'F30-34' + UNION ALL SELECT 'Lauren Matthews', TIMESTAMP '2016-10-18 3:01:17', 'F35-39' + UNION ALL SELECT 'Desiree Berry', TIMESTAMP '2016-10-18 3:05:42', 'F35-39' + UNION ALL SELECT 'Suzy Slane', TIMESTAMP '2016-10-18 3:06:24', 'F35-39' + UNION ALL SELECT 'Jen Edwards', TIMESTAMP '2016-10-18 3:06:36', 'F30-34' + UNION ALL SELECT 'Meghan Lederer', TIMESTAMP '2016-10-18 2:59:01', 'F30-34') +SELECT name, + finish_time, + division, + DENSE_RANK() OVER (PARTITION BY division ORDER BY finish_time ASC) AS finish_rank +FROM finishers +`, + expectedRows: [][]interface{}{ + {"Sophia Liu", createTimeFromString("2016-10-18 09:51:45+00"), "F30-34", int64(1)}, + {"Nikki Leith", createTimeFromString("2016-10-18 09:59:01+00"), "F30-34", int64(2)}, + {"Meghan Lederer", createTimeFromString("2016-10-18 09:59:01+00"), "F30-34", int64(2)}, + {"Jen Edwards", createTimeFromString("2016-10-18 10:06:36+00"), "F30-34", int64(3)}, + {"Lisa Stelzner", createTimeFromString("2016-10-18 09:54:11+00"), "F35-39", int64(1)}, + {"Lauren Matthews", createTimeFromString("2016-10-18 10:01:17+00"), "F35-39", int64(2)}, + {"Desiree Berry", createTimeFromString("2016-10-18 10:05:42+00"), "F35-39", int64(3)}, + {"Suzy Slane", createTimeFromString("2016-10-18 10:06:24+00"), "F35-39", int64(4)}, + }, + }, + { + name: "window lag", + query: ` +WITH finishers AS + (SELECT 'Sophia Liu' as name, + TIMESTAMP '2016-10-18 2:51:45+00' as finish_time, + 'F30-34' as division + UNION ALL SELECT 'Lisa Stelzner', TIMESTAMP '2016-10-18 2:54:11+00', 'F35-39' + UNION ALL SELECT 'Nikki Leith', TIMESTAMP '2016-10-18 2:59:01+00', 'F30-34' + UNION ALL SELECT 'Lauren Matthews', TIMESTAMP '2016-10-18 3:01:17+00', 'F35-39' + UNION ALL SELECT 'Desiree Berry', TIMESTAMP '2016-10-18 3:05:42+00', 'F35-39' + UNION ALL SELECT 'Suzy Slane', TIMESTAMP '2016-10-18 3:06:24+00', 'F35-39' + UNION ALL SELECT 'Jen Edwards', TIMESTAMP '2016-10-18 3:06:36+00', 'F30-34' + UNION ALL SELECT 'Meghan Lederer', TIMESTAMP '2016-10-18 3:07:41+00', 'F30-34' + UNION ALL SELECT 'Carly Forte', TIMESTAMP '2016-10-18 3:08:58+00', 'F25-29' + UNION ALL SELECT 'Lauren Reasoner', TIMESTAMP '2016-10-18 3:10:14+00', 'F30-34') +SELECT name, + finish_time, + division, + LAG(name) + OVER (PARTITION BY division ORDER BY finish_time ASC) AS preceding_runner +FROM finishers`, + expectedRows: [][]interface{}{ + {"Carly Forte", createTimeFromString("2016-10-18 03:08:58+00"), "F25-29", nil}, + {"Sophia Liu", createTimeFromString("2016-10-18 02:51:45+00"), "F30-34", nil}, + {"Nikki Leith", createTimeFromString("2016-10-18 02:59:01+00"), "F30-34", "Sophia Liu"}, + {"Jen Edwards", createTimeFromString("2016-10-18 03:06:36+00"), "F30-34", "Nikki Leith"}, + {"Meghan Lederer", createTimeFromString("2016-10-18 03:07:41+00"), "F30-34", "Jen Edwards"}, + {"Lauren Reasoner", createTimeFromString("2016-10-18 03:10:14+00"), "F30-34", "Meghan Lederer"}, + {"Lisa Stelzner", createTimeFromString("2016-10-18 02:54:11+00"), "F35-39", nil}, + {"Lauren Matthews", createTimeFromString("2016-10-18 03:01:17+00"), "F35-39", "Lisa Stelzner"}, + {"Desiree Berry", createTimeFromString("2016-10-18 03:05:42+00"), "F35-39", "Lauren Matthews"}, + {"Suzy Slane", createTimeFromString("2016-10-18 03:06:24+00"), "F35-39", "Desiree Berry"}, + }, + }, + { + name: "window lag with offset", + query: ` +WITH finishers AS + (SELECT 'Sophia Liu' as name, + TIMESTAMP '2016-10-18 2:51:45+00' as finish_time, + 'F30-34' as division + UNION ALL SELECT 'Lisa Stelzner', TIMESTAMP '2016-10-18 2:54:11+00', 'F35-39' + UNION ALL SELECT 'Nikki Leith', TIMESTAMP '2016-10-18 2:59:01+00', 'F30-34' + UNION ALL SELECT 'Lauren Matthews', TIMESTAMP '2016-10-18 3:01:17+00', 'F35-39' + UNION ALL SELECT 'Desiree Berry', TIMESTAMP '2016-10-18 3:05:42+00', 'F35-39' + UNION ALL SELECT 'Suzy Slane', TIMESTAMP '2016-10-18 3:06:24+00', 'F35-39' + UNION ALL SELECT 'Jen Edwards', TIMESTAMP '2016-10-18 3:06:36+00', 'F30-34' + UNION ALL SELECT 'Meghan Lederer', TIMESTAMP '2016-10-18 3:07:41+00', 'F30-34' + UNION ALL SELECT 'Carly Forte', TIMESTAMP '2016-10-18 3:08:58+00', 'F25-29' + UNION ALL SELECT 'Lauren Reasoner', TIMESTAMP '2016-10-18 3:10:14+00', 'F30-34') +SELECT name, + finish_time, + division, + LAG(name, 2) + OVER (PARTITION BY division ORDER BY finish_time ASC) AS two_runners_ahead +FROM finishers`, + expectedRows: [][]interface{}{ + {"Carly Forte", createTimeFromString("2016-10-18 03:08:58+00"), "F25-29", nil}, + {"Sophia Liu", createTimeFromString("2016-10-18 02:51:45+00"), "F30-34", nil}, + {"Nikki Leith", createTimeFromString("2016-10-18 02:59:01+00"), "F30-34", nil}, + {"Jen Edwards", createTimeFromString("2016-10-18 03:06:36+00"), "F30-34", "Sophia Liu"}, + {"Meghan Lederer", createTimeFromString("2016-10-18 03:07:41+00"), "F30-34", "Nikki Leith"}, + {"Lauren Reasoner", createTimeFromString("2016-10-18 03:10:14+00"), "F30-34", "Jen Edwards"}, + {"Lisa Stelzner", createTimeFromString("2016-10-18 02:54:11+00"), "F35-39", nil}, + {"Lauren Matthews", createTimeFromString("2016-10-18 03:01:17+00"), "F35-39", nil}, + {"Desiree Berry", createTimeFromString("2016-10-18 03:05:42+00"), "F35-39", "Lisa Stelzner"}, + {"Suzy Slane", createTimeFromString("2016-10-18 03:06:24+00"), "F35-39", "Lauren Matthews"}, + }, + }, + { + name: "window lag with offset and default value", + query: ` +WITH finishers AS + (SELECT 'Sophia Liu' as name, + TIMESTAMP '2016-10-18 2:51:45+00' as finish_time, + 'F30-34' as division + UNION ALL SELECT 'Lisa Stelzner', TIMESTAMP '2016-10-18 2:54:11+00', 'F35-39' + UNION ALL SELECT 'Nikki Leith', TIMESTAMP '2016-10-18 2:59:01+00', 'F30-34' + UNION ALL SELECT 'Lauren Matthews', TIMESTAMP '2016-10-18 3:01:17+00', 'F35-39' + UNION ALL SELECT 'Desiree Berry', TIMESTAMP '2016-10-18 3:05:42+00', 'F35-39' + UNION ALL SELECT 'Suzy Slane', TIMESTAMP '2016-10-18 3:06:24+00', 'F35-39' + UNION ALL SELECT 'Jen Edwards', TIMESTAMP '2016-10-18 3:06:36+00', 'F30-34' + UNION ALL SELECT 'Meghan Lederer', TIMESTAMP '2016-10-18 3:07:41+00', 'F30-34' + UNION ALL SELECT 'Carly Forte', TIMESTAMP '2016-10-18 3:08:58+00', 'F25-29' + UNION ALL SELECT 'Lauren Reasoner', TIMESTAMP '2016-10-18 3:10:14+00', 'F30-34') +SELECT name, + finish_time, + division, + LAG(name, 2, 'NoBody') + OVER (PARTITION BY division ORDER BY finish_time ASC) AS two_runners_ahead +FROM finishers`, + expectedRows: [][]interface{}{ + {"Carly Forte", createTimeFromString("2016-10-18 03:08:58+00"), "F25-29", "NoBody"}, + {"Sophia Liu", createTimeFromString("2016-10-18 02:51:45+00"), "F30-34", "NoBody"}, + {"Nikki Leith", createTimeFromString("2016-10-18 02:59:01+00"), "F30-34", "NoBody"}, + {"Jen Edwards", createTimeFromString("2016-10-18 03:06:36+00"), "F30-34", "Sophia Liu"}, + {"Meghan Lederer", createTimeFromString("2016-10-18 03:07:41+00"), "F30-34", "Nikki Leith"}, + {"Lauren Reasoner", createTimeFromString("2016-10-18 03:10:14+00"), "F30-34", "Jen Edwards"}, + {"Lisa Stelzner", createTimeFromString("2016-10-18 02:54:11+00"), "F35-39", "NoBody"}, + {"Lauren Matthews", createTimeFromString("2016-10-18 03:01:17+00"), "F35-39", "NoBody"}, + {"Desiree Berry", createTimeFromString("2016-10-18 03:05:42+00"), "F35-39", "Lisa Stelzner"}, + {"Suzy Slane", createTimeFromString("2016-10-18 03:06:24+00"), "F35-39", "Lauren Matthews"}, + }, + }, { name: "sign", query: `SELECT SIGN(25) UNION ALL SELECT SIGN(0) UNION ALL SELECT SIGN(-25)`, @@ -930,7 +1097,7 @@ FROM Numbers`, name: "current_timestamp", query: `SELECT CURRENT_TIMESTAMP()`, expectedRows: [][]interface{}{ - {now.Format(time.RFC3339)}, + {now.UTC()}, }, }, // INVALID_ARGUMENT: No matching signature for operator - for argument types: TIMESTAMP, TIMESTAMP. Supported signatures: INT64 - INT64; NUMERIC - NUMERIC; FLOAT64 - FLOAT64; DATE - INT64 [at 1:8] @@ -1142,7 +1309,7 @@ FROM ( }, { name: "generate_timestamp_array function", - query: `SELECT GENERATE_TIMESTAMP_ARRAY('2016-10-05 00:00:00', '2016-10-07 00:00:00', INTERVAL 1 DAY) AS timestamp_array`, + query: `SELECT GENERATE_TIMESTAMP_ARRAY(TIMESTAMP '2016-10-05 00:00:00+00', '2016-10-07 00:00:00+00', INTERVAL 1 DAY) AS timestamp_array`, expectedRows: [][]interface{}{ { []time.Time{ @@ -1155,7 +1322,7 @@ FROM ( }, { name: "generate_timestamp_array function interval 1 second", - query: `SELECT GENERATE_TIMESTAMP_ARRAY('2016-10-05 00:00:00', '2016-10-05 00:00:02', INTERVAL 1 SECOND) AS timestamp_array`, + query: `SELECT GENERATE_TIMESTAMP_ARRAY('2016-10-05 00:00:00+00', '2016-10-05 00:00:02+00', INTERVAL 1 SECOND) AS timestamp_array`, expectedRows: [][]interface{}{ { []time.Time{ @@ -1168,7 +1335,7 @@ FROM ( }, { name: "generate_timestamp_array function negative interval", - query: `SELECT GENERATE_TIMESTAMP_ARRAY('2016-10-06 00:00:00', '2016-10-01 00:00:00', INTERVAL -2 DAY) AS timestamp_array`, + query: `SELECT GENERATE_TIMESTAMP_ARRAY('2016-10-06 00:00:00+00', '2016-10-01 00:00:00+00', INTERVAL -2 DAY) AS timestamp_array`, expectedRows: [][]interface{}{ { []time.Time{ @@ -1181,7 +1348,7 @@ FROM ( }, { name: "generate_timestamp_array function same value", - query: `SELECT GENERATE_TIMESTAMP_ARRAY('2016-10-05 00:00:00', '2016-10-05 00:00:00', INTERVAL 1 HOUR) AS timestamp_array`, + query: `SELECT GENERATE_TIMESTAMP_ARRAY('2016-10-05 00:00:00+00', '2016-10-05 00:00:00+00', INTERVAL 1 HOUR) AS timestamp_array`, expectedRows: [][]interface{}{ { []time.Time{ @@ -1192,14 +1359,14 @@ FROM ( }, { name: "generate_timestamp_array function over step", - query: `SELECT GENERATE_TIMESTAMP_ARRAY('2016-10-06 00:00:00', '2016-10-05 00:00:00', INTERVAL 1 HOUR) AS timestamp_array`, + query: `SELECT GENERATE_TIMESTAMP_ARRAY('2016-10-06 00:00:00+00', '2016-10-05 00:00:00+00', INTERVAL 1 HOUR) AS timestamp_array`, expectedRows: [][]interface{}{ {[]time.Time{}}, }, }, { name: "generate_timestamp_array function with null", - query: `SELECT GENERATE_TIMESTAMP_ARRAY('2016-10-05 00:00:00', NULL, INTERVAL 1 HOUR) AS timestamp_array`, + query: `SELECT GENERATE_TIMESTAMP_ARRAY('2016-10-05 00:00:00+00', NULL, INTERVAL 1 HOUR) AS timestamp_array`, expectedRows: [][]interface{}{ {nil}, }, @@ -1211,16 +1378,16 @@ SELECT GENERATE_TIMESTAMP_ARRAY(start_timestamp, end_timestamp, INTERVAL 1 HOUR) AS timestamp_array FROM (SELECT - TIMESTAMP '2016-10-05 00:00:00' AS start_timestamp, - TIMESTAMP '2016-10-05 02:00:00' AS end_timestamp + TIMESTAMP '2016-10-05 00:00:00+00' AS start_timestamp, + TIMESTAMP '2016-10-05 02:00:00+00' AS end_timestamp UNION ALL SELECT - TIMESTAMP '2016-10-05 12:00:00' AS start_timestamp, - TIMESTAMP '2016-10-05 14:00:00' AS end_timestamp + TIMESTAMP '2016-10-05 12:00:00+00' AS start_timestamp, + TIMESTAMP '2016-10-05 14:00:00+00' AS end_timestamp UNION ALL SELECT - TIMESTAMP '2016-10-05 23:59:00' AS start_timestamp, - TIMESTAMP '2016-10-06 01:59:00' AS end_timestamp)`, + TIMESTAMP '2016-10-05 23:59:00+00' AS start_timestamp, + TIMESTAMP '2016-10-06 01:59:00+00' AS end_timestamp)`, expectedRows: [][]interface{}{ { []time.Time{ @@ -1259,6 +1426,252 @@ WITH example AS ( {[]int64{}}, }, }, + { + name: "group by", + query: ` +WITH Sales AS ( + SELECT 123 AS sku, 1 AS day, 9.99 AS price UNION ALL + SELECT 123, 1, 8.99 UNION ALL + SELECT 456, 1, 4.56 UNION ALL + SELECT 123, 2, 9.99 UNION ALL + SELECT 789, 3, 1.00 UNION ALL + SELECT 456, 3, 4.25 UNION ALL + SELECT 789, 3, 0.99 +) +SELECT + day, + SUM(price) AS total +FROM Sales +GROUP BY day`, + expectedRows: [][]interface{}{ + {int64(1), float64(23.54)}, + {int64(2), float64(9.99)}, + {int64(3), float64(6.24)}, + }, + }, + { + name: "group by rollup with one column", + query: ` +WITH Sales AS ( + SELECT 123 AS sku, 1 AS day, 9.99 AS price UNION ALL + SELECT 123, 1, 8.99 UNION ALL + SELECT 456, 1, 4.56 UNION ALL + SELECT 123, 2, 9.99 UNION ALL + SELECT 789, 3, 1.00 UNION ALL + SELECT 456, 3, 4.25 UNION ALL + SELECT 789, 3, 0.99 +) +SELECT + day, + SUM(price) AS total +FROM Sales +GROUP BY ROLLUP(day)`, + expectedRows: [][]interface{}{ + {nil, float64(39.77)}, + {int64(1), float64(23.54)}, + {int64(2), float64(9.99)}, + {int64(3), float64(6.24)}, + }, + }, + { + name: "group by rollup with two columns", + query: ` +WITH Sales AS ( + SELECT 123 AS sku, 1 AS day, 9.99 AS price UNION ALL + SELECT 123, 1, 8.99 UNION ALL + SELECT 456, 1, 4.56 UNION ALL + SELECT 123, 2, 9.99 UNION ALL + SELECT 789, 3, 1.00 UNION ALL + SELECT 456, 3, 4.25 UNION ALL + SELECT 789, 3, 0.99 +) +SELECT + sku, + day, + SUM(price) AS total +FROM Sales +GROUP BY ROLLUP(sku, day) +ORDER BY sku, day`, + expectedRows: [][]interface{}{ + {nil, nil, float64(39.77)}, + {int64(123), nil, float64(28.97)}, + {int64(123), int64(1), float64(18.98)}, + {int64(123), int64(2), float64(9.99)}, + {int64(456), nil, float64(8.81)}, + {int64(456), int64(1), float64(4.56)}, + {int64(456), int64(3), float64(4.25)}, + {int64(789), nil, float64(1.99)}, + {int64(789), int64(3), float64(1.99)}, + }, + }, + { + name: "group by having", + query: ` +WITH Sales AS ( + SELECT 123 AS sku, 1 AS day, 9.99 AS price UNION ALL + SELECT 123, 1, 8.99 UNION ALL + SELECT 456, 1, 4.56 UNION ALL + SELECT 123, 2, 9.99 UNION ALL + SELECT 789, 2, 1.00 UNION ALL + SELECT 456, 3, 4.25 UNION ALL + SELECT 789, 3, 0.99 +) +SELECT + day, + SUM(price) AS total +FROM Sales +GROUP BY day HAVING SUM(price) > 10`, + expectedRows: [][]interface{}{ + {int64(1), float64(23.54)}, + {int64(2), float64(10.99)}, + }, + }, + { + name: "order by", + query: `SELECT x, y FROM (SELECT 1 AS x, true AS y UNION ALL SELECT 9, true UNION ALL SELECT NULL, false) ORDER BY x`, + expectedRows: [][]interface{}{ + {nil, false}, + {int64(1), true}, + {int64(9), true}, + }, + }, + { + name: "order by with nulls last", + query: `SELECT x, y FROM (SELECT 1 AS x, true AS y UNION ALL SELECT 9, true UNION ALL SELECT NULL, false) ORDER BY x NULLS LAST`, + expectedRows: [][]interface{}{ + {int64(1), true}, + {int64(9), true}, + {nil, false}, + }, + }, + { + name: "order by desc", + query: `SELECT x, y FROM (SELECT 1 AS x, true AS y UNION ALL SELECT 9, true UNION ALL SELECT NULL, false) ORDER BY x DESC`, + expectedRows: [][]interface{}{ + {int64(9), true}, + {int64(1), true}, + {nil, false}, + }, + }, + { + name: "order by nulls first", + query: `SELECT x, y FROM (SELECT 1 AS x, true AS y UNION ALL SELECT 9, true UNION ALL SELECT NULL, false) ORDER BY x DESC NULLS FIRST`, + expectedRows: [][]interface{}{ + {nil, false}, + {int64(9), true}, + {int64(1), true}, + }, + }, + { + name: "inner join with using", + query: ` +WITH Roster AS + (SELECT 'Adams' as LastName, 50 as SchoolID UNION ALL + SELECT 'Buchanan', 52 UNION ALL + SELECT 'Coolidge', 52 UNION ALL + SELECT 'Davis', 51 UNION ALL + SELECT 'Eisenhower', 77), + TeamMascot AS + (SELECT 50 as SchoolID, 'Jaguars' as Mascot UNION ALL + SELECT 51, 'Knights' UNION ALL + SELECT 52, 'Lakers' UNION ALL + SELECT 53, 'Mustangs') +SELECT * FROM Roster INNER JOIN TeamMascot USING (SchoolID) +`, + expectedRows: [][]interface{}{ + {int64(50), "Adams", "Jaguars"}, + {int64(52), "Buchanan", "Lakers"}, + {int64(52), "Coolidge", "Lakers"}, + {int64(51), "Davis", "Knights"}, + }, + }, + { + name: "left join", + query: ` +WITH Roster AS + (SELECT 'Adams' as LastName, 50 as SchoolID UNION ALL + SELECT 'Buchanan', 52 UNION ALL + SELECT 'Coolidge', 52 UNION ALL + SELECT 'Davis', 51 UNION ALL + SELECT 'Eisenhower', 77), + TeamMascot AS + (SELECT 50 as SchoolID, 'Jaguars' as Mascot UNION ALL + SELECT 51, 'Knights' UNION ALL + SELECT 52, 'Lakers' UNION ALL + SELECT 53, 'Mustangs') +SELECT Roster.LastName, TeamMascot.Mascot FROM Roster LEFT JOIN TeamMascot ON Roster.SchoolID = TeamMascot.SchoolID +`, + expectedRows: [][]interface{}{ + {"Adams", "Jaguars"}, + {"Buchanan", "Lakers"}, + {"Coolidge", "Lakers"}, + {"Davis", "Knights"}, + {"Eisenhower", nil}, + }, + }, + /* + { + name: "right join", + query: ` + WITH Roster AS + (SELECT 'Adams' as LastName, 50 as SchoolID UNION ALL + SELECT 'Buchanan', 52 UNION ALL + SELECT 'Coolidge', 52 UNION ALL + SELECT 'Davis', 51 UNION ALL + SELECT 'Eisenhower', 77), + TeamMascot AS + (SELECT 50 as SchoolID, 'Jaguars' as Mascot UNION ALL + SELECT 51, 'Knights' UNION ALL + SELECT 52, 'Lakers' UNION ALL + SELECT 53, 'Mustangs') + SELECT Roster.LastName, TeamMascot.Mascot FROM Roster RIGHT JOIN TeamMascot ON Roster.SchoolID = TeamMascot.SchoolID + `, + expectedRows: [][]interface{}{ + {"Adams", "Jaguars"}, + {"Buchanan", "Lakers"}, + {"Coolidge", "Lakers"}, + {"Davis", "Knights"}, + {nil, "Mustangs"}, + }, + }, + */ + { + name: "qualify", + query: ` +WITH Produce AS + (SELECT 'kale' as item, 23 as purchases, 'vegetable' as category + UNION ALL SELECT 'banana', 2, 'fruit' + UNION ALL SELECT 'cabbage', 9, 'vegetable' + UNION ALL SELECT 'apple', 8, 'fruit' + UNION ALL SELECT 'leek', 2, 'vegetable' + UNION ALL SELECT 'lettuce', 10, 'vegetable') +SELECT + item, + RANK() OVER (PARTITION BY category ORDER BY purchases DESC) as rank +FROM Produce WHERE Produce.category = 'vegetable' QUALIFY rank <= 3`, + expectedRows: [][]interface{}{ + {"kale", int64(1)}, + {"lettuce", int64(2)}, + {"cabbage", int64(3)}, + }, + }, + { + name: "qualify direct", + query: ` +WITH Produce AS + (SELECT 'kale' as item, 23 as purchases, 'vegetable' as category + UNION ALL SELECT 'banana', 2, 'fruit' + UNION ALL SELECT 'cabbage', 9, 'vegetable' + UNION ALL SELECT 'apple', 8, 'fruit' + UNION ALL SELECT 'leek', 2, 'vegetable' + UNION ALL SELECT 'lettuce', 10, 'vegetable') +SELECT item FROM Produce WHERE Produce.category = 'vegetable' QUALIFY RANK() OVER (PARTITION BY category ORDER BY purchases DESC) <= 3`, + expectedRows: [][]interface{}{ + {"kale"}, + {"lettuce"}, + {"cabbage"}, + }, + }, } { test := test t.Run(test.name, func(t *testing.T) { @@ -1267,10 +1680,11 @@ WITH example AS ( t.Fatal(err) } defer rows.Close() - if len(test.expectedRows) == 0 { - return + columns, err := rows.Columns() + if err != nil { + t.Fatal(err) } - columnNum := len(test.expectedRows[0]) + columnNum := len(columns) args := []interface{}{} for i := 0; i < columnNum; i++ { var v interface{} @@ -1281,13 +1695,20 @@ WITH example AS ( if err := rows.Scan(args...); err != nil { t.Fatal(err) } - expectedRow := test.expectedRows[rowNum] - if len(args) != len(expectedRow) { - t.Fatalf("failed to get columns. expected %d but got %d", len(expectedRow), len(args)) - } + derefArgs := []interface{}{} for i := 0; i < len(args); i++ { value := reflect.ValueOf(args[i]).Elem().Interface() - if diff := cmp.Diff(expectedRow[i], value); diff != "" { + derefArgs = append(derefArgs, value) + } + if len(test.expectedRows) <= rowNum { + t.Fatalf("unexpected row %v", derefArgs) + } + expectedRow := test.expectedRows[rowNum] + if len(derefArgs) != len(expectedRow) { + t.Fatalf("failed to get columns. expected %d but got %d", len(expectedRow), len(derefArgs)) + } + for i := 0; i < len(derefArgs); i++ { + if diff := cmp.Diff(expectedRow[i], derefArgs[i], floatCmpOpt); diff != "" { t.Errorf("(-want +got):\n%s", diff) } }