diff --git a/sqlserver.go b/sqlserver.go index 6b32e8f..072df08 100644 --- a/sqlserver.go +++ b/sqlserver.go @@ -40,9 +40,13 @@ func New(config Config) gorm.Dialector { } func (dialector Dialector) Initialize(db *gorm.DB) (err error) { - // register callbacks - callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ + CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT"}, + QueryClauses: []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}, + UpdateClauses: []string{"UPDATE", "SET", "RETURNING", "FROM", "WHERE"}, + DeleteClauses: []string{"DELETE", "FROM", "RETURNING", "WHERE"}, + }) db.Callback().Create().Replace("gorm:create", Create) db.Callback().Update().Replace("gorm:update", Update) @@ -97,6 +101,34 @@ func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { } } }, + "RETURNING": func(c clause.Clause, builder clause.Builder) { + if returning, ok := c.Expression.(clause.Returning); ok { + if stmt, ok := builder.(*gorm.Statement); ok { + var outputTable string + if _, ok := stmt.Clauses["UPDATE"]; ok { + outputTable = "INSERTED" + } else if _, ok := stmt.Clauses["DELETE"]; ok { + outputTable = "DELETED" + } + + if outputTable != "" { + stmt.WriteString("OUTPUT ") + + if len(returning.Columns) > 0 { + columns := []clause.Column{} + for _, column := range returning.Columns { + column.Table = outputTable + columns = append(columns, column) + } + returning.Columns = columns + returning.Build(stmt) + } else { + stmt.WriteString(outputTable + ".*") + } + } + } + } + }, } } diff --git a/update.go b/update.go index 991ae68..7f59b9d 100644 --- a/update.go +++ b/update.go @@ -5,7 +5,9 @@ import ( "gorm.io/gorm/callbacks" ) -var updateFunc = callbacks.Update(&callbacks.Config{}) +var updateFunc = callbacks.Update(&callbacks.Config{ + UpdateClauses: []string{"UPDATE", "SET", "RETURNING", "FROM", "WHERE"}, +}) func Update(db *gorm.DB) { if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement {