Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate context filtering logic into the Aim GetRunMetrics and SearchAlignedMetrics endpoints. #787

Merged
merged 31 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
bf9ab6f
*extend GetRunMetrics and SearchAlignedMetrics to work with `context`.
dsuhinin Jan 3, 2024
b6f5b50
*extend GetRunMetrics and SearchAlignedMetrics to work with `context`.
dsuhinin Jan 4, 2024
be14364
*adjust integration tests for `GetRunMetrics` endpoint.
dsuhinin Jan 4, 2024
72761df
Merge branch 'main' into dsuhinin/adopt-endpoints-to-filter-by-context
dsuhinin Jan 4, 2024
140d41b
*adjust integration tests for `SearchAlignedMetrics` endpoint.
dsuhinin Jan 4, 2024
ce44ca8
Merge branch 'main' into dsuhinin/adopt-endpoints-to-filter-by-context
dsuhinin Jan 10, 2024
c59bf17
Merge branch 'main' into dsuhinin/adopt-endpoints-to-filter-by-context
dsuhinin Jan 12, 2024
1ebc12e
*apply last discussed changes.
dsuhinin Jan 13, 2024
6fa3e06
*bug fix for `AlignedSearch` endpoint.
dsuhinin Jan 15, 2024
d8d9181
*other fixed for `SearchAligned` endpoint.
dsuhinin Jan 15, 2024
75c53c7
Merge branch 'main' into dsuhinin/adopt-endpoints-to-filter-by-context
dsuhinin Jan 15, 2024
c7bec68
*fix `json` comparsing.
dsuhinin Jan 19, 2024
252e9e4
Merge branch 'main' into dsuhinin/adopt-endpoints-to-filter-by-context
dsuhinin Jan 19, 2024
b8a0960
*fix unit tests.
dsuhinin Jan 20, 2024
7f8cb26
Merge remote-tracking branch 'origin/dsuhinin/adopt-endpoints-to-filt…
dsuhinin Jan 20, 2024
f7dc098
*fix integration test fixtures. make `json` comparsing to be database…
dsuhinin Jan 20, 2024
ef3eb20
*changes according to discussion.
dsuhinin Jan 24, 2024
da077ab
Merge branch 'main' into dsuhinin/adopt-endpoints-to-filter-by-context
dsuhinin Jan 24, 2024
36284a0
*apply code formatting.
dsuhinin Jan 24, 2024
300584c
*fix `SearchMetrics` endpoint to correctly handle search when `x_axis…
dsuhinin Jan 26, 2024
5be2271
*working with integration tests.
dsuhinin Jan 29, 2024
280abd7
Merge branch 'main' into dsuhinin/adopt-endpoints-to-filter-by-context
dsuhinin Jan 29, 2024
9e1ef6d
*fixing integration tests.
dsuhinin Jan 29, 2024
a8fb3e0
Merge remote-tracking branch 'origin/dsuhinin/adopt-endpoints-to-filt…
dsuhinin Jan 29, 2024
c320e32
*fixing integration tests.
dsuhinin Jan 29, 2024
e45f996
*fixed typo.
dsuhinin Jan 29, 2024
b3f633c
*revert changes back.
dsuhinin Jan 29, 2024
0603eba
*bug fix for postgress case.
dsuhinin Jan 31, 2024
ded7e14
Merge branch 'main' into dsuhinin/adopt-endpoints-to-filter-by-context
dsuhinin Jan 31, 2024
21435f5
*fix integration tests for postgres case.
dsuhinin Jan 31, 2024
1edd602
Merge remote-tracking branch 'origin/dsuhinin/adopt-endpoints-to-filt…
dsuhinin Jan 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions pkg/api/aim/projects.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,24 +191,29 @@ func GetProjectParams(c *fiber.Ctx) error {
).Joins(
"INNER JOIN experiments ON experiments.experiment_id = runs.experiment_id AND experiments.namespace_id = ?",
ns.ID,
).Preload(
).Joins(
"Context",
).Where(
"runs.lifecycle_stage = ?", database.LifecycleStageActive,
).Find(&metrics); tx.Error != nil {
return fmt.Errorf("error retrieving metric keys: %w", tx.Error)
}

data := make(map[string][]fiber.Map, len(metrics))
data, mapped := make(map[string][]fiber.Map, len(metrics)), make(map[string]map[string]fiber.Map, len(metrics))
for _, metric := range metrics {
// to be properly decoded by AIM UI, json should be represented as a key:value object.
context := fiber.Map{}
if err := json.Unmarshal(metric.Context.Json, &context); err != nil {
return eris.Wrap(err, "error unmarshalling `context` json to `fiber.Map` object")
if mapped[metric.Key] == nil {
mapped[metric.Key] = map[string]fiber.Map{}
}
if _, ok := mapped[metric.Key][metric.Context.GetJsonHash()]; !ok {
// to be properly decoded by AIM UI, json should be represented as a key:value object.
context := fiber.Map{}
if err := json.Unmarshal(metric.Context.Json, &context); err != nil {
return eris.Wrap(err, "error unmarshalling `context` json to `fiber.Map` object")
}
mapped[metric.Key][metric.Context.GetJsonHash()] = context
data[metric.Key] = append(data[metric.Key], context)
}
data[metric.Key] = append(data[metric.Key], context)
}

resp[s] = data
default:
return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("%q is not a valid Sequence", s))
Expand Down
162 changes: 122 additions & 40 deletions pkg/api/aim/runs.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package aim
import (
"bufio"
"bytes"
"crypto/sha256"
"encoding/binary"
"encoding/json"
"errors"
Expand Down Expand Up @@ -160,17 +161,24 @@ func GetRunMetrics(c *fiber.Ctx) error {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}

b := []struct {
var b []struct {
Name string `json:"name"`
Context fiber.Map `json:"context"`
}{}
}

if err := c.BodyParser(&b); err != nil {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}

metricKeysMap := make(fiber.Map, len(b))
metricKeysMap, contexts := make(fiber.Map, len(b)), make([]string, 0, len(b))
for _, m := range b {
if m.Context != nil {
serializedContext, err := json.Marshal(m.Context)
if err != nil {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}
contexts = append(contexts, string(serializedContext))
}
metricKeysMap[m.Name] = nil
}
metricKeys := make([]string, len(metricKeysMap))
Expand All @@ -181,39 +189,60 @@ func GetRunMetrics(c *fiber.Ctx) error {
i++
}

r := database.Run{
// check that requested run actually exists.
if err := database.DB.Select(
"ID",
).InnerJoins(
"Experiment",
database.DB.Select(
"ID",
).Where(
&models.Experiment{NamespaceID: ns.ID},
),
).First(&database.Run{
ID: p.ID,
}
if err := database.DB.
Select("ID").
InnerJoins(
"Experiment",
database.DB.Select(
"ID",
).Where(
&models.Experiment{NamespaceID: ns.ID},
),
).
Preload("Metrics", func(db *gorm.DB) *gorm.DB {
return db.
Where("key IN ?", metricKeys).
Order("iter")
}).
Preload("Metrics.Context").
First(&r).Error; err != nil {
}).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fiber.ErrNotFound
}
return fmt.Errorf("unable to find run %q: %w", p.ID, err)
}

// fetch run metrics based on provided criteria.
var data []database.Metric
if err := database.DB.Where(
"run_uuid = ?", p.ID,
).InnerJoins(
"Context",
func() *gorm.DB {
query := database.DB
for _, context := range contexts {
if query.Dialector.Name() == database.SQLiteDialectorName {
query = query.Or("json(json) = json(?)", context)
suprjinx marked this conversation as resolved.
Show resolved Hide resolved
} else {
query = query.Or("json = ?", context)
}
}
return query
}(),
).Where(
"key IN ?", metricKeys,
).Order(
"iter",
).Find(
&data,
).Error; err != nil {
return fmt.Errorf("unable to find run metrics: %w", err)
}

metrics := make(map[string]struct {
name string
iters []int
values []*float64
context datatypes.JSON
}, len(metricKeys))
for _, m := range r.Metrics {

for _, m := range data {
v := m.Value
pv := &v
if m.IsNan {
Expand All @@ -231,13 +260,12 @@ func GetRunMetrics(c *fiber.Ctx) error {

resp := make([]fiber.Map, 0, len(metrics))
for _, m := range metrics {
data := fiber.Map{
resp = append(resp, fiber.Map{
"name": m.name,
"iters": m.iters,
"values": m.values,
"context": m.context,
}
resp = append(resp, data)
})
}

return c.JSON(resp)
Expand Down Expand Up @@ -742,7 +770,8 @@ func SearchMetrics(c *fiber.Ctx) error {
tx.
Select("metrics.*", "runmetrics.context_json", "x_axis.value as x_axis_value", "x_axis.is_nan as x_axis_is_nan").
Joins(
"LEFT JOIN metrics x_axis ON metrics.run_uuid = x_axis.run_uuid AND metrics.iter = x_axis.iter AND x_axis.key = ?",
"LEFT JOIN metrics x_axis ON metrics.run_uuid = x_axis.run_uuid AND "+
"metrics.iter = x_axis.iter AND x_axis.context_id = metrics.context_id AND x_axis.key = ?",
q.XAxis,
)
xAxis = true
Expand Down Expand Up @@ -934,22 +963,59 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}

var capacity int
var values []any
values, capacity, contextsMap := []any{}, 0, map[string]string{}
for _, r := range b.Runs {
for _, t := range r.Traces {
l := t.Slice[2]
if l > capacity {
capacity = l
}
values = append(values, r.ID, t.Name, float32(l))
// collect map of unique contexts.
data, err := json.Marshal(t.Context)
if err != nil {
return api.NewInternalError("error serializing context: %s", err)
}
sum := sha256.Sum256(data)
contextHash := fmt.Sprintf("%x", sum)
_, ok := contextsMap[contextHash]
if !ok {
contextsMap[contextHash] = string(data)
}
values = append(values, r.ID, t.Name, string(data), float32(l))
}
}

// map context values to context ids
query := database.DB
for _, context := range contextsMap {
if query.Dialector.Name() == database.SQLiteDialectorName {
query = query.Or("json(contexts.json) = json(?)", context)
} else {
query = query.Or("contexts.json = ?", context)
}
}
var contexts []database.Context
if err := query.Find(&contexts).Error; err != nil {
return api.NewInternalError("error getting context information: %s", err)
}

// add context ids to `values`
for _, context := range contexts {
for i := 2; i < len(values); i += 4 {
json, err := json.Marshal(context.Json)
if err != nil {
return api.NewInternalError("error serializing context: %s", err)
}
if values[i] == string(json) {
values[i] = context.ID
}
}
}

var valuesStmt strings.Builder
length := len(values) / 3
length := len(values) / 4
for i := 0; i < length; i++ {
valuesStmt.WriteString("(?, ?, CAST(? AS numeric))")
valuesStmt.WriteString("(?, ?, CAST(? AS numeric), CAST(? AS numeric))")
if i < length-1 {
valuesStmt.WriteString(",")
}
Expand All @@ -958,21 +1024,34 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
// TODO this should probably be batched

values = append(values, ns.ID, b.AlignBy)

rows, err := database.DB.Raw(
fmt.Sprintf("WITH params(run_uuid, key, steps) AS (VALUES %s)", &valuesStmt)+
" SELECT m.run_uuid, rm.key, m.iter, m.value, m.is_nan, c.json AS context_json FROM metrics AS m"+
fmt.Sprintf("WITH params(run_uuid, key, context_id, steps) AS (VALUES %s)", &valuesStmt)+
" SELECT m.run_uuid, "+
" rm.key, "+
" m.iter, "+
" m.value, "+
" m.is_nan, "+
" rm.context_id, "+
" rm.context_json"+
" FROM metrics AS m"+
" RIGHT JOIN ("+
" SELECT p.run_uuid, p.key, lm.last_iter AS max, (lm.last_iter + 1) / p.steps AS interval"+
" SELECT p.run_uuid, "+
" p.key, "+
" p.context_id, "+
" lm.last_iter AS max, "+
" (lm.last_iter + 1) / p.steps AS interval, "+
" contexts.json AS context_json"+
" FROM params AS p"+
" LEFT JOIN latest_metrics AS lm USING(run_uuid, key)"+
" ) rm USING(run_uuid)"+
" LEFT JOIN contexts AS c ON c.id = m.context_id"+
" LEFT JOIN latest_metrics AS lm USING(run_uuid, key, context_id)"+
" INNER JOIN contexts ON contexts.id = lm.context_id"+
" ) rm USING(run_uuid, context_id)"+
" INNER JOIN runs AS r ON m.run_uuid = r.run_uuid"+
" INNER JOIN experiments AS e ON r.experiment_id = e.experiment_id AND e.namespace_id = ?"+
" WHERE m.key = ?"+
" AND m.iter <= rm.max"+
" AND MOD(m.iter + 1 + rm.interval / 2, rm.interval) < 1"+
" ORDER BY m.run_uuid, rm.key, m.iter",
" ORDER BY r.row_num DESC, rm.key, rm.context_id, m.iter",
values...,
).Rows()
if err != nil {
Expand All @@ -992,6 +1071,7 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
var id string
var key string
var context fiber.Map
var contextID uint
metrics := make([]fiber.Map, 0)
values := make([]float64, 0, capacity)
iters := make([]float64, 0, capacity)
Expand Down Expand Up @@ -1030,7 +1110,7 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
}

// New series of metrics
if metric.Key != key || metric.RunID != id {
if metric.Key != key || metric.RunID != id || metric.ContextID != contextID {
addMetrics()

if metric.RunID != id {
Expand All @@ -1044,6 +1124,7 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
key = metric.Key
values = values[:0]
iters = iters[:0]
context = fiber.Map{}
}

v := metric.Value
Expand All @@ -1057,6 +1138,7 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
if err := json.Unmarshal(metric.Context, &context); err != nil {
return eris.Wrap(err, "error unmarshalling `context` json to `fiber.Map` object")
}
contextID = metric.ContextID
}
}

Expand Down
7 changes: 7 additions & 0 deletions pkg/database/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package database

import (
"context"
"crypto/sha256"
"database/sql"
"database/sql/driver"
"encoding/hex"
Expand Down Expand Up @@ -165,6 +166,12 @@ type Context struct {
Json datatypes.JSON `gorm:"not null;unique;index"`
}

// GetJsonHash returns hash of the Context.Json
func (c Context) GetJsonHash() string {
hash := sha256.Sum256(c.Json)
return string(hash[:])
}

type AlembicVersion struct {
Version string `gorm:"column:version_num;type:varchar(32);not null;primaryKey"`
}
Expand Down
1 change: 0 additions & 1 deletion python/fasttrackml/LICENSE

This file was deleted.

1 change: 1 addition & 0 deletions python/fasttrackml/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
../LICENSE
1 change: 0 additions & 1 deletion python/fasttrackml/README.md

This file was deleted.

1 change: 1 addition & 0 deletions python/fasttrackml/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
../README.md
Loading
Loading