diff --git a/go.mod b/go.mod index 5063bb3652..100c7c8064 100644 --- a/go.mod +++ b/go.mod @@ -1,18 +1,18 @@ module github.com/dolthub/doltgresql -go 1.25.3 +go 1.25.6 require ( github.com/PuerkitoBio/goquery v1.8.1 github.com/cockroachdb/apd/v2 v2.0.3-0.20200518165714-d020e156310a github.com/cockroachdb/errors v1.7.5 - github.com/dolthub/dolt/go v0.40.5-0.20260122084121-6b5d5373d1ec + github.com/dolthub/dolt/go v0.40.5-0.20260205001014-db7263eb669c github.com/dolthub/eventsapi_schema v0.0.0-20250915094920-eadfd39051ca github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 - github.com/dolthub/go-mysql-server v0.20.1-0.20260121234050-2f0507726303 + github.com/dolthub/go-mysql-server v0.20.1-0.20260204193159-86990113e4cc github.com/dolthub/pg_query_go/v6 v6.0.0-20251215122834-fb20be4254d1 github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 - github.com/dolthub/vitess v0.0.0-20260121194826-a5ce52b608e4 + github.com/dolthub/vitess v0.0.0-20260202234501-b14ed9b1632b github.com/fatih/color v1.13.0 github.com/goccy/go-json v0.10.2 github.com/gogo/protobuf v1.3.2 diff --git a/go.sum b/go.sum index 1c000732f5..15b8600719 100644 --- a/go.sum +++ b/go.sum @@ -228,8 +228,8 @@ github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12 h1:I github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12/go.mod h1:rN7X8BHwkjPcfMQQ2QTAq/xM3leUSGLfb+1Js7Y6TVo= github.com/dolthub/dolt-mcp v0.2.2 h1:bpROmam74n95uU4EA3BpOIVlTDT0pzeFMBwe/YRq2mI= github.com/dolthub/dolt-mcp v0.2.2/go.mod h1:S++DJ4QWTAXq+0TNzFa7Oq3IhoT456DJHwAINFAHgDQ= -github.com/dolthub/dolt/go v0.40.5-0.20260122084121-6b5d5373d1ec h1:IzUyIbG7oX3Je53nxq0ZNrsp3C9n2THmiWMJHA9KeDo= -github.com/dolthub/dolt/go v0.40.5-0.20260122084121-6b5d5373d1ec/go.mod h1:UUKnXBHOBTr5CzxFLOZ/9Pm0O+F4ZArKRxs5kUMpXH8= +github.com/dolthub/dolt/go v0.40.5-0.20260205001014-db7263eb669c h1:lrTKtYUO5T5ka0/rgOx9TnSIFuzUv6R86wx7cRDVuUU= +github.com/dolthub/dolt/go v0.40.5-0.20260205001014-db7263eb669c/go.mod h1:GTxR5JEXqpGgYHQc6nBnu/m+TgEC/LP9IPB+b3tIrvc= github.com/dolthub/eventsapi_schema v0.0.0-20250915094920-eadfd39051ca h1:BGFz/0OlKIuC6qHIZQbvPapFvdAJkeEyGXWVgL5clmE= github.com/dolthub/eventsapi_schema v0.0.0-20250915094920-eadfd39051ca/go.mod h1:CoDLfgPqHyBtth0Cp+fi/CmC4R81zJNX4wPjShdZ+Bw= github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww= @@ -238,8 +238,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20250916051405-78a38d478790 h1:zxMsH7RLiG+dlZ/y0LgJHTV26XoiSJcuWq+em6t6VVc= github.com/dolthub/go-icu-regex v0.0.0-20250916051405-78a38d478790/go.mod h1:F3cnm+vMRK1HaU6+rNqQrOCyR03HHhR1GWG2gnPOqaE= -github.com/dolthub/go-mysql-server v0.20.1-0.20260121234050-2f0507726303 h1:O4s+FF9G1Jv7uRnR/bB9jTzvPfvnZbcZOZMa3cip8d4= -github.com/dolthub/go-mysql-server v0.20.1-0.20260121234050-2f0507726303/go.mod h1:7L2EdzgWLnS7blMNuF+67RTZVMhRLyKQv36mZcjU8u8= +github.com/dolthub/go-mysql-server v0.20.1-0.20260204193159-86990113e4cc h1:gnUcLBhmGmUeoYTdkPuknx2X5Imjo1/O4onz3duwC90= +github.com/dolthub/go-mysql-server v0.20.1-0.20260204193159-86990113e4cc/go.mod h1:LEWdXw6LKjdonOv2X808RpUc8wZVtQx4ZEPvmDWkvY4= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE= @@ -250,8 +250,8 @@ github.com/dolthub/pg_query_go/v6 v6.0.0-20251215122834-fb20be4254d1 h1:GY17cGA4 github.com/dolthub/pg_query_go/v6 v6.0.0-20251215122834-fb20be4254d1/go.mod h1:qnrZP3/1slFl2Bq5yw38HLOsArZareGwdpEceriblLc= github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 h1:JWkKRE4EHUcEVQCMRBej8DYxjYjRz/9MdF/NNQh0o70= github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216/go.mod h1:e/FIZVvT2IR53HBCAo41NjqgtEnjMJGKca3Y/dAmZaA= -github.com/dolthub/vitess v0.0.0-20260121194826-a5ce52b608e4 h1:isOZRx9OfOdwZ4UmeYrS+UYiiu02e8taNJLAuqZjfcQ= -github.com/dolthub/vitess v0.0.0-20260121194826-a5ce52b608e4/go.mod h1:FLWqdXsAeeBQyFwDjmBVu0GnbjI2MKeRf3tRVdJEKlI= +github.com/dolthub/vitess v0.0.0-20260202234501-b14ed9b1632b h1:B8QS0U5EHtJTiOptjti1cH/OiE6uczyhePtvVFigf3w= +github.com/dolthub/vitess v0.0.0-20260202234501-b14ed9b1632b/go.mod h1:eLLslh1CSPMf89pPcaMG4yM72PQbTN9OUYJeAy0fAis= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= diff --git a/postgres/parser/parser/sql.y b/postgres/parser/parser/sql.y index 9698f94145..7f44ccc48c 100644 --- a/postgres/parser/parser/sql.y +++ b/postgres/parser/parser/sql.y @@ -13073,7 +13073,7 @@ d_expr: | '(' a_expr ')' '.' '@' ICONST { idx, err := $6.numVal().AsInt32() - if err != nil || idx <= 0 { return setErr(sqllex, err) } + if err != nil { return setErr(sqllex, err) } $$.val = &tree.ColumnAccessExpr{Expr: $2.expr(), ByIndex: true, ColIndex: int(idx-1)} } | '(' a_expr ')' diff --git a/server/analyzer/resolve_type.go b/server/analyzer/resolve_type.go index 98ee1916be..be3e6e812d 100644 --- a/server/analyzer/resolve_type.go +++ b/server/analyzer/resolve_type.go @@ -15,6 +15,7 @@ package analyzer import ( + "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/analyzer" "github.com/dolthub/go-mysql-server/sql/plan" @@ -24,7 +25,7 @@ import ( "github.com/dolthub/doltgresql/core" "github.com/dolthub/doltgresql/core/id" - "github.com/dolthub/doltgresql/server/expression" + pgexprs "github.com/dolthub/doltgresql/server/expression" pgtransform "github.com/dolthub/doltgresql/server/transform" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -118,26 +119,38 @@ func ResolveTypeForNodes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, // ResolveTypeForExprs replaces types.ResolvableType to appropriate pgtypes.DoltgresType. func ResolveTypeForExprs(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, scope *plan.Scope, selector analyzer.RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) { - return pgtransform.NodeExprsWithOpaque(node, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) { - var same = transform.SameTree - switch e := expr.(type) { - case *expression.ExplicitCast: - if rt, ok := e.Type().(*pgtypes.DoltgresType); ok && !rt.IsResolvedType() { + return pgtransform.NodeExprsWithOpaque(node, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { + switch expr := e.(type) { + case *pgexprs.ColumnAccess: + exprType, _ := expr.Type().(*pgtypes.DoltgresType) + if exprType == nil { + return nil, transform.NewTree, errors.New("column access is missing its child expression") + } else if exprType.IsResolvedType() { + // The type has already been resolved + return expr, transform.SameTree, nil + } + newType, err := resolveType(ctx, exprType) + if err != nil { + return nil, transform.NewTree, err + } + return expr.WithType(newType), transform.NewTree, nil + case *pgexprs.ExplicitCast: + if rt, ok := expr.Type().(*pgtypes.DoltgresType); ok && !rt.IsResolvedType() { dt, err := resolveType(ctx, rt) if err != nil { return nil, transform.NewTree, err } - same = transform.NewTree if !dt.IsDefined { return nil, transform.NewTree, pgtypes.ErrTypeIsOnlyAShell.New(dt.Name()) } else { - expr = e.WithCastToType(dt) + return expr.WithCastToType(dt), transform.NewTree, nil } + } else { + return expr, transform.SameTree, nil } - return expr, same, nil default: // TODO: add expressions that use unresolved types like domain - return e, transform.SameTree, nil + return expr, transform.SameTree, nil } }) } diff --git a/server/ast/expr.go b/server/ast/expr.go index 23a9643144..45033a49be 100644 --- a/server/ast/expr.go +++ b/server/ast/expr.go @@ -300,7 +300,18 @@ func nodeExpr(ctx *Context, node tree.Expr) (vitess.Expr, error) { logrus.Warnf("collate is not yet supported, ignoring") return nodeExpr(ctx, node.Expr) case *tree.ColumnAccessExpr: - return nil, errors.Errorf("(E).x is not yet supported") + colAccess, err := pgexprs.NewColumnAccess(node.ColName, node.ColIndex) + if err != nil { + return nil, err + } + expr, err := nodeExpr(ctx, node.Expr) + if err != nil { + return nil, err + } + return vitess.InjectedExpr{ + Expression: colAccess, + Children: vitess.Exprs{expr}, + }, nil case *tree.ColumnItem: var tableName vitess.TableName if node.TableName != nil { diff --git a/server/expression/column_access.go b/server/expression/column_access.go new file mode 100644 index 0000000000..a67b2c14ec --- /dev/null +++ b/server/expression/column_access.go @@ -0,0 +1,184 @@ +// Copyright 2026 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "fmt" + + "github.com/cockroachdb/errors" + "github.com/dolthub/go-mysql-server/sql" + vitess "github.com/dolthub/vitess/go/vt/sqlparser" + + pgtypes "github.com/dolthub/doltgresql/server/types" +) + +// ColumnAccess represents an ARRAY[...] expression. +type ColumnAccess struct { + colName string + colNameIdx int + colTyp *pgtypes.DoltgresType + child sql.Expression +} + +var _ vitess.Injectable = (*ColumnAccess)(nil) +var _ sql.Expression = (*ColumnAccess)(nil) + +// NewColumnAccess returns a new *ColumnAccess. +func NewColumnAccess(colName string, colIdx int) (*ColumnAccess, error) { + if len(colName) > 0 { + return &ColumnAccess{ + colName: colName, + colNameIdx: -1, + colTyp: nil, + child: nil, + }, nil + } else { + return &ColumnAccess{ + colName: "", + colNameIdx: colIdx, + colTyp: nil, + child: nil, + }, nil + } +} + +// Children implements the sql.Expression interface. +func (expr *ColumnAccess) Children() []sql.Expression { + return []sql.Expression{expr.child} +} + +// Eval implements the sql.Expression interface. +func (expr *ColumnAccess) Eval(ctx *sql.Context, row sql.Row) (any, error) { + field, err := expr.child.Eval(ctx, row) + if err != nil { + return nil, err + } + if field == nil { + return nil, nil + } + recordVals, ok := field.([]pgtypes.RecordValue) + if !ok { + if len(expr.colName) > 0 { + return nil, errors.Errorf("column notation .%s applied to type %s, which is not a composite type", + expr.colName, expr.child.Type().String()) + } else { + return nil, errors.Errorf("column notation .@%d applied to type %s, which is not a composite type", + expr.colNameIdx+1, expr.child.Type().String()) + } + } + return recordVals[expr.colNameIdx].Value, nil +} + +// IsNullable implements the sql.Expression interface. +func (expr *ColumnAccess) IsNullable() bool { + return true +} + +// Resolved implements the sql.Expression interface. +func (expr *ColumnAccess) Resolved() bool { + return expr.child != nil && expr.child.Resolved() +} + +// String implements the sql.Expression interface. +func (expr *ColumnAccess) String() string { + if expr.child == nil { + return "COLUMN_ACCESS" + } + if len(expr.colName) > 0 { + return fmt.Sprintf("(%s).%s", expr.child.String(), expr.colName) + } else { + return fmt.Sprintf("(%s).@%d", expr.child.String(), expr.colNameIdx+1) + } +} + +// Type implements the sql.Expression interface. +func (expr *ColumnAccess) Type() sql.Type { + if expr.colTyp != nil { + return expr.colTyp + } + if expr.child == nil { + return nil + } + // We're technically returning a different type here since an unresolved type is not the same as a resolved one. + // However, for many early analyzer steps, we only check the ID, so this at least lets us get past those cases. + return pgtypes.NewUnresolvedDoltgresTypeFromID(expr.child.Type().(*pgtypes.DoltgresType).CompositeAttrs[expr.colNameIdx].TypeID) +} + +// WithChildren implements the sql.Expression interface. +func (expr *ColumnAccess) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(expr, len(children), 1) + } + childType := children[0].Type() + doltgresType, ok := childType.(*pgtypes.DoltgresType) + if !ok { + return nil, errors.New("column access is only valid for Doltgres types") + } + if !doltgresType.IsCompositeType() { + if len(expr.colName) > 0 { + return nil, errors.Errorf("column notation .%s applied to type %s, which is not a composite type", + expr.colName, children[0].Type().String()) + } else { + return nil, errors.Errorf("column notation .@%d applied to type %s, which is not a composite type", + expr.colNameIdx+1, children[0].Type().String()) + } + } + var idx int + if len(expr.colName) > 0 { + idx = -1 + for _, attr := range doltgresType.CompositeAttrs { + if attr.Name == expr.colName { + idx = int(attr.Num - 1) + break + } + } + if idx == -1 { + return nil, errors.Errorf(`column "%s" not found in data type %s`, + expr.colName, doltgresType.String()) + } + } else { + if expr.colNameIdx < 0 || expr.colNameIdx >= len(doltgresType.CompositeAttrs) { + return nil, errors.Errorf("column notation .@%d applied to type %s is out of bounds", + expr.colNameIdx+1, children[0].Type().String()) + } + idx = expr.colNameIdx + } + return &ColumnAccess{ + colName: expr.colName, + colNameIdx: idx, + colTyp: expr.colTyp, + child: children[0], + }, nil +} + +// WithResolvedChildren implements the vitess.InjectableExpression interface. +func (expr *ColumnAccess) WithResolvedChildren(children []any) (any, error) { + newExpressions := make([]sql.Expression, len(children)) + for i, resolvedChild := range children { + resolvedExpression, ok := resolvedChild.(sql.Expression) + if !ok { + return nil, errors.Errorf("expected vitess child to be an expression but has type `%T`", resolvedChild) + } + newExpressions[i] = resolvedExpression + } + return expr.WithChildren(newExpressions...) +} + +// WithType returns this expression with the given type set, as it must be set within the analyzer. +func (expr *ColumnAccess) WithType(typ *pgtypes.DoltgresType) sql.Expression { + ne := *expr + ne.colTyp = typ + return &ne +} diff --git a/server/functions/dolt_recordtrim.go b/server/functions/dolt_recordtrim.go new file mode 100644 index 0000000000..51fae2eb5c --- /dev/null +++ b/server/functions/dolt_recordtrim.go @@ -0,0 +1,53 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package functions + +import ( + "github.com/cockroachdb/errors" + "github.com/dolthub/go-mysql-server/sql" + + "github.com/dolthub/doltgresql/server/functions/framework" + pgtypes "github.com/dolthub/doltgresql/server/types" +) + +// initDoltRecordTrim registers the functions to the catalog. +func initDoltRecordTrim() { + framework.RegisterFunction(dolt_recordTrim) +} + +// dolt_recordTrim is used to remove a specific column within a composite type. This will generally lead to an invalid +// value for the composite type, however this is used within the DROP COLUMN table hook to fix data for all columns that +// reference the type, as that is the only time when the data is invalid. This is why this is a "dolt_" function as +// well, as it's not intended for general use. +var dolt_recordTrim = framework.Function2{ + Name: "dolt_recordtrim", + Return: pgtypes.AnyElement, + Parameters: [2]*pgtypes.DoltgresType{pgtypes.AnyElement, pgtypes.Int32}, + Strict: true, + Callable: func(ctx *sql.Context, types [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { + if !types[0].IsCompositeType() { + return val1, nil + } + trimVal := val2.(int32) + recordVals := val1.([]pgtypes.RecordValue) + if trimVal < 0 || int(trimVal) >= len(recordVals) { + return nil, errors.New("record trim index is out of bounds") + } + newRecordVals := make([]pgtypes.RecordValue, len(recordVals)-1) + copy(newRecordVals, recordVals[:trimVal]) + copy(newRecordVals[trimVal:], recordVals[trimVal+1:]) + return newRecordVals, nil + }, +} diff --git a/server/functions/framework/cast.go b/server/functions/framework/cast.go index 1150ab931f..96b2d237b4 100644 --- a/server/functions/framework/cast.go +++ b/server/functions/framework/cast.go @@ -372,14 +372,16 @@ func getRecordCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType, return nil, err } outputVals[i].Type = outputType - positionCast := passthrough(valType, outputType) - if positionCast == nil { - // TODO: this should be the DETAIL, with the actual error being "cannot cast type to " - return nil, errors.Newf("Cannot cast type %s to %s in column %d", valType.Name(), outputType.Name(), i+1) - } - outputVals[i].Value, err = positionCast(ctx, vals[i].Value, outputType) - if err != nil { - return nil, err + if vals[i].Value != nil { + positionCast := passthrough(valType, outputType) + if positionCast == nil { + // TODO: this should be the DETAIL, with the actual error being "cannot cast type to " + return nil, errors.Newf("Cannot cast type %s to %s in column %d", valType.Name(), outputType.Name(), i+1) + } + outputVals[i].Value, err = positionCast(ctx, vals[i].Value, outputType) + if err != nil { + return nil, err + } } } return outputVals, nil diff --git a/server/functions/init.go b/server/functions/init.go index e98cc3edc7..45ff4154f2 100644 --- a/server/functions/init.go +++ b/server/functions/init.go @@ -106,6 +106,7 @@ func Init() { initDegrees() initDiv() initDoltProcedures() + initDoltRecordTrim() initExp() initExtract() initFactorial() diff --git a/server/hook/delete_table.go b/server/hook/delete_table.go index ad2f4a8efc..301001a0eb 100644 --- a/server/hook/delete_table.go +++ b/server/hook/delete_table.go @@ -30,7 +30,7 @@ import ( // BeforeTableDeletion performs all validation necessary to ensure that table deletion does not leave the database in an // invalid state. -func BeforeTableDeletion(ctx *sql.Context, nodeInterface sql.Node) (sql.Node, error) { +func BeforeTableDeletion(ctx *sql.Context, runner sql.StatementRunner, nodeInterface sql.Node) (sql.Node, error) { n, ok := nodeInterface.(*plan.DropTable) if !ok { return nil, errors.Newf("DROP TABLE pre-hook expected `*plan.DropTable` but received `%T`", nodeInterface) diff --git a/server/hook/table_add_column.go b/server/hook/table_add_column.go new file mode 100644 index 0000000000..7475a00693 --- /dev/null +++ b/server/hook/table_add_column.go @@ -0,0 +1,179 @@ +// Copyright 2026 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hook + +import ( + "fmt" + "strings" + + "github.com/cockroachdb/errors" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/plan" + + "github.com/dolthub/doltgresql/core" + "github.com/dolthub/doltgresql/core/id" + pgtypes "github.com/dolthub/doltgresql/server/types" +) + +// BeforeTableAddColumn handles validation that's unique to Doltgres. +func BeforeTableAddColumn(ctx *sql.Context, runner sql.StatementRunner, nodeInterface sql.Node) (sql.Node, error) { + n, ok := nodeInterface.(*plan.AddColumn) + if !ok { + return nil, errors.Errorf("ADD COLUMN pre-hook expected `*plan.AddColumn` but received `%T`", nodeInterface) + } + // If the column being added doesn't have a default value, then we don't have anything to check (for now) + if n.Column().Default == nil { + return n, nil + } + + // Grab the table being altered + doltTable := core.SQLNodeToDoltTable(n.Table) + if doltTable == nil { + // If this table isn't a Dolt table then we don't have anything to do + return n, nil + } + _, root, err := core.GetRootFromContext(ctx) + if err != nil { + return n, nil + } + tableName := doltTable.TableName() + tableAsType := id.NewType(tableName.Schema, tableName.Name) + allTableNames, err := root.GetAllTableNames(ctx, false) + if err != nil { + return nil, err + } + + for _, otherTableName := range allTableNames { + if doltdb.IsSystemTable(otherTableName) { + // System tables don't use any table types + continue + } + otherTable, ok, err := root.GetTable(ctx, otherTableName) + if err != nil { + return nil, err + } + if !ok { + return nil, errors.Errorf("root returned table name `%s` but it could not be found?", otherTableName.String()) + } + otherTableSch, err := otherTable.GetSchema(ctx) + if err != nil { + return nil, err + } + for _, otherCol := range otherTableSch.GetAllCols().GetColumns() { + colType := otherCol.TypeInfo.ToSqlType() + dgtype, ok := colType.(*pgtypes.DoltgresType) + if !ok { + // If this isn't a Doltgres type, then it can't be a table type so we can ignore it + continue + } + if dgtype.ID != tableAsType { + // This column isn't our table type, so we can ignore it + continue + } + return nil, errors.Errorf(`cannot alter table "%s" because column "%s.%s" uses its row type`, + tableName.Name, otherTableName.Name, otherCol.Name) + } + } + return n, nil +} + +// AfterTableAddColumn handles updating various table columns, alongside other validation that's unique to Doltgres. +func AfterTableAddColumn(ctx *sql.Context, runner sql.StatementRunner, nodeInterface sql.Node) error { + n, ok := nodeInterface.(*plan.AddColumn) + if !ok { + return errors.Errorf("ADD COLUMN post-hook expected `*plan.AddColumn` but received `%T`", nodeInterface) + } + + // Grab the table being altered + doltTable := core.SQLNodeToDoltTable(n.Table) + if doltTable == nil { + // If this table isn't a Dolt table then we don't have anything to do + return nil + } + _, root, err := core.GetRootFromContext(ctx) + if err != nil { + return err + } + tableName := doltTable.TableName() + tableAsType := id.NewType(tableName.Schema, tableName.Name) + allTableNames, err := root.GetAllTableNames(ctx, false) + if err != nil { + return err + } + sch := doltTable.Schema() + + for _, otherTableName := range allTableNames { + if doltdb.IsSystemTable(otherTableName) { + // System tables don't use any table types + continue + } + otherTable, ok, err := root.GetTable(ctx, otherTableName) + if err != nil { + return err + } + if !ok { + return errors.Errorf("root returned table name `%s` but it could not be found?", otherTableName.String()) + } + otherTableSch, err := otherTable.GetSchema(ctx) + if err != nil { + return err + } + for _, otherCol := range otherTableSch.GetAllCols().GetColumns() { + colType := otherCol.TypeInfo.ToSqlType() + dgtype, ok := colType.(*pgtypes.DoltgresType) + if !ok { + // If this isn't a Doltgres type, then it can't be a table type so we can ignore it + continue + } + if dgtype.ID != tableAsType { + // This column isn't our table type, so we can ignore it + continue + } + // Build the UPDATE statement that we'll run for this table + rowValues := make([]string, len(sch)+1) + for i, col := range sch { + rowValues[i] = fmt.Sprintf(`("%s")."%s"`, otherCol.Name, col.Name) + } + rowValues[len(rowValues)-1] = "NULL" + // The UPDATE changes the values in the table + updateStr := fmt.Sprintf(`UPDATE "%s"."%s" SET "%s" = ROW(%s)::"%s"."%s" WHERE length("%s"::text) > 0;`, + otherTableName.Schema, otherTableName.Name, otherCol.Name, strings.Join(rowValues, ","), tableName.Schema, tableName.Name, otherCol.Name) + // The ALTER updates the type on the schema since it still has the old one + alterStr := fmt.Sprintf(`ALTER TABLE "%s"."%s" ALTER COLUMN "%s" TYPE "%s"."%s";`, + otherTableName.Schema, otherTableName.Name, otherCol.Name, tableName.Schema, tableName.Name) + // We run the statements as though they were interpreted since we're running new statements inside the original + _, err = sql.RunInterpreted(ctx, func(subCtx *sql.Context) ([]sql.Row, error) { + _, rowIter, _, err := runner.QueryWithBindings(subCtx, updateStr, nil, nil, nil) + if err != nil { + return nil, err + } + _, err = sql.RowIterToRows(subCtx, rowIter) + if err != nil { + return nil, err + } + _, rowIter, _, err = runner.QueryWithBindings(subCtx, alterStr, nil, nil, nil) + if err != nil { + return nil, err + } + return sql.RowIterToRows(subCtx, rowIter) + }) + if err != nil { + return err + } + } + } + return nil +} diff --git a/server/hook/table_drop_column.go b/server/hook/table_drop_column.go new file mode 100644 index 0000000000..946d0ac775 --- /dev/null +++ b/server/hook/table_drop_column.go @@ -0,0 +1,121 @@ +// Copyright 2026 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hook + +import ( + "fmt" + + "github.com/cockroachdb/errors" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/plan" + + "github.com/dolthub/doltgresql/core" + "github.com/dolthub/doltgresql/core/id" + pgtypes "github.com/dolthub/doltgresql/server/types" +) + +// AfterTableDropColumn handles updating various table columns, alongside other validation that's unique to Doltgres. +func AfterTableDropColumn(ctx *sql.Context, runner sql.StatementRunner, nodeInterface sql.Node) error { + n, ok := nodeInterface.(*plan.DropColumn) + if !ok { + return errors.Errorf("DROP COLUMN post-hook expected `*plan.DropColumn` but received `%T`", nodeInterface) + } + + // Grab the table being altered + doltTable := core.SQLNodeToDoltTable(n.Table) + if doltTable == nil { + // If this table isn't a Dolt table then we don't have anything to do + return nil + } + _, root, err := core.GetRootFromContext(ctx) + if err != nil { + return err + } + tableName := doltTable.TableName() + tableAsType := id.NewType(tableName.Schema, tableName.Name) + allTableNames, err := root.GetAllTableNames(ctx, false) + if err != nil { + return err + } + sch := n.TargetSchema() + + for _, otherTableName := range allTableNames { + if doltdb.IsSystemTable(otherTableName) { + // System tables don't use any table types + continue + } + otherTable, ok, err := root.GetTable(ctx, otherTableName) + if err != nil { + return err + } + if !ok { + return errors.Errorf("root returned table name `%s` but it could not be found?", otherTableName.String()) + } + otherTableSch, err := otherTable.GetSchema(ctx) + if err != nil { + return err + } + for _, otherCol := range otherTableSch.GetAllCols().GetColumns() { + colType := otherCol.TypeInfo.ToSqlType() + dgtype, ok := colType.(*pgtypes.DoltgresType) + if !ok { + // If this isn't a Doltgres type, then it can't be a table type so we can ignore it + continue + } + if dgtype.ID != tableAsType { + // This column isn't our table type, so we can ignore it + continue + } + // Build the UPDATE statement that we'll run for this table + trimIdx := -1 + for i, col := range sch { + if col.Name == n.Column { + trimIdx = i + break + } + } + if trimIdx == -1 { + return errors.New("DROP COLUMN post-hook could not find the index of the column to remove") + } + // The UPDATE changes the values in the table + updateStr := fmt.Sprintf(`UPDATE "%s"."%s" SET "%s" = dolt_recordtrim("%s", %d)::"%s"."%s";`, + otherTableName.Schema, otherTableName.Name, otherCol.Name, otherCol.Name, trimIdx, tableName.Schema, tableName.Name) + // The ALTER updates the type on the schema since it still has the old one + alterStr := fmt.Sprintf(`ALTER TABLE "%s"."%s" ALTER COLUMN "%s" TYPE "%s"."%s";`, + otherTableName.Schema, otherTableName.Name, otherCol.Name, tableName.Schema, tableName.Name) + // We run the statements as though they were interpreted since we're running new statements inside the original + _, err = sql.RunInterpreted(ctx, func(subCtx *sql.Context) ([]sql.Row, error) { + _, rowIter, _, err := runner.QueryWithBindings(subCtx, updateStr, nil, nil, nil) + if err != nil { + return nil, err + } + _, err = sql.RowIterToRows(subCtx, rowIter) + if err != nil { + return nil, err + } + _, rowIter, _, err = runner.QueryWithBindings(subCtx, alterStr, nil, nil, nil) + if err != nil { + return nil, err + } + return sql.RowIterToRows(subCtx, rowIter) + }) + if err != nil { + return err + } + } + } + return nil +} diff --git a/server/types/type.go b/server/types/type.go index 009a510492..927ee2d5f4 100644 --- a/server/types/type.go +++ b/server/types/type.go @@ -94,11 +94,16 @@ var _ sql.NullType = &DoltgresType{} var _ sql.StringType = &DoltgresType{} var _ sql.NumberType = &DoltgresType{} -// NewUnresolvedDoltgresType returns DoltgresType that is not resolved. +// NewUnresolvedDoltgresType returns a DoltgresType that is not resolved. // The type will have the schema and name defined with given values, with IsUnresolved == true. func NewUnresolvedDoltgresType(sch, name string) *DoltgresType { + return NewUnresolvedDoltgresTypeFromID(id.NewType(sch, name)) +} + +// NewUnresolvedDoltgresTypeFromID returns a DoltgresType that is not resolved. +func NewUnresolvedDoltgresTypeFromID(idType id.Type) *DoltgresType { return &DoltgresType{ - ID: id.NewType(sch, name), + ID: idType, IsUnresolved: true, } } @@ -286,7 +291,7 @@ func (t *DoltgresType) Compare(ctx context.Context, v1 interface{}, v2 interface return cmp.Compare(ab.OID(), v2.(id.Oid).OID()), nil case []any: if !t.IsArrayType() { - return 0, errors.Errorf("array value received in Compare for non array type") + return 0, errors.New("array value received in Compare for non array type") } bb := v2.([]any) minLength := utils.Min(len(ab), len(bb)) @@ -306,6 +311,36 @@ func (t *DoltgresType) Compare(ctx context.Context, v1 interface{}, v2 interface } else { return 1, nil } + case []RecordValue: + if !t.IsCompositeType() { + return 0, errors.New("record value received in Compare for non composite type") + } + bb := v2.([]RecordValue) + minLength := utils.Min(len(ab), len(bb)) + for i := 0; i < minLength; i++ { + dgType, isDgType1 := ab[i].Type.(*DoltgresType) + otherDgType, isDgType2 := bb[i].Type.(*DoltgresType) + if !isDgType1 || !isDgType2 { + return 0, errors.New("record values in Compare must use a Doltgres type") + } + if dgType.ID != otherDgType.ID { + return 0, errors.New("record values in Compare must use the same type as the same index") + } + res, err := dgType.Compare(ctx, ab[i].Value, bb[i].Value) + if err != nil { + return 0, err + } + if res != 0 { + return res, nil + } + } + if len(ab) == len(bb) { + return 0, nil + } else if len(ab) < len(bb) { + return -1, nil + } else { + return 1, nil + } default: return 0, errors.Errorf("unhandled type %T in Compare", v1) } diff --git a/servercfg/config.go b/servercfg/config.go index f9727f8a40..957e973cfd 100755 --- a/servercfg/config.go +++ b/servercfg/config.go @@ -49,6 +49,13 @@ func (*DoltgresConfig) Overrides() sql.EngineOverrides { DropTable: sql.DropTable{ PreSQLExecution: hook.BeforeTableDeletion, }, + TableAddColumn: sql.TableAddColumn{ + PreSQLExecution: hook.BeforeTableAddColumn, + PostSQLExecution: hook.AfterTableAddColumn, + }, + TableDropColumn: sql.TableDropColumn{ + PostSQLExecution: hook.AfterTableDropColumn, + }, }, SchemaFormatter: pgsql.NewPostgresSchemaFormatter(), CostedIndexScanExpressionFilter: &analyzer.LogicTreeWalker{}, diff --git a/testing/go/enginetest/doltgres_engine_test.go b/testing/go/enginetest/doltgres_engine_test.go index be09583680..d3439c902c 100755 --- a/testing/go/enginetest/doltgres_engine_test.go +++ b/testing/go/enginetest/doltgres_engine_test.go @@ -1687,6 +1687,7 @@ func TestDoltCommit(t *testing.T) { "CALL DOLT_COMMIT('-amend') works to remove changes from a commit", "CALL DOLT_COMMIT('-amend') works to update a merge commit", "CALL DOLT_COMMIT('--amend') works on initial commit", + "DOLT_COMMIT respects foreign_key_checks=0", }) denginetest.RunDoltCommitTests(t, harness) } diff --git a/testing/go/issues_test.go b/testing/go/issues_test.go index ae80326f3d..4f27f1a1b3 100644 --- a/testing/go/issues_test.go +++ b/testing/go/issues_test.go @@ -229,5 +229,142 @@ limit 1`, }, }, }, + { + Name: "Issue #2197 Part 2", + SetUpScript: []string{ + `CREATE TABLE t1a (a INT4, b VARCHAR(3));`, + `CREATE TABLE t1b (a INT4 NOT NULL, b VARCHAR(3) NOT NULL);`, + `CREATE TABLE t2 (id SERIAL, t1a t1a, t1b t1b);`, + `INSERT INTO t2 (t1a) VALUES (ROW(1, 'abc'));`, + `INSERT INTO t2 (t1b) VALUES (ROW(1, 'abc'));`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT * FROM t2;`, + Expected: []sql.Row{ + {1, "(1,abc)", nil}, + {2, nil, "(1,abc)"}, + }, + }, + { + Query: `ALTER TABLE t1a ADD COLUMN c VARCHAR(10);`, + Expected: []sql.Row{}, + }, + { + Query: `ALTER TABLE t1b ADD COLUMN c VARCHAR(10) NOT NULL;`, + Expected: []sql.Row{}, + }, + { + Query: `SELECT * FROM t2 ORDER BY id;`, + Expected: []sql.Row{ + {1, "(1,abc,)", nil}, + {2, nil, "(1,abc,)"}, + }, + }, + { + Query: `ALTER TABLE t1a DROP COLUMN b;`, + Expected: []sql.Row{}, + }, + { + Query: `ALTER TABLE t1b DROP COLUMN b;`, + Expected: []sql.Row{}, + }, + { + Query: `SELECT * FROM t2 ORDER BY id;`, + Expected: []sql.Row{ + {1, "(1,)", nil}, + {2, nil, "(1,)"}, + }, + }, + { + Query: `INSERT INTO t1a VALUES (2, 'def');`, + Expected: []sql.Row{}, + }, + { + Query: `INSERT INTO t1b VALUES (3, 'xyzzy');`, + Expected: []sql.Row{}, + }, + { + Query: `INSERT INTO t2 (t1a) SELECT ROW(a,c)::t1a FROM t1a;`, + Expected: []sql.Row{}, + }, + { + Query: `INSERT INTO t2 (t1b) SELECT ROW(a,c)::t1b FROM t1b;`, + Expected: []sql.Row{}, + }, + { + Query: `SELECT * FROM t2 ORDER BY id;`, + Expected: []sql.Row{ + {1, "(1,)", nil}, + {2, nil, "(1,)"}, + {3, "(2,def)", nil}, + {4, nil, "(3,xyzzy)"}, + }, + }, + { + Query: `SELECT ((t1a).@1), ((t1b).@2) FROM t2 ORDER BY id;`, + Expected: []sql.Row{ + {1, nil}, + {nil, nil}, + {2, nil}, + {nil, "xyzzy"}, + }, + }, + { + Query: `UPDATE t2 SET t1a=ROW((t1a).a+100, (t1a).c)::t1a WHERE length(t1a::text) > 0;`, + Expected: []sql.Row{}, + }, + { + Query: `UPDATE t2 SET t1b=ROW((t1b).@1+100, (t1b).@2)::t1b WHERE length(t1b::text) > 0;`, + Expected: []sql.Row{}, + }, + { + Query: `SELECT * FROM t2 ORDER BY id;`, + Expected: []sql.Row{ + {1, "(101,)", nil}, + {2, nil, "(101,)"}, + {3, "(102,def)", nil}, + {4, nil, "(103,xyzzy)"}, + }, + }, + { + Query: `SELECT (id).a FROM t2;`, + ExpectedErr: "column notation .a applied to type", + }, + { + Query: `SELECT (t1a).g FROM t2;`, + ExpectedErr: `column "g" not found in data type`, + }, + { + Query: `SELECT (t1a).@0 FROM t2;`, + ExpectedErr: "out of bounds", + }, + { + Query: `SELECT (t1a).@3 FROM t2;`, + ExpectedErr: "out of bounds", + }, + { + Query: `ALTER TABLE t1a ADD COLUMN d VARCHAR(10) DEFAULT 'abc';`, + ExpectedErr: `cannot alter table "t1a" because column "t2.t1a" uses its row type`, + }, + { + Query: `ALTER TABLE t1a ADD COLUMN d VARCHAR(10);`, + Expected: []sql.Row{}, + }, + { + Query: `ALTER TABLE t1a DROP COLUMN c;`, + Expected: []sql.Row{}, + }, + { + Query: `SELECT * FROM t2 ORDER BY id;`, + Expected: []sql.Row{ + {1, "(101,)", nil}, + {2, nil, "(101,)"}, + {3, "(102,)", nil}, + {4, nil, "(103,xyzzy)"}, + }, + }, + }, + }, }) }