@@ -18,7 +18,7 @@ type Repository interface {
18
18
FindChatRoomInfoByID (ctx context.Context , chatRoomID uuid.UUID ) (* ChatRoomInfo , error )
19
19
JoinChatRoomByID (ctx context.Context , chatRoomID uuid.UUID , userID uuid.UUID ) (* ChatRoom , error )
20
20
FindChatRoomList (ctx context.Context ) ([]* ChatRoom , error )
21
- SaveMessage ( ctx context. Context , message * Message ) error
21
+
22
22
GetPaginatedMessages (ctx context.Context , chatRoomID uuid.UUID , cursor * time.Time , pageSize int ) ([]Message , error )
23
23
GetFirstPageMessages (ctx context.Context , chatRoomID uuid.UUID , pageSize int ) ([]Message , error )
24
24
}
@@ -28,6 +28,7 @@ type DBTX interface {
28
28
PrepareContext (context.Context , string ) (* sql.Stmt , error )
29
29
QueryContext (context.Context , string , ... interface {}) (* sql.Rows , error )
30
30
QueryRowContext (context.Context , string , ... interface {}) * sql.Row
31
+ BeginTx (ctx context.Context , opts * sql.TxOptions ) (* sql.Tx , error )
31
32
}
32
33
33
34
type repository struct {
@@ -39,6 +40,11 @@ func NewRepository(db DBTX) Repository {
39
40
}
40
41
41
42
func (r * repository ) CreateChatRoom (ctx context.Context , chatRoom * ChatRoom ) (* ChatRoom , error ) {
43
+ tx , err := r .db .BeginTx (ctx , nil )
44
+ if err != nil {
45
+ slog .Error ("Creating Chatroom transaction failed" )
46
+ return nil , err // handle error appropriately
47
+ }
42
48
// Generate a UUID for user ID
43
49
chatRoom .ID = uuid .New ()
44
50
@@ -47,10 +53,10 @@ func (r *repository) CreateChatRoom(ctx context.Context, chatRoom *ChatRoom) (*C
47
53
48
54
query := "INSERT INTO chat_rooms(id, name, created_at) VALUES ($1, $2, $3) RETURNING id"
49
55
50
- err : = r .db .QueryRowContext (ctx , query , chatRoom .ID , chatRoom .Name , chatRoom .CreatedAt ).Scan (& chatRoom .ID )
56
+ err = r .db .QueryRowContext (ctx , query , chatRoom .ID , chatRoom .Name , chatRoom .CreatedAt ).Scan (& chatRoom .ID )
51
57
if err != nil {
52
58
log .Printf ("Error creating chat room: %v" , err )
53
-
59
+ tx . Rollback ()
54
60
return nil , errors .New ("failed to create chat room" )
55
61
}
56
62
@@ -122,11 +128,33 @@ func (r *repository) JoinChatRoomByID(ctx context.Context, chatRoomID uuid.UUID,
122
128
// INSERT into chat rooms, UNIQUE constraint will prevent duplicates
123
129
query := `INSERT INTO users_in_chat_rooms (user_id, chat_room_id) VALUES ($1, $2)`
124
130
125
- _ , err2 := r .db .ExecContext (ctx , query , userID , chatRoomID )
131
+ tx , err := r .db .BeginTx (ctx , nil )
132
+ if err != nil {
133
+ slog .Error ("Joining chat room transaction failed" )
134
+ return nil , err // handle error appropriately
135
+ }
136
+
137
+ // Ensure rollback in case of error
138
+ defer func () {
139
+ if err != nil {
140
+ if rbErr := tx .Rollback (); rbErr != nil {
141
+ slog .Error ("Transaction rollback failed: %v" , rbErr )
142
+ }
143
+ }
144
+ }()
126
145
127
- if err2 != nil {
128
- return nil , err2
146
+ _ , err = r .db .ExecContext (ctx , query , userID , chatRoomID )
147
+ if err != nil {
148
+ slog .Error ("Error joining chat room, db execcontext: " , err )
149
+ return nil , err
129
150
}
151
+
152
+ // Commit transaction
153
+ if err = tx .Commit (); err != nil {
154
+ slog .Error ("Transaction commit failed: " , err )
155
+ return nil , err
156
+ }
157
+
130
158
return chatRoom , nil
131
159
132
160
}
@@ -167,21 +195,6 @@ func (r *repository) FindChatRoomList(ctx context.Context) ([]*ChatRoom, error)
167
195
168
196
}
169
197
170
- func (r * repository ) SaveMessage (ctx context.Context , message * Message ) error {
171
- query := `
172
- INSERT INTO messages (id, chat_room_id, sender_id, content, media_url, created_at, read_at, deleted_by_user_id)
173
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
174
- `
175
-
176
- _ , err := r .db .ExecContext (ctx , query , message .ID , message .ChatRoomID , message .SenderID , message .Content , message .MediaURL , message .CreatedAt , message .ReadAt , message .DeletedByUserID )
177
- if err != nil {
178
- log .Printf ("Problem occured related to saving the message into the db, err: %v" , err )
179
- return err
180
- }
181
-
182
- return nil
183
- }
184
-
185
198
func (r * repository ) GetPaginatedMessages (ctx context.Context , chatRoomID uuid.UUID , cursor * time.Time , pageSize int ) ([]Message , error ) {
186
199
query := `
187
200
SELECT id, chat_room_id, sender_id, content, media_url, created_at, read_at, deleted_by_user_id
0 commit comments