Skip to content

Commit

Permalink
Use JSONB instead of JSON for context object in SQLite. (#871)
Browse files Browse the repository at this point in the history
Use `jsonb` type instead of `json` for metric `context` object.
  • Loading branch information
dsuhinin authored Feb 7, 2024
1 parent 3a12ef0 commit 0bda8b4
Show file tree
Hide file tree
Showing 14 changed files with 424 additions and 76 deletions.
15 changes: 14 additions & 1 deletion pkg/api/aim/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"math"
"reflect"
"time"

"github.com/gofiber/fiber/v2"
Expand Down Expand Up @@ -41,7 +42,7 @@ func RunsSearchAsCSVResponse(ctx *fiber.Ctx, runs []database.Run, excludeTraces,
if metric.IsNan {
v = math.NaN()
}
key := fmt.Sprintf("%s %s", metric.Key, metric.Context.Json.String())
key := fmt.Sprintf("%s %s", metric.Key, string(metric.Context.Json))
if _, ok := metricData[key]; ok {
metricData[key][run.ID] = v
} else {
Expand Down Expand Up @@ -270,3 +271,15 @@ func RunsSearchAsStreamResponse(
log.Infof("body - %s %s %s", time.Since(start), ctx.Method(), ctx.Path())
})
}

// CompareJson compares two json objects.
func CompareJson(json1, json2 []byte) bool {
var j, j2 interface{}
if err := json.Unmarshal(json1, &j); err != nil {
return false
}
if err := json.Unmarshal(json2, &j2); err != nil {
return false
}
return reflect.DeepEqual(j2, j)
}
33 changes: 11 additions & 22 deletions pkg/api/aim/runs.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/G-Research/fasttrackml/pkg/api/mlflow/api"
"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/repositories"
"github.com/G-Research/fasttrackml/pkg/common/db/types"
"github.com/G-Research/fasttrackml/pkg/common/middleware/namespace"
"github.com/G-Research/fasttrackml/pkg/database"
)
Expand Down Expand Up @@ -170,14 +171,14 @@ func GetRunMetrics(c *fiber.Ctx) error {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}

metricKeysMap, contexts := make(fiber.Map, len(b)), make([]string, 0, len(b))
metricKeysMap, contexts := make(fiber.Map, len(b)), make([]types.JSONB, 0, len(b))
for _, m := range b {
if m.Context != nil {
serializedContext, err := json.Marshal(m.Context)
if err != nil {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}
contexts = append(contexts, string(serializedContext))
contexts = append(contexts, serializedContext)
}
metricKeysMap[m.Name] = nil
}
Expand Down Expand Up @@ -217,11 +218,7 @@ func GetRunMetrics(c *fiber.Ctx) error {
func() *gorm.DB {
query := database.DB
for _, context := range contexts {
if query.Dialector.Name() == database.SQLiteDialectorName {
query = query.Or("json(json) = json(?)", context)
} else {
query = query.Or("json = ?", context)
}
query = query.Or("json = ?", context)
}
return query
}(),
Expand All @@ -239,7 +236,7 @@ func GetRunMetrics(c *fiber.Ctx) error {
name string
iters []int
values []*float64
context datatypes.JSON
context json.RawMessage
}, len(metricKeys))

for _, m := range data {
Expand All @@ -254,7 +251,7 @@ func GetRunMetrics(c *fiber.Ctx) error {
k.name = m.Key
k.iters = append(k.iters, int(m.Iter))
k.values = append(k.values, pv)
k.context = m.Context.Json
k.context = json.RawMessage(m.Context.Json)
metrics[key] = k
}

Expand Down Expand Up @@ -869,7 +866,7 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}

values, capacity, contextsMap := []any{}, 0, map[string]string{}
values, capacity, contextsMap := []any{}, 0, map[string]types.JSONB{}
for _, r := range b.Runs {
for _, t := range r.Traces {
l := t.Slice[2]
Expand All @@ -885,20 +882,16 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
contextHash := fmt.Sprintf("%x", sum)
_, ok := contextsMap[contextHash]
if !ok {
contextsMap[contextHash] = string(data)
contextsMap[contextHash] = data
}
values = append(values, r.ID, t.Name, string(data), float32(l))
values = append(values, r.ID, t.Name, data, float32(l))
}
}

