Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add partial ALTER TABLE support (adding column and DEFAULT) #113 #1

Merged
merged 1 commit into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 184 additions & 9 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,196 @@ import (
zetasqlite "github.com/goccy/go-zetasqlite"
)

func TestDriverAlter(t *testing.T) {
db, err := sql.Open("zetasqlite", ":memory:")
if err != nil {
t.Fatal(err)
}
if _, err := db.Exec(`
CREATE TABLE IF NOT EXISTS Artists (
SingerId INT64 NOT NULL,
FirstName STRING(1024),
LastName STRING(1024),
SingerInfo BYTES(MAX)
)
`); err != nil {
t.Fatal(err)
}
if _, err := db.Exec(`INSERT Artists (SingerId, FirstName, LastName) VALUES (1, 'John', 'Titor')`); err != nil {
t.Fatal(err)
}
row := db.QueryRow(`SELECT SingerId, FirstName, LastName FROM Artists WHERE SingerId = @id`, 1)
if row.Err() != nil {
t.Fatal(row.Err())
}
var (
singerID int64
firstName string
lastName string
)
if err := row.Scan(&singerID, &firstName, &lastName); err != nil {
t.Fatal(err)
}
if singerID != 1 || firstName != "John" || lastName != "Titor" {
t.Fatalf("failed to find row %v %v %v", singerID, firstName, lastName)
}

if _, err := db.Exec(`
CREATE VIEW IF NOT EXISTS
SingerNames AS SELECT FirstName || ' ' || LastName AS Name
FROM Artists
`); err != nil {
t.Fatal(err)
}

viewRow := db.QueryRow(`SELECT Name FROM SingerNames LIMIT 1`)
if viewRow.Err() != nil {
t.Fatal(viewRow.Err())
}

var name string

if err := viewRow.Scan(&name); err != nil {
t.Fatal(err)
}
if name != "John Titor" {
t.Fatalf("failed to find view row")
}

// Test ALTER TABLE SET OPTIONS
if _, err := db.Exec(`ALTER TABLE Artists SET OPTIONS (description="Famous Artists")`); err != nil {
t.Fatal(err)
}

// Test ALTER TABLE ADD COLUMN
if _, err := db.Exec(`ALTER TABLE Artists ADD COLUMN Age INT64, ADD COLUMN IsSingle BOOL`); err != nil {
t.Fatal(err)
}

// Verify the changes
row = db.QueryRow(`
SELECT SingerId, FirstName, LastName, Age, IsSingle
FROM Artists
WHERE SingerId = @id`,
1,
)
if row.Err() != nil {
t.Fatal(row.Err())
}

var age sql.NullInt64
var isSingle sql.NullBool
if err := row.Scan(&singerID, &firstName, &lastName, &age, &isSingle); err != nil {
t.Fatal(err)
}
if singerID != 1 || firstName != "John" || lastName != "Titor" || age.Valid || isSingle.Valid {
t.Fatalf("failed to find row after ALTER TABLE statements")
}

if _, err := db.Exec(`
INSERT Artists (SingerId, FirstName, LastName, Age, IsSingle)
VALUES (2, 'Mike', 'Bit', 11, TRUE)
`); err != nil {
t.Fatal(err)
}
row = db.QueryRow(`
SELECT SingerId, FirstName, LastName, Age, isSingle
FROM Artists
WHERE SingerId = @id AND isSingle IS NOT NULL`,
2,
)
if row.Err() != nil {
t.Fatal(row.Err())
}
if err := row.Scan(&singerID, &firstName, &lastName, &age, &isSingle); err != nil {
t.Fatal(err)
}
if singerID != 2 || firstName != "Mike" || lastName != "Bit" || age.Int64 != 11 || isSingle.Bool != true {
t.Fatalf("Failed to find row %v %v %v %v %v", singerID, firstName, lastName, age, isSingle)
}

if _, err := db.Exec(`
ALTER TABLE Artists
ADD COLUMN Nationality STRING
`); err != nil {
t.Fatal(err)
}

if _, err := db.Exec(`
ALTER TABLE Artists
ALTER COLUMN Nationality SET DEFAULT 'Unknown'
`); err != nil {
t.Fatal(err)
}

// Verify the changes
row = db.QueryRow(`
SELECT SingerID, FirstName, LastName, Age, IsSingle, Nationality
FROM Artists
WHERE SingerId = @id`,
2,
)
if row.Err() != nil {
t.Fatal(row.Err())
}

var nationality sql.NullString
if err := row.Scan(&singerID, &firstName, &lastName, &age, &isSingle, &nationality); err != nil {
t.Fatal(err)
}

if singerID != 2 || firstName != "Mike" || lastName != "Bit" || age.Int64 != 11 || isSingle.Bool != true || nationality.Valid {
t.Fatalf("failed to find row after multi-action ALTER TABLE statement")
}

if _, err := db.Exec(`
INSERT Artists (SingerId, FirstName, LastName, Age, IsSingle)
VALUES (3, 'Mark', 'Byte', 12, FALSE)
`); err != nil {
t.Fatal(err)
}

// Verify the changes
row = db.QueryRow(`
SELECT SingerID, FirstName, LastName, Age, IsSingle, Nationality
FROM Artists
WHERE SingerId = @id`,
3,
)
if row.Err() != nil {
t.Fatal(row.Err())
}

if err := row.Scan(&singerID, &firstName, &lastName, &age, &isSingle, &nationality); err != nil {
t.Fatal(err)
}
if singerID != 3 || firstName != "Mark" || lastName != "Byte" || age.Int64 != 12 || isSingle.Bool != false || nationality.String != "Unknown" {
t.Fatalf("failed to find row after multi-action ALTER TABLE statement")
}
}

