Skip to content

Commit

Permalink
Context query subscript (#791)
Browse files Browse the repository at this point in the history
Adds dictionary and tuple parsing to `runs.metric[]` subscript slicer
  • Loading branch information
suprjinx authored Jan 23, 2024
1 parent 943662a commit ad67e89
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 62 deletions.
242 changes: 188 additions & 54 deletions pkg/api/aim/query/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,17 @@ import (
"github.com/go-python/gpython/parser"
"github.com/go-python/gpython/py"
"github.com/gofiber/fiber/v2"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
"gorm.io/gorm/clause"

"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
)

const (
TableContexts = "contexts"
)

type DefaultExpression struct {
Contains string
Expression string
Expand All @@ -36,6 +41,7 @@ type ParsedQuery interface {
type parsedQuery struct {
qp *QueryParser
joins map[string]join
joinKeys []string
conditions []clause.Expression
metricSelected bool
}
Expand All @@ -47,6 +53,7 @@ type attributeGetter func(attr string) (any, error)
type subscriptSlicer func(index ast.Slicer) (any, error)

type join struct {
key string
alias string
query string
args []any
Expand Down Expand Up @@ -154,9 +161,25 @@ func (qp *QueryParser) Parse(q string) (ParsedQuery, error) {
return pq, nil
}

// AddJoin will append a query join and retain the order added.
func (pq *parsedQuery) AddJoin(key string, j join) {
_, ok := pq.joins[key]
if !ok {
pq.joinKeys = append(pq.joinKeys, key)
pq.joins[key] = j
}
}

// Filter will add the appropriate Joins and Where clauses to the tx.
func (pq *parsedQuery) Filter(tx *gorm.DB) *gorm.DB {
for _, j := range pq.joins {
tx.Joins(j.query, j.args...)
for _, k := range pq.joinKeys {
j, ok := pq.joins[k]
// prevents panic, but something is wrong if not okay here
if ok {
tx.Joins(j.query, j.args...)
} else {
log.Errorf("error preparing query filter, join key not found in joins map: %s", k)
}
}
if len(pq.conditions) > 0 {
tx.Where(clause.And(pq.conditions...))
Expand Down Expand Up @@ -204,6 +227,10 @@ func (pq *parsedQuery) _parseNode(node ast.Expr) (any, error) {
return pq.parseAttribute(n)
case *ast.Compare:
return pq.parseCompare(n)
case *ast.Dict:
return pq.parseDictionary(n)
case *ast.Tuple:
return pq.parseTuple(n)
default:
return nil, fmt.Errorf("unsupported expression %q", ast.Dump(n))
}
Expand Down Expand Up @@ -404,6 +431,38 @@ func (pq *parsedQuery) parseCompare(node *ast.Compare) (any, error) {
}, nil
}

// parseDictionary returns []JsonEq conditions derived from the dictionary.
func (pq *parsedQuery) parseDictionary(node *ast.Dict) (any, error) {
clauses := make([]JsonEq, len(node.Keys))
for i, key := range node.Keys {
clauses[i] = JsonEq{
Left: Json{
Column: clause.Column{
Table: TableContexts,
Name: "json",
},
JsonPath: string(key.(*ast.Str).S),
Dialector: pq.qp.Dialector,
},
Value: string(node.Values[i].(*ast.Str).S),
}
}
return clauses, nil
}

// parseTuple converts a tuple node to slice of parsed nodes.
func (pq *parsedQuery) parseTuple(node *ast.Tuple) (any, error) {
var err error
list := make([]any, len(node.Elts))
for i, e := range node.Elts {
list[i], err = pq.parseNode(e)
if err != nil {
return nil, err
}
}
return list, nil
}

func (pq *parsedQuery) parseList(node *ast.List) (any, error) {
var err error
list := make([]any, len(node.Elts))
Expand Down Expand Up @@ -488,41 +547,7 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) {
if err != nil {
return nil, err
}
switch v := v.(type) {
case string:
j, ok := pq.joins[fmt.Sprintf("metrics:%s", v)]
if !ok {
alias := fmt.Sprintf("metrics_%d", len(pq.joins))
j = join{
alias: alias,
query: fmt.Sprintf(
"LEFT JOIN latest_metrics %s ON %s.run_uuid = %s.run_uuid AND %s.key = ?",
alias, table, alias, alias,
),
args: []any{v},
}
pq.joins[fmt.Sprintf("metrics:%s", v)] = j
}
return attributeGetter(func(attr string) (any, error) {
var name string
switch attr {
case "last":
name = "value"
case "last_step":
name = "last_iter"
case "first_step":
return 0, nil
default:
return nil, fmt.Errorf("unsupported metrics attribute %q", attr)
}
return clause.Column{
Table: j.alias,
Name: name,
}, nil
}), nil
default:
return nil, fmt.Errorf("unsupported index value type %t", v)
}
return pq.metricSubscriptSlicer(v)
default:
return nil, fmt.Errorf("unsupported slicer %q", ast.Dump(s))
}
Expand All @@ -537,7 +562,8 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) {
}
switch v := v.(type) {
case string:
j, ok := pq.joins[fmt.Sprintf("tags:%s", v)]
joinKey := fmt.Sprintf("tags:%s", v)
j, ok := pq.joins[joinKey]
if !ok {
alias := fmt.Sprintf("tags_%d", len(pq.joins))
j = join{
Expand All @@ -548,21 +574,22 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) {
),
args: []any{v},
}
pq.joins[fmt.Sprintf("tags:%s", v)] = j
pq.AddJoin(joinKey, j)
}
return clause.Column{
Table: j.alias,
Name: "value",
}, nil
default:
return nil, fmt.Errorf("unsupported index value type %t", v)
return nil, fmt.Errorf("unsupported index value type %T", v)
}
default:
return nil, fmt.Errorf("unsupported slicer %q", ast.Dump(s))
}
}), nil
default:
j, ok := pq.joins[fmt.Sprintf("params:%s", attr)]
joinKey := fmt.Sprintf("params:%s", attr)
j, ok := pq.joins[joinKey]
if !ok {
alias := fmt.Sprintf("params_%d", len(pq.joins))
j = join{
Expand All @@ -573,7 +600,7 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) {
),
args: []any{attr},
}
pq.joins[fmt.Sprintf("params:%s", attr)] = j
pq.AddJoin(joinKey, j)
}
return clause.Column{
Table: j.alias,
Expand Down Expand Up @@ -611,21 +638,10 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) {
case "context":
return attributeGetter(
func(contextKey string) (any, error) {
// create the join for contexts
_, ok := pq.joins["metric_contexts"]
if !ok {
alias := "contexts"
j := join{
alias: alias,
query: "LEFT JOIN contexts ON latest_metrics.context_id = contexts.id",
}
pq.joins["metric_contexts"] = j
}

// Add a WHERE clause for the context key
return Json{
Column: clause.Column{
Table: "contexts",
Table: TableContexts,
Name: "json",
},
JsonPath: contextKey,
Expand Down Expand Up @@ -728,6 +744,124 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) {
}
}

func (pq *parsedQuery) metricSubscriptSlicer(v any) (any, error) {
table, ok := pq.qp.Tables["runs"]
if !ok {
return nil, errors.New("unsupported table name 'runs'")
}
switch v := v.(type) {
case string:
// case of metric key
pq.metricSelected = true
latestMetricJoin := pq.latestMetricsKeyJoin(v, table)
return metricAttributeGetter(latestMetricJoin.alias)
case []any:
// case of subscript tuple (string and context dictionary)
if len(v) != 2 {
return nil, fmt.Errorf("unsupported tuple length %d (should be 2)", len(v))
}
metricKey, ok := v[0].(string)
if !ok {
return nil, fmt.Errorf("unsupported tuple value type %T (should be string at 0)", v)
}
metricContextExpression, ok := v[1].([]JsonEq)
if !ok {
return nil, fmt.Errorf("unsupported index value type %T (should be []JsonEq at 1)", v)
}
pq.metricSelected = true
latestMetricJoin := pq.latestMetricsKeyJoin(metricKey, table)
pq.latestMetricsContextJoin(metricContextExpression, latestMetricJoin)
return metricAttributeGetter(latestMetricJoin.alias)
default:
return nil, fmt.Errorf("unsupported index value type %T", v)
}
}

// latestMetricsKeyJoin joins the latest_metrics table by run_uuid and metric key, returning the join struct.
func (pq *parsedQuery) latestMetricsKeyJoin(key, table string) join {
joinsKey := fmt.Sprintf("metrics:%s", key)
j, ok := pq.joins[joinsKey]
if !ok {
alias := fmt.Sprintf("metrics_%d", len(pq.joins))
j = join{
alias: alias,
query: fmt.Sprintf(
"LEFT JOIN latest_metrics %s ON %s.run_uuid = %s.run_uuid AND %s.key = ?",
alias, table, alias, alias,
),
args: []any{key},
key: joinsKey,
}
pq.AddJoin(joinsKey, j)
}
return j
}

// latestMetrics joins the latest_metrics and contexts tables, reusing the latestMetricsJoin param when given.
// returns the latest_metrics and contexts join structs.
func (pq *parsedQuery) latestMetricsContextJoin(exps []JsonEq, latestMetricsJoin join) (join, join) {
latestMetricsJoin, ok := pq.joins[latestMetricsJoin.key]
if !ok {
alias := fmt.Sprintf("metrics_%d", len(pq.joins))
latestMetricsJoin = join{
alias: alias,
query: fmt.Sprintf(
"LEFT JOIN latest_metrics %s USING(run_uuid)",
alias,
),
key: alias,
}
pq.AddJoin(alias, latestMetricsJoin)
}

contextsJoinKey := fmt.Sprintf("contexts:%s", latestMetricsJoin.alias)
contextJoin, ok := pq.joins[contextsJoinKey]
if !ok {
alias := fmt.Sprintf("contexts_%d", len(pq.joins))
contextJoin = join{
alias: alias,
query: fmt.Sprintf(
"LEFT JOIN contexts %s ON %s.context_id = %s.id",
alias, latestMetricsJoin.alias, alias,
),
key: contextsJoinKey,
}
pq.AddJoin(contextsJoinKey, contextJoin)
}

// adjust the expressions to reference the new context alias created for the join
clauses := make([]clause.Expression, len(exps))
for idx := range exps {
exps[idx].Left.Table = contextJoin.alias
clauses[idx] = exps[idx]
}

if len(clauses) > 0 {
pq.conditions = append(pq.conditions, clause.And(clauses...))
}
return latestMetricsJoin, contextJoin
}

func metricAttributeGetter(table string) (any, error) {
return attributeGetter(func(attr string) (any, error) {
var name string
switch attr {
case "last":
name = "value"
case "last_step":
name = "last_iter"
case "first_step":
return 0, nil
default:
return nil, fmt.Errorf("unsupported metrics attribute %q", attr)
}
return clause.Column{
Table: table,
Name: name,
}, nil
}), nil
}

func (pq *parsedQuery) parseNameConstant(node *ast.NameConstant) (any, error) {
switch node.Value.Type() {
case py.NoneTypeType:
Expand Down
Loading

0 comments on commit ad67e89

Please sign in to comment.