Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRAFT: Add Support for MongoDB as a Local JWT backend #275

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 56 additions & 28 deletions backends/jwt_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ package backends
import (
"database/sql"
"strings"

"context"

"github.com/iegomez/mosquitto-go-auth/hashing"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
Expand All @@ -13,6 +14,7 @@ type localJWTChecker struct {
db string
postgres Postgres
mysql Mysql
mongo Mongo
userQuery string
hasher hashing.HashComparer
options tokenOptions
Expand All @@ -21,10 +23,11 @@ type localJWTChecker struct {
const (
mysqlDB = "mysql"
postgresDB = "postgres"
mongoDB = "mongo"
)

// NewLocalJWTChecker initializes a checker with a local DB.
func NewLocalJWTChecker(authOpts map[string]string, logLevel log.Level, hasher hashing.HashComparer, options tokenOptions) (jwtChecker, error) {
func NewLocalJWTChecker(authOpts map[string]string, logLevel log.Level, hasher hashing.HashComparer, options tokenOptions) (JWTChecker, error) {
checker := &localJWTChecker{
hasher: hasher,
db: postgresDB,
Expand All @@ -35,7 +38,7 @@ func NewLocalJWTChecker(authOpts map[string]string, logLevel log.Level, hasher h
localOk := true

if options.secret == "" {
return nil, errors.New("JWT backend error: missing jwt secret")
return nil, errors.New("JWT backend error: missing JWT secret")
}

if db, ok := authOpts["jwt_db"]; ok {
Expand All @@ -59,29 +62,36 @@ func NewLocalJWTChecker(authOpts map[string]string, logLevel log.Level, hasher h
if checker.db == mysqlDB {
mysql, err := NewMysql(dbAuthOpts, logLevel, hasher)
if err != nil {
return nil, errors.Errorf("JWT backend error: couldn't create mysql connector for local jwt: %s", err)
return nil, errors.Errorf("JWT backend error: couldn't create mysql connector for local JWT: %s", err)
}

checker.mysql = mysql
} else if checker.db == mongoDB {
mongodb, err := NewMongo(dbAuthOpts, logLevel, hasher)

return checker, nil
}
if err != nil {
return nil, errors.Errorf("JWT backend error: couldn't create mysql connector for local JWT: %s", err)
}

checker.mongo = mongodb
} else {
postgres, err := NewPostgres(dbAuthOpts, logLevel, hasher)

postgres, err := NewPostgres(dbAuthOpts, logLevel, hasher)
checker.postgres = postgres
}

if err != nil {
return nil, errors.Errorf("JWT backend error: couldn't create postgres connector for local jwt: %s", err)
return nil, errors.Errorf("JWT backend error: couldn't create postgres connector for local JWT: %s", err)
}

checker.postgres = postgres

return checker, nil
}

func (o *localJWTChecker) GetUser(token string) (bool, error) {
username, err := getUsernameForToken(o.options, token, o.options.skipUserExpiration)

if err != nil {
log.Printf("jwt local get user error: %s", err)
log.Printf("JWT local get user error: %s", err)
return false, err
}

Expand All @@ -92,43 +102,47 @@ func (o *localJWTChecker) GetSuperuser(token string) (bool, error) {
username, err := getUsernameForToken(o.options, token, o.options.skipUserExpiration)

if err != nil {
log.Printf("jwt local get superuser error: %s", err)
log.Printf("JWT local get superuser error: %s", err)
return false, err
}

if o.db == mysqlDB {
return o.mysql.GetSuperuser(username)
} else if o.db == mongoDB {
return o.mongo.GetSuperuser(username)
} else {
return o.postgres.GetSuperuser(username)
}

return o.postgres.GetSuperuser(username)
}

func (o *localJWTChecker) CheckAcl(token, topic, clientid string, acc int32) (bool, error) {
username, err := getUsernameForToken(o.options, token, o.options.skipACLExpiration)

if err != nil {
log.Printf("jwt local check acl error: %s", err)
log.Printf("JWT local check acl error: %s", err)
return false, err
}

if o.db == mysqlDB {
return o.mysql.CheckAcl(username, topic, clientid, acc)
} else if o.db == mongoDB {
return o.mongo.CheckAcl(username)
} else {
return o.postgres.CheckAcl(username)
}

return o.postgres.CheckAcl(username, topic, clientid, acc)
}

func (o *localJWTChecker) Halt() {
if o.postgres != (Postgres{}) && o.postgres.DB != nil {
err := o.postgres.DB.Close()
if err != nil {
log.Errorf("JWT cleanup error: %s", err)
}
} else if o.mysql != (Mysql{}) && o.mysql.DB != nil {
err := o.mysql.DB.Close()
if err != nil {
log.Errorf("JWT cleanup error: %s", err)
}
} else if o.mongo != (Mongo{}) && o.mongo.Conn != nil {
err := o.mongo.Conn.Disconnect(context.TODO())
}

if err != nil {
log.Errorf("JWT cleanup error: %s", err)
}
}

Expand All @@ -137,25 +151,36 @@ func (o *localJWTChecker) getLocalUser(username string) (bool, error) {
return false, nil
}

var count sql.NullInt64
var err error
var sqlCount sql.NullInt64
var count Int64
var valid boolean

if o.db == mysqlDB {
err = o.mysql.DB.Get(&count, o.userQuery, username)
valid = sqlCount.Valid
count = sqlCount.Int64
} else if o.db == mongoDB {
var uc := o.mongo.Conn.Database(o.mongo.DBName).Collection(o.mongo.UsersCollection)

count, err := uc.CountDocuments(context.TODO(), bson.M{"username": username})
} else {
err = o.postgres.DB.Get(&count, o.userQuery, username)
}
valid = sqlCount.Valid
count = sqlCount.Int64
}

if err != nil {
log.Debugf("local JWT get user error: %s", err)
return false, err
}

if !count.Valid {
if !valid {
log.Debugf("local JWT get user error: user %s not found", username)
return false, nil
}

if count.Int64 > 0 {
if count > 0 {
return true, nil
}

Expand All @@ -165,9 +190,12 @@ func (o *localJWTChecker) getLocalUser(username string) (bool, error) {
func extractOpts(authOpts map[string]string, db string) map[string]string {
dbAuthOpts := make(map[string]string)

dbPrefix := "pg"
if db == mysqlDB {
dbPrefix = mysqlDB
} else if db == mongoDB {
dbPrefix = mongoDB
} else {
dbPrefix := "pg"
}

prefix := "jwt_" + dbPrefix
Expand Down