diff --git a/server/analyzer/serial.go b/server/analyzer/serial.go index e4372fc343..b78de7596f 100644 --- a/server/analyzer/serial.go +++ b/server/analyzer/serial.go @@ -16,6 +16,7 @@ package analyzer import ( "fmt" + "strings" "github.com/cockroachdb/errors" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" @@ -27,6 +28,7 @@ import ( "github.com/dolthub/doltgresql/core" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/core/sequences" + "github.com/dolthub/doltgresql/server/ast" pgexprs "github.com/dolthub/doltgresql/server/expression" "github.com/dolthub/doltgresql/server/functions/framework" pgnodes "github.com/dolthub/doltgresql/server/node" @@ -43,81 +45,124 @@ func ReplaceSerial(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, scope var ctSequences []*pgnodes.CreateSequence for _, col := range createTable.PkSchema().Schema { - if doltgresType, ok := col.Type.(*pgtypes.DoltgresType); ok { - if doltgresType.IsSerial { - var maxValue int64 - switch doltgresType.Name() { - case "smallserial": - col.Type = pgtypes.Int16 - maxValue = 32767 - case "serial": - col.Type = pgtypes.Int32 - maxValue = 2147483647 - case "bigserial": - col.Type = pgtypes.Int64 - maxValue = 9223372036854775807 - } - - baseSequenceName := fmt.Sprintf("%s_%s_seq", createTable.Name(), col.Name) - sequenceName := baseSequenceName - schemaName, err := core.GetSchemaName(ctx, createTable.Db, "") - if err != nil { - return nil, false, err - } + doltgresType, isDoltgresType := col.Type.(*pgtypes.DoltgresType) + if !isDoltgresType || !doltgresType.IsSerial { + continue + } - relationType, err := core.GetRelationType(ctx, schemaName, baseSequenceName) - if err != nil { - return nil, transform.NewTree, err - } - if relationType != core.RelationType_DoesNotExist { - seqIndex := 1 - for ; seqIndex <= 100; seqIndex++ { - sequenceName = fmt.Sprintf("%s%d", baseSequenceName, seqIndex) - relationType, err = core.GetRelationType(ctx, schemaName, baseSequenceName) - if err != nil { - return nil, transform.NewTree, err - } - if relationType == core.RelationType_DoesNotExist { - break - } + // For always-generated columns we insert a placeholder sequence to be replaced by the actual sequence name. We + // detect that here and treat these generated columns differently than other generated columns on serial types. + isGeneratedFromSequence := false + if col.Generated != nil { + seenNextVal := false + transform.InspectExpr(col.Generated, func(expr sql.Expression) bool { + switch expr := expr.(type) { + case *framework.CompiledFunction: + if strings.ToLower(expr.Name) == "nextval" { + seenNextVal = true } - if seqIndex > 100 { - return nil, transform.NewTree, errors.Errorf("SERIAL sequence name reached max iterations") + case *pgexprs.Literal: + placeholderName := fmt.Sprintf("'%s'", ast.DoltCreateTablePlaceholderSequenceName) + if expr.String() == placeholderName { + isGeneratedFromSequence = true } } + return false + }) - seqName := doltdb.TableName{Name: sequenceName, Schema: schemaName}.String() - nextVal, ok, err := framework.GetFunction("nextval", pgexprs.NewTextLiteral(seqName)) - if err != nil { - return nil, transform.NewTree, err - } - if !ok { - return nil, transform.NewTree, errors.Errorf(`function "nextval" could not be found for SERIAL default`) - } - col.Default = &sql.ColumnDefaultValue{ - Expr: nextVal, - OutType: pgtypes.Int64, - Literal: false, - ReturnNil: false, - Parenthesized: false, - } - ctSequences = append(ctSequences, pgnodes.NewCreateSequence(false, "", &sequences.Sequence{ - Id: id.NewSequence("", sequenceName), - DataTypeID: col.Type.(*pgtypes.DoltgresType).ID, - Persistence: sequences.Persistence_Permanent, - Start: 1, - Current: 1, - Increment: 1, - Minimum: 1, - Maximum: maxValue, - Cache: 1, - Cycle: false, - IsAtEnd: false, - OwnerTable: id.NewTable("", createTable.Name()), - OwnerColumn: col.Name, - })) + if !seenNextVal && !isGeneratedFromSequence { + continue } } + + schemaName, err := core.GetSchemaName(ctx, createTable.Db, "") + if err != nil { + return nil, false, err + } + + sequenceName, err := generateSequenceName(ctx, createTable, col, schemaName) + if err != nil { + return nil, transform.NewTree, err + } + + seqName := doltdb.TableName{Name: sequenceName, Schema: schemaName}.String() + nextVal, isDoltgresType, err := framework.GetFunction("nextval", pgexprs.NewTextLiteral(seqName)) + if err != nil { + return nil, transform.NewTree, err + } + if !isDoltgresType { + return nil, transform.NewTree, errors.Errorf(`function "nextval" could not be found for SERIAL default`) + } + + nextValExpr := &sql.ColumnDefaultValue{ + Expr: nextVal, + OutType: pgtypes.Int64, + Literal: false, + ReturnNil: false, + Parenthesized: false, + } + + if isGeneratedFromSequence { + col.Generated = nextValExpr + } else { + col.Default = nextValExpr + } + + var maxValue int64 + switch doltgresType.Name() { + case "smallserial": + col.Type = pgtypes.Int16 + maxValue = 32767 + case "serial": + col.Type = pgtypes.Int32 + maxValue = 2147483647 + case "bigserial": + col.Type = pgtypes.Int64 + maxValue = 9223372036854775807 + } + + ctSequences = append(ctSequences, pgnodes.NewCreateSequence(false, "", &sequences.Sequence{ + Id: id.NewSequence("", sequenceName), + DataTypeID: col.Type.(*pgtypes.DoltgresType).ID, + Persistence: sequences.Persistence_Permanent, + Start: 1, + Current: 1, + Increment: 1, + Minimum: 1, + Maximum: maxValue, + Cache: 1, + Cycle: false, + IsAtEnd: false, + OwnerTable: id.NewTable("", createTable.Name()), + OwnerColumn: col.Name, + })) } return pgnodes.NewCreateTable(createTable, ctSequences), transform.NewTree, nil } + +// generateSequenceName generates a unique sequence name for a SERIAL column in the table given +func generateSequenceName(ctx *sql.Context, createTable *plan.CreateTable, col *sql.Column, schemaName string) (string, error) { + baseSequenceName := fmt.Sprintf("%s_%s_seq", createTable.Name(), col.Name) + sequenceName := baseSequenceName + relationType, err := core.GetRelationType(ctx, schemaName, baseSequenceName) + if err != nil { + return "", err + } + if relationType != core.RelationType_DoesNotExist { + seqIndex := 1 + for ; seqIndex <= 100; seqIndex++ { + sequenceName = fmt.Sprintf("%s%d", baseSequenceName, seqIndex) + relationType, err = core.GetRelationType(ctx, schemaName, baseSequenceName) + if err != nil { + return "", err + } + if relationType == core.RelationType_DoesNotExist { + break + } + } + if seqIndex > 100 { + return "", errors.Errorf("SERIAL sequence name reached max iterations") + } + } + return sequenceName, nil +} diff --git a/server/ast/column_table_def.go b/server/ast/column_table_def.go index 0646ad7e46..c5b43d9c33 100644 --- a/server/ast/column_table_def.go +++ b/server/ast/column_table_def.go @@ -23,6 +23,10 @@ import ( pgtypes "github.com/dolthub/doltgresql/server/types" ) +// DoltCreateTablePlaceholderSequenceName is a Placeholder name used in translating computed columns to generated +// columns that involve a sequence, used later in analysis +const DoltCreateTablePlaceholderSequenceName = "dolt_create_table_placeholder_sequence" + // nodeColumnTableDef handles *tree.ColumnTableDef nodes. func nodeColumnTableDef(ctx *Context, node *tree.ColumnTableDef) (*vitess.ColumnDefinition, error) { if node == nil { @@ -87,13 +91,34 @@ func nodeColumnTableDef(ctx *Context, node *tree.ColumnTableDef) (*vitess.Column return nil, err } } + + if len(node.Computed.Options) > 0 { + return nil, errors.Errorf("sequence options are not yet supported, create a sequence separately") + } + var generated vitess.Expr - if node.Computed.Computed { + hasGeneratedExpr := node.IsComputed() && node.Computed.Expr != nil + computedByDefaultAsIdentity := node.IsComputed() && !hasGeneratedExpr && node.Computed.ByDefault + computedAsIdentity := node.IsComputed() && !hasGeneratedExpr && !node.Computed.ByDefault + + if hasGeneratedExpr { generated, err = nodeExpr(ctx, node.Computed.Expr) if err != nil { return nil, err } + } else if computedAsIdentity { + generated, err = nodeExpr(ctx, &tree.FuncExpr{ + Func: tree.WrapFunction("nextval"), + Exprs: tree.Exprs{ + tree.NewStrVal(DoltCreateTablePlaceholderSequenceName), + }, + }) + if err != nil { + return nil, err + } + } + if generated != nil { // GMS requires the AST to wrap function expressions in parens if _, ok := generated.(*vitess.FuncExpr); ok { generated = &vitess.ParenExpr{Expr: generated} @@ -103,7 +128,8 @@ func nodeColumnTableDef(ctx *Context, node *tree.ColumnTableDef) (*vitess.Column // appropriate in this context. generated = clearAliases(generated) } - if node.IsSerial { + + if node.IsSerial || computedByDefaultAsIdentity || computedAsIdentity { if resolvedType.IsEmptyType() { return nil, errors.Errorf("serial type was not resolvable") } @@ -121,6 +147,7 @@ func nodeColumnTableDef(ctx *Context, node *tree.ColumnTableDef) (*vitess.Column return nil, errors.Errorf(`multiple default values specified for column "%s"`, node.Name) } } + colDef := &vitess.ColumnDefinition{ Name: vitess.NewColIdent(string(node.Name)), Type: vitess.ColumnType{ diff --git a/testing/generation/command_docs/output/alter_table_test.go b/testing/generation/command_docs/output/alter_table_test.go old mode 100644 new mode 100755 index 0b0bfd221b..9ae8cb9d90 --- a/testing/generation/command_docs/output/alter_table_test.go +++ b/testing/generation/command_docs/output/alter_table_test.go @@ -207,7 +207,7 @@ func TestAlterTable(t *testing.T) { Parses("ALTER TABLE ONLY name ADD column_name data_type CONSTRAINT constraint_name REFERENCES reftable MATCH SIMPLE ON DELETE NO ACTION ON UPDATE NO ACTION DEFERRABLE INITIALLY IMMEDIATE CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) ON DELETE NO ACTION ON UPDATE SET NULL DEFERRABLE , ADD column_name data_type REFERENCES reftable ( refcolumn ) MATCH SIMPLE ON DELETE SET NULL ON UPDATE SET DEFAULT ( column_name ) DEFERRABLE INITIALLY IMMEDIATE REFERENCES reftable ( refcolumn ) MATCH PARTIAL ON DELETE SET NULL ON UPDATE NO ACTION"), Parses("ALTER TABLE name ADD COLUMN column_name data_type REFERENCES reftable ON DELETE NO ACTION ON UPDATE CASCADE REFERENCES reftable MATCH PARTIAL ON DELETE SET DEFAULT ON UPDATE SET DEFAULT NOT DEFERRABLE INITIALLY DEFERRED , ADD COLUMN IF NOT EXISTS column_name data_type COLLATE en_US REFERENCES reftable ON DELETE SET NULL ( column_name ) ON UPDATE SET DEFAULT DEFERRABLE INITIALLY IMMEDIATE REFERENCES reftable ( refcolumn ) MATCH SIMPLE ON DELETE SET NULL ON UPDATE NO ACTION"), Parses("ALTER TABLE name ADD COLUMN column_name data_type COLLATE en_US REFERENCES reftable ( refcolumn ) MATCH SIMPLE ON DELETE SET DEFAULT ON UPDATE SET NULL ( column_name ) NOT DEFERRABLE REFERENCES reftable ( refcolumn ) ON DELETE SET DEFAULT ON UPDATE SET DEFAULT ( column_name , column_name ) DEFERRABLE INITIALLY IMMEDIATE , ADD column_name data_type REFERENCES reftable ( refcolumn ) ON DELETE SET NULL ( column_name , column_name ) ON UPDATE SET DEFAULT NOT DEFERRABLE CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) MATCH SIMPLE ON DELETE SET NULL ON UPDATE NO ACTION"), - Converts("ALTER TABLE IF EXISTS name ADD COLUMN column_name data_type CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) MATCH FULL ON DELETE CASCADE ON UPDATE SET DEFAULT INITIALLY DEFERRED REFERENCES reftable ( refcolumn ) MATCH SIMPLE ON DELETE NO ACTION ON UPDATE NO ACTION DEFERRABLE , ADD column_name data_type COLLATE en_US CONSTRAINT constraint_name GENERATED ALWAYS AS IDENTITY ( NO MINVALUE ) NOT DEFERRABLE INITIALLY DEFERRED CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) MATCH SIMPLE ON DELETE SET NULL ON UPDATE NO ACTION"), + Parses("ALTER TABLE IF EXISTS name ADD COLUMN column_name data_type CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) MATCH FULL ON DELETE CASCADE ON UPDATE SET DEFAULT INITIALLY DEFERRED REFERENCES reftable ( refcolumn ) MATCH SIMPLE ON DELETE NO ACTION ON UPDATE NO ACTION DEFERRABLE , ADD column_name data_type COLLATE en_US CONSTRAINT constraint_name GENERATED ALWAYS AS IDENTITY ( NO MINVALUE ) NOT DEFERRABLE INITIALLY DEFERRED CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) MATCH SIMPLE ON DELETE SET NULL ON UPDATE NO ACTION"), Parses("ALTER TABLE IF EXISTS ONLY name ADD IF NOT EXISTS column_name data_type REFERENCES reftable ( refcolumn ) MATCH FULL ON DELETE SET NULL ( column_name ) ON UPDATE SET DEFAULT ( column_name , column_name ) CONSTRAINT constraint_name REFERENCES reftable ON DELETE SET DEFAULT ( column_name ) ON UPDATE NO ACTION INITIALLY DEFERRED , ADD column_name data_type CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) MATCH PARTIAL ON DELETE RESTRICT ON UPDATE NO ACTION INITIALLY IMMEDIATE CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) MATCH SIMPLE ON DELETE SET NULL ON UPDATE NO ACTION"), Parses("ALTER TABLE ONLY name ADD IF NOT EXISTS column_name data_type COLLATE en_US REFERENCES reftable ( refcolumn ) MATCH FULL ON DELETE RESTRICT ON UPDATE SET NULL ( column_name , column_name ) NOT DEFERRABLE REFERENCES reftable MATCH SIMPLE ON DELETE RESTRICT ON UPDATE SET DEFAULT ( column_name , column_name ) INITIALLY IMMEDIATE , ADD COLUMN column_name data_type REFERENCES reftable MATCH FULL ON DELETE SET NULL ( column_name ) ON UPDATE SET DEFAULT REFERENCES reftable ON DELETE SET NULL ( column_name ) ON UPDATE NO ACTION"), Parses("ALTER TABLE name ADD column_name data_type COLLATE en_US CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) ON DELETE SET NULL ( column_name , column_name ) ON UPDATE SET DEFAULT INITIALLY DEFERRED CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) MATCH PARTIAL ON DELETE NO ACTION ON UPDATE SET DEFAULT ( column_name , column_name ) NOT DEFERRABLE , ADD COLUMN column_name data_type COLLATE en_US REFERENCES reftable ( refcolumn ) MATCH FULL ON DELETE SET NULL ON UPDATE RESTRICT DEFERRABLE INITIALLY IMMEDIATE CONSTRAINT constraint_name REFERENCES reftable ON DELETE SET NULL ( column_name ) ON UPDATE NO ACTION"), @@ -10019,5 +10019,7 @@ func TestAlterTable(t *testing.T) { Parses("ALTER TABLE name * ADD COLUMN column_name data_type COLLATE en_US CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) ON DELETE SET DEFAULT ( column_name ) ON UPDATE SET NULL ( column_name ) DEFERRABLE CONSTRAINT constraint_name REFERENCES reftable MATCH PARTIAL ON UPDATE CASCADE NOT DEFERRABLE INITIALLY DEFERRED , ADD EXCLUDE ( column_name DESC NULLS FIRST WITH + , ( expression ) NULLS LAST WITH + ) INCLUDE ( column_name , column_name ) WITH ( fillfactor ) USING INDEX TABLESPACE tablespace_name NOT DEFERRABLE INITIALLY IMMEDIATE NOT VALID"), Parses("ALTER TABLE IF EXISTS name ADD column_name data_type CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) MATCH SIMPLE ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY IMMEDIATE CONSTRAINT constraint_name REFERENCES reftable ( refcolumn ) MATCH FULL ON DELETE SET NULL ( column_name , column_name ) ON UPDATE SET NULL NOT DEFERRABLE INITIALLY DEFERRED , ADD EXCLUDE ( column_name NULLS LAST WITH + , column_name opclass ASC NULLS LAST WITH + ) INCLUDE ( column_name , column_name ) WITH ( fillfactor , fillfactor = value ) USING INDEX TABLESPACE tablespace_name NOT DEFERRABLE INITIALLY IMMEDIATE NOT VALID"), } + RunTests(t, tests) + // RewriteTests(t, tests, "alter_table_test.go") } diff --git a/testing/generation/command_docs/output/framework_test.go b/testing/generation/command_docs/output/framework_test.go index 36ccf11099..66c8626824 100644 --- a/testing/generation/command_docs/output/framework_test.go +++ b/testing/generation/command_docs/output/framework_test.go @@ -195,6 +195,10 @@ func RewriteTests(t *testing.T, tests []QueryParses, file string) { return ast.Convert(statement) }() + if err != nil { + line = testStatementRegex.ReplaceAll(line, []byte("${1}Parses(")) + } + if !test.ShouldConvert() { if err == nil && vitessAST != nil { line = testStatementRegex.ReplaceAll(line, []byte("${1}Converts(")) diff --git a/testing/go/create_table_test.go b/testing/go/create_table_test.go index 6eb43de3f1..6cff5ead98 100755 --- a/testing/go/create_table_test.go +++ b/testing/go/create_table_test.go @@ -303,6 +303,27 @@ func TestCreateTable(t *testing.T) { }, }, }, + { + Name: "primary key GENERATED ALWAYS AS IDENTITY", + SetUpScript: []string{ + `create table t1 ( + a BIGINT NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + b varchar(100) + );`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into t1 (b) values ('foo') returning a;", + Expected: []sql.Row{ + {1}, + }, + }, + { + Query: "insert into t1 (a, b) values (2, 'foo') returning a;", + ExpectedErr: "The value specified for generated column \"a\" in table \"t1\" is not allowed", + }, + }, + }, { Name: "create table with default value", SetUpScript: []string{ diff --git a/testing/go/sequences_test.go b/testing/go/sequences_test.go index 0c9101f185..1290376742 100644 --- a/testing/go/sequences_test.go +++ b/testing/go/sequences_test.go @@ -850,5 +850,60 @@ func TestSequences(t *testing.T) { }, }, }, + { + Name: "identity generated by default", + SetUpScript: []string{ + `CREATE TABLE "django_migrations" ( + "id" bigint NOT NULL PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, + "app" varchar(255) NOT NULL, + "name" varchar(255) NOT NULL, + "applied" timestamp with time zone NOT NULL)`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `INSERT INTO "django_migrations" ("app", "name", "applied") VALUES ('contenttypes', '0001_initial', '2025-03-25T17:45:54.794344+00:00'::timestamptz) RETURNING "django_migrations"."id"`, + Expected: []sql.Row{ + {1}, + }, + }, + { + Query: `INSERT INTO "django_migrations" ("app", "name", "applied") VALUES ('contenttypes', '0001_initial', '2025-03-25T17:45:54.794344+00:00'::timestamptz) RETURNING "django_migrations"."id"`, + Expected: []sql.Row{ + {2}, + }, + }, + { + Query: `INSERT INTO "django_migrations" ("id", "app", "name", "applied") VALUES (100, 'contenttypes', '0001_initial', '2025-03-25T17:45:54.794344+00:00'::timestamptz) RETURNING "django_migrations"."id"`, + Expected: []sql.Row{ + {100}, + }, + }, + }, + }, + { + Name: "identity generated by default with sequence options", + Skip: true, // not supported yet, need to add sequence info into DML node given to GMS + SetUpScript: []string{ + `CREATE TABLE "django_migrations" ( + "id" bigint NOT NULL PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY (START WITH 100 INCREMENT BY 2), + "app" varchar(255) NOT NULL, + "name" varchar(255) NOT NULL, + "applied" timestamp with time zone NOT NULL)`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `INSERT INTO "django_migrations" ("app", "name", "applied") VALUES ('contenttypes', '0001_initial', '2025-03-25T17:45:54.794344+00:00'::timestamptz) RETURNING "django_migrations"."id"`, + Expected: []sql.Row{ + {100}, + }, + }, + { + Query: `INSERT INTO "django_migrations" ("app", "name", "applied") VALUES ('contenttypes', '0001_initial', '2025-03-25T17:45:54.794344+00:00'::timestamptz) RETURNING "django_migrations"."id"`, + Expected: []sql.Row{ + {102}, + }, + }, + }, + }, }) }