@@ -3,8 +3,10 @@ package connection
3
3
import (
4
4
"chat/globals"
5
5
"chat/utils"
6
+ "crypto/tls"
6
7
"database/sql"
7
8
"fmt"
9
+ "github.com/go-sql-driver/mysql"
8
10
_ "github.com/go-sql-driver/mysql"
9
11
_ "github.com/mattn/go-sqlite3"
10
12
"github.com/spf13/viper"
@@ -32,15 +34,25 @@ func getConn() *sql.DB {
32
34
return db
33
35
}
34
36
35
- // connect to MySQL
36
- db , err := sql .Open ("mysql" , fmt .Sprintf (
37
+ mysqlUrl := fmt .Sprintf (
37
38
"%s:%s@tcp(%s:%d)/%s" ,
38
39
viper .GetString ("mysql.user" ),
39
40
viper .GetString ("mysql.password" ),
40
41
viper .GetString ("mysql.host" ),
41
42
viper .GetInt ("mysql.port" ),
42
43
viper .GetString ("mysql.db" ),
43
- ))
44
+ )
45
+ if viper .GetBool ("mysql.tls" ) {
46
+ mysql .RegisterTLSConfig ("tls" , & tls.Config {
47
+ MinVersion : tls .VersionTLS12 ,
48
+ ServerName : viper .GetString ("mysql.host" ),
49
+ })
50
+
51
+ mysqlUrl += "?tls=tls"
52
+ }
53
+
54
+ // connect to MySQL
55
+ db , err := sql .Open ("mysql" , mysqlUrl )
44
56
45
57
if pingErr := db .Ping (); err != nil || pingErr != nil {
46
58
errMsg := utils .Multi [string ](err != nil , utils .GetError (err ), utils .GetError (pingErr )) // err.Error() may contain nil pointer
0 commit comments