diff --git a/oracle/clause_builder.go b/oracle/clause_builder.go index 26cfbd7..b588d59 100644 --- a/oracle/clause_builder.go +++ b/oracle/clause_builder.go @@ -325,7 +325,16 @@ func OnConflictClauseBuilder(c clause.Clause, builder clause.Builder) { missingColumns = append(missingColumns, conflictCol.Name) } } + if len(missingColumns) > 0 { + // primary keys with auto increment will always be missing from create values columns + for _, missingCol := range missingColumns { + field := stmt.Schema.LookUpField(missingCol) + if field != nil && field.PrimaryKey && field.AutoIncrement { + return + } + } + var selectedColumns []string for col := range selectedColumnSet { selectedColumns = append(selectedColumns, col) @@ -335,6 +344,34 @@ func OnConflictClauseBuilder(c clause.Clause, builder clause.Builder) { return } + // exclude primary key, default value columns from merge update clause + if len(onConflict.DoUpdates) > 0 { + hasPrimaryKey := false + + for _, assignment := range onConflict.DoUpdates { + field := stmt.Schema.LookUpField(assignment.Column.Name) + if field != nil && field.PrimaryKey { + hasPrimaryKey = true + break + } + } + + if hasPrimaryKey { + onConflict.DoUpdates = nil + columns := make([]string, 0, len(values.Columns)-1) + for _, col := range values.Columns { + field := stmt.Schema.LookUpField(col.Name) + + if field != nil && !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil || + strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 { + columns = append(columns, col.Name) + } + + } + onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...) + } + } + // Build MERGE statement buildMergeInClause(stmt, onConflict, values, conflictColumns) } diff --git a/oracle/create.go b/oracle/create.go index 28a8118..c1415ae 100644 --- a/oracle/create.go +++ b/oracle/create.go @@ -267,9 +267,11 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau valuesColumnMap[strings.ToUpper(column.Name)] = true } + // Filter conflict columns to remove non unique columns var filteredConflictColumns []clause.Column for _, conflictCol := range conflictColumns { - if valuesColumnMap[strings.ToUpper(conflictCol.Name)] { + field := stmt.Schema.LookUpField(conflictCol.Name) + if valuesColumnMap[strings.ToUpper(conflictCol.Name)] && (field.Unique || field.AutoIncrement) { filteredConflictColumns = append(filteredConflictColumns, conflictCol) } } @@ -336,6 +338,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau // Build ON clause using conflict columns plsqlBuilder.WriteString(" ON (") + for idx, conflictCol := range conflictColumns { if idx > 0 { plsqlBuilder.WriteString(" AND ") @@ -425,7 +428,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau } plsqlBuilder.WriteString(" WHEN MATCHED THEN UPDATE SET t.") writeQuotedIdentifier(&plsqlBuilder, noopCol) - plsqlBuilder.WriteString(" = t.") + plsqlBuilder.WriteString(" = s.") writeQuotedIdentifier(&plsqlBuilder, noopCol) plsqlBuilder.WriteString("\n") } diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go index aa3939a..d36acd9 100644 --- a/tests/multi_primary_keys_test.go +++ b/tests/multi_primary_keys_test.go @@ -75,18 +75,17 @@ func compareTags(tags []Tag, contents []string) bool { } func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { - t.Skip() if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } - if name := DB.Dialector.Name(); name == "postgres" { + if name := DB.Dialector.Name(); name == "postgres" || name == "oracle" { stmt := gorm.Statement{DB: DB} stmt.Parse(&Blog{}) stmt.Schema.LookUpField("ID").Unique = true stmt.Parse(&Tag{}) stmt.Schema.LookUpField("ID").Unique = true - // postgers only allow unique constraint matching given keys + // postgers and oracle only allow unique constraint matching given keys } DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags") @@ -300,7 +299,6 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { } func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { - t.Skip() if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } @@ -309,6 +307,15 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { t.Skip("skip postgres due to it only allow unique constraint matching given keys") } + if name := DB.Dialector.Name(); name == "oracle" { + stmt := gorm.Statement{DB: DB} + stmt.Parse(&Blog{}) + stmt.Schema.LookUpField("ID").Unique = true + stmt.Parse(&Tag{}) + stmt.Schema.LookUpField("ID").Unique = true + // oracle only allow unique constraint matching given keys + } + DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags") if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { t.Fatalf("Failed to auto migrate, got error: %v", err) @@ -326,7 +333,7 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { DB.Save(&blog) blog2 := Blog{ - ID: blog.ID, + ID: 2, Locale: "EN", } DB.Create(&blog2) @@ -358,7 +365,7 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { } var blog1 Blog - DB.Preload("LocaleTags").Find(&blog1, "locale = ? AND id = ?", "ZH", blog.ID) + DB.Preload("LocaleTags").Find(&blog1, "\"locale\" = ? AND \"id\" = ?", "ZH", blog.ID) if !compareTags(blog1.LocaleTags, []string{"tag1", "tag2", "tag3"}) { t.Fatalf("Preload many2many relations") } @@ -388,7 +395,7 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { } var blog11 Blog - DB.Preload("LocaleTags").First(&blog11, "id = ? AND locale = ?", blog.ID, blog.Locale) + DB.Preload("LocaleTags").First(&blog11, "\"id\" = ? AND \"locale\" = ?", blog.ID, blog.Locale) if !compareTags(blog11.LocaleTags, []string{"tag1", "tag2", "tag3"}) { t.Fatalf("CN Blog's tags should not be changed after EN Blog Replace") } @@ -399,7 +406,7 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { } var blog21 Blog - DB.Preload("LocaleTags").First(&blog21, "id = ? AND locale = ?", blog2.ID, blog2.Locale) + DB.Preload("LocaleTags").First(&blog21, "\"id\" = ? AND \"locale\" = ?", blog2.ID, blog2.Locale) if !compareTags(blog21.LocaleTags, []string{"tag5", "tag6"}) { t.Fatalf("EN Blog's tags should be changed after Replace") } @@ -454,8 +461,6 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { } func TestCompositePrimaryKeysAssociations(t *testing.T) { - t.Skip() - type Label struct { BookID *uint `gorm:"primarykey"` Name string `gorm:"primarykey"`