Skip to content

Commit

Permalink
Integrate auth logic into mlflow, aim, admin, chooser parts (#1051)
Browse files Browse the repository at this point in the history
Integrate auth logic into `mlflow`, `aim`, `admin`, `chooser` parts.
  • Loading branch information
dsuhinin authored Mar 28, 2024
1 parent 61a3a1b commit c09ee2c
Show file tree
Hide file tree
Showing 33 changed files with 438 additions and 187 deletions.
15 changes: 0 additions & 15 deletions pkg/api/admin/controller/controller.go

This file was deleted.

33 changes: 0 additions & 33 deletions pkg/api/admin/controller/namespace.go

This file was deleted.

27 changes: 0 additions & 27 deletions pkg/api/admin/routes.go

This file was deleted.

14 changes: 13 additions & 1 deletion pkg/api/aim2/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,35 @@ package aim2
import (
"github.com/gofiber/fiber/v2"

mlflowConfig "github.com/G-Research/fasttrackml/pkg/api/mlflow/config"
"github.com/G-Research/fasttrackml/pkg/common/middleware"

"github.com/G-Research/fasttrackml/pkg/api/aim2/controller"
)

// Router represents `mlflow` router.
type Router struct {
config *mlflowConfig.ServiceConfig
controller *controller.Controller
}

// NewRouter creates new instance of `mlflow` router.
func NewRouter(controller *controller.Controller) *Router {
func NewRouter(config *mlflowConfig.ServiceConfig, controller *controller.Controller) *Router {
return &Router{
config: config,
controller: controller,
}
}

func (r Router) Init(server fiber.Router) {
mainGroup := server.Group("/aim/api")
// apply global auth middlewares.
switch {
case r.config.Auth.IsAuthTypeUser():
mainGroup.Use(middleware.NewUserMiddleware(r.config.Auth.AuthParsedUserPermissions))
}

// setup related routes.
apps := mainGroup.Group("apps")
apps.Get("/", r.controller.GetApps)
apps.Post("/", r.controller.CreateApp)
Expand Down
20 changes: 16 additions & 4 deletions pkg/api/mlflow/config/auth/auth.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
package auth

import (
"github.com/rotisserie/eris"

"github.com/G-Research/fasttrackml/pkg/common/db/models"
)

// supported list of authentication types.
const (
TypeOIDC string = "oidc"
TypeUser string = "user"
)

type Config struct {
AuthType string
AuthUsername string
AuthPassword string
AuthUsersConfig string
AuthType string
AuthUsername string
AuthPassword string
AuthUsersConfig string
AuthParsedUserPermissions *models.UserPermissions
}

// IsAuthTypeOIDC makes check that current auth is TypeOIDC.
Expand All @@ -33,6 +40,11 @@ func (c *Config) NormalizeConfiguration() error {
switch {
case c.AuthUsersConfig != "":
c.AuthType = TypeUser
parsedUserPermissions, err := Load(c.AuthUsersConfig)
if err != nil {
return eris.Wrapf(err, "error loading auth user configuration from file: %s", c.AuthUsersConfig)
}
c.AuthParsedUserPermissions = parsedUserPermissions
}
return nil
}
21 changes: 16 additions & 5 deletions pkg/api/mlflow/config/auth/auth_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package auth

import (
"fmt"
"os"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -9,22 +11,31 @@ import (
func TestConfig_NormalizeConfiguration(t *testing.T) {
tests := []struct {
name string
config *Config
init func() *Config
configType string
}{
{
name: "TestAuthTypeUser",
config: &Config{
AuthUsersConfig: "/path/to/file",
init: func() *Config {
configPath := fmt.Sprintf("%s/configuration.yml", t.TempDir())
// #nosec G304
f, err := os.Create(configPath)
assert.Nil(t, err)
assert.Nil(t, f.Close())

return &Config{
AuthUsersConfig: configPath,
}
},
configType: TypeUser,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Nil(t, tt.config.NormalizeConfiguration())
assert.Equal(t, tt.configType, tt.config.AuthType)
config := tt.init()
assert.Nil(t, config.NormalizeConfiguration())
assert.Equal(t, tt.configType, config.AuthType)
})
}
}
29 changes: 27 additions & 2 deletions pkg/api/mlflow/config/auth/config_loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,24 @@ func TestUserPermissions_HasAccess_Ok(t *testing.T) {
},
}),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
authToken := tt.permissions.ValidateAuthToken(tt.token)
assert.NotNil(t, authToken)
assert.True(t, authToken.HasUserAccess(tt.namespace))
})
}
}

