Skip to content

Commit

Permalink
Metric context query (#638)
Browse files Browse the repository at this point in the history
Introduce json-path style selection of metrics, where a map of JsonPath: Value can be used to filter metric results.
  • Loading branch information
suprjinx authored Dec 2, 2023
1 parent e1eef7a commit 0b87a35
Show file tree
Hide file tree
Showing 11 changed files with 323 additions and 26 deletions.
11 changes: 6 additions & 5 deletions pkg/api/mlflow/api/request/metric.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ type GetMetricHistoryBulkRequest struct {

// GetMetricHistoriesRequest is a request object for `POST /mlflow/metrics/get-histories` endpoint.
type GetMetricHistoriesRequest struct {
ExperimentIDs []string `json:"experiment_ids"`
RunIDs []string `json:"run_ids"`
MetricKeys []string `json:"metric_keys"`
ViewType ViewType `json:"run_view_type"`
MaxResults int32 `json:"max_results"`
ExperimentIDs []string `json:"experiment_ids"`
RunIDs []string `json:"run_ids"`
MetricKeys []string `json:"metric_keys"`
ViewType ViewType `json:"run_view_type"`
MaxResults int32 `json:"max_results"`
Context map[string]string `json:"context"`
}
38 changes: 38 additions & 0 deletions pkg/api/mlflow/dao/repositories/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"fmt"
"strings"

"gorm.io/driver/postgres"

"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
)

Expand All @@ -30,3 +32,39 @@ func makeParamConflictPlaceholdersAndValues(params []models.Param) (string, []in
}
return placeholders, valuesArray
}

// BuildJsonCondition creates sql and values for where condition to select items having the specified map of json paths
// and values in the given json column. Json path is expressed as "key" or "outerkey.nestedKey".
func BuildJsonCondition(
dialector string,
jsonColumnName string,
jsonPathValueMap map[string]string,
) (sql string, args []any) {
if len(jsonPathValueMap) == 0 {
return sql, args
}
var conditionTemplate string
args = make([]any, len(jsonPathValueMap)*2)
switch dialector {
case postgres.Dialector{}.Name():
conditionTemplate = "%s#>>? = ?"
idx := 0
for k, v := range jsonPathValueMap {
path := strings.ReplaceAll(k, ".", ",")
args[idx] = fmt.Sprintf("{%s}", path)
args[idx+1] = v
idx = idx + 2
}
default:
conditionTemplate = "%s->>? = ?"
idx := 0
for k, v := range jsonPathValueMap {
args[idx] = k
args[idx+1] = v
idx = idx + 2
}
}
conditionTemplate = fmt.Sprintf(conditionTemplate, jsonColumnName)
sql = strings.Repeat(conditionTemplate+" AND ", len(jsonPathValueMap)-1) + conditionTemplate
return sql, args
}
51 changes: 51 additions & 0 deletions pkg/api/mlflow/dao/repositories/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/stretchr/testify/assert"

"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
"github.com/G-Research/fasttrackml/pkg/database"
)

func Test_makeSqlPlaceholders(t *testing.T) {
Expand Down Expand Up @@ -53,3 +54,53 @@ func Test_makeParamConflictPlaceholdersAndValues(t *testing.T) {
assert.Equal(t, tt.expectedValues, values)
}
}

