Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 113 additions & 68 deletions server/analyzer/serial.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package analyzer

import (
"fmt"
"strings"

"github.com/cockroachdb/errors"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
Expand All @@ -27,6 +28,7 @@ import (
"github.com/dolthub/doltgresql/core"
"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/core/sequences"
"github.com/dolthub/doltgresql/server/ast"
pgexprs "github.com/dolthub/doltgresql/server/expression"
"github.com/dolthub/doltgresql/server/functions/framework"
pgnodes "github.com/dolthub/doltgresql/server/node"
Expand All @@ -43,81 +45,124 @@ func ReplaceSerial(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, scope

var ctSequences []*pgnodes.CreateSequence
for _, col := range createTable.PkSchema().Schema {
if doltgresType, ok := col.Type.(*pgtypes.DoltgresType); ok {
if doltgresType.IsSerial {
var maxValue int64
switch doltgresType.Name() {
case "smallserial":
col.Type = pgtypes.Int16
maxValue = 32767
case "serial":
col.Type = pgtypes.Int32
maxValue = 2147483647
case "bigserial":
col.Type = pgtypes.Int64
maxValue = 9223372036854775807
}

baseSequenceName := fmt.Sprintf("%s_%s_seq", createTable.Name(), col.Name)
sequenceName := baseSequenceName
schemaName, err := core.GetSchemaName(ctx, createTable.Db, "")
if err != nil {
return nil, false, err
}
doltgresType, isDoltgresType := col.Type.(*pgtypes.DoltgresType)
if !isDoltgresType || !doltgresType.IsSerial {
continue
}

relationType, err := core.GetRelationType(ctx, schemaName, baseSequenceName)
if err != nil {
return nil, transform.NewTree, err
}
if relationType != core.RelationType_DoesNotExist {
seqIndex := 1
for ; seqIndex <= 100; seqIndex++ {
sequenceName = fmt.Sprintf("%s%d", baseSequenceName, seqIndex)
relationType, err = core.GetRelationType(ctx, schemaName, baseSequenceName)
if err != nil {
return nil, transform.NewTree, err
}
if relationType == core.RelationType_DoesNotExist {
break
}
// For always-generated columns we insert a placeholder sequence to be replaced by the actual sequence name. We
// detect that here and treat these generated columns differently than other generated columns on serial types.
isGeneratedFromSequence := false
if col.Generated != nil {
seenNextVal := false
transform.InspectExpr(col.Generated, func(expr sql.Expression) bool {
switch expr := expr.(type) {
case *framework.CompiledFunction:
if strings.ToLower(expr.Name) == "nextval" {
seenNextVal = true
}
if seqIndex > 100 {
return nil, transform.NewTree, errors.Errorf("SERIAL sequence name reached max iterations")
case *pgexprs.Literal:
placeholderName := fmt.Sprintf("'%s'", ast.DoltCreateTablePlaceholderSequenceName)
if expr.String() == placeholderName {
isGeneratedFromSequence = true
}
}
return false
})

seqName := doltdb.TableName{Name: sequenceName, Schema: schemaName}.String()
nextVal, ok, err := framework.GetFunction("nextval", pgexprs.NewTextLiteral(seqName))
if err != nil {
return nil, transform.NewTree, err
}
if !ok {
return nil, transform.NewTree, errors.Errorf(`function "nextval" could not be found for SERIAL default`)
}
col.Default = &sql.ColumnDefaultValue{
Expr: nextVal,
OutType: pgtypes.Int64,
Literal: false,
ReturnNil: false,
Parenthesized: false,
}
ctSequences = append(ctSequences, pgnodes.NewCreateSequence(false, "", &sequences.Sequence{
Id: id.NewSequence("", sequenceName),
DataTypeID: col.Type.(*pgtypes.DoltgresType).ID,
Persistence: sequences.Persistence_Permanent,
Start: 1,
Current: 1,
Increment: 1,
Minimum: 1,
Maximum: maxValue,
Cache: 1,
Cycle: false,
IsAtEnd: false,
OwnerTable: id.NewTable("", createTable.Name()),
OwnerColumn: col.Name,
}))
if !seenNextVal && !isGeneratedFromSequence {
continue
}
}

schemaName, err := core.GetSchemaName(ctx, createTable.Db, "")
if err != nil {
return nil, false, err
}

sequenceName, err := generateSequenceName(ctx, createTable, col, schemaName)
if err != nil {
return nil, transform.NewTree, err
}

seqName := doltdb.TableName{Name: sequenceName, Schema: schemaName}.String()
nextVal, isDoltgresType, err := framework.GetFunction("nextval", pgexprs.NewTextLiteral(seqName))
if err != nil {
return nil, transform.NewTree, err
}
if !isDoltgresType {
return nil, transform.NewTree, errors.Errorf(`function "nextval" could not be found for SERIAL default`)
}

nextValExpr := &sql.ColumnDefaultValue{
Expr: nextVal,
OutType: pgtypes.Int64,
Literal: false,
ReturnNil: false,
Parenthesized: false,
}

if isGeneratedFromSequence {
col.Generated = nextValExpr
} else {
col.Default = nextValExpr
}

var maxValue int64
switch doltgresType.Name() {
case "smallserial":
col.Type = pgtypes.Int16
maxValue = 32767
case "serial":
col.Type = pgtypes.Int32
maxValue = 2147483647
case "bigserial":
col.Type = pgtypes.Int64
maxValue = 9223372036854775807
}

ctSequences = append(ctSequences, pgnodes.NewCreateSequence(false, "", &sequences.Sequence{
Id: id.NewSequence("", sequenceName),
DataTypeID: col.Type.(*pgtypes.DoltgresType).ID,
Persistence: sequences.Persistence_Permanent,
Start: 1,
Current: 1,
Increment: 1,
Minimum: 1,
Maximum: maxValue,
Cache: 1,
Cycle: false,
IsAtEnd: false,
OwnerTable: id.NewTable("", createTable.Name()),
OwnerColumn: col.Name,
}))
}
return pgnodes.NewCreateTable(createTable, ctSequences), transform.NewTree, nil
}

// generateSequenceName generates a unique sequence name for a SERIAL column in the table given
func generateSequenceName(ctx *sql.Context, createTable *plan.CreateTable, col *sql.Column, schemaName string) (string, error) {
baseSequenceName := fmt.Sprintf("%s_%s_seq", createTable.Name(), col.Name)
sequenceName := baseSequenceName
relationType, err := core.GetRelationType(ctx, schemaName, baseSequenceName)
if err != nil {
return "", err
}
if relationType != core.RelationType_DoesNotExist {
seqIndex := 1
for ; seqIndex <= 100; seqIndex++ {
sequenceName = fmt.Sprintf("%s%d", baseSequenceName, seqIndex)
relationType, err = core.GetRelationType(ctx, schemaName, baseSequenceName)
if err != nil {
return "", err
}
if relationType == core.RelationType_DoesNotExist {
break
}
}
if seqIndex > 100 {
return "", errors.Errorf("SERIAL sequence name reached max iterations")
}
}
return sequenceName, nil
}
31 changes: 29 additions & 2 deletions server/ast/column_table_def.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ import (
pgtypes "github.com/dolthub/doltgresql/server/types"
)

// DoltCreateTablePlaceholderSequenceName is a Placeholder name used in translating computed columns to generated
// columns that involve a sequence, used later in analysis
const DoltCreateTablePlaceholderSequenceName = "dolt_create_table_placeholder_sequence"

// nodeColumnTableDef handles *tree.ColumnTableDef nodes.
func nodeColumnTableDef(ctx *Context, node *tree.ColumnTableDef) (*vitess.ColumnDefinition, error) {
if node == nil {
Expand Down Expand Up @@ -87,13 +91,34 @@ func nodeColumnTableDef(ctx *Context, node *tree.ColumnTableDef) (*vitess.Column
return nil, err
}
}

if len(node.Computed.Options) > 0 {
return nil, errors.Errorf("sequence options are not yet supported, create a sequence separately")
}

var generated vitess.Expr
if node.Computed.Computed {
hasGeneratedExpr := node.IsComputed() && node.Computed.Expr != nil
computedByDefaultAsIdentity := node.IsComputed() && !hasGeneratedExpr && node.Computed.ByDefault
computedAsIdentity := node.IsComputed() && !hasGeneratedExpr && !node.Computed.ByDefault

if hasGeneratedExpr {
generated, err = nodeExpr(ctx, node.Computed.Expr)
if err != nil {
return nil, err
}
} else if computedAsIdentity {
generated, err = nodeExpr(ctx, &tree.FuncExpr{
Func: tree.WrapFunction("nextval"),
Exprs: tree.Exprs{
tree.NewStrVal(DoltCreateTablePlaceholderSequenceName),
},
})
if err != nil {
return nil, err
}
}

if generated != nil {
// GMS requires the AST to wrap function expressions in parens
if _, ok := generated.(*vitess.FuncExpr); ok {
generated = &vitess.ParenExpr{Expr: generated}
Expand All @@ -103,7 +128,8 @@ func nodeColumnTableDef(ctx *Context, node *tree.ColumnTableDef) (*vitess.Column
// appropriate in this context.
generated = clearAliases(generated)
}
if node.IsSerial {

if node.IsSerial || computedByDefaultAsIdentity || computedAsIdentity {
if resolvedType.IsEmptyType() {
return nil, errors.Errorf("serial type was not resolvable")
}
Expand All @@ -121,6 +147,7 @@ func nodeColumnTableDef(ctx *Context, node *tree.ColumnTableDef) (*vitess.Column
return nil, errors.Errorf(`multiple default values specified for column "%s"`, node.Name)
}
}

colDef := &vitess.ColumnDefinition{
Name: vitess.NewColIdent(string(node.Name)),
Type: vitess.ColumnType{
Expand Down
4 changes: 3 additions & 1 deletion testing/generation/command_docs/output/alter_table_test.go
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func TestAlterTable(t *testing.T) {
Parses("ALTER TABLE ONLY name ADD column_name data_type CONSTRAINT constraint_name REFERENCES reftable MATCH SIMPLE ON DELETE NO ACTION ON UPDATE NO ACTION DEFERRABLE INITIALLY IMMEDIATE CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) ON DELETE NO ACTION ON UPDATE SET NULL DEFERRABLE , ADD column_name data_type REFERENCES reftable ( refcolumn ) MATCH SIMPLE ON DELETE SET NULL ON UPDATE SET DEFAULT ( column_name ) DEFERRABLE INITIALLY IMMEDIATE REFERENCES reftable ( refcolumn ) MATCH PARTIAL ON DELETE SET NULL ON UPDATE NO ACTION"),
Parses("ALTER TABLE name ADD COLUMN column_name data_type REFERENCES reftable ON DELETE NO ACTION ON UPDATE CASCADE REFERENCES reftable MATCH PARTIAL ON DELETE SET DEFAULT ON UPDATE SET DEFAULT NOT DEFERRABLE INITIALLY DEFERRED , ADD COLUMN IF NOT EXISTS column_name data_type COLLATE en_US REFERENCES reftable ON DELETE SET NULL ( column_name ) ON UPDATE SET DEFAULT DEFERRABLE INITIALLY IMMEDIATE REFERENCES reftable ( refcolumn ) MATCH SIMPLE ON DELETE SET NULL ON UPDATE NO ACTION"),
Parses("ALTER TABLE name ADD COLUMN column_name data_type COLLATE en_US REFERENCES reftable ( refcolumn ) MATCH SIMPLE ON DELETE SET DEFAULT ON UPDATE SET NULL ( column_name ) NOT DEFERRABLE REFERENCES reftable ( refcolumn ) ON DELETE SET DEFAULT ON UPDATE SET DEFAULT ( column_name , column_name ) DEFERRABLE INITIALLY IMMEDIATE , ADD column_name data_type REFERENCES reftable ( refcolumn ) ON DELETE SET NULL ( column_name , column_name ) ON UPDATE SET DEFAULT NOT DEFERRABLE CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) MATCH SIMPLE ON DELETE SET NULL ON UPDATE NO ACTION"),
Converts("ALTER TABLE IF EXISTS name ADD COLUMN column_name data_type CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) MATCH FULL ON DELETE CASCADE ON UPDATE SET DEFAULT INITIALLY DEFERRED REFERENCES reftable ( refcolumn ) MATCH SIMPLE ON DELETE NO ACTION ON UPDATE NO ACTION DEFERRABLE , ADD column_name data_type COLLATE en_US CONSTRAINT constraint_name GENERATED ALWAYS AS IDENTITY ( NO MINVALUE ) NOT DEFERRABLE INITIALLY DEFERRED CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) MATCH SIMPLE ON DELETE SET NULL ON UPDATE NO ACTION"),
Parses("ALTER TABLE IF EXISTS name ADD COLUMN column_name data_type CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) MATCH FULL ON DELETE CASCADE ON UPDATE SET DEFAULT INITIALLY DEFERRED REFERENCES reftable ( refcolumn ) MATCH SIMPLE ON DELETE NO ACTION ON UPDATE NO ACTION DEFERRABLE , ADD column_name data_type COLLATE en_US CONSTRAINT constraint_name GENERATED ALWAYS AS IDENTITY ( NO MINVALUE ) NOT DEFERRABLE INITIALLY DEFERRED CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) MATCH SIMPLE ON DELETE SET NULL ON UPDATE NO ACTION"),
Parses("ALTER TABLE IF EXISTS ONLY name ADD IF NOT EXISTS column_name data_type REFERENCES reftable ( refcolumn ) MATCH FULL ON DELETE SET NULL ( column_name ) ON UPDATE SET DEFAULT ( column_name , column_name ) CONSTRAINT constraint_name REFERENCES reftable ON DELETE SET DEFAULT ( column_name ) ON UPDATE NO ACTION INITIALLY DEFERRED , ADD column_name data_type CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) MATCH PARTIAL ON DELETE RESTRICT ON UPDATE NO ACTION INITIALLY IMMEDIATE CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) MATCH SIMPLE ON DELETE SET NULL ON UPDATE NO ACTION"),
Parses("ALTER TABLE ONLY name ADD IF NOT EXISTS column_name data_type COLLATE en_US REFERENCES reftable ( refcolumn ) MATCH FULL ON DELETE RESTRICT ON UPDATE SET NULL ( column_name , column_name ) NOT DEFERRABLE REFERENCES reftable MATCH SIMPLE ON DELETE RESTRICT ON UPDATE SET DEFAULT ( column_name , column_name ) INITIALLY IMMEDIATE , ADD COLUMN column_name data_type REFERENCES reftable MATCH FULL ON DELETE SET NULL ( column_name ) ON UPDATE SET DEFAULT REFERENCES reftable ON DELETE SET NULL ( column_name ) ON UPDATE NO ACTION"),
Parses("ALTER TABLE name ADD column_name data_type COLLATE en_US CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) ON DELETE SET NULL ( column_name , column_name ) ON UPDATE SET DEFAULT INITIALLY DEFERRED CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) MATCH PARTIAL ON DELETE NO ACTION ON UPDATE SET DEFAULT ( column_name , column_name ) NOT DEFERRABLE , ADD COLUMN column_name data_type COLLATE en_US REFERENCES reftable ( refcolumn ) MATCH FULL ON DELETE SET NULL ON UPDATE RESTRICT DEFERRABLE INITIALLY IMMEDIATE CONSTRAINT constraint_name REFERENCES reftable ON DELETE SET NULL ( column_name ) ON UPDATE NO ACTION"),
Expand Down Expand Up @@ -10019,5 +10019,7 @@ func TestAlterTable(t *testing.T) {
Parses("ALTER TABLE name * ADD COLUMN column_name data_type COLLATE en_US CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) ON DELETE SET DEFAULT ( column_name ) ON UPDATE SET NULL ( column_name ) DEFERRABLE CONSTRAINT constraint_name REFERENCES reftable MATCH PARTIAL ON UPDATE CASCADE NOT DEFERRABLE INITIALLY DEFERRED , ADD EXCLUDE ( column_name DESC NULLS FIRST WITH + , ( expression ) NULLS LAST WITH + ) INCLUDE ( column_name , column_name ) WITH ( fillfactor ) USING INDEX TABLESPACE tablespace_name NOT DEFERRABLE INITIALLY IMMEDIATE NOT VALID"),
Parses("ALTER TABLE IF EXISTS name ADD column_name data_type CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) MATCH SIMPLE ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY IMMEDIATE CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) MATCH FULL ON DELETE SET NULL ( column_name , column_name ) ON UPDATE SET NULL NOT DEFERRABLE INITIALLY DEFERRED , ADD EXCLUDE ( column_name NULLS LAST WITH + , column_name opclass ASC NULLS LAST WITH + ) INCLUDE ( column_name , column_name ) WITH ( fillfactor , fillfactor = value ) USING INDEX TABLESPACE tablespace_name NOT DEFERRABLE INITIALLY IMMEDIATE NOT VALID"),
}

RunTests(t, tests)
// RewriteTests(t, tests, "alter_table_test.go")
}
4 changes: 4 additions & 0 deletions testing/generation/command_docs/output/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ func RewriteTests(t *testing.T, tests []QueryParses, file string) {
return ast.Convert(statement)
}()

if err != nil {
line = testStatementRegex.ReplaceAll(line, []byte("${1}Parses("))
}

if !test.ShouldConvert() {
if err == nil && vitessAST != nil {
line = testStatementRegex.ReplaceAll(line, []byte("${1}Converts("))
Expand Down
21 changes: 21 additions & 0 deletions testing/go/create_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,27 @@ func TestCreateTable(t *testing.T) {
},
},
},
{
Name: "primary key GENERATED ALWAYS AS IDENTITY",
SetUpScript: []string{
`create table t1 (
a BIGINT NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY,
b varchar(100)
);`,
},
Assertions: []ScriptTestAssertion{
{
Query: "insert into t1 (b) values ('foo') returning a;",
Expected: []sql.Row{
{1},
},
},
{
Query: "insert into t1 (a, b) values (2, 'foo') returning a;",
ExpectedErr: "The value specified for generated column \"a\" in table \"t1\" is not allowed",
},
},
},
{
Name: "create table with default value",
SetUpScript: []string{
Expand Down
55 changes: 55 additions & 0 deletions testing/go/sequences_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -850,5 +850,60 @@ func TestSequences(t *testing.T) {
},
},
},
{
Name: "identity generated by default",
SetUpScript: []string{
`CREATE TABLE "django_migrations" (
"id" bigint NOT NULL PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY,
"app" varchar(255) NOT NULL,
"name" varchar(255) NOT NULL,
"applied" timestamp with time zone NOT NULL)`,
},
Assertions: []ScriptTestAssertion{
{
Query: `INSERT INTO "django_migrations" ("app", "name", "applied") VALUES ('contenttypes', '0001_initial', '2025-03-25T17:45:54.794344+00:00'::timestamptz) RETURNING "django_migrations"."id"`,
Expected: []sql.Row{
{1},
},
},
{
Query: `INSERT INTO "django_migrations" ("app", "name", "applied") VALUES ('contenttypes', '0001_initial', '2025-03-25T17:45:54.794344+00:00'::timestamptz) RETURNING "django_migrations"."id"`,
Expected: []sql.Row{
{2},
},
},
{
Query: `INSERT INTO "django_migrations" ("id", "app", "name", "applied") VALUES (100, 'contenttypes', '0001_initial', '2025-03-25T17:45:54.794344+00:00'::timestamptz) RETURNING "django_migrations"."id"`,
Expected: []sql.Row{
{100},
},
},
},
},
{
Name: "identity generated by default with sequence options",
Skip: true, // not supported yet, need to add sequence info into DML node given to GMS
SetUpScript: []string{
`CREATE TABLE "django_migrations" (
"id" bigint NOT NULL PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY (START WITH 100 INCREMENT BY 2),
"app" varchar(255) NOT NULL,
"name" varchar(255) NOT NULL,
"applied" timestamp with time zone NOT NULL)`,
},
Assertions: []ScriptTestAssertion{
{
Query: `INSERT INTO "django_migrations" ("app", "name", "applied") VALUES ('contenttypes', '0001_initial', '2025-03-25T17:45:54.794344+00:00'::timestamptz) RETURNING "django_migrations"."id"`,
Expected: []sql.Row{
{100},
},
},
{
Query: `INSERT INTO "django_migrations" ("app", "name", "applied") VALUES ('contenttypes', '0001_initial', '2025-03-25T17:45:54.794344+00:00'::timestamptz) RETURNING "django_migrations"."id"`,
Expected: []sql.Row{
{102},
},
},
},
},
})
}