func TestUserPermissions_HasAdminAccess_Ok(t *testing.T) {
tests := []struct {
name string
token string
namespace string
permissions *models.UserPermissions
}{
{
name: "TestUserPermissionsUserHasAdminRole",
token: "token",
Expand All @@ -174,7 +192,9 @@ func TestUserPermissions_HasAccess_Ok(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.True(t, tt.permissions.HasAccess(tt.namespace, tt.token))
authToken := tt.permissions.ValidateAuthToken(tt.token)
assert.NotNil(t, authToken)
assert.True(t, authToken.HasAdminAccess())
})
}
}
Expand Down Expand Up @@ -212,7 +232,12 @@ func TestUserPermissions_HasAccess_Error(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.False(t, tt.permissions.HasAccess(tt.namespace, tt.token))
authToken := tt.permissions.ValidateAuthToken(tt.token)
if authToken != nil {
assert.False(t, authToken.HasUserAccess(tt.namespace))
} else {
assert.Nil(t, authToken)
}
})
}
}
18 changes: 14 additions & 4 deletions pkg/api/mlflow/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package mlflow
import (
"github.com/gofiber/fiber/v2"

mlflowConfig "github.com/G-Research/fasttrackml/pkg/api/mlflow/config"
"github.com/G-Research/fasttrackml/pkg/api/mlflow/controller"
"github.com/G-Research/fasttrackml/pkg/common/api"
"github.com/G-Research/fasttrackml/pkg/common/middleware"
)

// List of route prefixes.
Expand Down Expand Up @@ -58,13 +60,15 @@ const (

// Router represents `mlflow` router.
type Router struct {
config *mlflowConfig.ServiceConfig
prefixList []string
controller *controller.Controller
}

// NewRouter creates new instance of `mlflow` router.
func NewRouter(controller *controller.Controller) *Router {
func NewRouter(config *mlflowConfig.ServiceConfig, controller *controller.Controller) *Router {
return &Router{
config: config,
prefixList: []string{
"/api/2.0/mlflow/",
"/ajax-api/2.0/mlflow/",
Expand All @@ -74,10 +78,16 @@ func NewRouter(controller *controller.Controller) *Router {
}

// Init makes initialization of all `mlflow` routes.
func (r Router) Init(server fiber.Router) {
func (r Router) Init(router fiber.Router) {
for _, prefix := range r.prefixList {
mainGroup := server.Group(prefix)

mainGroup := router.Group(prefix)
// apply global auth middlewares.
switch {
case r.config.Auth.IsAuthTypeUser():
mainGroup.Use(middleware.NewUserMiddleware(r.config.Auth.AuthParsedUserPermissions))
}

// setup related routes.
artifacts := mainGroup.Group(ArtifactsRoutePrefix)
artifacts.Get(ArtifactsGetRoute, r.controller.GetArtifact)
artifacts.Get(ArtifactsListRoute, r.controller.ListArtifacts)
Expand Down
1 change: 0 additions & 1 deletion pkg/common/api/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ type ErrorCode string

const (
ErrorCodeInternalError = "INTERNAL_ERROR"
ErrorAccessForbiddenError = "FORBIDDEN_ERROR"
ErrorCodeTemporarilyUnavailable = "TEMPORARILY_UNAVAILABLE"
ErrorCodeBadRequest = "BAD_REQUEST"
ErrorCodeInvalidParameterValue = "INVALID_PARAMETER_VALUE"
Expand Down
43 changes: 32 additions & 11 deletions pkg/common/db/models/user_permission.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,32 @@ package models

import "fmt"

// BasicAuthToken represents object to store auth information related to Basic Auth.
type BasicAuthToken struct {
roles map[string]struct{}
}

// HasAdminAccess makes check that user has admin permissions to access to the requested resource.
func (p BasicAuthToken) HasAdminAccess() bool {
if _, ok := p.roles["admin"]; ok {
return true
}
return false
}

// HasUserAccess makes check that user has permission to access to the requested namespace.
func (p BasicAuthToken) HasUserAccess(namespace string) bool {
if _, ok := p.roles[fmt.Sprintf("ns:%s", namespace)]; !ok {
return ok
}
return true
}

// GetRoles returns User roles assigned to current Auth token.
func (p BasicAuthToken) GetRoles() map[string]struct{} {
return p.roles
}

// UserPermissions represents model to store user permissions data.
type UserPermissions struct {
data map[string]map[string]struct{}
Expand All @@ -19,23 +45,18 @@ func (p UserPermissions) GetData() map[string]map[string]struct{} {
return p.data
}

// HasAccess makes check that user has permission to access to the requested namespace.
func (p UserPermissions) HasAccess(namespace string, authToken string) bool {
// ValidateAuthToken makes basic validation of auth token.
func (p UserPermissions) ValidateAuthToken(authToken string) *BasicAuthToken {
if authToken == "" {
return false
return nil
}

roles, ok := p.data[authToken]
if !ok {
return ok
return nil
}

if _, ok := roles["admin"]; ok {
return true
return &BasicAuthToken{
roles: roles,
}

if _, ok := roles[fmt.Sprintf("ns:%s", namespace)]; !ok {
return ok
}
return true
}
Loading

0 comments on commit c09ee2c

Please sign in to comment.