func TestBuildJsonCondition(t *testing.T) {
tests := []struct {
name string
dialector string
jsonColumnName string
jsonPathValueMap map[string]string
expectedSQL string
expectedArgs []interface{}
}{
{
name: "Postgres",
dialector: database.PostgresDialectorName,
jsonColumnName: "contexts.json",
jsonPathValueMap: map[string]string{
"key1": "value1",
"key2.nested": "value2",
},
expectedSQL: "contexts.json#>>? = ? AND contexts.json#>>? = ?",
expectedArgs: []interface{}{"{key1}", "value1", "{key2,nested}", "value2"},
},
{
name: "Sqlite",
dialector: database.SQLiteDialectorName,
jsonColumnName: "contexts.json",
jsonPathValueMap: map[string]string{
"key1": "value1",
"key2.nested": "value2",
},
expectedSQL: "contexts.json->>? = ? AND contexts.json->>? = ?",
expectedArgs: []interface{}{"key1", "value1", "key2.nested", "value2"},
},
{
name: "SqliteEmptyMap",
dialector: database.SQLiteDialectorName,
jsonColumnName: "contexts.json",
jsonPathValueMap: map[string]string{},
expectedSQL: "",
expectedArgs: []interface{}(nil),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sql, args := BuildJsonCondition(tt.dialector, tt.jsonColumnName, tt.jsonPathValueMap)
assert.Equal(t, tt.expectedSQL, sql)
assert.ElementsMatch(t, tt.expectedArgs, args)
})
}
}
8 changes: 8 additions & 0 deletions pkg/api/mlflow/dao/repositories/metric.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type MetricRepositoryProvider interface {
experimentIDs []string, runIDs []string, metricKeys []string,
viewType request.ViewType,
limit int32,
jsonPathValueMap map[string]string,
) (*sql.Rows, func(*sql.Rows, interface{}) error, error)
// GetMetricHistoryBulk returns metrics history bulk.
GetMetricHistoryBulk(
Expand Down Expand Up @@ -151,6 +152,7 @@ func (r MetricRepository) GetMetricHistories(
experimentIDs []string, runIDs []string, metricKeys []string,
viewType request.ViewType,
limit int32,
jsonPathValueMap map[string]string,
) (*sql.Rows, func(*sql.Rows, interface{}) error, error) {
// if experimentIDs has been provided then firstly get the runs by provided experimentIDs.
if len(experimentIDs) > 0 {
Expand Down Expand Up @@ -219,6 +221,12 @@ func (r MetricRepository) GetMetricHistories(
query.Where("metrics.key IN ?", metricKeys)
}

if len(jsonPathValueMap) > 0 {
query.Joins("LEFT JOIN contexts on metrics.context_id = contexts.id")
sql, args := BuildJsonCondition(query.Dialector.Name(), "contexts.json", jsonPathValueMap)
query.Where(sql, args...)
}

rows, err := query.Rows()
if err != nil {
return nil, nil, eris.Wrapf(
Expand Down
22 changes: 11 additions & 11 deletions pkg/api/mlflow/dao/repositories/mock_metric_repository_provider.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pkg/api/mlflow/service/metric/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func (s Service) GetMetricHistories(
req.MetricKeys,
req.ViewType,
req.MaxResults,
req.Context,
)
if err != nil {
return nil, nil, api.NewInternalError("Unable to search runs: %s", err)
Expand Down
2 changes: 2 additions & 0 deletions pkg/api/mlflow/service/metric/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ func TestNewService_GetMetricHistories_Ok(t *testing.T) {
[]string{"key1", "key2"},
request.ViewTypeActiveOnly,
int32(1),
map[string]string(nil),
).Return(
tt.expectedRows,
tt.expectedIter,
Expand Down Expand Up @@ -423,6 +424,7 @@ func TestNewService_GetMetricHistories_Error(t *testing.T) {
[]string{"key1", "key2"},
request.ViewTypeAll,
int32(1),
map[string]string(nil),
).Return(
nil,
nil,
Expand Down
1 change: 1 addition & 0 deletions tests/integration/golang/fixtures/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ func (f baseFixtures) TruncateTables() error {
models.Param{},
models.LatestMetric{},
models.Metric{},
models.Context{},
models.Run{},
models.ExperimentTag{},
models.Experiment{},
Expand Down
46 changes: 44 additions & 2 deletions tests/integration/golang/fixtures/metric.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/repositories"
"github.com/G-Research/fasttrackml/pkg/database"
)

// MetricFixtures represents data fixtures object.
Expand All @@ -26,18 +27,59 @@ func NewMetricFixtures(db *gorm.DB) (*MetricFixtures, error) {

// CreateMetric creates new test Metric.
func (f MetricFixtures) CreateMetric(ctx context.Context, metric *models.Metric) (*models.Metric, error) {
if metric.Context != nil {
if err := f.baseFixtures.db.WithContext(ctx).FirstOrCreate(&metric.Context).Error; err != nil {
return nil, eris.Wrap(err, "error creating metric context")
}
metric.ContextID = &metric.Context.ID
}
if err := f.baseFixtures.db.WithContext(ctx).Create(metric).Error; err != nil {
return nil, eris.Wrap(err, "error creating test metric")
return nil, eris.Wrap(err, "error creating metric")
}
return metric, nil
}

// GetMetricsByRunID returns the metrics by Run ID.
func (f MetricFixtures) GetMetricsByRunID(ctx context.Context, runID string) ([]*models.Metric, error) {
var metrics []*models.Metric
if err := f.db.WithContext(ctx).Where(
"run_uuid = ?", runID,
).Find(&metrics).Error; err != nil {
return nil, eris.Wrapf(err, "error getting metric by run_uuid: %v", runID)
}
return metrics, nil
}

// GetMetricsByContext returns metric by a context partial match.
func (f MetricFixtures) GetMetricsByContext(
ctx context.Context,
metricContext map[string]string,
) ([]*models.Metric, error) {
var metrics []*models.Metric
tx := f.db.WithContext(ctx).Model(
&database.Metric{},
).Joins(
"LEFT JOIN contexts ON metrics.context_id = contexts.id",
)
sql, args := repositories.BuildJsonCondition(tx.Dialector.Name(), "contexts.json", metricContext)
if err := tx.Where(sql, args...).Find(&metrics).Error; err != nil {
return nil, eris.Wrapf(err, "error getting metrics by context: %v", metricContext)
}
return metrics, nil
}

// CreateLatestMetric creates new test Latest Metric.
func (f MetricFixtures) CreateLatestMetric(
ctx context.Context, metric *models.LatestMetric,
) (*models.LatestMetric, error) {
if metric.Context != nil {
if err := f.baseFixtures.db.WithContext(ctx).FirstOrCreate(&metric.Context).Error; err != nil {
return nil, eris.Wrap(err, "error creating latest metric context")
}
metric.ContextID = &metric.Context.ID
}
if err := f.baseFixtures.db.WithContext(ctx).Create(metric).Error; err != nil {
return nil, eris.Wrap(err, "error creating test latest metric")
return nil, eris.Wrap(err, "error creating latest metric")
}
return metric, nil
}
Expand Down
50 changes: 50 additions & 0 deletions tests/integration/golang/helpers/arrow.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package helpers

import (
"bytes"

"github.com/apache/arrow/go/v12/arrow/array"
"github.com/apache/arrow/go/v12/arrow/ipc"
"github.com/apache/arrow/go/v12/arrow/memory"
"github.com/rotisserie/eris"

"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
)

func DecodeArrowMetrics(buf *bytes.Buffer) ([]models.Metric, error) {
pool := memory.NewGoAllocator()

// Create a new reader
reader, err := ipc.NewReader(buf, ipc.WithAllocator(pool))
if err != nil {
return nil, eris.Wrap(err, "error creating reader for arrow decode")
}
defer reader.Release()

var metrics []models.Metric

// Iterate over all records in the reader
for reader.Next() {
rec := reader.Record()
for i := 0; i < int(rec.NumRows()); i++ {
metric := models.Metric{
RunID: rec.Column(0).(*array.String).Value(i),
Key: rec.Column(1).(*array.String).Value(i),
Step: rec.Column(2).(*array.Int64).Value(i),
Timestamp: rec.Column(3).(*array.Int64).Value(i),
IsNan: rec.Column(4).(*array.Float64).IsNull(i),
}
if !metric.IsNan {
metric.Value = rec.Column(4).(*array.Float64).Value(i)
}
metrics = append(metrics, metric)
}
rec.Release()
}

if reader.Err() != nil {
return nil, eris.Wrap(reader.Err(), "error processing reader in arrow decode")
}

return metrics, nil
}
Loading

0 comments on commit 0b87a35

Please sign in to comment.