Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metric context query #638

Merged
merged 39 commits into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
3b547bb
add json context to metrics and latest_metrics
suprjinx Nov 15, 2023
96f6d0f
go mod tidy
suprjinx Nov 15, 2023
44cf62f
normalize metric context
suprjinx Nov 15, 2023
2585c31
add metric context model to dao/models
suprjinx Nov 16, 2023
1339d6b
Merge branch 'main' into metric-context-storage
suprjinx Nov 20, 2023
cd3cd81
WIP for metric context querying
suprjinx Nov 21, 2023
00c4ecf
WIP for test
suprjinx Nov 22, 2023
a3051ab
Merge branch 'main' into metric-context-query
suprjinx Nov 22, 2023
ab57751
Create metrics with context properly in integration test fixtures; ve…
suprjinx Nov 22, 2023
38ba059
Add json column selection
suprjinx Nov 23, 2023
c817313
need to check error
suprjinx Nov 23, 2023
52b2ac7
Merge branch 'main' into metric-context-query
suprjinx Nov 27, 2023
798bbd6
go mod tidy
suprjinx Nov 27, 2023
3912c62
fix mock expectation in unit test, remove change to request struct
suprjinx Nov 27, 2023
da78737
Add implementation tweak and test for nested json
suprjinx Nov 27, 2023
b6ecfc2
oops fix imports
suprjinx Nov 27, 2023
edd1f18
use jsonpath: value map instead of nesting
suprjinx Nov 27, 2023
33b7557
use same postgres syntax in fixture and helper
suprjinx Nov 27, 2023
4de4650
Regen mocks
suprjinx Nov 27, 2023
b6f48e4
remove error return and handling
suprjinx Nov 27, 2023
f79615b
too many return vals after removing error from helper
suprjinx Nov 27, 2023
d71b177
Change helper comment to explain usage
suprjinx Nov 27, 2023
5b06bff
PR var renames
suprjinx Nov 28, 2023
b150720
Remove duplicate json column helper from fixture, reuse the function
suprjinx Nov 28, 2023
728b480
Add context to GetMetricHistory request
suprjinx Nov 28, 2023
c1acee0
better validation of get-histories response with context query
suprjinx Nov 29, 2023
bf5c652
add nested test case and rename response validation
suprjinx Nov 29, 2023
b811cf2
add context selection to GetMetricHistory
suprjinx Nov 29, 2023
902bc69
Revert "add context selection to GetMetricHistory"
suprjinx Nov 29, 2023
222b794
Merge branch 'main' into metric-context-query
suprjinx Nov 29, 2023
502c8e9
make json `where` helper more testable and add test
suprjinx Nov 30, 2023
8504e9f
Merge branch 'metric-context-query' of github.com:suprjinx/fasttrackm…
suprjinx Nov 30, 2023
0f8126c
fix lint
suprjinx Nov 30, 2023
4c47e1c
syntax fix
suprjinx Nov 30, 2023
98606ef
arrow decoding for GetHistories test
suprjinx Nov 30, 2023
5bf5dcf
PR requested changes
suprjinx Dec 1, 2023
ab7ff6b
Fix unit test flake
suprjinx Dec 1, 2023
d16e85f
Merge branch 'main' into metric-context-query
suprjinx Dec 1, 2023
4e04e3c
Merge branch 'main' into metric-context-query
suprjinx Dec 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions pkg/api/mlflow/api/request/metric.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ type GetMetricHistoryBulkRequest struct {

// GetMetricHistoriesRequest is a request object for `POST /mlflow/metrics/get-histories` endpoint.
type GetMetricHistoriesRequest struct {
ExperimentIDs []string `json:"experiment_ids"`
RunIDs []string `json:"run_ids"`
MetricKeys []string `json:"metric_keys"`
ViewType ViewType `json:"run_view_type"`
MaxResults int32 `json:"max_results"`
ExperimentIDs []string `json:"experiment_ids"`
RunIDs []string `json:"run_ids"`
MetricKeys []string `json:"metric_keys"`
ViewType ViewType `json:"run_view_type"`
MaxResults int32 `json:"max_results"`
Context map[string]string `json:"context"`
}
38 changes: 38 additions & 0 deletions pkg/api/mlflow/dao/repositories/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"fmt"
"strings"

"gorm.io/driver/postgres"

"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
)

