|
1 | 1 | package auth
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "chat/channel" |
4 | 5 | "chat/globals"
|
5 | 6 | "chat/utils"
|
6 | 7 | "database/sql"
|
7 | 8 | "errors"
|
8 | 9 | "fmt"
|
9 | 10 | "github.com/dgrijalva/jwt-go"
|
10 | 11 | "github.com/gin-gonic/gin"
|
| 12 | + "github.com/go-redis/redis/v8" |
11 | 13 | "github.com/spf13/viper"
|
| 14 | + "strings" |
12 | 15 | "time"
|
13 | 16 | )
|
14 | 17 |
|
@@ -54,9 +57,129 @@ func ParseApiKey(c *gin.Context, key string) *User {
|
54 | 57 | return &user
|
55 | 58 | }
|
56 | 59 |
|
| 60 | +func getCode(c *gin.Context, cache *redis.Client, email string) string { |
| 61 | + code, err := cache.Get(c, fmt.Sprintf("nio:otp:%s", email)).Result() |
| 62 | + if err != nil { |
| 63 | + return "" |
| 64 | + } |
| 65 | + return code |
| 66 | +} |
| 67 | + |
| 68 | +func checkCode(c *gin.Context, cache *redis.Client, email, code string) bool { |
| 69 | + storage := getCode(c, cache, email) |
| 70 | + if len(storage) == 0 { |
| 71 | + return false |
| 72 | + } |
| 73 | + |
| 74 | + if storage != code { |
| 75 | + return false |
| 76 | + } |
| 77 | + |
| 78 | + cache.Del(c, fmt.Sprintf("nio:top:%s", email)) |
| 79 | + return true |
| 80 | +} |
| 81 | + |
| 82 | +func setCode(c *gin.Context, cache *redis.Client, email, code string) { |
| 83 | + cache.Set(c, fmt.Sprintf("nio:otp:%s", email), code, 5*time.Minute) |
| 84 | +} |
| 85 | + |
| 86 | +func generateCode(c *gin.Context, cache *redis.Client, email string) string { |
| 87 | + code := utils.GenerateCode(6) |
| 88 | + setCode(c, cache, email, code) |
| 89 | + return code |
| 90 | +} |
| 91 | + |
| 92 | +func Verify(c *gin.Context, email string) error { |
| 93 | + cache := utils.GetCacheFromContext(c) |
| 94 | + code := generateCode(c, cache, email) |
| 95 | + |
| 96 | + provider := channel.SystemInstance.GetMail() |
| 97 | + return provider.SendMail( |
| 98 | + email, |
| 99 | + "Chat Nio | OTP Verification", |
| 100 | + fmt.Sprintf("Your OTP code is: %s", code), |
| 101 | + ) |
| 102 | +} |
| 103 | + |
| 104 | +func SignUp(c *gin.Context, form RegisterForm) (string, error) { |
| 105 | + db := utils.GetDBFromContext(c) |
| 106 | + cache := utils.GetCacheFromContext(c) |
| 107 | + |
| 108 | + username := strings.TrimSpace(form.Username) |
| 109 | + password := strings.TrimSpace(form.Password) |
| 110 | + email := strings.TrimSpace(form.Email) |
| 111 | + code := strings.TrimSpace(form.Code) |
| 112 | + |
| 113 | + if !utils.All( |
| 114 | + validateUsername(username), |
| 115 | + validatePassword(password), |
| 116 | + validateEmail(email), |
| 117 | + validateCode(code), |
| 118 | + ) { |
| 119 | + return "", errors.New("invalid username/password/email format") |
| 120 | + } |
| 121 | + |
| 122 | + if !IsUserExist(db, username) { |
| 123 | + return "", fmt.Errorf("username is already taken, please try another one username (your current username: %s)", username) |
| 124 | + } |
| 125 | + |
| 126 | + if !IsEmailExist(db, email) { |
| 127 | + return "", fmt.Errorf("email is already taken, please try another one email (your current email: %s)", email) |
| 128 | + } |
| 129 | + |
| 130 | + if !checkCode(c, cache, email, code) { |
| 131 | + return "", errors.New("invalid email verification code") |
| 132 | + } |
| 133 | + |
| 134 | + hash := utils.Sha2Encrypt(password) |
| 135 | + |
| 136 | + user := &User{ |
| 137 | + Username: username, |
| 138 | + Password: hash, |
| 139 | + Email: email, |
| 140 | + BindID: getMaxBindId(db) + 1, |
| 141 | + Token: utils.Sha2Encrypt(email + username), |
| 142 | + } |
| 143 | + |
| 144 | + if _, err := db.Exec(` |
| 145 | + INSERT INTO auth (username, password, email, bind_id, token) |
| 146 | + VALUES (?, ?, ?, ?, ?) |
| 147 | + `, user.Username, user.Password, user.Email, user.BindID, user.Token); err != nil { |
| 148 | + return "", err |
| 149 | + } |
| 150 | + |
| 151 | + return user.GenerateToken() |
| 152 | +} |
| 153 | + |
| 154 | +func Login(c *gin.Context, form LoginForm) (string, error) { |
| 155 | + db := utils.GetDBFromContext(c) |
| 156 | + username := strings.TrimSpace(form.Username) |
| 157 | + password := strings.TrimSpace(form.Password) |
| 158 | + |
| 159 | + if !utils.All( |
| 160 | + validateUsernameOrEmail(username), |
| 161 | + validatePassword(password), |
| 162 | + ) { |
| 163 | + return "", errors.New("invalid username or password format") |
| 164 | + } |
| 165 | + |
| 166 | + hash := utils.Sha2Encrypt(password) |
| 167 | + |
| 168 | + // get user from db by username (or email) and password |
| 169 | + var user User |
| 170 | + if err := db.QueryRow(` |
| 171 | + SELECT auth.id, auth.username, auth.password FROM auth |
| 172 | + WHERE (auth.username = ? OR auth.email = ?) AND auth.password = ? |
| 173 | + `, username, hash).Scan(&user.ID, &user.Username, &user.Password); err != nil { |
| 174 | + return "", errors.New("invalid username or password") |
| 175 | + } |
| 176 | + |
| 177 | + return user.GenerateToken() |
| 178 | +} |
| 179 | + |
57 | 180 | func DeepLogin(c *gin.Context, token string) (string, error) {
|
58 | 181 | if !useDeeptrain() {
|
59 |
| - return "", errors.New("deeptrain feature is disabled") |
| 182 | + return "", errors.New("deeptrain mode is disabled") |
60 | 183 | }
|
61 | 184 |
|
62 | 185 | user := Validate(token)
|
@@ -91,6 +214,41 @@ func DeepLogin(c *gin.Context, token string) (string, error) {
|
91 | 214 | return u.GenerateToken()
|
92 | 215 | }
|
93 | 216 |
|
| 217 | +func Reset(c *gin.Context, form ResetForm) error { |
| 218 | + db := utils.GetDBFromContext(c) |
| 219 | + cache := utils.GetCacheFromContext(c) |
| 220 | + |
| 221 | + email := strings.TrimSpace(form.Email) |
| 222 | + code := strings.TrimSpace(form.Code) |
| 223 | + password := strings.TrimSpace(form.Password) |
| 224 | + |
| 225 | + if !utils.All( |
| 226 | + validateEmail(email), |
| 227 | + validateCode(code), |
| 228 | + validatePassword(password), |
| 229 | + ) { |
| 230 | + return errors.New("invalid email/code/password format") |
| 231 | + } |
| 232 | + |
| 233 | + if !IsEmailExist(db, email) { |
| 234 | + return errors.New("email is not registered") |
| 235 | + } |
| 236 | + |
| 237 | + if !checkCode(c, cache, email, code) { |
| 238 | + return errors.New("invalid email verification code") |
| 239 | + } |
| 240 | + |
| 241 | + hash := utils.Sha2Encrypt(password) |
| 242 | + |
| 243 | + if _, err := db.Exec(` |
| 244 | + UPDATE auth SET password = ? WHERE email = ? |
| 245 | + `, hash, email); err != nil { |
| 246 | + return err |
| 247 | + } |
| 248 | + |
| 249 | + return nil |
| 250 | +} |
| 251 | + |
94 | 252 | func (u *User) Validate(c *gin.Context) bool {
|
95 | 253 | if u.Username == "" || u.Password == "" {
|
96 | 254 | return false
|
|
0 commit comments