Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions api/py/ai/chronon/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import ai.chronon.api.ttypes as ttypes
from typing import Optional


class ModelType:
XGBoost = ttypes.ModelType.XGBoost
PyTorch = ttypes.ModelType.PyTorch


# Name must match S3 path that we expose if you're uploading trained models?
def Model(
source: ttypes.Source,
outputSchema: ttypes.TDataType,
modelType: ModelType,
name: str = None,
modelParams: Optional[dict[str, str]] = None
) -> ttypes.Model:
if not isinstance(source, ttypes.Source):
raise ValueError("Invalid source type")
if not (isinstance(outputSchema, ttypes.TDataType) or isinstance(outputSchema, int)):
raise ValueError("outputSchema must be a TDataType or DataKind")
if isinstance(outputSchema, int):
# Convert DataKind to TDataType
outputSchema = ttypes.TDataType(outputSchema)

if modelParams is None:
modelParams = {}

metaData = ttypes.MetaData(
name=name,
)

return ttypes.Model(modelType=modelType, outputSchema=outputSchema, source=source,
modelParams=modelParams, metaData=metaData)
1 change: 1 addition & 0 deletions api/py/ai/chronon/repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@
JOIN_FOLDER_NAME = 'joins'
GROUP_BY_FOLDER_NAME = 'group_bys'
STAGING_QUERY_FOLDER_NAME = 'staging_queries'
MODEL_FOLDER_NAME = 'models'
# TODO - make team part of thrift API?
TEAMS_FILE_PATH = 'teams.json'
5 changes: 3 additions & 2 deletions api/py/ai/chronon/repo/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
import ai.chronon.api.ttypes as api
import ai.chronon.repo.extract_objects as eo
import ai.chronon.utils as utils
from ai.chronon.api.ttypes import GroupBy, Join, StagingQuery
from ai.chronon.api.ttypes import GroupBy, Join, StagingQuery, Model
from ai.chronon.repo import JOIN_FOLDER_NAME, \
GROUP_BY_FOLDER_NAME, STAGING_QUERY_FOLDER_NAME, TEAMS_FILE_PATH
GROUP_BY_FOLDER_NAME, STAGING_QUERY_FOLDER_NAME, MODEL_FOLDER_NAME, TEAMS_FILE_PATH
from ai.chronon.repo import teams
from ai.chronon.repo.serializer import thrift_simple_json_protected
from ai.chronon.repo.validator import ChrononRepoValidator, get_join_output_columns, get_group_by_output_columns
Expand All @@ -38,6 +38,7 @@
GROUP_BY_FOLDER_NAME: GroupBy,
JOIN_FOLDER_NAME: Join,
STAGING_QUERY_FOLDER_NAME: StagingQuery,
MODEL_FOLDER_NAME: Model,
}

DEFAULT_TEAM_NAME = "default"
Expand Down
20 changes: 20 additions & 0 deletions api/py/test/sample/models/quickstart/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

from ai.chronon.model import Model, ModelType
from ai.chronon.api.ttypes import DataKind, EventSource, Source, TDataType
from ai.chronon.query import Query, select


"""
This is the "left side" of the join that will comprise our training set. It is responsible for providing the primary keys
and timestamps for which features will be computed.
"""
source = Source(
events=EventSource(
table="data.checkouts",
query=Query(
selects=select("user_id"),
time_column="ts",
)
))

v1 = Model(source=source, outputSchema=TDataType(DataKind.DOUBLE), modelType=ModelType.XGBoost)
27 changes: 27 additions & 0 deletions api/py/test/sample/production/models/quickstart/test.v1
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"outputSchema": {
"kind": 6
},
"modelType": 1,
"metaData": {
"name": "quickstart.test.v1",
"tableProperties": {
"source": "chronon"
},
"outputNamespace": "default",
"team": "quickstart"
},
"source": {
"events": {
"table": "data.checkouts",
"query": {
"selects": {
"user_id": "user_id"
},
"timeColumn": "ts",
"setups": []
}
}
},
"modelParams": {}
}
15 changes: 14 additions & 1 deletion api/thrift/api.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -420,4 +420,17 @@ struct DataSpec {
2: optional list<string> partitionColumns
3: optional i32 retentionDays
4: optional map<string, string> props
}
}

enum ModelType {
XGBoost = 1
PyTorch = 2
}

struct Model {
1: optional TDataType outputSchema
2: optional ModelType modelType
3: optional MetaData metaData
4: optional Source source
5: optional map<string, string> modelParams
}