Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sqlite to postgres db import #566

Merged
merged 19 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ jobs:
env:
DOCKER_BUILDKIT: 1
FML_DATABASE_URI: ${{ matrix.database-uri }}
FML_OUTPUT_DATABASE_URI: ${{ matrix.database-uri }}
suprjinx marked this conversation as resolved.
Show resolved Hide resolved

- name: Save cache
if: steps.cache.outputs.cache-hit != 'true'
Expand Down
15 changes: 15 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@
"--database-uri",
"sqlite:///tmp/fasttrackml.db"
],
},
{
"name": "Launch database import (sqlite to pg)",
"type": "go",
"request": "launch",
"mode": "auto",
"program": "${workspaceFolder}/main.go",
"buildFlags": "-tags '${config:go.buildTags}'",
"args": [
"import",
"--input-database-uri",
"sqlite://fasttrackml.db",
"--output-database-uri",
"postgres://postgres:postgres@localhost/postgres"
],
}
]
}
5 changes: 5 additions & 0 deletions pkg/cmd/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ func initDBs() (input, output database.DBProvider, err error) {
if err != nil {
return input, output, fmt.Errorf("error connecting to output DB: %w", err)
}

if err := database.CheckAndMigrateDB(true, output.GormDB()); err != nil {
return nil, nil, fmt.Errorf("error running database migration: %w", err)
}

suprjinx marked this conversation as resolved.
Show resolved Hide resolved
return
}

Expand Down
83 changes: 74 additions & 9 deletions pkg/database/import.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
package database

import (
"fmt"
"reflect"
"regexp"

"github.com/google/uuid"
"github.com/rotisserie/eris"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
"gorm.io/gorm/clause"

"github.com/G-Research/fasttrackml/pkg/api/mlflow/common"
)

type experimentInfo struct {
destID int64
sourceID int64
}

var uuidRegexp = regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)

// Importer will handle transport of data from source to destination db.
type Importer struct {
destDB *gorm.DB
Expand Down Expand Up @@ -47,6 +56,9 @@ func (s *Importer) Import() error {
return eris.Wrapf(err, "error importing table %s", table)
}
}
if err := s.updateNamespaceDefaultExperiment(); err != nil {
return eris.Wrap(err, "error updating namespace default experiment")
}
return nil
}

