Skip to content

Commit

Permalink
Merge pull request #2 from goccy/fix-udf
Browse files Browse the repository at this point in the history
Fix UDF and connection management
  • Loading branch information
goccy authored Jun 25, 2022
2 parents f9abc41 + e834895 commit f2a6812
Show file tree
Hide file tree
Showing 9 changed files with 294 additions and 214 deletions.
65 changes: 37 additions & 28 deletions analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package zetasqlite

import (
"context"
"database/sql"
"database/sql/driver"
"fmt"

Expand All @@ -25,9 +26,9 @@ type AnalyzerOutput struct {
isQuery bool
tableSpec *TableSpec
outputColumns []*ColumnSpec
prepare func(conn driver.Conn) (driver.Stmt, error)
execContext func(context.Context, driver.ExecerContext, []driver.NamedValue) (driver.Result, error)
queryContext func(context.Context, driver.QueryerContext, []driver.NamedValue) (driver.Rows, error)
prepare func(context.Context, *sql.Conn) (driver.Stmt, error)
execContext func(context.Context, *sql.Conn, ...interface{}) (driver.Result, error)
queryContext func(context.Context, *sql.Conn, ...interface{}) (driver.Rows, error)
}

func newAnalyzer(catalog *Catalog) *Analyzer {
Expand Down Expand Up @@ -74,8 +75,8 @@ func newAnalyzerOptions() *zetasql.AnalyzerOptions {
return opt
}

func (a *Analyzer) Analyze(query string) (*AnalyzerOutput, error) {
if err := a.catalog.Sync(); err != nil {
func (a *Analyzer) Analyze(ctx context.Context, query string) (*AnalyzerOutput, error) {
if err := a.catalog.Sync(ctx); err != nil {
return nil, fmt.Errorf("failed to sync catalog: %w", err)
}
out, err := zetasql.AnalyzeStatement(query, a.catalog.catalog, a.opt)
Expand All @@ -86,9 +87,14 @@ func (a *Analyzer) Analyze(query string) (*AnalyzerOutput, error) {
if err != nil {
return nil, fmt.Errorf("failed to get full name path %s: %w", query, err)
}
ctx := withNamePath(context.Background(), a.namePath)
funcMap := map[string]*FunctionSpec{}
for _, spec := range a.catalog.functions {
funcMap[spec.FuncName()] = spec
}
ctx = withNamePath(ctx, a.namePath)
ctx = withColumnRefMap(ctx, map[string]string{})
ctx = withFullNamePath(ctx, fullpath)
ctx = withFuncMap(ctx, funcMap)
stmtNode := out.Statement()
switch stmtNode.Kind() {
case ast.CreateTableStmt:
Expand All @@ -110,33 +116,30 @@ func (a *Analyzer) analyzeCreateTableStmt(query string, node *ast.CreateTableStm
query: query,
argsNum: a.getParamNumFromNode(node),
tableSpec: spec,
prepare: func(conn driver.Conn) (driver.Stmt, error) {
prepare: func(ctx context.Context, conn *sql.Conn) (driver.Stmt, error) {
if spec.CreateMode == ast.CreateOrReplaceMode {
stmt, err := conn.Prepare(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", spec.TableName()))
if err != nil {
return nil, err
}
if _, err := stmt.Exec(nil); err != nil {
query := fmt.Sprintf("DROP TABLE IF EXISTS `%s`", spec.TableName())
if _, err := conn.ExecContext(ctx, query); err != nil {
return nil, err
}
}
s, err := conn.Prepare(spec.SQLiteSchema())
s, err := conn.PrepareContext(ctx, spec.SQLiteSchema())
if err != nil {
return nil, fmt.Errorf("failed to prepare %s: %w", query, err)
}
return newCreateTableStmt(s, a.catalog, spec), nil
},
execContext: func(ctx context.Context, execer driver.ExecerContext, args []driver.NamedValue) (driver.Result, error) {
execContext: func(ctx context.Context, conn *sql.Conn, args ...interface{}) (driver.Result, error) {
if spec.CreateMode == ast.CreateOrReplaceMode {
dropTableQuery := fmt.Sprintf("DROP TABLE IF EXISTS `%s`", spec.TableName())
if _, err := execer.ExecContext(ctx, dropTableQuery, nil); err != nil {
if _, err := conn.ExecContext(ctx, dropTableQuery); err != nil {
return nil, err
}
}
if _, err := execer.ExecContext(ctx, spec.SQLiteSchema(), args); err != nil {
if _, err := conn.ExecContext(ctx, spec.SQLiteSchema(), args...); err != nil {
return nil, fmt.Errorf("failed to exec %s: %w", query, err)
}
if err := a.catalog.AddNewTableSpec(spec); err != nil {
if err := a.catalog.AddNewTableSpec(ctx, spec); err != nil {
return nil, fmt.Errorf("failed to add new table spec: %w", err)
}
return nil, nil
Expand All @@ -152,15 +155,21 @@ func (a *Analyzer) analyzeCreateFunctionStmt(ctx context.Context, query string,
return &AnalyzerOutput{
query: query,
node: node,
prepare: func(conn driver.Conn) (driver.Stmt, error) {
prepare: func(ctx context.Context, conn *sql.Conn) (driver.Stmt, error) {
return newCreateFunctionStmt(a.catalog, spec), nil
},
execContext: func(ctx context.Context, execer driver.ExecerContext, args []driver.NamedValue) (driver.Result, error) {
if err := a.catalog.AddNewFunctionSpec(spec); err != nil {
execContext: func(ctx context.Context, conn *sql.Conn, args ...interface{}) (driver.Result, error) {
if err := a.catalog.AddNewFunctionSpec(ctx, spec); err != nil {
return nil, fmt.Errorf("failed to add new function spec: %w", err)
}
return nil, nil
},
queryContext: func(ctx context.Context, conn *sql.Conn, args ...interface{}) (driver.Rows, error) {
if err := a.catalog.AddNewFunctionSpec(ctx, spec); err != nil {
return nil, fmt.Errorf("failed to add new function spec: %w", err)
}
return &Rows{}, nil
},
}, nil
}

Expand All @@ -178,15 +187,15 @@ func (a *Analyzer) analyzeDMLStmt(ctx context.Context, query string, node ast.No
query: query,
formattedQuery: formattedQuery,
argsNum: argsNum,
prepare: func(conn driver.Conn) (driver.Stmt, error) {
s, err := conn.Prepare(formattedQuery)
prepare: func(ctx context.Context, conn *sql.Conn) (driver.Stmt, error) {
s, err := conn.PrepareContext(ctx, formattedQuery)
if err != nil {
return nil, fmt.Errorf("failed to prepare %s: %w", query, err)
}
return newDMLStmt(s, argsNum, formattedQuery), nil
},
execContext: func(ctx context.Context, execer driver.ExecerContext, args []driver.NamedValue) (driver.Result, error) {
if _, err := execer.ExecContext(ctx, formattedQuery, args); err != nil {
execContext: func(ctx context.Context, conn *sql.Conn, args ...interface{}) (driver.Result, error) {
if _, err := conn.ExecContext(ctx, formattedQuery, args...); err != nil {
return nil, fmt.Errorf("failed to exec %s: %w", formattedQuery, err)
}
return nil, nil
Expand Down Expand Up @@ -216,15 +225,15 @@ func (a *Analyzer) analyzeQueryStmt(ctx context.Context, query string, node *ast
formattedQuery: formattedQuery,
argsNum: argsNum,
isQuery: true,
prepare: func(conn driver.Conn) (driver.Stmt, error) {
s, err := conn.Prepare(formattedQuery)
prepare: func(ctx context.Context, conn *sql.Conn) (driver.Stmt, error) {
s, err := conn.PrepareContext(ctx, formattedQuery)
if err != nil {
return nil, fmt.Errorf("failed to prepare %s: %w", query, err)
}
return newQueryStmt(s, argsNum, formattedQuery, outputColumns), nil
},
queryContext: func(ctx context.Context, queryer driver.QueryerContext, args []driver.NamedValue) (driver.Rows, error) {
rows, err := queryer.QueryContext(ctx, formattedQuery, args)
queryContext: func(ctx context.Context, conn *sql.Conn, args ...interface{}) (driver.Rows, error) {
rows, err := conn.QueryContext(ctx, formattedQuery, args...)
if err != nil {
return nil, fmt.Errorf("failed to query %s: %w", formattedQuery, err)
}
Expand Down
Loading

0 comments on commit f2a6812

Please sign in to comment.