From 05aea846ba20fb6729375ab28059a938494ec1ea Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Sat, 25 Jun 2022 02:00:53 +0900 Subject: [PATCH 1/2] Fix UDF - refactor sqlite3 connection management --- analyzer.go | 65 ++++++++++++++++++--------------- catalog.go | 101 ++++++++++++++++++---------------------------------- context.go | 13 +++++++ driver.go | 97 +++++++++++++++++++++++++++++++------------------ node.go | 18 ++++++++++ rows.go | 34 +++++++++++++++--- spec.go | 35 ------------------ stmt.go | 34 +++++++++++------- value.go | 20 +++++------ 9 files changed, 225 insertions(+), 192 deletions(-) diff --git a/analyzer.go b/analyzer.go index 8aa2e49..65177cb 100644 --- a/analyzer.go +++ b/analyzer.go @@ -2,6 +2,7 @@ package zetasqlite import ( "context" + "database/sql" "database/sql/driver" "fmt" @@ -25,9 +26,9 @@ type AnalyzerOutput struct { isQuery bool tableSpec *TableSpec outputColumns []*ColumnSpec - prepare func(conn driver.Conn) (driver.Stmt, error) - execContext func(context.Context, driver.ExecerContext, []driver.NamedValue) (driver.Result, error) - queryContext func(context.Context, driver.QueryerContext, []driver.NamedValue) (driver.Rows, error) + prepare func(context.Context, *sql.Conn) (driver.Stmt, error) + execContext func(context.Context, *sql.Conn, ...interface{}) (driver.Result, error) + queryContext func(context.Context, *sql.Conn, ...interface{}) (driver.Rows, error) } func newAnalyzer(catalog *Catalog) *Analyzer { @@ -74,8 +75,8 @@ func newAnalyzerOptions() *zetasql.AnalyzerOptions { return opt } -func (a *Analyzer) Analyze(query string) (*AnalyzerOutput, error) { - if err := a.catalog.Sync(); err != nil { +func (a *Analyzer) Analyze(ctx context.Context, query string) (*AnalyzerOutput, error) { + if err := a.catalog.Sync(ctx); err != nil { return nil, fmt.Errorf("failed to sync catalog: %w", err) } out, err := zetasql.AnalyzeStatement(query, a.catalog.catalog, a.opt) @@ -86,9 +87,14 @@ func (a *Analyzer) Analyze(query string) (*AnalyzerOutput, error) { if err != nil { return nil, fmt.Errorf("failed to get full name path %s: %w", query, err) } - ctx := withNamePath(context.Background(), a.namePath) + funcMap := map[string]*FunctionSpec{} + for _, spec := range a.catalog.functions { + funcMap[spec.FuncName()] = spec + } + ctx = withNamePath(ctx, a.namePath) ctx = withColumnRefMap(ctx, map[string]string{}) ctx = withFullNamePath(ctx, fullpath) + ctx = withFuncMap(ctx, funcMap) stmtNode := out.Statement() switch stmtNode.Kind() { case ast.CreateTableStmt: @@ -110,33 +116,30 @@ func (a *Analyzer) analyzeCreateTableStmt(query string, node *ast.CreateTableStm query: query, argsNum: a.getParamNumFromNode(node), tableSpec: spec, - prepare: func(conn driver.Conn) (driver.Stmt, error) { + prepare: func(ctx context.Context, conn *sql.Conn) (driver.Stmt, error) { if spec.CreateMode == ast.CreateOrReplaceMode { - stmt, err := conn.Prepare(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", spec.TableName())) - if err != nil { - return nil, err - } - if _, err := stmt.Exec(nil); err != nil { + query := fmt.Sprintf("DROP TABLE IF EXISTS `%s`", spec.TableName()) + if _, err := conn.ExecContext(ctx, query); err != nil { return nil, err } } - s, err := conn.Prepare(spec.SQLiteSchema()) + s, err := conn.PrepareContext(ctx, spec.SQLiteSchema()) if err != nil { return nil, fmt.Errorf("failed to prepare %s: %w", query, err) } return newCreateTableStmt(s, a.catalog, spec), nil }, - execContext: func(ctx context.Context, execer driver.ExecerContext, args []driver.NamedValue) (driver.Result, error) { + execContext: func(ctx context.Context, conn *sql.Conn, args ...interface{}) (driver.Result, error) { if spec.CreateMode == ast.CreateOrReplaceMode { dropTableQuery := fmt.Sprintf("DROP TABLE IF EXISTS `%s`", spec.TableName()) - if _, err := execer.ExecContext(ctx, dropTableQuery, nil); err != nil { + if _, err := conn.ExecContext(ctx, dropTableQuery); err != nil { return nil, err } } - if _, err := execer.ExecContext(ctx, spec.SQLiteSchema(), args); err != nil { + if _, err := conn.ExecContext(ctx, spec.SQLiteSchema(), args...); err != nil { return nil, fmt.Errorf("failed to exec %s: %w", query, err) } - if err := a.catalog.AddNewTableSpec(spec); err != nil { + if err := a.catalog.AddNewTableSpec(ctx, spec); err != nil { return nil, fmt.Errorf("failed to add new table spec: %w", err) } return nil, nil @@ -152,15 +155,21 @@ func (a *Analyzer) analyzeCreateFunctionStmt(ctx context.Context, query string, return &AnalyzerOutput{ query: query, node: node, - prepare: func(conn driver.Conn) (driver.Stmt, error) { + prepare: func(ctx context.Context, conn *sql.Conn) (driver.Stmt, error) { return newCreateFunctionStmt(a.catalog, spec), nil }, - execContext: func(ctx context.Context, execer driver.ExecerContext, args []driver.NamedValue) (driver.Result, error) { - if err := a.catalog.AddNewFunctionSpec(spec); err != nil { + execContext: func(ctx context.Context, conn *sql.Conn, args ...interface{}) (driver.Result, error) { + if err := a.catalog.AddNewFunctionSpec(ctx, spec); err != nil { return nil, fmt.Errorf("failed to add new function spec: %w", err) } return nil, nil }, + queryContext: func(ctx context.Context, conn *sql.Conn, args ...interface{}) (driver.Rows, error) { + if err := a.catalog.AddNewFunctionSpec(ctx, spec); err != nil { + return nil, fmt.Errorf("failed to add new function spec: %w", err) + } + return &Rows{}, nil + }, }, nil } @@ -178,15 +187,15 @@ func (a *Analyzer) analyzeDMLStmt(ctx context.Context, query string, node ast.No query: query, formattedQuery: formattedQuery, argsNum: argsNum, - prepare: func(conn driver.Conn) (driver.Stmt, error) { - s, err := conn.Prepare(formattedQuery) + prepare: func(ctx context.Context, conn *sql.Conn) (driver.Stmt, error) { + s, err := conn.PrepareContext(ctx, formattedQuery) if err != nil { return nil, fmt.Errorf("failed to prepare %s: %w", query, err) } return newDMLStmt(s, argsNum, formattedQuery), nil }, - execContext: func(ctx context.Context, execer driver.ExecerContext, args []driver.NamedValue) (driver.Result, error) { - if _, err := execer.ExecContext(ctx, formattedQuery, args); err != nil { + execContext: func(ctx context.Context, conn *sql.Conn, args ...interface{}) (driver.Result, error) { + if _, err := conn.ExecContext(ctx, formattedQuery, args...); err != nil { return nil, fmt.Errorf("failed to exec %s: %w", formattedQuery, err) } return nil, nil @@ -216,15 +225,15 @@ func (a *Analyzer) analyzeQueryStmt(ctx context.Context, query string, node *ast formattedQuery: formattedQuery, argsNum: argsNum, isQuery: true, - prepare: func(conn driver.Conn) (driver.Stmt, error) { - s, err := conn.Prepare(formattedQuery) + prepare: func(ctx context.Context, conn *sql.Conn) (driver.Stmt, error) { + s, err := conn.PrepareContext(ctx, formattedQuery) if err != nil { return nil, fmt.Errorf("failed to prepare %s: %w", query, err) } return newQueryStmt(s, argsNum, formattedQuery, outputColumns), nil }, - queryContext: func(ctx context.Context, queryer driver.QueryerContext, args []driver.NamedValue) (driver.Rows, error) { - rows, err := queryer.QueryContext(ctx, formattedQuery, args) + queryContext: func(ctx context.Context, conn *sql.Conn, args ...interface{}) (driver.Rows, error) { + rows, err := conn.QueryContext(ctx, formattedQuery, args...) if err != nil { return nil, fmt.Errorf("failed to query %s: %w", formattedQuery, err) } diff --git a/catalog.go b/catalog.go index ff767e5..eb01cad 100644 --- a/catalog.go +++ b/catalog.go @@ -1,12 +1,11 @@ package zetasqlite import ( - "database/sql/driver" + "context" "encoding/json" "fmt" - "io" "reflect" - "sync" + "time" "github.com/goccy/go-zetasql/types" ) @@ -37,7 +36,6 @@ type Catalog struct { functions []*FunctionSpec tableMap map[string]*TableSpec funcMap map[string]*FunctionSpec - mu sync.Mutex } func newCatalog(conn *ZetaSQLiteConn) *Catalog { @@ -51,38 +49,25 @@ func newCatalog(conn *ZetaSQLiteConn) *Catalog { } } -func (c *Catalog) Sync() error { - if err := c.createCatalogTablesIfNotExists(); err != nil { +func (c *Catalog) Sync(ctx context.Context) error { + if err := c.createCatalogTablesIfNotExists(ctx); err != nil { return err } - c.mu.Lock() - defer c.mu.Unlock() - - stmt, err := c.conn.sqliteConn.Prepare(loadCatalogQuery) - if err != nil { - return fmt.Errorf("failed to prepare load catalog: %w", err) - } - defer stmt.Close() - rows, err := stmt.Query(nil) + rows, err := c.conn.conn.QueryContext(ctx, loadCatalogQuery) if err != nil { return fmt.Errorf("failed to query load catalog: %w", err) } defer rows.Close() - for { + for rows.Next() { var ( name string kind CatalogSpecKind spec string ) - values := []driver.Value{&name, &kind, &spec} - if err := rows.Next(values); err != nil { - if err != io.EOF { - return fmt.Errorf("failed to load catalog values: %w", err) - } - break + if err := rows.Scan(&name, &kind, &spec); err != nil { + return fmt.Errorf("failed to scan catalog values: %w", err) } - spec = values[2].(string) - switch CatalogSpecKind(values[1].(string)) { + switch kind { case TableSpecKind: if err := c.loadTableSpec(spec); err != nil { return fmt.Errorf("failed to load table spec: %w", err) @@ -98,61 +83,65 @@ func (c *Catalog) Sync() error { return nil } -func (c *Catalog) AddNewTableSpec(spec *TableSpec) error { +func (c *Catalog) AddNewTableSpec(ctx context.Context, spec *TableSpec) error { if err := c.addTableSpec(spec); err != nil { return err } - if err := c.saveTableSpec(spec); err != nil { + if err := c.saveTableSpec(ctx, spec); err != nil { return err } return nil } -func (c *Catalog) AddNewFunctionSpec(spec *FunctionSpec) error { +func (c *Catalog) AddNewFunctionSpec(ctx context.Context, spec *FunctionSpec) error { if err := c.addFunctionSpec(spec); err != nil { return err } - if err := c.saveFunctionSpec(spec); err != nil { + if err := c.saveFunctionSpec(ctx, spec); err != nil { return err } return nil } -func (c *Catalog) saveTableSpec(spec *TableSpec) error { +func (c *Catalog) saveTableSpec(ctx context.Context, spec *TableSpec) error { encoded, err := json.Marshal(spec) if err != nil { return fmt.Errorf("failed to encode table spec: %w", err) } - if err := c.exec(upsertCatalogQuery, []driver.Value{ + if err := c.exec( + ctx, + upsertCatalogQuery, spec.TableName(), string(TableSpecKind), string(encoded), string(encoded), - }); err != nil { + ); err != nil { return fmt.Errorf("failed to save a new table spec: %w", err) } return nil } -func (c *Catalog) saveFunctionSpec(spec *FunctionSpec) error { +func (c *Catalog) saveFunctionSpec(ctx context.Context, spec *FunctionSpec) error { encoded, err := json.Marshal(spec) if err != nil { return fmt.Errorf("failed to encode function spec: %w", err) } - if err := c.exec(upsertCatalogQuery, []driver.Value{ + if err := c.exec( + ctx, + upsertCatalogQuery, spec.FuncName(), string(FunctionSpecKind), string(encoded), string(encoded), - }); err != nil { + ); err != nil { return fmt.Errorf("failed to save a new function spec: %w", err) } return nil } -func (c *Catalog) createCatalogTablesIfNotExists() error { - if err := c.exec(createCatalogTableQuery, nil); err != nil { +func (c *Catalog) createCatalogTablesIfNotExists(ctx context.Context) error { + if err := c.exec(ctx, createCatalogTableQuery); err != nil { return fmt.Errorf("failed to create catalog table: %w", err) } return nil @@ -185,9 +174,6 @@ func (c *Catalog) addFunctionSpec(spec *FunctionSpec) error { if _, exists := c.funcMap[funcName]; exists { return nil } - if err := c.conn.sqliteConn.RegisterFunc(funcName, spec.FuncBody(c.conn), true); err != nil { - return fmt.Errorf("failed to register user defined function: %w", err) - } c.functions = append(c.functions, spec) c.funcMap[funcName] = spec return c.addFunctionSpecRecursive(c.catalog, spec) @@ -338,32 +324,15 @@ func (c *Catalog) copyFunctionSpec(spec *FunctionSpec, newNamePath []string) *Fu } } -func (c *Catalog) exec(query string, args []driver.Value) error { - c.mu.Lock() - defer c.mu.Unlock() - - stmt, err := c.conn.sqliteConn.Prepare(query) - if err != nil { - return fmt.Errorf("failed to prepare %s: %w", query, err) - } - defer stmt.Close() - if _, err := stmt.Exec(args); err != nil { - return fmt.Errorf("failed to exec %s: %w", query, err) - } - return nil -} - -func (c *Catalog) query(query string, args []driver.Value) (driver.Rows, error) { - c.mu.Lock() - defer c.mu.Unlock() - - stmt, err := c.conn.sqliteConn.Prepare(query) - if err != nil { - return nil, fmt.Errorf("failed to prepare %s: %w", query, err) - } - rows, err := stmt.Query(args) - if err != nil { - return nil, fmt.Errorf("failed to query %s: %w", query, err) +func (c *Catalog) exec(ctx context.Context, query string, args ...interface{}) error { + var retErr error + for i := 0; i < 5; i++ { + if _, err := c.conn.conn.ExecContext(ctx, query, args...); err != nil { + retErr = err + time.Sleep(100 * time.Millisecond) + continue + } + return nil } - return rows, nil + return retErr } diff --git a/context.go b/context.go index 68c7e22..de130a8 100644 --- a/context.go +++ b/context.go @@ -6,6 +6,7 @@ type ( namePathKey struct{} fullNamePathKey struct{} columnRefMapKey struct{} + funcMapKey struct{} ) func namePathFromContext(ctx context.Context) []string { @@ -48,3 +49,15 @@ func columnRefMap(ctx context.Context) map[string]string { } return value.(map[string]string) } + +func withFuncMap(ctx context.Context, m map[string]*FunctionSpec) context.Context { + return context.WithValue(ctx, funcMapKey{}, m) +} + +func funcMapFromContext(ctx context.Context) map[string]*FunctionSpec { + value := ctx.Value(funcMapKey{}) + if value == nil { + return nil + } + return value.(map[string]*FunctionSpec) +} diff --git a/driver.go b/driver.go index 2690def..e29711a 100644 --- a/driver.go +++ b/driver.go @@ -5,26 +5,59 @@ import ( "database/sql" "database/sql/driver" "fmt" + "sync" "github.com/mattn/go-sqlite3" ) -func init() { - sql.Register("zetasqlite", &ZetaSQLiteDriver{}) -} - var ( _ driver.Driver = &ZetaSQLiteDriver{} _ driver.Conn = &ZetaSQLiteConn{} _ driver.Tx = &ZetaSQLiteTx{} ) +var ( + nameToDBMap = map[string]*sql.DB{} + nameToDBMapMu sync.Mutex +) + +func init() { + sql.Register("zetasqlite", &ZetaSQLiteDriver{}) + sql.Register("zetasqlite_sqlite3", &sqlite3.SQLiteDriver{ + ConnectHook: func(conn *sqlite3.SQLiteConn) error { + if err := registerBuiltinFunctions(conn); err != nil { + return err + } + return nil + }, + }) +} + +func openDB(name string) (*sql.DB, error) { + nameToDBMapMu.Lock() + defer nameToDBMapMu.Unlock() + db, exists := nameToDBMap[name] + if exists { + return db, nil + } + db, err := sql.Open("zetasqlite_sqlite3", name) + if err != nil { + return nil, fmt.Errorf("failed to open database by %s: %w", name, err) + } + nameToDBMap[name] = db + return db, nil +} + type ZetaSQLiteDriver struct { ConnectHook func(*ZetaSQLiteConn) error } func (d *ZetaSQLiteDriver) Open(name string) (driver.Conn, error) { - conn, err := newZetaSQLiteConn(name) + db, err := openDB(name) + if err != nil { + return nil, err + } + conn, err := newZetaSQLiteConn(db) if err != nil { return nil, err } @@ -37,30 +70,16 @@ func (d *ZetaSQLiteDriver) Open(name string) (driver.Conn, error) { } type ZetaSQLiteConn struct { - sqliteConn *sqlite3.SQLiteConn - conn driver.Conn - analyzer *Analyzer + conn *sql.Conn + analyzer *Analyzer } -func newZetaSQLiteConn(name string) (*ZetaSQLiteConn, error) { - var sqliteConn *sqlite3.SQLiteConn - sqliteDriver := &sqlite3.SQLiteDriver{ - ConnectHook: func(conn *sqlite3.SQLiteConn) error { - if err := registerBuiltinFunctions(conn); err != nil { - return err - } - sqliteConn = conn - return nil - }, - } - conn, err := sqliteDriver.Open(name) +func newZetaSQLiteConn(db *sql.DB) (*ZetaSQLiteConn, error) { + conn, err := db.Conn(context.Background()) if err != nil { - return nil, fmt.Errorf("zetasqlite: failed to open database: %w", err) - } - c := &ZetaSQLiteConn{ - sqliteConn: sqliteConn, - conn: conn, + return nil, fmt.Errorf("failed to get sqlite3 connection: %w", err) } + c := &ZetaSQLiteConn{conn: conn} c.analyzer = newAnalyzer(newCatalog(c)) return c, nil } @@ -82,35 +101,43 @@ func (s *ZetaSQLiteConn) CheckNamedValue(value *driver.NamedValue) error { } func (c *ZetaSQLiteConn) Prepare(query string) (driver.Stmt, error) { - out, err := c.analyzer.Analyze(query) + out, err := c.analyzer.Analyze(context.Background(), query) if err != nil { return nil, fmt.Errorf("failed to analyze query: %w", err) } - return out.prepare(c.conn) + return out.prepare(context.Background(), c.conn) } func (c *ZetaSQLiteConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - out, err := c.analyzer.Analyze(query) + out, err := c.analyzer.Analyze(ctx, query) if err != nil { return nil, fmt.Errorf("failed to analyze query: %w", err) } - newArgs, err := convertNamedValues(args) + newNamedValues, err := convertNamedValues(args) if err != nil { return nil, err } - return out.execContext(ctx, c.conn.(driver.ExecerContext), newArgs) + newArgs := make([]interface{}, 0, len(args)) + for _, newNamedValue := range newNamedValues { + newArgs = append(newArgs, newNamedValue) + } + return out.execContext(ctx, c.conn, newArgs...) } func (c *ZetaSQLiteConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - out, err := c.analyzer.Analyze(query) + out, err := c.analyzer.Analyze(ctx, query) if err != nil { return nil, fmt.Errorf("failed to analyze query: %w", err) } - newArgs, err := convertNamedValues(args) + newNamedValues, err := convertNamedValues(args) if err != nil { return nil, err } - return out.queryContext(ctx, c.conn.(driver.QueryerContext), newArgs) + newArgs := make([]interface{}, 0, len(args)) + for _, newNamedValue := range newNamedValues { + newArgs = append(newArgs, newNamedValue) + } + return out.queryContext(ctx, c.conn, newArgs...) } func (c *ZetaSQLiteConn) Close() error { @@ -118,7 +145,7 @@ func (c *ZetaSQLiteConn) Close() error { } func (c *ZetaSQLiteConn) Begin() (driver.Tx, error) { - tx, err := c.conn.Begin() + tx, err := c.conn.BeginTx(context.Background(), nil) if err != nil { return nil, err } @@ -129,7 +156,7 @@ func (c *ZetaSQLiteConn) Begin() (driver.Tx, error) { } type ZetaSQLiteTx struct { - tx driver.Tx + tx *sql.Tx conn *ZetaSQLiteConn } diff --git a/node.go b/node.go index b2f27ca..70dd759 100644 --- a/node.go +++ b/node.go @@ -536,6 +536,15 @@ func (n *FunctionCallNode) FormatSQL(ctx context.Context) (string, error) { ), ) fullpath.idx++ + funcMap := funcMapFromContext(ctx) + if spec, exists := funcMap[funcName]; exists { + body := spec.Body + for _, arg := range args { + // TODO: Need to recognize the argument exactly. + body = strings.Replace(body, "?", arg, 1) + } + return fmt.Sprintf("( %s )", body), nil + } } return fmt.Sprintf( "%s(%s)", @@ -576,6 +585,15 @@ func (n *AggregateFunctionCallNode) FormatSQL(ctx context.Context) (string, erro ), ) fullpath.idx++ + funcMap := funcMapFromContext(ctx) + if spec, exists := funcMap[funcName]; exists { + body := spec.Body + for _, arg := range args { + // TODO: Need to recognize the argument exactly. + body = strings.Replace(body, "?", arg, 1) + } + return fmt.Sprintf("( %s )", body), nil + } } return fmt.Sprintf( "%s(%s)", diff --git a/rows.go b/rows.go index 1b2abbe..cf72e79 100644 --- a/rows.go +++ b/rows.go @@ -1,13 +1,16 @@ package zetasqlite import ( + "database/sql" "database/sql/driver" + "io" + "reflect" "github.com/goccy/go-zetasql/types" ) type Rows struct { - rows driver.Rows + rows *sql.Rows columns []*ColumnSpec } @@ -19,7 +22,14 @@ func (r *Rows) Columns() []string { return colNames } +func (r *Rows) ColumnTypeDatabaseTypeName(i int) string { + return r.columns[i].Type.Name +} + func (r *Rows) Close() error { + if r.rows == nil { + return nil + } return r.rows.Close() } @@ -32,14 +42,28 @@ func (r *Rows) columnTypes() ([]*Type, error) { } func (r *Rows) Next(dest []driver.Value) error { + if r.rows == nil { + return io.EOF + } + if !r.rows.Next() { + return io.EOF + } + if err := r.rows.Err(); err != nil { + return err + } colTypes, err := r.columnTypes() if err != nil { return err } - values := make([]driver.Value, len(colTypes)) - retErr := r.rows.Next(values) + values := make([]interface{}, 0, len(dest)) + for i := 0; i < len(dest); i++ { + var v interface{} + values = append(values, &v) + } + retErr := r.rows.Scan(values...) for idx, colType := range colTypes { - value, err := r.convertValue(values[idx], colType) + v := reflect.ValueOf(values[idx]).Elem().Interface() + value, err := r.convertValue(v, colType) if err != nil { return err } @@ -48,7 +72,7 @@ func (r *Rows) Next(dest []driver.Value) error { return retErr } -func (r *Rows) convertValue(value driver.Value, typ *Type) (driver.Value, error) { +func (r *Rows) convertValue(value interface{}, typ *Type) (driver.Value, error) { if typ.IsArray() { val, err := ValueOf(value) if err != nil { diff --git a/spec.go b/spec.go index 78382cb..7bf4391 100644 --- a/spec.go +++ b/spec.go @@ -2,9 +2,7 @@ package zetasqlite import ( "context" - "database/sql/driver" "fmt" - "io" "strings" ast "github.com/goccy/go-zetasql/resolved_ast" @@ -45,39 +43,6 @@ func (s *FunctionSpec) SQL() string { ) } -func (s *FunctionSpec) FuncBody(c *ZetaSQLiteConn) interface{} { - switch types.TypeKind(s.Return.Kind) { - case types.INT64: - return func(args ...interface{}) (int64, error) { - stmt, err := c.conn.Prepare(s.Body) - if err != nil { - return 0, err - } - driverArgs := []driver.Value{} - for _, arg := range args { - driverArgs = append(driverArgs, arg) - } - newArgs, err := convertValues(driverArgs) - if err != nil { - return 0, err - } - rows, err := stmt.Query(newArgs) - if err != nil { - return 0, err - } - defer rows.Close() - values := make([]driver.Value, 1) - if err := rows.Next(values); err != nil { - if err != io.EOF { - return 0, err - } - } - return values[0].(int64), nil - } - } - return nil -} - type TableSpec struct { NamePath []string `json:"namePath"` Columns []*ColumnSpec `json:"columns"` diff --git a/stmt.go b/stmt.go index 8fc6e23..47e69ee 100644 --- a/stmt.go +++ b/stmt.go @@ -2,6 +2,7 @@ package zetasqlite import ( "context" + "database/sql" "database/sql/driver" "fmt" ) @@ -14,7 +15,7 @@ var ( ) type CreateTableStmt struct { - stmt driver.Stmt + stmt *sql.Stmt catalog *Catalog spec *TableSpec } @@ -31,7 +32,7 @@ func (s *CreateTableStmt) Exec(args []driver.Value) (driver.Result, error) { if _, err := s.stmt.Exec(args); err != nil { return nil, err } - if err := s.catalog.AddNewTableSpec(s.spec); err != nil { + if err := s.catalog.AddNewTableSpec(context.Background(), s.spec); err != nil { return nil, fmt.Errorf("failed to add new table spec: %w", err) } return nil, nil @@ -41,7 +42,7 @@ func (s *CreateTableStmt) Query(args []driver.Value) (driver.Rows, error) { return nil, fmt.Errorf("failed to query for CreateTableStmt") } -func newCreateTableStmt(stmt driver.Stmt, catalog *Catalog, spec *TableSpec) *CreateTableStmt { +func newCreateTableStmt(stmt *sql.Stmt, catalog *Catalog, spec *TableSpec) *CreateTableStmt { return &CreateTableStmt{ stmt: stmt, catalog: catalog, @@ -63,7 +64,7 @@ func (s *CreateFunctionStmt) NumInput() int { } func (s *CreateFunctionStmt) Exec(args []driver.Value) (driver.Result, error) { - if err := s.catalog.AddNewFunctionSpec(s.spec); err != nil { + if err := s.catalog.AddNewFunctionSpec(context.Background(), s.spec); err != nil { return nil, fmt.Errorf("failed to add new function spec: %w", err) } return nil, nil @@ -81,12 +82,12 @@ func newCreateFunctionStmt(catalog *Catalog, spec *FunctionSpec) *CreateFunction } type DMLStmt struct { - stmt driver.Stmt + stmt *sql.Stmt argsNum int formattedQuery string } -func newDMLStmt(stmt driver.Stmt, argsNum int, formattedQuery string) *DMLStmt { +func newDMLStmt(stmt *sql.Stmt, argsNum int, formattedQuery string) *DMLStmt { return &DMLStmt{ stmt: stmt, argsNum: argsNum, @@ -107,11 +108,15 @@ func (s *DMLStmt) NumInput() int { } func (s *DMLStmt) Exec(args []driver.Value) (driver.Result, error) { - newArgs, err := convertValues(args) + values := make([]interface{}, 0, len(args)) + for _, arg := range args { + values = append(values, arg) + } + newArgs, err := convertValues(values) if err != nil { return nil, err } - result, err := s.stmt.Exec(newArgs) + result, err := s.stmt.Exec(newArgs...) if err != nil { return nil, fmt.Errorf( "failed to execute query %s: args %v: %w", @@ -136,13 +141,13 @@ func (s *DMLStmt) QueryContext(ctx context.Context, query string, args []driver. } type QueryStmt struct { - stmt driver.Stmt + stmt *sql.Stmt argsNum int formattedQuery string outputColumns []*ColumnSpec } -func newQueryStmt(stmt driver.Stmt, argsNum int, formattedQuery string, outputColumns []*ColumnSpec) *QueryStmt { +func newQueryStmt(stmt *sql.Stmt, argsNum int, formattedQuery string, outputColumns []*ColumnSpec) *QueryStmt { return &QueryStmt{ stmt: stmt, argsNum: argsNum, @@ -176,12 +181,15 @@ func (s *QueryStmt) ExecContext(ctx context.Context, query string, args []driver } func (s *QueryStmt) Query(args []driver.Value) (driver.Rows, error) { - newArgs, err := convertValues(args) + values := make([]interface{}, 0, len(args)) + for _, arg := range args { + values = append(values, arg) + } + newArgs, err := convertValues(values) if err != nil { return nil, err } - return s.stmt.Query(newArgs) - rows, err := s.stmt.Query(newArgs) + rows, err := s.stmt.Query(newArgs...) if err != nil { return nil, fmt.Errorf( "failed to query %s: args: %v: %w", diff --git a/value.go b/value.go index 25bac00..d50987a 100644 --- a/value.go +++ b/value.go @@ -2,6 +2,7 @@ package zetasqlite import ( "bytes" + "database/sql" "database/sql/driver" "encoding/base64" "encoding/json" @@ -972,8 +973,8 @@ func isNULLValue(v interface{}) bool { return len(vv) == 0 } -func convertNamedValues(v []driver.NamedValue) ([]driver.NamedValue, error) { - ret := make([]driver.NamedValue, 0, len(v)) +func convertNamedValues(v []driver.NamedValue) ([]sql.NamedArg, error) { + ret := make([]sql.NamedArg, 0, len(v)) for _, vv := range v { converted, err := convertNamedValue(vv) if err != nil { @@ -984,20 +985,19 @@ func convertNamedValues(v []driver.NamedValue) ([]driver.NamedValue, error) { return ret, nil } -func convertNamedValue(v driver.NamedValue) (driver.NamedValue, error) { +func convertNamedValue(v driver.NamedValue) (sql.NamedArg, error) { value, err := SQLiteValue(v.Value) if err != nil { - return driver.NamedValue{}, err + return sql.NamedArg{}, err } - return driver.NamedValue{ - Name: strings.ToLower(v.Name), - Ordinal: v.Ordinal, - Value: value, + return sql.NamedArg{ + Name: strings.ToLower(v.Name), + Value: value, }, nil } -func convertValues(v []driver.Value) ([]driver.Value, error) { - ret := make([]driver.Value, 0, len(v)) +func convertValues(v []interface{}) ([]interface{}, error) { + ret := make([]interface{}, 0, len(v)) for _, vv := range v { value, err := SQLiteValue(vv) if err != nil { From e8348950e43648d48a17fa7f06c9882fb4b03de7 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Sat, 25 Jun 2022 17:00:15 +0900 Subject: [PATCH 2/2] fix instance management of catalog --- catalog.go | 131 +++++++++++++++++++++++++++++++++++------------------ driver.go | 32 +++++++------ 2 files changed, 105 insertions(+), 58 deletions(-) diff --git a/catalog.go b/catalog.go index eb01cad..71ec39c 100644 --- a/catalog.go +++ b/catalog.go @@ -2,9 +2,11 @@ package zetasqlite import ( "context" + "database/sql" "encoding/json" "fmt" "reflect" + "sync" "time" "github.com/goccy/go-zetasql/types" @@ -15,11 +17,28 @@ var ( CREATE TABLE IF NOT EXISTS zetasqlite_catalog( name STRING NOT NULL PRIMARY KEY, kind STRING NOT NULL, - spec STRING NOT NULL + spec STRING NOT NULL, + updatedAt TIMESTAMP NOT NULL, + createdAt TIMESTAMP NOT NULL ) ` - loadCatalogQuery = `SELECT name, kind, spec FROM zetasqlite_catalog` - upsertCatalogQuery = `INSERT INTO zetasqlite_catalog (name, kind, spec) VALUES (?, ?, ?) ON CONFLICT(name) DO UPDATE SET spec = ?` + upsertCatalogQuery = ` +INSERT INTO zetasqlite_catalog ( + name, + kind, + spec, + updatedAt, + createdAt +) VALUES ( + @name, + @kind, + @spec, + @updatedAt, + @createdAt +) ON CONFLICT(name) DO UPDATE SET + spec = @spec, + updatedAt = @updatedAt +` ) type CatalogSpecKind string @@ -30,19 +49,21 @@ const ( ) type Catalog struct { - conn *ZetaSQLiteConn - catalog *types.SimpleCatalog - tables []*TableSpec - functions []*FunctionSpec - tableMap map[string]*TableSpec - funcMap map[string]*FunctionSpec + db *sql.DB + catalog *types.SimpleCatalog + lastSyncedAt time.Time + mu sync.Mutex + tables []*TableSpec + functions []*FunctionSpec + tableMap map[string]*TableSpec + funcMap map[string]*FunctionSpec } -func newCatalog(conn *ZetaSQLiteConn) *Catalog { +func newCatalog(db *sql.DB) *Catalog { catalog := types.NewSimpleCatalog("zetasqlite") catalog.AddZetaSQLBuiltinFunctions() return &Catalog{ - conn: conn, + db: db, catalog: catalog, tableMap: map[string]*TableSpec{}, funcMap: map[string]*FunctionSpec{}, @@ -50,10 +71,23 @@ func newCatalog(conn *ZetaSQLiteConn) *Catalog { } func (c *Catalog) Sync(ctx context.Context) error { - if err := c.createCatalogTablesIfNotExists(ctx); err != nil { - return err + c.mu.Lock() + defer c.mu.Unlock() + + tx, err := c.db.Begin() + if err != nil { + return fmt.Errorf("failed to start transaction for zetasqlite_catalog: %w", err) + } + defer tx.Commit() + if err := c.createCatalogTablesIfNotExists(ctx, tx); err != nil { + return fmt.Errorf("failed to create catalog tables: %w", err) } - rows, err := c.conn.conn.QueryContext(ctx, loadCatalogQuery) + now := time.Now() + rows, err := tx.QueryContext( + ctx, + `SELECT name, kind, spec FROM zetasqlite_catalog WHERE updatedAt >= @lastUpdatedAt`, + c.lastSyncedAt, + ) if err != nil { return fmt.Errorf("failed to query load catalog: %w", err) } @@ -80,68 +114,88 @@ func (c *Catalog) Sync(ctx context.Context) error { return fmt.Errorf("unknown catalog spec kind %s", kind) } } + c.lastSyncedAt = now return nil } func (c *Catalog) AddNewTableSpec(ctx context.Context, spec *TableSpec) error { + c.mu.Lock() + defer c.mu.Unlock() + if err := c.addTableSpec(spec); err != nil { return err } - if err := c.saveTableSpec(ctx, spec); err != nil { + tx, err := c.db.Begin() + if err != nil { + return err + } + defer tx.Commit() + if err := c.saveTableSpec(ctx, tx, spec); err != nil { return err } return nil } func (c *Catalog) AddNewFunctionSpec(ctx context.Context, spec *FunctionSpec) error { + c.mu.Lock() + defer c.mu.Unlock() + if err := c.addFunctionSpec(spec); err != nil { return err } - if err := c.saveFunctionSpec(ctx, spec); err != nil { + tx, err := c.db.Begin() + if err != nil { + return err + } + defer tx.Commit() + if err := c.saveFunctionSpec(ctx, tx, spec); err != nil { return err } return nil } -func (c *Catalog) saveTableSpec(ctx context.Context, spec *TableSpec) error { +func (c *Catalog) saveTableSpec(ctx context.Context, tx *sql.Tx, spec *TableSpec) error { encoded, err := json.Marshal(spec) if err != nil { return fmt.Errorf("failed to encode table spec: %w", err) } - - if err := c.exec( + now := time.Now() + if _, err := tx.ExecContext( ctx, upsertCatalogQuery, - spec.TableName(), - string(TableSpecKind), - string(encoded), - string(encoded), + sql.Named("name", spec.TableName()), + sql.Named("kind", string(TableSpecKind)), + sql.Named("spec", string(encoded)), + sql.Named("updatedAt", now), + sql.Named("createdAt", now), ); err != nil { return fmt.Errorf("failed to save a new table spec: %w", err) } return nil } -func (c *Catalog) saveFunctionSpec(ctx context.Context, spec *FunctionSpec) error { +func (c *Catalog) saveFunctionSpec(ctx context.Context, tx *sql.Tx, spec *FunctionSpec) error { encoded, err := json.Marshal(spec) if err != nil { return fmt.Errorf("failed to encode function spec: %w", err) } - if err := c.exec( + now := time.Now() + if _, err := tx.ExecContext( ctx, upsertCatalogQuery, - spec.FuncName(), - string(FunctionSpecKind), - string(encoded), - string(encoded), + sql.Named("name", spec.FuncName()), + sql.Named("kind", string(FunctionSpecKind)), + sql.Named("spec", string(encoded)), + sql.Named("updatedAt", now), + sql.Named("createdAt", now), ); err != nil { return fmt.Errorf("failed to save a new function spec: %w", err) } return nil } -func (c *Catalog) createCatalogTablesIfNotExists(ctx context.Context) error { - if err := c.exec(ctx, createCatalogTableQuery); err != nil { +func (c *Catalog) createCatalogTablesIfNotExists(ctx context.Context, tx *sql.Tx) error { + if _, err := tx.ExecContext(ctx, createCatalogTableQuery); err != nil { return fmt.Errorf("failed to create catalog table: %w", err) } return nil @@ -172,6 +226,7 @@ func (c *Catalog) loadFunctionSpec(spec string) error { func (c *Catalog) addFunctionSpec(spec *FunctionSpec) error { funcName := spec.FuncName() if _, exists := c.funcMap[funcName]; exists { + c.funcMap[funcName] = spec // update current spec return nil } c.functions = append(c.functions, spec) @@ -182,6 +237,7 @@ func (c *Catalog) addFunctionSpec(spec *FunctionSpec) error { func (c *Catalog) addTableSpec(spec *TableSpec) error { tableName := spec.TableName() if _, exists := c.tableMap[tableName]; exists { + c.tableMap[tableName] = spec // update current spec return nil } c.tables = append(c.tables, spec) @@ -323,16 +379,3 @@ func (c *Catalog) copyFunctionSpec(spec *FunctionSpec, newNamePath []string) *Fu Body: spec.Body, } } - -func (c *Catalog) exec(ctx context.Context, query string, args ...interface{}) error { - var retErr error - for i := 0; i < 5; i++ { - if _, err := c.conn.conn.ExecContext(ctx, query, args...); err != nil { - retErr = err - time.Sleep(100 * time.Millisecond) - continue - } - return nil - } - return retErr -} diff --git a/driver.go b/driver.go index e29711a..25bbe1e 100644 --- a/driver.go +++ b/driver.go @@ -17,8 +17,9 @@ var ( ) var ( - nameToDBMap = map[string]*sql.DB{} - nameToDBMapMu sync.Mutex + nameToCatalogMap = map[string]*Catalog{} + nameToDBMap = map[string]*sql.DB{} + nameToValueMapMu sync.Mutex ) func init() { @@ -33,19 +34,21 @@ func init() { }) } -func openDB(name string) (*sql.DB, error) { - nameToDBMapMu.Lock() - defer nameToDBMapMu.Unlock() +func newDBAndCatalog(name string) (*sql.DB, *Catalog, error) { + nameToValueMapMu.Lock() + defer nameToValueMapMu.Unlock() db, exists := nameToDBMap[name] if exists { - return db, nil + return db, nameToCatalogMap[name], nil } db, err := sql.Open("zetasqlite_sqlite3", name) if err != nil { - return nil, fmt.Errorf("failed to open database by %s: %w", name, err) + return nil, nil, fmt.Errorf("failed to open database by %s: %w", name, err) } + catalog := newCatalog(db) nameToDBMap[name] = db - return db, nil + nameToCatalogMap[name] = catalog + return db, catalog, nil } type ZetaSQLiteDriver struct { @@ -53,11 +56,11 @@ type ZetaSQLiteDriver struct { } func (d *ZetaSQLiteDriver) Open(name string) (driver.Conn, error) { - db, err := openDB(name) + db, catalog, err := newDBAndCatalog(name) if err != nil { return nil, err } - conn, err := newZetaSQLiteConn(db) + conn, err := newZetaSQLiteConn(db, catalog) if err != nil { return nil, err } @@ -74,14 +77,15 @@ type ZetaSQLiteConn struct { analyzer *Analyzer } -func newZetaSQLiteConn(db *sql.DB) (*ZetaSQLiteConn, error) { +func newZetaSQLiteConn(db *sql.DB, catalog *Catalog) (*ZetaSQLiteConn, error) { conn, err := db.Conn(context.Background()) if err != nil { return nil, fmt.Errorf("failed to get sqlite3 connection: %w", err) } - c := &ZetaSQLiteConn{conn: conn} - c.analyzer = newAnalyzer(newCatalog(c)) - return c, nil + return &ZetaSQLiteConn{ + conn: conn, + analyzer: newAnalyzer(catalog), + }, nil } func (c *ZetaSQLiteConn) NamePath() []string {