diff --git a/api/src/main/scala/ai/chronon/api/Constants.scala b/api/src/main/scala/ai/chronon/api/Constants.scala index 85aaf3f84e..582f67ae59 100644 --- a/api/src/main/scala/ai/chronon/api/Constants.scala +++ b/api/src/main/scala/ai/chronon/api/Constants.scala @@ -77,12 +77,7 @@ object Constants { val extensionsToIgnore: Array[String] = Array(".class", ".csv", ".java", ".scala", ".py", ".DS_Store") val foldersToIgnore: Array[String] = Array(".git") - // import base64 - // text_bytes = "chronon".encode('utf-8') - // base64_str = base64.b64encode(text_bytes) - // int.from_bytes(base64.b64decode(base64_str), "big") - // - // output: 27980863399423854 - - val magicNullDouble: java.lang.Double = -27980863399423854.0 + // A negative integer within the safe range for both long and double in JavaScript, Java, Scala, Python + val magicNullLong: java.lang.Long = -1234567890L + val magicNullDouble: java.lang.Double = -1234567890.0 } diff --git a/api/src/test/scala/ai/chronon/api/test/TileSeriesSerializationTest.scala b/api/src/test/scala/ai/chronon/api/test/TileSeriesSerializationTest.scala index 2072ee3303..b3c7cc7d69 100644 --- a/api/src/test/scala/ai/chronon/api/test/TileSeriesSerializationTest.scala +++ b/api/src/test/scala/ai/chronon/api/test/TileSeriesSerializationTest.scala @@ -4,10 +4,13 @@ import ai.chronon.api.Constants import ai.chronon.api.ScalaJavaConversions.JListOps import ai.chronon.api.ThriftJsonCodec import ai.chronon.observability.TileDriftSeries +import ai.chronon.observability.TileSummarySeries import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers import java.lang.{Double => JDouble} +import java.lang.{Long => JLong} +import scala.jdk.CollectionConverters.asScalaBufferConverter class TileSeriesSerializationTest extends AnyFlatSpec with Matchers { @@ -27,8 +30,49 @@ class TileSeriesSerializationTest extends AnyFlatSpec with Matchers { val jsonStr = ThriftJsonCodec.toJsonStr(tileDriftSeries) - jsonStr should be ("""{"percentileDriftSeries":[0.1,-2.7980863399423856E16,-2.7980863399423856E16,-2.7980863399423856E16,0.5]}""") + jsonStr should be (s"""{"percentileDriftSeries":[0.1,${Constants.magicNullDouble},${Constants.magicNullDouble},${Constants.magicNullDouble},0.5]}""") } + it should "deserialize double values correctly" in { + val json = s"""{"percentileDriftSeries":[0.1,${Constants.magicNullDouble},${Constants.magicNullDouble},${Constants.magicNullDouble},0.5]}""" + + val series = ThriftJsonCodec.fromJsonStr[TileDriftSeries](json, true, classOf[TileDriftSeries])(manifest[TileDriftSeries]) + + val drifts = series.getPercentileDriftSeries.asScala.toList + drifts.size should be (5) + drifts(0) should be (0.1) + drifts(1) should be (Constants.magicNullDouble) + drifts(2) should be (Constants.magicNullDouble) + drifts(3) should be (Constants.magicNullDouble) + drifts(4) should be (0.5) + } + + "TileSummarySeries" should "serialize with nulls and special long values" in { + val tileSummarySeries = new TileSummarySeries() + + val counts: Seq[JLong] = Seq(100L, null, Long.MaxValue, Constants.magicNullLong, 500L) + .map(v => if (v == null) Constants.magicNullLong else v.asInstanceOf[JLong]) + + val countsList: java.util.List[JLong] = counts.toJava + tileSummarySeries.setCount(countsList) + + val jsonStr = ThriftJsonCodec.toJsonStr(tileSummarySeries) + + jsonStr should be (s"""{"count":[100,${Constants.magicNullLong},9223372036854775807,${Constants.magicNullLong},500]}""") + } + + it should "deserialize long values correctly" in { + val json = s"""{"count":[100,${Constants.magicNullLong},9223372036854775807,${Constants.magicNullLong},500]}""" + + val series = ThriftJsonCodec.fromJsonStr[TileSummarySeries](json, true, classOf[TileSummarySeries])(manifest[TileSummarySeries]) + + val counts = series.getCount.asScala.toList + counts.size should be (5) + counts(0) should be (100L) + counts(1) should be (Constants.magicNullLong) + counts(2) should be (Long.MaxValue) + counts(3) should be (Constants.magicNullLong) + counts(4) should be (500L) + } } diff --git a/api/thrift/api.thrift b/api/thrift/api.thrift index 4a4acf46ff..46419f25c1 100644 --- a/api/thrift/api.thrift +++ b/api/thrift/api.thrift @@ -432,4 +432,4 @@ struct Model { 3: optional TDataType outputSchema 4: optional Source source 5: optional map modelParams -} \ No newline at end of file +} diff --git a/api/thrift/hub.thrift b/api/thrift/hub.thrift index e7b8553552..704d46ab7c 100644 --- a/api/thrift/hub.thrift +++ b/api/thrift/hub.thrift @@ -115,3 +115,39 @@ struct Submission { 20: optional i64 finishedTs 21: optional DateRange dateRange } + +enum ConfType{ + STAGING_QUERY = 1 + GROUP_BY = 2 + JOIN = 3 + MODEL = 4 +} + +struct ConfRequest { + 1: optional string confName + 2: optional ConfType confType + + // one of either branch or version are set - otherwise we will pull conf for main branch + 3: optional string branch + 4: optional string version +} + +/** + * lists all confs of the specified type + */ +struct ConfListRequest { + 1: optional ConfType confType + + // if not specified we will pull conf list for main branch + 2: optional string branch +} + +/** + * Response for listing configurations of a specific type + */ +struct ConfListResponse { + 1: optional list joins + 2: optional list groupBys + 3: optional list models + 4: optional list stagingQueries +} diff --git a/api/thrift/observability.thrift b/api/thrift/observability.thrift index fc1d84b616..d201b4bdee 100644 --- a/api/thrift/observability.thrift +++ b/api/thrift/observability.thrift @@ -139,4 +139,24 @@ struct DriftSpec { // default drift metric to use 6: optional DriftMetric driftMetric = DriftMetric.JENSEN_SHANNON -} \ No newline at end of file +} + +struct JoinDriftRequest { + 1: required string name + 2: required i64 startTs + 3: required i64 endTs + 6: optional string offset // Format: "24h" or "7d" + 7: optional DriftMetric algorithm + 8: optional string columnName +} + +struct JoinDriftResponse { + 1: required list driftSeries +} + +struct JoinSummaryRequest { + 1: required string name + 2: required i64 startTs + 3: required i64 endTs + 8: required string columnName +} diff --git a/frontend/src/lib/api/api.test.ts b/frontend/src/lib/api/api.test.ts index 9cbc21a3dc..2be638f30f 100644 --- a/frontend/src/lib/api/api.test.ts +++ b/frontend/src/lib/api/api.test.ts @@ -30,9 +30,9 @@ describe('API module', () => { text: () => Promise.resolve(JSON.stringify(mockResponse)) }); - const result = await api.getModels(); + const result = await api.getModelList(); - expect(mockFetch).toHaveBeenCalledWith(`/api/v1/models`, { + expect(mockFetch).toHaveBeenCalledWith(`/api/v1/conf/list?confType=MODEL`, { method: 'GET', headers: { 'Content-Type': 'application/json' @@ -47,7 +47,7 @@ describe('API module', () => { text: () => Promise.resolve('') }); - const result = await api.getModels(); + const result = await api.getModelList(); expect(result).toEqual({}); }); @@ -58,7 +58,7 @@ describe('API module', () => { status: 404 }); - await api.getModels(); + await api.getModelList(); expect(error).toHaveBeenCalledWith(404); }); diff --git a/frontend/src/lib/api/api.ts b/frontend/src/lib/api/api.ts index 4f0232f1b4..eba01ad652 100644 --- a/frontend/src/lib/api/api.ts +++ b/frontend/src/lib/api/api.ts @@ -1,10 +1,17 @@ import { error } from '@sveltejs/kit'; +import type { FeatureResponse, JoinTimeSeriesResponse } from '$lib/types/Model/Model'; import type { - FeatureResponse, - JoinsResponse, - JoinTimeSeriesResponse, - ModelsResponse -} from '$lib/types/Model/Model'; + Join, + GroupBy, + Model, + StagingQuery, + IJoinDriftRequestArgs, + IJoinDriftResponseArgs, + ITileSummarySeries, + IJoinSummaryRequestArgs +} from '$lib/types/codegen'; +import { ConfType, DriftMetric } from '$lib/types/codegen'; +import type { ConfListResponse } from '$lib/types/codegen/ConfListResponse'; export type ApiOptions = { base?: string; @@ -30,25 +37,35 @@ export class Api { this.#accessToken = opts.accessToken; } - // TODO: eventually move this to a model-specific file/decide on a good project structure for organizing api calls - async getModels() { - return this.#send('models'); - } - - async getJoins(offset: number = 0, limit: number = 10) { + async getConf(name: string, type: ConfType) { const params = new URLSearchParams({ - offset: offset.toString(), - limit: limit.toString() + confName: name, + confType: ConfType[type] }); - return this.#send(`joins?${params.toString()}`); + return this.#send(`conf?${params.toString()}`); + } + + async getJoin(name: string): Promise { + return this.getConf(name, ConfType.JOIN) as Promise; } - async search(term: string, limit: number = 20) { + async getGroupBy(name: string): Promise { + return this.getConf(name, ConfType.GROUP_BY) as Promise; + } + + async getModel(name: string): Promise { + return this.getConf(name, ConfType.MODEL) as Promise; + } + + async getStagingQuery(name: string): Promise { + return this.getConf(name, ConfType.STAGING_QUERY) as Promise; + } + + async search(term: string) { const params = new URLSearchParams({ - term, - limit: limit.toString() + confName: term }); - return this.#send(`search?${params.toString()}`); + return this.#send(`search?${params.toString()}`); } async getJoinTimeseries({ @@ -115,6 +132,74 @@ export class Api { ); } + async getConfList(type: ConfType): Promise { + const params = new URLSearchParams({ + confType: ConfType[type] + }); + return this.#send(`conf/list?${params.toString()}`); + } + + async getJoinList(): Promise { + return this.getConfList(ConfType.JOIN); + } + + async getGroupByList(): Promise { + return this.getConfList(ConfType.GROUP_BY); + } + + async getModelList(): Promise { + return this.getConfList(ConfType.MODEL); + } + + async getStagingQueryList(): Promise { + return this.getConfList(ConfType.STAGING_QUERY); + } + + async getJoinDrift({ + name, + startTs, + endTs, + offset = '10h', + algorithm = DriftMetric.PSI + }: IJoinDriftRequestArgs) { + const params = new URLSearchParams({ + startTs: startTs.toString(), + endTs: endTs.toString(), + offset, + algorithm: DriftMetric[algorithm] + }); + return this.#send(`join/${name}/drift?${params.toString()}`); + } + + async getColumnDrift({ + name, + columnName, + startTs, + endTs, + offset = '10h', + algorithm = DriftMetric.PSI + }: IJoinDriftRequestArgs) { + const params = new URLSearchParams({ + startTs: startTs.toString(), + endTs: endTs.toString(), + offset, + algorithm: DriftMetric[algorithm] + }); + return this.#send( + `join/${name}/column/${columnName}/drift?${params.toString()}` + ); + } + + async getColumnSummary({ name, columnName, startTs, endTs }: IJoinSummaryRequestArgs) { + const params = new URLSearchParams({ + startTs: startTs.toString(), + endTs: endTs.toString() + }); + return this.#send( + `join/${name}/column/${columnName}/summary?${params.toString()}` + ); + } + async #send(resource: string, options?: ApiRequestOptions) { let url = `${this.#base}/${resource}`; diff --git a/frontend/src/lib/components/LogicalNodeTable.svelte b/frontend/src/lib/components/LogicalNodeTable.svelte new file mode 100644 index 0000000000..b4e0405a1b --- /dev/null +++ b/frontend/src/lib/components/LogicalNodeTable.svelte @@ -0,0 +1,64 @@ + + + + +
+ +
+ + + + + + + {title} + + + + {#if items.length === 0} + + + No {title.toLowerCase()} found. + + + {:else} + {#each items as item} + + + + + {item.metaData?.name} + + + + {/each} + {/if} + +
+ + diff --git a/frontend/src/lib/components/NavigationBar.svelte b/frontend/src/lib/components/NavigationBar.svelte index 33ca2a2f63..4aec2c59d2 100644 --- a/frontend/src/lib/components/NavigationBar.svelte +++ b/frontend/src/lib/components/NavigationBar.svelte @@ -11,7 +11,6 @@ CommandEmpty } from '$lib/components/ui/command/'; import { Api } from '$lib/api/api'; - import type { Model } from '$lib/types/Model/Model'; import debounce from 'lodash/debounce'; import { onDestroy, onMount } from 'svelte'; import { @@ -23,7 +22,12 @@ import { goto } from '$app/navigation'; import { isMacOS } from '$lib/util/browser'; import { Badge } from '$lib/components/ui/badge'; - import { getEntity, type Entity } from '$lib/types/Entity/Entity'; + import { + getEntity, + type Entity, + EntityTypes, + type EntityWithType + } from '$lib/types/Entity/Entity'; import IconArrowsUpDown from '~icons/heroicons/arrows-up-down-16-solid'; import IconAdjustmentsHorizontal from '~icons/heroicons/adjustments-horizontal-16-solid'; @@ -45,7 +49,7 @@ const { navItems, user }: Props = $props(); let open = $state(false); - let searchResults: Model[] = $state([]); + let searchResults: EntityWithType[] = $state([]); let isMac: boolean | undefined = $state(undefined); const api = new Api(); @@ -53,7 +57,16 @@ const debouncedSearch = debounce(async () => { if (input.length > 0) { const response = await api.search(input); - searchResults = response.items; + searchResults = [ + ...(response.joins?.map((item) => ({ ...item, entityType: EntityTypes.JOINS })) || []), + ...(response.groupBys?.map((item) => ({ ...item, entityType: EntityTypes.GROUPBYS })) || + []), + ...(response.models?.map((item) => ({ ...item, entityType: EntityTypes.MODELS })) || []), + ...(response.stagingQueries?.map((item) => ({ + ...item, + entityType: EntityTypes.STAGINGQUERIES + })) || []) + ]; } else { searchResults = []; } @@ -235,15 +248,18 @@ {/if} {:else} - {#each searchResults as entity (entity.name)} + {#each searchResults as entity} - handleSelect(`${getEntity('joins').path}/${encodeURIComponent(entity.name)}`)} + handleSelect( + `${getEntity(entity.entityType).path}/${encodeURIComponent(entity.metaData?.name || '')}` + )} > - {@const IconJoins = getEntity('joins').icon} - + {@const IconEntity = getEntity(entity.entityType).icon} + + {entity.metaData?.name} {/each} diff --git a/frontend/src/lib/server/conf-loader.ts b/frontend/src/lib/server/conf-loader.ts new file mode 100644 index 0000000000..a57b2bc204 --- /dev/null +++ b/frontend/src/lib/server/conf-loader.ts @@ -0,0 +1,43 @@ +import { Api } from '$lib/api/api'; +import { ConfType } from '$lib/types/codegen'; +import type { IConfListResponse } from '$lib/types/codegen/ConfListResponse'; +import type { RequestEvent } from '@sveltejs/kit'; +import { entityConfig } from '$lib/types/Entity/Entity'; + +const ConfResponseMap: Record = { + [ConfType.MODEL]: 'models', + [ConfType.STAGING_QUERY]: 'stagingQueries', + [ConfType.GROUP_BY]: 'groupBys', + [ConfType.JOIN]: 'joins' +}; + +export async function loadConfList({ fetch, url }: Pick) { + const path = url.pathname; + const entityMatch = entityConfig.find((entity) => path.startsWith(entity.path)); + + if (!entityMatch) { + return { + items: [], + basePath: path, + title: '' + }; + } + + try { + const api = new Api({ fetch }); + const response = await api.getConfList(entityMatch.type); + if (!response) throw new Error(`Failed to fetch ${entityMatch.label.toLowerCase()}`); + + const responseKey = ConfResponseMap[entityMatch.type]; + const items = response[responseKey] ?? []; + + return { + items: items, + basePath: path, + title: entityMatch.label + }; + } catch (error) { + console.error(`Failed to load ${entityMatch.label.toLowerCase()}:`, error); + throw error; + } +} diff --git a/frontend/src/lib/types/Entity/Entity.ts b/frontend/src/lib/types/Entity/Entity.ts index b93620c841..f585cc449d 100644 --- a/frontend/src/lib/types/Entity/Entity.ts +++ b/frontend/src/lib/types/Entity/Entity.ts @@ -2,39 +2,60 @@ import IconCube from '~icons/heroicons/cube-16-solid'; import IconSquare3Stack3d from '~icons/heroicons/square-3-stack-3d-16-solid'; import IconCubeTransparent from '~icons/heroicons/cube-transparent-16-solid'; import IconRectangleStack from '~icons/heroicons/rectangle-stack-16-solid'; +import { + type IJoin, + type IGroupBy, + type IModel, + type IStagingQuery, + ConfType +} from '$lib/types/codegen'; + +export const EntityTypes = { + MODELS: 'models', + JOINS: 'joins', + GROUPBYS: 'groupbys', + STAGINGQUERIES: 'stagingqueries' +} as const; + +export type EntityId = (typeof EntityTypes)[keyof typeof EntityTypes]; export const entityConfig = [ { label: 'Models', path: '/models', icon: IconCube, - id: 'models' + id: EntityTypes.MODELS, + type: ConfType.MODEL }, { label: 'Joins', path: '/joins', icon: IconSquare3Stack3d, - id: 'joins' + id: EntityTypes.JOINS, + type: ConfType.JOIN }, { label: 'GroupBys', - path: '/GroupBys', + path: '/groupbys', icon: IconRectangleStack, - id: 'groupbys' + id: EntityTypes.GROUPBYS, + type: ConfType.GROUP_BY }, - { label: 'Staging Queries', - path: '/StagingQueries', + path: '/stagingqueries', icon: IconCubeTransparent, - id: 'stagingqueries' + id: EntityTypes.STAGINGQUERIES, + type: ConfType.STAGING_QUERY } ] as const; export type Entity = (typeof entityConfig)[number]; -export type EntityId = Entity['id']; -// Helper function to get entity by ID +// This is a workaround, see https://app.asana.com/0/1208277377735902/1209205208293672/f +export type EntityWithType = (IJoin | IGroupBy | IModel | IStagingQuery) & { entityType: EntityId }; + +// Helper function to get entity config by ID export function getEntity(id: EntityId): Entity { const entity = entityConfig.find((entity) => entity.id === id); if (!entity) throw new Error(`Entity with id "${id}" not found`); diff --git a/frontend/src/lib/types/Model/Model.ts b/frontend/src/lib/types/Model/Model.ts index b45377a094..fde71373b1 100644 --- a/frontend/src/lib/types/Model/Model.ts +++ b/frontend/src/lib/types/Model/Model.ts @@ -65,8 +65,3 @@ export type NullComparedFeatureResponse = { oldValueCount: number; newValueCount: number; }; - -export type JoinsResponse = { - offset: number; - items: Join[]; -}; diff --git a/frontend/src/routes/groupbys/+page.server.ts b/frontend/src/routes/groupbys/+page.server.ts new file mode 100644 index 0000000000..ddc6c06a47 --- /dev/null +++ b/frontend/src/routes/groupbys/+page.server.ts @@ -0,0 +1,4 @@ +import type { PageServerLoad } from './$types'; +import { loadConfList } from '$lib/server/conf-loader'; + +export const load: PageServerLoad = loadConfList; diff --git a/frontend/src/routes/groupbys/+page.svelte b/frontend/src/routes/groupbys/+page.svelte new file mode 100644 index 0000000000..f10bb22943 --- /dev/null +++ b/frontend/src/routes/groupbys/+page.svelte @@ -0,0 +1,7 @@ + + + diff --git a/frontend/src/routes/joins/+page.server.ts b/frontend/src/routes/joins/+page.server.ts index 8c67418f3e..ddc6c06a47 100644 --- a/frontend/src/routes/joins/+page.server.ts +++ b/frontend/src/routes/joins/+page.server.ts @@ -1,12 +1,4 @@ import type { PageServerLoad } from './$types'; -import type { JoinsResponse } from '$lib/types/Model/Model'; -import { Api } from '$lib/api/api'; +import { loadConfList } from '$lib/server/conf-loader'; -export const load: PageServerLoad = async ({ fetch }): Promise<{ joins: JoinsResponse }> => { - const offset = 0; - const limit = 100; - const api = new Api({ fetch }); - return { - joins: await api.getJoins(offset, limit) - }; -}; +export const load: PageServerLoad = loadConfList; diff --git a/frontend/src/routes/joins/+page.svelte b/frontend/src/routes/joins/+page.svelte index 12852a1589..efec7356a1 100644 --- a/frontend/src/routes/joins/+page.svelte +++ b/frontend/src/routes/joins/+page.svelte @@ -1,54 +1,15 @@ - - -
- -
- - - - - Join - - - - {#each reorderedJoins as join} - - - - {join.name} - - - - {/each} - -
- + diff --git a/frontend/src/routes/joins/[slug]/observability/ModelTable.svelte b/frontend/src/routes/joins/[slug]/observability/ModelTable.svelte index 449668656c..1b67fecdf2 100644 --- a/frontend/src/routes/joins/[slug]/observability/ModelTable.svelte +++ b/frontend/src/routes/joins/[slug]/observability/ModelTable.svelte @@ -1,19 +1,22 @@ - - + + - + diff --git a/frontend/src/routes/joins/[slug]/services/joins.service.ts b/frontend/src/routes/joins/[slug]/services/joins.service.ts index d5573b7cb6..45abde4f5a 100644 --- a/frontend/src/routes/joins/[slug]/services/joins.service.ts +++ b/frontend/src/routes/joins/[slug]/services/joins.service.ts @@ -1,14 +1,15 @@ import { Api } from '$lib/api/api'; -import type { JoinTimeSeriesResponse, Model } from '$lib/types/Model/Model'; +import type { JoinTimeSeriesResponse } from '$lib/types/Model/Model'; import type { MetricType } from '$lib/types/MetricType/MetricType'; import { sortDrift, type SortDirection } from '$lib/util/sort'; +import type { IModel } from '$lib/types/codegen'; const FALLBACK_START_TS = 1672531200000; // 2023-01-01 const FALLBACK_END_TS = 1677628800000; // 2023-03-01 export type JoinData = { joinTimeseries: JoinTimeSeriesResponse; - model?: Model; + model?: IModel; metricType: MetricType; dateRange: { startTimestamp: number; @@ -36,11 +37,13 @@ async function fetchInitialData( offset: undefined, algorithm: metricType }), - api.getModels() + api.getModelList() ]); const sortedJoinTimeseries = sortDrift(joinTimeseries, sortDirection); - const modelToReturn = models.items.find((m) => m.join.name === joinName); + const modelToReturn = models.models?.find( + (m) => m.source?.joinSource?.join?.metaData?.name === joinName + ); return { joinTimeseries: sortedJoinTimeseries, diff --git a/frontend/src/routes/models/+page.server.ts b/frontend/src/routes/models/+page.server.ts new file mode 100644 index 0000000000..ddc6c06a47 --- /dev/null +++ b/frontend/src/routes/models/+page.server.ts @@ -0,0 +1,4 @@ +import type { PageServerLoad } from './$types'; +import { loadConfList } from '$lib/server/conf-loader'; + +export const load: PageServerLoad = loadConfList; diff --git a/frontend/src/routes/models/+page.svelte b/frontend/src/routes/models/+page.svelte new file mode 100644 index 0000000000..f10bb22943 --- /dev/null +++ b/frontend/src/routes/models/+page.svelte @@ -0,0 +1,7 @@ + + + diff --git a/frontend/src/routes/stagingqueries/+page.server.ts b/frontend/src/routes/stagingqueries/+page.server.ts new file mode 100644 index 0000000000..ddc6c06a47 --- /dev/null +++ b/frontend/src/routes/stagingqueries/+page.server.ts @@ -0,0 +1,4 @@ +import type { PageServerLoad } from './$types'; +import { loadConfList } from '$lib/server/conf-loader'; + +export const load: PageServerLoad = loadConfList; diff --git a/frontend/src/routes/stagingqueries/+page.svelte b/frontend/src/routes/stagingqueries/+page.svelte new file mode 100644 index 0000000000..f10bb22943 --- /dev/null +++ b/frontend/src/routes/stagingqueries/+page.svelte @@ -0,0 +1,7 @@ + + + diff --git a/frontend/src/routes/thrift/+page.svelte b/frontend/src/routes/thrift/+page.svelte new file mode 100644 index 0000000000..763e6ddc3d --- /dev/null +++ b/frontend/src/routes/thrift/+page.svelte @@ -0,0 +1,135 @@ + + + +
hello!
diff --git a/hub/src/main/java/ai/chronon/hub/HubVerticle.java b/hub/src/main/java/ai/chronon/hub/HubVerticle.java index 3b84f5a32a..6aa5fdbdd8 100644 --- a/hub/src/main/java/ai/chronon/hub/HubVerticle.java +++ b/hub/src/main/java/ai/chronon/hub/HubVerticle.java @@ -3,6 +3,8 @@ import ai.chronon.api.Constants; import ai.chronon.hub.handlers.*; import ai.chronon.hub.store.MonitoringModelStore; +import ai.chronon.observability.JoinDriftRequest; +import ai.chronon.observability.JoinSummaryRequest; import ai.chronon.observability.TileKey; import ai.chronon.online.Api; import ai.chronon.online.KVStore; @@ -58,10 +60,10 @@ protected void startHttpServer(int port, String configJsonString, Api api, Promi // Add routes for metadata retrieval MonitoringModelStore store = new MonitoringModelStore(api); - router.get("/api/v1/models").handler(new ModelsHandler(store)); - router.get("/api/v1/join/:name").handler(new JoinsHandler(store).getHandler()); - router.get("/api/v1/joins").handler(new JoinsHandler(store).listHandler()); - router.get("/api/v1/search").handler(new SearchHandler(store)); + ConfHandler confHandler = new ConfHandler(store); + router.get("/api/v1/conf").handler(RouteHandlerWrapper.createHandler(confHandler::getConf, ConfRequest.class)); + router.get("/api/v1/conf/list").handler(RouteHandlerWrapper.createHandler(confHandler::getConfList, ConfListRequest.class)); + router.get("/api/v1/search").handler(RouteHandlerWrapper.createHandler(confHandler::searchConf, ConfRequest.class)); router.get("/api/v1/:name/job/type/:type").handler(RouteHandlerWrapper.createHandler(JobTracker::handle, JobTrackerRequest.class)); // hacked up in mem kv store bulkPut @@ -77,6 +79,11 @@ protected void startHttpServer(int port, String configJsonString, Api api, Promi router.get("/api/v1/join/:name/timeseries").handler(new TimeSeriesHandler(driftStore).joinDriftHandler()); router.get("/api/v1/join/:join/feature/:name/timeseries").handler(new TimeSeriesHandler(driftStore).featureDriftHandler()); + DriftHandler driftHandler = new DriftHandler(driftStore); + router.get("/api/v1/join/:name/drift").handler(RouteHandlerWrapper.createHandler(driftHandler::getJoinDrift, JoinDriftRequest.class)); + router.get("/api/v1/join/:name/column/:columnName/drift").handler(RouteHandlerWrapper.createHandler(driftHandler::getColumnDrift, JoinDriftRequest.class)); + router.get("/api/v1/join/:name/column/:columnName/summary").handler(RouteHandlerWrapper.createHandler(driftHandler::getColumnSummary, JoinSummaryRequest.class)); + // Start HTTP server HttpServerOptions httpOptions = new HttpServerOptions() diff --git a/hub/src/main/scala/ai/chronon/hub/handlers/ConfHandler.scala b/hub/src/main/scala/ai/chronon/hub/handlers/ConfHandler.scala new file mode 100644 index 0000000000..6e52867aa4 --- /dev/null +++ b/hub/src/main/scala/ai/chronon/hub/handlers/ConfHandler.scala @@ -0,0 +1,110 @@ +package ai.chronon.hub.handlers + +import ai.chronon.hub.ConfListRequest +import ai.chronon.hub.ConfListResponse +import ai.chronon.hub.ConfRequest +import ai.chronon.hub.ConfType +import ai.chronon.hub.store.MonitoringModelStore +import ai.chronon.orchestration.LogicalNode +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +import scala.collection.JavaConverters._ + +class ConfHandler(store: MonitoringModelStore) { + private val logger: Logger = LoggerFactory.getLogger(this.getClass) + + /** + * Returns a specific configuration by name and type + */ + def getConf(req: ConfRequest): LogicalNode = { + logger.debug(s"Retrieving ${req.getConfName} of type ${req.getConfType}") + val registry = store.configRegistryCache("default") + + val node = new LogicalNode() + + req.getConfType match { + case ConfType.JOIN => + node.setJoin(findConfig(registry.joins, "join", req.getConfName)) + case ConfType.GROUP_BY => + node.setGroupBy(findConfig(registry.groupBys, "groupBy", req.getConfName)) + case ConfType.MODEL => + node.setModel(findConfig(registry.models, "model", req.getConfName)) + case ConfType.STAGING_QUERY => + node.setStagingQuery(findConfig(registry.stagingQueries, "staging query", req.getConfName)) + case _ => throw new RuntimeException(s"Unsupported configuration type ${req.getConfType}") + } + + node + } + + /** + * Finds a specific configuration by name within a sequence of configs + */ + private def findConfig[T](configs: Seq[T], configType: String, name: String): T = { + configs + .find(_.asInstanceOf[{ def getMetaData: { def getName: String } }].getMetaData.getName.equalsIgnoreCase(name)) + .getOrElse(throw new RuntimeException(s"Unable to retrieve $configType $name")) + } + + /** + * Returns all configurations of a specific type + */ + def getConfList(req: ConfListRequest): ConfListResponse = { + logger.debug(s"Retrieving all configurations of type ${req.getConfType}") + val registry = store.configRegistryCache("default") + + val response = new ConfListResponse() + + req.getConfType match { + case ConfType.JOIN => + response.setJoins(registry.joins.asJava) + case ConfType.GROUP_BY => + response.setGroupBys(registry.groupBys.asJava) + case ConfType.MODEL => + response.setModels(registry.models.asJava) + case ConfType.STAGING_QUERY => + response.setStagingQueries(registry.stagingQueries.asJava) + case _ => throw new RuntimeException(s"Unsupported configuration type ${req.getConfType}") + } + response + } + + /** + * Returns configurations matching the search criteria + */ + def searchConf(req: ConfRequest): ConfListResponse = { + logger.debug(s"Searching for configurations matching '${req.getConfName}' of type ${req.getConfType}") + val registry = store.configRegistryCache("default") + val searchTerm = Option(req.getConfName).getOrElse("").toLowerCase + + val response = new ConfListResponse() + + // Helper function to filter configs by name + def filterByName[T](configs: Seq[T]): Seq[T] = { + configs.filter( + _.asInstanceOf[{ def getMetaData: { def getName: String } }].getMetaData.getName.toLowerCase + .contains(searchTerm)) + } + + // If confType is specified, only search that type + Option(req.getConfType) match { + case Some(ConfType.JOIN) => + response.setJoins(filterByName(registry.joins).asJava) + case Some(ConfType.GROUP_BY) => + response.setGroupBys(filterByName(registry.groupBys).asJava) + case Some(ConfType.MODEL) => + response.setModels(filterByName(registry.models).asJava) + case Some(ConfType.STAGING_QUERY) => + response.setStagingQueries(filterByName(registry.stagingQueries).asJava) + case None => + // If no type specified, search all types + response.setJoins(filterByName(registry.joins).asJava) + response.setGroupBys(filterByName(registry.groupBys).asJava) + response.setModels(filterByName(registry.models).asJava) + response.setStagingQueries(filterByName(registry.stagingQueries).asJava) + } + + response + } +} diff --git a/hub/src/main/scala/ai/chronon/hub/handlers/DriftHandler.scala b/hub/src/main/scala/ai/chronon/hub/handlers/DriftHandler.scala new file mode 100644 index 0000000000..4cf4e883c6 --- /dev/null +++ b/hub/src/main/scala/ai/chronon/hub/handlers/DriftHandler.scala @@ -0,0 +1,102 @@ +package ai.chronon.hub.handlers + +import ai.chronon.api.TimeUnit +import ai.chronon.api.Window +import ai.chronon.observability.JoinDriftRequest +import ai.chronon.observability.JoinDriftResponse +import ai.chronon.observability.JoinSummaryRequest +import ai.chronon.observability.TileDriftSeries +import ai.chronon.observability.TileSummarySeries +import ai.chronon.online.stats.DriftStore +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +import scala.collection.JavaConverters._ +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.util.Failure +import scala.util.Success + +class DriftHandler(driftStore: DriftStore) { + private val logger: Logger = LoggerFactory.getLogger(this.getClass) + + private def getDriftSeriesWithWindow(req: JoinDriftRequest): Seq[TileDriftSeries] = { + logger.debug(s"Processing drift request for join: ${req.getName}, algorithm: ${req.getAlgorithm}, feature: ${Option( + req.getColumnName).getOrElse("none")}") + + val offsetDuration = parseOffset(Option(req.getOffset)) match { + case Some(duration) => + logger.debug(s"Parsed offset duration: $duration") + duration + case None => + logger.error(s"Failed to parse offset: ${req.getOffset}") + throw new IllegalArgumentException(s"Unable to parse offset - ${req.getOffset}") + } + + val window = new Window(offsetDuration.toMinutes.toInt, TimeUnit.MINUTES) + val joinPath = req.getName.replaceFirst("\\.", "/") + logger.debug(s"Querying drift store with window: $window, joinPath: $joinPath") + + driftStore.getDriftSeries( + joinPath, + req.getAlgorithm, + window, + req.getStartTs, + req.getEndTs, + Option(req.getColumnName) + ) match { + case Success(driftSeriesFuture) => + val result = Await.result(driftSeriesFuture, 30.seconds) + logger.debug(s"Successfully retrieved ${result.size} drift series entries") + result + case Failure(exception) => + logger.error("Failed to retrieve drift series", exception) + throw new RuntimeException(s"Error getting drift - ${exception.getMessage}") + } + } + + def getJoinDrift(req: JoinDriftRequest): JoinDriftResponse = { + val driftSeries = getDriftSeriesWithWindow(req) + new JoinDriftResponse().setDriftSeries(driftSeries.asJava) + } + + def getColumnDrift(req: JoinDriftRequest): TileDriftSeries = { + val driftSeries = getDriftSeriesWithWindow(req) + driftSeries.headOption.getOrElse(new TileDriftSeries()) + } + + private def parseOffset(offset: Option[String]): Option[Duration] = { + logger.debug(s"Parsing offset: $offset") + val hourPattern = """(\d+)h""".r + val dayPattern = """(\d+)d""".r + val result = offset.map(_.toLowerCase) match { + case Some(hourPattern(num)) => Some(num.toInt.hours) + case Some(dayPattern(num)) => Some(num.toInt.days) + case _ => None + } + logger.debug(s"Parsed offset result: $result") + result + } + + def getColumnSummary(req: JoinSummaryRequest): TileSummarySeries = { + logger.debug(s"Processing summary request for join: ${req.getName}, column: ${req.getColumnName}") + + val joinPath = req.getName.replaceFirst("\\.", "/") + logger.debug(s"Querying summary store with joinPath: $joinPath") + + driftStore.getSummarySeries( + joinPath, + req.getStartTs, + req.getEndTs, + Some(req.getColumnName) + ) match { + case Success(summarySeriesFuture) => + val result = Await.result(summarySeriesFuture, 30.seconds) + logger.debug(s"Successfully retrieved ${result.size} summary series entries") + result.headOption.getOrElse(new TileSummarySeries()) + case Failure(exception) => + logger.error("Failed to retrieve summary series", exception) + throw new RuntimeException(s"Error getting summary - ${exception.getMessage}") + } + } +} diff --git a/hub/src/main/scala/ai/chronon/hub/handlers/JoinsHandler.scala b/hub/src/main/scala/ai/chronon/hub/handlers/JoinsHandler.scala deleted file mode 100644 index 11bad55146..0000000000 --- a/hub/src/main/scala/ai/chronon/hub/handlers/JoinsHandler.scala +++ /dev/null @@ -1,52 +0,0 @@ -package ai.chronon.hub.handlers - -import ai.chronon.hub.model.ListJoinResponse -import ai.chronon.hub.store.MonitoringModelStore -import io.circe.generic.auto._ -import io.circe.syntax._ -import io.vertx.core.Handler -import io.vertx.ext.web.RoutingContext -import org.slf4j.Logger -import org.slf4j.LoggerFactory - -class JoinsHandler(monitoringStore: MonitoringModelStore) extends Paginate { - - val logger: Logger = LoggerFactory.getLogger(this.getClass) - import VertxExtensions._ - - /** - * Powers the /api/v1/joins endpoint. Returns a list of models - * offset - For pagination. We skip over offset entries before returning results - * limit - Number of elements to return - */ - val listHandler: Handler[RoutingContext] = (ctx: RoutingContext) => { - val offset = Option(ctx.queryParams.get("offset")).map(_.toInt).getOrElse(defaultOffset) - val limit = Option(ctx.queryParams.get("limit")).map(l => math.min(l.toInt, maxLimit)).getOrElse(defaultLimit) - - if (offset < 0) { - ctx.BadRequest("Invalid offset - expect a positive number") - } else if (limit < 0) { - ctx.BadRequest("Invalid limit - expect a positive number") - } else { - val joins = monitoringStore.getJoins - val paginatedResults = paginateResults(joins, offset, limit) - val json = ListJoinResponse(offset, paginatedResults).asJson.noSpaces - ctx.Ok(json) - } - } - - /** - * Returns a specific join by name (/api/v1/join/:name) - */ - val getHandler: Handler[RoutingContext] = (ctx: RoutingContext) => { - val entityName = ctx.pathParam("name"); - logger.debug("Retrieving {}", entityName); - - val maybeJoin = monitoringStore.getJoins.find(j => j.name.equalsIgnoreCase(entityName)) - maybeJoin match { - case None => ctx.NotFound(s"Unable to retrive $entityName") - case Some(join) => ctx.Ok(join.asJson.noSpaces) - } - } - -} diff --git a/hub/src/main/scala/ai/chronon/hub/handlers/ModelsHandler.scala b/hub/src/main/scala/ai/chronon/hub/handlers/ModelsHandler.scala deleted file mode 100644 index 8cdf57d0ff..0000000000 --- a/hub/src/main/scala/ai/chronon/hub/handlers/ModelsHandler.scala +++ /dev/null @@ -1,34 +0,0 @@ -package ai.chronon.hub.handlers - -import ai.chronon.hub.model.ListModelResponse -import ai.chronon.hub.store.MonitoringModelStore -import io.circe.generic.auto._ -import io.circe.syntax._ -import io.vertx.core.Handler -import io.vertx.ext.web.RoutingContext - -/** - * Powers the /api/v1/models endpoint. Returns a list of models - * offset - For pagination. We skip over offset entries before returning results - * limit - Number of elements to return - */ -class ModelsHandler(monitoringStore: MonitoringModelStore) extends Handler[RoutingContext] with Paginate { - - import VertxExtensions._ - - override def handle(ctx: RoutingContext): Unit = { - val offset = Option(ctx.queryParams.get("offset")).map(_.toInt).getOrElse(defaultOffset) - val limit = Option(ctx.queryParams.get("limit")).map(l => math.min(l.toInt, maxLimit)).getOrElse(defaultLimit) - - if (offset < 0) { - ctx.BadRequest("Invalid offset - expect a positive number") - } else if (limit < 0) { - ctx.BadRequest("Invalid limit - expect a positive number") - } else { - val models = monitoringStore.getModels - val paginatedResults = paginateResults(models, offset, limit) - val json = ListModelResponse(offset, paginatedResults).asJson.noSpaces - ctx.Ok(json) - } - } -} diff --git a/hub/src/main/scala/ai/chronon/hub/handlers/Paginate.scala b/hub/src/main/scala/ai/chronon/hub/handlers/Paginate.scala deleted file mode 100644 index 34ed49abcb..0000000000 --- a/hub/src/main/scala/ai/chronon/hub/handlers/Paginate.scala +++ /dev/null @@ -1,11 +0,0 @@ -package ai.chronon.hub.handlers - -trait Paginate { - val defaultOffset = 0 - val defaultLimit = 10 - val maxLimit = 100 - - def paginateResults[T](results: Seq[T], offset: Int, limit: Int): Seq[T] = { - results.slice(offset, offset + limit) - } -} diff --git a/hub/src/main/scala/ai/chronon/hub/handlers/SearchHandler.scala b/hub/src/main/scala/ai/chronon/hub/handlers/SearchHandler.scala deleted file mode 100644 index 89ec1c2828..0000000000 --- a/hub/src/main/scala/ai/chronon/hub/handlers/SearchHandler.scala +++ /dev/null @@ -1,43 +0,0 @@ -package ai.chronon.hub.handlers - -import ai.chronon.hub.model.Join -import ai.chronon.hub.model.SearchJoinResponse -import ai.chronon.hub.store.MonitoringModelStore -import io.circe.generic.auto._ -import io.circe.syntax._ -import io.vertx.core.Handler -import io.vertx.ext.web.RoutingContext - -/** - * Powers the /api/v1/search endpoint. Returns a list of joins - * term - Search term to search for (currently we only support searching join names) - * offset - For pagination. We skip over offset entries before returning results - * limit - Number of elements to return - */ -class SearchHandler(monitoringStore: MonitoringModelStore) extends Handler[RoutingContext] with Paginate { - - import VertxExtensions._ - - override def handle(ctx: RoutingContext): Unit = { - val term = Option(ctx.queryParams.get("term")).getOrElse("") - val offset = Option(ctx.queryParams.get("offset")).map(_.toInt).getOrElse(defaultOffset) - val limit = Option(ctx.queryParams.get("limit")).map(l => math.min(l.toInt, maxLimit)).getOrElse(defaultLimit) - - if (offset < 0) { - ctx.BadRequest("Invalid offset - expect a positive number") - } else if (limit < 0) { - ctx.BadRequest("Invalid limit - expect a positive number") - } else { - val searchResults = searchRegistry(term) - val paginatedResults = paginateResults(searchResults, offset, limit) - val json = SearchJoinResponse(offset, paginatedResults).asJson.noSpaces - ctx.Ok(json) - } - } - - // a trivial search where we check the join name for similarity with the search term - private def searchRegistry(term: String): Seq[Join] = { - val joins = monitoringStore.getJoins - joins.filter(j => j.name.contains(term)) - } -} diff --git a/hub/src/main/scala/ai/chronon/hub/handlers/TimeSeriesHandler.scala b/hub/src/main/scala/ai/chronon/hub/handlers/TimeSeriesHandler.scala index 99a6d34d90..9fe7ce0856 100644 --- a/hub/src/main/scala/ai/chronon/hub/handlers/TimeSeriesHandler.scala +++ b/hub/src/main/scala/ai/chronon/hub/handlers/TimeSeriesHandler.scala @@ -119,13 +119,15 @@ class TimeSeriesHandler(driftStore: DriftStore) { // check if we have a numeric / categorical feature. If the percentile drift series has non-null doubles // then we have a numeric feature at hand val isNumeric = - tileDriftSeries.percentileDriftSeries.asScala != null && tileDriftSeries.percentileDriftSeries.asScala - .exists(_ != null) + Option(tileDriftSeries.percentileDriftSeries).exists(series => series.asScala.exists(_ != null)) + val lhsList = if (metric == NullMetric) { - tileDriftSeries.nullRatioChangePercentSeries.asScala + Option(tileDriftSeries.nullRatioChangePercentSeries).map(_.asScala).getOrElse(Seq.empty) } else { - if (isNumeric) tileDriftSeries.percentileDriftSeries.asScala - else tileDriftSeries.histogramDriftSeries.asScala + if (isNumeric) + Option(tileDriftSeries.percentileDriftSeries).map(_.asScala).getOrElse(Seq.empty) + else + Option(tileDriftSeries.histogramDriftSeries).map(_.asScala).getOrElse(Seq.empty) } val points = lhsList.zip(tileDriftSeries.timestamps.asScala).map { case (v, ts) => TimeSeriesPoint(v, ts) diff --git a/hub/src/main/scala/ai/chronon/hub/model/Model.scala b/hub/src/main/scala/ai/chronon/hub/model/Model.scala index 41dd1ae91a..d949938de7 100644 --- a/hub/src/main/scala/ai/chronon/hub/model/Model.scala +++ b/hub/src/main/scala/ai/chronon/hub/model/Model.scala @@ -57,10 +57,4 @@ case class ComparedFeatureTimeSeries(feature: String, current: Seq[TimeSeriesPoint]) case class GroupByTimeSeries(name: String, items: Seq[FeatureTimeSeries]) -// Currently search only covers joins -case class ListModelResponse(offset: Int, items: Seq[Model]) -case class SearchJoinResponse(offset: Int, items: Seq[Join]) -case class ListJoinResponse(offset: Int, items: Seq[Join]) - -case class ModelTimeSeriesResponse(id: String, items: Seq[TimeSeriesPoint]) case class JoinTimeSeriesResponse(name: String, items: Seq[GroupByTimeSeries]) diff --git a/hub/src/test/scala/ai/chronon/hub/handlers/ConfHandlerTest.scala b/hub/src/test/scala/ai/chronon/hub/handlers/ConfHandlerTest.scala new file mode 100644 index 0000000000..e1be553401 --- /dev/null +++ b/hub/src/test/scala/ai/chronon/hub/handlers/ConfHandlerTest.scala @@ -0,0 +1,296 @@ +package ai.chronon.hub.handlers + +import ai.chronon.api.GroupBy +import ai.chronon.api.Join +import ai.chronon.api.MetaData +import ai.chronon.api.Model +import ai.chronon.api.StagingQuery +import ai.chronon.hub.ConfListRequest +import ai.chronon.hub.ConfRequest +import ai.chronon.hub.ConfType +import ai.chronon.hub.store.LoadedConfs +import ai.chronon.hub.store.MonitoringModelStore +import ai.chronon.online.TTLCache +import io.vertx.ext.unit.junit.VertxUnitRunner +import org.junit.Assert._ +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.mockito.Mock +import org.mockito.Mockito.when +import org.mockito.MockitoAnnotations + +@RunWith(classOf[VertxUnitRunner]) +class ConfHandlerTest { + + @Mock private var mockedStore: MonitoringModelStore = _ + @Mock private var mockedCache: TTLCache[String, LoadedConfs] = _ + private var handler: ConfHandler = _ + private val defaultRegistry = MockConfigRegistry.createMockRegistry() + + @Before + def setUp(): Unit = { + MockitoAnnotations.openMocks(this) + handler = new ConfHandler(mockedStore) + + when(mockedStore.configRegistryCache).thenReturn(mockedCache) + when(mockedCache.apply("default")).thenReturn(LoadedConfs( + joins = defaultRegistry.joins, + groupBys = defaultRegistry.groupBys, + stagingQueries = defaultRegistry.stagingQueries, + models = defaultRegistry.models + )) + } + + @Test + def testGetConfForJoin(): Unit = { + val request = new ConfRequest() + request.setConfType(ConfType.JOIN) + request.setConfName("test_join_1") + + val result = handler.getConf(request) + + assertNotNull(result.getJoin) + assertEquals("test_join_1", result.getJoin.getMetaData.getName) + } + + @Test + def testGetConfForModel(): Unit = { + val request = new ConfRequest() + request.setConfType(ConfType.MODEL) + request.setConfName("test_model_1") + + val result = handler.getConf(request) + + assertNotNull(result.getModel) + assertEquals("test_model_1", result.getModel.getMetaData.getName) + } + + @Test + def testGetConfForGroupBy(): Unit = { + val request = new ConfRequest() + request.setConfType(ConfType.GROUP_BY) + request.setConfName("test_groupby_1") + + val result = handler.getConf(request) + + assertNotNull(result.getGroupBy) + assertEquals("test_groupby_1", result.getGroupBy.getMetaData.getName) + } + + @Test + def testGetConfForStagingQuery(): Unit = { + val request = new ConfRequest() + request.setConfType(ConfType.STAGING_QUERY) + request.setConfName("test_query_1") + + val result = handler.getConf(request) + + assertNotNull(result.getStagingQuery) + assertEquals("test_query_1", result.getStagingQuery.getMetaData.getName) + } + + @Test(expected = classOf[RuntimeException]) + def testGetConfForNonexistentJoin(): Unit = { + val request = new ConfRequest() + request.setConfType(ConfType.JOIN) + request.setConfName("nonexistent_join") + handler.getConf(request) + } + + @Test(expected = classOf[RuntimeException]) + def testGetConfForNonexistentModel(): Unit = { + val request = new ConfRequest() + request.setConfType(ConfType.MODEL) + request.setConfName("nonexistent_model") + handler.getConf(request) + } + + @Test(expected = classOf[RuntimeException]) + def testGetConfForNonexistentGroupBy(): Unit = { + val request = new ConfRequest() + request.setConfType(ConfType.GROUP_BY) + request.setConfName("nonexistent_groupby") + handler.getConf(request) + } + + @Test(expected = classOf[RuntimeException]) + def testGetConfForNonexistentStagingQuery(): Unit = { + val request = new ConfRequest() + request.setConfType(ConfType.STAGING_QUERY) + request.setConfName("nonexistent_query") + handler.getConf(request) + } + + @Test + def testGetConfListForJoins(): Unit = { + val request = new ConfListRequest() + request.setConfType(ConfType.JOIN) + val result = handler.getConfList(request) + assertNotNull(result.getJoins) + assertEquals(3, result.getJoins.size()) + } + + @Test + def testGetConfListForModels(): Unit = { + val request = new ConfListRequest() + request.setConfType(ConfType.MODEL) + val result = handler.getConfList(request) + assertNotNull(result.getModels) + assertEquals(2, result.getModels.size()) + } + + @Test + def testGetConfListForGroupBys(): Unit = { + val request = new ConfListRequest() + request.setConfType(ConfType.GROUP_BY) + val result = handler.getConfList(request) + assertNotNull(result.getGroupBys) + assertEquals(2, result.getGroupBys.size()) + } + + @Test + def testGetConfListForStagingQueries(): Unit = { + val request = new ConfListRequest() + request.setConfType(ConfType.STAGING_QUERY) + val result = handler.getConfList(request) + assertNotNull(result.getStagingQueries) + assertEquals(2, result.getStagingQueries.size()) + } + + @Test + def testSearchConfForJoins(): Unit = { + val request = new ConfRequest() + request.setConfType(ConfType.JOIN) + request.setConfName("1") + val result = handler.searchConf(request) + assertNotNull(result.getJoins) + assertEquals(1, result.getJoins.size()) + assertEquals("test_join_1", result.getJoins.get(0).getMetaData.getName) + } + + @Test + def testSearchConfForModels(): Unit = { + val request = new ConfRequest() + request.setConfType(ConfType.MODEL) + request.setConfName("1") + val result = handler.searchConf(request) + assertNotNull(result.getModels) + assertEquals(1, result.getModels.size()) + assertEquals("test_model_1", result.getModels.get(0).getMetaData.getName) + } + + @Test + def testSearchConfForGroupBys(): Unit = { + val request = new ConfRequest() + request.setConfType(ConfType.GROUP_BY) + request.setConfName("1") + val result = handler.searchConf(request) + assertNotNull(result.getGroupBys) + assertEquals(1, result.getGroupBys.size()) + assertEquals("test_groupby_1", result.getGroupBys.get(0).getMetaData.getName) + } + + @Test + def testSearchConfForStagingQueries(): Unit = { + val request = new ConfRequest() + request.setConfType(ConfType.STAGING_QUERY) + request.setConfName("1") + val result = handler.searchConf(request) + assertNotNull(result.getStagingQueries) + assertEquals(1, result.getStagingQueries.size()) + assertEquals("test_query_1", result.getStagingQueries.get(0).getMetaData.getName) + } + + @Test + def testSearchConfAcrossAllTypes(): Unit = { + val request = new ConfRequest() + request.setConfName("1") + + val result = handler.searchConf(request) + + assertEquals(1, result.getJoins.size()) + assertEquals(1, result.getModels.size()) + assertEquals(1, result.getGroupBys.size()) + assertEquals(1, result.getStagingQueries.size()) + } + + @Test + def testSearchConfWithPartialName(): Unit = { + val request = new ConfRequest() + request.setConfType(ConfType.JOIN) + request.setConfName("join") + + val result = handler.searchConf(request) + assertNotNull(result.getJoins) + assertEquals(3, result.getJoins.size()) + assertEquals("test_join_1", result.getJoins.get(0).getMetaData.getName) + } + + @Test + def testSearchConfWithEmptyString(): Unit = { + val request = new ConfRequest() + request.setConfType(ConfType.JOIN) + request.setConfName("") + + val result = handler.searchConf(request) + assertNotNull(result.getJoins) + assertEquals(3, result.getJoins.size()) + } +} + +object MockConfigRegistry { + def createMockRegistry(): createMockRegistry = new createMockRegistry() + class createMockRegistry() extends { + val joins: Seq[Join] = Seq( + createMockJoin("test_join_1"), + createMockJoin("test_join_2"), + createMockJoin("test_join_3") + ) + + val models: Seq[Model] = Seq( + createMockModel("test_model_1"), + createMockModel("test_model_2") + ) + + val groupBys: Seq[GroupBy] = Seq( + createMockGroupBy("test_groupby_1"), + createMockGroupBy("test_groupby_2") + ) + + val stagingQueries: Seq[StagingQuery] = Seq( + createMockStagingQuery("test_query_1"), + createMockStagingQuery("test_query_2") + ) + } + + private def createMetaData(name: String): MetaData = { + val metadata = new MetaData() + metadata.setName(name) + metadata + } + + private def createMockJoin(name: String): Join = { + val join = new Join() + join.setMetaData(createMetaData(name)) + join + } + + private def createMockModel(name: String): Model = { + val model = new Model() + model.setMetaData(createMetaData(name)) + model + } + + private def createMockGroupBy(name: String): GroupBy = { + val groupBy = new GroupBy() + groupBy.setMetaData(createMetaData(name)) + groupBy + } + + private def createMockStagingQuery(name: String): StagingQuery = { + val query = new StagingQuery() + query.setMetaData(createMetaData(name)) + query + } +} \ No newline at end of file diff --git a/hub/src/test/scala/ai/chronon/hub/handlers/DriftHandlerTest.scala b/hub/src/test/scala/ai/chronon/hub/handlers/DriftHandlerTest.scala new file mode 100644 index 0000000000..24a80f34f2 --- /dev/null +++ b/hub/src/test/scala/ai/chronon/hub/handlers/DriftHandlerTest.scala @@ -0,0 +1,208 @@ +package ai.chronon.hub.handlers + +import ai.chronon.api.Window +import ai.chronon.observability._ +import ai.chronon.online.stats.DriftStore +import io.vertx.ext.unit.junit.VertxUnitRunner +import org.junit.Assert._ +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.mockito.ArgumentMatchers.any +import org.mockito.ArgumentMatchers.anyLong +import org.mockito.ArgumentMatchers.anyString +import org.mockito.Mock +import org.mockito.Mockito._ +import org.mockito.MockitoAnnotations + +import java.util.concurrent.TimeUnit +import scala.collection.JavaConverters._ +import scala.concurrent.Future +import scala.util.Success + +@RunWith(classOf[VertxUnitRunner]) +class DriftHandlerTest { + + @Mock private var mockedDriftStore: DriftStore = _ + private var handler: DriftHandler = _ + + private val testJoinName = "test_join" + private val testColumnName = "test_column" + private val baseTimestamp = System.currentTimeMillis() + + @Before + def setUp(): Unit = { + MockitoAnnotations.openMocks(this) + handler = new DriftHandler(mockedDriftStore) + } + + private def createMockDriftSeries(timestamp: Long): TileDriftSeries = { + val series = new TileDriftSeries() + series.setTimestamps(List(timestamp: java.lang.Long).asJava) + series.setPercentileDriftSeries(List(0.1: java.lang.Double).asJava) + series.setHistogramDriftSeries(List(0.2: java.lang.Double).asJava) + series.setCountChangePercentSeries(List(0.3: java.lang.Double).asJava) + + val key = new TileSeriesKey() + key.setColumn(testColumnName) + series.setKey(key) + series + } + + private def createMockSummarySeries(timestamp: Long): TileSummarySeries = { + val series = new TileSummarySeries() + series.setTimestamps(List(timestamp: java.lang.Long).asJava) + series.setPercentiles(List(List(0.1: java.lang.Double, 0.5: java.lang.Double, 0.9: java.lang.Double).asJava).asJava) + series.setCount(List(100L: java.lang.Long).asJava) + + val key = new TileSeriesKey() + key.setColumn(testColumnName) + series.setKey(key) + series + } + + @Test + def testGetDriftWithHourOffset(): Unit = { + val request = new JoinDriftRequest() + request.setName(testJoinName) + request.setColumnName(testColumnName) + request.setStartTs(baseTimestamp - TimeUnit.HOURS.toMillis(24)) + request.setEndTs(baseTimestamp) + request.setOffset("24h") + request.setAlgorithm(DriftMetric.JENSEN_SHANNON) + + val mockDriftSeries = List(createMockDriftSeries(baseTimestamp)) + when(mockedDriftStore.getDriftSeries( + anyString(), + any[DriftMetric], + any[Window], + anyLong(), + anyLong(), + any[Option[String]] + )).thenReturn(Success(Future.successful(mockDriftSeries))) + + val response = handler.getJoinDrift(request) + assertNotNull(response.getDriftSeries) + assertEquals(1, response.getDriftSeries.size()) + assertEquals(testColumnName, response.getDriftSeries.get(0).getKey.getColumn) + } + + @Test + def testGetDriftWithDayOffset(): Unit = { + val request = new JoinDriftRequest() + request.setName(testJoinName) + request.setColumnName(testColumnName) + request.setStartTs(baseTimestamp - TimeUnit.DAYS.toMillis(7)) + request.setEndTs(baseTimestamp) + request.setOffset("7d") + request.setAlgorithm(DriftMetric.JENSEN_SHANNON) + + val mockDriftSeries = List(createMockDriftSeries(baseTimestamp)) + when(mockedDriftStore.getDriftSeries( + anyString(), + any[DriftMetric], + any[Window], + anyLong(), + anyLong(), + any[Option[String]] + )).thenReturn(Success(Future.successful(mockDriftSeries))) + + val response = handler.getJoinDrift(request) + assertNotNull(response.getDriftSeries) + assertEquals(1, response.getDriftSeries.size()) + } + + @Test(expected = classOf[IllegalArgumentException]) + def testGetDriftWithInvalidOffset(): Unit = { + val request = new JoinDriftRequest() + request.setName(testJoinName) + request.setOffset("invalid") + request.setStartTs(baseTimestamp - TimeUnit.HOURS.toMillis(24)) + request.setEndTs(baseTimestamp) + + handler.getJoinDrift(request) + } + + @Test + def testGetSummary(): Unit = { + val request = new JoinSummaryRequest() + request.setName(testJoinName) + request.setColumnName(testColumnName) + request.setStartTs(baseTimestamp - TimeUnit.HOURS.toMillis(24)) + request.setEndTs(baseTimestamp) + + val mockSummarySeries = List(createMockSummarySeries(baseTimestamp)) + when(mockedDriftStore.getSummarySeries( + anyString(), + anyLong(), + anyLong(), + any[Option[String]] + )).thenReturn(Success(Future.successful(mockSummarySeries))) + + val response = handler.getColumnSummary(request) + assertNotNull(response) + assertNotNull(response.getTimestamps) + assertEquals(1, response.getTimestamps.size()) + assertEquals(testColumnName, response.getKey.getColumn) + } + + @Test(expected = classOf[IllegalArgumentException]) + def testGetDriftWithEndTimeBeforeStartTime(): Unit = { + val request = new JoinDriftRequest() + request.setName(testJoinName) + request.setStartTs(baseTimestamp) + request.setEndTs(baseTimestamp - TimeUnit.HOURS.toMillis(24)) + // We don't need to set offset or algorithm since the timestamp validation should happen first + + handler.getJoinDrift(request) + } + + @Test + def testGetDriftWithDefaultAlgorithm(): Unit = { + val request = new JoinDriftRequest() + request.setName(testJoinName) + request.setColumnName(testColumnName) + request.setStartTs(baseTimestamp - TimeUnit.HOURS.toMillis(24)) + request.setEndTs(baseTimestamp) + request.setOffset("24h") + // Not setting algorithm should use default + + val mockDriftSeries = List(createMockDriftSeries(baseTimestamp)) + when(mockedDriftStore.getDriftSeries( + anyString(), + any[DriftMetric], + any[Window], + anyLong(), + anyLong(), + any[Option[String]] + )).thenReturn(Success(Future.successful(mockDriftSeries))) + + val response = handler.getJoinDrift(request) + assertNotNull(response.getDriftSeries) + assertEquals(1, response.getDriftSeries.size()) + } + + @Test + def testGetDriftWithNullColumnName(): Unit = { + val request = new JoinDriftRequest() + request.setName(testJoinName) + request.setStartTs(baseTimestamp - TimeUnit.HOURS.toMillis(24)) + request.setEndTs(baseTimestamp) + request.setOffset("24h") + // Not setting columnName + + val mockDriftSeries = List(createMockDriftSeries(baseTimestamp)) + when(mockedDriftStore.getDriftSeries( + anyString(), + any[DriftMetric], + any[Window], + anyLong(), + anyLong(), + any[Option[String]] + )).thenReturn(Success(Future.successful(mockDriftSeries))) + + val response = handler.getJoinDrift(request) + assertNotNull(response.getDriftSeries) + assertEquals(1, response.getDriftSeries.size()) + } +} \ No newline at end of file diff --git a/hub/src/test/scala/ai/chronon/hub/handlers/JoinHandlerTest.scala b/hub/src/test/scala/ai/chronon/hub/handlers/JoinHandlerTest.scala deleted file mode 100644 index dd52c7bc0b..0000000000 --- a/hub/src/test/scala/ai/chronon/hub/handlers/JoinHandlerTest.scala +++ /dev/null @@ -1,193 +0,0 @@ -package ai.chronon.hub.handlers - -import ai.chronon.hub.handlers.MockJoinService.mockJoinRegistry -import ai.chronon.hub.model.Join -import ai.chronon.hub.model.ListJoinResponse -import ai.chronon.hub.store.MonitoringModelStore -import io.circe._ -import io.circe.generic.auto._ -import io.circe.parser._ -import io.vertx.core.Handler -import io.vertx.core.MultiMap -import io.vertx.core.Vertx -import io.vertx.core.http.HttpServerResponse -import io.vertx.ext.unit.TestContext -import io.vertx.ext.unit.junit.VertxUnitRunner -import io.vertx.ext.web.RequestBody -import io.vertx.ext.web.RoutingContext -import org.junit.Assert._ -import org.junit.Before -import org.junit.Test -import org.junit.runner.RunWith -import org.mockito.ArgumentCaptor -import org.mockito.ArgumentMatchers.anyInt -import org.mockito.ArgumentMatchers.anyString -import org.mockito.Mock -import org.mockito.Mockito.verify -import org.mockito.Mockito.when -import org.mockito.MockitoAnnotations -import org.scalatest.EitherValues - -@RunWith(classOf[VertxUnitRunner]) -class JoinHandlerTest extends EitherValues { - - @Mock var routingContext: RoutingContext = _ - @Mock var response: HttpServerResponse = _ - @Mock var requestBody: RequestBody = _ - @Mock var mockedStore: MonitoringModelStore = _ - - var vertx: Vertx = _ - var listHandler: Handler[RoutingContext] = _ - var getHandler: Handler[RoutingContext] = _ - - @Before - def setUp(context: TestContext): Unit = { - MockitoAnnotations.openMocks(this) - vertx = Vertx.vertx - listHandler = new JoinsHandler(mockedStore).listHandler - getHandler = new JoinsHandler(mockedStore).getHandler - // Set up common routing context behavior - when(routingContext.response).thenReturn(response) - when(response.putHeader(anyString, anyString)).thenReturn(response) - when(response.setStatusCode(anyInt)).thenReturn(response) - when(routingContext.body).thenReturn(requestBody) - } - - @Test - def testSend400BadOffset(context: TestContext) : Unit = { - val async = context.async - val multiMap = MultiMap.caseInsensitiveMultiMap - multiMap.add("offset", "-1") - multiMap.add("limit", "10") - when(routingContext.queryParams()).thenReturn(multiMap) - - // Trigger call// Trigger call - listHandler.handle(routingContext) - vertx.setTimer(1000, _ => { - verify(response).setStatusCode(400) - async.complete() - }) - } - - @Test - def testSend400BadLimit(context: TestContext) : Unit = { - val async = context.async - val multiMap = MultiMap.caseInsensitiveMultiMap - multiMap.add("offset", "10") - multiMap.add("limit", "-1") - when(routingContext.queryParams()).thenReturn(multiMap) - - // Trigger call// Trigger call - listHandler.handle(routingContext) - vertx.setTimer(1000, _ => { - verify(response).setStatusCode(400) - async.complete() - }) - } - - @Test - def testSend404MissingJoin(context: TestContext) : Unit = { - val async = context.async - val multiMap = MultiMap.caseInsensitiveMultiMap - multiMap.add("offset", "10") - multiMap.add("limit", "-1") - when(routingContext.queryParams()).thenReturn(multiMap) - when(routingContext.pathParam("name")).thenReturn("fake_join") - when(mockedStore.getJoins).thenReturn(mockJoinRegistry) - - // Trigger call// Trigger call - getHandler.handle(routingContext) - vertx.setTimer(1000, _ => { - verify(response).setStatusCode(404) - async.complete() - }) - } - - @Test - def testSendValidResults(context: TestContext) : Unit = { - val async = context.async - when(mockedStore.getJoins).thenReturn(mockJoinRegistry) - val multiMap = MultiMap.caseInsensitiveMultiMap - when(routingContext.queryParams()).thenReturn(multiMap) - - // Capture the response that will be sent - val responseCaptor = ArgumentCaptor.forClass(classOf[String]) - - // Trigger call// Trigger call - listHandler.handle(routingContext) - vertx.setTimer(1000, _ => { - verify(response).setStatusCode(200) - verify(response).putHeader("content-type", "application/json") - verify(response).end(responseCaptor.capture) - val jsonResponse = responseCaptor.getValue - - val listResponse: Either[Error, ListJoinResponse] = decode[ListJoinResponse](jsonResponse) - val items = listResponse.right.value.items - assertEquals(items.length, new JoinsHandler(mockedStore).defaultLimit) - assertEquals(items.map(_.name.toInt).toSet, (0 until 10).toSet) - - async.complete() - }) - } - - @Test - def testSendPaginatedResultsCorrectly(context: TestContext) : Unit = { - val async = context.async - when(mockedStore.getJoins).thenReturn(mockJoinRegistry) - - val multiMap = MultiMap.caseInsensitiveMultiMap - val number = 10 - val startOffset = 25 - multiMap.add("offset", startOffset.toString) - multiMap.add("limit", number.toString) - when(routingContext.queryParams()).thenReturn(multiMap) - - // Capture the response that will be sent - val responseCaptor = ArgumentCaptor.forClass(classOf[String]) - - // Trigger call// Trigger call - listHandler.handle(routingContext) - vertx.setTimer(1000, _ => { - verify(response).setStatusCode(200) - verify(response).putHeader("content-type", "application/json") - verify(response).end(responseCaptor.capture) - val jsonResponse = responseCaptor.getValue - - val listResponse: Either[Error, ListJoinResponse] = decode[ListJoinResponse](jsonResponse) - val items = listResponse.right.value.items - assertEquals(items.length, number) - assertEquals(items.map(_.name.toInt).toSet, (startOffset until startOffset + number).toSet) - - async.complete() - }) - } - - @Test - def testSendValidJoinOnLookup(context: TestContext) : Unit = { - val async = context.async - val multiMap = MultiMap.caseInsensitiveMultiMap - multiMap.add("offset", "10") - multiMap.add("limit", "-1") - when(routingContext.pathParam("name")).thenReturn("10") - - when(mockedStore.getJoins).thenReturn(mockJoinRegistry) - - // Capture the response that will be sent - val responseCaptor = ArgumentCaptor.forClass(classOf[String]) - - // Trigger call// Trigger call - getHandler.handle(routingContext) - vertx.setTimer(1000, _ => { - verify(response).setStatusCode(200) - verify(response).putHeader("content-type", "application/json") - verify(response).end(responseCaptor.capture) - val jsonResponse = responseCaptor.getValue - - val joinResponse: Either[Error, Join] = decode[Join](jsonResponse) - assertEquals(joinResponse.right.value.name, "10") - - async.complete() - }) - } - -} diff --git a/hub/src/test/scala/ai/chronon/hub/handlers/ModelHandlerTest.scala b/hub/src/test/scala/ai/chronon/hub/handlers/ModelHandlerTest.scala deleted file mode 100644 index a3b8b501e8..0000000000 --- a/hub/src/test/scala/ai/chronon/hub/handlers/ModelHandlerTest.scala +++ /dev/null @@ -1,155 +0,0 @@ -package ai.chronon.hub.handlers - -import ai.chronon.hub.handlers.MockDataService.mockModelRegistry -import ai.chronon.hub.model.GroupBy -import ai.chronon.hub.model.Join -import ai.chronon.hub.model.ListModelResponse -import ai.chronon.hub.model.Model -import ai.chronon.hub.store.MonitoringModelStore -import io.circe._ -import io.circe.generic.auto._ -import io.circe.parser._ -import io.vertx.core.MultiMap -import io.vertx.core.Vertx -import io.vertx.core.http.HttpServerResponse -import io.vertx.ext.unit.TestContext -import io.vertx.ext.unit.junit.VertxUnitRunner -import io.vertx.ext.web.RequestBody -import io.vertx.ext.web.RoutingContext -import org.junit.Assert._ -import org.junit.Before -import org.junit.Test -import org.junit.runner.RunWith -import org.mockito.ArgumentCaptor -import org.mockito.ArgumentMatchers.anyInt -import org.mockito.ArgumentMatchers.anyString -import org.mockito.Mock -import org.mockito.Mockito.verify -import org.mockito.Mockito.when -import org.mockito.MockitoAnnotations -import org.scalatest.EitherValues - -@RunWith(classOf[VertxUnitRunner]) -class ModelHandlerTest extends EitherValues { - - @Mock var routingContext: RoutingContext = _ - @Mock var response: HttpServerResponse = _ - @Mock var requestBody: RequestBody = _ - @Mock var mockedStore: MonitoringModelStore = _ - - var vertx: Vertx = _ - var handler: ModelsHandler = _ - - @Before - def setUp(context: TestContext): Unit = { - MockitoAnnotations.openMocks(this) - vertx = Vertx.vertx - handler = new ModelsHandler(mockedStore) - // Set up common routing context behavior - when(routingContext.response).thenReturn(response) - when(response.putHeader(anyString, anyString)).thenReturn(response) - when(response.setStatusCode(anyInt)).thenReturn(response) - when(routingContext.body).thenReturn(requestBody) - } - - @Test - def testSend400BadOffset(context: TestContext) : Unit = { - val async = context.async - val multiMap = MultiMap.caseInsensitiveMultiMap - multiMap.add("offset", "-1") - multiMap.add("limit", "10") - when(routingContext.queryParams()).thenReturn(multiMap) - - // Trigger call// Trigger call - handler.handle(routingContext) - vertx.setTimer(1000, _ => { - verify(response).setStatusCode(400) - async.complete() - }) - } - - @Test - def testSend400BadLimit(context: TestContext) : Unit = { - val async = context.async - val multiMap = MultiMap.caseInsensitiveMultiMap - multiMap.add("offset", "10") - multiMap.add("limit", "-1") - when(routingContext.queryParams()).thenReturn(multiMap) - - // Trigger call// Trigger call - handler.handle(routingContext) - vertx.setTimer(1000, _ => { - verify(response).setStatusCode(400) - async.complete() - }) - } - - @Test - def testSendValidResults(context: TestContext) : Unit = { - val async = context.async - when(mockedStore.getModels).thenReturn(mockModelRegistry) - val multiMap = MultiMap.caseInsensitiveMultiMap - when(routingContext.queryParams()).thenReturn(multiMap) - - // Capture the response that will be sent - val responseCaptor = ArgumentCaptor.forClass(classOf[String]) - - // Trigger call// Trigger call - handler.handle(routingContext) - vertx.setTimer(1000, _ => { - verify(response).setStatusCode(200) - verify(response).putHeader("content-type", "application/json") - verify(response).end(responseCaptor.capture) - val jsonResponse = responseCaptor.getValue - - val listModelResponse: Either[Error, ListModelResponse] = decode[ListModelResponse](jsonResponse) - val items = listModelResponse.right.value.items - assertEquals(items.length, handler.defaultLimit) - assertEquals(items.map(_.name.toInt).toSet, (0 until 10).toSet) - - async.complete() - }) - } - - @Test - def testSendPaginatedResultsCorrectly(context: TestContext) : Unit = { - val async = context.async - when(mockedStore.getModels).thenReturn(mockModelRegistry) - - val multiMap = MultiMap.caseInsensitiveMultiMap - val number = 10 - val startOffset = 25 - multiMap.add("offset", startOffset.toString) - multiMap.add("limit", number.toString) - when(routingContext.queryParams()).thenReturn(multiMap) - - // Capture the response that will be sent - val responseCaptor = ArgumentCaptor.forClass(classOf[String]) - - // Trigger call// Trigger call - handler.handle(routingContext) - vertx.setTimer(1000, _ => { - verify(response).setStatusCode(200) - verify(response).putHeader("content-type", "application/json") - verify(response).end(responseCaptor.capture) - val jsonResponse = responseCaptor.getValue - - val listModelResponse: Either[Error, ListModelResponse] = decode[ListModelResponse](jsonResponse) - val items = listModelResponse.right.value.items - assertEquals(items.length, number) - assertEquals(items.map(_.name.toInt).toSet, (startOffset until startOffset + number).toSet) - - async.complete() - }) - } -} - -object MockDataService { - def generateMockModel(id: String): Model = { - val groupBys = Seq(GroupBy("my_groupBy", Seq("g1", "g2"))) - val join = Join("my_join", Seq("ext_f1", "ext_f2", "d_1", "d2"), groupBys, true, true, Some("my_team")) - Model(id, join, online = true, production = true, "my team", "XGBoost") - } - - val mockModelRegistry: Seq[Model] = (0 until 100).map(i => generateMockModel(i.toString)) -} diff --git a/hub/src/test/scala/ai/chronon/hub/handlers/SearchHandlerTest.scala b/hub/src/test/scala/ai/chronon/hub/handlers/SearchHandlerTest.scala deleted file mode 100644 index 2e4abaa408..0000000000 --- a/hub/src/test/scala/ai/chronon/hub/handlers/SearchHandlerTest.scala +++ /dev/null @@ -1,162 +0,0 @@ -package ai.chronon.hub.handlers - -import ai.chronon.hub.handlers.MockJoinService.mockJoinRegistry -import ai.chronon.hub.model.GroupBy -import ai.chronon.hub.model.Join -import ai.chronon.hub.model.SearchJoinResponse -import ai.chronon.hub.store.MonitoringModelStore -import io.circe._ -import io.circe.generic.auto._ -import io.circe.parser._ -import io.vertx.core.MultiMap -import io.vertx.core.Vertx -import io.vertx.core.http.HttpServerResponse -import io.vertx.ext.unit.TestContext -import io.vertx.ext.unit.junit.VertxUnitRunner -import io.vertx.ext.web.RequestBody -import io.vertx.ext.web.RoutingContext -import org.junit.Assert._ -import org.junit.Before -import org.junit.Test -import org.junit.runner.RunWith -import org.mockito.ArgumentCaptor -import org.mockito.ArgumentMatchers.anyInt -import org.mockito.ArgumentMatchers.anyString -import org.mockito.Mock -import org.mockito.Mockito.verify -import org.mockito.Mockito.when -import org.mockito.MockitoAnnotations -import org.scalatest.EitherValues - -@RunWith(classOf[VertxUnitRunner]) -class SearchHandlerTest extends EitherValues { - - @Mock var routingContext: RoutingContext = _ - @Mock var response: HttpServerResponse = _ - @Mock var requestBody: RequestBody = _ - @Mock var mockedStore: MonitoringModelStore = _ - - var vertx: Vertx = _ - var handler: SearchHandler = _ - - @Before - def setUp(context: TestContext): Unit = { - MockitoAnnotations.openMocks(this) - vertx = Vertx.vertx - handler = new SearchHandler(mockedStore) - // Set up common routing context behavior - when(routingContext.response).thenReturn(response) - when(response.putHeader(anyString, anyString)).thenReturn(response) - when(response.setStatusCode(anyInt)).thenReturn(response) - when(routingContext.body).thenReturn(requestBody) - } - - @Test - def testSend400BadOffset(context: TestContext) : Unit = { - val async = context.async - val multiMap = MultiMap.caseInsensitiveMultiMap - multiMap.add("term", "foo") - multiMap.add("offset", "-1") - multiMap.add("limit", "10") - when(routingContext.queryParams()).thenReturn(multiMap) - - // Trigger call// Trigger call - handler.handle(routingContext) - vertx.setTimer(1000, _ => { - verify(response).setStatusCode(400) - async.complete() - }) - } - - @Test - def testSend400BadLimit(context: TestContext) : Unit = { - val async = context.async - val multiMap = MultiMap.caseInsensitiveMultiMap - multiMap.add("term", "foo") - multiMap.add("offset", "10") - multiMap.add("limit", "-1") - when(routingContext.queryParams()).thenReturn(multiMap) - - // Trigger call// Trigger call - handler.handle(routingContext) - vertx.setTimer(1000, _ => { - verify(response).setStatusCode(400) - async.complete() - }) - } - - @Test - def testSendValidResults(context: TestContext) : Unit = { - val async = context.async - when(mockedStore.getJoins).thenReturn(mockJoinRegistry) - val multiMap = MultiMap.caseInsensitiveMultiMap - multiMap.add("term", "1") - when(routingContext.queryParams()).thenReturn(multiMap) - - // Capture the response that will be sent - val responseCaptor = ArgumentCaptor.forClass(classOf[String]) - - // Trigger call// Trigger call - handler.handle(routingContext) - vertx.setTimer(1000, _ => { - verify(response).setStatusCode(200) - verify(response).putHeader("content-type", "application/json") - verify(response).end(responseCaptor.capture) - val jsonResponse = responseCaptor.getValue - - val listResponse: Either[Error, SearchJoinResponse] = decode[SearchJoinResponse](jsonResponse) - val items = listResponse.right.value.items - assertEquals(items.length, handler.defaultLimit) - assertEquals(items.map(_.name.toInt).toSet, Set(1, 10, 11, 12, 13, 14, 15, 16, 17, 18)) - - async.complete() - }) - } - - @Test - def testSendPaginatedResultsCorrectly(context: TestContext) : Unit = { - val async = context.async - when(mockedStore.getJoins).thenReturn(mockJoinRegistry) - - val multiMap = MultiMap.caseInsensitiveMultiMap - val number = 6 - val startOffset = 3 - multiMap.add("term", "1") - multiMap.add("offset", startOffset.toString) - multiMap.add("limit", number.toString) - when(routingContext.queryParams()).thenReturn(multiMap) - - // we have names: 0, 1, 2, .. 99 - // our result should give us: 1, 10, 11, 12, .. 19, 21, 31, .. 91 - val expected = Set(12, 13, 14, 15, 16, 17) - - // Capture the response that will be sent - val responseCaptor = ArgumentCaptor.forClass(classOf[String]) - - // Trigger call// Trigger call - handler.handle(routingContext) - vertx.setTimer(1000, _ => { - verify(response).setStatusCode(200) - verify(response).putHeader("content-type", "application/json") - verify(response).end(responseCaptor.capture) - val jsonResponse = responseCaptor.getValue - - val listResponse: Either[Error, SearchJoinResponse] = decode[SearchJoinResponse](jsonResponse) - val items = listResponse.right.value.items - assertEquals(items.length, number) - assertEquals(items.map(_.name.toInt).toSet, expected) - - async.complete() - }) - } -} - -object MockJoinService { - def generateMockJoin(id: String): Join = { - val groupBys = Seq(GroupBy("my_groupBy", Seq("g1", "g2"))) - Join(id, Seq("ext_f1", "ext_f2", "d_1", "d2"), groupBys, true, true, Some("my_team")) - } - - val mockJoinRegistry: Seq[Join] = (0 until 100).map(i => generateMockJoin(i.toString)) -} - diff --git a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala index 6af7933ee0..110e882b60 100644 --- a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala +++ b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala @@ -84,11 +84,14 @@ class DriftStore(kvStore: KVStore, val tileKeyMap = tileKeysForJoin(joinConf, None, columnPrefix) val requestContextMap: Map[GetRequest, SummaryRequestContext] = tileKeyMap.flatMap { case (group, keys) => - keys.map { key => - val keyBytes = serializer.serialize(key) - val get = GetRequest(keyBytes, summaryDataset, startTsMillis = startMs, endTsMillis = endMs) - get -> SummaryRequestContext(get, key, group) - } + // Only create requests for keys that match our column prefix + keys + .filter(key => columnPrefix.forall(prefix => key.getColumn == prefix)) + .map { key => + val keyBytes = serializer.serialize(key) + val get = GetRequest(keyBytes, summaryDataset, startTsMillis = startMs, endTsMillis = endMs) + get -> SummaryRequestContext(get, key, group) + } } val responseFuture = kvStore.multiGet(requestContextMap.keys.toSeq) diff --git a/online/src/main/scala/ai/chronon/online/stats/PivotUtils.scala b/online/src/main/scala/ai/chronon/online/stats/PivotUtils.scala index 12b172a502..64392599b2 100644 --- a/online/src/main/scala/ai/chronon/online/stats/PivotUtils.scala +++ b/online/src/main/scala/ai/chronon/online/stats/PivotUtils.scala @@ -95,7 +95,7 @@ object PivotUtils { if (isSetFunc(summary)) { JLong.valueOf(extract(summary)) } else { - null + Constants.magicNullLong } } } diff --git a/online/src/test/scala/ai/chronon/online/test/stats/PivotUtilsTest.scala b/online/src/test/scala/ai/chronon/online/test/stats/PivotUtilsTest.scala index b9ce6d0e2e..046e86cfde 100644 --- a/online/src/test/scala/ai/chronon/online/test/stats/PivotUtilsTest.scala +++ b/online/src/test/scala/ai/chronon/online/test/stats/PivotUtilsTest.scala @@ -113,7 +113,7 @@ class PivotUtilsTest extends AnyFlatSpec with Matchers { (ts3, 3000L) )) - result.getCount.asScala shouldEqual List(100L, null, 300L) + result.getCount.asScala.toList shouldEqual List(100L, Constants.magicNullLong, 300L) } it should "preserve timestamp order" in { @@ -313,4 +313,62 @@ class PivotUtilsTest extends AnyFlatSpec with Matchers { series(0) shouldBe Constants.magicNullDouble series(1) shouldBe 0.5 } + + it should "handle Long.MAX_VALUE and magicNullLong values" in { + val ts1 = new TileSummary() + ts1.setCount(Long.MaxValue) + + val ts2 = new TileSummary() + // count is not set, should become magicNullLong + + val ts3 = new TileSummary() + ts3.setCount(100L) + + val result = pivot(Array( + (ts1, 1000L), + (ts2, 2000L), + (ts3, 3000L) + )) + + result.getCount.asScala shouldEqual List(Long.MaxValue, Constants.magicNullLong, 100L) + } + + it should "handle all null Long values" in { + val ts1 = new TileSummary() + val ts2 = new TileSummary() + val ts3 = new TileSummary() + // no counts set for any summary + + val result = pivot(Array( + (ts1, 1000L), + (ts2, 2000L), + (ts3, 3000L) + )) + + // Since all values are unset, they should all be magicNullLong rather than null + result.getCount.asScala.toList shouldEqual List.fill(3)(Constants.magicNullLong) + } + + it should "handle mixed null and non-null Long fields" in { + val ts1 = new TileSummary() + ts1.setCount(100L) + ts1.setNullCount(10L) + + val ts2 = new TileSummary() + // count not set + ts2.setNullCount(20L) + + val ts3 = new TileSummary() + ts3.setCount(300L) + // nullCount not set + + val result = pivot(Array( + (ts1, 1000L), + (ts2, 2000L), + (ts3, 3000L) + )) + + result.getCount.asScala shouldEqual List(100L, Constants.magicNullLong, 300L) + result.getNullCount.asScala shouldEqual List(10L, 20L, Constants.magicNullLong) + } } diff --git a/service_commons/src/main/java/ai/chronon/service/RouteHandlerWrapper.java b/service_commons/src/main/java/ai/chronon/service/RouteHandlerWrapper.java index 85cedb5bcf..b90822a622 100644 --- a/service_commons/src/main/java/ai/chronon/service/RouteHandlerWrapper.java +++ b/service_commons/src/main/java/ai/chronon/service/RouteHandlerWrapper.java @@ -2,6 +2,7 @@ import ai.chronon.api.thrift.*; import ai.chronon.api.thrift.protocol.TBinaryProtocol; +import ai.chronon.api.thrift.protocol.TSimpleJSONProtocol; import ai.chronon.api.thrift.transport.TTransportException; import io.vertx.core.Handler; import io.vertx.core.json.JsonObject; @@ -50,7 +51,6 @@ public class RouteHandlerWrapper { private static final ThreadLocal base64Encoder = ThreadLocal.withInitial(Base64::getEncoder); private static final ThreadLocal base64Decoder = ThreadLocal.withInitial(Base64::getDecoder); - public static T deserializeTBinaryBase64(String base64Data, Class clazz) throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, TException { byte[] binaryData = base64Decoder.get().decode(base64Data); T tb = (T) clazz.getDeclaredConstructor().newInstance(); @@ -87,8 +87,20 @@ public static Handler createHandler(Function transf String responseFormat = ctx.request().getHeader(RESPONSE_CONTENT_TYPE_HEADER); if (responseFormat == null || responseFormat.equals("application/json")) { - // Send json response - ctx.response().setStatusCode(200).putHeader("content-type", JSON_TYPE_VALUE).end(JsonObject.mapFrom(output).encode()); + try { + TSerializer serializer = new TSerializer(new TSimpleJSONProtocol.Factory()); + String jsonString = serializer.toString((TBase)output); + ctx.response() + .setStatusCode(200) + .putHeader("content-type", JSON_TYPE_VALUE) + .end(jsonString); + } catch (TException e) { + LOGGER.error("Failed to serialize response", e); + throw new RuntimeException(e); + } catch (Exception e) { + LOGGER.error("Unexpected error during serialization", e); + throw new RuntimeException(e); + } } else { if (!responseFormat.equals(TBINARY_B64_TYPE_VALUE)) { throw new IllegalArgumentException(String.format("Unsupported response-content-type: %s. Supported values are: %s and %s", responseFormat, JSON_TYPE_VALUE, TBINARY_B64_TYPE_VALUE));