Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Context query subscript #791

Merged
merged 68 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
fcee4fb
WIP for parsing context dictionary suscript
suprjinx Dec 20, 2023
e5476a6
Fix needed for migrating existing data
suprjinx Jan 3, 2024
8aecc94
Set latest_metric.context_id
suprjinx Jan 3, 2024
c9cc2bb
Use AutoMigrate and new unique index for context_id col
suprjinx Jan 3, 2024
f08e216
switch back to primary keys
suprjinx Jan 3, 2024
71e8bfb
Create new tables and copy data
suprjinx Jan 3, 2024
be16a83
Drop old indexes
suprjinx Jan 3, 2024
b83615c
make helper for deleting index if exists
suprjinx Jan 3, 2024
5f90110
Postgres does not relax FK constraints, so we have to supply the real
suprjinx Jan 4, 2024
0c6b35f
Remove the default:0 gorm annotation -- will never work
suprjinx Jan 4, 2024
6b99867
add test to make sure latest_metrics are differentiated by context
suprjinx Jan 4, 2024
d850736
lint
suprjinx Jan 4, 2024
a1bd805
Clarified comments
suprjinx Jan 4, 2024
7e36e35
Properly consolidate metrics by contiguous key and context
suprjinx Jan 4, 2024
e87ff9a
Merge branch 'main' into context-query-subscript
suprjinx Jan 4, 2024
de5cbbc
Merge branch 'main' into fix-context-migration-for-upgrades
suprjinx Jan 5, 2024
c207009
add dictionary parsing
suprjinx Jan 5, 2024
7160444
add subscript handler for []clause.Expression (which is what the
suprjinx Jan 5, 2024
dbb40e2
remove extra blank line
suprjinx Jan 5, 2024
7e91551
DRY up the metric attribute getter
suprjinx Jan 5, 2024
db649de
Lint
suprjinx Jan 5, 2024
f47cba3
Fix lints
suprjinx Jan 5, 2024
e3937fb
Fix another lint
suprjinx Jan 5, 2024
da289c6
another lint fix
suprjinx Jan 6, 2024
f87e085
fix the metric attribute getter
suprjinx Jan 6, 2024
1012b2c
add support for metric subscript tuple
suprjinx Jan 7, 2024
7979db1
additional test reversing tuple order
suprjinx Jan 7, 2024
c4748a0
Move conditions append to the subscript slicer
suprjinx Jan 7, 2024
c706d0a
Add tuple test for Postgres
suprjinx Jan 7, 2024
f6efd8d
lint
suprjinx Jan 7, 2024
e16dd20
WIP refactor to use one latest_metrics join
suprjinx Jan 11, 2024
4545edf
WIP cont'd
suprjinx Jan 11, 2024
108a7e7
Update query parser to handle tuple case
suprjinx Jan 11, 2024
0914330
Fix lint issues
suprjinx Jan 12, 2024
4fa155e
Add row-count verification to migration
suprjinx Jan 12, 2024
77b33bd
refactor migration for tables/indexes map
suprjinx Jan 12, 2024
17e8458
PR cleanup
suprjinx Jan 12, 2024
20b82b0
lint
suprjinx Jan 12, 2024
a93f904
Merge branch 'main' into fix-context-migration-for-upgrades
suprjinx Jan 12, 2024
a0a5fb8
more lint
suprjinx Jan 12, 2024
633aa4d
Merge branch 'fix-context-migration-for-upgrades' of github.com:suprj…
suprjinx Jan 12, 2024
8a9854f
fix typo
suprjinx Jan 12, 2024
35c2f1a
Merge branch 'fix-context-migration-for-upgrades' into context-query-…
suprjinx Jan 12, 2024
401ad4d
WIP adapting tests
suprjinx Jan 12, 2024
ce294f8
Fixed remaining tests by adding joins ordering slice
suprjinx Jan 12, 2024
c0479be
Make sure latest_metrics is joined for query parser
suprjinx Jan 13, 2024
26ae1ea
cosmetic query change
suprjinx Jan 13, 2024
76ec123
put lastest_metrics and context joins into the input query
suprjinx Jan 13, 2024
d5a84da
adjust query_test, metric and context joins are not created by Filter
suprjinx Jan 14, 2024
55d9fd4
Remove the joinKeys slice
suprjinx Jan 14, 2024
20d1b26
Add joins needed by pq.Filter
suprjinx Jan 14, 2024
3da858e
Merge branch 'main' into context-query-subscript
suprjinx Jan 14, 2024
4fc85e9
Can't have metrics without latest_metrics
suprjinx Jan 14, 2024
755d0fb
Merge branch 'main' into context-query-subscript
suprjinx Jan 15, 2024
d79d161
restore unique joins for each metric evaluation
suprjinx Jan 15, 2024
3525982
fix lint and unit tests
suprjinx Jan 15, 2024
6ea1774
comments and a couple of more test cases
suprjinx Jan 15, 2024
cdeb399
Comment and additional tests
suprjinx Jan 15, 2024
1c2816e
Don't add new join for metric.context queries
suprjinx Jan 16, 2024
c1aefc8
lint
suprjinx Jan 16, 2024
59af0fa
lint
suprjinx Jan 16, 2024
28b9e16
errant capitalization removal
suprjinx Jan 16, 2024
fa0b9af
lint
suprjinx Jan 16, 2024
2e5c2b9
use constant for table name
suprjinx Jan 17, 2024
f3d8fef
allow "run.metrics[name]" to count as metric selected
suprjinx Jan 17, 2024
682ba0f
error case with joins map
suprjinx Jan 22, 2024
8ec7363
Merge branch 'main' into context-query-subscript
suprjinx Jan 23, 2024
82c6c9e
Merge branch 'main' into context-query-subscript
suprjinx Jan 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 105 additions & 48 deletions pkg/api/aim/query/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ import (
"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
)

const (
TableContexts = "contexts"
TableLatestMetrics = "latest_metrics"
)

type DefaultExpression struct {
Contains string
Expression string
Expand Down Expand Up @@ -154,6 +159,7 @@ func (qp *QueryParser) Parse(q string) (ParsedQuery, error) {
return pq, nil
}

// Filter will add the appropriate Joins and Where clauses to the tx.
func (pq *parsedQuery) Filter(tx *gorm.DB) *gorm.DB {
for _, j := range pq.joins {
tx.Joins(j.query, j.args...)
Expand Down Expand Up @@ -204,6 +210,10 @@ func (pq *parsedQuery) _parseNode(node ast.Expr) (any, error) {
return pq.parseAttribute(n)
case *ast.Compare:
return pq.parseCompare(n)
case *ast.Dict:
return pq.parseDictionary(n)
case *ast.Tuple:
return pq.parseTuple(n)
default:
return nil, fmt.Errorf("unsupported expression %q", ast.Dump(n))
}
Expand Down Expand Up @@ -404,6 +414,38 @@ func (pq *parsedQuery) parseCompare(node *ast.Compare) (any, error) {
}, nil
}

// parseDictionary will return `clause.And` having joined JsonEq conditions
// derived from the dictionary
func (pq *parsedQuery) parseDictionary(node *ast.Dict) (any, error) {
clauses := make([]clause.Expression, len(node.Keys))
for i, key := range node.Keys {
clauses[i] = JsonEq{
Left: Json{
Column: clause.Column{
Table: TableContexts,
Name: "json",
},
dsuhinin marked this conversation as resolved.
Show resolved Hide resolved
JsonPath: string(key.(*ast.Str).S),
Dialector: pq.qp.Dialector,
},
Value: string(node.Values[i].(*ast.Str).S),
}
}
return clause.And(clauses...), nil
}

func (pq *parsedQuery) parseTuple(node *ast.Tuple) (any, error) {
var err error
list := make([]any, len(node.Elts))
for i, e := range node.Elts {
list[i], err = pq.parseNode(e)
if err != nil {
return nil, err
}
}
return list, nil
}

func (pq *parsedQuery) parseList(node *ast.List) (any, error) {
var err error
list := make([]any, len(node.Elts))
Expand Down Expand Up @@ -488,41 +530,7 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) {
if err != nil {
return nil, err
}
switch v := v.(type) {
case string:
j, ok := pq.joins[fmt.Sprintf("metrics:%s", v)]
if !ok {
alias := fmt.Sprintf("metrics_%d", len(pq.joins))
j = join{
alias: alias,
query: fmt.Sprintf(
"LEFT JOIN latest_metrics %s ON %s.run_uuid = %s.run_uuid AND %s.key = ?",
alias, table, alias, alias,
),
args: []any{v},
}
pq.joins[fmt.Sprintf("metrics:%s", v)] = j
}
return attributeGetter(func(attr string) (any, error) {
var name string
switch attr {
case "last":
name = "value"
case "last_step":
name = "last_iter"
case "first_step":
return 0, nil
default:
return nil, fmt.Errorf("unsupported metrics attribute %q", attr)
}
return clause.Column{
Table: j.alias,
Name: name,
}, nil
}), nil
default:
return nil, fmt.Errorf("unsupported index value type %t", v)
}
return pq.metricSubscriptSlicer(v)
default:
return nil, fmt.Errorf("unsupported slicer %q", ast.Dump(s))
}
Expand Down Expand Up @@ -555,7 +563,7 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) {
Name: "value",
}, nil
default:
return nil, fmt.Errorf("unsupported index value type %t", v)
return nil, fmt.Errorf("unsupported index value type %T", v)
}
default:
return nil, fmt.Errorf("unsupported slicer %q", ast.Dump(s))
Expand Down Expand Up @@ -610,22 +618,12 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) {
return 0, nil
case "context":
return attributeGetter(
func(contextKey string) (any, error) {
// create the join for contexts
_, ok := pq.joins["metric_contexts"]
if !ok {
alias := "contexts"
j := join{
alias: alias,
query: "LEFT JOIN contexts ON latest_metrics.context_id = contexts.id",
}
pq.joins["metric_contexts"] = j
}

func(contextKey string) (any, error) {
// Add a WHERE clause for the context key
return Json{
Column: clause.Column{
Table: "contexts",
Table: TableContexts,
Name: "json",
},
JsonPath: contextKey,
Expand Down Expand Up @@ -728,6 +726,65 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) {
}
}

func (pq *parsedQuery) metricSubscriptSlicer(v any) (any, error) {
switch v := v.(type) {
case string:
// case of metric key
keyEqual := pq.metricSubscriptStringExpression(v)
pq.conditions = append(pq.conditions, keyEqual)
return metricAttributeGetter(TableLatestMetrics)
case clause.Expression:
// case of metric context dictionary
pq.conditions = append(pq.conditions, v)
return metricAttributeGetter(TableLatestMetrics)
case []any:
// case of subscript tuple (string and context dictionary)
if len(v) != 2 {
return nil, fmt.Errorf("unsupported tuple length %d (should be 2)", len(v))
}
metricKey, ok := v[0].(string)
if !ok {
return nil, fmt.Errorf("unsupported tuple value type %T (should be string at 0)", v)
}
metricContextExpression, ok := v[1].(clause.Expression)
if !ok {
return nil, fmt.Errorf("unsupported index value type %T (should be clause.Expression at 1)", v)
}
metricKeyExpression := pq.metricSubscriptStringExpression(metricKey)
pq.conditions = append(pq.conditions, clause.And(metricKeyExpression, metricContextExpression))
return metricAttributeGetter(TableLatestMetrics)
default:
return nil, fmt.Errorf("unsupported index value type %T", v)
}
}

func (pq *parsedQuery) metricSubscriptStringExpression(v string) clause.Expression {
return clause.Eq{
Column: fmt.Sprintf("%s.%s", TableLatestMetrics, "key"),
Value: v,
}
}

func metricAttributeGetter(table string) (any, error) {
return attributeGetter(func(attr string) (any, error) {
var name string
switch attr {
case "last":
name = "value"
case "last_step":
name = "last_iter"
case "first_step":
return 0, nil
default:
return nil, fmt.Errorf("unsupported metrics attribute %q", attr)
}
return clause.Column{
Table: table,
Name: name,
}, nil
}), nil
}

func (pq *parsedQuery) parseNameConstant(node *ast.NameConstant) (any, error) {
switch node.Value.Type() {
case py.NoneTypeType:
Expand Down
63 changes: 51 additions & 12 deletions pkg/api/aim/query/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,24 +110,21 @@ func (s *QueryTestSuite) TestPostgresDialector_Ok() {
name: "TestNegativeInteger",
query: `run.metrics['my_metric'].last < -1`,
expectedSQL: `SELECT "run_uuid" FROM "runs" ` +
`LEFT JOIN latest_metrics metrics_0 ON runs.run_uuid = metrics_0.run_uuid AND metrics_0.key = $1 ` +
`WHERE ("metrics_0"."value" < $2 AND "runs"."lifecycle_stage" <> $3)`,
`WHERE ("latest_metrics"."key" = $1 AND ("latest_metrics"."value" < $2 AND "runs"."lifecycle_stage" <> $3))`,
suprjinx marked this conversation as resolved.
Show resolved Hide resolved
expectedVars: []interface{}{"my_metric", -1, models.LifecycleStageDeleted},
},
{
name: "TestNegativeFloat",
query: `run.metrics['my_metric'].last < -1.0`,
expectedSQL: `SELECT "run_uuid" FROM "runs" ` +
`LEFT JOIN latest_metrics metrics_0 ON runs.run_uuid = metrics_0.run_uuid AND metrics_0.key = $1 ` +
`WHERE ("metrics_0"."value" < $2 AND "runs"."lifecycle_stage" <> $3)`,
`WHERE ("latest_metrics"."key" = $1 AND ("latest_metrics"."value" < $2 AND "runs"."lifecycle_stage" <> $3))`,
expectedVars: []interface{}{"my_metric", -1.0, models.LifecycleStageDeleted},
},
{
name: "TestMetricContext",
query: `metric.context.key1 == 'value1'`,
selectMetrics: true,
expectedSQL: `SELECT ID FROM "metrics" ` +
`LEFT JOIN contexts ON latest_metrics.context_id = contexts.id ` +
`WHERE ("contexts"."json"#>>$1 = $2 AND "runs"."lifecycle_stage" <> $3)`,
expectedVars: []interface{}{"{key1}", "value1", models.LifecycleStageDeleted},
},
Expand All @@ -136,10 +133,25 @@ func (s *QueryTestSuite) TestPostgresDialector_Ok() {
query: `metric.context.key1 != 'value1'`,
selectMetrics: true,
expectedSQL: `SELECT ID FROM "metrics" ` +
`LEFT JOIN contexts ON latest_metrics.context_id = contexts.id ` +
`WHERE ("contexts"."json"#>>$1 <> $2 AND "runs"."lifecycle_stage" <> $3)`,
expectedVars: []interface{}{"{key1}", "value1", models.LifecycleStageDeleted},
},
{
name: "TestMetricContextSlice",
query: `run.metrics[{"key1": "value1"}].last < -1`,
expectedSQL: `SELECT "run_uuid" FROM "runs" ` +
`WHERE ("contexts"."json"#>>$1 = $2 AND ("latest_metrics"."value" < $3 ` +
`AND "runs"."lifecycle_stage" <> $4))`,
expectedVars: []interface{}{"{key1}", "value1", -1, models.LifecycleStageDeleted},
},
{
name: "TestMetricContextSliceTuple",
query: `run.metrics["my_metric", {"key1": "value1"}].last < -1`,
expectedSQL: `SELECT "run_uuid" FROM "runs" ` +
`WHERE (("latest_metrics"."key" = $1 AND "contexts"."json"#>>$2 = $3) ` +
`AND ("latest_metrics"."value" < $4 AND "runs"."lifecycle_stage" <> $5))`,
expectedVars: []interface{}{"my_metric", "{key1}", "value1", -1, models.LifecycleStageDeleted},
},
}

for _, tt := range tests {
Expand Down Expand Up @@ -251,24 +263,23 @@ func (s *QueryTestSuite) TestSqliteDialector_Ok() {
name: "TestNegativeInteger",
query: `run.metrics['my_metric'].last < -1`,
expectedSQL: `SELECT "run_uuid" FROM "runs" ` +
`LEFT JOIN latest_metrics metrics_0 ON runs.run_uuid = metrics_0.run_uuid AND metrics_0.key = $1 ` +
`WHERE ("metrics_0"."value" < $2 AND "runs"."lifecycle_stage" <> $3)`,
`WHERE ("latest_metrics"."key" = $1 AND ("latest_metrics"."value" < $2 ` +
`AND "runs"."lifecycle_stage" <> $3))`,
expectedVars: []interface{}{"my_metric", -1, models.LifecycleStageDeleted},
},
{
name: "TestNegativeFloat",
query: `run.metrics['my_metric'].last < -1.0`,
expectedSQL: `SELECT "run_uuid" FROM "runs" ` +
`LEFT JOIN latest_metrics metrics_0 ON runs.run_uuid = metrics_0.run_uuid AND metrics_0.key = $1 ` +
`WHERE ("metrics_0"."value" < $2 AND "runs"."lifecycle_stage" <> $3)`,
`WHERE ("latest_metrics"."key" = $1 AND ("latest_metrics"."value" < $2 ` +
`AND "runs"."lifecycle_stage" <> $3))`,
expectedVars: []interface{}{"my_metric", -1.0, models.LifecycleStageDeleted},
},
{
name: "TestMetricContext",
query: `metric.context.key1 == 'value1'`,
selectMetrics: true,
expectedSQL: `SELECT ID FROM "metrics" ` +
`LEFT JOIN contexts ON latest_metrics.context_id = contexts.id ` +
`WHERE (IFNULL("contexts"."json", JSON('{}'))->>$1 = $2 AND "runs"."lifecycle_stage" <> $3)`,
expectedVars: []interface{}{"key1", "value1", models.LifecycleStageDeleted},
},
Expand All @@ -277,10 +288,33 @@ func (s *QueryTestSuite) TestSqliteDialector_Ok() {
query: `metric.context.key1 != 'value1'`,
selectMetrics: true,
expectedSQL: `SELECT ID FROM "metrics" ` +
`LEFT JOIN contexts ON latest_metrics.context_id = contexts.id ` +
`WHERE (IFNULL("contexts"."json", JSON('{}'))->>$1 <> $2 AND "runs"."lifecycle_stage" <> $3)`,
expectedVars: []interface{}{"key1", "value1", models.LifecycleStageDeleted},
},
{
name: "TestMetricKeySlice",
query: `run.metrics["key1"].last < -1`,
expectedSQL: `SELECT "run_uuid" FROM "runs" ` +
`WHERE ("latest_metrics"."key" = $1 AND ` +
`("latest_metrics"."value" < $2 AND "runs"."lifecycle_stage" <> $3))`,
expectedVars: []interface{}{"key1", -1, models.LifecycleStageDeleted},
},
{
name: "TestMetricContextSlice",
query: `run.metrics[{"key1": "value1"}].last < -1`,
expectedSQL: `SELECT "run_uuid" FROM "runs" ` +
`WHERE (IFNULL("contexts"."json", JSON('{}'))->>$1 = $2 AND ` +
`("latest_metrics"."value" < $3 AND "runs"."lifecycle_stage" <> $4))`,
expectedVars: []interface{}{"key1", "value1", -1, models.LifecycleStageDeleted},
},
{
name: "TestMetricContextSliceTuple",
query: `run.metrics["my_metric", {"key1": "value1"}].last < -1`,
expectedSQL: `SELECT "run_uuid" FROM "runs" ` +
`WHERE (("latest_metrics"."key" = $1 AND IFNULL("contexts"."json", JSON('{}'))->>$2 = $3) ` +
`AND ("latest_metrics"."value" < $4 AND "runs"."lifecycle_stage" <> $5))`,
expectedVars: []interface{}{"my_metric", "key1", "value1", -1, models.LifecycleStageDeleted},
},
}

for _, tt := range tests {
Expand Down Expand Up @@ -328,6 +362,11 @@ func (s *QueryTestSuite) Test_Error() {
query: `metric.context.parent.nested == 'value1'`,
expectedError: SyntaxError{},
},
{
name: "TestMetricContextSliceTupleOrder",
query: `run.metrics[{"key1": "value1"}, "my_metric"].last < -1`,
expectedError: SyntaxError{},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
Expand Down
12 changes: 8 additions & 4 deletions pkg/api/aim/runs.go
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,9 @@ func SearchRuns(c *fiber.Ctx) error {
}

var runs []database.Run
// add joins needed by pq.Filter
tx.Joins("JOIN latest_metrics USING(run_uuid)").
Joins("JOIN contexts ON latest_metrics.context_id = contexts.id")
pq.Filter(tx).Find(&runs)
if tx.Error != nil {
return fmt.Errorf("error searching runs: %w", tx.Error)
Expand Down Expand Up @@ -660,7 +663,9 @@ func SearchMetrics(c *fiber.Ctx) error {
"INNER JOIN experiments ON experiments.experiment_id = runs.experiment_id AND experiments.namespace_id = ?",
ns.ID,
).
Joins("LEFT JOIN latest_metrics USING(run_uuid)"))).
Joins("JOIN latest_metrics USING(run_uuid)").
Joins("JOIN contexts ON latest_metrics.context_id = contexts.id"),
)).
Order("runs.row_num DESC").
Find(&runs); tx.Error != nil {
return fmt.Errorf("error searching run metrics: %w", tx.Error)
Expand Down Expand Up @@ -718,7 +723,7 @@ func SearchMetrics(c *fiber.Ctx) error {
"runs.row_num",
"latest_metrics.key",
"latest_metrics.context_id",
"latest_metrics_context.json AS context_json",
"contexts.json AS context_json",
fmt.Sprintf("(latest_metrics.last_iter + 1)/ %f AS interval", float32(q.Steps)),
).
Table("runs").
Expand All @@ -727,8 +732,7 @@ func SearchMetrics(c *fiber.Ctx) error {
ns.ID,
).
Joins("LEFT JOIN latest_metrics USING(run_uuid)").
Joins(`LEFT JOIN contexts latest_metrics_context `+
`ON latest_metrics.context_id = latest_metrics_context.id`)),
Joins("LEFT JOIN contexts ON latest_metrics.context_id = contexts.id")),
).
Where("MOD(metrics.iter + 1 + runmetrics.interval / 2, runmetrics.interval) < 1").
Order("runmetrics.row_num DESC").
Expand Down
Loading