Skip to content

Commit

Permalink
Integrate context filtering logic into the Aim GetRunMetrics and Sear…
Browse files Browse the repository at this point in the history
…chAlignedMetrics endpoints. (#787)

Integrate context filtering into AIM GetRunMetrics and SearchAlignedMetrics endpoints.
  • Loading branch information
dsuhinin authored Feb 1, 2024
1 parent 27e237f commit dd785a7
Show file tree
Hide file tree
Showing 8 changed files with 541 additions and 142 deletions.
21 changes: 13 additions & 8 deletions pkg/api/aim/projects.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,24 +191,29 @@ func GetProjectParams(c *fiber.Ctx) error {
).Joins(
"INNER JOIN experiments ON experiments.experiment_id = runs.experiment_id AND experiments.namespace_id = ?",
ns.ID,
).Preload(
).Joins(
"Context",
).Where(
"runs.lifecycle_stage = ?", database.LifecycleStageActive,
).Find(&metrics); tx.Error != nil {
return fmt.Errorf("error retrieving metric keys: %w", tx.Error)
}

data := make(map[string][]fiber.Map, len(metrics))
data, mapped := make(map[string][]fiber.Map, len(metrics)), make(map[string]map[string]fiber.Map, len(metrics))
for _, metric := range metrics {
// to be properly decoded by AIM UI, json should be represented as a key:value object.
context := fiber.Map{}
if err := json.Unmarshal(metric.Context.Json, &context); err != nil {
return eris.Wrap(err, "error unmarshalling `context` json to `fiber.Map` object")
if mapped[metric.Key] == nil {
mapped[metric.Key] = map[string]fiber.Map{}
}
if _, ok := mapped[metric.Key][metric.Context.GetJsonHash()]; !ok {
// to be properly decoded by AIM UI, json should be represented as a key:value object.
context := fiber.Map{}
if err := json.Unmarshal(metric.Context.Json, &context); err != nil {
return eris.Wrap(err, "error unmarshalling `context` json to `fiber.Map` object")
}
mapped[metric.Key][metric.Context.GetJsonHash()] = context
data[metric.Key] = append(data[metric.Key], context)
}
data[metric.Key] = append(data[metric.Key], context)
}

resp[s] = data
default:
return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("%q is not a valid Sequence", s))
Expand Down
162 changes: 122 additions & 40 deletions pkg/api/aim/runs.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package aim
import (
"bufio"
"bytes"
"crypto/sha256"
"encoding/binary"
"encoding/json"
"errors"
Expand Down Expand Up @@ -160,17 +161,24 @@ func GetRunMetrics(c *fiber.Ctx) error {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}

b := []struct {
var b []struct {
Name string `json:"name"`
Context fiber.Map `json:"context"`
}{}
}

if err := c.BodyParser(&b); err != nil {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}

metricKeysMap := make(fiber.Map, len(b))
metricKeysMap, contexts := make(fiber.Map, len(b)), make([]string, 0, len(b))
for _, m := range b {
if m.Context != nil {
serializedContext, err := json.Marshal(m.Context)
if err != nil {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}
contexts = append(contexts, string(serializedContext))
}
metricKeysMap[m.Name] = nil
}
metricKeys := make([]string, len(metricKeysMap))
Expand All @@ -181,39 +189,60 @@ func GetRunMetrics(c *fiber.Ctx) error {
i++
}

r := database.Run{
// check that requested run actually exists.
if err := database.DB.Select(
"ID",
).InnerJoins(
"Experiment",
database.DB.Select(
"ID",
).Where(
&models.Experiment{NamespaceID: ns.ID},
),
).First(&database.Run{
ID: p.ID,
}
if err := database.DB.
Select("ID").
InnerJoins(
"Experiment",
database.DB.Select(
"ID",
).Where(
&models.Experiment{NamespaceID: ns.ID},
),
).
Preload("Metrics", func(db *gorm.DB) *gorm.DB {
return db.
Where("key IN ?", metricKeys).
Order("iter")
}).
Preload("Metrics.Context").
First(&r).Error; err != nil {
}).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fiber.ErrNotFound
}
return fmt.Errorf("unable to find run %q: %w", p.ID, err)
}

// fetch run metrics based on provided criteria.
var data []database.Metric
if err := database.DB.Where(
"run_uuid = ?", p.ID,
).InnerJoins(
"Context",
func() *gorm.DB {
query := database.DB
for _, context := range contexts {
if query.Dialector.Name() == database.SQLiteDialectorName {
query = query.Or("json(json) = json(?)", context)
} else {
query = query.Or("json = ?", context)
}
}
return query
}(),
).Where(
"key IN ?", metricKeys,
).Order(
"iter",
).Find(
&data,
).Error; err != nil {
return fmt.Errorf("unable to find run metrics: %w", err)
}

metrics := make(map[string]struct {
name string
iters []int
values []*float64
context datatypes.JSON
}, len(metricKeys))
for _, m := range r.Metrics {

for _, m := range data {
v := m.Value
pv := &v
if m.IsNan {
Expand All @@ -231,13 +260,12 @@ func GetRunMetrics(c *fiber.Ctx) error {

resp := make([]fiber.Map, 0, len(metrics))
for _, m := range metrics {
data := fiber.Map{
resp = append(resp, fiber.Map{
"name": m.name,
"iters": m.iters,
"values": m.values,
"context": m.context,
}
resp = append(resp, data)
})
}

return c.JSON(resp)
Expand Down Expand Up @@ -742,7 +770,8 @@ func SearchMetrics(c *fiber.Ctx) error {
tx.
Select("metrics.*", "runmetrics.context_json", "x_axis.value as x_axis_value", "x_axis.is_nan as x_axis_is_nan").
Joins(
"LEFT JOIN metrics x_axis ON metrics.run_uuid = x_axis.run_uuid AND metrics.iter = x_axis.iter AND x_axis.key = ?",
"LEFT JOIN metrics x_axis ON metrics.run_uuid = x_axis.run_uuid AND "+
"metrics.iter = x_axis.iter AND x_axis.context_id = metrics.context_id AND x_axis.key = ?",
q.XAxis,
)
xAxis = true
Expand Down Expand Up @@ -934,22 +963,59 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}

var capacity int
var values []any
values, capacity, contextsMap := []any{}, 0, map[string]string{}
for _, r := range b.Runs {
for _, t := range r.Traces {
l := t.Slice[2]
if l > capacity {
capacity = l
}
values = append(values, r.ID, t.Name, float32(l))
// collect map of unique contexts.
data, err := json.Marshal(t.Context)
if err != nil {
return api.NewInternalError("error serializing context: %s", err)
}
sum := sha256.Sum256(data)
contextHash := fmt.Sprintf("%x", sum)
_, ok := contextsMap[contextHash]
if !ok {
contextsMap[contextHash] = string(data)
}
values = append(values, r.ID, t.Name, string(data), float32(l))
}
}

// map context values to context ids
query := database.DB
for _, context := range contextsMap {
if query.Dialector.Name() == database.SQLiteDialectorName {
query = query.Or("json(contexts.json) = json(?)", context)
} else {
query = query.Or("contexts.json = ?", context)
}
}
var contexts []database.Context
if err := query.Find(&contexts).Error; err != nil {
return api.NewInternalError("error getting context information: %s", err)
}

// add context ids to `values`
for _, context := range contexts {
for i := 2; i < len(values); i += 4 {
json, err := json.Marshal(context.Json)
if err != nil {
return api.NewInternalError("error serializing context: %s", err)
}
if values[i] == string(json) {
values[i] = context.ID
}
}
}

var valuesStmt strings.Builder
length := len(values) / 3
length := len(values) / 4
for i := 0; i < length; i++ {
valuesStmt.WriteString("(?, ?, CAST(? AS numeric))")
valuesStmt.WriteString("(?, ?, CAST(? AS numeric), CAST(? AS numeric))")
if i < length-1 {
valuesStmt.WriteString(",")
}
Expand All @@ -958,21 +1024,34 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
// TODO this should probably be batched

values = append(values, ns.ID, b.AlignBy)

rows, err := database.DB.Raw(
fmt.Sprintf("WITH params(run_uuid, key, steps) AS (VALUES %s)", &valuesStmt)+
" SELECT m.run_uuid, rm.key, m.iter, m.value, m.is_nan, c.json AS context_json FROM metrics AS m"+
fmt.Sprintf("WITH params(run_uuid, key, context_id, steps) AS (VALUES %s)", &valuesStmt)+
" SELECT m.run_uuid, "+
" rm.key, "+
" m.iter, "+
" m.value, "+
" m.is_nan, "+
" rm.context_id, "+
" rm.context_json"+
" FROM metrics AS m"+
" RIGHT JOIN ("+
" SELECT p.run_uuid, p.key, lm.last_iter AS max, (lm.last_iter + 1) / p.steps AS interval"+
" SELECT p.run_uuid, "+
" p.key, "+
" p.context_id, "+
" lm.last_iter AS max, "+
" (lm.last_iter + 1) / p.steps AS interval, "+
" contexts.json AS context_json"+
" FROM params AS p"+
" LEFT JOIN latest_metrics AS lm USING(run_uuid, key)"+
" ) rm USING(run_uuid)"+
" LEFT JOIN contexts AS c ON c.id = m.context_id"+
" LEFT JOIN latest_metrics AS lm USING(run_uuid, key, context_id)"+
" INNER JOIN contexts ON contexts.id = lm.context_id"+
" ) rm USING(run_uuid, context_id)"+
" INNER JOIN runs AS r ON m.run_uuid = r.run_uuid"+
" INNER JOIN experiments AS e ON r.experiment_id = e.experiment_id AND e.namespace_id = ?"+
" WHERE m.key = ?"+
" AND m.iter <= rm.max"+
" AND MOD(m.iter + 1 + rm.interval / 2, rm.interval) < 1"+
" ORDER BY m.run_uuid, rm.key, m.iter",
" ORDER BY r.row_num DESC, rm.key, rm.context_id, m.iter",
values...,
).Rows()
if err != nil {
Expand All @@ -992,6 +1071,7 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
var id string
var key string
var context fiber.Map
var contextID uint
metrics := make([]fiber.Map, 0)
values := make([]float64, 0, capacity)
iters := make([]float64, 0, capacity)
Expand Down Expand Up @@ -1030,7 +1110,7 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
}

// New series of metrics
if metric.Key != key || metric.RunID != id {
if metric.Key != key || metric.RunID != id || metric.ContextID != contextID {
addMetrics()

if metric.RunID != id {
Expand All @@ -1044,6 +1124,7 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
key = metric.Key
values = values[:0]
iters = iters[:0]
context = fiber.Map{}
}

v := metric.Value
Expand All @@ -1057,6 +1138,7 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
if err := json.Unmarshal(metric.Context, &context); err != nil {
return eris.Wrap(err, "error unmarshalling `context` json to `fiber.Map` object")
}
contextID = metric.ContextID
}
}

Expand Down
7 changes: 7 additions & 0 deletions pkg/database/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package database

import (
"context"
"crypto/sha256"
"database/sql"
"database/sql/driver"
"encoding/hex"
Expand Down Expand Up @@ -165,6 +166,12 @@ type Context struct {
Json datatypes.JSON `gorm:"not null;unique;index"`
}

// GetJsonHash returns hash of the Context.Json
func (c Context) GetJsonHash() string {
hash := sha256.Sum256(c.Json)
return string(hash[:])
}

type AlembicVersion struct {
Version string `gorm:"column:version_num;type:varchar(32);not null;primaryKey"`
}
Expand Down
1 change: 0 additions & 1 deletion python/fasttrackml/LICENSE

This file was deleted.

1 change: 1 addition & 0 deletions python/fasttrackml/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
../LICENSE
1 change: 0 additions & 1 deletion python/fasttrackml/README.md

This file was deleted.

1 change: 1 addition & 0 deletions python/fasttrackml/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
../README.md
Loading

0 comments on commit dd785a7

Please sign in to comment.