// map context values to context ids
query := database.DB
for _, context := range contextsMap {
if query.Dialector.Name() == database.SQLiteDialectorName {
query = query.Or("json(contexts.json) = json(?)", context)
} else {
query = query.Or("contexts.json = ?", context)
}
query = query.Or("contexts.json = ?", context)
}
var contexts []database.Context
if err := query.Find(&contexts).Error; err != nil {
Expand All @@ -908,11 +901,7 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
// add context ids to `values`
for _, context := range contexts {
for i := 2; i < len(values); i += 4 {
json, err := json.Marshal(context.Json)
if err != nil {
return api.NewInternalError("error serializing context: %s", err)
}
if values[i] == string(json) {
if CompareJson(values[i].([]byte), context.Json) {
values[i] = context.ID
}
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/api/mlflow/controller/metric.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (c Controller) GetMetricHistories(ctx *fiber.Ctx) error {
} else {
b.Field(4).(*array.Float64Builder).Append(m.Value)
}
b.Field(5).(*array.StringBuilder).Append(m.Context.Json.String())
b.Field(5).(*array.StringBuilder).Append(string(m.Context.Json))
if (i+1)%100000 == 0 {
if err := WriteStreamingRecord(writer, b.NewRecord()); err != nil {
return fmt.Errorf("unable to write Arrow record batch: %w", err)
Expand Down
8 changes: 4 additions & 4 deletions pkg/api/mlflow/dao/models/metric.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ import (
"crypto/sha256"
"fmt"

"gorm.io/datatypes"
"github.com/G-Research/fasttrackml/pkg/common/db/types"
)

// DefaultContext is the default metric context
var DefaultContext = Context{Json: datatypes.JSON("{}")}
var DefaultContext = Context{Json: types.JSONB("{}")}

// Metric represents model to work with `metrics` table.
type Metric struct {
Expand Down Expand Up @@ -48,8 +48,8 @@ func (m LatestMetric) UniqueKey() string {

// Context represents model to work with `contexts` table.
type Context struct {
ID uint `gorm:"primaryKey;autoIncrement"`
Json datatypes.JSON `gorm:"not null;unique;index"`
ID uint `gorm:"primaryKey;autoIncrement"`
Json types.JSONB `gorm:"not null;unique;index"`
}

// GetJsonHash returns hash of the Context.Json
Expand Down
86 changes: 86 additions & 0 deletions pkg/common/db/types/jsonb.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package types

import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"

"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)

// JSONB defined JSONB data type, need to implements driver.Valuer, sql.Scanner interface
type JSONB json.RawMessage

// Value return json value, implement driver.Valuer interface
func (j JSONB) Value() (driver.Value, error) {
if len(j) == 0 {
return nil, nil
}
return string(j), nil
}

// Scan scan value into Jsonb, implements sql.Scanner interface
func (j *JSONB) Scan(value interface{}) error {
if value == nil {
*j = JSONB("null")
return nil
}
var bytes []byte
switch v := value.(type) {
case []byte:
if len(v) > 0 {
bytes = make([]byte, len(v))
copy(bytes, v)
}
case string:
bytes = []byte(v)
default:
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value))
}

result := json.RawMessage(bytes)
*j = JSONB(result)
return nil
}

// MarshalJSON to output non base64 encoded []byte
func (j JSONB) MarshalJSON() ([]byte, error) {
return json.RawMessage(j).MarshalJSON()
}

// UnmarshalJSON to deserialize []byte
func (j *JSONB) UnmarshalJSON(b []byte) error {
result := json.RawMessage{}
err := result.UnmarshalJSON(b)
*j = JSONB(result)
return err
}

func (j JSONB) String() string {
return string(j)
}

// GormDataType gorm common data type
func (JSONB) GormDataType() string {
return "json"
}

// GormDBDataType gorm db data type
func (JSONB) GormDBDataType(db *gorm.DB, field *schema.Field) string {
return "JSONB"
}

// GormValue gorm db actual value
// nolint
func (js JSONB) GormValue(ctx context.Context, db *gorm.DB) clause.Expr {
if len(js) == 0 {
return gorm.Expr("NULL")
}

data, _ := js.MarshalJSON()
return gorm.Expr("?", string(data))
}
16 changes: 12 additions & 4 deletions pkg/database/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ import (
"time"

log "github.com/sirupsen/logrus"
"gorm.io/datatypes"
"gorm.io/gorm"
"gorm.io/gorm/logger"

"github.com/G-Research/fasttrackml/pkg/api/mlflow/common"
"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
"github.com/G-Research/fasttrackml/pkg/common/db/types"
"github.com/G-Research/fasttrackml/pkg/database/migrations/v_0001"
"github.com/G-Research/fasttrackml/pkg/database/migrations/v_0002"
"github.com/G-Research/fasttrackml/pkg/database/migrations/v_0003"
Expand All @@ -24,6 +24,7 @@ import (
"github.com/G-Research/fasttrackml/pkg/database/migrations/v_0007"
"github.com/G-Research/fasttrackml/pkg/database/migrations/v_0008"
"github.com/G-Research/fasttrackml/pkg/database/migrations/v_0009"
"github.com/G-Research/fasttrackml/pkg/database/migrations/v_0010"
)

var supportedAlembicVersions = []string{
Expand All @@ -45,7 +46,7 @@ func CheckAndMigrateDB(migrate bool, db *gorm.DB) error {
tx.First(&schemaVersion)
}

if !slices.Contains(supportedAlembicVersions, alembicVersion.Version) || schemaVersion.Version != v_0009.Version {
if !slices.Contains(supportedAlembicVersions, alembicVersion.Version) || schemaVersion.Version != v_0010.Version {
if !migrate && alembicVersion.Version != "" {
return fmt.Errorf(
"unsupported database schema versions alembic %s, FastTrackML %s",
Expand Down Expand Up @@ -190,6 +191,13 @@ func CheckAndMigrateDB(migrate bool, db *gorm.DB) error {
if err := v_0009.Migrate(db); err != nil {
return fmt.Errorf("error migrating database to FastTrackML schema %s: %w", v_0009.Version, err)
}
fallthrough

case v_0009.Version:
log.Infof("Migrating database to FastTrackML schema %s", v_0010.Version)
if err := v_0010.Migrate(db); err != nil {
return fmt.Errorf("error migrating database to FastTrackML schema %s: %w", v_0010.Version, err)
}

default:
return fmt.Errorf("unsupported database FastTrackML schema version %s", schemaVersion.Version)
Expand Down Expand Up @@ -221,7 +229,7 @@ func CheckAndMigrateDB(migrate bool, db *gorm.DB) error {
Version: "97727af70f4d",
})
tx.Create(&SchemaVersion{
Version: v_0009.Version,
Version: v_0010.Version,
})
tx.Commit()
if tx.Error != nil {
Expand Down Expand Up @@ -308,7 +316,7 @@ func CreateDefaultExperiment(db *gorm.DB, defaultArtifactRoot string) error {

// CreateDefaultMetricContext creates the default metric context if it doesn't exist.
func CreateDefaultMetricContext(db *gorm.DB) error {
defaultContext := Context{Json: datatypes.JSON("{}")}
defaultContext := Context{Json: types.JSONB("{}")}
if err := db.First(&defaultContext).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Info("Creating default context")
Expand Down
19 changes: 19 additions & 0 deletions pkg/database/migrations/v_0010/migrate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package v_0010

import (
"gorm.io/gorm"
)

const Version = "10d125c68d9a"

func Migrate(db *gorm.DB) error {
return db.Transaction(func(tx *gorm.DB) error {
if err := tx.AutoMigrate(&Context{}); err != nil {
return err
}
return tx.Model(&SchemaVersion{}).
Where("1 = 1").
Update("Version", Version).
Error
})
}
Loading

0 comments on commit 0bda8b4

Please sign in to comment.