diff --git a/README.md b/README.md index c532686..f95a9a8 100644 --- a/README.md +++ b/README.md @@ -94,7 +94,7 @@ A list of ZetaSQL specifications and features supported by go-zetasqlite. - [x] ARRAY - [x] STRUCT - [ ] GEOGRAPHY -- [ ] JSON +- [x] JSON - [x] RECORD ## Statements @@ -369,8 +369,13 @@ A list of ZetaSQL specifications and features supported by go-zetasqlite. - [ ] JSON_EXTRACT_STRING_ARRAY - [ ] JSON_VALUE_ARRAY - [ ] PARSE_JSON -- [ ] TO_JSON +- [x] TO_JSON - [ ] TO_JSON_STRING +- [ ] STRING +- [x] BOOL +- [x] INT64 +- [x] FLOAT64 +- [x] JSON_TYPE ### Array functions diff --git a/driver.go b/driver.go index c6c5b0f..a784a8b 100644 --- a/driver.go +++ b/driver.go @@ -7,7 +7,6 @@ import ( "fmt" "sync" - "github.com/goccy/go-zetasql" internal "github.com/goccy/go-zetasqlite/internal" "github.com/mattn/go-sqlite3" ) @@ -103,10 +102,6 @@ func (c *ZetaSQLiteConn) AddNamePath(path string) { c.analyzer.AddNamePath(path) } -func (c *ZetaSQLiteConn) SetParameterMode(mode zetasql.ParameterMode) { - c.analyzer.SetParameterMode(mode) -} - func (s *ZetaSQLiteConn) CheckNamedValue(value *driver.NamedValue) error { return nil } diff --git a/driver_test.go b/driver_test.go index ac5ec16..f827f61 100644 --- a/driver_test.go +++ b/driver_test.go @@ -59,7 +59,7 @@ func TestRegisterCustomDriver(t *testing.T) { if _, err := db.Exec("INSERT `project-id`.datasetID.tableID (Id) VALUES (1)"); err != nil { t.Fatal(err) } - row := db.QueryRow("SELECT * FROM project-id.datasetID.tableID WHERE Id = @id", 1) + row := db.QueryRow("SELECT * FROM project-id.datasetID.tableID WHERE Id = ?", 1) if row.Err() != nil { t.Fatal(row.Err()) } diff --git a/exec_test.go b/exec_test.go index 0acffe5..b8e6ff4 100644 --- a/exec_test.go +++ b/exec_test.go @@ -28,6 +28,15 @@ func TestExec(t *testing.T) { name: "create table with all types", query: `CREATE TABLE _table_a ( doubleValue DOUBLE, floatValue FLOAT )`, }, + { + name: "recreate table", + query: ` +CREATE OR REPLACE TABLE recreate_table ( a string ); +DROP TABLE recreate_table; +CREATE TABLE recreate_table ( b string ); +INSERT recreate_table (b) VALUES ('hello'); +`, + }, { name: "transaction", query: ` diff --git a/go.mod b/go.mod index 468baf2..5642573 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/goccy/go-zetasqlite go 1.17 require ( - github.com/goccy/go-zetasql v0.3.2 + github.com/goccy/go-zetasql v0.3.3 github.com/mattn/go-sqlite3 v1.14.14 ) diff --git a/go.sum b/go.sum index 14a2eea..fbda56e 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,8 @@ github.com/go-pdf/fpdf v0.5.0/go.mod h1:HzcnA+A23uwogo0tp9yU+l3V+KXhiESpt1PMayhO github.com/go-pdf/fpdf v0.6.0/go.mod h1:HzcnA+A23uwogo0tp9yU+l3V+KXhiESpt1PMayhOh5M= github.com/goccy/go-json v0.9.10 h1:hCeNmprSNLB8B8vQKWl6DpuH0t60oEs+TAk9a7CScKc= github.com/goccy/go-json v0.9.10/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= -github.com/goccy/go-zetasql v0.3.2 h1:+HxuazroaYJnZYZdBgeRwxnRmXO+OTE+EMQ9hQwHhUA= -github.com/goccy/go-zetasql v0.3.2/go.mod h1:6W14CJVKh7crrSPyj6NPk4c49L2NWnxvyDLsRkOm4BI= +github.com/goccy/go-zetasql v0.3.3 h1:+Ar/GZ4k2vNgaljRPt5lpD8JSIiq0WSEG38Fbowj5fM= +github.com/goccy/go-zetasql v0.3.3/go.mod h1:6W14CJVKh7crrSPyj6NPk4c49L2NWnxvyDLsRkOm4BI= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= diff --git a/internal/analyzer.go b/internal/analyzer.go index d9a5904..23dc9cb 100644 --- a/internal/analyzer.go +++ b/internal/analyzer.go @@ -111,10 +111,6 @@ func (a *Analyzer) AddNamePath(path string) { a.namePath = append(a.namePath, path) } -func (a *Analyzer) SetParameterMode(mode zetasql.ParameterMode) { - a.opt.SetParameterMode(mode) -} - func (a *Analyzer) parseScript(query string) ([]parsed_ast.StatementNode, error) { loc := zetasql.NewParseResumeLocation(query) var stmts []parsed_ast.StatementNode @@ -201,6 +197,32 @@ func (a *Analyzer) getFullNamePathMap(stmts []parsed_ast.StatementNode) (map[str return fullNamePathMap, nil } +func (a *Analyzer) getParameterMode(stmt parsed_ast.StatementNode) (zetasql.ParameterMode, error) { + var ( + enabledNamedParameter bool + enabledPositionalParameter bool + ) + parsed_ast.Walk(stmt, func(node parsed_ast.Node) error { + switch n := node.(type) { + case *parsed_ast.ParameterExprNode: + if n.Position() > 0 { + enabledPositionalParameter = true + } + if n.Name() != nil { + enabledNamedParameter = true + } + } + return nil + }) + if enabledNamedParameter && enabledPositionalParameter { + return zetasql.ParameterNone, fmt.Errorf("named parameter and positional parameter cannot be used together") + } + if enabledPositionalParameter { + return zetasql.ParameterPositional, nil + } + return zetasql.ParameterNamed, nil +} + func (a *Analyzer) AnalyzeIterator(ctx context.Context, conn *Conn, query string, args []driver.NamedValue) (*AnalyzerOutputIterator, error) { if err := a.catalog.Sync(ctx, conn); err != nil { return nil, fmt.Errorf("failed to sync catalog: %w", err) @@ -209,10 +231,21 @@ func (a *Analyzer) AnalyzeIterator(ctx context.Context, conn *Conn, query string if err != nil { return nil, err } + resultStmts := make([]*Statement, 0, len(stmts)) fullNamePathMap, err := a.getFullNamePathMap(stmts) if err != nil { return nil, fmt.Errorf("failed to get full name path map %s: %w", query, err) } + for _, stmt := range stmts { + mode, err := a.getParameterMode(stmt) + if err != nil { + return nil, err + } + resultStmts = append(resultStmts, &Statement{ + stmt: stmt, + mode: mode, + }) + } funcMap := map[string]*FunctionSpec{} for _, spec := range a.catalog.getFunctions(a.namePath) { funcMap[spec.FuncName()] = spec @@ -220,21 +253,26 @@ func (a *Analyzer) AnalyzeIterator(ctx context.Context, conn *Conn, query string return &AnalyzerOutputIterator{ query: query, args: args, - stmts: stmts, + stmts: resultStmts, analyzer: a, funcMap: funcMap, fullNamePathMap: fullNamePathMap, }, nil } +type Statement struct { + stmt parsed_ast.StatementNode + mode zetasql.ParameterMode +} + type AnalyzerOutputIterator struct { query string args []driver.NamedValue analyzer *Analyzer - stmts []parsed_ast.StatementNode + stmts []*Statement stmtIdx int - funcMap map[string]*FunctionSpec fullNamePathMap map[string][]string + funcMap map[string]*FunctionSpec out *zetasql.AnalyzerOutput isEnd bool err error @@ -244,9 +282,11 @@ func (it *AnalyzerOutputIterator) Next() bool { if it.stmtIdx >= len(it.stmts) { return false } + stmt := it.stmts[it.stmtIdx] + it.analyzer.opt.SetParameterMode(stmt.mode) out, err := zetasql.AnalyzeStatementFromParserAST( it.query, - it.stmts[it.stmtIdx], + stmt.stmt, it.analyzer.catalog.getCatalog(it.analyzer.namePath), it.analyzer.opt, ) @@ -341,7 +381,10 @@ func (it *AnalyzerOutputIterator) analyzeCreateTableStmt(ctx context.Context, no if err != nil { return nil, err } - it.stmts = append(it.stmts, stmt) + it.stmts = append(it.stmts, &Statement{ + stmt: stmt, + mode: zetasql.ParameterNamed, + }) } return nil, nil }, @@ -394,7 +437,10 @@ func (it *AnalyzerOutputIterator) analyzeCreateTableAsSelectStmt(ctx context.Con if err != nil { return nil, err } - it.stmts = append(it.stmts, stmt) + it.stmts = append(it.stmts, &Statement{ + stmt: stmt, + mode: zetasql.ParameterNamed, + }) } return nil, nil }, diff --git a/internal/catalog.go b/internal/catalog.go index 84f9eec..5f3d838 100644 --- a/internal/catalog.go +++ b/internal/catalog.go @@ -48,8 +48,9 @@ DELETE FROM zetasqlite_catalog WHERE name = @name type CatalogSpecKind string const ( - TableSpecKind CatalogSpecKind = "table" - FunctionSpecKind CatalogSpecKind = "function" + TableSpecKind CatalogSpecKind = "table" + FunctionSpecKind CatalogSpecKind = "function" + defaultCatalogName = "zetasqlite" ) type Catalog struct { @@ -64,12 +65,16 @@ type Catalog struct { funcMap map[string]*FunctionSpec } +func newSimpleCatalog(name string) *types.SimpleCatalog { + catalog := types.NewSimpleCatalog(name) + catalog.AddZetaSQLBuiltinFunctions(nil) + return catalog +} + func NewCatalog(db *sql.DB) *Catalog { - catalog := types.NewSimpleCatalog("zetasqlite") - catalog.AddZetaSQLBuiltinFunctions() return &Catalog{ db: db, - defaultCatalog: catalog, + defaultCatalog: newSimpleCatalog(defaultCatalogName), pathToCatalogMap: map[string]*types.SimpleCatalog{}, tableMap: map[string]*TableSpec{}, funcMap: map[string]*FunctionSpec{}, @@ -185,7 +190,9 @@ func (c *Catalog) DeleteTableSpec(ctx context.Context, conn *Conn, name string) c.mu.Lock() defer c.mu.Unlock() - // TODO: need to remove table from *types.SimpleCatalog + if err := c.deleteTableSpecByName(name); err != nil { + return err + } if _, err := conn.ExecContext(ctx, deleteCatalogQuery, sql.Named("name", name)); err != nil { return err } @@ -196,13 +203,71 @@ func (c *Catalog) DeleteFunctionSpec(ctx context.Context, conn *Conn, name strin c.mu.Lock() defer c.mu.Unlock() - // TODO: need to remove function from *types.SimpleCatalog + if err := c.deleteFunctionSpecByName(name); err != nil { + return err + } if _, err := conn.ExecContext(ctx, deleteCatalogQuery, sql.Named("name", name)); err != nil { return err } return nil } +func (c *Catalog) deleteTableSpecByName(name string) error { + spec, exists := c.tableMap[name] + if !exists { + return fmt.Errorf("failed to find table spec from map by %s", name) + } + tables := make([]*TableSpec, 0, len(c.tables)) + for _, table := range c.tables { + if spec == table { + continue + } + tables = append(tables, table) + } + c.tables = tables + delete(c.tableMap, name) + if err := c.resetCatalogs(); err != nil { + return fmt.Errorf("failed to reset catalogs: %w", err) + } + return nil +} + +func (c *Catalog) deleteFunctionSpecByName(name string) error { + spec, exists := c.funcMap[name] + if !exists { + return fmt.Errorf("failed to find function spec from map by %s", name) + } + functions := make([]*FunctionSpec, 0, len(c.functions)) + for _, function := range c.functions { + if spec == function { + continue + } + functions = append(functions, function) + } + c.functions = functions + delete(c.funcMap, name) + if err := c.resetCatalogs(); err != nil { + return fmt.Errorf("failed to reset catalogs: %w", err) + } + return nil +} + +func (c *Catalog) resetCatalogs() error { + c.defaultCatalog = newSimpleCatalog(defaultCatalogName) + c.pathToCatalogMap = map[string]*types.SimpleCatalog{} + for _, spec := range c.tables { + if err := c.addTableSpec(spec); err != nil { + return err + } + } + for _, spec := range c.functions { + if err := c.addFunctionSpec(spec); err != nil { + return err + } + } + return nil +} + func (c *Catalog) saveTableSpec(ctx context.Context, conn *Conn, spec *TableSpec) error { encoded, err := json.Marshal(spec) if err != nil { @@ -290,8 +355,7 @@ func (c *Catalog) addFunctionSpec(spec *FunctionSpec) error { catalogMapKey := c.pathToCatalogMapKey(c.trimmedLastPath(spec.NamePath)) cat, exists := c.pathToCatalogMap[catalogMapKey] if !exists { - cat = types.NewSimpleCatalog("zetasqlite") - cat.AddZetaSQLBuiltinFunctions() + cat = newSimpleCatalog(defaultCatalogName) c.pathToCatalogMap[catalogMapKey] = cat } if err := c.addFunctionSpecRecursive(cat, spec); err != nil { @@ -314,8 +378,7 @@ func (c *Catalog) addTableSpec(spec *TableSpec) error { catalogMapKey := c.pathToCatalogMapKey(c.trimmedLastPath(spec.NamePath)) cat, exists := c.pathToCatalogMap[catalogMapKey] if !exists { - cat = types.NewSimpleCatalog("zetasqlite") - cat.AddZetaSQLBuiltinFunctions() + cat = newSimpleCatalog(defaultCatalogName) c.pathToCatalogMap[catalogMapKey] = cat } if err := c.addTableSpecRecursive(cat, spec); err != nil { @@ -330,7 +393,7 @@ func (c *Catalog) addTableSpec(spec *TableSpec) error { func (c *Catalog) addTableSpecRecursive(cat *types.SimpleCatalog, spec *TableSpec) error { if len(spec.NamePath) > 1 { subCatalogName := spec.NamePath[0] - subCatalog := types.NewSimpleCatalog(subCatalogName) + subCatalog := newSimpleCatalog(subCatalogName) if !c.existsCatalog(cat, subCatalogName) { cat.AddCatalog(subCatalog) } @@ -386,7 +449,7 @@ func (c *Catalog) createSimpleTable(tableName string, spec *TableSpec) (*types.S func (c *Catalog) addFunctionSpecRecursive(cat *types.SimpleCatalog, spec *FunctionSpec) error { if len(spec.NamePath) > 1 { subCatalogName := spec.NamePath[0] - subCatalog := types.NewSimpleCatalog(subCatalogName) + subCatalog := newSimpleCatalog(subCatalogName) if !c.existsCatalog(cat, subCatalogName) { cat.AddCatalog(subCatalog) } diff --git a/internal/formatter.go b/internal/formatter.go index c047fea..5a87858 100644 --- a/internal/formatter.go +++ b/internal/formatter.go @@ -414,19 +414,34 @@ func (n *CastNode) FormatSQL(ctx context.Context) (string, error) { return fmt.Sprintf("zetasqlite_cast_%s(%s)", typeSuffix, expr), nil } +func extractColumnNameFromFormattedName(col string) string { + if len(col) == 0 { + return col + } + if col[0] == '`' { + // trimmed back quote + col = col[1 : len(col)-1] + } + return strings.Split(col, "#")[0] +} + func (n *MakeStructNode) FormatSQL(ctx context.Context) (string, error) { if n.node == nil { return "", nil } - var fields []string + var args []string for _, field := range n.node.FieldList() { col, err := newNode(field).FormatSQL(ctx) if err != nil { return "", err } - fields = append(fields, col) + args = append( + args, + fmt.Sprintf(`'%s'`, extractColumnNameFromFormattedName(col)), // field name + col, // field value + ) } - return fmt.Sprintf("zetasqlite_make_struct_struct(%s)", strings.Join(fields, ",")), nil + return fmt.Sprintf("zetasqlite_make_struct_struct(%s)", strings.Join(args, ",")), nil } func (n *MakeProtoNode) FormatSQL(ctx context.Context) (string, error) { diff --git a/internal/function.go b/internal/function.go index d9879f2..7d66799 100644 --- a/internal/function.go +++ b/internal/function.go @@ -516,16 +516,23 @@ func DECODE_ARRAY(v string) (Value, error) { } func MAKE_STRUCT(args ...Value) (Value, error) { - keys := make([]string, len(args)) + keys := make([]string, len(args)/2) + values := make([]Value, len(args)/2) fieldMap := map[string]Value{} - for i := 0; i < len(args); i++ { - key := fmt.Sprintf("_field_%d", i+1) - keys[i] = key - fieldMap[key] = args[i] + for i := 0; i < len(args)/2; i++ { + key := args[i*2] + value := args[i*2+1] + k, err := key.ToString() + if err != nil { + return nil, err + } + keys[i] = k + values[i] = value + fieldMap[k] = value } return &StructValue{ keys: keys, - values: args, + values: values, m: fieldMap, }, nil } diff --git a/internal/function_bind.go b/internal/function_bind.go index 44789ed..3c11799 100644 --- a/internal/function_bind.go +++ b/internal/function_bind.go @@ -204,6 +204,23 @@ func bindTimestampFunc(fn BindFunction) SQLiteFunction { } } +func bindJsonFunc(fn BindFunction) SQLiteFunction { + return func(args ...interface{}) (interface{}, error) { + values, err := convertArgs(args...) + if err != nil { + return nil, err + } + ret, err := fn(values...) + if err != nil { + return nil, err + } + if ret == nil { + return nil, nil + } + return ret.ToString() + } +} + func bindArrayFunc(fn BindFunction) SQLiteFunction { return func(args ...interface{}) (interface{}, error) { values, err := convertArgs(args...) @@ -289,6 +306,12 @@ var ( } return v.ToString() } + jsonValueConverter = func(v Value) (interface{}, error) { + if v == nil { + return nil, nil + } + return v.ToString() + } arrayValueConverter = func(v Value) (interface{}, error) { if v == nil { return nil, nil @@ -339,6 +362,10 @@ func bindAggregateTimestampFunc(bindFunc func(ReturnValueConverter) func() *Aggr return bindFunc(timestampValueConverter) } +func bindAggregateJsonFunc(bindFunc func(ReturnValueConverter) func() *Aggregator) func() *Aggregator { + return bindFunc(jsonValueConverter) +} + func bindAggregateArrayFunc(bindFunc func(ReturnValueConverter) func() *Aggregator) func() *Aggregator { return bindFunc(arrayValueConverter) } @@ -452,6 +479,10 @@ func bindWindowTimestampFunc(bindFunc func(ReturnValueConverter) func() *WindowA return bindFunc(timestampValueConverter) } +func bindWindowJsonFunc(bindFunc func(ReturnValueConverter) func() *WindowAggregator) func() *WindowAggregator { + return bindFunc(jsonValueConverter) +} + func bindWindowArrayFunc(bindFunc func(ReturnValueConverter) func() *WindowAggregator) func() *WindowAggregator { return bindFunc(arrayValueConverter) } @@ -970,6 +1001,78 @@ func bindFormat(args ...Value) (Value, error) { return FORMAT(format) } +func bindToJson(args ...Value) (Value, error) { + if len(args) != 1 && len(args) != 2 { + return nil, fmt.Errorf("TO_JSON: invalid argument num %d", len(args)) + } + var stringifyWideNumbers bool + if len(args) == 2 { + b, err := args[1].ToBool() + if err != nil { + return nil, err + } + stringifyWideNumbers = b + } + return TO_JSON(args[0], stringifyWideNumbers) +} + +func bindToJsonString(args ...Value) (Value, error) { + if len(args) != 1 && len(args) != 2 { + return nil, fmt.Errorf("TO_JSON_STRING: invalid argument num %d", len(args)) + } + var prettyPrint bool + if len(args) == 2 { + b, err := args[1].ToBool() + if err != nil { + return nil, err + } + prettyPrint = b + } + return TO_JSON_STRING(args[0], prettyPrint) +} + +func bindBool(args ...Value) (Value, error) { + if len(args) != 1 { + return nil, fmt.Errorf("BOOL: invalid argument num %d", len(args)) + } + return args[0], nil +} + +func bindInt64(args ...Value) (Value, error) { + if len(args) != 1 { + return nil, fmt.Errorf("INT64: invalid argument num %d", len(args)) + } + return args[0], nil +} + +func bindDouble(args ...Value) (Value, error) { + if len(args) != 2 { + return nil, fmt.Errorf("FLOAT64: invalid argument num %d", len(args)) + } + mode, err := args[1].ToString() + if err != nil { + return nil, err + } + switch mode { + case "exact": + return args[0], nil + case "round": + return args[0], nil + } + return nil, fmt.Errorf("unexpected wide_number_mode: %s", mode) +} + +func bindJsonType(args ...Value) (Value, error) { + if len(args) != 1 { + return nil, fmt.Errorf("JSON_TYPE: invalid argument num %d", len(args)) + } + value, ok := args[0].(JsonValue) + if !ok { + return nil, fmt.Errorf("JSON_TYPE: failed to convert %T to JSON value", args[0]) + } + return JSON_TYPE(value) +} + func bindAbs(args ...Value) (Value, error) { if len(args) != 1 { return nil, fmt.Errorf("ABS: invalid argument num %d", len(args)) @@ -1826,6 +1929,10 @@ func bindString(args ...Value) (Value, error) { if existsNull(args) { return nil, nil } + jsonValue, ok := args[0].(JsonValue) + if ok { + return StringValue(fmt.Sprint(jsonValue.Interface())), nil + } t, err := args[0].ToTime() if err != nil { return nil, err @@ -2180,6 +2287,9 @@ func bindArrayReverse(args ...Value) (Value, error) { } func bindMakeStruct(args ...Value) (Value, error) { + if len(args)%2 != 0 { + return nil, fmt.Errorf("MAKE_STRUCT: unexpected argument num %d", len(args)) + } return MAKE_STRUCT(args...) } diff --git a/internal/function_json.go b/internal/function_json.go new file mode 100644 index 0000000..d6ffa8e --- /dev/null +++ b/internal/function_json.go @@ -0,0 +1,21 @@ +package internal + +func TO_JSON(v Value, stringifyWideNumbers bool) (Value, error) { + s, err := v.ToJSON() + if err != nil { + return nil, err + } + return JsonValue(s), nil +} + +func TO_JSON_STRING(v Value, prettyPrint bool) (Value, error) { + s, err := v.ToJSON() + if err != nil { + return nil, err + } + return StringValue(s), nil +} + +func JSON_TYPE(v JsonValue) (Value, error) { + return StringValue(v.Type()), nil +} diff --git a/internal/function_register.go b/internal/function_register.go index a9abea6..c2b523f 100644 --- a/internal/function_register.go +++ b/internal/function_register.go @@ -538,6 +538,38 @@ var normalFuncs = []*FuncInfo{ ReturnTypes: []types.TypeKind{types.STRING}, }, + // json functions + { + Name: "to_json", + BindFunc: bindToJson, + ReturnTypes: []types.TypeKind{types.JSON}, + }, + { + Name: "to_json_string", + BindFunc: bindToJsonString, + ReturnTypes: []types.TypeKind{types.STRING}, + }, + { + Name: "bool", + BindFunc: bindBool, + ReturnTypes: []types.TypeKind{types.BOOL}, + }, + { + Name: "int64", + BindFunc: bindInt64, + ReturnTypes: []types.TypeKind{types.INT64}, + }, + { + Name: "double", + BindFunc: bindDouble, + ReturnTypes: []types.TypeKind{types.DOUBLE}, + }, + { + Name: "json_type", + BindFunc: bindJsonType, + ReturnTypes: []types.TypeKind{types.STRING}, + }, + // math functions { @@ -1171,6 +1203,9 @@ func setupNormalFuncMap(info *FuncInfo) error { case types.STRUCT: name = fmt.Sprintf("zetasqlite_%s_struct", info.Name) fn = bindStructFunc(info.BindFunc) + case types.JSON: + name = fmt.Sprintf("zetasqlite_%s_json", info.Name) + fn = bindJsonFunc(info.BindFunc) default: return fmt.Errorf("unsupported return type %s for function: %s", retType, info.Name) } @@ -1222,6 +1257,9 @@ func setupAggregateFuncMap(info *AggregateFuncInfo) error { case types.STRUCT: name = fmt.Sprintf("zetasqlite_%s_struct", info.Name) aggregator = bindAggregateStructFunc(info.BindFunc) + case types.JSON: + name = fmt.Sprintf("zetasqlite_%s_json", info.Name) + aggregator = bindAggregateJsonFunc(info.BindFunc) default: return fmt.Errorf("unsupported return type %s for aggregate function: %s", retType, info.Name) } @@ -1273,6 +1311,9 @@ func setupWindowFuncMap(info *WindowFuncInfo) error { case types.STRUCT: name = fmt.Sprintf("zetasqlite_window_%s_struct", info.Name) aggregator = bindWindowStructFunc(info.BindFunc) + case types.JSON: + name = fmt.Sprintf("zetasqlite_window_%s_json", info.Name) + aggregator = bindWindowJsonFunc(info.BindFunc) default: return fmt.Errorf("unsupported return type %s for window function: %s", retType, info.Name) } diff --git a/internal/json.go b/internal/json.go index fd21fa9..ff0558d 100644 --- a/internal/json.go +++ b/internal/json.go @@ -11,7 +11,7 @@ import ( func JSONFromZetaSQLValue(v types.Value) string { value := jsonFromZetaSQLValue(v) - if value == "null" { + if value == "null" && v.Type().Kind() != types.JSON { return "null" } switch v.Type().Kind() { @@ -31,6 +31,9 @@ func JSONFromZetaSQLValue(v types.Value) string { return toArrayValueFromJSONString(value) case types.STRUCT: return toStructValueFromJSONString(value) + case types.JSON: + v, _ := toJsonValueFromString(value) + return v } return value } @@ -68,6 +71,8 @@ func jsonFromZetaSQLValue(v types.Value) string { ) } return fmt.Sprintf("{%s}", strings.Join(fields, ",")) + case types.JSON: + return v.JSONString() default: vv := v.SQLLiteral(0) if vv == "NULL" { diff --git a/internal/rows.go b/internal/rows.go index 52912df..512dad7 100644 --- a/internal/rows.go +++ b/internal/rows.go @@ -169,6 +169,16 @@ func (r *Rows) convertValue(value interface{}, typ *Type) (driver.Value, error) return v, nil case types.STRUCT: return array.Interface(), nil + case types.JSON: + v := []string{} + for _, value := range array.values { + jv, err := value.ToJSON() + if err != nil { + return nil, err + } + v = append(v, jv) + } + return v, nil } case types.STRUCT: val, err := ValueOf(value) @@ -208,6 +218,12 @@ func (r *Rows) convertValue(value interface{}, typ *Type) (driver.Value, error) return nil, err } return t.UTC(), nil + case types.JSON: + val, err := ValueOf(value) + if err != nil { + return nil, err + } + return val.ToJSON() } return value, nil } diff --git a/internal/value.go b/internal/value.go index 608394a..3e58614 100644 --- a/internal/value.go +++ b/internal/value.go @@ -697,6 +697,130 @@ func (bv BoolValue) Interface() interface{} { return bool(bv) } +type JsonValue string + +func (jv JsonValue) Add(v Value) (Value, error) { + return nil, fmt.Errorf("add operation is unsupported for json %v", jv) +} + +func (jv JsonValue) Sub(v Value) (Value, error) { + return nil, fmt.Errorf("sub operation is unsupported for json %v", jv) +} + +func (jv JsonValue) Mul(v Value) (Value, error) { + return nil, fmt.Errorf("mul operation is unsupported for json %v", jv) +} + +func (jv JsonValue) Div(v Value) (Value, error) { + return nil, fmt.Errorf("div operation is unsupported for json %v", jv) +} + +func (jv JsonValue) EQ(v Value) (bool, error) { + return false, fmt.Errorf("eq operation is unsupported for json %v", jv) +} + +func (jv JsonValue) GT(v Value) (bool, error) { + return false, fmt.Errorf("gt operation is unsupported for json %v", jv) +} + +func (jv JsonValue) GTE(v Value) (bool, error) { + return false, fmt.Errorf("gte operation is unsupported for json %v", jv) +} + +func (jv JsonValue) LT(v Value) (bool, error) { + return false, fmt.Errorf("lt operation is unsupported for json %v", jv) +} + +func (jv JsonValue) LTE(v Value) (bool, error) { + return false, fmt.Errorf("lte operation is unsupported for json %v", jv) +} + +func (jv JsonValue) ToInt64() (int64, error) { + return strconv.ParseInt(string(jv), 10, 64) +} + +func (jv JsonValue) ToString() (string, error) { + return toJsonValueFromString(string(jv)) +} + +func (jv JsonValue) ToFloat64() (float64, error) { + return strconv.ParseFloat(string(jv), 64) +} + +func (jv JsonValue) ToBool() (bool, error) { + return strconv.ParseBool(string(jv)) +} + +func (jv JsonValue) ToArray() (*ArrayValue, error) { + return nil, fmt.Errorf("failed to convert json from array: %v", jv) +} + +func (jv JsonValue) ToStruct() (*StructValue, error) { + return nil, fmt.Errorf("failed to convert json from struct: %v", jv) +} + +func (jv JsonValue) ToJSON() (string, error) { + return string(jv), nil +} + +func (jv JsonValue) ToTime() (time.Time, error) { + return time.Time{}, fmt.Errorf("failed to convert json from time.Time: %v", jv) +} + +func (jv JsonValue) ToRat() (*big.Rat, error) { + i64, err := strconv.ParseInt(string(jv), 10, 64) + if err != nil { + return nil, err + } + r := new(big.Rat) + r.SetInt64(i64) + return r, nil +} + +func (jv JsonValue) Marshal() (string, error) { + return jv.ToString() +} + +func (jv JsonValue) Format(verb rune) string { + return string(jv) +} + +func (jv JsonValue) Interface() interface{} { + var v interface{} + if err := json.Unmarshal([]byte(jv), &v); err != nil { + return nil + } + return v +} + +func (jv JsonValue) reflectTypeToJsonType(t reflect.Type) string { + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + return "number" + case reflect.String: + return "string" + case reflect.Bool: + return "boolean" + case reflect.Slice, reflect.Array: + return "array" + case reflect.Struct, reflect.Map: + return "object" + case reflect.Ptr: + return jv.reflectTypeToJsonType(t.Elem()) + } + return "unknown" +} + +func (jv JsonValue) Type() string { + if string(jv) == "null" { + return "null" + } + rv := reflect.ValueOf(jv.Interface()) + return jv.reflectTypeToJsonType(rv.Type()) +} + type ArrayValue struct { values []Value } @@ -1864,6 +1988,7 @@ const ( DatetimeValueHeader = "zetasqlitedatetime:" TimeValueHeader = "zetasqlitetime:" TimestampValueHeader = "zetasqlitetimestamp:" + JsonValueHeader = "zetasqlitejson:" ) func ValueOf(v interface{}) (Value, error) { @@ -1912,6 +2037,8 @@ func ValueOf(v interface{}) (Value, error) { return TimeValueOf(vv) case isTimestampValue(vv): return TimestampValueOf(vv) + case isJsonValue(vv): + return JsonValueOf(vv) } return StringValue(vv), nil case []byte: @@ -1996,6 +2123,16 @@ func isTimestampValue(v string) bool { return strings.HasPrefix(v, TimestampValueHeader) } +func isJsonValue(v string) bool { + if len(v) < len(JsonValueHeader) { + return false + } + if v[0] == '"' { + return strings.HasPrefix(v[1:], JsonValueHeader) + } + return strings.HasPrefix(v, JsonValueHeader) +} + func NumericValueOf(v string) (Value, error) { numeric, err := numericValueFromEncodedString(v) if err != nil { @@ -2036,6 +2173,14 @@ func TimestampValueOf(v string) (Value, error) { return TimestampValue(date), nil } +func JsonValueOf(v string) (Value, error) { + json, err := jsonValueFromEncodedString(v) + if err != nil { + return nil, fmt.Errorf("failed to get json value from encoded string: %w", err) + } + return JsonValue(json), nil +} + func ArrayValueOf(v string) (Value, error) { arr, err := arrayValueFromEncodedString(v) if err != nil { @@ -2199,6 +2344,25 @@ func timestampValueFromEncodedString(v string) (time.Time, error) { return parseTimestamp(content, loc) } +func jsonValueFromEncodedString(v string) (string, error) { + if len(v) == 0 { + return "", nil + } + if v[0] == '"' { + unquoted, err := strconv.Unquote(v) + if err != nil { + return "", fmt.Errorf("failed to unquote value %q: %w", v, err) + } + v = unquoted + } + content := v[len(JsonValueHeader):] + decoded, err := base64.StdEncoding.DecodeString(content) + if err != nil { + return "", fmt.Errorf("failed to base64 decode for json value %q: %w", content, err) + } + return string(decoded), nil +} + func arrayValueFromEncodedString(v string) ([]interface{}, error) { if len(v) == 0 { return nil, nil @@ -2305,6 +2469,16 @@ func toTimestampValueFromString(s string) (string, error) { ), nil } +func toJsonValueFromString(s string) (string, error) { + return strconv.Quote( + fmt.Sprintf( + "%s%s", + JsonValueHeader, + base64.StdEncoding.EncodeToString([]byte(s)), + ), + ), nil +} + func formatTimestamp(s string) (string, error) { loc, err := time.LoadLocation("") if err != nil { @@ -2694,7 +2868,11 @@ func encodeValueWithType(v interface{}, t types.Type) (interface{}, error) { case types.EXTENDED: return nil, fmt.Errorf("failed to convert EXTENDED type from %T", v) case types.JSON: - return nil, fmt.Errorf("failed to convert JSON type from %T", v) + b, err := json.Marshal(v) + if err != nil { + return nil, err + } + return string(b), nil case types.INTERVAL: return nil, fmt.Errorf("failed to convert INTERVAL type from %T", v) default: diff --git a/query_test.go b/query_test.go index 30016f0..d9aa709 100644 --- a/query_test.go +++ b/query_test.go @@ -1193,24 +1193,24 @@ FROM finishers`, []interface{}{ []map[string]interface{}{ map[string]interface{}{ - "_field_1": float64(1), + "$col1": float64(1), }, map[string]interface{}{ - "_field_2": float64(2), + "$col2": float64(2), }, map[string]interface{}{ - "_field_3": float64(3), + "$col3": float64(3), }, }, []map[string]interface{}{ map[string]interface{}{ - "_field_1": float64(4), + "$col1": float64(4), }, map[string]interface{}{ - "_field_2": float64(5), + "$col2": float64(5), }, map[string]interface{}{ - "_field_3": float64(6), + "$col3": float64(6), }, }, }, @@ -1225,7 +1225,7 @@ FROM finishers`, []interface{}{ []map[string]interface{}{ map[string]interface{}{ - "_field_1": []interface{}{ + "$col1": []interface{}{ float64(1), float64(2), float64(3), @@ -1234,7 +1234,7 @@ FROM finishers`, }, []map[string]interface{}{ map[string]interface{}{ - "_field_1": []interface{}{ + "$col1": []interface{}{ float64(4), float64(5), float64(6), @@ -2285,6 +2285,83 @@ SELECT Add(3, 4); query: `WITH orders AS (SELECT 5 as order_id, "sprocket" as item_name, 200 as quantity) SELECT * REPLACE (quantity/2 AS quantity) FROM orders`, expectedRows: [][]interface{}{{int64(5), "sprocket", float64(100)}}, }, + + // json + { + name: "to_json", + query: ` +With CoordinatesTable AS ( + (SELECT 1 AS id, [10,20] AS coordinates) UNION ALL + (SELECT 2 AS id, [30,40] AS coordinates) UNION ALL + (SELECT 3 AS id, [50,60] AS coordinates)) +SELECT TO_JSON(t) AS json_objects FROM CoordinatesTable AS t`, + expectedRows: [][]interface{}{ + {`{"id":1,"coordinates":[10,20]}`}, + {`{"id":2,"coordinates":[30,40]}`}, + {`{"id":3,"coordinates":[50,60]}`}, + }, + }, + { + name: "to_json_string", + query: ` +With CoordinatesTable AS ( + (SELECT 1 AS id, [10,20] AS coordinates) UNION ALL + (SELECT 2 AS id, [30,40] AS coordinates) UNION ALL + (SELECT 3 AS id, [50,60] AS coordinates)) +SELECT id, coordinates, TO_JSON_STRING(t) AS json_data +FROM CoordinatesTable AS t`, + expectedRows: [][]interface{}{ + {int64(1), []int64{10, 20}, `{"id":1,"coordinates":[10,20]}`}, + {int64(2), []int64{30, 40}, `{"id":2,"coordinates":[30,40]}`}, + {int64(3), []int64{50, 60}, `{"id":3,"coordinates":[50,60]}`}, + }, + }, + { + name: "json_string", + query: `SELECT STRING(JSON '"purple"') AS color`, + expectedRows: [][]interface{}{{"purple"}}, + }, + { + name: "json_bool", + query: `SELECT BOOL(JSON 'true') AS vacancy`, + expectedRows: [][]interface{}{{true}}, + }, + { + name: "json_int64", + query: `SELECT INT64(JSON '2005') AS flight_number`, + expectedRows: [][]interface{}{{int64(2005)}}, + }, + { + name: "json_float64", + query: `SELECT FLOAT64(JSON '9.8') AS velocity`, + expectedRows: [][]interface{}{{float64(9.8)}}, + }, + { + name: "json_type", + query: ` +SELECT json_val, JSON_TYPE(json_val) AS type +FROM + UNNEST( + [ + JSON '"apple"', + JSON '10', + JSON '3.14', + JSON 'null', + JSON '{"city": "New York", "State": "NY"}', + JSON '["apple", "banana"]', + JSON 'false' + ] + ) AS json_val`, + expectedRows: [][]interface{}{ + {`"apple"`, "string"}, + {"10", "number"}, + {"3.14", "number"}, + {"null", "null"}, + {`{"State":"NY","city":"New York"}`, "object"}, + {`["apple","banana"]`, "array"}, + {"false", "boolean"}, + }, + }, } { test := test t.Run(test.name, func(t *testing.T) {