Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Display and Update experiment description in aim UI #578

Merged
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
ebd48cd
Admin create namespace integration tests
fabiovincenzi Nov 6, 2023
72d05e7
delete tests
fabiovincenzi Nov 6, 2023
0e68678
update create tests
fabiovincenzi Nov 6, 2023
93a2836
uopdate tests
fabiovincenzi Nov 6, 2023
42c96c3
Merge branch 'main' into admin-integration-tests
fabiovincenzi Nov 6, 2023
bc99ae8
merge conflicts
fabiovincenzi Nov 6, 2023
b29e9fb
adressing comments
fabiovincenzi Nov 7, 2023
a7f1a64
Merge branch 'main' into admin-integration-tests
fabiovincenzi Nov 7, 2023
c7e0854
addressing comments
fabiovincenzi Nov 7, 2023
0e484a6
Merge branch 'admin-integration-tests' of https://github.com/fabiovin…
fabiovincenzi Nov 7, 2023
8af027a
fix duplicate namespace code error
fabiovincenzi Nov 7, 2023
1567bac
add html response type in client
fabiovincenzi Nov 8, 2023
b12dea9
fix code formatting
fabiovincenzi Nov 8, 2023
65cf02f
remove unnecessary dependancy
fabiovincenzi Nov 8, 2023
b73dbdf
implementing updateExperiment end point
fabiovincenzi Nov 13, 2023
1d8cf70
Merge branch 'G-Research:main' into experiment-description-aim-ui
fabiovincenzi Nov 13, 2023
f83531f
get experiments endpoint
fabiovincenzi Nov 13, 2023
fc43c9d
Merge branch 'experiment-description-aim-ui' of https://github.com/fa…
fabiovincenzi Nov 13, 2023
0b29563
fix style
fabiovincenzi Nov 13, 2023
d8fc5ee
update experiment mock
fabiovincenzi Nov 13, 2023
ecd8707
fix get experiment test
fabiovincenzi Nov 14, 2023
bfd96cc
Merge branch 'main' into experiment-description-aim-ui
fabiovincenzi Nov 14, 2023
c9f44db
fix lint problems
fabiovincenzi Nov 14, 2023
cb074d6
fix lint problems
fabiovincenzi Nov 14, 2023
d07b06e
fix black lines
fabiovincenzi Nov 14, 2023
8a15215
rename file
fabiovincenzi Nov 15, 2023
0d6b27e
using constant for experiment description
fabiovincenzi Nov 15, 2023
8c223e9
description tag key const
fabiovincenzi Nov 15, 2023
a7bd05a
Merge branch 'main' into experiment-description-aim-ui
fabiovincenzi Nov 17, 2023
11a04da
Merge branch 'main' into experiment-description-aim-ui
fabiovincenzi Nov 20, 2023
65be212
change assert.Nil with require.Nil
fabiovincenzi Nov 20, 2023
07d0b12
Merge branch 'main' into experiment-description-aim-ui
fabiovincenzi Nov 21, 2023
02ec3f1
Merge branch 'main' into experiment-description-aim-ui
fabiovincenzi Nov 21, 2023
f3a6f55
fix formatting
fabiovincenzi Nov 21, 2023
3f351d0
fix get_experiments test
fabiovincenzi Nov 21, 2023
fc125ad
removing MAX function
fabiovincenzi Nov 21, 2023
5054d46
restore MAX function
fabiovincenzi Nov 21, 2023
73e430f
Merge branch 'main' into experiment-description-aim-ui
fabiovincenzi Nov 22, 2023
3a75a69
Merge branch 'main' into experiment-description-aim-ui
fabiovincenzi Nov 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 89 additions & 7 deletions pkg/api/aim/experiments.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
log "github.com/sirupsen/logrus"
"gorm.io/gorm"

