Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use JSONB instead of JSON for context object in SQLite. #871

Merged
merged 15 commits into from
Feb 7, 2024
Merged
15 changes: 14 additions & 1 deletion pkg/api/aim/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"math"
"reflect"
"time"

"github.com/gofiber/fiber/v2"
Expand Down Expand Up @@ -41,7 +42,7 @@ func RunsSearchAsCSVResponse(ctx *fiber.Ctx, runs []database.Run, excludeTraces,
if metric.IsNan {
v = math.NaN()
}
key := fmt.Sprintf("%s %s", metric.Key, metric.Context.Json.String())
key := fmt.Sprintf("%s %s", metric.Key, string(metric.Context.Json))
if _, ok := metricData[key]; ok {
metricData[key][run.ID] = v
} else {
Expand Down Expand Up @@ -270,3 +271,15 @@ func RunsSearchAsStreamResponse(
log.Infof("body - %s %s %s", time.Since(start), ctx.Method(), ctx.Path())
})
}

// CompareJson compares two json objects.
func CompareJson(json1, json2 []byte) bool {
var j, j2 interface{}
if err := json.Unmarshal(json1, &j); err != nil {
return false
}
if err := json.Unmarshal(json2, &j2); err != nil {
return false
}
return reflect.DeepEqual(j2, j)
}
33 changes: 11 additions & 22 deletions pkg/api/aim/runs.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/G-Research/fasttrackml/pkg/api/mlflow/api"
"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/repositories"
"github.com/G-Research/fasttrackml/pkg/common/db/types"
"github.com/G-Research/fasttrackml/pkg/common/middleware/namespace"
"github.com/G-Research/fasttrackml/pkg/database"
)
Expand Down Expand Up @@ -170,14 +171,14 @@ func GetRunMetrics(c *fiber.Ctx) error {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}

metricKeysMap, contexts := make(fiber.Map, len(b)), make([]string, 0, len(b))
metricKeysMap, contexts := make(fiber.Map, len(b)), make([]types.JSONB, 0, len(b))
for _, m := range b {
if m.Context != nil {
serializedContext, err := json.Marshal(m.Context)
if err != nil {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}
contexts = append(contexts, string(serializedContext))
contexts = append(contexts, serializedContext)
}
metricKeysMap[m.Name] = nil
}
Expand Down Expand Up @@ -217,11 +218,7 @@ func GetRunMetrics(c *fiber.Ctx) error {
func() *gorm.DB {
query := database.DB
for _, context := range contexts {
if query.Dialector.Name() == database.SQLiteDialectorName {
query = query.Or("json(json) = json(?)", context)
} else {
query = query.Or("json = ?", context)
}
query = query.Or("json = ?", context)
}
return query
}(),
Expand All @@ -239,7 +236,7 @@ func GetRunMetrics(c *fiber.Ctx) error {
name string
iters []int
values []*float64
context datatypes.JSON
context json.RawMessage
}, len(metricKeys))

for _, m := range data {
Expand All @@ -254,7 +251,7 @@ func GetRunMetrics(c *fiber.Ctx) error {
k.name = m.Key
k.iters = append(k.iters, int(m.Iter))
k.values = append(k.values, pv)
k.context = m.Context.Json
k.context = json.RawMessage(m.Context.Json)
metrics[key] = k
}

Expand Down Expand Up @@ -869,7 +866,7 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}

values, capacity, contextsMap := []any{}, 0, map[string]string{}
values, capacity, contextsMap := []any{}, 0, map[string]types.JSONB{}
for _, r := range b.Runs {
for _, t := range r.Traces {
l := t.Slice[2]
Expand All @@ -885,20 +882,16 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
contextHash := fmt.Sprintf("%x", sum)
_, ok := contextsMap[contextHash]
if !ok {
contextsMap[contextHash] = string(data)
contextsMap[contextHash] = data
}
values = append(values, r.ID, t.Name, string(data), float32(l))
values = append(values, r.ID, t.Name, data, float32(l))
}
}

