diff --git a/pkg/api/mlflow/api/request/metric.go b/pkg/api/mlflow/api/request/metric.go index 459231b6a..5a7c0c603 100644 --- a/pkg/api/mlflow/api/request/metric.go +++ b/pkg/api/mlflow/api/request/metric.go @@ -24,9 +24,10 @@ type GetMetricHistoryBulkRequest struct { // GetMetricHistoriesRequest is a request object for `POST /mlflow/metrics/get-histories` endpoint. type GetMetricHistoriesRequest struct { - ExperimentIDs []string `json:"experiment_ids"` - RunIDs []string `json:"run_ids"` - MetricKeys []string `json:"metric_keys"` - ViewType ViewType `json:"run_view_type"` - MaxResults int32 `json:"max_results"` + ExperimentIDs []string `json:"experiment_ids"` + RunIDs []string `json:"run_ids"` + MetricKeys []string `json:"metric_keys"` + ViewType ViewType `json:"run_view_type"` + MaxResults int32 `json:"max_results"` + Context map[string]string `json:"context"` } diff --git a/pkg/api/mlflow/dao/repositories/helpers.go b/pkg/api/mlflow/dao/repositories/helpers.go index 6d2124250..095b22454 100644 --- a/pkg/api/mlflow/dao/repositories/helpers.go +++ b/pkg/api/mlflow/dao/repositories/helpers.go @@ -4,6 +4,8 @@ import ( "fmt" "strings" + "gorm.io/driver/postgres" + "github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models" ) @@ -30,3 +32,39 @@ func makeParamConflictPlaceholdersAndValues(params []models.Param) (string, []in } return placeholders, valuesArray } + +// BuildJsonCondition creates sql and values for where condition to select items having the specified map of json paths +// and values in the given json column. Json path is expressed as "key" or "outerkey.nestedKey". +func BuildJsonCondition( + dialector string, + jsonColumnName string, + jsonPathValueMap map[string]string, +) (sql string, args []any) { + if len(jsonPathValueMap) == 0 { + return sql, args + } + var conditionTemplate string + args = make([]any, len(jsonPathValueMap)*2) + switch dialector { + case postgres.Dialector{}.Name(): + conditionTemplate = "%s#>>? = ?" + idx := 0 + for k, v := range jsonPathValueMap { + path := strings.ReplaceAll(k, ".", ",") + args[idx] = fmt.Sprintf("{%s}", path) + args[idx+1] = v + idx = idx + 2 + } + default: + conditionTemplate = "%s->>? = ?" + idx := 0 + for k, v := range jsonPathValueMap { + args[idx] = k + args[idx+1] = v + idx = idx + 2 + } + } + conditionTemplate = fmt.Sprintf(conditionTemplate, jsonColumnName) + sql = strings.Repeat(conditionTemplate+" AND ", len(jsonPathValueMap)-1) + conditionTemplate + return sql, args +} diff --git a/pkg/api/mlflow/dao/repositories/helpers_test.go b/pkg/api/mlflow/dao/repositories/helpers_test.go index 9ddca09de..e099cfe2f 100644 --- a/pkg/api/mlflow/dao/repositories/helpers_test.go +++ b/pkg/api/mlflow/dao/repositories/helpers_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models" + "github.com/G-Research/fasttrackml/pkg/database" ) func Test_makeSqlPlaceholders(t *testing.T) { @@ -53,3 +54,53 @@ func Test_makeParamConflictPlaceholdersAndValues(t *testing.T) { assert.Equal(t, tt.expectedValues, values) } } + +func TestBuildJsonCondition(t *testing.T) { + tests := []struct { + name string + dialector string + jsonColumnName string + jsonPathValueMap map[string]string + expectedSQL string + expectedArgs []interface{} + }{ + { + name: "Postgres", + dialector: database.PostgresDialectorName, + jsonColumnName: "contexts.json", + jsonPathValueMap: map[string]string{ + "key1": "value1", + "key2.nested": "value2", + }, + expectedSQL: "contexts.json#>>? = ? AND contexts.json#>>? = ?", + expectedArgs: []interface{}{"{key1}", "value1", "{key2,nested}", "value2"}, + }, + { + name: "Sqlite", + dialector: database.SQLiteDialectorName, + jsonColumnName: "contexts.json", + jsonPathValueMap: map[string]string{ + "key1": "value1", + "key2.nested": "value2", + }, + expectedSQL: "contexts.json->>? = ? AND contexts.json->>? = ?", + expectedArgs: []interface{}{"key1", "value1", "key2.nested", "value2"}, + }, + { + name: "SqliteEmptyMap", + dialector: database.SQLiteDialectorName, + jsonColumnName: "contexts.json", + jsonPathValueMap: map[string]string{}, + expectedSQL: "", + expectedArgs: []interface{}(nil), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, args := BuildJsonCondition(tt.dialector, tt.jsonColumnName, tt.jsonPathValueMap) + assert.Equal(t, tt.expectedSQL, sql) + assert.ElementsMatch(t, tt.expectedArgs, args) + }) + } +} diff --git a/pkg/api/mlflow/dao/repositories/metric.go b/pkg/api/mlflow/dao/repositories/metric.go index 4ffdb26b1..2e5cfcd57 100644 --- a/pkg/api/mlflow/dao/repositories/metric.go +++ b/pkg/api/mlflow/dao/repositories/metric.go @@ -30,6 +30,7 @@ type MetricRepositoryProvider interface { experimentIDs []string, runIDs []string, metricKeys []string, viewType request.ViewType, limit int32, + jsonPathValueMap map[string]string, ) (*sql.Rows, func(*sql.Rows, interface{}) error, error) // GetMetricHistoryBulk returns metrics history bulk. GetMetricHistoryBulk( @@ -151,6 +152,7 @@ func (r MetricRepository) GetMetricHistories( experimentIDs []string, runIDs []string, metricKeys []string, viewType request.ViewType, limit int32, + jsonPathValueMap map[string]string, ) (*sql.Rows, func(*sql.Rows, interface{}) error, error) { // if experimentIDs has been provided then firstly get the runs by provided experimentIDs. if len(experimentIDs) > 0 { @@ -219,6 +221,12 @@ func (r MetricRepository) GetMetricHistories( query.Where("metrics.key IN ?", metricKeys) } + if len(jsonPathValueMap) > 0 { + query.Joins("LEFT JOIN contexts on metrics.context_id = contexts.id") + sql, args := BuildJsonCondition(query.Dialector.Name(), "contexts.json", jsonPathValueMap) + query.Where(sql, args...) + } + rows, err := query.Rows() if err != nil { return nil, nil, eris.Wrapf( diff --git a/pkg/api/mlflow/dao/repositories/mock_metric_repository_provider.go b/pkg/api/mlflow/dao/repositories/mock_metric_repository_provider.go index 60818dbb0..e528442cd 100644 --- a/pkg/api/mlflow/dao/repositories/mock_metric_repository_provider.go +++ b/pkg/api/mlflow/dao/repositories/mock_metric_repository_provider.go @@ -51,34 +51,34 @@ func (_m *MockMetricRepositoryProvider) GetDB() *gorm.DB { return r0 } -// GetMetricHistories provides a mock function with given fields: ctx, namespaceID, experimentIDs, runIDs, metricKeys, viewType, limit -func (_m *MockMetricRepositoryProvider) GetMetricHistories(ctx context.Context, namespaceID uint, experimentIDs []string, runIDs []string, metricKeys []string, viewType request.ViewType, limit int32) (*sql.Rows, func(*sql.Rows, interface{}) error, error) { - ret := _m.Called(ctx, namespaceID, experimentIDs, runIDs, metricKeys, viewType, limit) +// GetMetricHistories provides a mock function with given fields: ctx, namespaceID, experimentIDs, runIDs, metricKeys, viewType, limit, jsonPathValueMap +func (_m *MockMetricRepositoryProvider) GetMetricHistories(ctx context.Context, namespaceID uint, experimentIDs []string, runIDs []string, metricKeys []string, viewType request.ViewType, limit int32, jsonPathValueMap map[string]string) (*sql.Rows, func(*sql.Rows, interface{}) error, error) { + ret := _m.Called(ctx, namespaceID, experimentIDs, runIDs, metricKeys, viewType, limit, jsonPathValueMap) var r0 *sql.Rows var r1 func(*sql.Rows, interface{}) error var r2 error - if rf, ok := ret.Get(0).(func(context.Context, uint, []string, []string, []string, request.ViewType, int32) (*sql.Rows, func(*sql.Rows, interface{}) error, error)); ok { - return rf(ctx, namespaceID, experimentIDs, runIDs, metricKeys, viewType, limit) + if rf, ok := ret.Get(0).(func(context.Context, uint, []string, []string, []string, request.ViewType, int32, map[string]string) (*sql.Rows, func(*sql.Rows, interface{}) error, error)); ok { + return rf(ctx, namespaceID, experimentIDs, runIDs, metricKeys, viewType, limit, jsonPathValueMap) } - if rf, ok := ret.Get(0).(func(context.Context, uint, []string, []string, []string, request.ViewType, int32) *sql.Rows); ok { - r0 = rf(ctx, namespaceID, experimentIDs, runIDs, metricKeys, viewType, limit) + if rf, ok := ret.Get(0).(func(context.Context, uint, []string, []string, []string, request.ViewType, int32, map[string]string) *sql.Rows); ok { + r0 = rf(ctx, namespaceID, experimentIDs, runIDs, metricKeys, viewType, limit, jsonPathValueMap) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*sql.Rows) } } - if rf, ok := ret.Get(1).(func(context.Context, uint, []string, []string, []string, request.ViewType, int32) func(*sql.Rows, interface{}) error); ok { - r1 = rf(ctx, namespaceID, experimentIDs, runIDs, metricKeys, viewType, limit) + if rf, ok := ret.Get(1).(func(context.Context, uint, []string, []string, []string, request.ViewType, int32, map[string]string) func(*sql.Rows, interface{}) error); ok { + r1 = rf(ctx, namespaceID, experimentIDs, runIDs, metricKeys, viewType, limit, jsonPathValueMap) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(func(*sql.Rows, interface{}) error) } } - if rf, ok := ret.Get(2).(func(context.Context, uint, []string, []string, []string, request.ViewType, int32) error); ok { - r2 = rf(ctx, namespaceID, experimentIDs, runIDs, metricKeys, viewType, limit) + if rf, ok := ret.Get(2).(func(context.Context, uint, []string, []string, []string, request.ViewType, int32, map[string]string) error); ok { + r2 = rf(ctx, namespaceID, experimentIDs, runIDs, metricKeys, viewType, limit, jsonPathValueMap) } else { r2 = ret.Error(2) } diff --git a/pkg/api/mlflow/service/metric/service.go b/pkg/api/mlflow/service/metric/service.go index 2f8ec12d0..35a0385b0 100644 --- a/pkg/api/mlflow/service/metric/service.go +++ b/pkg/api/mlflow/service/metric/service.go @@ -89,6 +89,7 @@ func (s Service) GetMetricHistories( req.MetricKeys, req.ViewType, req.MaxResults, + req.Context, ) if err != nil { return nil, nil, api.NewInternalError("Unable to search runs: %s", err) diff --git a/pkg/api/mlflow/service/metric/service_test.go b/pkg/api/mlflow/service/metric/service_test.go index f550f213f..b506c436f 100644 --- a/pkg/api/mlflow/service/metric/service_test.go +++ b/pkg/api/mlflow/service/metric/service_test.go @@ -335,6 +335,7 @@ func TestNewService_GetMetricHistories_Ok(t *testing.T) { []string{"key1", "key2"}, request.ViewTypeActiveOnly, int32(1), + map[string]string(nil), ).Return( tt.expectedRows, tt.expectedIter, @@ -423,6 +424,7 @@ func TestNewService_GetMetricHistories_Error(t *testing.T) { []string{"key1", "key2"}, request.ViewTypeAll, int32(1), + map[string]string(nil), ).Return( nil, nil, diff --git a/tests/integration/golang/fixtures/base.go b/tests/integration/golang/fixtures/base.go index 32d69fe17..58aba2988 100644 --- a/tests/integration/golang/fixtures/base.go +++ b/tests/integration/golang/fixtures/base.go @@ -22,6 +22,7 @@ func (f baseFixtures) TruncateTables() error { models.Param{}, models.LatestMetric{}, models.Metric{}, + models.Context{}, models.Run{}, models.ExperimentTag{}, models.Experiment{}, diff --git a/tests/integration/golang/fixtures/metric.go b/tests/integration/golang/fixtures/metric.go index 44c966eb9..75aa4f942 100644 --- a/tests/integration/golang/fixtures/metric.go +++ b/tests/integration/golang/fixtures/metric.go @@ -8,6 +8,7 @@ import ( "github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models" "github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/repositories" + "github.com/G-Research/fasttrackml/pkg/database" ) // MetricFixtures represents data fixtures object. @@ -26,18 +27,59 @@ func NewMetricFixtures(db *gorm.DB) (*MetricFixtures, error) { // CreateMetric creates new test Metric. func (f MetricFixtures) CreateMetric(ctx context.Context, metric *models.Metric) (*models.Metric, error) { + if metric.Context != nil { + if err := f.baseFixtures.db.WithContext(ctx).FirstOrCreate(&metric.Context).Error; err != nil { + return nil, eris.Wrap(err, "error creating metric context") + } + metric.ContextID = &metric.Context.ID + } if err := f.baseFixtures.db.WithContext(ctx).Create(metric).Error; err != nil { - return nil, eris.Wrap(err, "error creating test metric") + return nil, eris.Wrap(err, "error creating metric") } return metric, nil } +// GetMetricsByRunID returns the metrics by Run ID. +func (f MetricFixtures) GetMetricsByRunID(ctx context.Context, runID string) ([]*models.Metric, error) { + var metrics []*models.Metric + if err := f.db.WithContext(ctx).Where( + "run_uuid = ?", runID, + ).Find(&metrics).Error; err != nil { + return nil, eris.Wrapf(err, "error getting metric by run_uuid: %v", runID) + } + return metrics, nil +} + +// GetMetricsByContext returns metric by a context partial match. +func (f MetricFixtures) GetMetricsByContext( + ctx context.Context, + metricContext map[string]string, +) ([]*models.Metric, error) { + var metrics []*models.Metric + tx := f.db.WithContext(ctx).Model( + &database.Metric{}, + ).Joins( + "LEFT JOIN contexts ON metrics.context_id = contexts.id", + ) + sql, args := repositories.BuildJsonCondition(tx.Dialector.Name(), "contexts.json", metricContext) + if err := tx.Where(sql, args...).Find(&metrics).Error; err != nil { + return nil, eris.Wrapf(err, "error getting metrics by context: %v", metricContext) + } + return metrics, nil +} + // CreateLatestMetric creates new test Latest Metric. func (f MetricFixtures) CreateLatestMetric( ctx context.Context, metric *models.LatestMetric, ) (*models.LatestMetric, error) { + if metric.Context != nil { + if err := f.baseFixtures.db.WithContext(ctx).FirstOrCreate(&metric.Context).Error; err != nil { + return nil, eris.Wrap(err, "error creating latest metric context") + } + metric.ContextID = &metric.Context.ID + } if err := f.baseFixtures.db.WithContext(ctx).Create(metric).Error; err != nil { - return nil, eris.Wrap(err, "error creating test latest metric") + return nil, eris.Wrap(err, "error creating latest metric") } return metric, nil } diff --git a/tests/integration/golang/helpers/arrow.go b/tests/integration/golang/helpers/arrow.go new file mode 100644 index 000000000..6b36f9124 --- /dev/null +++ b/tests/integration/golang/helpers/arrow.go @@ -0,0 +1,50 @@ +package helpers + +import ( + "bytes" + + "github.com/apache/arrow/go/v12/arrow/array" + "github.com/apache/arrow/go/v12/arrow/ipc" + "github.com/apache/arrow/go/v12/arrow/memory" + "github.com/rotisserie/eris" + + "github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models" +) + +func DecodeArrowMetrics(buf *bytes.Buffer) ([]models.Metric, error) { + pool := memory.NewGoAllocator() + + // Create a new reader + reader, err := ipc.NewReader(buf, ipc.WithAllocator(pool)) + if err != nil { + return nil, eris.Wrap(err, "error creating reader for arrow decode") + } + defer reader.Release() + + var metrics []models.Metric + + // Iterate over all records in the reader + for reader.Next() { + rec := reader.Record() + for i := 0; i < int(rec.NumRows()); i++ { + metric := models.Metric{ + RunID: rec.Column(0).(*array.String).Value(i), + Key: rec.Column(1).(*array.String).Value(i), + Step: rec.Column(2).(*array.Int64).Value(i), + Timestamp: rec.Column(3).(*array.Int64).Value(i), + IsNan: rec.Column(4).(*array.Float64).IsNull(i), + } + if !metric.IsNan { + metric.Value = rec.Column(4).(*array.Float64).Value(i) + } + metrics = append(metrics, metric) + } + rec.Release() + } + + if reader.Err() != nil { + return nil, eris.Wrap(reader.Err(), "error processing reader in arrow decode") + } + + return metrics, nil +} diff --git a/tests/integration/golang/mlflow/metric/get_histories_test.go b/tests/integration/golang/mlflow/metric/get_histories_test.go index 41fe0f7ee..42f420786 100644 --- a/tests/integration/golang/mlflow/metric/get_histories_test.go +++ b/tests/integration/golang/mlflow/metric/get_histories_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/stretchr/testify/suite" + "gorm.io/datatypes" "github.com/G-Research/fasttrackml/pkg/api/mlflow" "github.com/G-Research/fasttrackml/pkg/api/mlflow/api" @@ -35,6 +36,13 @@ func (s *GetHistoriesTestSuite) Test_Ok() { }) s.Require().Nil(err) + experiment2, err := s.ExperimentFixtures.CreateExperiment(context.Background(), &models.Experiment{ + Name: "Test Experiment2", + NamespaceID: s.DefaultNamespace.ID, + LifecycleStage: models.LifecycleStageActive, + }) + s.Require().Nil(err) + run1, err := s.RunFixtures.CreateRun(context.Background(), &models.Run{ ID: "run1", Name: "chill-run", @@ -45,16 +53,67 @@ func (s *GetHistoriesTestSuite) Test_Ok() { }) s.Require().Nil(err) - _, err = s.MetricFixtures.CreateMetric(context.Background(), &models.Metric{ + metric, err := s.MetricFixtures.CreateMetric(context.Background(), &models.Metric{ Key: "key1", Value: 1.1, Timestamp: 1234567890, RunID: run1.ID, Step: 1, - IsNan: false, Iter: 1, }) s.Require().Nil(err) + s.Require().Nil(metric.ContextID) + s.Require().Nil(metric.Context) + + metric, err = s.MetricFixtures.CreateMetric(context.Background(), &models.Metric{ + Key: "key2", + Value: 1.1, + Timestamp: 2234567890, + RunID: run1.ID, + Step: 1, + Iter: 1, + Context: &models.Context{ + Json: datatypes.JSON([]byte(` + { + "metrickey1": "metricvalue1", + "metrickey2": "metricvalue2", + "metricnested": { "metricnestedkey": "metricnestedvalue" } + }`, + )), + }, + }) + s.Require().Nil(err) + s.Require().NotNil(metric.ContextID) + s.Require().NotNil(metric.Context) + + // verify metric contexts are persisting + metrics, err := s.MetricFixtures.GetMetricsByRunID(context.Background(), run1.ID) + s.Require().Nil(err) + s.Require().Nil(metrics[0].ContextID) + s.Require().NotNil(metrics[1].ContextID) + + // verify metric contexts can be used for selection (toplevel key) + metrics, err = s.MetricFixtures.GetMetricsByContext(context.Background(), map[string]string{ + "metrickey1": "metricvalue1", + }) + s.Require().Nil(err) + s.Require().Len(metrics, 1) + s.Require().NotNil(metrics[0].ContextID) + + // nested key + metrics, err = s.MetricFixtures.GetMetricsByContext(context.Background(), map[string]string{ + "metricnested.metricnestedkey": "metricnestedvalue", + }) + s.Require().Nil(err) + s.Require().Len(metrics, 1) + s.Require().NotNil(metrics[0].ContextID) + + metrics, err = s.MetricFixtures.GetMetricsByContext( + context.Background(), + map[string]string{"metrickey2": "metricvalue1"}, + ) + s.Require().Nil(err) + s.Require().Len(metrics, 0) run2, err := s.RunFixtures.CreateRun(context.Background(), &models.Run{ ID: "run2", @@ -62,7 +121,7 @@ func (s *GetHistoriesTestSuite) Test_Ok() { Status: models.StatusScheduled, SourceType: "JOB", LifecycleStage: models.LifecycleStageActive, - ExperimentID: *experiment.ID, + ExperimentID: *experiment2.ID, }) s.Require().Nil(err) @@ -78,20 +137,64 @@ func (s *GetHistoriesTestSuite) Test_Ok() { s.Require().Nil(err) tests := []struct { - name string - request *request.GetMetricHistoriesRequest + name string + request *request.GetMetricHistoriesRequest + verifyResponse func(metrics []models.Metric) }{ { name: "GetMetricHistoriesByRunIDs", request: &request.GetMetricHistoriesRequest{ RunIDs: []string{run1.ID, run2.ID}, }, + verifyResponse: func(metrics []models.Metric) { + s.Equal(3, len(metrics)) + s.Equal("run1", metrics[0].RunID) + s.Equal("run1", metrics[1].RunID) + s.Equal("run2", metrics[2].RunID) + }, }, { name: "GetMetricHistoriesByExperimentIDs", request: &request.GetMetricHistoriesRequest{ ExperimentIDs: []string{fmt.Sprintf("%d", *experiment.ID)}, }, + verifyResponse: func(metrics []models.Metric) { + s.Equal(2, len(metrics)) + s.Equal("run1", metrics[0].RunID) + s.Equal("run1", metrics[1].RunID) + }, + }, + { + name: "GetMetricHistoriesByContextMatch", + request: &request.GetMetricHistoriesRequest{ + ExperimentIDs: []string{fmt.Sprintf("%d", *experiment.ID)}, + Context: map[string]string{"metrickey1": "metricvalue1"}, + }, + verifyResponse: func(metrics []models.Metric) { + s.Equal(1, len(metrics)) + s.Equal("run1", metrics[0].RunID) + }, + }, + { + name: "GetMetricHistoriesByNestedContextMatch", + request: &request.GetMetricHistoriesRequest{ + ExperimentIDs: []string{fmt.Sprintf("%d", *experiment.ID)}, + Context: map[string]string{"metricnested.metricnestedkey": "metricnestedvalue"}, + }, + verifyResponse: func(metrics []models.Metric) { + s.Equal(1, len(metrics)) + s.Equal("run1", metrics[0].RunID) + }, + }, + { + name: "GetMetricHistoriesByContextNoMatch", + request: &request.GetMetricHistoriesRequest{ + ExperimentIDs: []string{fmt.Sprintf("%d", *experiment.ID)}, + Context: map[string]string{"metrickey1": "metricvalue2"}, + }, + verifyResponse: func(metrics []models.Metric) { + s.Equal(0, len(metrics)) + }, }, } for _, tt := range tests { @@ -111,9 +214,9 @@ func (s *GetHistoriesTestSuite) Test_Ok() { ), ) - // TODO:DSuhinin - data is encoded so we need a bit more smart way to check the data. - // right now we can go with this simple approach. - s.NotEmpty(resp.String()) + metrics, err := helpers.DecodeArrowMetrics(resp) + s.Require().Nil(err) + tt.verifyResponse(metrics) }) } }