Expand All @@ -30,3 +32,39 @@ func makeParamConflictPlaceholdersAndValues(params []models.Param) (string, []in
}
return placeholders, valuesArray
}

// BuildJsonCondition creates sql and values for where condition to select items having the specified map of json paths
// and values in the given json column. Json path is expressed as "key" or "outerkey.nestedKey".
func BuildJsonCondition(
dialector string,
jsonColumnName string,
jsonPathValueMap map[string]string,
) (sql string, args []any) {
if len(jsonPathValueMap) == 0 {
return sql, args
}
var conditionTemplate string
args = make([]any, len(jsonPathValueMap)*2)
switch dialector {
case postgres.Dialector{}.Name():
conditionTemplate = "%s#>>? = ?"
idx := 0
for k, v := range jsonPathValueMap {
path := strings.ReplaceAll(k, ".", ",")
args[idx] = fmt.Sprintf("{%s}", path)
args[idx+1] = v
idx = idx + 2
}
default:
conditionTemplate = "%s->>? = ?"
idx := 0
for k, v := range jsonPathValueMap {
args[idx] = k
args[idx+1] = v
idx = idx + 2
}
}
conditionTemplate = fmt.Sprintf(conditionTemplate, jsonColumnName)
sql = strings.Repeat(conditionTemplate+" AND ", len(jsonPathValueMap)-1) + conditionTemplate
return sql, args
}
51 changes: 51 additions & 0 deletions pkg/api/mlflow/dao/repositories/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/stretchr/testify/assert"

"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
"github.com/G-Research/fasttrackml/pkg/database"
)

func Test_makeSqlPlaceholders(t *testing.T) {
Expand Down Expand Up @@ -53,3 +54,53 @@ func Test_makeParamConflictPlaceholdersAndValues(t *testing.T) {
assert.Equal(t, tt.expectedValues, values)
}
}

