Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate auth logic into mlflow, aim, admin, chooser parts #1051

Merged
merged 11 commits into from
Mar 28, 2024
15 changes: 0 additions & 15 deletions pkg/api/admin/controller/controller.go

This file was deleted.

33 changes: 0 additions & 33 deletions pkg/api/admin/controller/namespace.go

This file was deleted.

27 changes: 0 additions & 27 deletions pkg/api/admin/routes.go

This file was deleted.

14 changes: 13 additions & 1 deletion pkg/api/aim2/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,35 @@ package aim2
import (
"github.com/gofiber/fiber/v2"

mlflowConfig "github.com/G-Research/fasttrackml/pkg/api/mlflow/config"
suprjinx marked this conversation as resolved.
Show resolved Hide resolved
"github.com/G-Research/fasttrackml/pkg/common/middleware"

"github.com/G-Research/fasttrackml/pkg/api/aim2/controller"
)

// Router represents `mlflow` router.
type Router struct {
config *mlflowConfig.ServiceConfig
controller *controller.Controller
}

// NewRouter creates new instance of `mlflow` router.
func NewRouter(controller *controller.Controller) *Router {
func NewRouter(config *mlflowConfig.ServiceConfig, controller *controller.Controller) *Router {
return &Router{
config: config,
controller: controller,
}
}

func (r Router) Init(server fiber.Router) {
mainGroup := server.Group("/aim/api")
// apply global auth middlewares.
switch {
case r.config.Auth.IsAuthTypeUser():
mainGroup.Use(middleware.NewUserMiddleware(r.config.Auth.AuthParsedUserPermissions))
}

// setup related routes.
apps := mainGroup.Group("apps")
apps.Get("/", r.controller.GetApps)
apps.Post("/", r.controller.CreateApp)
Expand Down
20 changes: 16 additions & 4 deletions pkg/api/mlflow/config/auth/auth.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
package auth

import (
"github.com/rotisserie/eris"

"github.com/G-Research/fasttrackml/pkg/common/db/models"
)

// supported list of authentication types.
const (
TypeOIDC string = "oidc"
TypeUser string = "user"
)

type Config struct {
AuthType string
AuthUsername string
AuthPassword string
AuthUsersConfig string
AuthType string
AuthUsername string
AuthPassword string
AuthUsersConfig string
AuthParsedUserPermissions *models.UserPermissions
}

// IsAuthTypeOIDC makes check that current auth is TypeOIDC.
Expand All @@ -33,6 +40,11 @@ func (c *Config) NormalizeConfiguration() error {
switch {
case c.AuthUsersConfig != "":
c.AuthType = TypeUser
parsedUserPermissions, err := Load(c.AuthUsersConfig)
if err != nil {
return eris.Wrapf(err, "error loading auth user configuration from file: %s", c.AuthUsersConfig)
}
c.AuthParsedUserPermissions = parsedUserPermissions
}
return nil
}
21 changes: 16 additions & 5 deletions pkg/api/mlflow/config/auth/auth_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package auth

import (
"fmt"
"os"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -9,22 +11,31 @@ import (
func TestConfig_NormalizeConfiguration(t *testing.T) {
tests := []struct {
name string
config *Config
init func() *Config
configType string
}{
{
name: "TestAuthTypeUser",
config: &Config{
AuthUsersConfig: "/path/to/file",
init: func() *Config {
configPath := fmt.Sprintf("%s/configuration.yml", t.TempDir())
// #nosec G304
f, err := os.Create(configPath)
assert.Nil(t, err)
assert.Nil(t, f.Close())

return &Config{
AuthUsersConfig: configPath,
}
},
configType: TypeUser,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Nil(t, tt.config.NormalizeConfiguration())
assert.Equal(t, tt.configType, tt.config.AuthType)
config := tt.init()
assert.Nil(t, config.NormalizeConfiguration())
assert.Equal(t, tt.configType, config.AuthType)
})
}
}
29 changes: 27 additions & 2 deletions pkg/api/mlflow/config/auth/config_loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,24 @@ func TestUserPermissions_HasAccess_Ok(t *testing.T) {
},
}),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
authToken := tt.permissions.ValidateAuthToken(tt.token)
assert.NotNil(t, authToken)
assert.True(t, authToken.HasUserAccess(tt.namespace))
})
}
}