Expand Down Expand Up @@ -80,6 +92,10 @@ func (s *Importer) importExperiments() error {
CreationTime: scannedItem.CreationTime,
LastUpdateTime: scannedItem.LastUpdateTime,
}
// keep default experiment ID, but otherwise draw new one
if *scannedItem.ID == int32(0) {
newItem.ID = scannedItem.ID
}
if err := destTX.
Where(Experiment{Name: scannedItem.Name}).
FirstOrCreate(&newItem).Error; err != nil {
Expand Down Expand Up @@ -159,26 +175,75 @@ func (s *Importer) saveExperimentInfo(source, dest Experiment) {

// translateFields alters row before creation as needed (especially, replacing old experiment_id with new).
func (s *Importer) translateFields(item map[string]any) (map[string]any, error) {
// boolean is numeric when coming from sqlite
if isNaN, ok := item["is_nan"]; ok {
switch v := isNaN.(type) {
case bool:
break
default:
item["is_nan"] = v != 0.0
// boolean fields are numeric when coming from sqlite
booleanFields := []string{"is_nan", "is_archived"}
for _, field := range booleanFields {
if fieldVal, ok := item[field]; ok {
switch v := fieldVal.(type) {
case bool:
break
default:
item[field] = v != 0.0
}
}
}
// items with experiment_id fk need to reference the new ID
// items with experiment_id need to reference the new ID
if expID, ok := item["experiment_id"]; ok {
id, ok := expID.(int64)
if !ok {
return nil, eris.Errorf("unable to assert experiment_id as int64: %d", expID)
return nil, eris.Errorf("unable to assert %s as int64: %d", "experiment_id", expID)
}
for _, expInfo := range s.experimentInfos {
if expInfo.sourceID == id {
item["experiment_id"] = expInfo.destID
}
}
}
// items with string uuid need to translate to UUID native type
uuidFields := []string{"id", "app_id"}
for _, field := range uuidFields {
if srcUUID, ok := item[field]; ok {
// when uuid, this field will be pointer to interface{} and requires some reflection
stringUUID := fmt.Sprintf("%v", reflect.Indirect(reflect.ValueOf(srcUUID)))
if uuidRegexp.MatchString(stringUUID) {
binID, err := uuid.Parse(stringUUID)
if err != nil {
return nil, eris.Errorf("unable to create binary UUID field from string: %s", stringUUID)
}
item[field] = binID
}
}
}
return item, nil
}

// updateNamespaceDefaultExperiment updates the default_experiment_id for all namespaces
// when its related experiment received a new id.
func (s Importer) updateNamespaceDefaultExperiment() error {
// Start transaction in the destDB
err := s.destDB.Transaction(func(destTX *gorm.DB) error {
// Get namespaces
var namespaces []Namespace
if err := destTX.Model(Namespace{}).Find(&namespaces).Error; err != nil {
return eris.Wrap(err, "error reading namespaces in destination")
}
for _, ns := range namespaces {
updatedExperimentID := ns.DefaultExperimentID
for _, expInfo := range s.experimentInfos {
if ns.DefaultExperimentID != nil && expInfo.sourceID == int64(*ns.DefaultExperimentID) {
updatedExperimentID = common.GetPointer[int32](int32(expInfo.destID))
break
}
}
if err := destTX.
Model(Namespace{}).
Where(Namespace{ID: ns.ID}).
Update("default_experiment_id", updatedExperimentID).Error; err != nil {
return eris.Wrap(err, "error updating destination namespace row")
}
}
log.Infof("Updating namespaces - processed %d records", len(namespaces))
return nil
})
return err
}
82 changes: 45 additions & 37 deletions tests/integration/golang/database/import_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,37 +59,73 @@ func (s *ImportTestSuite) SetupTest() {
require.Nil(s.T(), database.CreateDefaultExperiment(db.GormDB(), "s3://fasttrackml"))
s.inputDB = db.GormDB()

inputExperimentFixtures, err := fixtures.NewExperimentFixtures(db.GormDB())
require.Nil(s.T(), err)
inputRunFixtures, err := fixtures.NewRunFixtures(db.GormDB())
require.Nil(s.T(), err)
s.inputRunFixtures = inputRunFixtures
s.populateDB(s.inputDB)
dave-gantenbein marked this conversation as resolved.
Show resolved Hide resolved

// prepare output database.
db, err = database.NewDBProvider(
helpers.GetOutputDatabaseUri(),
1*time.Second,
20,
)
require.Nil(s.T(), err)
require.Nil(s.T(), database.CheckAndMigrateDB(true, db.GormDB()))
require.Nil(s.T(), database.CreateDefaultNamespace(db.GormDB()))
require.Nil(s.T(), database.CreateDefaultExperiment(db.GormDB(), "s3://fasttrackml"))
s.outputDB = db.GormDB()

outputRunFixtures, err := fixtures.NewRunFixtures(db.GormDB())
require.Nil(s.T(), err)
s.outputRunFixtures = outputRunFixtures

s.populatedRowCounts = rowCounts{
namespaces: 1,
experiments: 3,
runs: 10,
distinctRunExperimentIDs: 2,
metrics: 40,
latestMetrics: 20,
tags: 10,
params: 20,
dashboards: 2,
apps: 1,
}
}

func (s *ImportTestSuite) populateDB(db *gorm.DB) {
experimentFixtures, err := fixtures.NewExperimentFixtures(db)
require.Nil(s.T(), err)

runFixtures, err := fixtures.NewRunFixtures(db)
require.Nil(s.T(), err)

// experiment 1
experiment, err := inputExperimentFixtures.CreateExperiment(context.Background(), &models.Experiment{
experiment, err := experimentFixtures.CreateExperiment(context.Background(), &models.Experiment{
Name: uuid.New().String(),
NamespaceID: 1,
LifecycleStage: models.LifecycleStageActive,
})
require.Nil(s.T(), err)

runs, err := inputRunFixtures.CreateExampleRuns(context.Background(), experiment, 5)
runs, err := runFixtures.CreateExampleRuns(context.Background(), experiment, 5)
require.Nil(s.T(), err)
s.runs = runs

// experiment 2
experiment, err = inputExperimentFixtures.CreateExperiment(context.Background(), &models.Experiment{
experiment, err = experimentFixtures.CreateExperiment(context.Background(), &models.Experiment{
Name: uuid.New().String(),
NamespaceID: 1,
LifecycleStage: models.LifecycleStageActive,
})
require.Nil(s.T(), err)

runs, err = inputRunFixtures.CreateExampleRuns(context.Background(), experiment, 5)
runs, err = runFixtures.CreateExampleRuns(context.Background(), experiment, 5)
require.Nil(s.T(), err)
s.runs = runs

appFixtures, err := fixtures.NewAppFixtures(db.GormDB())
appFixtures, err := fixtures.NewAppFixtures(db)
require.Nil(s.T(), err)
app, err := appFixtures.CreateApp(context.Background(), &database.App{
Base: database.Base{
Expand All @@ -102,7 +138,7 @@ func (s *ImportTestSuite) SetupTest() {
})
require.Nil(s.T(), err)

dashboardFixtures, err := fixtures.NewDashboardFixtures(db.GormDB())
dashboardFixtures, err := fixtures.NewDashboardFixtures(db)
require.Nil(s.T(), err)

// dashboard 1
Expand All @@ -126,34 +162,6 @@ func (s *ImportTestSuite) SetupTest() {
Name: uuid.NewString(),
})
require.Nil(s.T(), err)
// prepare output database.
db, err = database.NewDBProvider(
helpers.GetOutputDatabaseUri(),
1*time.Second,
20,
)
require.Nil(s.T(), err)
require.Nil(s.T(), database.CheckAndMigrateDB(true, db.GormDB()))
require.Nil(s.T(), database.CreateDefaultNamespace(db.GormDB()))
require.Nil(s.T(), database.CreateDefaultExperiment(db.GormDB(), "s3://fasttrackml"))
s.outputDB = db.GormDB()

outputRunFixtures, err := fixtures.NewRunFixtures(db.GormDB())
require.Nil(s.T(), err)
s.outputRunFixtures = outputRunFixtures

s.populatedRowCounts = rowCounts{
namespaces: 1,
experiments: 3,
runs: 10,
distinctRunExperimentIDs: 2,
metrics: 40,
latestMetrics: 20,
tags: 10,
params: 20,
dashboards: 2,
apps: 1,
}
}

func (s *ImportTestSuite) Test_Ok() {
Expand All @@ -180,7 +188,7 @@ func (s *ImportTestSuite) Test_Ok() {
err = importer.Import()
require.Nil(s.T(), err)

// dest DB should still only have the expected
// dest DB should still only have the expected (idempotent)
validateRowCounts(s.T(), s.outputDB, s.populatedRowCounts)

// confirm row-for-row equality
Expand Down