func TestBuildJsonCondition(t *testing.T) {
tests := []struct {
name string
dialector string
jsonColumnName string
jsonPathValueMap map[string]string
expectedSQL string
expectedArgs []interface{}
}{
{
name: "Postgres",
dialector: database.PostgresDialectorName,
jsonColumnName: "contexts.json",
jsonPathValueMap: map[string]string{
"key1": "value1",
"key2.nested": "value2",
},
expectedSQL: "contexts.json#>>? = ? AND contexts.json#>>? = ?",
expectedArgs: []interface{}{"{key1}", "value1", "{key2,nested}", "value2"},
},
{
name: "Sqlite",
dialector: database.SQLiteDialectorName,
jsonColumnName: "contexts.json",
jsonPathValueMap: map[string]string{
"key1": "value1",
"key2.nested": "value2",
},
expectedSQL: "contexts.json->>? = ? AND contexts.json->>? = ?",
expectedArgs: []interface{}{"key1", "value1", "key2.nested", "value2"},
},
{
name: "SqliteEmptyMap",
dialector: database.SQLiteDialectorName,
jsonColumnName: "contexts.json",
jsonPathValueMap: map[string]string{},
expectedSQL: "",
expectedArgs: []interface{}(nil),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sql, args := BuildJsonCondition(tt.dialector, tt.jsonColumnName, tt.jsonPathValueMap)
assert.Equal(t, tt.expectedSQL, sql)
assert.ElementsMatch(t, tt.expectedArgs, args)
})
}
}
8 changes: 8 additions & 0 deletions pkg/api/mlflow/dao/repositories/metric.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type MetricRepositoryProvider interface {
experimentIDs []string, runIDs []string, metricKeys []string,
viewType request.ViewType,
limit int32,
jsonPathValueMap map[string]string,
) (*sql.Rows, func(*sql.Rows, interface{}) error, error)
// GetMetricHistoryBulk returns metrics history bulk.
GetMetricHistoryBulk(
Expand Down Expand Up @@ -151,6 +152,7 @@ func (r MetricRepository) GetMetricHistories(
experimentIDs []string, runIDs []string, metricKeys []string,
viewType request.ViewType,
limit int32,
jsonPathValueMap map[string]string,
) (*sql.Rows, func(*sql.Rows, interface{}) error, error) {
// if experimentIDs has been provided then firstly get the runs by provided experimentIDs.
if len(experimentIDs) > 0 {
Expand Down Expand Up @@ -219,6 +221,12 @@ func (r MetricRepository) GetMetricHistories(
query.Where("metrics.key IN ?", metricKeys)
}

if len(jsonPathValueMap) > 0 {
suprjinx marked this conversation as resolved.
Show resolved Hide resolved
query.Joins("LEFT JOIN contexts on metrics.context_id = contexts.id")
sql, args := BuildJsonCondition(query.Dialector.Name(), "contexts.json", jsonPathValueMap)
query.Where(sql, args...)
suprjinx marked this conversation as resolved.
Show resolved Hide resolved
}

rows, err := query.Rows()
if err != nil {
return nil, nil, eris.Wrapf(
Expand Down
22 changes: 11 additions & 11 deletions pkg/api/mlflow/dao/repositories/mock_metric_repository_provider.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pkg/api/mlflow/service/metric/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func (s Service) GetMetricHistories(
req.MetricKeys,
req.ViewType,
req.MaxResults,
req.Context,
)
if err != nil {
return nil, nil, api.NewInternalError("Unable to search runs: %s", err)
Expand Down
2 changes: 2 additions & 0 deletions pkg/api/mlflow/service/metric/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ func TestNewService_GetMetricHistories_Ok(t *testing.T) {
[]string{"key1", "key2"},
request.ViewTypeActiveOnly,
int32(1),
map[string]string(nil),
).Return(
tt.expectedRows,
tt.expectedIter,
Expand Down Expand Up @@ -423,6 +424,7 @@ func TestNewService_GetMetricHistories_Error(t *testing.T) {
[]string{"key1", "key2"},
request.ViewTypeAll,
int32(1),
map[string]string(nil),
).Return(
nil,
nil,
Expand Down
1 change: 1 addition & 0 deletions tests/integration/golang/fixtures/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ func (f baseFixtures) TruncateTables() error {
models.Param{},
models.LatestMetric{},
models.Metric{},
models.Context{},
models.Run{},
models.ExperimentTag{},
models.Experiment{},
Expand Down
46 changes: 44 additions & 2 deletions tests/integration/golang/fixtures/metric.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/repositories"
"github.com/G-Research/fasttrackml/pkg/database"
)

// MetricFixtures represents data fixtures object.
Expand All @@ -26,18 +27,59 @@ func NewMetricFixtures(db *gorm.DB) (*MetricFixtures, error) {

// CreateMetric creates new test Metric.
func (f MetricFixtures) CreateMetric(ctx context.Context, metric *models.Metric) (*models.Metric, error) {
if metric.Context != nil {
if err := f.baseFixtures.db.WithContext(ctx).FirstOrCreate(&metric.Context).Error; err != nil {
return nil, eris.Wrap(err, "error creating metric context")
}
metric.ContextID = &metric.Context.ID
}
if err := f.baseFixtures.db.WithContext(ctx).Create(metric).Error; err != nil {
return nil, eris.Wrap(err, "error creating test metric")
return nil, eris.Wrap(err, "error creating metric")
}
return metric, nil
}
Comment on lines 29 to 40
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please check this moment -> https://gorm.io/docs/create.html#Create-With-Associations. amybe we even don't need to create twice and gorm can do all the magic under the hood.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gorm is not doing the magic we need -- it could perhaps become a BeforeCreate hook on the Metric and LatestMetric model?


// GetMetricsByRunID returns the metrics by Run ID.
func (f MetricFixtures) GetMetricsByRunID(ctx context.Context, runID string) ([]*models.Metric, error) {
var metrics []*models.Metric
if err := f.db.WithContext(ctx).Where(
"run_uuid = ?", runID,
).Find(&metrics).Error; err != nil {
return nil, eris.Wrapf(err, "error getting metric by run_uuid: %v", runID)
}
return metrics, nil
}

// GetMetricsByContext returns metric by a context partial match.
func (f MetricFixtures) GetMetricsByContext(
ctx context.Context,
metricContext map[string]string,
) ([]*models.Metric, error) {
var metrics []*models.Metric
tx := f.db.WithContext(ctx).Model(
&database.Metric{},
).Joins(
"LEFT JOIN contexts ON metrics.context_id = contexts.id",
)
sql, args := repositories.BuildJsonCondition(tx.Dialector.Name(), "contexts.json", metricContext)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we pass it right into Where?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, personally I think what Geoff wrote is more readable...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also we have to ... the second return val in this approach

if err := tx.Where(sql, args...).Find(&metrics).Error; err != nil {
return nil, eris.Wrapf(err, "error getting metrics by context: %v", metricContext)
}
return metrics, nil
}

// CreateLatestMetric creates new test Latest Metric.
func (f MetricFixtures) CreateLatestMetric(
ctx context.Context, metric *models.LatestMetric,
) (*models.LatestMetric, error) {
if metric.Context != nil {
if err := f.baseFixtures.db.WithContext(ctx).FirstOrCreate(&metric.Context).Error; err != nil {
return nil, eris.Wrap(err, "error creating latest metric context")
}
metric.ContextID = &metric.Context.ID
}
if err := f.baseFixtures.db.WithContext(ctx).Create(metric).Error; err != nil {
return nil, eris.Wrap(err, "error creating test latest metric")
return nil, eris.Wrap(err, "error creating latest metric")
}
Comment on lines +75 to 83
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the same moment here -> maybe we can just create everything during one gorm call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gorm is not able to do "FirstOrCreate" on the association by magic, but I think we could move to Metric.BeforeCreate hook if that's preferable (?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a bit strange cause they say:

When creating some data with associations, if its associations value is not zero-value, those associations will be upserted, and its Hooks methods will be invoked.
``` - upsert. if no, then ok.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my testing did not show that to work in this case, possibly because our FK relationship is inverted (ie, we'd have to create Context and then it could auto create Context.Metrics if we had that association defined).

return metric, nil
}
Expand Down
50 changes: 50 additions & 0 deletions tests/integration/golang/helpers/arrow.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package helpers

import (
"bytes"

"github.com/apache/arrow/go/v12/arrow/array"
"github.com/apache/arrow/go/v12/arrow/ipc"
"github.com/apache/arrow/go/v12/arrow/memory"
"github.com/rotisserie/eris"

"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
)

func DecodeArrowMetrics(buf *bytes.Buffer) ([]models.Metric, error) {
pool := memory.NewGoAllocator()

// Create a new reader
reader, err := ipc.NewReader(buf, ipc.WithAllocator(pool))
if err != nil {
return nil, eris.Wrap(err, "error creating reader for arrow decode")
}
defer reader.Release()

var metrics []models.Metric

// Iterate over all records in the reader
for reader.Next() {
rec := reader.Record()
for i := 0; i < int(rec.NumRows()); i++ {
metric := models.Metric{
RunID: rec.Column(0).(*array.String).Value(i),
Key: rec.Column(1).(*array.String).Value(i),
Step: rec.Column(2).(*array.Int64).Value(i),
Timestamp: rec.Column(3).(*array.Int64).Value(i),
IsNan: rec.Column(4).(*array.Float64).IsNull(i),
}
if !metric.IsNan {
metric.Value = rec.Column(4).(*array.Float64).Value(i)
}
metrics = append(metrics, metric)
}
rec.Release()
}

if reader.Err() != nil {
return nil, eris.Wrap(reader.Err(), "error processing reader in arrow decode")
}

return metrics, nil
}
Loading