// map context values to context ids
query := database.DB
for _, context := range contextsMap {
if query.Dialector.Name() == database.SQLiteDialectorName {
query = query.Or("json(contexts.json) = json(?)", context)
} else {
query = query.Or("contexts.json = ?", context)
}
query = query.Or("contexts.json = ?", context)
}
var contexts []database.Context
if err := query.Find(&contexts).Error; err != nil {
Expand All @@ -908,11 +901,7 @@ func SearchAlignedMetrics(c *fiber.Ctx) error {
// add context ids to `values`
for _, context := range contexts {
for i := 2; i < len(values); i += 4 {
json, err := json.Marshal(context.Json)
if err != nil {
return api.NewInternalError("error serializing context: %s", err)
}
if values[i] == string(json) {
if CompareJson(values[i].([]byte), context.Json) {
values[i] = context.ID
}
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/api/mlflow/controller/metric.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (c Controller) GetMetricHistories(ctx *fiber.Ctx) error {
} else {
b.Field(4).(*array.Float64Builder).Append(m.Value)
}
b.Field(5).(*array.StringBuilder).Append(m.Context.Json.String())
b.Field(5).(*array.StringBuilder).Append(string(m.Context.Json))
if (i+1)%100000 == 0 {
if err := WriteStreamingRecord(writer, b.NewRecord()); err != nil {
return fmt.Errorf("unable to write Arrow record batch: %w", err)
Expand Down
8 changes: 4 additions & 4 deletions pkg/api/mlflow/dao/models/metric.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ import (
"crypto/sha256"
"fmt"

"gorm.io/datatypes"
"github.com/G-Research/fasttrackml/pkg/common/db/types"
)

// DefaultContext is the default metric context
var DefaultContext = Context{Json: datatypes.JSON("{}")}
var DefaultContext = Context{Json: types.JSONB("{}")}

// Metric represents model to work with `metrics` table.
type Metric struct {
Expand Down Expand Up @@ -48,8 +48,8 @@ func (m LatestMetric) UniqueKey() string {

// Context represents model to work with `contexts` table.
type Context struct {
ID uint `gorm:"primaryKey;autoIncrement"`
Json datatypes.JSON `gorm:"not null;unique;index"`
ID uint `gorm:"primaryKey;autoIncrement"`
Json types.JSONB `gorm:"not null;unique;index"`
}

// GetJsonHash returns hash of the Context.Json
Expand Down
86 changes: 86 additions & 0 deletions pkg/common/db/types/jsonb.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package types

import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"

"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)

// JSONB defined JSONB data type, need to implements driver.Valuer, sql.Scanner interface
type JSONB json.RawMessage

// Value return json value, implement driver.Valuer interface
func (j JSONB) Value() (driver.Value, error) {
if len(j) == 0 {
return nil, nil
}
return string(j), nil
}

// Scan scan value into Jsonb, implements sql.Scanner interface
func (j *JSONB) Scan(value interface{}) error {
if value == nil {
*j = JSONB("null")
return nil
}
var bytes []byte
switch v := value.(type) {
case []byte:
if len(v) > 0 {
bytes = make([]byte, len(v))
copy(bytes, v)
}
case string:
bytes = []byte(v)
default:
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value))
}

result := json.RawMessage(bytes)
*j = JSONB(result)
return nil
}

// MarshalJSON to output non base64 encoded []byte
func (j JSONB) MarshalJSON() ([]byte, error) {
return json.RawMessage(j).MarshalJSON()
}

// UnmarshalJSON to deserialize []byte
func (j *JSONB) UnmarshalJSON(b []byte) error {
result := json.RawMessage{}
err := result.UnmarshalJSON(b)
*j = JSONB(result)
return err
}

func (j JSONB) String() string {
return string(j)
}

// GormDataType gorm common data type
func (JSONB) GormDataType() string {
return "json"
}

// GormDBDataType gorm db data type
func (JSONB) GormDBDataType(db *gorm.DB, field *schema.Field) string {
return "JSONB"
}

// GormValue gorm db actual value
// nolint
func (js JSONB) GormValue(ctx context.Context, db *gorm.DB) clause.Expr {
if len(js) == 0 {
return gorm.Expr("NULL")
}

data, _ := js.MarshalJSON()
return gorm.Expr("?", string(data))
}
16 changes: 12 additions & 4 deletions pkg/database/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ import (
"time"

log "github.com/sirupsen/logrus"
"gorm.io/datatypes"
"gorm.io/gorm"
"gorm.io/gorm/logger"

"github.com/G-Research/fasttrackml/pkg/api/mlflow/common"
"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
"github.com/G-Research/fasttrackml/pkg/common/db/types"
"github.com/G-Research/fasttrackml/pkg/database/migrations/v_0001"
"github.com/G-Research/fasttrackml/pkg/database/migrations/v_0002"
"github.com/G-Research/fasttrackml/pkg/database/migrations/v_0003"
Expand All @@ -24,6 +24,7 @@ import (
"github.com/G-Research/fasttrackml/pkg/database/migrations/v_0007"
"github.com/G-Research/fasttrackml/pkg/database/migrations/v_0008"
"github.com/G-Research/fasttrackml/pkg/database/migrations/v_0009"
"github.com/G-Research/fasttrackml/pkg/database/migrations/v_0010"
)

var supportedAlembicVersions = []string{
Expand All @@ -45,7 +46,7 @@ func CheckAndMigrateDB(migrate bool, db *gorm.DB) error {
tx.First(&schemaVersion)
}

if !slices.Contains(supportedAlembicVersions, alembicVersion.Version) || schemaVersion.Version != v_0009.Version {
if !slices.Contains(supportedAlembicVersions, alembicVersion.Version) || schemaVersion.Version != v_0010.Version {
if !migrate && alembicVersion.Version != "" {
return fmt.Errorf(
"unsupported database schema versions alembic %s, FastTrackML %s",
Expand Down Expand Up @@ -190,6 +191,13 @@ func CheckAndMigrateDB(migrate bool, db *gorm.DB) error {
if err := v_0009.Migrate(db); err != nil {
return fmt.Errorf("error migrating database to FastTrackML schema %s: %w", v_0009.Version, err)
}
fallthrough

case v_0009.Version:
log.Infof("Migrating database to FastTrackML schema %s", v_0010.Version)
if err := v_0010.Migrate(db); err != nil {
return fmt.Errorf("error migrating database to FastTrackML schema %s: %w", v_0010.Version, err)
}

default:
return fmt.Errorf("unsupported database FastTrackML schema version %s", schemaVersion.Version)
Expand Down Expand Up @@ -221,7 +229,7 @@ func CheckAndMigrateDB(migrate bool, db *gorm.DB) error {
Version: "97727af70f4d",
})
tx.Create(&SchemaVersion{
Version: v_0009.Version,
Version: v_0010.Version,
})
tx.Commit()
if tx.Error != nil {
Expand Down Expand Up @@ -308,7 +316,7 @@ func CreateDefaultExperiment(db *gorm.DB, defaultArtifactRoot string) error {

// CreateDefaultMetricContext creates the default metric context if it doesn't exist.
func CreateDefaultMetricContext(db *gorm.DB) error {
defaultContext := Context{Json: datatypes.JSON("{}")}
defaultContext := Context{Json: types.JSONB("{}")}
if err := db.First(&defaultContext).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Info("Creating default context")
Expand Down
19 changes: 19 additions & 0 deletions pkg/database/migrations/v_0010/migrate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package v_0010

import (
"gorm.io/gorm"
)

const Version = "10d125c68d9a"

func Migrate(db *gorm.DB) error {
return db.Transaction(func(tx *gorm.DB) error {
if err := tx.AutoMigrate(&Context{}); err != nil {
return err
}
return tx.Model(&SchemaVersion{}).
Where("1 = 1").
Update("Version", Version).
Error
})
}
Loading
Loading