diff --git a/upstream/go.mod b/upstream/go.mod index 51956d858..5b7fce197 100644 --- a/upstream/go.mod +++ b/upstream/go.mod @@ -1,6 +1,7 @@ module github.com/tektoncd/results -go 1.22.3 +go 1.22.7 + toolchain go1.23.5 require ( @@ -42,9 +43,9 @@ require ( google.golang.org/grpc v1.70.0 google.golang.org/protobuf v1.36.5 gorm.io/driver/mysql v1.5.1 - gorm.io/driver/postgres v1.5.2 + gorm.io/driver/postgres v1.5.11 gorm.io/driver/sqlite v1.5.7 - gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde + gorm.io/gorm v1.25.10 k8s.io/api v0.29.13 k8s.io/apimachinery v0.29.13 k8s.io/apiserver v0.29.13 diff --git a/upstream/go.sum b/upstream/go.sum index 746089e71..58ccddb8b 100644 --- a/upstream/go.sum +++ b/upstream/go.sum @@ -1306,13 +1306,13 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/mysql v1.5.1 h1:WUEH5VF9obL/lTtzjmML/5e6VfFR/788coz2uaVCAZw= gorm.io/driver/mysql v1.5.1/go.mod h1:Jo3Xu7mMhCyj8dlrb3WoCaRd1FhsVh+yMXb1jUInf5o= -gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0= -gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8= +gorm.io/driver/postgres v1.5.11 h1:ubBVAfbKEUld/twyKZ0IYn9rSQh448EdelLYk9Mv314= +gorm.io/driver/postgres v1.5.11/go.mod h1:DX3GReXH+3FPWGrrgffdvCk3DQ1dwDPdmbenSkweRGI= gorm.io/driver/sqlite v1.5.7 h1:8NvsrhP0ifM7LX9G4zPB97NwovUakUxc+2V2uuf3Z1I= gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= -gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde h1:9DShaph9qhkIYw7QF91I/ynrr4cOO2PZra2PFD7Mfeg= -gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s= +gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= diff --git a/upstream/vendor/gorm.io/driver/postgres/error_translator.go b/upstream/vendor/gorm.io/driver/postgres/error_translator.go index 285494c2d..5f813501c 100644 --- a/upstream/vendor/gorm.io/driver/postgres/error_translator.go +++ b/upstream/vendor/gorm.io/driver/postgres/error_translator.go @@ -2,26 +2,32 @@ package postgres import ( "encoding/json" - "github.com/jackc/pgx/v5/pgconn" + "gorm.io/gorm" + + "github.com/jackc/pgx/v5/pgconn" ) -var errCodes = map[string]string{ - "uniqueConstraint": "23505", +// The error codes to map PostgreSQL errors to gorm errors, here is the PostgreSQL error codes reference https://www.postgresql.org/docs/current/errcodes-appendix.html. +var errCodes = map[string]error{ + "23505": gorm.ErrDuplicatedKey, + "23503": gorm.ErrForeignKeyViolated, + "42703": gorm.ErrInvalidField, + "23514": gorm.ErrCheckConstraintViolated, } type ErrMessage struct { - Code string `json:"Code"` - Severity string `json:"Severity"` - Message string `json:"Message"` + Code string + Severity string + Message string } // Translate it will translate the error to native gorm errors. // Since currently gorm supporting both pgx and pg drivers, only checking for pgx PgError types is not enough for translating errors, so we have additional error json marshal fallback. func (dialector Dialector) Translate(err error) error { if pgErr, ok := err.(*pgconn.PgError); ok { - if pgErr.Code == errCodes["uniqueConstraint"] { - return gorm.ErrDuplicatedKey + if translatedErr, found := errCodes[pgErr.Code]; found { + return translatedErr } return err } @@ -37,8 +43,8 @@ func (dialector Dialector) Translate(err error) error { return err } - if errMsg.Code == errCodes["uniqueConstraint"] { - return gorm.ErrDuplicatedKey + if translatedErr, found := errCodes[errMsg.Code]; found { + return translatedErr } return err } diff --git a/upstream/vendor/gorm.io/driver/postgres/migrator.go b/upstream/vendor/gorm.io/driver/postgres/migrator.go index e4d8e9260..2293a7cf1 100644 --- a/upstream/vendor/gorm.io/driver/postgres/migrator.go +++ b/upstream/vendor/gorm.io/driver/postgres/migrator.go @@ -13,44 +13,71 @@ import ( "gorm.io/gorm/schema" ) +// See https://stackoverflow.com/questions/2204058/list-columns-with-indexes-in-postgresql +// Here are some changes: +// - use `LEFT JOIN` instead of `CROSS JOIN` +// - exclude indexes used to support constraints (they are auto-generated) const indexSql = ` -select - t.relname as table_name, - i.relname as index_name, - a.attname as column_name, - ix.indisunique as non_unique, - ix.indisprimary as primary -from - pg_class t, - pg_class i, - pg_index ix, - pg_attribute a -where - t.oid = ix.indrelid - and i.oid = ix.indexrelid - and a.attrelid = t.oid - and a.attnum = ANY(ix.indkey) - and t.relkind = 'r' - and t.relname = ? +SELECT + ct.relname AS table_name, + ci.relname AS index_name, + i.indisunique AS non_unique, + i.indisprimary AS primary, + a.attname AS column_name +FROM + pg_index i + LEFT JOIN pg_class ct ON ct.oid = i.indrelid + LEFT JOIN pg_class ci ON ci.oid = i.indexrelid + LEFT JOIN pg_attribute a ON a.attrelid = ct.oid + LEFT JOIN pg_constraint con ON con.conindid = i.indexrelid +WHERE + a.attnum = ANY(i.indkey) + AND con.oid IS NULL + AND ct.relkind = 'r' + AND ct.relname = ? ` var typeAliasMap = map[string][]string{ - "int2": {"smallint"}, - "int4": {"integer"}, - "int8": {"bigint"}, - "smallint": {"int2"}, - "integer": {"int4"}, - "bigint": {"int8"}, - "decimal": {"numeric"}, - "numeric": {"decimal"}, + "int": {"integer"}, + "int2": {"smallint"}, + "int4": {"integer"}, + "int8": {"bigint"}, + "smallint": {"int2"}, + "integer": {"int4"}, + "bigint": {"int8"}, + "decimal": {"numeric"}, + "numeric": {"decimal"}, + "timestamptz": {"timestamp with time zone"}, + "timestamp with time zone": {"timestamptz"}, + "bool": {"boolean"}, + "boolean": {"bool"}, + "serial2": {"smallserial"}, + "serial4": {"serial"}, + "serial8": {"bigserial"}, + "varbit": {"bit varying"}, + "char": {"character"}, + "varchar": {"character varying"}, + "float4": {"real"}, + "float8": {"double precision"}, + "timetz": {"time with time zone"}, } type Migrator struct { migrator.Migrator } +// select querys ignore dryrun +func (m Migrator) queryRaw(sql string, values ...interface{}) (tx *gorm.DB) { + queryTx := m.DB + if m.DB.DryRun { + queryTx = m.DB.Session(&gorm.Session{}) + queryTx.DryRun = false + } + return queryTx.Raw(sql, values...) +} + func (m Migrator) CurrentDatabase() (name string) { - m.DB.Raw("SELECT CURRENT_DATABASE()").Scan(&name) + m.queryRaw("SELECT CURRENT_DATABASE()").Scan(&name) return } @@ -82,7 +109,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool { } } currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) - return m.DB.Raw( + return m.queryRaw( "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = ?", curTable, name, currentSchema, ).Scan(&count).Error }) @@ -115,6 +142,10 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { createIndexSQL += " ?" } + if idx.Option != "" { + createIndexSQL += " " + idx.Option + } + if idx.Where != "" { createIndexSQL += " WHERE " + idx.Where } @@ -150,7 +181,7 @@ func (m Migrator) DropIndex(value interface{}, name string) error { func (m Migrator) GetTables() (tableList []string, err error) { currentSchema, _ := m.CurrentSchema(m.DB.Statement, "") - return tableList, m.DB.Raw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error + return tableList, m.queryRaw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error } func (m Migrator) CreateTable(values ...interface{}) (err error) { @@ -160,7 +191,8 @@ func (m Migrator) CreateTable(values ...interface{}) (err error) { for _, value := range m.ReorderModels(values, false) { if err = m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil { - for _, field := range stmt.Schema.FieldsByDBName { + for _, fieldName := range stmt.Schema.DBNames { + field := stmt.Schema.FieldsByDBName[fieldName] if field.Comment != "" { if err := m.DB.Exec( "COMMENT ON COLUMN ?.? IS ?", @@ -183,7 +215,7 @@ func (m Migrator) HasTable(value interface{}) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) - return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error + return m.queryRaw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error }) return count > 0 } @@ -235,7 +267,7 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { } currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) - return m.DB.Raw( + return m.queryRaw( "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentSchema, curTable, name, ).Scan(&count).Error @@ -260,7 +292,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy checkSQL += "WHERE objsubid = (SELECT ordinal_position FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?) " checkSQL += "AND objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = ? AND relnamespace = " checkSQL += "(SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?))" - m.DB.Raw(checkSQL, values...).Scan(&description) + m.queryRaw(checkSQL, values...).Scan(&description) comment := strings.Trim(field.Comment, "'") comment = strings.Trim(comment, `"`) @@ -294,7 +326,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { fileType := clause.Expr{SQL: m.DataTypeOf(field)} // check for typeName and SQL name isSameType := true - if fieldColumnType.DatabaseTypeName() != fileType.SQL { + if !strings.EqualFold(fieldColumnType.DatabaseTypeName(), fileType.SQL) { isSameType = false // if different, also check for aliases aliases := m.GetTypeAliases(fieldColumnType.DatabaseTypeName()) @@ -326,8 +358,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { return err } } else { - if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?"+m.genUsingExpression(fileType.SQL, fieldColumnType.DatabaseTypeName()), - m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, clause.Column{Name: field.DBName}, fileType).Error; err != nil { + if err := m.modifyColumn(stmt, field, fileType, fieldColumnType); err != nil { return err } } @@ -345,16 +376,6 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { } } - if uniq, _ := fieldColumnType.Unique(); !uniq && field.Unique { - idxName := clause.Column{Name: m.DB.Config.NamingStrategy.IndexName(stmt.Table, field.DBName)} - // Not a unique constraint but a unique index - if !m.HasIndex(stmt.Table, idxName.Name) { - if err := m.DB.Exec("ALTER TABLE ? ADD CONSTRAINT ? UNIQUE(?)", m.CurrentTable(stmt), idxName, clause.Column{Name: field.DBName}).Error; err != nil { - return err - } - } - } - if v, ok := fieldColumnType.DefaultValue(); (field.DefaultValueInterface == nil && ok) || v != field.DefaultValue { if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { if field.DefaultValueInterface != nil { @@ -368,10 +389,16 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { return err } } else { - if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil { + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { return err } } + } else if !field.HasDefaultValue { + // case - as-is column has default value and to-be column has no default value + // need to drop default + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { + return err + } } } return nil @@ -387,28 +414,39 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { return nil } -func (m Migrator) genUsingExpression(targetType, sourceType string) string { - if targetType == "boolean" { - switch sourceType { +func (m Migrator) modifyColumn(stmt *gorm.Statement, field *schema.Field, targetType clause.Expr, existingColumn *migrator.ColumnType) error { + alterSQL := "ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::?" + isUncastableDefaultValue := false + + if targetType.SQL == "boolean" { + switch existingColumn.DatabaseTypeName() { case "int2", "int8", "numeric": - return " USING ?::INT::?" + alterSQL = "ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::int::?" + } + isUncastableDefaultValue = true + } + + if dv, _ := existingColumn.DefaultValue(); dv != "" && isUncastableDefaultValue { + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { + return err } } - return " USING ?::?" + if err := m.DB.Exec(alterSQL, m.CurrentTable(stmt), clause.Column{Name: field.DBName}, targetType, clause.Column{Name: field.DBName}, targetType).Error; err != nil { + return err + } + return nil } func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { - constraint, chk, table := m.GuessConstraintAndTable(stmt, name) - currentSchema, curTable := m.CurrentSchema(stmt, table) + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { - name = constraint.Name - } else if chk != nil { - name = chk.Name + name = constraint.GetName() } + currentSchema, curTable := m.CurrentSchema(stmt, table) - return m.DB.Raw( + return m.queryRaw( "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = ? AND table_name = ? AND constraint_name = ?", currentSchema, curTable, name, ).Scan(&count).Error @@ -423,7 +461,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, var ( currentDatabase = m.DB.Migrator().CurrentDatabase() currentSchema, table = m.CurrentSchema(stmt, stmt.Table) - columns, err = m.DB.Raw( + columns, err = m.queryRaw( "SELECT c.column_name, c.is_nullable = 'YES', c.udt_name, c.character_maximum_length, c.numeric_precision, c.numeric_precision_radix, c.numeric_scale, c.datetime_precision, 8 * typlen, c.column_default, pd.description, c.identity_increment FROM information_schema.columns AS c JOIN pg_type AS pgt ON c.udt_name = pgt.typname LEFT JOIN pg_catalog.pg_description as pd ON pd.objsubid = c.ordinal_position AND pd.objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = c.table_name AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema)) where table_catalog = ? AND table_schema = ? AND table_name = ?", currentDatabase, currentSchema, table).Rows() ) @@ -463,7 +501,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, } if column.DefaultValueValue.Valid { - column.DefaultValueValue.String = regexp.MustCompile(`'?(.*)\b'?:+[\w\s]+$`).ReplaceAllString(column.DefaultValueValue.String, "$1") + column.DefaultValueValue.String = parseDefaultValueValue(column.DefaultValueValue.String) } if datetimePrecision.Valid { @@ -497,7 +535,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, // check primary, unique field { - columnTypeRows, err := m.DB.Raw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows() + columnTypeRows, err := m.queryRaw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows() if err != nil { return err } @@ -509,7 +547,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, } columnTypeRows.Close() - columnTypeRows, err = m.DB.Raw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows() + columnTypeRows, err = m.queryRaw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows() if err != nil { return err } @@ -536,7 +574,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, // check column type { - dataTypeRows, err := m.DB.Raw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type + dataTypeRows, err := m.queryRaw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type FROM pg_attribute a JOIN pg_class b ON a.attrelid = b.oid AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?) WHERE a.attnum > 0 -- hide internal columns AND NOT a.attisdropped -- hide deleted columns @@ -694,7 +732,7 @@ func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) { err := m.RunWithValue(value, func(stmt *gorm.Statement) error { result := make([]*Index, 0) - scanErr := m.DB.Raw(indexSql, stmt.Table).Scan(&result).Error + scanErr := m.queryRaw(indexSql, stmt.Table).Scan(&result).Error if scanErr != nil { return scanErr } @@ -769,3 +807,8 @@ func (m Migrator) RenameColumn(dst interface{}, oldName, field string) error { m.resetPreparedStmts() return nil } + +func parseDefaultValueValue(defaultValue string) string { + value := regexp.MustCompile(`^(.*?)(?:::.*)?$`).ReplaceAllString(defaultValue, "$1") + return strings.Trim(value, "'") +} diff --git a/upstream/vendor/gorm.io/driver/postgres/postgres.go b/upstream/vendor/gorm.io/driver/postgres/postgres.go index dbeabf561..e865b0f85 100644 --- a/upstream/vendor/gorm.io/driver/postgres/postgres.go +++ b/upstream/vendor/gorm.io/driver/postgres/postgres.go @@ -3,11 +3,11 @@ package postgres import ( "database/sql" "fmt" - "github.com/jackc/pgx/v5" "regexp" "strconv" "strings" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/stdlib" "gorm.io/gorm" "gorm.io/gorm/callbacks" @@ -24,11 +24,17 @@ type Dialector struct { type Config struct { DriverName string DSN string + WithoutQuotingCheck bool PreferSimpleProtocol bool WithoutReturning bool Conn gorm.ConnPool } +var ( + timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )") + defaultIdentifierLength = 63 //maximum identifier length for postgres +) + func Open(dsn string) gorm.Dialector { return &Dialector{&Config{DSN: dsn}} } @@ -41,12 +47,33 @@ func (dialector Dialector) Name() string { return "postgres" } -var timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )") +func (dialector Dialector) Apply(config *gorm.Config) error { + if config.NamingStrategy == nil { + config.NamingStrategy = schema.NamingStrategy{ + IdentifierMaxLength: defaultIdentifierLength, + } + return nil + } + + switch v := config.NamingStrategy.(type) { + case *schema.NamingStrategy: + if v.IdentifierMaxLength <= 0 { + v.IdentifierMaxLength = defaultIdentifierLength + } + case schema.NamingStrategy: + if v.IdentifierMaxLength <= 0 { + v.IdentifierMaxLength = defaultIdentifierLength + config.NamingStrategy = v + } + } + + return nil +} func (dialector Dialector) Initialize(db *gorm.DB) (err error) { callbackConfig := &callbacks.Config{ CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT"}, - UpdateClauses: []string{"UPDATE", "SET", "WHERE"}, + UpdateClauses: []string{"UPDATE", "SET", "FROM", "WHERE"}, DeleteClauses: []string{"DELETE", "FROM", "WHERE"}, } // register callbacks @@ -94,10 +121,23 @@ func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { writer.WriteByte('$') - writer.WriteString(strconv.Itoa(len(stmt.Vars))) + index := 0 + varLen := len(stmt.Vars) + if varLen > 0 { + switch stmt.Vars[0].(type) { + case pgx.QueryExecMode: + index++ + } + } + writer.WriteString(strconv.Itoa(varLen - index)) } func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { + if dialector.WithoutQuotingCheck { + writer.WriteString(str) + return + } + var ( underQuoted, selfQuoted bool continuousBacktick int8 diff --git a/upstream/vendor/gorm.io/gorm/callbacks.go b/upstream/vendor/gorm.io/gorm/callbacks.go index 195d17203..50b5b0e93 100644 --- a/upstream/vendor/gorm.io/gorm/callbacks.go +++ b/upstream/vendor/gorm.io/gorm/callbacks.go @@ -187,10 +187,18 @@ func (p *processor) Replace(name string, fn func(*DB)) error { func (p *processor) compile() (err error) { var callbacks []*callback + removedMap := map[string]bool{} for _, callback := range p.callbacks { if callback.match == nil || callback.match(p.db) { callbacks = append(callbacks, callback) } + if callback.remove { + removedMap[callback.name] = true + } + } + + if len(removedMap) > 0 { + callbacks = removeCallbacks(callbacks, removedMap) } p.callbacks = callbacks @@ -339,3 +347,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { return } + +func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback { + callbacks := make([]*callback, 0, len(cs)) + for _, callback := range cs { + if nameMap[callback.name] { + continue + } + callbacks = append(callbacks, callback) + } + return callbacks +} diff --git a/upstream/vendor/gorm.io/gorm/callbacks/create.go b/upstream/vendor/gorm.io/gorm/callbacks/create.go index b1488b082..8b7846b63 100644 --- a/upstream/vendor/gorm.io/gorm/callbacks/create.go +++ b/upstream/vendor/gorm.io/gorm/callbacks/create.go @@ -111,6 +111,17 @@ func Create(config *Config) func(db *gorm.DB) { pkField *schema.Field pkFieldName = "@id" ) + + insertID, err := result.LastInsertId() + insertOk := err == nil && insertID > 0 + + if !insertOk { + if !supportReturning { + db.AddError(err) + } + return + } + if db.Statement.Schema != nil { if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { return @@ -119,13 +130,6 @@ func Create(config *Config) func(db *gorm.DB) { pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName } - insertID, err := result.LastInsertId() - insertOk := err == nil && insertID > 0 - if !insertOk { - db.AddError(err) - return - } - // append @id column with value for auto-increment primary key // the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1 switch values := db.Statement.Dest.(type) { @@ -142,6 +146,11 @@ func Create(config *Config) func(db *gorm.DB) { } } } + + if config.LastInsertIDReversed { + insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement + } + for _, mapValue := range mapValues { if mapValue != nil { mapValue[pkFieldName] = insertID @@ -293,13 +302,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { } } - for field, vs := range defaultValueFieldsHavingValue { - values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) - for idx := range values.Values { - if vs[idx] == nil { - values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field)) - } else { - values.Values[idx] = append(values.Values[idx], vs[idx]) + for _, field := range stmt.Schema.FieldsWithDefaultDBValue { + if vs, ok := defaultValueFieldsHavingValue[field]; ok { + values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) + for idx := range values.Values { + if vs[idx] == nil { + values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field)) + } else { + values.Values[idx] = append(values.Values[idx], vs[idx]) + } } } } @@ -322,7 +333,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { } for _, field := range stmt.Schema.FieldsWithDefaultDBValue { - if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil { if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) values.Values[0] = append(values.Values[0], rvOfvalue) @@ -351,7 +362,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { case schema.UnixNanosecond: assignment.Value = curTime.UnixNano() case schema.UnixMillisecond: - assignment.Value = curTime.UnixNano() / 1e6 + assignment.Value = curTime.UnixMilli() case schema.UnixSecond: assignment.Value = curTime.Unix() } diff --git a/upstream/vendor/gorm.io/gorm/callbacks/preload.go b/upstream/vendor/gorm.io/gorm/callbacks/preload.go index 25ecfe761..112343fa5 100644 --- a/upstream/vendor/gorm.io/gorm/callbacks/preload.go +++ b/upstream/vendor/gorm.io/gorm/callbacks/preload.go @@ -75,7 +75,7 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string { names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations)) for _, relation := range embeddedRelations.Relations { // skip first struct name - names = append(names, strings.Join(relation.Field.BindNames[1:], ".")) + names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], ".")) } for _, relations := range embeddedRelations.EmbeddedRelations { names = append(names, embeddedValues(relations)...) @@ -121,10 +121,33 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati } } else if rel := relationships.Relations[name]; rel != nil { if joined, nestedJoins := isJoined(name); joined { - reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) - tx := preloadDB(db, reflectValue, reflectValue.Interface()) - if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { - return err + switch rv := db.Statement.ReflectValue; rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Len() > 0 { + reflectValue := rel.FieldSchema.MakeSlice().Elem() + reflectValue.SetLen(rv.Len()) + for i := 0; i < rv.Len(); i++ { + frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i)) + if frv.Kind() != reflect.Ptr { + reflectValue.Index(i).Set(frv.Addr()) + } else { + reflectValue.Index(i).Set(frv) + } + } + + tx := preloadDB(db, reflectValue, reflectValue.Interface()) + if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { + return err + } + } + case reflect.Struct: + reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv) + tx := preloadDB(db, reflectValue, reflectValue.Interface()) + if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { + return err + } + default: + return gorm.ErrInvalidData } } else { tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}) diff --git a/upstream/vendor/gorm.io/gorm/callbacks/update.go b/upstream/vendor/gorm.io/gorm/callbacks/update.go index ff075dcf2..7cde7f619 100644 --- a/upstream/vendor/gorm.io/gorm/callbacks/update.go +++ b/upstream/vendor/gorm.io/gorm/callbacks/update.go @@ -234,7 +234,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field.AutoUpdateTime == schema.UnixNanosecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) } else if field.AutoUpdateTime == schema.UnixMillisecond { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6}) + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()}) } else if field.AutoUpdateTime == schema.UnixSecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) } else { @@ -268,7 +268,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field.AutoUpdateTime == schema.UnixNanosecond { value = stmt.DB.NowFunc().UnixNano() } else if field.AutoUpdateTime == schema.UnixMillisecond { - value = stmt.DB.NowFunc().UnixNano() / 1e6 + value = stmt.DB.NowFunc().UnixMilli() } else if field.AutoUpdateTime == schema.UnixSecond { value = stmt.DB.NowFunc().Unix() } else { diff --git a/upstream/vendor/gorm.io/gorm/chainable_api.go b/upstream/vendor/gorm.io/gorm/chainable_api.go index 1ec9b865f..333706032 100644 --- a/upstream/vendor/gorm.io/gorm/chainable_api.go +++ b/upstream/vendor/gorm.io/gorm/chainable_api.go @@ -429,6 +429,15 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) { return } +// Unscoped disables the global scope of soft deletion in a query. +// By default, GORM uses soft deletion, marking records as "deleted" +// by setting a timestamp on a specific field (e.g., `deleted_at`). +// Unscoped allows queries to include records marked as deleted, +// overriding the soft deletion behavior. +// Example: +// var users []User +// db.Unscoped().Find(&users) +// // Retrieves all users, including deleted ones. func (db *DB) Unscoped() (tx *DB) { tx = db.getInstance() tx.Statement.Unscoped = true diff --git a/upstream/vendor/gorm.io/gorm/clause/limit.go b/upstream/vendor/gorm.io/gorm/clause/limit.go index abda00551..3edde4346 100644 --- a/upstream/vendor/gorm.io/gorm/clause/limit.go +++ b/upstream/vendor/gorm.io/gorm/clause/limit.go @@ -1,7 +1,5 @@ package clause -import "strconv" - // Limit limit clause type Limit struct { Limit *int @@ -17,14 +15,14 @@ func (limit Limit) Name() string { func (limit Limit) Build(builder Builder) { if limit.Limit != nil && *limit.Limit >= 0 { builder.WriteString("LIMIT ") - builder.WriteString(strconv.Itoa(*limit.Limit)) + builder.AddVar(builder, *limit.Limit) } if limit.Offset > 0 { if limit.Limit != nil && *limit.Limit >= 0 { builder.WriteByte(' ') } builder.WriteString("OFFSET ") - builder.WriteString(strconv.Itoa(limit.Offset)) + builder.AddVar(builder, limit.Offset) } } diff --git a/upstream/vendor/gorm.io/gorm/clause/where.go b/upstream/vendor/gorm.io/gorm/clause/where.go index 46d0b3193..2c3c90f18 100644 --- a/upstream/vendor/gorm.io/gorm/clause/where.go +++ b/upstream/vendor/gorm.io/gorm/clause/where.go @@ -21,11 +21,11 @@ func (where Where) Name() string { // Build build where clause func (where Where) Build(builder Builder) { - if len(where.Exprs) == 1 { - if andCondition, ok := where.Exprs[0].(AndConditions); ok { - where.Exprs = andCondition.Exprs - } - } + if len(where.Exprs) == 1 { + if andCondition, ok := where.Exprs[0].(AndConditions); ok { + where.Exprs = andCondition.Exprs + } + } // Switch position if the first query expression is a single Or condition for idx, expr := range where.Exprs { @@ -166,19 +166,63 @@ type NotConditions struct { } func (not NotConditions) Build(builder Builder) { - if len(not.Exprs) > 1 { - builder.WriteByte('(') + anyNegationBuilder := false + for _, c := range not.Exprs { + if _, ok := c.(NegationExpressionBuilder); ok { + anyNegationBuilder = true + break + } } - for idx, c := range not.Exprs { - if idx > 0 { - builder.WriteString(AndWithSpace) + if anyNegationBuilder { + if len(not.Exprs) > 1 { + builder.WriteByte('(') } - if negationBuilder, ok := c.(NegationExpressionBuilder); ok { - negationBuilder.NegationBuild(builder) - } else { - builder.WriteString("NOT ") + for idx, c := range not.Exprs { + if idx > 0 { + builder.WriteString(AndWithSpace) + } + + if negationBuilder, ok := c.(NegationExpressionBuilder); ok { + negationBuilder.NegationBuild(builder) + } else { + builder.WriteString("NOT ") + e, wrapInParentheses := c.(Expr) + if wrapInParentheses { + sql := strings.ToUpper(e.SQL) + if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses { + builder.WriteByte('(') + } + } + + c.Build(builder) + + if wrapInParentheses { + builder.WriteByte(')') + } + } + } + + if len(not.Exprs) > 1 { + builder.WriteByte(')') + } + } else { + builder.WriteString("NOT ") + if len(not.Exprs) > 1 { + builder.WriteByte('(') + } + + for idx, c := range not.Exprs { + if idx > 0 { + switch c.(type) { + case OrConditions: + builder.WriteString(OrWithSpace) + default: + builder.WriteString(AndWithSpace) + } + } + e, wrapInParentheses := c.(Expr) if wrapInParentheses { sql := strings.ToUpper(e.SQL) @@ -193,9 +237,9 @@ func (not NotConditions) Build(builder Builder) { builder.WriteByte(')') } } - } - if len(not.Exprs) > 1 { - builder.WriteByte(')') + if len(not.Exprs) > 1 { + builder.WriteByte(')') + } } } diff --git a/upstream/vendor/gorm.io/gorm/errors.go b/upstream/vendor/gorm.io/gorm/errors.go index cd76f1f52..025f5d643 100644 --- a/upstream/vendor/gorm.io/gorm/errors.go +++ b/upstream/vendor/gorm.io/gorm/errors.go @@ -49,4 +49,6 @@ var ( ErrDuplicatedKey = errors.New("duplicated key not allowed") // ErrForeignKeyViolated occurs when there is a foreign key constraint violation ErrForeignKeyViolated = errors.New("violates foreign key constraint") + // ErrCheckConstraintViolated occurs when there is a check constraint violation + ErrCheckConstraintViolated = errors.New("violates check constraint") ) diff --git a/upstream/vendor/gorm.io/gorm/logger/sql.go b/upstream/vendor/gorm.io/gorm/logger/sql.go index 8ce8d8b17..ad4787956 100644 --- a/upstream/vendor/gorm.io/gorm/logger/sql.go +++ b/upstream/vendor/gorm.io/gorm/logger/sql.go @@ -34,6 +34,19 @@ var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO // RegEx matches only numeric values var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`) +func isNumeric(k reflect.Kind) bool { + switch k { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return true + case reflect.Float32, reflect.Float64: + return true + default: + return false + } +} + // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { var ( @@ -110,6 +123,12 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a convertParams(v, idx) } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { convertParams(reflect.Indirect(rv).Interface(), idx) + } else if isNumeric(rv.Kind()) { + if rv.CanInt() || rv.CanUint() { + vars[idx] = fmt.Sprintf("%d", rv.Interface()) + } else { + vars[idx] = fmt.Sprintf("%.6f", rv.Interface()) + } } else { for _, t := range convertibleTypes { if rv.Type().ConvertibleTo(t) { diff --git a/upstream/vendor/gorm.io/gorm/migrator/migrator.go b/upstream/vendor/gorm.io/gorm/migrator/migrator.go index d97fbf35c..189a141f5 100644 --- a/upstream/vendor/gorm.io/gorm/migrator/migrator.go +++ b/upstream/vendor/gorm.io/gorm/migrator/migrator.go @@ -7,6 +7,7 @@ import ( "fmt" "reflect" "regexp" + "strconv" "strings" "time" @@ -93,10 +94,6 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { expr.SQL += " NOT NULL" } - if field.Unique { - expr.SQL += " UNIQUE" - } - if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { if field.DefaultValueInterface != nil { defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} @@ -130,6 +127,11 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } else { if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + + if stmt.Schema == nil { + return errors.New("failed to get schema") + } + columnTypes, err := queryTx.Migrator().ColumnTypes(value) if err != nil { return err @@ -214,6 +216,11 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range m.ReorderModels(values, false) { tx := m.DB.Session(&gorm.Session{}) if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { + + if stmt.Schema == nil { + return errors.New("failed to get schema") + } + var ( createTableSQL = "CREATE TABLE ? (" values = []interface{}{m.CurrentTable(stmt)} @@ -366,6 +373,9 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error { func (m Migrator) AddColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { // avoid using the same name field + if stmt.Schema == nil { + return errors.New("failed to get schema") + } f := stmt.Schema.LookUpField(name) if f == nil { return fmt.Errorf("failed to look up field with name: %s", name) @@ -385,8 +395,10 @@ func (m Migrator) AddColumn(value interface{}, name string) error { // DropColumn drop value's `name` column func (m Migrator) DropColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(name); field != nil { - name = field.DBName + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(name); field != nil { + name = field.DBName + } } return m.DB.Exec( @@ -398,13 +410,15 @@ func (m Migrator) DropColumn(value interface{}, name string) error { // AlterColumn alter value's `field` column' type based on schema definition func (m Migrator) AlterColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(field); field != nil { - fileType := m.FullDataTypeOf(field) - return m.DB.Exec( - "ALTER TABLE ? ALTER COLUMN ? TYPE ?", - m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, - ).Error + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(field); field != nil { + fileType := m.FullDataTypeOf(field) + return m.DB.Exec( + "ALTER TABLE ? ALTER COLUMN ? TYPE ?", + m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, + ).Error + } } return fmt.Errorf("failed to look up field with name: %s", field) }) @@ -416,8 +430,10 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() name := field - if field := stmt.Schema.LookUpField(field); field != nil { - name = field.DBName + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(field); field != nil { + name = field.DBName + } } return m.DB.Raw( @@ -432,12 +448,14 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { // RenameColumn rename value's field name from oldName to newName func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(oldName); field != nil { - oldName = field.DBName - } + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(oldName); field != nil { + oldName = field.DBName + } - if field := stmt.Schema.LookUpField(newName); field != nil { - newName = field.DBName + if field := stmt.Schema.LookUpField(newName); field != nil { + newName = field.DBName + } } return m.DB.Exec( @@ -512,14 +530,6 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } - // check unique - if unique, ok := columnType.Unique(); ok && unique != (field.Unique || field.UniqueIndex != "") { - // not primary key - if !field.PrimaryKey { - alterColumn = true - } - } - // check default value if !field.PrimaryKey { currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL")) @@ -530,12 +540,18 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } else if !dvNotNull && currentDefaultNotNull { // null -> default value alterColumn = true - } else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) || - (field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) { - // default value not equal - // not both null - if currentDefaultNotNull || dvNotNull { - alterColumn = true + } else if currentDefaultNotNull || dvNotNull { + switch field.GORMDataType { + case schema.Time: + if !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()")) { + alterColumn = true + } + case schema.Bool: + v1, _ := strconv.ParseBool(dv) + v2, _ := strconv.ParseBool(field.DefaultValue) + alterColumn = v1 != v2 + default: + alterColumn = dv != field.DefaultValue } } } @@ -548,8 +564,14 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } - if alterColumn && !field.IgnoreMigration { - return m.DB.Migrator().AlterColumn(value, field.DBName) + if alterColumn { + if err := m.DB.Migrator().AlterColumn(value, field.DBName); err != nil { + return err + } + } + + if err := m.DB.Migrator().MigrateColumnUnique(value, field, columnType); err != nil { + return err } return nil @@ -793,6 +815,9 @@ type BuildIndexOptionsInterface interface { // CreateIndex create index `name` func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if stmt.Schema == nil { + return errors.New("failed to get schema") + } if idx := stmt.Schema.LookIndex(name); idx != nil { opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} @@ -825,8 +850,10 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { // DropIndex drop index `name` func (m Migrator) DropIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - name = idx.Name + if stmt.Schema != nil { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } } return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error @@ -838,8 +865,10 @@ func (m Migrator) HasIndex(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() - if idx := stmt.Schema.LookIndex(name); idx != nil { - name = idx.Name + if stmt.Schema != nil { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } } return m.DB.Raw( diff --git a/upstream/vendor/gorm.io/gorm/prepare_stmt.go b/upstream/vendor/gorm.io/gorm/prepare_stmt.go index aa944624c..4d533885e 100644 --- a/upstream/vendor/gorm.io/gorm/prepare_stmt.go +++ b/upstream/vendor/gorm.io/gorm/prepare_stmt.go @@ -3,6 +3,8 @@ package gorm import ( "context" "database/sql" + "database/sql/driver" + "errors" "reflect" "sync" ) @@ -147,7 +149,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { result, err = stmt.ExecContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { db.Mux.Lock() defer db.Mux.Unlock() go stmt.Close() @@ -161,7 +163,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { rows, err = stmt.QueryContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { db.Mux.Lock() defer db.Mux.Unlock() @@ -180,6 +182,14 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg return &sql.Row{} } +func (db *PreparedStmtDB) Ping() error { + conn, err := db.GetDBConn() + if err != nil { + return err + } + return conn.Ping() +} + type PreparedStmtTX struct { Tx PreparedStmtDB *PreparedStmtDB @@ -207,7 +217,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { tx.PreparedStmtDB.Mux.Lock() defer tx.PreparedStmtDB.Mux.Unlock() @@ -222,7 +232,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { tx.PreparedStmtDB.Mux.Lock() defer tx.PreparedStmtDB.Mux.Unlock() @@ -240,3 +250,11 @@ func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, arg } return &sql.Row{} } + +func (tx *PreparedStmtTX) Ping() error { + conn, err := tx.GetDBConn() + if err != nil { + return err + } + return conn.Ping() +} diff --git a/upstream/vendor/gorm.io/gorm/scan.go b/upstream/vendor/gorm.io/gorm/scan.go index 736db4d3a..e95e6d30b 100644 --- a/upstream/vendor/gorm.io/gorm/scan.go +++ b/upstream/vendor/gorm.io/gorm/scan.go @@ -257,9 +257,11 @@ func Scan(rows Rows, db *DB, mode ScanMode) { continue } } - values[idx] = &sql.RawBytes{} + var val interface{} + values[idx] = &val } else { - values[idx] = &sql.RawBytes{} + var val interface{} + values[idx] = &val } } } @@ -274,12 +276,16 @@ func Scan(rows Rows, db *DB, mode ScanMode) { if !update || reflectValue.Len() == 0 { update = false - // if the slice cap is externally initialized, the externally initialized slice is directly used here - if reflectValue.Cap() == 0 { - db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) - } else if !isArrayKind { - reflectValue.SetLen(0) - db.Statement.ReflectValue.Set(reflectValue) + if isArrayKind { + db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type())) + } else { + // if the slice cap is externally initialized, the externally initialized slice is directly used here + if reflectValue.Cap() == 0 { + db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) + } else { + reflectValue.SetLen(0) + db.Statement.ReflectValue.Set(reflectValue) + } } } diff --git a/upstream/vendor/gorm.io/gorm/schema/constraint.go b/upstream/vendor/gorm.io/gorm/schema/constraint.go index 5f6beb89c..80a743a83 100644 --- a/upstream/vendor/gorm.io/gorm/schema/constraint.go +++ b/upstream/vendor/gorm.io/gorm/schema/constraint.go @@ -8,7 +8,7 @@ import ( ) // reg match english letters and midline -var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$") +var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`) type CheckConstraint struct { Name string diff --git a/upstream/vendor/gorm.io/gorm/schema/field.go b/upstream/vendor/gorm.io/gorm/schema/field.go index 91e4c0abf..a16c98ab0 100644 --- a/upstream/vendor/gorm.io/gorm/schema/field.go +++ b/upstream/vendor/gorm.io/gorm/schema/field.go @@ -56,6 +56,7 @@ type Field struct { Name string DBName string BindNames []string + EmbeddedBindNames []string DataType DataType GORMDataType DataType PrimaryKey bool @@ -112,6 +113,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { Name: fieldStruct.Name, DBName: tagSetting["COLUMN"], BindNames: []string{fieldStruct.Name}, + EmbeddedBindNames: []string{fieldStruct.Name}, FieldType: fieldStruct.Type, IndirectFieldType: fieldStruct.Type, StructField: fieldStruct, @@ -403,6 +405,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.Schema = schema ef.OwnerSchema = field.EmbeddedSchema ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) + if _, ok := field.TagSettings["EMBEDDED"]; ok || !fieldStruct.Anonymous { + ef.EmbeddedBindNames = append([]string{fieldStruct.Name}, ef.EmbeddedBindNames...) + } // index is negative means is pointer if field.FieldType.Kind() == reflect.Struct { ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...) @@ -664,7 +669,7 @@ func (field *Field) setupValuerAndSetter() { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) + field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli()) } else { field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } @@ -673,7 +678,7 @@ func (field *Field) setupValuerAndSetter() { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) + field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli()) } else { field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } @@ -738,7 +743,7 @@ func (field *Field) setupValuerAndSetter() { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano())) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixMilli())) } else { field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix())) } diff --git a/upstream/vendor/gorm.io/gorm/schema/relationship.go b/upstream/vendor/gorm.io/gorm/schema/relationship.go index 2e94fc2cb..c11918a5e 100644 --- a/upstream/vendor/gorm.io/gorm/schema/relationship.go +++ b/upstream/vendor/gorm.io/gorm/schema/relationship.go @@ -150,12 +150,12 @@ func (schema *Schema) setRelation(relation *Relationship) { } // set embedded relation - if len(relation.Field.BindNames) <= 1 { + if len(relation.Field.EmbeddedBindNames) <= 1 { return } relationships := &schema.Relationships - for i, name := range relation.Field.BindNames { - if i < len(relation.Field.BindNames)-1 { + for i, name := range relation.Field.EmbeddedBindNames { + if i < len(relation.Field.EmbeddedBindNames)-1 { if relationships.EmbeddedRelations == nil { relationships.EmbeddedRelations = map[string]*Relationships{} } diff --git a/upstream/vendor/gorm.io/gorm/schema/serializer.go b/upstream/vendor/gorm.io/gorm/schema/serializer.go index 397edff03..f500521ef 100644 --- a/upstream/vendor/gorm.io/gorm/schema/serializer.go +++ b/upstream/vendor/gorm.io/gorm/schema/serializer.go @@ -126,12 +126,12 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect rv := reflect.ValueOf(fieldValue) switch v := fieldValue.(type) { case int64, int, uint, uint64, int32, uint32, int16, uint16: - result = time.Unix(reflect.Indirect(rv).Int(), 0) + result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: if rv.IsZero() { return nil, nil } - result = time.Unix(reflect.Indirect(rv).Int(), 0) + result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() default: err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) } diff --git a/upstream/vendor/gorm.io/gorm/utils/utils.go b/upstream/vendor/gorm.io/gorm/utils/utils.go index a4d8ac250..b8d30b35b 100644 --- a/upstream/vendor/gorm.io/gorm/utils/utils.go +++ b/upstream/vendor/gorm.io/gorm/utils/utils.go @@ -32,12 +32,16 @@ func sourceDir(file string) string { // FileWithLineNum return the file name and line number of the current file func FileWithLineNum() string { - // the second caller usually from gorm internal, so set i start from 2 - for i := 2; i < 15; i++ { - _, file, line, ok := runtime.Caller(i) - if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) && - !strings.HasSuffix(file, ".gen.go") { - return file + ":" + strconv.FormatInt(int64(line), 10) + pcs := [13]uintptr{} + // the third caller usually from gorm internal + len := runtime.Callers(3, pcs[:]) + frames := runtime.CallersFrames(pcs[:len]) + for i := 0; i < len; i++ { + // second return value is "more", not "ok" + frame, _ := frames.Next() + if (!strings.HasPrefix(frame.File, gormSourceDir) || + strings.HasSuffix(frame.File, "_test.go")) && !strings.HasSuffix(frame.File, ".gen.go") { + return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) } } @@ -74,7 +78,11 @@ func ToStringKey(values ...interface{}) string { case uint: results[idx] = strconv.FormatUint(uint64(v), 10) default: - results[idx] = fmt.Sprint(reflect.Indirect(reflect.ValueOf(v)).Interface()) + results[idx] = "nil" + vv := reflect.ValueOf(v) + if vv.IsValid() && !vv.IsZero() { + results[idx] = fmt.Sprint(reflect.Indirect(vv).Interface()) + } } } diff --git a/upstream/vendor/modules.txt b/upstream/vendor/modules.txt index 51b2d6e53..e2da86224 100644 --- a/upstream/vendor/modules.txt +++ b/upstream/vendor/modules.txt @@ -1504,13 +1504,13 @@ gopkg.in/yaml.v3 # gorm.io/driver/mysql v1.5.1 ## explicit; go 1.14 gorm.io/driver/mysql -# gorm.io/driver/postgres v1.5.2 -## explicit; go 1.18 +# gorm.io/driver/postgres v1.5.11 +## explicit; go 1.19 gorm.io/driver/postgres # gorm.io/driver/sqlite v1.5.7 ## explicit; go 1.20 gorm.io/driver/sqlite -# gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde +# gorm.io/gorm v1.25.10 ## explicit; go 1.18 gorm.io/gorm gorm.io/gorm/callbacks