func TestDriver(t *testing.T) {
db, err := sql.Open("zetasqlite", ":memory:")
if err != nil {
t.Fatal(err)
}
if _, err := db.Exec(`
CREATE TABLE IF NOT EXISTS Singers (
SingerId INT64 NOT NULL,
FirstName STRING(1024),
LastName STRING(1024),
SingerInfo BYTES(MAX)
)`); err != nil {
CREATE TABLE IF NOT EXISTS Singers (
SingerId INT64 NOT NULL,
FirstName STRING(1024),
LastName STRING(1024),
SingerInfo BYTES(MAX)
)
`); err != nil {
t.Fatal(err)
}
if _, err := db.Exec(`INSERT Singers (SingerId, FirstName, LastName) VALUES (1, 'John', 'Titor')`); err != nil {
if _, err := db.Exec(`
INSERT Singers (SingerId, FirstName, LastName)
VALUES (1, 'John', 'Titor')
`); err != nil {
t.Fatal(err)
}
row := db.QueryRow("SELECT SingerID, FirstName, LastName FROM Singers WHERE SingerId = @id", 1)
row := db.QueryRow(`SELECT SingerID, FirstName, LastName FROM Singers WHERE SingerId = @id`, 1)
if row.Err() != nil {
t.Fatal(row.Err())
}
Expand All @@ -43,7 +215,10 @@ CREATE TABLE IF NOT EXISTS Singers (
t.Fatalf("failed to find row %v %v %v", singerID, firstName, lastName)
}
if _, err := db.Exec(`
CREATE VIEW IF NOT EXISTS SingerNames AS SELECT FirstName || ' ' || LastName AS Name FROM Singers`); err != nil {
CREATE VIEW IF NOT EXISTS
SingerNames AS SELECT FirstName || ' ' || LastName AS Name
FROM Singers
`); err != nil {
t.Fatal(err)
}

Expand Down
24 changes: 24 additions & 0 deletions internal/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ func newAnalyzerOptions() (*zetasql.AnalyzerOptions, error) {
zetasql.FeatureV11WithOnSubquery,
zetasql.FeatureV13Pivot,
zetasql.FeatureV13Unpivot,
zetasql.FeatureV13ColumnDefaultValue,
})
langOpt.SetSupportedStatementKinds([]ast.Kind{
ast.BeginStmt,
Expand All @@ -87,6 +88,7 @@ func newAnalyzerOptions() (*zetasql.AnalyzerOptions, error) {
ast.DropStmt,
ast.TruncateStmt,
ast.CreateTableStmt,
ast.AlterTableStmt,
ast.CreateTableAsSelectStmt,
ast.CreateProcedureStmt,
ast.CreateFunctionStmt,
Expand Down Expand Up @@ -290,10 +292,32 @@ func (a *Analyzer) newStmtAction(ctx context.Context, query string, args []drive
return a.newBeginStmtAction(ctx, query, args, node)
case ast.CommitStmt:
return a.newCommitStmtAction(ctx, query, args, node)
case ast.AlterTableStmt:
return a.alterTableStmtAction(ctx, query, args, node.(*ast.AlterTableStmtNode))
}
return nil, fmt.Errorf("unsupported stmt %s", node.DebugString())
}

func (a *Analyzer) alterTableStmtAction(ctx context.Context, query string, args []driver.NamedValue, node *ast.AlterTableStmtNode) (*AlterTableStmtAction, error) {
spec, err := newAlterSpec(ctx, a.namePath, node)
if err != nil {
return nil, err
}
params := getParamsFromNode(node)
queryArgs, err := getArgsFromParams(args, params)
if err != nil {
return nil, err
}
return &AlterTableStmtAction{
query: query,
spec: spec,
node: node,
args: queryArgs,
rawArgs: args,
catalog: a.catalog,
}, nil
}

func (a *Analyzer) newCreateTableStmtAction(_ context.Context, query string, args []driver.NamedValue, node *ast.CreateTableStmtNode) (*CreateTableStmtAction, error) {
spec := newTableSpec(a.namePath, node)
params := getParamsFromNode(node)
Expand Down
36 changes: 36 additions & 0 deletions internal/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,42 @@ func (c *Catalog) addTableSpec(spec *TableSpec) error {
return nil
}

func (c *Catalog) modifyTableSpec(spec *AlterTableSpec) error {
tableName := spec.TableName()
foundSpecToUpdate, exists := c.tableMap[tableName]

if !exists {
return fmt.Errorf("table %s does not exist", tableName)
}

formattedPath := formatPath(spec.NamePath)

err := c.deleteTableSpecByName(formattedPath)
if err != nil {
return err
}

for _, column := range spec.ColumnsWithNewDefaultValue {
if foundSpecToUpdate.Column(column.ColumnName) == nil {
return fmt.Errorf("cannot update column %s to have a default value, table %s does not have this column", tableName, column.ColumnName)
}
}

addedColumns := make([]*ColumnSpec, len(foundSpecToUpdate.Columns))
copy(addedColumns, foundSpecToUpdate.Columns)
addedColumns = append(addedColumns, spec.AddedColumns...)

foundSpecToUpdate.Columns = addedColumns
foundSpecToUpdate.UpdatedAt = spec.UpdatedAt

err = c.addTableSpec(foundSpecToUpdate)
if err != nil {
return err
}

return nil
}

func (c *Catalog) addTableSpecRecursive(cat *types.SimpleCatalog, spec *TableSpec) error {
if len(spec.NamePath) > 1 {
subCatalogName := spec.NamePath[0]
Expand Down
64 changes: 64 additions & 0 deletions internal/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,18 @@ type TableSpec struct {
CreatedAt time.Time `json:"createdAt"`
}

type ColumnWithDefaultSpec struct {
ColumnName string
DefaultValue string
}

type AlterTableSpec struct {
NamePath []string `json:"namePath"`
AddedColumns []*ColumnSpec `json:"addedColumns"`
ColumnsWithNewDefaultValue []*ColumnWithDefaultSpec `json:"columnsWithNewDefaultValue"`
UpdatedAt time.Time `json:"updatedAt"`
}

func (s *TableSpec) Column(name string) *ColumnSpec {
for _, col := range s.Columns {
if col.Name == name {
Expand All @@ -123,6 +135,10 @@ func (s *TableSpec) Column(name string) *ColumnSpec {
return nil
}

func (s *AlterTableSpec) TableName() string {
return formatPath(s.NamePath)
}

func (s *TableSpec) TableName() string {
return formatPath(s.NamePath)
}
Expand Down Expand Up @@ -513,6 +529,54 @@ func newPrimaryKey(key *ast.PrimaryKeyNode) []string {
return key.ColumnNameList()
}

func newAlterSpec(ctx context.Context, namePath *NamePath, stmt *ast.AlterTableStmtNode) (*AlterTableSpec, error) {
list := stmt.AlterActionList()
var columns []*ast.ColumnDefinitionNode
var columnsAddDefault []*ColumnWithDefaultSpec

var err error

for i := range list {
action := list[i]
if err != nil {
return nil, err
}
switch action.Kind() {
case ast.AddColumnAction | ast.AlterColumnSetDefaultAction:
err = fmt.Errorf("adding field with default value to an existing table schema is not supported")
case ast.AddColumnAction:
addColumnAction := action.(*ast.AddColumnActionNode)
columns = append(columns, addColumnAction.ColumnDefinition())
case ast.AlterColumnSetDefaultAction:
setDefaultAction := action.(*ast.AlterColumnSetDefaultActionNode)
columnName := setDefaultAction.Column()
defaultValueExpr := setDefaultAction.DefaultValue().Expression()
var defaultValue string
if defaultValueExpr != nil {
// TODO: figure out the timestamp thing here?
defaultValue, err = newNode(defaultValueExpr).FormatSQL(ctx) // assuming newNode has a method to format SQL
if err != nil {
return nil, fmt.Errorf("failed to format default value: %w", err)
}
}
columnsAddDefault = append(columnsAddDefault, &ColumnWithDefaultSpec{
ColumnName: columnName,
DefaultValue: defaultValue,
})
default:
err = fmt.Errorf("unknown alter action kind: %v", action.Kind())
}
}

now := time.Now()
return &AlterTableSpec{
NamePath: namePath.mergePath(stmt.NamePath()),
AddedColumns: newColumnsFromDef(columns),
ColumnsWithNewDefaultValue: columnsAddDefault,
UpdatedAt: now,
}, nil
}

func newTableSpec(namePath *NamePath, stmt *ast.CreateTableStmtNode) *TableSpec {
now := time.Now()
return &TableSpec{
Expand Down
Loading
Loading