func TestUserPermissions_HasAdminAccess_Ok(t *testing.T) {
tests := []struct {
name string
token string
namespace string
permissions *models.UserPermissions
}{
{
name: "TestUserPermissionsUserHasAdminRole",
token: "token",
Expand All @@ -174,7 +192,9 @@ func TestUserPermissions_HasAccess_Ok(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.True(t, tt.permissions.HasAccess(tt.namespace, tt.token))
authToken := tt.permissions.ValidateAuthToken(tt.token)
assert.NotNil(t, authToken)
assert.True(t, authToken.HasAdminAccess())
})
}
}
Expand Down Expand Up @@ -212,7 +232,12 @@ func TestUserPermissions_HasAccess_Error(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.False(t, tt.permissions.HasAccess(tt.namespace, tt.token))
authToken := tt.permissions.ValidateAuthToken(tt.token)
if authToken != nil {
assert.False(t, authToken.HasUserAccess(tt.namespace))
} else {
assert.Nil(t, authToken)
}
})
}
}
18 changes: 14 additions & 4 deletions pkg/api/mlflow/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package mlflow
import (
"github.com/gofiber/fiber/v2"

mlflowConfig "github.com/G-Research/fasttrackml/pkg/api/mlflow/config"
"github.com/G-Research/fasttrackml/pkg/api/mlflow/controller"
"github.com/G-Research/fasttrackml/pkg/common/api"
"github.com/G-Research/fasttrackml/pkg/common/middleware"
)

// List of route prefixes.
Expand Down Expand Up @@ -58,13 +60,15 @@ const (

// Router represents `mlflow` router.
type Router struct {
config *mlflowConfig.ServiceConfig
prefixList []string
controller *controller.Controller
}

// NewRouter creates new instance of `mlflow` router.
func NewRouter(controller *controller.Controller) *Router {
func NewRouter(config *mlflowConfig.ServiceConfig, controller *controller.Controller) *Router {
return &Router{
config: config,
prefixList: []string{
"/api/2.0/mlflow/",
"/ajax-api/2.0/mlflow/",
Expand All @@ -74,10 +78,16 @@ func NewRouter(controller *controller.Controller) *Router {
}

// Init makes initialization of all `mlflow` routes.
func (r Router) Init(server fiber.Router) {
func (r Router) Init(router fiber.Router) {
for _, prefix := range r.prefixList {
mainGroup := server.Group(prefix)

mainGroup := router.Group(prefix)
// apply global auth middlewares.
switch {
case r.config.Auth.IsAuthTypeUser():
mainGroup.Use(middleware.NewUserMiddleware(r.config.Auth.AuthParsedUserPermissions))
}

// setup related routes.
artifacts := mainGroup.Group(ArtifactsRoutePrefix)
artifacts.Get(ArtifactsGetRoute, r.controller.GetArtifact)
artifacts.Get(ArtifactsListRoute, r.controller.ListArtifacts)
Expand Down
1 change: 0 additions & 1 deletion pkg/common/api/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ type ErrorCode string

const (
ErrorCodeInternalError = "INTERNAL_ERROR"
ErrorAccessForbiddenError = "FORBIDDEN_ERROR"
ErrorCodeTemporarilyUnavailable = "TEMPORARILY_UNAVAILABLE"
ErrorCodeBadRequest = "BAD_REQUEST"
ErrorCodeInvalidParameterValue = "INVALID_PARAMETER_VALUE"
Expand Down
43 changes: 32 additions & 11 deletions pkg/common/db/models/user_permission.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,32 @@ package models

import "fmt"

// BasicAuthToken represents object to store auth information related to Basic Auth.
type BasicAuthToken struct {
roles map[string]struct{}
}

// HasAdminAccess makes check that user has admin permissions to access to the requested resource.
func (p BasicAuthToken) HasAdminAccess() bool {
if _, ok := p.roles["admin"]; ok {
return true
}
return false
}

// HasUserAccess makes check that user has permission to access to the requested namespace.
func (p BasicAuthToken) HasUserAccess(namespace string) bool {
if _, ok := p.roles[fmt.Sprintf("ns:%s", namespace)]; !ok {
return ok
}
return true
}

// GetRoles returns User roles assigned to current Auth token.
func (p BasicAuthToken) GetRoles() map[string]struct{} {
return p.roles
}

// UserPermissions represents model to store user permissions data.
type UserPermissions struct {
data map[string]map[string]struct{}
Expand All @@ -19,23 +45,18 @@ func (p UserPermissions) GetData() map[string]map[string]struct{} {
return p.data
}

// HasAccess makes check that user has permission to access to the requested namespace.
func (p UserPermissions) HasAccess(namespace string, authToken string) bool {
// ValidateAuthToken makes basic validation of auth token.
func (p UserPermissions) ValidateAuthToken(authToken string) *BasicAuthToken {
if authToken == "" {
return false
return nil
}

roles, ok := p.data[authToken]
if !ok {
return ok
return nil
}

if _, ok := roles["admin"]; ok {
return true
return &BasicAuthToken{
roles: roles,
}

if _, ok := roles[fmt.Sprintf("ns:%s", namespace)]; !ok {
return ok
}
return true
}
Loading