Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Context query subscript #791

Merged
merged 68 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from 65 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
fcee4fb
WIP for parsing context dictionary suscript
suprjinx Dec 20, 2023
e5476a6
Fix needed for migrating existing data
suprjinx Jan 3, 2024
8aecc94
Set latest_metric.context_id
suprjinx Jan 3, 2024
c9cc2bb
Use AutoMigrate and new unique index for context_id col
suprjinx Jan 3, 2024
f08e216
switch back to primary keys
suprjinx Jan 3, 2024
71e8bfb
Create new tables and copy data
suprjinx Jan 3, 2024
be16a83
Drop old indexes
suprjinx Jan 3, 2024
b83615c
make helper for deleting index if exists
suprjinx Jan 3, 2024
5f90110
Postgres does not relax FK constraints, so we have to supply the real
suprjinx Jan 4, 2024
0c6b35f
Remove the default:0 gorm annotation -- will never work
suprjinx Jan 4, 2024
6b99867
add test to make sure latest_metrics are differentiated by context
suprjinx Jan 4, 2024
d850736
lint
suprjinx Jan 4, 2024
a1bd805
Clarified comments
suprjinx Jan 4, 2024
7e36e35
Properly consolidate metrics by contiguous key and context
suprjinx Jan 4, 2024
e87ff9a
Merge branch 'main' into context-query-subscript
suprjinx Jan 4, 2024
de5cbbc
Merge branch 'main' into fix-context-migration-for-upgrades
suprjinx Jan 5, 2024
c207009
add dictionary parsing
suprjinx Jan 5, 2024
7160444
add subscript handler for []clause.Expression (which is what the
suprjinx Jan 5, 2024
dbb40e2
remove extra blank line
suprjinx Jan 5, 2024
7e91551
DRY up the metric attribute getter
suprjinx Jan 5, 2024
db649de
Lint
suprjinx Jan 5, 2024
f47cba3
Fix lints
suprjinx Jan 5, 2024
e3937fb
Fix another lint
suprjinx Jan 5, 2024
da289c6
another lint fix
suprjinx Jan 6, 2024
f87e085
fix the metric attribute getter
suprjinx Jan 6, 2024
1012b2c
add support for metric subscript tuple
suprjinx Jan 7, 2024
7979db1
additional test reversing tuple order
suprjinx Jan 7, 2024
c4748a0
Move conditions append to the subscript slicer
suprjinx Jan 7, 2024
c706d0a
Add tuple test for Postgres
suprjinx Jan 7, 2024
f6efd8d
lint
suprjinx Jan 7, 2024
e16dd20
WIP refactor to use one latest_metrics join
suprjinx Jan 11, 2024
4545edf
WIP cont'd
suprjinx Jan 11, 2024
108a7e7
Update query parser to handle tuple case
suprjinx Jan 11, 2024
0914330
Fix lint issues
suprjinx Jan 12, 2024
4fa155e
Add row-count verification to migration
suprjinx Jan 12, 2024
77b33bd
refactor migration for tables/indexes map
suprjinx Jan 12, 2024
17e8458
PR cleanup
suprjinx Jan 12, 2024
20b82b0
lint
suprjinx Jan 12, 2024
a93f904
Merge branch 'main' into fix-context-migration-for-upgrades
suprjinx Jan 12, 2024
a0a5fb8
more lint
suprjinx Jan 12, 2024
633aa4d
Merge branch 'fix-context-migration-for-upgrades' of github.com:suprj…
suprjinx Jan 12, 2024
8a9854f
fix typo
suprjinx Jan 12, 2024
35c2f1a
Merge branch 'fix-context-migration-for-upgrades' into context-query-…
suprjinx Jan 12, 2024
401ad4d
WIP adapting tests
suprjinx Jan 12, 2024
ce294f8
Fixed remaining tests by adding joins ordering slice
suprjinx Jan 12, 2024
c0479be
Make sure latest_metrics is joined for query parser
suprjinx Jan 13, 2024
26ae1ea
cosmetic query change
suprjinx Jan 13, 2024
76ec123
put lastest_metrics and context joins into the input query
suprjinx Jan 13, 2024
d5a84da
adjust query_test, metric and context joins are not created by Filter
suprjinx Jan 14, 2024
55d9fd4
Remove the joinKeys slice
suprjinx Jan 14, 2024
20d1b26
Add joins needed by pq.Filter
suprjinx Jan 14, 2024
3da858e
Merge branch 'main' into context-query-subscript
suprjinx Jan 14, 2024
4fc85e9
Can't have metrics without latest_metrics
suprjinx Jan 14, 2024
755d0fb
Merge branch 'main' into context-query-subscript
suprjinx Jan 15, 2024
d79d161
restore unique joins for each metric evaluation
suprjinx Jan 15, 2024
3525982
fix lint and unit tests
suprjinx Jan 15, 2024
6ea1774
comments and a couple of more test cases
suprjinx Jan 15, 2024
cdeb399
Comment and additional tests
suprjinx Jan 15, 2024
1c2816e
Don't add new join for metric.context queries
suprjinx Jan 16, 2024
c1aefc8
lint
suprjinx Jan 16, 2024
59af0fa
lint
suprjinx Jan 16, 2024
28b9e16
errant capitalization removal
suprjinx Jan 16, 2024
fa0b9af
lint
suprjinx Jan 16, 2024
2e5c2b9
use constant for table name
suprjinx Jan 17, 2024
f3d8fef
allow "run.metrics[name]" to count as metric selected
suprjinx Jan 17, 2024
682ba0f
error case with joins map
suprjinx Jan 22, 2024
8ec7363
Merge branch 'main' into context-query-subscript
suprjinx Jan 23, 2024
82c6c9e
Merge branch 'main' into context-query-subscript
suprjinx Jan 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 181 additions & 53 deletions pkg/api/aim/query/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ import (
"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
)

const (
TableContexts = "contexts"
)

type DefaultExpression struct {
Contains string
Expression string
Expand All @@ -36,6 +40,7 @@ type ParsedQuery interface {
type parsedQuery struct {
qp *QueryParser
joins map[string]join
joinKeys []string
conditions []clause.Expression
metricSelected bool
}
Expand All @@ -47,6 +52,7 @@ type attributeGetter func(attr string) (any, error)
type subscriptSlicer func(index ast.Slicer) (any, error)

type join struct {
key string
alias string
query string
args []any
Expand Down Expand Up @@ -154,8 +160,19 @@ func (qp *QueryParser) Parse(q string) (ParsedQuery, error) {
return pq, nil
}

// AddJoin will append a query join and retain the order added.
func (pq *parsedQuery) AddJoin(key string, j join) {
_, ok := pq.joins[key]
if !ok {
pq.joinKeys = append(pq.joinKeys, key)
pq.joins[key] = j
}
}

// Filter will add the appropriate Joins and Where clauses to the tx.
func (pq *parsedQuery) Filter(tx *gorm.DB) *gorm.DB {
for _, j := range pq.joins {
for _, k := range pq.joinKeys {
j := pq.joins[k]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could it be circumstance, when j := pq.joins[k] is nil?

Copy link
Contributor Author

@suprjinx suprjinx Jan 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i searched for all cases where we append to pq.joins to make sure this isn't possible (using pq.AddJoin ensures this). Added some error-level logging if joins[k] is nil.

tx.Joins(j.query, j.args...)
}
if len(pq.conditions) > 0 {
Expand Down Expand Up @@ -204,6 +221,10 @@ func (pq *parsedQuery) _parseNode(node ast.Expr) (any, error) {
return pq.parseAttribute(n)
case *ast.Compare:
return pq.parseCompare(n)
case *ast.Dict:
return pq.parseDictionary(n)
case *ast.Tuple:
return pq.parseTuple(n)
default:
return nil, fmt.Errorf("unsupported expression %q", ast.Dump(n))
}
Expand Down Expand Up @@ -404,6 +425,38 @@ func (pq *parsedQuery) parseCompare(node *ast.Compare) (any, error) {
}, nil
}

// parseDictionary returns []JsonEq conditions derived from the dictionary.
func (pq *parsedQuery) parseDictionary(node *ast.Dict) (any, error) {
clauses := make([]JsonEq, len(node.Keys))
for i, key := range node.Keys {
clauses[i] = JsonEq{
Left: Json{
Column: clause.Column{
Table: TableContexts,
Name: "json",
},
dsuhinin marked this conversation as resolved.
Show resolved Hide resolved
JsonPath: string(key.(*ast.Str).S),
Dialector: pq.qp.Dialector,
},
Value: string(node.Values[i].(*ast.Str).S),
}
}
return clauses, nil
}

// parseTuple converts a tuple node to slice of parsed nodes.
func (pq *parsedQuery) parseTuple(node *ast.Tuple) (any, error) {
var err error
list := make([]any, len(node.Elts))
for i, e := range node.Elts {
list[i], err = pq.parseNode(e)
if err != nil {
return nil, err
}
}
return list, nil
}

func (pq *parsedQuery) parseList(node *ast.List) (any, error) {
var err error
list := make([]any, len(node.Elts))
Expand Down Expand Up @@ -488,41 +541,7 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) {
if err != nil {
return nil, err
}
switch v := v.(type) {
case string:
j, ok := pq.joins[fmt.Sprintf("metrics:%s", v)]
if !ok {
alias := fmt.Sprintf("metrics_%d", len(pq.joins))
j = join{
alias: alias,
query: fmt.Sprintf(
"LEFT JOIN latest_metrics %s ON %s.run_uuid = %s.run_uuid AND %s.key = ?",
alias, table, alias, alias,
),
args: []any{v},
}
pq.joins[fmt.Sprintf("metrics:%s", v)] = j
}
return attributeGetter(func(attr string) (any, error) {
var name string
switch attr {
case "last":
name = "value"
case "last_step":
name = "last_iter"
case "first_step":
return 0, nil
default:
return nil, fmt.Errorf("unsupported metrics attribute %q", attr)
}
return clause.Column{
Table: j.alias,
Name: name,
}, nil
}), nil
default:
return nil, fmt.Errorf("unsupported index value type %t", v)
}
return pq.metricSubscriptSlicer(v)
default:
return nil, fmt.Errorf("unsupported slicer %q", ast.Dump(s))
}
Expand All @@ -537,7 +556,8 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) {
}
switch v := v.(type) {
case string:
j, ok := pq.joins[fmt.Sprintf("tags:%s", v)]
joinKey := fmt.Sprintf("tags:%s", v)
j, ok := pq.joins[joinKey]
if !ok {
alias := fmt.Sprintf("tags_%d", len(pq.joins))
j = join{
Expand All @@ -548,21 +568,22 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) {
),
args: []any{v},
}
pq.joins[fmt.Sprintf("tags:%s", v)] = j
pq.AddJoin(joinKey, j)
}
return clause.Column{
Table: j.alias,
Name: "value",
}, nil
default:
return nil, fmt.Errorf("unsupported index value type %t", v)
return nil, fmt.Errorf("unsupported index value type %T", v)
}
default:
return nil, fmt.Errorf("unsupported slicer %q", ast.Dump(s))
}
}), nil
default:
j, ok := pq.joins[fmt.Sprintf("params:%s", attr)]
joinKey := fmt.Sprintf("params:%s", attr)
j, ok := pq.joins[joinKey]
if !ok {
alias := fmt.Sprintf("params_%d", len(pq.joins))
j = join{
Expand All @@ -573,7 +594,7 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) {
),
args: []any{attr},
}
pq.joins[fmt.Sprintf("params:%s", attr)] = j
pq.AddJoin(joinKey, j)
}
return clause.Column{
Table: j.alias,
Expand Down Expand Up @@ -611,21 +632,10 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) {
case "context":
return attributeGetter(
func(contextKey string) (any, error) {
// create the join for contexts
_, ok := pq.joins["metric_contexts"]
if !ok {
alias := "contexts"
j := join{
alias: alias,
query: "LEFT JOIN contexts ON latest_metrics.context_id = contexts.id",
}
pq.joins["metric_contexts"] = j
}

// Add a WHERE clause for the context key
return Json{
Column: clause.Column{
Table: "contexts",
Table: TableContexts,
Name: "json",
},
JsonPath: contextKey,
Expand Down Expand Up @@ -728,6 +738,124 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) {
}
}

func (pq *parsedQuery) metricSubscriptSlicer(v any) (any, error) {
table, ok := pq.qp.Tables["runs"]
if !ok {
return nil, errors.New("unsupported table name 'runs'")
}
switch v := v.(type) {
case string:
// case of metric key
pq.metricSelected = true
latestMetricJoin := pq.latestMetricsKeyJoin(v, table)
return metricAttributeGetter(latestMetricJoin.alias)
case []any:
// case of subscript tuple (string and context dictionary)
if len(v) != 2 {
return nil, fmt.Errorf("unsupported tuple length %d (should be 2)", len(v))
}
metricKey, ok := v[0].(string)
if !ok {
return nil, fmt.Errorf("unsupported tuple value type %T (should be string at 0)", v)
}
metricContextExpression, ok := v[1].([]JsonEq)
if !ok {
return nil, fmt.Errorf("unsupported index value type %T (should be []JsonEq at 1)", v)
}
pq.metricSelected = true
latestMetricJoin := pq.latestMetricsKeyJoin(metricKey, table)
pq.latestMetricsContextJoin(metricContextExpression, latestMetricJoin)
return metricAttributeGetter(latestMetricJoin.alias)
default:
return nil, fmt.Errorf("unsupported index value type %T", v)
}
}

// latestMetricsKeyJoin joins the latest_metrics table by run_uuid and metric key, returning the join struct.
func (pq *parsedQuery) latestMetricsKeyJoin(key, table string) join {
joinsKey := fmt.Sprintf("metrics:%s", key)
j, ok := pq.joins[joinsKey]
if !ok {
alias := fmt.Sprintf("metrics_%d", len(pq.joins))
j = join{
alias: alias,
query: fmt.Sprintf(
"LEFT JOIN latest_metrics %s ON %s.run_uuid = %s.run_uuid AND %s.key = ?",
alias, table, alias, alias,
),
args: []any{key},
key: joinsKey,
}
pq.AddJoin(joinsKey, j)
}
return j
}

// latestMetrics joins the latest_metrics and contexts tables, reusing the latestMetricsJoin param when given.
// returns the latest_metrics and contexts join structs.
func (pq *parsedQuery) latestMetricsContextJoin(exps []JsonEq, latestMetricsJoin join) (join, join) {
latestMetricsJoin, ok := pq.joins[latestMetricsJoin.key]
if !ok {
alias := fmt.Sprintf("metrics_%d", len(pq.joins))
latestMetricsJoin = join{
alias: alias,
query: fmt.Sprintf(
"LEFT JOIN latest_metrics %s USING(run_uuid)",
alias,
),
key: alias,
}
pq.AddJoin(alias, latestMetricsJoin)
}

contextsJoinKey := fmt.Sprintf("contexts:%s", latestMetricsJoin.alias)
contextJoin, ok := pq.joins[contextsJoinKey]
if !ok {
alias := fmt.Sprintf("contexts_%d", len(pq.joins))
contextJoin = join{
alias: alias,
query: fmt.Sprintf(
"LEFT JOIN contexts %s ON %s.context_id = %s.id",
alias, latestMetricsJoin.alias, alias,
),
key: contextsJoinKey,
}
pq.AddJoin(contextsJoinKey, contextJoin)
}

// adjust the expressions to reference the new context alias created for the join
clauses := make([]clause.Expression, len(exps))
for idx := range exps {
exps[idx].Left.Table = contextJoin.alias
clauses[idx] = exps[idx]
}

if len(clauses) > 0 {
pq.conditions = append(pq.conditions, clause.And(clauses...))
}
return latestMetricsJoin, contextJoin
}

func metricAttributeGetter(table string) (any, error) {
return attributeGetter(func(attr string) (any, error) {
var name string
switch attr {
case "last":
name = "value"
case "last_step":
name = "last_iter"
case "first_step":
return 0, nil
default:
return nil, fmt.Errorf("unsupported metrics attribute %q", attr)
}
return clause.Column{
Table: table,
Name: name,
}, nil
}), nil
}

func (pq *parsedQuery) parseNameConstant(node *ast.NameConstant) (any, error) {
switch node.Value.Type() {
case py.NoneTypeType:
Expand Down
Loading