Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate context filtering logic into the Aim GetRunMetrics and SearchAlignedMetrics endpoints. #787

Merged
merged 31 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
bf9ab6f
*extend GetRunMetrics and SearchAlignedMetrics to work with `context`.
dsuhinin Jan 3, 2024
b6f5b50
*extend GetRunMetrics and SearchAlignedMetrics to work with `context`.
dsuhinin Jan 4, 2024
be14364
*adjust integration tests for `GetRunMetrics` endpoint.
dsuhinin Jan 4, 2024
72761df
Merge branch 'main' into dsuhinin/adopt-endpoints-to-filter-by-context
dsuhinin Jan 4, 2024
140d41b
*adjust integration tests for `SearchAlignedMetrics` endpoint.
dsuhinin Jan 4, 2024
ce44ca8
Merge branch 'main' into dsuhinin/adopt-endpoints-to-filter-by-context
dsuhinin Jan 10, 2024
c59bf17
Merge branch 'main' into dsuhinin/adopt-endpoints-to-filter-by-context
dsuhinin Jan 12, 2024
1ebc12e
*apply last discussed changes.
dsuhinin Jan 13, 2024
6fa3e06
*bug fix for `AlignedSearch` endpoint.
dsuhinin Jan 15, 2024
d8d9181
*other fixed for `SearchAligned` endpoint.
dsuhinin Jan 15, 2024
75c53c7
Merge branch 'main' into dsuhinin/adopt-endpoints-to-filter-by-context
dsuhinin Jan 15, 2024
c7bec68
*fix `json` comparsing.
dsuhinin Jan 19, 2024
252e9e4
Merge branch 'main' into dsuhinin/adopt-endpoints-to-filter-by-context
dsuhinin Jan 19, 2024
b8a0960
*fix unit tests.
dsuhinin Jan 20, 2024
7f8cb26
Merge remote-tracking branch 'origin/dsuhinin/adopt-endpoints-to-filt…
dsuhinin Jan 20, 2024
f7dc098
*fix integration test fixtures. make `json` comparsing to be database…
dsuhinin Jan 20, 2024
ef3eb20
*changes according to discussion.
dsuhinin Jan 24, 2024
da077ab
Merge branch 'main' into dsuhinin/adopt-endpoints-to-filter-by-context
dsuhinin Jan 24, 2024
36284a0
*apply code formatting.
dsuhinin Jan 24, 2024
300584c
*fix `SearchMetrics` endpoint to correctly handle search when `x_axis…
dsuhinin Jan 26, 2024
5be2271
*working with integration tests.
dsuhinin Jan 29, 2024
280abd7
Merge branch 'main' into dsuhinin/adopt-endpoints-to-filter-by-context
dsuhinin Jan 29, 2024
9e1ef6d
*fixing integration tests.
dsuhinin Jan 29, 2024
a8fb3e0
Merge remote-tracking branch 'origin/dsuhinin/adopt-endpoints-to-filt…
dsuhinin Jan 29, 2024
c320e32
*fixing integration tests.
dsuhinin Jan 29, 2024
e45f996
*fixed typo.
dsuhinin Jan 29, 2024
b3f633c
*revert changes back.
dsuhinin Jan 29, 2024
0603eba
*bug fix for postgress case.
dsuhinin Jan 31, 2024
ded7e14
Merge branch 'main' into dsuhinin/adopt-endpoints-to-filter-by-context
dsuhinin Jan 31, 2024
21435f5
*fix integration tests for postgres case.
dsuhinin Jan 31, 2024
1edd602
Merge remote-tracking branch 'origin/dsuhinin/adopt-endpoints-to-filt…
dsuhinin Jan 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 115 additions & 38 deletions pkg/api/aim/runs.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package aim
import (
"bufio"
"bytes"
"crypto/sha256"
"encoding/binary"
"encoding/json"
"errors"
Expand Down Expand Up @@ -160,17 +161,24 @@ func GetRunMetrics(c *fiber.Ctx) error {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}

b := []struct {
var b []struct {
Name string `json:"name"`
Context fiber.Map `json:"context"`
}{}
}

if err := c.BodyParser(&b); err != nil {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}

metricKeysMap := make(fiber.Map, len(b))
metricKeysMap, contexts := make(fiber.Map, len(b)), make([]string, 0, len(b))
for _, m := range b {
if m.Context != nil {
serializedContext, err := json.Marshal(m.Context)
if err != nil {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}
contexts = append(contexts, string(serializedContext))
}
metricKeysMap[m.Name] = nil
}
metricKeys := make([]string, len(metricKeysMap))
Expand All @@ -181,39 +189,60 @@ func GetRunMetrics(c *fiber.Ctx) error {
i++
}

r := database.Run{
// check that requested run actually exists.
if err := database.DB.Select(
"ID",
).InnerJoins(
"Experiment",
database.DB.Select(
"ID",
).Where(
&models.Experiment{NamespaceID: ns.ID},
),
).First(&database.Run{
ID: p.ID,
}
if err := database.DB.
Select("ID").
InnerJoins(
"Experiment",
database.DB.Select(
"ID",
).Where(
&models.Experiment{NamespaceID: ns.ID},
),
).
Preload("Metrics", func(db *gorm.DB) *gorm.DB {
return db.
Where("key IN ?", metricKeys).
Order("iter")
}).
Preload("Metrics.Context").
First(&r).Error; err != nil {
}).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fiber.ErrNotFound
}
return fmt.Errorf("unable to find run %q: %w", p.ID, err)
}

// fetch run metrics based on provided criteria.
var data []database.Metric
if err := database.DB.Where(
"run_uuid = ?", p.ID,
).InnerJoins(
"Context",
func() *gorm.DB {
query := database.DB
for _, context := range contexts {
if query.Dialector.Name() == database.SQLiteDialectorName {
query = query.Or("json(json) = json(?)", context)
suprjinx marked this conversation as resolved.
Show resolved Hide resolved
} else {
query = query.Or("json = ?", context)
}
}
return query
}(),
).Where(
"key IN ?", metricKeys,
).Order(
"iter",
).Find(
&data,
).Error; err != nil {
return fmt.Errorf("unable to find run metrics: %w", err)
}

metrics := make(map[string]struct {
name string
iters []int
values []*float64
context datatypes.JSON
}, len(metricKeys))
for _, m := range r.Metrics {

for _, m := range data {
v := m.Value
pv := &v
if m.IsNan {
Expand All @@ -231,13 +260,12 @@ func GetRunMetrics(c *fiber.Ctx) error {

resp := make([]fiber.Map, 0, len(metrics))
for _, m := range metrics {
data := fiber.Map{
resp = append(resp, fiber.Map{
"name": m.name,
"iters": m.iters,
"values": m.values,
"context": m.context,
}
resp = append(resp, data)
})
}

return c.JSON(resp)
Expand Down Expand Up @@ -933,22 +961,55 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}

var capacity int
var values []any
values, capacity, contextsMap := []any{}, 0, map[string]string{}
for _, r := range b.Runs {
for _, t := range r.Traces {
l := t.Slice[2]
if l > capacity {
capacity = l
}
values = append(values, r.ID, t.Name, float32(l))
// collect map of unique contexts.
data, err := json.Marshal(t.Context)
if err != nil {
return api.NewInternalError("error serializing provided context: %s", err)
}
sum := sha256.Sum256(data)
contextHash := fmt.Sprintf("%x", sum)
_, ok := contextsMap[contextHash]
if !ok {
contextsMap[contextHash] = string(data)
}
values = append(values, r.ID, t.Name, string(data), float32(l))
}
}

// map context values to context ids
query := database.DB
for _, context := range contextsMap {
if query.Dialector.Name() == database.SQLiteDialectorName {
query = query.Or("json(contexts.json) = json(?)", context)
} else {
query = query.Or("contexts.json = ?", context)
}
}
var contexts []database.Context
if err := query.Find(&contexts).Error; err != nil {
return api.NewInternalError("error getting context information: %s", err)
}

// add context ids to `values`
for _, context := range contexts {
for i := 2; i < len(values); i += 4 {
if values[i] == context.Json.String() {
values[i] = context.ID
}
}
}

var valuesStmt strings.Builder
length := len(values) / 3
length := len(values) / 4
for i := 0; i < length; i++ {
valuesStmt.WriteString("(?, ?, CAST(? AS numeric))")
valuesStmt.WriteString("(?, ?, CAST(? AS numeric), CAST(? AS float))")
if i < length-1 {
valuesStmt.WriteString(",")
}
Expand All @@ -957,21 +1018,34 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
// TODO this should probably be batched

values = append(values, ns.ID, b.AlignBy)

rows, err := database.DB.Raw(
fmt.Sprintf("WITH params(run_uuid, key, steps) AS (VALUES %s)", &valuesStmt)+
" SELECT m.run_uuid, rm.key, m.iter, m.value, m.is_nan, c.json AS context_json FROM metrics AS m"+
fmt.Sprintf("WITH params(run_uuid, key, context_id, steps) AS (VALUES %s)", &valuesStmt)+
" SELECT m.run_uuid, "+
" rm.key, "+
" m.iter, "+
" m.value, "+
" m.is_nan, "+
" rm.context_id, "+
" rm.context_json"+
" FROM metrics AS m"+
" RIGHT JOIN ("+
" SELECT p.run_uuid, p.key, lm.last_iter AS max, (lm.last_iter + 1) / p.steps AS interval"+
" SELECT p.run_uuid, "+
" p.key, "+
" p.context_id, "+
" lm.last_iter AS max, "+
" (lm.last_iter + 1) / p.steps AS interval, "+
" contexts.json AS context_json"+
" FROM params AS p"+
" LEFT JOIN latest_metrics AS lm USING(run_uuid, key)"+
" LEFT JOIN latest_metrics AS lm USING(run_uuid, key, context_id)"+
" INNER JOIN contexts ON contexts.id = lm.context_id"+
" ) rm USING(run_uuid)"+
" LEFT JOIN contexts AS c ON c.id = m.context_id"+
" INNER JOIN runs AS r ON m.run_uuid = r.run_uuid"+
" INNER JOIN experiments AS e ON r.experiment_id = e.experiment_id AND e.namespace_id = ?"+
" WHERE m.key = ?"+
" AND m.iter <= rm.max"+
" AND MOD(m.iter + 1 + rm.interval / 2, rm.interval) < 1"+
" ORDER BY m.run_uuid, rm.key, m.iter",
" ORDER BY r.row_num DESC, rm.key, rm.context_id, m.iter",
values...,
).Rows()
if err != nil {
Expand All @@ -991,6 +1065,7 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
var id string
var key string
var context fiber.Map
var contextID uint
metrics := make([]fiber.Map, 0)
values := make([]float64, 0, capacity)
iters := make([]float64, 0, capacity)
Expand Down Expand Up @@ -1029,7 +1104,7 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
}

// New series of metrics
if metric.Key != key || metric.RunID != id {
if metric.Key != key || metric.RunID != id || metric.ContextID != contextID {
addMetrics()

if metric.RunID != id {
Expand All @@ -1043,6 +1118,7 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
key = metric.Key
values = values[:0]
iters = iters[:0]
context = fiber.Map{}
}

v := metric.Value
Expand All @@ -1056,6 +1132,7 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
if err := json.Unmarshal(metric.Context, &context); err != nil {
return eris.Wrap(err, "error unmarshalling `context` json to `fiber.Map` object")
}
contextID = metric.ContextID
}
}

Expand Down
Loading