"github.com/G-Research/fasttrackml/pkg/api/aim/request"
"github.com/G-Research/fasttrackml/pkg/api/mlflow/api"
"github.com/G-Research/fasttrackml/pkg/api/mlflow/common"
"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
Expand All @@ -27,7 +28,8 @@ func GetExperiments(c *fiber.Ctx) error {

var experiments []struct {
database.Experiment
RunCount int
RunCount int
Description string `gorm:"column:description"`
}
if tx := database.DB.Model(&database.Experiment{}).
Select(
Expand All @@ -36,10 +38,13 @@ func GetExperiments(c *fiber.Ctx) error {
"experiments.lifecycle_stage",
"experiments.creation_time",
"COUNT(runs.run_uuid) AS run_count",
"COALESCE(MAX(experiment_tags.value), '') AS description",
suprjinx marked this conversation as resolved.
Show resolved Hide resolved
).
Where("experiments.namespace_id = ?", ns.ID).
Where("experiments.lifecycle_stage = ?", database.LifecycleStageActive).
Joins("LEFT JOIN runs USING(experiment_id)").
Joins("LEFT JOIN experiment_tags ON experiments.experiment_id = experiment_tags.experiment_id AND"+
" experiment_tags.key = ?", common.DescriptionTagKey).
Group("experiments.experiment_id").
Find(&experiments); tx.Error != nil {
return fmt.Errorf("error fetching experiments: %w", tx.Error)
Expand All @@ -50,7 +55,7 @@ func GetExperiments(c *fiber.Ctx) error {
resp[i] = fiber.Map{
"id": strconv.Itoa(int(*e.ID)),
"name": e.Name,
"description": nil,
"description": e.Description,
"archived": e.LifecycleStage == database.LifecycleStageDeleted,
"run_count": e.RunCount,
"creation_time": float64(e.CreationTime.Int64) / 1000,
Expand Down Expand Up @@ -92,7 +97,8 @@ func GetExperiment(c *fiber.Ctx) error {

var exp struct {
database.Experiment
RunCount int
RunCount int
Description string `gorm:"column:description"`
}
if err := database.DB.Model(&database.Experiment{}).
Select(
Expand All @@ -101,22 +107,24 @@ func GetExperiment(c *fiber.Ctx) error {
"experiments.lifecycle_stage",
"experiments.creation_time",
"COUNT(runs.run_uuid) AS run_count",
"COALESCE(MAX(experiment_tags.value), '') AS description",
suprjinx marked this conversation as resolved.
Show resolved Hide resolved
).
Joins("LEFT JOIN runs USING(experiment_id)").
Group("experiments.experiment_id").
Joins("LEFT JOIN experiment_tags ON experiments.experiment_id = experiment_tags.experiment_id AND"+
" experiment_tags.key = ?", common.DescriptionTagKey).
Where("experiments.namespace_id = ?", ns.ID).
Where("experiments.experiment_id = ?", id).
Group("experiments.experiment_id").
First(&exp).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fiber.ErrNotFound
}
return fmt.Errorf("error fetching experiment %q: %w", p.ID, err)
}

return c.JSON(fiber.Map{
"id": id,
"id": strconv.Itoa(int(id)),
"name": exp.Name,
"description": nil,
"description": exp.Description,
"archived": exp.LifecycleStage == database.LifecycleStageDeleted,
"run_count": exp.RunCount,
"creation_time": float64(exp.CreationTime.Int64) / 1000,
Expand Down Expand Up @@ -327,3 +335,77 @@ func DeleteExperiment(c *fiber.Ctx) error {
"status": "OK",
})
}

func UpdateExperiment(c *fiber.Ctx) error {
dave-gantenbein marked this conversation as resolved.
Show resolved Hide resolved
ns, err := namespace.GetNamespaceFromContext(c.Context())
if err != nil {
return api.NewInternalError("error getting namespace from context")
}
log.Debugf("updateExperiment namespace: %s", ns.Code)

params := struct {
ID string `params:"id"`
}{}
if err = c.ParamsParser(&params); err != nil {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}
id, err := strconv.ParseInt(params.ID, 10, 32)
if err != nil {
return api.NewBadRequestError("Unable to parse experiment id '%s': %s", params.ID, err)
}

var updateRequest request.UpdateExperimentRequest
if err = c.BodyParser(&updateRequest); err != nil {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}

experimentRepository := repositories.NewExperimentRepository(database.DB)
tagRepository := repositories.NewTagRepository(database.DB)
experiment, err := experimentRepository.GetByNamespaceIDAndExperimentID(c.Context(), ns.ID, int32(id))
if err != nil {
return fiber.NewError(
fiber.StatusInternalServerError, fmt.Sprintf("unable to find experiment '%s': %s", params.ID, err),
)
}
if experiment == nil {
return fiber.NewError(fiber.StatusNotFound, fmt.Sprintf("unable to find experiment '%s'", params.ID))
}
if updateRequest.Archived != nil {
if *updateRequest.Archived {
experiment.LifecycleStage = models.LifecycleStageDeleted
} else {
experiment.LifecycleStage = models.LifecycleStageActive
}
}

if updateRequest.Name != nil {
experiment.Name = *updateRequest.Name
}

if updateRequest.Archived != nil || updateRequest.Name != nil {
if err := database.DB.Transaction(func(tx *gorm.DB) error {
if err := experimentRepository.UpdateWithTransaction(c.Context(), tx, experiment); err != nil {
return err
}
return nil
}); err != nil {
return fiber.NewError(fiber.StatusInternalServerError,
fmt.Sprintf("unable to update experiment %q: %s", params.ID, err))
}
}
if updateRequest.Description != nil {
description := models.ExperimentTag{
Key: common.DescriptionTagKey,
Value: *updateRequest.Description,
ExperimentID: *experiment.ID,
}
if err := tagRepository.CreateExperimentTag(c.Context(), &description); err != nil {
suprjinx marked this conversation as resolved.
Show resolved Hide resolved
return err
}
}

return c.JSON(fiber.Map{
"id": params.ID,
"status": "OK",
})
}
8 changes: 8 additions & 0 deletions pkg/api/aim/request/experiment.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package request

// UpdateExperimentRequest is a request struct for `PUT /experiments/:id` endpoint.
type UpdateExperimentRequest struct {
Name *string `json:"name"`
Description *string `json:"description"`
Archived *bool `json:"archived"`
}
2 changes: 1 addition & 1 deletion pkg/api/aim/response/experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package response

// GetExperiment represents the response json fot the GetExperimnt endpoint.
type GetExperiment struct {
ID int32 `json:"id"`
ID string `json:"id"`
dave-gantenbein marked this conversation as resolved.
Show resolved Hide resolved
Name string `json:"name"`
Description string `json:"description"`
Archived bool `json:"archived"`
Expand Down
1 change: 1 addition & 0 deletions pkg/api/aim/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ func AddRoutes(r fiber.Router) {
experiments.Get("/:id/activity/", GetExperimentActivity)
experiments.Get("/:id/runs/", GetExperimentRuns)
experiments.Delete("/:id/", DeleteExperiment)
experiments.Put("/:id/", UpdateExperiment)

projects := r.Group("/projects")
projects.Get("/", GetProject)
Expand Down
5 changes: 5 additions & 0 deletions pkg/api/mlflow/common/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,8 @@ const (
NANPositiveInfinity = "Infinity"
NANNegativeInfinity = "-Infinity"
)

// Constants for experiment tags keys.
const (
DescriptionTagKey = "mlflow.note.content"
)
15 changes: 15 additions & 0 deletions pkg/api/mlflow/dao/repositories/experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ type ExperimentRepositoryProvider interface {
GetByNamespaceIDAndExperimentID(
ctx context.Context, namespaceID uint, experimentID int32,
) (*models.Experiment, error)
// UpdateWithTransaction updates existing models.Experiment entity in scope of transaction.
UpdateWithTransaction(ctx context.Context, tx *gorm.DB, experiment *models.Experiment) error
}

// ExperimentRepository repository to work with `experiment` entity.
Expand Down Expand Up @@ -175,3 +177,16 @@ func (r ExperimentRepository) DeleteBatch(ctx context.Context, ids []*int32) err

return nil
}

// UpdateWithTransaction updates existing models.Experiment entity in scope of transaction.
func (r ExperimentRepository) UpdateWithTransaction(
ctx context.Context,
tx *gorm.DB,
experiment *models.Experiment,
) error {
if err := tx.WithContext(ctx).Model(&experiment).Updates(experiment).Error; err != nil {
return eris.Wrapf(err, "error updating existing experiment with id: %d", experiment.ID)
}

return nil
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package experiment
import (
"context"
"database/sql"
"fmt"
"testing"
"time"

Expand Down Expand Up @@ -43,7 +44,7 @@ func (s *GetExperimentTestSuite) Test_Ok() {
Name: "Test Experiment",
Tags: []models.ExperimentTag{
{
Key: "key1",
Key: common.DescriptionTagKey,
Value: "value1",
},
},
Expand All @@ -62,11 +63,11 @@ func (s *GetExperimentTestSuite) Test_Ok() {
require.Nil(s.T(), err)

var resp response.GetExperiment
require.Nil(s.T(), s.AIMClient.WithResponse(&resp).DoRequest("/experiments/%d", *experiment.ID))
assert.Nil(s.T(), s.AIMClient.WithResponse(&resp).DoRequest("/experiments/%d", *experiment.ID))

assert.Equal(s.T(), *experiment.ID, resp.ID)
assert.Equal(s.T(), fmt.Sprintf("%d", *experiment.ID), resp.ID)
assert.Equal(s.T(), experiment.Name, resp.Name)
assert.Equal(s.T(), "", resp.Description)
assert.Equal(s.T(), helpers.GetDescriptionFromExperiment(*experiment), resp.Description)
assert.Equal(s.T(), float64(experiment.CreationTime.Int64)/1000, resp.CreationTime)
assert.Equal(s.T(), false, resp.Archived)
assert.Equal(s.T(), len(experiment.Runs), resp.RunCount)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"database/sql"
"fmt"
"strconv"
"testing"
"time"

Expand All @@ -24,7 +25,7 @@ type GetExperimentsTestSuite struct {
}

func TestGetExperimentsTestSuite(t *testing.T) {
suite.Run(t, new(GetExperimentTestSuite))
suite.Run(t, new(GetExperimentsTestSuite))
}

func (s *GetExperimentsTestSuite) Test_Ok() {
Expand Down Expand Up @@ -70,9 +71,12 @@ func (s *GetExperimentsTestSuite) Test_Ok() {
require.Nil(s.T(), s.AIMClient.WithResponse(&resp).DoRequest("/experiments/"))
assert.Equal(s.T(), len(experiments), len(resp))
for _, actualExperiment := range resp {
expectedExperiment := experiments[actualExperiment.ID]
id, err := strconv.ParseInt(actualExperiment.ID, 10, 32)
assert.Nil(s.T(), err)
expectedExperiment := experiments[int32(id)]
assert.Equal(s.T(), fmt.Sprintf("%d", *expectedExperiment.ID), actualExperiment.ID)
assert.Equal(s.T(), expectedExperiment.Name, actualExperiment.Name)
assert.Equal(s.T(), helpers.GetDescriptionFromExperiment(*expectedExperiment), actualExperiment.Description)
assert.Equal(s.T(), float64(expectedExperiment.CreationTime.Int64)/1000, actualExperiment.CreationTime)
assert.Equal(s.T(), expectedExperiment.LifecycleStage == models.LifecycleStageDeleted, actualExperiment.Archived)
assert.Equal(s.T(), len(expectedExperiment.Runs), actualExperiment.RunCount)
Expand Down
16 changes: 16 additions & 0 deletions tests/integration/golang/helpers/experiment.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package helpers

import (
"github.com/G-Research/fasttrackml/pkg/api/mlflow/common"
"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
)

// GetDescriptionFromexperiment returns the description of a given experiment.
func GetDescriptionFromExperiment(experiment models.Experiment) string {
for _, tag := range experiment.Tags {
if tag.Key == common.DescriptionTagKey {
return tag.Value
}
}
return ""
dave-gantenbein marked this conversation as resolved.
Show resolved Hide resolved
}