@@ -3,6 +3,7 @@ use std::ops::Range;
3
3
4
4
use async_trait:: async_trait;
5
5
use atuin_common:: record:: { EncryptedData , HostId , Record , RecordIdx , RecordStatus } ;
6
+ use atuin_common:: utils:: crypto_random_string;
6
7
use atuin_server_database:: models:: { History , NewHistory , NewSession , NewUser , Session , User } ;
7
8
use atuin_server_database:: { Database , DbError , DbResult } ;
8
9
use futures_util:: TryStreamExt ;
@@ -11,7 +12,7 @@ use sqlx::postgres::PgPoolOptions;
11
12
use sqlx:: Row ;
12
13
13
14
use time:: { OffsetDateTime , PrimitiveDateTime , UtcOffset } ;
14
- use tracing:: instrument;
15
+ use tracing:: { instrument, trace } ;
15
16
use uuid:: Uuid ;
16
17
use wrappers:: { DbHistory , DbRecord , DbSession , DbUser } ;
17
18
@@ -100,18 +101,100 @@ impl Database for Postgres {
100
101
101
102
#[ instrument( skip_all) ]
102
103
async fn get_user ( & self , username : & str ) -> DbResult < User > {
103
- sqlx:: query_as ( "select id, username, email, password from users where username = $1" )
104
- . bind ( username)
105
- . fetch_one ( & self . pool )
106
- . await
107
- . map_err ( fix_error)
108
- . map ( |DbUser ( user) | user)
104
+ sqlx:: query_as (
105
+ "select id, username, email, password, verified_at from users where username = $1" ,
106
+ )
107
+ . bind ( username)
108
+ . fetch_one ( & self . pool )
109
+ . await
110
+ . map_err ( fix_error)
111
+ . map ( |DbUser ( user) | user)
112
+ }
113
+
114
+ #[ instrument( skip_all) ]
115
+ async fn user_verified ( & self , id : i64 ) -> DbResult < bool > {
116
+ let res: ( bool , ) =
117
+ sqlx:: query_as ( "select verified_at is not null from users where id = $1" )
118
+ . bind ( id)
119
+ . fetch_one ( & self . pool )
120
+ . await
121
+ . map_err ( fix_error) ?;
122
+
123
+ Ok ( res. 0 )
124
+ }
125
+
126
+ #[ instrument( skip_all) ]
127
+ async fn verify_user ( & self , id : i64 ) -> DbResult < ( ) > {
128
+ sqlx:: query (
129
+ "update users set verified_at = (current_timestamp at time zone 'utc') where id=$1" ,
130
+ )
131
+ . bind ( id)
132
+ . execute ( & self . pool )
133
+ . await
134
+ . map_err ( fix_error) ?;
135
+
136
+ Ok ( ( ) )
137
+ }
138
+
139
+ /// Return a valid verification token for the user
140
+ /// If the user does not have any token, create one, insert it, and return
141
+ /// If the user has a token, but it's invalid, delete it, create a new one, return
142
+ /// If the user already has a valid token, return it
143
+ #[ instrument( skip_all) ]
144
+ async fn user_verification_token ( & self , id : i64 ) -> DbResult < String > {
145
+ const TOKEN_VALID_MINUTES : i64 = 15 ;
146
+
147
+ // First we check if there is a verification token
148
+ let token: Option < ( String , sqlx:: types:: time:: OffsetDateTime ) > = sqlx:: query_as (
149
+ "select token, valid_until from user_verification_token where user_id = $1" ,
150
+ )
151
+ . bind ( id)
152
+ . fetch_optional ( & self . pool )
153
+ . await
154
+ . map_err ( fix_error) ?;
155
+
156
+ let token = if let Some ( ( token, valid_until) ) = token {
157
+ trace ! ( "Token for user {id} valid until {valid_until}" ) ;
158
+
159
+ // We have a token, AND it's still valid
160
+ if valid_until > time:: OffsetDateTime :: now_utc ( ) {
161
+ token
162
+ } else {
163
+ // token has expired. generate a new one, return it
164
+ let token = crypto_random_string :: < 24 > ( ) ;
165
+
166
+ sqlx:: query ( "update user_verification_token set token = $2, valid_until = $3 where user_id=$1" )
167
+ . bind ( id)
168
+ . bind ( & token)
169
+ . bind ( time:: OffsetDateTime :: now_utc ( ) + time:: Duration :: minutes ( TOKEN_VALID_MINUTES ) )
170
+ . execute ( & self . pool )
171
+ . await
172
+ . map_err ( fix_error) ?;
173
+
174
+ token
175
+ }
176
+ } else {
177
+ // No token in the database! Generate one, insert it
178
+ let token = crypto_random_string :: < 24 > ( ) ;
179
+
180
+ sqlx:: query ( "insert into user_verification_token (user_id, token, valid_until) values ($1, $2, $3)" )
181
+ . bind ( id)
182
+ . bind ( & token)
183
+ . bind ( time:: OffsetDateTime :: now_utc ( ) + time:: Duration :: minutes ( TOKEN_VALID_MINUTES ) )
184
+ . execute ( & self . pool )
185
+ . await
186
+ . map_err ( fix_error) ?;
187
+
188
+ token
189
+ } ;
190
+
191
+ Ok ( token)
109
192
}
110
193
111
194
#[ instrument( skip_all) ]
112
195
async fn get_session_user ( & self , token : & str ) -> DbResult < User > {
113
196
sqlx:: query_as (
114
- "select users.id, users.username, users.email, users.password from users
197
+ "select users.id, users.username, users.email, users.password, users.verified_at from users
115
198
inner join sessions
116
199
on users.id = sessions.user_id
117
200
and sessions.token = $1" ,
0 commit comments