From 7ad85ac8b730045d2d5824ece81527255b4cc5ae Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Fri, 22 Nov 2024 16:51:15 -0700 Subject: [PATCH 001/152] move everything to joins page, update search/default page, etc --- .../NavigationBar/NavigationBar.svelte | 34 +++++++++-- frontend/src/lib/types/Entity/Entity.ts | 32 ++++++++++ frontend/src/lib/types/Model/Model.ts | 12 +++- frontend/src/routes/+layout.svelte | 10 +-- frontend/src/routes/+page.server.ts | 2 +- frontend/src/routes/groupbys/+page.svelte | 5 -- .../routes/{models => joins}/+page.server.ts | 0 frontend/src/routes/joins/+page.svelte | 61 ++++++++++++++++++- .../{models => joins}/[slug]/+page.server.ts | 0 .../{models => joins}/[slug]/+page.svelte | 0 frontend/src/routes/models/+page.svelte | 56 ----------------- 11 files changed, 132 insertions(+), 80 deletions(-) create mode 100644 frontend/src/lib/types/Entity/Entity.ts delete mode 100644 frontend/src/routes/groupbys/+page.svelte rename frontend/src/routes/{models => joins}/+page.server.ts (100%) rename frontend/src/routes/{models => joins}/[slug]/+page.server.ts (100%) rename frontend/src/routes/{models => joins}/[slug]/+page.svelte (100%) delete mode 100644 frontend/src/routes/models/+page.svelte diff --git a/frontend/src/lib/components/NavigationBar/NavigationBar.svelte b/frontend/src/lib/components/NavigationBar/NavigationBar.svelte index 222d362042..781742bdf3 100644 --- a/frontend/src/lib/components/NavigationBar/NavigationBar.svelte +++ b/frontend/src/lib/components/NavigationBar/NavigationBar.svelte @@ -29,13 +29,13 @@ AdjustmentsHorizontal, ArrowsUpDown } from 'svelte-hero-icons'; - import type { IconSource } from 'svelte-hero-icons'; import { goto } from '$app/navigation'; import { isMacOS } from '$lib/util/browser'; import { Badge } from '$lib/components/ui/badge'; + import { getEntity, type Entity } from '$lib/types/Entity/Entity'; type Props = { - navItems: { label: string; href: string; icon: IconSource }[]; + navItems: Entity[]; user: { name: string; avatar: string }; }; @@ -129,16 +129,16 @@ {#each navItems as item}
  • @@ -200,9 +200,31 @@ {:else} {#each searchResults as model} - handleSelect(`/models/${encodeURIComponent(model.name)}`)}> + + handleSelect(`${getEntity('models').path}/${encodeURIComponent(model.name)}`)} + > + {model.name} + + handleSelect(`${getEntity('joins').path}/${encodeURIComponent(model.join.name)}`)} + > + + {model.join.name} + + {#each model.join.groupBys as groupBy} + + handleSelect(`${getEntity('groupbys').path}/${encodeURIComponent(groupBy.name)}`)} + > + + {groupBy.name} + + {/each} {/each} {/if} diff --git a/frontend/src/lib/types/Entity/Entity.ts b/frontend/src/lib/types/Entity/Entity.ts new file mode 100644 index 0000000000..3b497b5a93 --- /dev/null +++ b/frontend/src/lib/types/Entity/Entity.ts @@ -0,0 +1,32 @@ +import { Cube, PuzzlePiece, Square3Stack3d } from 'svelte-hero-icons'; + +export const entityConfig = [ + { + label: 'Models', + path: '/models', + icon: Cube, + id: 'models' + }, + { + label: 'GroupBys', + path: '/groupbys', + icon: PuzzlePiece, + id: 'groupbys' + }, + { + label: 'Joins', + path: '/joins', + icon: Square3Stack3d, + id: 'joins' + } +] as const; + +export type Entity = (typeof entityConfig)[number]; +export type EntityId = Entity['id']; + +// Helper function to get entity by ID +export function getEntity(id: EntityId): Entity { + const entity = entityConfig.find((entity) => entity.id === id); + if (!entity) throw new Error(`Entity with id "${id}" not found`); + return entity; +} diff --git a/frontend/src/lib/types/Model/Model.ts b/frontend/src/lib/types/Model/Model.ts index fb67e8d37d..16b00bf963 100644 --- a/frontend/src/lib/types/Model/Model.ts +++ b/frontend/src/lib/types/Model/Model.ts @@ -4,7 +4,7 @@ export type Model = { production: boolean; team: string; modelType: string; - join: JoinTimeSeriesResponse; // todo: this type needs to be updated to match the actual response once that WIP is finished + join: Join; }; export type ModelsResponse = { @@ -23,6 +23,16 @@ export type TimeSeriesResponse = { id: string; items: TimeSeriesItem[]; }; +export type Join = { + name: string; + joinFeatures: string[]; + groupBys: GroupBy[]; +}; + +export type GroupBy = { + name: string; + features: string[]; +}; export type JoinTimeSeriesResponse = { name: string; // todo: rename to joinName diff --git a/frontend/src/routes/+layout.svelte b/frontend/src/routes/+layout.svelte index a754a64a2c..282723e082 100644 --- a/frontend/src/routes/+layout.svelte +++ b/frontend/src/routes/+layout.svelte @@ -6,7 +6,7 @@ import NavigationBar from '$lib/components/NavigationBar/NavigationBar.svelte'; import BreadcrumbNav from '$lib/components/BreadcrumbNav/BreadcrumbNav.svelte'; import { ScrollArea } from '$lib/components/ui/scroll-area'; - import { Cube, PuzzlePiece, Square3Stack3d } from 'svelte-hero-icons'; + import { entityConfig } from '$lib/types/Entity/Entity'; let { children }: { children: Snippet } = $props(); @@ -16,12 +16,6 @@ avatar: '/path/to/avatar.jpg' }; - const navItems = [ - { label: 'Models', href: '/models', icon: Cube }, - { label: 'GroupBys', href: '/groupbys', icon: PuzzlePiece }, - { label: 'Joins', href: '/joins', icon: Square3Stack3d } - ]; - const breadcrumbs = $derived($page.url.pathname.split('/').filter(Boolean)); @@ -29,7 +23,7 @@ - + entity.id === 'joins')} {user} />
    - import ComingSoonPage from '$lib/components/ComingSoonPage/ComingSoonPage.svelte'; - - - diff --git a/frontend/src/routes/models/+page.server.ts b/frontend/src/routes/joins/+page.server.ts similarity index 100% rename from frontend/src/routes/models/+page.server.ts rename to frontend/src/routes/joins/+page.server.ts diff --git a/frontend/src/routes/joins/+page.svelte b/frontend/src/routes/joins/+page.svelte index a6cf0f1602..684b67f3e4 100644 --- a/frontend/src/routes/joins/+page.svelte +++ b/frontend/src/routes/joins/+page.svelte @@ -1,5 +1,60 @@ - - + + +
    + +
    + + + + + Join + Model + Team + Type + Online + Production + + + + {#each models as model} + + + + {model.join.name} + + + + {model.name} + + {model.team} + {model.modelType} + + + + + + + + {/each} + +
    + diff --git a/frontend/src/routes/models/[slug]/+page.server.ts b/frontend/src/routes/joins/[slug]/+page.server.ts similarity index 100% rename from frontend/src/routes/models/[slug]/+page.server.ts rename to frontend/src/routes/joins/[slug]/+page.server.ts diff --git a/frontend/src/routes/models/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte similarity index 100% rename from frontend/src/routes/models/[slug]/+page.svelte rename to frontend/src/routes/joins/[slug]/+page.svelte diff --git a/frontend/src/routes/models/+page.svelte b/frontend/src/routes/models/+page.svelte deleted file mode 100644 index 76ca35b441..0000000000 --- a/frontend/src/routes/models/+page.svelte +++ /dev/null @@ -1,56 +0,0 @@ - - - - -
    - -
    - - - - - Model - Team - Type - Online - Production - - - - {#each models as model} - - - - {model.name} - - - {model.team} - {model.modelType} - - - - - - - - {/each} - -
    - From 281fcc4c7865fcd10c86842e622ad23e993679fd Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Sat, 23 Nov 2024 08:35:27 -0700 Subject: [PATCH 002/152] URL-encode dynamic parameters in links to prevent potential issues. --- frontend/src/routes/joins/+page.svelte | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/src/routes/joins/+page.svelte b/frontend/src/routes/joins/+page.svelte index 684b67f3e4..274e458e17 100644 --- a/frontend/src/routes/joins/+page.svelte +++ b/frontend/src/routes/joins/+page.svelte @@ -38,7 +38,7 @@ {#each models as model} - + {model.join.name} From 8f62e420988dc753cebd93e40e9fda9738a5743e Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Sat, 23 Nov 2024 08:40:04 -0700 Subject: [PATCH 003/152] add nullValue to test --- frontend/src/lib/types/Model/Model.test.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/src/lib/types/Model/Model.test.ts b/frontend/src/lib/types/Model/Model.test.ts index b617996c93..948a9bad94 100644 --- a/frontend/src/lib/types/Model/Model.test.ts +++ b/frontend/src/lib/types/Model/Model.test.ts @@ -71,7 +71,7 @@ describe('Model types', () => { if (timeseriesResult.items.length > 0) { const item = timeseriesResult.items[0]; - const expectedItemKeys = ['value', 'ts', 'label']; + const expectedItemKeys = ['value', 'ts', 'label', 'nullValue']; expect(Object.keys(item)).toEqual(expect.arrayContaining(expectedItemKeys)); // Log a warning if there are additional fields @@ -184,7 +184,7 @@ describe('Model types', () => { if (subItem.points.length > 0) { const point = subItem.points[0]; - const expectedPointKeys = ['value', 'ts', 'label']; + const expectedPointKeys = ['value', 'ts', 'label', 'nullValue']; expect(Object.keys(point)).toEqual(expect.arrayContaining(expectedPointKeys)); // Log a warning if there are additional fields From 8cab3a3f0d751eae6392dd1b6178c570adda5958 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Sun, 24 Nov 2024 06:28:00 -0500 Subject: [PATCH 004/152] tooltip improvements --- .../EChartTooltip/EChartTooltip.svelte | 79 +++++++++---------- 1 file changed, 39 insertions(+), 40 deletions(-) diff --git a/frontend/src/lib/components/EChartTooltip/EChartTooltip.svelte b/frontend/src/lib/components/EChartTooltip/EChartTooltip.svelte index 47ed4d4802..629e9ac429 100644 --- a/frontend/src/lib/components/EChartTooltip/EChartTooltip.svelte +++ b/frontend/src/lib/components/EChartTooltip/EChartTooltip.svelte @@ -48,46 +48,45 @@ } -
    -
    -
    - {#if xValue !== null && visible} -
    -
    - {getTooltipTitle(xValue, xAxisCategories)} -
    - -
    - {#each series as item} - - {/each} -
    -
    - -
    - {isMacOS() ? '⌘' : 'Ctrl'} to lock tooltip +
    +
    + {#if xValue !== null && visible} +
    +
    + {getTooltipTitle(xValue, xAxisCategories)} +
    + +
    + {#each series as item} + + {/each}
    +
    + +
    + {isMacOS() ? '⌘' : 'Ctrl'} to lock tooltip
    - {/if} -
    +
    + {/if}
    From e26d99dcb45931292fce3f74b5026d9e85ea32c8 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Sun, 24 Nov 2024 06:33:18 -0500 Subject: [PATCH 005/152] justify start content --- .../src/lib/components/EChartTooltip/EChartTooltip.svelte | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/src/lib/components/EChartTooltip/EChartTooltip.svelte b/frontend/src/lib/components/EChartTooltip/EChartTooltip.svelte index 629e9ac429..48c0f3e98c 100644 --- a/frontend/src/lib/components/EChartTooltip/EChartTooltip.svelte +++ b/frontend/src/lib/components/EChartTooltip/EChartTooltip.svelte @@ -64,12 +64,12 @@ {#each series as item} - -
    diff --git a/frontend/src/lib/components/MetricTypeToggle/MetricTypeToggle.svelte b/frontend/src/lib/components/MetricTypeToggle/MetricTypeToggle.svelte new file mode 100644 index 0000000000..4464f41738 --- /dev/null +++ b/frontend/src/lib/components/MetricTypeToggle/MetricTypeToggle.svelte @@ -0,0 +1,22 @@ + + +
    + {#each METRIC_TYPES as metricType} + + {/each} +
    diff --git a/frontend/src/lib/types/MetricType/MetricType.ts b/frontend/src/lib/types/MetricType/MetricType.ts new file mode 100644 index 0000000000..2358154067 --- /dev/null +++ b/frontend/src/lib/types/MetricType/MetricType.ts @@ -0,0 +1,8 @@ +export const METRIC_TYPES = ['jsd', 'hellinger', 'psi'] as const; +export type MetricType = (typeof METRIC_TYPES)[number]; + +export const METRIC_LABELS: Record = { + jsd: 'JSD', + hellinger: 'Hellinger', + psi: 'PSI' +}; diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index b15056ba9d..b0dfcf0cd6 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -17,7 +17,7 @@ import { untrack } from 'svelte'; import PageHeader from '$lib/components/PageHeader/PageHeader.svelte'; import Separator from '$lib/components/ui/separator/separator.svelte'; - import DriftSkewToggle from '$lib/components/DriftSkewToggle/DriftSkewToggle.svelte'; + import MetricTypeToggle from '$lib/components/MetricTypeToggle/MetricTypeToggle.svelte'; import ResetZoomButton from '$lib/components/ResetZoomButton/ResetZoomButton.svelte'; import DateRangeSelector from '$lib/components/DateRangeSelector/DateRangeSelector.svelte'; import IntersectionObserver from 'svelte-intersection-observer'; @@ -36,9 +36,10 @@ import { formatDate, formatValue } from '$lib/util/format'; import PercentileChart from '$lib/components/PercentileChart/PercentileChart.svelte'; import { createChartOption } from '$lib/util/chart-options.svelte'; + import type { MetricType } from '$lib/types/MetricType/MetricType.js'; const { data } = $props(); - let selectedDriftSkew = $state<'drift' | 'skew'>('drift'); + let selectedMetricType = $state('psi'); // todo setup query param for jsd/psi/hellinger and use it in groupby charts and drilldown const joinTimeseries = $derived(data.joinTimeseries); const model = $derived(data.model); const distributions = $derived(data.distributions); @@ -341,7 +342,7 @@ {#if isZoomed} {/if} - +
    @@ -387,7 +388,7 @@ - + From 0a4e051aa52452f26537a4a70f2c86977e2b2bcb Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Tue, 26 Nov 2024 15:44:10 -0500 Subject: [PATCH 050/152] make metric type toggle functional --- .../MetricTypeToggle/MetricTypeToggle.svelte | 17 +++++++++++++++-- frontend/src/lib/types/MetricType/MetricType.ts | 13 +++++++++++++ .../src/routes/joins/[slug]/+page.server.ts | 8 ++++++-- frontend/src/routes/joins/[slug]/+page.svelte | 11 ++++++----- 4 files changed, 40 insertions(+), 9 deletions(-) diff --git a/frontend/src/lib/components/MetricTypeToggle/MetricTypeToggle.svelte b/frontend/src/lib/components/MetricTypeToggle/MetricTypeToggle.svelte index 4464f41738..e3ed53be69 100644 --- a/frontend/src/lib/components/MetricTypeToggle/MetricTypeToggle.svelte +++ b/frontend/src/lib/components/MetricTypeToggle/MetricTypeToggle.svelte @@ -1,9 +1,22 @@ diff --git a/frontend/src/lib/types/MetricType/MetricType.ts b/frontend/src/lib/types/MetricType/MetricType.ts index 2358154067..633de1fdee 100644 --- a/frontend/src/lib/types/MetricType/MetricType.ts +++ b/frontend/src/lib/types/MetricType/MetricType.ts @@ -1,8 +1,21 @@ export const METRIC_TYPES = ['jsd', 'hellinger', 'psi'] as const; export type MetricType = (typeof METRIC_TYPES)[number]; +export const DEFAULT_METRIC_TYPE: MetricType = 'psi'; + export const METRIC_LABELS: Record = { jsd: 'JSD', hellinger: 'Hellinger', psi: 'PSI' }; + +export const METRIC_SCALES: Record = { + jsd: { min: 0, max: 1 }, + hellinger: { min: 0, max: 1 }, + psi: { min: 0, max: 25 } +}; + +export function getMetricTypeFromParams(searchParams: URLSearchParams): MetricType { + const metric = searchParams.get('metric'); + return METRIC_TYPES.includes(metric as MetricType) ? (metric as MetricType) : DEFAULT_METRIC_TYPE; +} diff --git a/frontend/src/routes/joins/[slug]/+page.server.ts b/frontend/src/routes/joins/[slug]/+page.server.ts index 2f8c9453d4..d60c36fe02 100644 --- a/frontend/src/routes/joins/[slug]/+page.server.ts +++ b/frontend/src/routes/joins/[slug]/+page.server.ts @@ -3,6 +3,7 @@ import * as api from '$lib/api/api'; import type { JoinTimeSeriesResponse, Model, FeatureResponse } from '$lib/types/Model/Model'; import { parseDateRangeParams } from '$lib/util/date-ranges'; import { generatePercentileData } from '$lib/util/sample-data'; +import { getMetricTypeFromParams, type MetricType } from '$lib/types/MetricType/MetricType'; export const load: PageServerLoad = async ({ params, @@ -11,9 +12,11 @@ export const load: PageServerLoad = async ({ joinTimeseries: JoinTimeSeriesResponse; model?: Model; distributions: FeatureResponse[]; + metricType: MetricType; }> => { const dateRange = parseDateRangeParams(url.searchParams); const joinName = 'risk.user_transactions.txn_join'; // todo use params.slug once backend has data for all joins + const metricType = getMetricTypeFromParams(url.searchParams); const [joinTimeseries, models] = await Promise.all([ api.getJoinTimeseries({ @@ -23,7 +26,7 @@ export const load: PageServerLoad = async ({ metricType: 'drift', metrics: 'value', offset: undefined, - algorithm: 'jsd' // todo setup query param for jsd/psi/hellinger + algorithm: metricType }), api.getModels() ]); @@ -35,6 +38,7 @@ export const load: PageServerLoad = async ({ return { joinTimeseries, model: modelToReturn, - distributions + distributions, + metricType }; }; diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index b0dfcf0cd6..d12622a63f 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -36,10 +36,10 @@ import { formatDate, formatValue } from '$lib/util/format'; import PercentileChart from '$lib/components/PercentileChart/PercentileChart.svelte'; import { createChartOption } from '$lib/util/chart-options.svelte'; - import type { MetricType } from '$lib/types/MetricType/MetricType.js'; + import { METRIC_SCALES } from '$lib/types/MetricType/MetricType'; const { data } = $props(); - let selectedMetricType = $state('psi'); // todo setup query param for jsd/psi/hellinger and use it in groupby charts and drilldown + let selectedMetricType = $state(data.metricType); const joinTimeseries = $derived(data.joinTimeseries); const model = $derived(data.model); const distributions = $derived(data.distributions); @@ -88,12 +88,13 @@ } })) as EChartOption.Series[]; + const scale = METRIC_SCALES[selectedMetricType]; + return createChartOption( { yAxis: { - min: 0, - max: 1, - interval: 0.2 + min: scale.min, + max: scale.max }, series: series as EChartOption.Series[] }, From 6f8f7e743f0df67db927711d3717b508e765f46a Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Tue, 26 Nov 2024 16:07:31 -0500 Subject: [PATCH 051/152] better ux when switching metric --- .../components/MetricTypeToggle/MetricTypeToggle.svelte | 4 +--- frontend/src/routes/joins/[slug]/+page.svelte | 8 +++----- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/frontend/src/lib/components/MetricTypeToggle/MetricTypeToggle.svelte b/frontend/src/lib/components/MetricTypeToggle/MetricTypeToggle.svelte index e3ed53be69..609788e8a7 100644 --- a/frontend/src/lib/components/MetricTypeToggle/MetricTypeToggle.svelte +++ b/frontend/src/lib/components/MetricTypeToggle/MetricTypeToggle.svelte @@ -9,9 +9,7 @@ import { page } from '$app/stores'; import { goto } from '$app/navigation'; - let { - selected = $bindable(getMetricTypeFromParams(new URL($page.url).searchParams)) - } = $props(); + let selected = getMetricTypeFromParams(new URL($page.url).searchParams); function toggle(value: MetricType) { const url = new URL($page.url); diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index d12622a63f..0c8adc0158 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -39,7 +39,7 @@ import { METRIC_SCALES } from '$lib/types/MetricType/MetricType'; const { data } = $props(); - let selectedMetricType = $state(data.metricType); + let scale = $derived(METRIC_SCALES[data.metricType]); const joinTimeseries = $derived(data.joinTimeseries); const model = $derived(data.model); const distributions = $derived(data.distributions); @@ -88,8 +88,6 @@ } })) as EChartOption.Series[]; - const scale = METRIC_SCALES[selectedMetricType]; - return createChartOption( { yAxis: { @@ -343,7 +341,7 @@ {#if isZoomed} {/if} - + @@ -389,7 +387,7 @@ - + From 3d86ca7d2c849bc46f8c8bf54b20035ef4a78f47 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Tue, 26 Nov 2024 17:14:16 -0500 Subject: [PATCH 052/152] highlight hovered series in legend/tooltip --- .../CustomEChartLegend/CustomEChartLegend.svelte | 11 +++++++++++ frontend/src/lib/components/EChart/EChart.svelte | 1 + .../EChartTooltip/EChartTooltip.svelte | 16 +++++++++++++++- .../src/lib/components/ui/button/button.svelte | 2 ++ frontend/src/lib/util/chart.ts | 14 ++++++++++++++ 5 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 frontend/src/lib/util/chart.ts diff --git a/frontend/src/lib/components/CustomEChartLegend/CustomEChartLegend.svelte b/frontend/src/lib/components/CustomEChartLegend/CustomEChartLegend.svelte index 9051ba98b4..7c862b6048 100644 --- a/frontend/src/lib/components/CustomEChartLegend/CustomEChartLegend.svelte +++ b/frontend/src/lib/components/CustomEChartLegend/CustomEChartLegend.svelte @@ -2,6 +2,7 @@ import { Button } from '$lib/components/ui/button'; import type { EChartsType } from 'echarts'; import { Icon, ChevronDown, ChevronUp } from 'svelte-hero-icons'; + import { handleChartHighlight } from '$lib/util/chart'; type LegendItem = { feature: string }; type Props = { @@ -65,6 +66,14 @@ return () => resizeObserver.disconnect(); }); + + function handleMouseEnter(seriesName: string) { + handleChartHighlight(chart, seriesName, 'highlight'); + } + + function handleMouseLeave(seriesName: string) { + handleChartHighlight(chart, seriesName, 'downplay'); + }
    @@ -84,6 +93,8 @@ variant="ghost" on:click={() => toggleSeries(feature)} title={feature} + on:mouseenter={() => handleMouseEnter(feature)} + on:mouseleave={() => handleMouseLeave(feature)} >
    dispatch('click', { detail: event.detail, diff --git a/frontend/src/lib/components/EChartTooltip/EChartTooltip.svelte b/frontend/src/lib/components/EChartTooltip/EChartTooltip.svelte index 3f3bae1b2b..283ea66b3d 100644 --- a/frontend/src/lib/components/EChartTooltip/EChartTooltip.svelte +++ b/frontend/src/lib/components/EChartTooltip/EChartTooltip.svelte @@ -7,13 +7,16 @@ import Button from '$lib/components/ui/button/button.svelte'; import { Separator } from '$lib/components/ui/separator/'; import Badge from '$lib/components/ui/badge/badge.svelte'; + import { handleChartHighlight } from '$lib/util/chart'; + import type { EChartsType } from 'echarts'; let { visible, xValue, series, clickable = false, - xAxisCategories = undefined + xAxisCategories = undefined, + chart }: { visible: boolean; xValue: number | null; @@ -24,6 +27,7 @@ }>; clickable?: boolean; xAxisCategories?: string[]; + chart: EChartsType | null; } = $props(); const tooltipHeight = '300px'; @@ -46,6 +50,14 @@ return formatDate(xValue); } + + function handleMouseEnter(seriesName: string) { + handleChartHighlight(chart, seriesName, 'highlight'); + } + + function handleMouseLeave(seriesName: string) { + handleChartHighlight(chart, seriesName, 'downplay'); + }
    @@ -67,6 +79,8 @@ class="px-3 text-small text-neutral-800 text-left justify-between w-full {!clickable && 'pointer-events-none'}" on:click={() => handleSeriesClick(item)} + on:mouseenter={() => handleMouseEnter(item.name ?? '')} + on:mouseleave={() => handleMouseLeave(item.name ?? '')} > {#if series.length > 1}
    diff --git a/frontend/src/lib/components/ui/button/button.svelte b/frontend/src/lib/components/ui/button/button.svelte index a785e95c87..cedf513ca4 100644 --- a/frontend/src/lib/components/ui/button/button.svelte +++ b/frontend/src/lib/components/ui/button/button.svelte @@ -21,6 +21,8 @@ {...$$restProps} on:click on:keydown + on:mouseenter + on:mouseleave > diff --git a/frontend/src/lib/util/chart.ts b/frontend/src/lib/util/chart.ts new file mode 100644 index 0000000000..a7680975e8 --- /dev/null +++ b/frontend/src/lib/util/chart.ts @@ -0,0 +1,14 @@ +import type { EChartsType } from 'echarts'; + +export function handleChartHighlight( + chart: EChartsType | null, + seriesName: string, + type: 'highlight' | 'downplay' +) { + if (!chart || !seriesName) return; + + chart.dispatchAction({ + type, + seriesName + }); +} From 84e7bafd3e710d8182746d74cb38f27e0182862a Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Tue, 26 Nov 2024 17:20:31 -0500 Subject: [PATCH 053/152] allow some pointer events --- .../src/lib/components/EChartTooltip/EChartTooltip.svelte | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/frontend/src/lib/components/EChartTooltip/EChartTooltip.svelte b/frontend/src/lib/components/EChartTooltip/EChartTooltip.svelte index 283ea66b3d..3737ea1db6 100644 --- a/frontend/src/lib/components/EChartTooltip/EChartTooltip.svelte +++ b/frontend/src/lib/components/EChartTooltip/EChartTooltip.svelte @@ -34,6 +34,7 @@ const dispatch = createEventDispatcher(); function handleSeriesClick(item: (typeof series)[number]) { + if (!clickable) return; dispatch('click', { componentType: 'series', data: [xValue, item.value], @@ -76,8 +77,9 @@ {#each series as item} {/each} From 33ff2f968c4fa93576ab4b2ae8b7922a1054f842 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Wed, 27 Nov 2024 11:55:39 -0500 Subject: [PATCH 056/152] add key to each --- frontend/src/routes/joins/[slug]/+page.svelte | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index 0c8adc0158..1535ddde4e 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -408,7 +408,7 @@ - {#each joinTimeseries.items as group} + {#each joinTimeseries.items as group (group.name)} Date: Wed, 27 Nov 2024 12:03:21 -0500 Subject: [PATCH 057/152] move legend to echart component --- frontend/src/lib/components/EChart/EChart.svelte | 12 ++++++++++++ frontend/src/routes/joins/[slug]/+page.svelte | 16 ++++------------ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/frontend/src/lib/components/EChart/EChart.svelte b/frontend/src/lib/components/EChart/EChart.svelte index 56c1ade96a..a1062e9f5a 100644 --- a/frontend/src/lib/components/EChart/EChart.svelte +++ b/frontend/src/lib/components/EChart/EChart.svelte @@ -4,6 +4,7 @@ import type { ECElementEvent, EChartOption } from 'echarts'; import merge from 'lodash/merge'; import EChartTooltip from '$lib/components/EChartTooltip/EChartTooltip.svelte'; + import CustomEChartLegend from '$lib/components/CustomEChartLegend/CustomEChartLegend.svelte'; import { getCssColorAsHex } from '$lib/util/colors'; let { @@ -15,6 +16,8 @@ height = '230px', enableCustomTooltip = false, enableTooltipClick = false, + showCustomLegend = false, + legendGroup = undefined, markPoint = undefined }: { option: EChartOption; @@ -25,6 +28,8 @@ height?: string; enableCustomTooltip?: boolean; enableTooltipClick?: boolean; + showCustomLegend?: boolean; + legendGroup?: { name: string; items: Array<{ feature: string }> }; markPoint?: ECElementEvent; } = $props(); const dispatch = createEventDispatcher(); @@ -411,3 +416,10 @@
    {/if}
    +{#if showCustomLegend && legendGroup && chartInstance} + +{/if} diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index 1535ddde4e..30a981b5d8 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -429,11 +429,8 @@ enableCustomZoom={true} enableCustomTooltip={true} enableTooltipClick={true} - /> - {/snippet} @@ -504,14 +501,9 @@ enableCustomZoom={true} enableCustomTooltip={true} enableTooltipClick={true} + showCustomLegend={true} + legendGroup={selectedGroup} /> - {#if dialogGroupChart} - - {/if} {/snippet} {/if} From 198d8456e26dc0f1ee2b8d72596b4662dde8a6ff Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Wed, 27 Nov 2024 12:03:36 -0500 Subject: [PATCH 058/152] remove import --- frontend/src/routes/joins/[slug]/+page.svelte | 1 - 1 file changed, 1 deletion(-) diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index 30a981b5d8..1f04e06248 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -30,7 +30,6 @@ import InfoTooltip from '$lib/components/InfoTooltip/InfoTooltip.svelte'; import { Table, TableBody, TableCell, TableRow } from '$lib/components/ui/table/index.js'; import TrueFalseBadge from '$lib/components/TrueFalseBadge/TrueFalseBadge.svelte'; - import CustomEChartLegend from '$lib/components/CustomEChartLegend/CustomEChartLegend.svelte'; import ActionButtons from '$lib/components/ActionButtons/ActionButtons.svelte'; import { Dialog, DialogContent, DialogHeader } from '$lib/components/ui/dialog'; import { formatDate, formatValue } from '$lib/util/format'; From 5a654ba11552767eb0b3cd2642c67cad1e1adff4 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Wed, 27 Nov 2024 12:26:28 -0500 Subject: [PATCH 059/152] share series color code --- .../CustomEChartLegend.svelte | 11 +++-------- frontend/src/lib/util/chart.ts | 15 +++++++++++++++ frontend/src/routes/joins/[slug]/+page.svelte | 19 ++++--------------- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/frontend/src/lib/components/CustomEChartLegend/CustomEChartLegend.svelte b/frontend/src/lib/components/CustomEChartLegend/CustomEChartLegend.svelte index c4f41a277b..54587b3b8b 100644 --- a/frontend/src/lib/components/CustomEChartLegend/CustomEChartLegend.svelte +++ b/frontend/src/lib/components/CustomEChartLegend/CustomEChartLegend.svelte @@ -2,7 +2,7 @@ import { Button } from '$lib/components/ui/button'; import type { EChartsType } from 'echarts'; import { Icon, ChevronDown, ChevronUp } from 'svelte-hero-icons'; - import { handleChartHighlight } from '$lib/util/chart'; + import { getSeriesColor, handleChartHighlight } from '$lib/util/chart'; type LegendItem = { feature: string }; type Props = { @@ -41,10 +41,6 @@ }); } - function getSeriesColor(index: number, colors: string[]): string { - return colors[index % colors.length] || '#000000'; - } - function checkOverflow() { if (!itemsContainer) return; const hasVerticalOverflow = itemsContainer.scrollHeight > itemsContainer.clientHeight; @@ -83,10 +79,9 @@ class={`flex flex-wrap gap-x-4 gap-y-2 flex-1 transition-all duration-150 ease-in-out ${!isExpanded ? 'overflow-hidden' : ''}`} style="height: {isExpanded ? containerHeight + 'px' : containerHeightLine + 'px'};" > - {#each items as { feature }, index} - {@const colors = chart?.getOption()?.color || []} - {@const color = getSeriesColor(index, colors)} + {#each items as { feature } (feature)} {@const isHidden = hiddenSeries[groupName]?.has(feature)} + {@const color = getSeriesColor(chart, feature)}
    {/if}
    - {selectedSeries ? selectedSeries + ' at ' : ''}{formatEventDate()} + {selectedSeries ? `${selectedSeries} at ` : ''}{formatEventDate()}
    From 868ccfb261311368b658b522b6c22d1bf50a8932 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Wed, 27 Nov 2024 14:46:00 -0500 Subject: [PATCH 066/152] sort groupbys --- .../ActionButtons/ActionButtons.svelte | 16 ++++++++++++++-- .../src/lib/types/SortDirection/SortDirection.ts | 6 ++++++ frontend/src/routes/joins/[slug]/+page.server.ts | 15 +++++++++++++++ 3 files changed, 35 insertions(+), 2 deletions(-) create mode 100644 frontend/src/lib/types/SortDirection/SortDirection.ts diff --git a/frontend/src/lib/components/ActionButtons/ActionButtons.svelte b/frontend/src/lib/components/ActionButtons/ActionButtons.svelte index bee9dcd120..26b38d2332 100644 --- a/frontend/src/lib/components/ActionButtons/ActionButtons.svelte +++ b/frontend/src/lib/components/ActionButtons/ActionButtons.svelte @@ -2,11 +2,23 @@ import { Button } from '$lib/components/ui/button'; import { cn } from '$lib/utils'; import { Icon, Plus, ArrowsUpDown, Square3Stack3d, XMark } from 'svelte-hero-icons'; + import { goto } from '$app/navigation'; + import { page } from '$app/stores'; + import { getSortDirection, type SortDirection } from '$lib/types/SortDirection/SortDirection'; let { showCluster = false, class: className }: { showCluster?: boolean; class?: string } = $props(); let activeCluster = showCluster ? 'GroupBys' : null; + + let currentSort: SortDirection = $derived.by(() => getSortDirection($page.url.searchParams)); + + function handleSort() { + const newSort: SortDirection = currentSort === 'asc' ? 'desc' : 'asc'; + const url = new URL($page.url); + url.searchParams.set('sort', newSort); + goto(url, { replaceState: true }); + }
    @@ -32,9 +44,9 @@ Filter - {#if showCluster}
    {/if} @@ -371,22 +363,16 @@
    {/if} - -
    -
    -
    - {#if isZoomed} - - {/if} +
    - + +
    - From 33570ce06d36b73b3e16676bac5789539f8e4de7 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Wed, 27 Nov 2024 15:24:42 -0500 Subject: [PATCH 068/152] use chart controls in dialog --- frontend/src/routes/joins/[slug]/+page.svelte | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index 2466649978..479090fccd 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -441,7 +441,7 @@ highlightSeries(selectedSeries ?? '', dialogGroupChart, 'highlight')} onmouseleave={() => highlightSeries(selectedSeries ?? '', dialogGroupChart, 'downplay')} @@ -456,6 +456,7 @@ {selectedSeries ? `${selectedSeries} at ` : ''}{formatEventDate()}
    + @@ -465,16 +466,6 @@ )} {#if selectedGroup} - {#snippet headerContentRight()} -
    -
    - {#if isZoomed} - - {/if} -
    - -
    - {/snippet} {#snippet collapsibleContent()} Date: Sun, 3 Nov 2024 14:04:28 -0800 Subject: [PATCH 069/152] Changes --- .../spark/utils/DataFramePrinter.scala | 144 ++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala diff --git a/spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala b/spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala new file mode 100644 index 0000000000..764b3fc1e3 --- /dev/null +++ b/spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala @@ -0,0 +1,144 @@ +package ai.chronon.spark.utils + +import org.apache.commons.lang3.StringUtils +import org.apache.spark.sql.DataFrame + +import java.util.logging.Logger + + +// there was no way to print a message, and the contents of the dataframe together +// the methods to convert a dataframe into a string were private inside spark +// so pulling it out +object DataFramePrinter { + private val fullWidthRegex = ("""[""" + + // scalastyle:off nonascii + "\u1100-\u115F" + + "\u2E80-\uA4CF" + + "\uAC00-\uD7A3" + + "\uF900-\uFAFF" + + "\uFE10-\uFE19" + + "\uFE30-\uFE6F" + + "\uFF00-\uFF60" + + "\uFFE0-\uFFE6" + + // scalastyle:on nonascii + """]""").r + + private def stringHalfWidth(str: String): Int = { + if (str == null) 0 else str.length + fullWidthRegex.findAllIn(str).size + } + + private def escapeMetaCharacters(str: String): String = { + str.replaceAll("\n", "\\\\n") + .replaceAll("\r", "\\\\r") + .replaceAll("\t", "\\\\t") + .replaceAll("\f", "\\\\f") + .replaceAll("\b", "\\\\b") + .replaceAll("\u000B", "\\\\v") + .replaceAll("\u0007", "\\\\a") + } + + def showString( df: DataFrame, + numRows: Int = 10, + truncate: Int = 20, + vertical: Boolean = false): String = { + val data = df.take(numRows + 1) + + // For array values, replace Seq and Array with square brackets + // For cells that are beyond `truncate` characters, replace it with the + // first `truncate-3` and "..." + val tmpRows = df.schema.fieldNames.map(escapeMetaCharacters).toSeq +: data.map { row => + row.toSeq.map { cell => + assert(cell != null, "ToPrettyString is not nullable and should not return null value") + // Escapes meta-characters not to break the `showString` format + val str = escapeMetaCharacters(cell.toString) + if (truncate > 0 && str.length > truncate) { + // do not show ellipses for strings shorter than 4 characters. + if (truncate < 4) str.substring(0, truncate) + else str.substring(0, truncate - 3) + "..." + } else { + str + } + }: Seq[String] + } + + val hasMoreData = tmpRows.length - 1 > numRows + val rows = tmpRows.take(numRows + 1) + + val sb = new StringBuilder + val numCols = df.schema.fieldNames.length + // We set a minimum column width at '3' + val minimumColWidth = 3 + + if (!vertical) { + // Initialise the width of each column to a minimum value + val colWidths = Array.fill(numCols)(minimumColWidth) + + // Compute the width of each column + for (row <- rows) { + for ((cell, i) <- row.zipWithIndex) { + colWidths(i) = math.max(colWidths(i), stringHalfWidth(cell)) + } + } + + val paddedRows = rows.map { row => + row.zipWithIndex.map { case (cell, i) => + if (truncate > 0) { + StringUtils.leftPad(cell, colWidths(i) - stringHalfWidth(cell) + cell.length) + } else { + StringUtils.rightPad(cell, colWidths(i) - stringHalfWidth(cell) + cell.length) + } + } + } + + // Create SeparateLine + val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() + + // column names + paddedRows.head.addString(sb, "|", "|", "|\n") + sb.append(sep) + + // data + paddedRows.tail.foreach(_.addString(sb, "|", "|", "|\n")) + sb.append(sep) + } else { + // Extended display mode enabled + val fieldNames = rows.head + val dataRows = rows.tail + + // Compute the width of field name and data columns + val fieldNameColWidth = fieldNames.foldLeft(minimumColWidth) { case (curMax, fieldName) => + math.max(curMax, stringHalfWidth(fieldName)) + } + val dataColWidth = dataRows.foldLeft(minimumColWidth) { case (curMax, row) => + math.max(curMax, row.map(cell => stringHalfWidth(cell)).max) + } + + dataRows.zipWithIndex.foreach { case (row, i) => + // "+ 5" in size means a character length except for padded names and data + val rowHeader = StringUtils.rightPad( + s"-RECORD $i", fieldNameColWidth + dataColWidth + 5, "-") + sb.append(rowHeader).append("\n") + row.zipWithIndex.map { case (cell, j) => + val fieldName = StringUtils.rightPad(fieldNames(j), + fieldNameColWidth - stringHalfWidth(fieldNames(j)) + fieldNames(j).length) + val data = StringUtils.rightPad(cell, + dataColWidth - stringHalfWidth(cell) + cell.length) + s" $fieldName | $data " + }.addString(sb, "", "\n", "\n") + } + } + + // Print a footer + if (vertical && rows.tail.isEmpty) { + // In a vertical mode, print an empty row set explicitly + sb.append("(0 rows)\n") + } else if (hasMoreData) { + // For Data that has more than "numRows" records + val rowsString = if (numRows == 1) "row" else "rows" + sb.append(s"only showing top $numRows $rowsString\n") + } + + sb.toString() + } + +} From e3d21152e7f0dfa9ac3f50494c6d1b68936d09b1 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Sun, 3 Nov 2024 14:06:46 -0800 Subject: [PATCH 070/152] changes so far --- .../spark/utils/DataFramePrinter.scala | 61 ++++++++++--------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala b/spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala index 764b3fc1e3..cc7ce60de3 100644 --- a/spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala +++ b/spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala @@ -3,9 +3,6 @@ package ai.chronon.spark.utils import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.DataFrame -import java.util.logging.Logger - - // there was no way to print a message, and the contents of the dataframe together // the methods to convert a dataframe into a string were private inside spark // so pulling it out @@ -28,7 +25,8 @@ object DataFramePrinter { } private def escapeMetaCharacters(str: String): String = { - str.replaceAll("\n", "\\\\n") + str + .replaceAll("\n", "\\\\n") .replaceAll("\r", "\\\\r") .replaceAll("\t", "\\\\t") .replaceAll("\f", "\\\\f") @@ -37,10 +35,7 @@ object DataFramePrinter { .replaceAll("\u0007", "\\\\a") } - def showString( df: DataFrame, - numRows: Int = 10, - truncate: Int = 20, - vertical: Boolean = false): String = { + def showString(df: DataFrame, numRows: Int = 10, truncate: Int = 20, vertical: Boolean = false): String = { val data = df.take(numRows + 1) // For array values, replace Seq and Array with square brackets @@ -81,12 +76,13 @@ object DataFramePrinter { } val paddedRows = rows.map { row => - row.zipWithIndex.map { case (cell, i) => - if (truncate > 0) { - StringUtils.leftPad(cell, colWidths(i) - stringHalfWidth(cell) + cell.length) - } else { - StringUtils.rightPad(cell, colWidths(i) - stringHalfWidth(cell) + cell.length) - } + row.zipWithIndex.map { + case (cell, i) => + if (truncate > 0) { + StringUtils.leftPad(cell, colWidths(i) - stringHalfWidth(cell) + cell.length) + } else { + StringUtils.rightPad(cell, colWidths(i) - stringHalfWidth(cell) + cell.length) + } } } @@ -106,25 +102,30 @@ object DataFramePrinter { val dataRows = rows.tail // Compute the width of field name and data columns - val fieldNameColWidth = fieldNames.foldLeft(minimumColWidth) { case (curMax, fieldName) => - math.max(curMax, stringHalfWidth(fieldName)) + val fieldNameColWidth = fieldNames.foldLeft(minimumColWidth) { + case (curMax, fieldName) => + math.max(curMax, stringHalfWidth(fieldName)) } - val dataColWidth = dataRows.foldLeft(minimumColWidth) { case (curMax, row) => - math.max(curMax, row.map(cell => stringHalfWidth(cell)).max) + val dataColWidth = dataRows.foldLeft(minimumColWidth) { + case (curMax, row) => + math.max(curMax, row.map(cell => stringHalfWidth(cell)).max) } - dataRows.zipWithIndex.foreach { case (row, i) => - // "+ 5" in size means a character length except for padded names and data - val rowHeader = StringUtils.rightPad( - s"-RECORD $i", fieldNameColWidth + dataColWidth + 5, "-") - sb.append(rowHeader).append("\n") - row.zipWithIndex.map { case (cell, j) => - val fieldName = StringUtils.rightPad(fieldNames(j), - fieldNameColWidth - stringHalfWidth(fieldNames(j)) + fieldNames(j).length) - val data = StringUtils.rightPad(cell, - dataColWidth - stringHalfWidth(cell) + cell.length) - s" $fieldName | $data " - }.addString(sb, "", "\n", "\n") + dataRows.zipWithIndex.foreach { + case (row, i) => + // "+ 5" in size means a character length except for padded names and data + val rowHeader = StringUtils.rightPad(s"-RECORD $i", fieldNameColWidth + dataColWidth + 5, "-") + sb.append(rowHeader).append("\n") + row.zipWithIndex + .map { + case (cell, j) => + val fieldName = + StringUtils.rightPad(fieldNames(j), + fieldNameColWidth - stringHalfWidth(fieldNames(j)) + fieldNames(j).length) + val data = StringUtils.rightPad(cell, dataColWidth - stringHalfWidth(cell) + cell.length) + s" $fieldName | $data " + } + .addString(sb, "", "\n", "\n") } } From c0e5d8fabc3082001d2616b5d73b5ec0499bc827 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Sun, 3 Nov 2024 14:21:56 -0800 Subject: [PATCH 071/152] remove unused file --- .../spark/utils/DataFramePrinter.scala | 145 ------------------ 1 file changed, 145 deletions(-) delete mode 100644 spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala diff --git a/spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala b/spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala deleted file mode 100644 index cc7ce60de3..0000000000 --- a/spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala +++ /dev/null @@ -1,145 +0,0 @@ -package ai.chronon.spark.utils - -import org.apache.commons.lang3.StringUtils -import org.apache.spark.sql.DataFrame - -// there was no way to print a message, and the contents of the dataframe together -// the methods to convert a dataframe into a string were private inside spark -// so pulling it out -object DataFramePrinter { - private val fullWidthRegex = ("""[""" + - // scalastyle:off nonascii - "\u1100-\u115F" + - "\u2E80-\uA4CF" + - "\uAC00-\uD7A3" + - "\uF900-\uFAFF" + - "\uFE10-\uFE19" + - "\uFE30-\uFE6F" + - "\uFF00-\uFF60" + - "\uFFE0-\uFFE6" + - // scalastyle:on nonascii - """]""").r - - private def stringHalfWidth(str: String): Int = { - if (str == null) 0 else str.length + fullWidthRegex.findAllIn(str).size - } - - private def escapeMetaCharacters(str: String): String = { - str - .replaceAll("\n", "\\\\n") - .replaceAll("\r", "\\\\r") - .replaceAll("\t", "\\\\t") - .replaceAll("\f", "\\\\f") - .replaceAll("\b", "\\\\b") - .replaceAll("\u000B", "\\\\v") - .replaceAll("\u0007", "\\\\a") - } - - def showString(df: DataFrame, numRows: Int = 10, truncate: Int = 20, vertical: Boolean = false): String = { - val data = df.take(numRows + 1) - - // For array values, replace Seq and Array with square brackets - // For cells that are beyond `truncate` characters, replace it with the - // first `truncate-3` and "..." - val tmpRows = df.schema.fieldNames.map(escapeMetaCharacters).toSeq +: data.map { row => - row.toSeq.map { cell => - assert(cell != null, "ToPrettyString is not nullable and should not return null value") - // Escapes meta-characters not to break the `showString` format - val str = escapeMetaCharacters(cell.toString) - if (truncate > 0 && str.length > truncate) { - // do not show ellipses for strings shorter than 4 characters. - if (truncate < 4) str.substring(0, truncate) - else str.substring(0, truncate - 3) + "..." - } else { - str - } - }: Seq[String] - } - - val hasMoreData = tmpRows.length - 1 > numRows - val rows = tmpRows.take(numRows + 1) - - val sb = new StringBuilder - val numCols = df.schema.fieldNames.length - // We set a minimum column width at '3' - val minimumColWidth = 3 - - if (!vertical) { - // Initialise the width of each column to a minimum value - val colWidths = Array.fill(numCols)(minimumColWidth) - - // Compute the width of each column - for (row <- rows) { - for ((cell, i) <- row.zipWithIndex) { - colWidths(i) = math.max(colWidths(i), stringHalfWidth(cell)) - } - } - - val paddedRows = rows.map { row => - row.zipWithIndex.map { - case (cell, i) => - if (truncate > 0) { - StringUtils.leftPad(cell, colWidths(i) - stringHalfWidth(cell) + cell.length) - } else { - StringUtils.rightPad(cell, colWidths(i) - stringHalfWidth(cell) + cell.length) - } - } - } - - // Create SeparateLine - val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() - - // column names - paddedRows.head.addString(sb, "|", "|", "|\n") - sb.append(sep) - - // data - paddedRows.tail.foreach(_.addString(sb, "|", "|", "|\n")) - sb.append(sep) - } else { - // Extended display mode enabled - val fieldNames = rows.head - val dataRows = rows.tail - - // Compute the width of field name and data columns - val fieldNameColWidth = fieldNames.foldLeft(minimumColWidth) { - case (curMax, fieldName) => - math.max(curMax, stringHalfWidth(fieldName)) - } - val dataColWidth = dataRows.foldLeft(minimumColWidth) { - case (curMax, row) => - math.max(curMax, row.map(cell => stringHalfWidth(cell)).max) - } - - dataRows.zipWithIndex.foreach { - case (row, i) => - // "+ 5" in size means a character length except for padded names and data - val rowHeader = StringUtils.rightPad(s"-RECORD $i", fieldNameColWidth + dataColWidth + 5, "-") - sb.append(rowHeader).append("\n") - row.zipWithIndex - .map { - case (cell, j) => - val fieldName = - StringUtils.rightPad(fieldNames(j), - fieldNameColWidth - stringHalfWidth(fieldNames(j)) + fieldNames(j).length) - val data = StringUtils.rightPad(cell, dataColWidth - stringHalfWidth(cell) + cell.length) - s" $fieldName | $data " - } - .addString(sb, "", "\n", "\n") - } - } - - // Print a footer - if (vertical && rows.tail.isEmpty) { - // In a vertical mode, print an empty row set explicitly - sb.append("(0 rows)\n") - } else if (hasMoreData) { - // For Data that has more than "numRows" records - val rowsString = if (numRows == 1) "row" else "rows" - sb.append(s"only showing top $numRows $rowsString\n") - } - - sb.toString() - } - -} From e718f91e8a93ff5962ce669c5dd25be7b44496ea Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Sun, 3 Nov 2024 14:40:25 -0800 Subject: [PATCH 072/152] fix --- .../scala/ai/chronon/api/ColorPrinter.scala | 24 ------------------- 1 file changed, 24 deletions(-) delete mode 100644 api/src/main/scala/ai/chronon/api/ColorPrinter.scala diff --git a/api/src/main/scala/ai/chronon/api/ColorPrinter.scala b/api/src/main/scala/ai/chronon/api/ColorPrinter.scala deleted file mode 100644 index 4d1dc57c50..0000000000 --- a/api/src/main/scala/ai/chronon/api/ColorPrinter.scala +++ /dev/null @@ -1,24 +0,0 @@ -package ai.chronon.api - -object ColorPrinter { - // ANSI escape codes for text colors - private val ANSI_RESET = "\u001B[0m" - - // Colors chosen for visibility on both dark and light backgrounds - // More muted colors that should still be visible on various backgrounds - private val ANSI_RED = "\u001B[38;5;131m" // Muted red (soft burgundy) - private val ANSI_BLUE = "\u001B[38;5;32m" // Medium blue - private val ANSI_YELLOW = "\u001B[38;5;172m" // Muted Orange - private val ANSI_GREEN = "\u001B[38;5;28m" // Forest green - - private val BOLD = "\u001B[1m" - - implicit class ColorString(val s: String) extends AnyVal { - def red: String = s"$ANSI_RED$s$ANSI_RESET" - def blue: String = s"$ANSI_BLUE$s$ANSI_RESET" - def yellow: String = s"$ANSI_YELLOW$s$ANSI_RESET" - def green: String = s"$ANSI_GREEN$s$ANSI_RESET" - def low: String = s.toLowerCase - def highlight: String = s"$BOLD$ANSI_RED$s$ANSI_RESET" - } -} From d3ff253b7a78ed922f9028c5a1c768390ba187e8 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Sun, 3 Nov 2024 14:46:02 -0800 Subject: [PATCH 073/152] adding back color printer --- .../scala/ai/chronon/api/ColorPrinter.scala | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 api/src/main/scala/ai/chronon/api/ColorPrinter.scala diff --git a/api/src/main/scala/ai/chronon/api/ColorPrinter.scala b/api/src/main/scala/ai/chronon/api/ColorPrinter.scala new file mode 100644 index 0000000000..bf44fa2d13 --- /dev/null +++ b/api/src/main/scala/ai/chronon/api/ColorPrinter.scala @@ -0,0 +1,21 @@ +package ai.chronon.api + +object ColorPrinter { + // ANSI escape codes for text colors + private val ANSI_RESET = "\u001B[0m" + + // Colors chosen for visibility on both dark and light backgrounds + // More muted colors that should still be visible on various backgrounds + private val ANSI_RED = "\u001B[38;5;131m" // Muted red (soft burgundy) + private val ANSI_BLUE = "\u001B[38;5;32m" // Medium blue + private val ANSI_YELLOW = "\u001B[38;5;172m" // Muted Orange + private val ANSI_GREEN = "\u001B[38;5;28m" // Forest green + + implicit class ColorString(val s: String) extends AnyVal { + def red: String = s"$ANSI_RED$s$ANSI_RESET" + def blue: String = s"$ANSI_BLUE$s$ANSI_RESET" + def yellow: String = s"$ANSI_YELLOW$s$ANSI_RESET" + def green: String = s"$ANSI_GREEN$s$ANSI_RESET" + def low: String = s.toLowerCase + } +} \ No newline at end of file From bbf1ddd201983b774128d4d8c03f01a1534ae962 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Sun, 3 Nov 2024 17:39:58 -0800 Subject: [PATCH 074/152] scalafmt fix --- api/src/main/scala/ai/chronon/api/ColorPrinter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/src/main/scala/ai/chronon/api/ColorPrinter.scala b/api/src/main/scala/ai/chronon/api/ColorPrinter.scala index bf44fa2d13..e779e3eaf1 100644 --- a/api/src/main/scala/ai/chronon/api/ColorPrinter.scala +++ b/api/src/main/scala/ai/chronon/api/ColorPrinter.scala @@ -18,4 +18,4 @@ object ColorPrinter { def green: String = s"$ANSI_GREEN$s$ANSI_RESET" def low: String = s.toLowerCase } -} \ No newline at end of file +} From d61c83a486f69276402308e3b51d6ce85d5415b3 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Wed, 6 Nov 2024 18:31:57 -0800 Subject: [PATCH 075/152] assign intervals --- .../online/stats/DistanceMetrics.scala | 147 ++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala diff --git a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala new file mode 100644 index 0000000000..0620f73aa0 --- /dev/null +++ b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala @@ -0,0 +1,147 @@ +package ai.chronon.online.stats + +import ai.chronon.api.ColorPrinter.ColorString +import ai.chronon.api.Window + +import scala.math._ + + +object DistanceMetrics { + + // TODO move this to unit test + def main(args: Array[String]): Unit = { + val A = Array(0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200).map( + _.toDouble) + val B = Array(5, 15, 25, 35, 45, 55, 65, 75, 85, 95, 115, 115, 115, 135, 145, 155, 165, 1175, 1205, 1205, 1205).map( + _.toDouble) + + val jsd = jensenShannonDivergence(A, B) + println(f"The Jensen-Shannon Divergence between distributions A and B is: $jsd%.5f") + val psi = populationStabilityIndex(A, B) + println(f"The Population Stability Index between distributions A and B is: $psi%.5f") + val hd = hellingerDistance(A, B) + println(f"The Hellinger Distance between distributions A and B is: $hd%.5f") + + + // format: off + // aligned vertically for easier reasoning + val ptiles = Array( 1, 4, 6,6,6, 8, 9 ) + val breaks = Array(0, 1, 2, 3, 5, 6, 7, 8, 9, 10) + // format: on + + //val interval = 0.25 + val expected = Array(0.0, 1.0/3.0 , 1.0/3.0, (1.0)/(3.0) + (1.0)/(2.0), (1.0)/(2.0), 2.5, 0.5, 1, 0) + + val result = AssignIntervals.on(ptiles = ptiles.map(_.toDouble), breaks = breaks.map(_.toDouble)) + + expected.zip(result).foreach{case (e, r) => println(s"exp: $e res: $r")} + + } + + // all same size - used for distance computation and for drill down display in front-end + case class Distributions(p: Array[Double], q: Array[Double], bins: Array[String]) + + case class Comparison[T](previous: T, current: T, timeDelta: Window) + + + def functionBuilder[T](binningFunc: Comparison[T] => Distributions, distanceFunc: Distributions => Double): Comparison[T] => Double = { + c => + val dists = binningFunc(c) + val distance = distanceFunc(dists) + distance + } + + + def hellingerDistance(p: Array[Double], q: Array[Double]): Double = { + val (pProbs, qProbs) = computePDFs(p, q) + + sqrt( + pProbs + .zip(qProbs) + .map { + case (pi, qi) => + pow(sqrt(pi) - sqrt(qi), 2) + } + .sum / 2) + } + + def populationStabilityIndex(p: Array[Double], q: Array[Double]): Double = { + val (pProbs, qProbs) = computePDFs(p, q) + + pProbs + .zip(qProbs) + .map { + case (pi, qi) => + if (pi > 0 && qi > 0) (qi - pi) * log(qi / pi) + else 0.0 // Handle zero probabilities + } + .sum + } + + def jensenShannonDivergence(p: Array[Double], q: Array[Double]): Double = { + // Step 1: compute probability distributions on the same x-axis + val (pdfP, pdfQ) = computePDFs(p, q) + + // Step 2: compute the mixture distribution M + val pdfM = pdfP.zip(pdfQ).map { case (a, b) => 0.5 * (a + b) } + + // Step 3: compute divergence + val klAM = klDivergence(pdfP, pdfM) + val klBM = klDivergence(pdfQ, pdfM) + + 0.5 * (klAM + klBM) + } + + def computePDFs(p: Array[Double], q: Array[Double]): (Array[Double], Array[Double]) = { + val breakpoints = (p ++ q).distinct.sorted + + val pdfP = computePDF(p, breakpoints).map(_.value) + val pdfQ = computePDF(q, breakpoints).map(_.value) + + pdfP -> pdfQ + } + + case class Mass(value: Double, isPointMass: Boolean) + + def computePDF(percentiles: Array[Double], breaks: Array[Double]): Array[Mass] = { + val n = percentiles.length + require(percentiles.length > 2, "Need at-least 3 percentiles to plot a distribution") + + val interval: Double = 1.toDouble / (n - 1.0) + + def mass(i: Int, eh: Int): Mass = { + def indexMass(i: Int): Double = interval / (percentiles(i) - percentiles(i-1)) + val isPointMass = eh > 1 + val m = (i, eh) match { + case (0, _) => 0.0 // before range + case (x, 0) if x>=n => 0.0 // after range + case (_, e) if e > 1 => (e - 1) * interval // point mass + case (x, 1) if x==n => indexMass(n-1) // exactly at end of range + case (x, _) => indexMass(x) // somewhere in between + } + Mass(m, isPointMass) + } + + var i = 0 + breaks.map { break => + var equalityHits = 0 + while (i < percentiles.length && percentiles(i) <= break) { + if (percentiles(i) == break) equalityHits += 1 + i += 1 + } + mass(i, equalityHits) + } + } + + def klDivergence(p: Array[Double], q: Array[Double]): Double = { + require(p.length == q.length, s"Inputs are of different length ${p.length}, ${q.length}") + var i = 0 + var result = 0.0 + while(i < p.length) { + val inc = if (p(i)> 0 && q(i) > 0) p(i) * math.log(p(i) / q(i)) else 0 + result += inc + i += 1 + } + result + } +} From ce394283beaa95ece1564fe7f06c7ca74f454bf3 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Wed, 6 Nov 2024 18:33:02 -0800 Subject: [PATCH 076/152] assign intervals --- .../online/stats/DistanceMetrics.scala | 36 ++++++++----------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala index 0620f73aa0..a7896efc94 100644 --- a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala +++ b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala @@ -1,11 +1,8 @@ package ai.chronon.online.stats - -import ai.chronon.api.ColorPrinter.ColorString import ai.chronon.api.Window import scala.math._ - object DistanceMetrics { // TODO move this to unit test @@ -22,7 +19,6 @@ object DistanceMetrics { val hd = hellingerDistance(A, B) println(f"The Hellinger Distance between distributions A and B is: $hd%.5f") - // format: off // aligned vertically for easier reasoning val ptiles = Array( 1, 4, 6,6,6, 8, 9 ) @@ -30,11 +26,11 @@ object DistanceMetrics { // format: on //val interval = 0.25 - val expected = Array(0.0, 1.0/3.0 , 1.0/3.0, (1.0)/(3.0) + (1.0)/(2.0), (1.0)/(2.0), 2.5, 0.5, 1, 0) + val expected = Array(0.0, 1.0 / 3.0, 1.0 / 3.0, (1.0) / (3.0) + (1.0) / (2.0), (1.0) / (2.0), 2.5, 0.5, 1, 0) val result = AssignIntervals.on(ptiles = ptiles.map(_.toDouble), breaks = breaks.map(_.toDouble)) - expected.zip(result).foreach{case (e, r) => println(s"exp: $e res: $r")} + expected.zip(result).foreach { case (e, r) => println(s"exp: $e res: $r") } } @@ -43,15 +39,13 @@ object DistanceMetrics { case class Comparison[T](previous: T, current: T, timeDelta: Window) - - def functionBuilder[T](binningFunc: Comparison[T] => Distributions, distanceFunc: Distributions => Double): Comparison[T] => Double = { - c => - val dists = binningFunc(c) - val distance = distanceFunc(dists) - distance + def functionBuilder[T](binningFunc: Comparison[T] => Distributions, + distanceFunc: Distributions => Double): Comparison[T] => Double = { c => + val dists = binningFunc(c) + val distance = distanceFunc(dists) + distance } - def hellingerDistance(p: Array[Double], q: Array[Double]): Double = { val (pProbs, qProbs) = computePDFs(p, q) @@ -110,14 +104,14 @@ object DistanceMetrics { val interval: Double = 1.toDouble / (n - 1.0) def mass(i: Int, eh: Int): Mass = { - def indexMass(i: Int): Double = interval / (percentiles(i) - percentiles(i-1)) + def indexMass(i: Int): Double = interval / (percentiles(i) - percentiles(i - 1)) val isPointMass = eh > 1 val m = (i, eh) match { - case (0, _) => 0.0 // before range - case (x, 0) if x>=n => 0.0 // after range - case (_, e) if e > 1 => (e - 1) * interval // point mass - case (x, 1) if x==n => indexMass(n-1) // exactly at end of range - case (x, _) => indexMass(x) // somewhere in between + case (0, _) => 0.0 // before range + case (x, 0) if x >= n => 0.0 // after range + case (_, e) if e > 1 => (e - 1) * interval // point mass + case (x, 1) if x == n => indexMass(n - 1) // exactly at end of range + case (x, _) => indexMass(x) // somewhere in between } Mass(m, isPointMass) } @@ -137,8 +131,8 @@ object DistanceMetrics { require(p.length == q.length, s"Inputs are of different length ${p.length}, ${q.length}") var i = 0 var result = 0.0 - while(i < p.length) { - val inc = if (p(i)> 0 && q(i) > 0) p(i) * math.log(p(i) / q(i)) else 0 + while (i < p.length) { + val inc = if (p(i) > 0 && q(i) > 0) p(i) * math.log(p(i) / q(i)) else 0 result += inc i += 1 } From 787b66416870b3fc1b558602b4550697d83b17b1 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Wed, 6 Nov 2024 23:12:52 -0800 Subject: [PATCH 077/152] tile summary distance --- .../online/stats/DistanceMetrics.scala | 192 +++++++++--------- 1 file changed, 91 insertions(+), 101 deletions(-) diff --git a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala index a7896efc94..e4820172e4 100644 --- a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala +++ b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala @@ -1,7 +1,9 @@ package ai.chronon.online.stats -import ai.chronon.api.Window +import ai.chronon.api.DriftMetric +import ai.chronon.api.TileSummaries import scala.math._ +import scala.util.ScalaJavaConversions.IteratorOps object DistanceMetrics { @@ -12,130 +14,118 @@ object DistanceMetrics { val B = Array(5, 15, 25, 35, 45, 55, 65, 75, 85, 95, 115, 115, 115, 135, 145, 155, 165, 1175, 1205, 1205, 1205).map( _.toDouble) - val jsd = jensenShannonDivergence(A, B) + val jsd = percentileDistance(A, B, DriftMetric.JENSEN_SHANNON) println(f"The Jensen-Shannon Divergence between distributions A and B is: $jsd%.5f") - val psi = populationStabilityIndex(A, B) + val psi = percentileDistance(A, B, DriftMetric.PSI) println(f"The Population Stability Index between distributions A and B is: $psi%.5f") - val hd = hellingerDistance(A, B) + val hd = percentileDistance(A, B, DriftMetric.HELLINGER) println(f"The Hellinger Distance between distributions A and B is: $hd%.5f") - - // format: off - // aligned vertically for easier reasoning - val ptiles = Array( 1, 4, 6,6,6, 8, 9 ) - val breaks = Array(0, 1, 2, 3, 5, 6, 7, 8, 9, 10) - // format: on - - //val interval = 0.25 - val expected = Array(0.0, 1.0 / 3.0, 1.0 / 3.0, (1.0) / (3.0) + (1.0) / (2.0), (1.0) / (2.0), 2.5, 0.5, 1, 0) - - val result = AssignIntervals.on(ptiles = ptiles.map(_.toDouble), breaks = breaks.map(_.toDouble)) - - expected.zip(result).foreach { case (e, r) => println(s"exp: $e res: $r") } - } - // all same size - used for distance computation and for drill down display in front-end - case class Distributions(p: Array[Double], q: Array[Double], bins: Array[String]) - - case class Comparison[T](previous: T, current: T, timeDelta: Window) - - def functionBuilder[T](binningFunc: Comparison[T] => Distributions, - distanceFunc: Distributions => Double): Comparison[T] => Double = { c => - val dists = binningFunc(c) - val distance = distanceFunc(dists) - distance + @inline + private def toArray(l: java.util.List[java.lang.Double]): Array[Double] = { + l.iterator().toScala.map(_.toDouble).toArray } - def hellingerDistance(p: Array[Double], q: Array[Double]): Double = { - val (pProbs, qProbs) = computePDFs(p, q) - - sqrt( - pProbs - .zip(qProbs) - .map { - case (pi, qi) => - pow(sqrt(pi) - sqrt(qi), 2) - } - .sum / 2) + @inline + private def normalizeInplace(arr: Array[Double]): Array[Double] = { + val sum = arr.sum + var i = 0 + while (i < arr.length) { + arr.update(i, arr(i) / sum) + i += 1 + } + arr } - def populationStabilityIndex(p: Array[Double], q: Array[Double]): Double = { - val (pProbs, qProbs) = computePDFs(p, q) - - pProbs - .zip(qProbs) - .map { - case (pi, qi) => - if (pi > 0 && qi > 0) (qi - pi) * log(qi / pi) - else 0.0 // Handle zero probabilities - } - .sum + def distance(a: TileSummaries, b: TileSummaries, metric: DriftMetric): java.lang.Double = { + require(a.isSetPercentiles == b.isSetPercentiles, "Percentiles should be either set or unset together") + require(a.isSetHistogram == b.isSetHistogram, "Histograms should be either set or unset together") + + val isContinuous = a.isSetPercentiles && b.isSetPercentiles + val isCategorical = a.isSetHistogram && b.isSetHistogram + if (isContinuous) + percentileDistance(toArray(a.getPercentiles), toArray(b.getPercentiles), metric) + else if (isCategorical) + categoricalDistance(a.getHistogram, b.getHistogram, metric) + else + null } - def jensenShannonDivergence(p: Array[Double], q: Array[Double]): Double = { - // Step 1: compute probability distributions on the same x-axis - val (pdfP, pdfQ) = computePDFs(p, q) + def percentileDistance(a: Array[Double], b: Array[Double], metric: DriftMetric): Double = { + val breaks = (a ++ b).sorted.distinct + val aProjected = AssignIntervals.on(a, breaks) + val bProjected = AssignIntervals.on(b, breaks) - // Step 2: compute the mixture distribution M - val pdfM = pdfP.zip(pdfQ).map { case (a, b) => 0.5 * (a + b) } + val aNormalized = normalizeInplace(aProjected) + val bNormalized = normalizeInplace(bProjected) - // Step 3: compute divergence - val klAM = klDivergence(pdfP, pdfM) - val klBM = klDivergence(pdfQ, pdfM) + val func = termFunc(metric) - 0.5 * (klAM + klBM) + var i = 0 + var result = 0.0 + while (i < aNormalized.length) { + result += func(aNormalized(i), bNormalized(i)) + i += 1 + } + result } - def computePDFs(p: Array[Double], q: Array[Double]): (Array[Double], Array[Double]) = { - val breakpoints = (p ++ q).distinct.sorted + type Histogram = java.util.Map[String, java.lang.Long] + def categoricalDistance(a: Histogram, b: Histogram, metric: DriftMetric): Double = { + val aIt = a.entrySet().iterator() + var result = 0.0 + val func = termFunc(metric) + while (aIt.hasNext) { + val entry = aIt.next() + val key = entry.getKey + val aVal = entry.getValue.toDouble + val bValOpt = b.get(key) + val bVal = if (bValOpt == null) bValOpt.toDouble else 0.0 + val term = func(aVal, bVal) + result += term + } - val pdfP = computePDF(p, breakpoints).map(_.value) - val pdfQ = computePDF(q, breakpoints).map(_.value) + val bIt = b.entrySet().iterator() + while (bIt.hasNext) { + val entry = bIt.next() + val key = entry.getKey + val bVal = entry.getValue.toDouble + val aValOpt = a.get(key) + if (aValOpt == null) { + result += func(0.0, bVal) + } + } - pdfP -> pdfQ + result } - case class Mass(value: Double, isPointMass: Boolean) - - def computePDF(percentiles: Array[Double], breaks: Array[Double]): Array[Mass] = { - val n = percentiles.length - require(percentiles.length > 2, "Need at-least 3 percentiles to plot a distribution") + @inline + def klDivergenceTerm(a: Double, b: Double): Double = { + if (a > 0 && b > 0) a * math.log(a / b) else 0 + } - val interval: Double = 1.toDouble / (n - 1.0) + @inline + def jsdTerm(a: Double, b: Double): Double = { + val m = (a + b) * 0.5 + (klDivergenceTerm(a, m) + klDivergenceTerm(b, m)) * 0.5 + } - def mass(i: Int, eh: Int): Mass = { - def indexMass(i: Int): Double = interval / (percentiles(i) - percentiles(i - 1)) - val isPointMass = eh > 1 - val m = (i, eh) match { - case (0, _) => 0.0 // before range - case (x, 0) if x >= n => 0.0 // after range - case (_, e) if e > 1 => (e - 1) * interval // point mass - case (x, 1) if x == n => indexMass(n - 1) // exactly at end of range - case (x, _) => indexMass(x) // somewhere in between - } - Mass(m, isPointMass) - } + @inline + def hellingerTerm(a: Double, b: Double): Double = { + pow(sqrt(a) - sqrt(b), 2) * 0.5 + } - var i = 0 - breaks.map { break => - var equalityHits = 0 - while (i < percentiles.length && percentiles(i) <= break) { - if (percentiles(i) == break) equalityHits += 1 - i += 1 - } - mass(i, equalityHits) - } + @inline + def psiTerm(a: Double, b: Double): Double = { + if (a > 0 && b > 0) (b - a) * log(b / a) else 0.0 } - def klDivergence(p: Array[Double], q: Array[Double]): Double = { - require(p.length == q.length, s"Inputs are of different length ${p.length}, ${q.length}") - var i = 0 - var result = 0.0 - while (i < p.length) { - val inc = if (p(i) > 0 && q(i) > 0) p(i) * math.log(p(i) / q(i)) else 0 - result += inc - i += 1 + @inline + def termFunc(d: DriftMetric): (Double, Double) => Double = + d match { + case DriftMetric.PSI => psiTerm + case DriftMetric.HELLINGER => hellingerTerm + case DriftMetric.JENSEN_SHANNON => jsdTerm } - result - } } From e576dadbe8057d327e81675039883e5474c0214b Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Thu, 7 Nov 2024 18:54:19 -0800 Subject: [PATCH 078/152] histogram drift --- .../online/stats/DistanceMetrics.scala | 160 +++++++++++++++--- 1 file changed, 135 insertions(+), 25 deletions(-) diff --git a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala index e4820172e4..a586b95aab 100644 --- a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala +++ b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala @@ -4,22 +4,90 @@ import ai.chronon.api.TileSummaries import scala.math._ import scala.util.ScalaJavaConversions.IteratorOps +import scala.util.ScalaJavaConversions.JMapOps object DistanceMetrics { // TODO move this to unit test def main(args: Array[String]): Unit = { - val A = Array(0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200).map( - _.toDouble) - val B = Array(5, 15, 25, 35, 45, 55, 65, 75, 85, 95, 115, 115, 115, 135, 145, 155, 165, 1175, 1205, 1205, 1205).map( - _.toDouble) - - val jsd = percentileDistance(A, B, DriftMetric.JENSEN_SHANNON) - println(f"The Jensen-Shannon Divergence between distributions A and B is: $jsd%.5f") - val psi = percentileDistance(A, B, DriftMetric.PSI) - println(f"The Population Stability Index between distributions A and B is: $psi%.5f") - val hd = percentileDistance(A, B, DriftMetric.HELLINGER) - println(f"The Hellinger Distance between distributions A and B is: $hd%.5f") + + def buildPercentiles(mean: Double, variance: Double, breaks: Int = 20): Array[Double] = { + val stdDev = math.sqrt(variance) + + // Create probability points from 0.01 to 0.99 instead of 0 to 1 + val probPoints = (0 to breaks).map { i => + if (i == 0) 0.01 // p1 instead of p0 + else if (i == breaks) 0.99 // p99 instead of p100 + else i.toDouble / breaks + }.toArray + + // Convert probability points to percentiles + probPoints.map { p => + val standardNormalPercentile = math.sqrt(2) * inverseErf(2 * p - 1) + mean + (stdDev * standardNormalPercentile) + } + } + + // Helper function to calculate inverse error function + def inverseErf(x: Double): Double = { + // Approximation of inverse error function + // This is a rational approximation giving a maximum relative error of 3e-7 + val a = 0.147 + val signX = if (x >= 0) 1 else -1 + val absX = math.abs(x) + + val term1 = math.pow(2 / (math.Pi * a) + math.log(1 - absX * absX) / 2, 0.5) + val term2 = math.log(1 - absX * absX) / a + + signX * math.sqrt(term1 - term2) + } + + def compareDistributions(meanA: Double, + varianceA: Double, + meanB: Double, + varianceB: Double, + breaks: Int = 20, + debug: Boolean = false): Unit = { + + val aPercentiles = buildPercentiles(meanA, varianceA, breaks) + val bPercentiles = buildPercentiles(meanB, varianceB, breaks) + + val aHistogram: Histogram = (0 to breaks) + .map { i => + val value = java.lang.Long.valueOf((math.abs(aPercentiles(i)) * 100).toLong) + i.toString -> value + } + .toMap + .toJava + + val bHistogram: Histogram = (0 to breaks) + .map { i => + val value = java.lang.Long.valueOf((math.abs(bPercentiles(i)) * 100).toLong) + i.toString -> value + } + .toMap + .toJava + + val jsd = percentileDistance(aPercentiles, bPercentiles, DriftMetric.JENSEN_SHANNON, debug = debug) + val jsdHist = histogramDistance(aHistogram, bHistogram, DriftMetric.JENSEN_SHANNON) + println(f"The Jensen-Shannon Divergence between distributions A and B is: $jsd%.5f, $jsdHist%.5f") + + val psi = percentileDistance(aPercentiles, bPercentiles, DriftMetric.PSI, debug = debug) + val psiHist = histogramDistance(aHistogram, bHistogram, DriftMetric.PSI) + println(f"The Population Stability Index between distributions A and B is: $psi%.5f, $psiHist%.5f") + + val hd = percentileDistance(aPercentiles, bPercentiles, DriftMetric.HELLINGER, debug = debug) + val hdHist = histogramDistance(aHistogram, bHistogram, DriftMetric.HELLINGER) + println(f"The Hellinger Distance between distributions A and B is: $hd%.5f, $hdHist%.5f") + + println() + } + + compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 105.0, varianceB = 256.0) + compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 205.0, varianceB = 256.0) + compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 305.0, varianceB = 256.0) + compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 105.0, varianceB = 100.0) + } @inline @@ -28,14 +96,16 @@ object DistanceMetrics { } @inline - private def normalizeInplace(arr: Array[Double]): Array[Double] = { + private def normalize(arr: Array[Double]): Array[Double] = { + // TODO-OPTIMIZATION: normalize in place instead if this is a hotspot + val result = Array.ofDim[Double](arr.length) val sum = arr.sum var i = 0 while (i < arr.length) { - arr.update(i, arr(i) / sum) + result.update(i, arr(i) / sum) i += 1 } - arr + result } def distance(a: TileSummaries, b: TileSummaries, metric: DriftMetric): java.lang.Double = { @@ -47,32 +117,69 @@ object DistanceMetrics { if (isContinuous) percentileDistance(toArray(a.getPercentiles), toArray(b.getPercentiles), metric) else if (isCategorical) - categoricalDistance(a.getHistogram, b.getHistogram, metric) + histogramDistance(a.getHistogram, b.getHistogram, metric) else null } - def percentileDistance(a: Array[Double], b: Array[Double], metric: DriftMetric): Double = { + def percentileDistance(a: Array[Double], b: Array[Double], metric: DriftMetric, debug: Boolean = false): Double = { val breaks = (a ++ b).sorted.distinct val aProjected = AssignIntervals.on(a, breaks) val bProjected = AssignIntervals.on(b, breaks) - val aNormalized = normalizeInplace(aProjected) - val bNormalized = normalizeInplace(bProjected) + val aNormalized = normalize(aProjected) + val bNormalized = normalize(bProjected) val func = termFunc(metric) var i = 0 var result = 0.0 + + // debug only, remove before merging + val deltas = Array.ofDim[Double](aNormalized.length) + while (i < aNormalized.length) { - result += func(aNormalized(i), bNormalized(i)) + val ai = aNormalized(i) + val bi = bNormalized(i) + val delta = func(ai, bi) + + // debug only remove before merging + deltas.update(i, delta) + + result += delta i += 1 } + + if (debug) { + def printArr(arr: Array[Double]): String = + arr.map(v => f"$v%.3f").mkString(", ") + println(f""" + |aProjected : ${printArr(aProjected)} + |bProjected : ${printArr(bProjected)} + |aNormalized: ${printArr(aNormalized)} + |bNormalized: ${printArr(bNormalized)} + |deltas : ${printArr(deltas)} + |result : $result%.4f + |""".stripMargin) + } result } + // java map is what thrift produces upon deserialization type Histogram = java.util.Map[String, java.lang.Long] - def categoricalDistance(a: Histogram, b: Histogram, metric: DriftMetric): Double = { + def histogramDistance(a: Histogram, b: Histogram, metric: DriftMetric): Double = { + + @inline def sumValues(h: Histogram): Double = { + var result = 0.0 + val it = h.entrySet().iterator() + while (it.hasNext) { + result += it.next().getValue + } + result + } + val aSum = sumValues(a) + val bSum = sumValues(b) + val aIt = a.entrySet().iterator() var result = 0.0 val func = termFunc(metric) @@ -80,9 +187,9 @@ object DistanceMetrics { val entry = aIt.next() val key = entry.getKey val aVal = entry.getValue.toDouble - val bValOpt = b.get(key) - val bVal = if (bValOpt == null) bValOpt.toDouble else 0.0 - val term = func(aVal, bVal) + val bValOpt: java.lang.Long = b.get(key) + val bVal: Double = if (bValOpt == null) 0.0 else bValOpt.toDouble + val term = func(aVal / aSum, bVal / bSum) result += term } @@ -93,7 +200,8 @@ object DistanceMetrics { val bVal = entry.getValue.toDouble val aValOpt = a.get(key) if (aValOpt == null) { - result += func(0.0, bVal) + val term = func(0.0, bVal / bSum) + result += term } } @@ -118,7 +226,9 @@ object DistanceMetrics { @inline def psiTerm(a: Double, b: Double): Double = { - if (a > 0 && b > 0) (b - a) * log(b / a) else 0.0 + val aFixed = if (a == 0.0) 1e-5 else a + val bFixed = if (b == 0.0) 1e-5 else b + (bFixed - aFixed) * log(bFixed / aFixed) } @inline From 3ca60fa775e171137555b543bee1fb75e5fcee9b Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Thu, 7 Nov 2024 22:31:10 -0800 Subject: [PATCH 079/152] tile drift --- .../online/stats/DistanceMetrics.scala | 157 ++++-------------- .../test/stats/DistanceMetricsTest.scala | 128 ++++++++++++++ 2 files changed, 162 insertions(+), 123 deletions(-) create mode 100644 online/src/test/scala/ai/chronon/online/test/stats/DistanceMetricsTest.scala diff --git a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala index a586b95aab..1b12789de1 100644 --- a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala +++ b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala @@ -1,127 +1,9 @@ package ai.chronon.online.stats import ai.chronon.api.DriftMetric -import ai.chronon.api.TileSummaries import scala.math._ -import scala.util.ScalaJavaConversions.IteratorOps -import scala.util.ScalaJavaConversions.JMapOps object DistanceMetrics { - - // TODO move this to unit test - def main(args: Array[String]): Unit = { - - def buildPercentiles(mean: Double, variance: Double, breaks: Int = 20): Array[Double] = { - val stdDev = math.sqrt(variance) - - // Create probability points from 0.01 to 0.99 instead of 0 to 1 - val probPoints = (0 to breaks).map { i => - if (i == 0) 0.01 // p1 instead of p0 - else if (i == breaks) 0.99 // p99 instead of p100 - else i.toDouble / breaks - }.toArray - - // Convert probability points to percentiles - probPoints.map { p => - val standardNormalPercentile = math.sqrt(2) * inverseErf(2 * p - 1) - mean + (stdDev * standardNormalPercentile) - } - } - - // Helper function to calculate inverse error function - def inverseErf(x: Double): Double = { - // Approximation of inverse error function - // This is a rational approximation giving a maximum relative error of 3e-7 - val a = 0.147 - val signX = if (x >= 0) 1 else -1 - val absX = math.abs(x) - - val term1 = math.pow(2 / (math.Pi * a) + math.log(1 - absX * absX) / 2, 0.5) - val term2 = math.log(1 - absX * absX) / a - - signX * math.sqrt(term1 - term2) - } - - def compareDistributions(meanA: Double, - varianceA: Double, - meanB: Double, - varianceB: Double, - breaks: Int = 20, - debug: Boolean = false): Unit = { - - val aPercentiles = buildPercentiles(meanA, varianceA, breaks) - val bPercentiles = buildPercentiles(meanB, varianceB, breaks) - - val aHistogram: Histogram = (0 to breaks) - .map { i => - val value = java.lang.Long.valueOf((math.abs(aPercentiles(i)) * 100).toLong) - i.toString -> value - } - .toMap - .toJava - - val bHistogram: Histogram = (0 to breaks) - .map { i => - val value = java.lang.Long.valueOf((math.abs(bPercentiles(i)) * 100).toLong) - i.toString -> value - } - .toMap - .toJava - - val jsd = percentileDistance(aPercentiles, bPercentiles, DriftMetric.JENSEN_SHANNON, debug = debug) - val jsdHist = histogramDistance(aHistogram, bHistogram, DriftMetric.JENSEN_SHANNON) - println(f"The Jensen-Shannon Divergence between distributions A and B is: $jsd%.5f, $jsdHist%.5f") - - val psi = percentileDistance(aPercentiles, bPercentiles, DriftMetric.PSI, debug = debug) - val psiHist = histogramDistance(aHistogram, bHistogram, DriftMetric.PSI) - println(f"The Population Stability Index between distributions A and B is: $psi%.5f, $psiHist%.5f") - - val hd = percentileDistance(aPercentiles, bPercentiles, DriftMetric.HELLINGER, debug = debug) - val hdHist = histogramDistance(aHistogram, bHistogram, DriftMetric.HELLINGER) - println(f"The Hellinger Distance between distributions A and B is: $hd%.5f, $hdHist%.5f") - - println() - } - - compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 105.0, varianceB = 256.0) - compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 205.0, varianceB = 256.0) - compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 305.0, varianceB = 256.0) - compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 105.0, varianceB = 100.0) - - } - - @inline - private def toArray(l: java.util.List[java.lang.Double]): Array[Double] = { - l.iterator().toScala.map(_.toDouble).toArray - } - - @inline - private def normalize(arr: Array[Double]): Array[Double] = { - // TODO-OPTIMIZATION: normalize in place instead if this is a hotspot - val result = Array.ofDim[Double](arr.length) - val sum = arr.sum - var i = 0 - while (i < arr.length) { - result.update(i, arr(i) / sum) - i += 1 - } - result - } - - def distance(a: TileSummaries, b: TileSummaries, metric: DriftMetric): java.lang.Double = { - require(a.isSetPercentiles == b.isSetPercentiles, "Percentiles should be either set or unset together") - require(a.isSetHistogram == b.isSetHistogram, "Histograms should be either set or unset together") - - val isContinuous = a.isSetPercentiles && b.isSetPercentiles - val isCategorical = a.isSetHistogram && b.isSetHistogram - if (isContinuous) - percentileDistance(toArray(a.getPercentiles), toArray(b.getPercentiles), metric) - else if (isCategorical) - histogramDistance(a.getHistogram, b.getHistogram, metric) - else - null - } - def percentileDistance(a: Array[Double], b: Array[Double], metric: DriftMetric, debug: Boolean = false): Double = { val breaks = (a ++ b).sorted.distinct val aProjected = AssignIntervals.on(a, breaks) @@ -209,33 +91,62 @@ object DistanceMetrics { } @inline - def klDivergenceTerm(a: Double, b: Double): Double = { + private def normalize(arr: Array[Double]): Array[Double] = { + // TODO-OPTIMIZATION: normalize in place instead if this is a hotspot + val result = Array.ofDim[Double](arr.length) + val sum = arr.sum + var i = 0 + while (i < arr.length) { + result.update(i, arr(i) / sum) + i += 1 + } + result + } + + @inline + private def klDivergenceTerm(a: Double, b: Double): Double = { if (a > 0 && b > 0) a * math.log(a / b) else 0 } @inline - def jsdTerm(a: Double, b: Double): Double = { + private def jsdTerm(a: Double, b: Double): Double = { val m = (a + b) * 0.5 (klDivergenceTerm(a, m) + klDivergenceTerm(b, m)) * 0.5 } @inline - def hellingerTerm(a: Double, b: Double): Double = { + private def hellingerTerm(a: Double, b: Double): Double = { pow(sqrt(a) - sqrt(b), 2) * 0.5 } @inline - def psiTerm(a: Double, b: Double): Double = { + private def psiTerm(a: Double, b: Double): Double = { val aFixed = if (a == 0.0) 1e-5 else a val bFixed = if (b == 0.0) 1e-5 else b (bFixed - aFixed) * log(bFixed / aFixed) } @inline - def termFunc(d: DriftMetric): (Double, Double) => Double = + private def termFunc(d: DriftMetric): (Double, Double) => Double = d match { case DriftMetric.PSI => psiTerm case DriftMetric.HELLINGER => hellingerTerm case DriftMetric.JENSEN_SHANNON => jsdTerm } + + case class Thresholds(moderate: Double, severe: Double) { + def str(driftScore: Double): String = { + if (driftScore < moderate) "LOW" + else if (driftScore < severe) "MODERATE" + else "SEVERE" + } + } + + @inline + def thresholds(d: DriftMetric): Thresholds = + d match { + case DriftMetric.JENSEN_SHANNON => Thresholds(0.05, 0.15) + case DriftMetric.HELLINGER => Thresholds(0.05, 0.15) + case DriftMetric.PSI => Thresholds(0.1, 0.2) + } } diff --git a/online/src/test/scala/ai/chronon/online/test/stats/DistanceMetricsTest.scala b/online/src/test/scala/ai/chronon/online/test/stats/DistanceMetricsTest.scala new file mode 100644 index 0000000000..5e571ba5d7 --- /dev/null +++ b/online/src/test/scala/ai/chronon/online/test/stats/DistanceMetricsTest.scala @@ -0,0 +1,128 @@ +package ai.chronon.online.test.stats + +import ai.chronon.api.DriftMetric +import ai.chronon.online.stats.DistanceMetrics.histogramDistance +import ai.chronon.online.stats.DistanceMetrics.percentileDistance +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers + +import scala.util.ScalaJavaConversions.JMapOps + +class DistanceMetricsTest extends AnyFunSuite with Matchers { + + def buildPercentiles(mean: Double, variance: Double, breaks: Int = 20): Array[Double] = { + val stdDev = math.sqrt(variance) + + val probPoints = (0 to breaks).map { i => + if (i == 0) 0.01 + else if (i == breaks) 0.99 + else i.toDouble / breaks + }.toArray + + probPoints.map { p => + val standardNormalPercentile = math.sqrt(2) * inverseErf(2 * p - 1) + mean + (stdDev * standardNormalPercentile) + } + } + + def inverseErf(x: Double): Double = { + val a = 0.147 + val signX = if (x >= 0) 1 else -1 + val absX = math.abs(x) + + val term1 = math.pow(2 / (math.Pi * a) + math.log(1 - absX * absX) / 2, 0.5) + val term2 = math.log(1 - absX * absX) / a + + signX * math.sqrt(term1 - term2) + } + type Histogram = java.util.Map[String, java.lang.Long] + + def compareDistributions(meanA: Double, + varianceA: Double, + meanB: Double, + varianceB: Double, + breaks: Int = 20, + debug: Boolean = false): Map[DriftMetric, (Double, Double)] = { + + val aPercentiles = buildPercentiles(meanA, varianceA, breaks) + val bPercentiles = buildPercentiles(meanB, varianceB, breaks) + + val aHistogram: Histogram = (0 to breaks) + .map { i => + val value = java.lang.Long.valueOf((math.abs(aPercentiles(i)) * 100).toLong) + i.toString -> value + } + .toMap + .toJava + + val bHistogram: Histogram = (0 to breaks) + .map { i => + val value = java.lang.Long.valueOf((math.abs(bPercentiles(i)) * 100).toLong) + i.toString -> value + } + .toMap + .toJava + + def calculateDrift(metric: DriftMetric): (Double, Double) = { + val pDrift = percentileDistance(aPercentiles, bPercentiles, metric, debug = debug) + val histoDrift = histogramDistance(aHistogram, bHistogram, metric) + (pDrift, histoDrift) + } + + Map( + DriftMetric.JENSEN_SHANNON -> calculateDrift(DriftMetric.JENSEN_SHANNON), + DriftMetric.PSI -> calculateDrift(DriftMetric.PSI), + DriftMetric.HELLINGER -> calculateDrift(DriftMetric.HELLINGER) + ) + } + + test("Low drift - similar distributions") { + val drifts = compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 101.0, varianceB = 225.0) + + // JSD assertions + val (jsdPercentile, jsdHisto) = drifts(DriftMetric.JENSEN_SHANNON) + jsdPercentile should be < 0.05 + jsdHisto should be < 0.05 + + // Hellinger assertions + val (hellingerPercentile, hellingerHisto) = drifts(DriftMetric.HELLINGER) + hellingerPercentile should be < 0.05 + hellingerHisto should be < 0.05 + } + + test("Moderate drift - slightly different distributions") { + val drifts = compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 105.0, varianceB = 256.0) + + // JSD assertions + val (jsdPercentile, _) = drifts(DriftMetric.JENSEN_SHANNON) + jsdPercentile should (be >= 0.05 and be <= 0.15) + + // Hellinger assertions + val (hellingerPercentile, _) = drifts(DriftMetric.HELLINGER) + hellingerPercentile should (be >= 0.05 and be <= 0.15) + } + + test("Severe drift - different means") { + val drifts = compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 110.0, varianceB = 225.0) + + // JSD assertions + val (jsdPercentile, _) = drifts(DriftMetric.JENSEN_SHANNON) + jsdPercentile should be > 0.15 + + // Hellinger assertions + val (hellingerPercentile, _) = drifts(DriftMetric.HELLINGER) + hellingerPercentile should be > 0.15 + } + + test("Severe drift - different variances") { + val drifts = compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 105.0, varianceB = 100.0) + + // JSD assertions + val (jsdPercentile, _) = drifts(DriftMetric.JENSEN_SHANNON) + jsdPercentile should be > 0.15 + + // Hellinger assertions + val (hellingerPercentile, _) = drifts(DriftMetric.HELLINGER) + hellingerPercentile should be > 0.15 + } +} From ea115e384d80c32af7fd76bdfe951bebb7530cc2 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Tue, 12 Nov 2024 05:27:44 -0800 Subject: [PATCH 080/152] test wiring --- hub/app/model/Model.scala | 7 + .../ai/chronon/online/stats/Display.scala | 205 ++++++++++++++++++ .../online/stats/DistanceMetrics.scala | 152 ------------- .../test/stats/DistanceMetricsTest.scala | 128 ----------- 4 files changed, 212 insertions(+), 280 deletions(-) create mode 100644 online/src/main/scala/ai/chronon/online/stats/Display.scala delete mode 100644 online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala delete mode 100644 online/src/test/scala/ai/chronon/online/test/stats/DistanceMetricsTest.scala diff --git a/hub/app/model/Model.scala b/hub/app/model/Model.scala index a498936611..a14ce4d679 100644 --- a/hub/app/model/Model.scala +++ b/hub/app/model/Model.scala @@ -5,6 +5,13 @@ case class GroupBy(name: String, features: Seq[String]) case class Join(name: String, joinFeatures: Seq[String], groupBys: Seq[GroupBy]) case class Model(name: String, join: Join, online: Boolean, production: Boolean, team: String, modelType: String) +// 1.) metadataUpload: join -> map> +// 2.) fetchJoinConf + listColumns: join => list +// 3.) (columns, start, end) -> list + +// 4.) 1:n/fetchTile: tileKey -> TileSummaries +// 5.) 1:n:n/compareTiles: TileSummaries, TileSummaries -> TileDrift +// 6.) Map[column, Seq[tileDrift]] -> TimeSeriesController /** Supported Metric types */ sealed trait MetricType diff --git a/online/src/main/scala/ai/chronon/online/stats/Display.scala b/online/src/main/scala/ai/chronon/online/stats/Display.scala new file mode 100644 index 0000000000..a4c757dc27 --- /dev/null +++ b/online/src/main/scala/ai/chronon/online/stats/Display.scala @@ -0,0 +1,205 @@ +package ai.chronon.online.stats + +import cask._ +import scalatags.Text.all._ +import scalatags.Text.tags2.title + +// generates html / js code to serve a tabbed board on the network port +// boards are static and do not update, used for debugging only +// uses uPlot under the hood +object Display { + // single line inside a chart + case class Series(series: Array[Double], name: String) + // multiple lines in a chart plus the x-axis and a threshold (horizontal dashed line) + case class Chart(seriesList: Array[Series], + x: Array[String], + name: String, + moderateThreshold: Option[Double] = None, + severeThreshold: Option[Double] = None) + + // multiple charts in a section + case class Section(charts: Array[Chart], name: String) + // multiple sections in a tab + case class Tab(sectionList: Array[Section], name: String) + // multiple tabs in a board + case class Board(tabList: Array[Tab], name: String) + + private def generateChartJs(chart: Chart, chartId: String): String = { + val data = chart.seriesList.map(_.series) + val xData = chart.x.map(_.toString) + chart.seriesList.map(_.name) + + val seriesConfig = chart.seriesList.map(s => s"""{ + | label: "${s.name}", + | stroke: "rgb(${scala.util.Random.nextInt(255)}, ${scala.util.Random.nextInt(255)}, ${scala.util.Random.nextInt(255)})" + | + |}""".stripMargin).mkString(",\n") + + val thresholdLines = (chart.moderateThreshold.map(t => s""" + |{ + | label: "Moderate Threshold", + | value: $t, + | stroke: "#ff9800", + | style: [2, 2] + |}""".stripMargin) ++ + chart.severeThreshold.map(t => s""" + |{ + | label: "Severe Threshold", + | value: $t, + | stroke: "#f44336", + | style: [2, 2] + |}""".stripMargin)).mkString(",") + + s""" + |new uPlot({ + | title: "${chart.name}", + | id: "$chartId", + | class: "chart", + | width: 800, + | height: 400, + | scales: { + | x: { + | time: false, + | } + | }, + | series: [ + | {}, + | $seriesConfig + | ], + | axes: [ + | {}, + | { + | label: "Value", + | grid: true, + | } + | ], + | plugins: [ + | { + | hooks: { + | draw: u => { + | ${if (thresholdLines.nonEmpty) + s"""const lines = [$thresholdLines]; + | for (const line of lines) { + | const scale = u.scales.y; + | const y = scale.getPos(line.value); + | + | u.ctx.save(); + | u.ctx.strokeStyle = line.stroke; + | u.ctx.setLineDash(line.style); + | + | u.ctx.beginPath(); + | u.ctx.moveTo(u.bbox.left, y); + | u.ctx.lineTo(u.bbox.left + u.bbox.width, y); + | u.ctx.stroke(); + | + | u.ctx.restore(); + | }""".stripMargin + else ""} + | } + | } + | } + | ] + |}, [${xData.mkString("\"", "\",\"", "\"")}, ${data + .map(_.mkString(",")) + .mkString("[", "],[", "]")}], document.getElementById("$chartId")); + |""".stripMargin + } + + def serve(board: Board, portVal: Int = 9032): Unit = { + + object Server extends cask.MainRoutes { + @get("/") + def index() = { + val page = html( + head( + title(board.name), + script(src := "https://unpkg.com/uplot@1.6.24/dist/uPlot.iife.min.js"), + link(rel := "stylesheet", href := "https://unpkg.com/uplot@1.6.24/dist/uPlot.min.css"), + tag("style")(""" + |body { font-family: Arial, sans-serif; margin: 20px; } + |.tab { display: none; } + |.tab.active { display: block; } + |.tab-button { padding: 10px 20px; margin-right: 5px; cursor: pointer; } + |.tab-button.active { background-color: #ddd; } + |.section { margin: 20px 0; } + |.chart { margin: 20px 0; } + """.stripMargin) + ), + body( + h1(board.name), + div(cls := "tabs")( + board.tabList.map(tab => + button( + cls := "tab-button", + onclick := s"showTab('${tab.name}')", + tab.name + )) + ), + board.tabList.map(tab => + div(cls := "tab", id := tab.name)( + tab.sectionList.map(section => + div(cls := "section")( + h2(section.name), + section.charts.map(chart => + div(cls := "chart")( + div(id := s"${tab.name}-${section.name}-${chart.name}".replaceAll("\\s+", "-")) + )) + )) + )), + script(raw(""" + |function showTab(tabName) { + | document.querySelectorAll('.tab').forEach(tab => { + | tab.style.display = tab.id === tabName ? 'block' : 'none'; + | }); + | document.querySelectorAll('.tab-button').forEach(button => { + | button.classList.toggle('active', button.textContent === tabName); + | }); + |} + | + |// Show first tab by default + |document.querySelector('.tab-button').click(); + """.stripMargin)), + script( + raw( + board.tabList + .flatMap(tab => + tab.sectionList.flatMap(section => + section.charts.map(chart => + generateChartJs(chart, s"${tab.name}-${section.name}-${chart.name}".replaceAll("\\s+", "-"))))) + .mkString("\n") + )) + ) + ) + +// page.render + + cask.Response( + page.render, + headers = Seq("Content-Type" -> "text/html") + ) + } + + override def host: String = "0.0.0.0" + override def port: Int = portVal + + initialize() + } + + Server.main(Array()) + } + + def main(args: Array[String]): Unit = { + val series = Array(Series(Array(1.0, 2.0, 3.0), "Series 1"), Series(Array(2.0, 3.0, 4.0), "Series 2")) + val chart = Chart(series, Array("A", "B", "C"), "Chart 1", Some(2.5), Some(3.5)) + val section = Section(Array(chart), "Section 1") + val tab = Tab(Array(section), "Tab 1") + val board = Board(Array(tab), "Board 1") + + println("serving board at http://localhost:9032/") + serve(board) + // Keep the program running + while (true) { + Thread.sleep(5000) + } + } +} diff --git a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala deleted file mode 100644 index 1b12789de1..0000000000 --- a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala +++ /dev/null @@ -1,152 +0,0 @@ -package ai.chronon.online.stats -import ai.chronon.api.DriftMetric - -import scala.math._ - -object DistanceMetrics { - def percentileDistance(a: Array[Double], b: Array[Double], metric: DriftMetric, debug: Boolean = false): Double = { - val breaks = (a ++ b).sorted.distinct - val aProjected = AssignIntervals.on(a, breaks) - val bProjected = AssignIntervals.on(b, breaks) - - val aNormalized = normalize(aProjected) - val bNormalized = normalize(bProjected) - - val func = termFunc(metric) - - var i = 0 - var result = 0.0 - - // debug only, remove before merging - val deltas = Array.ofDim[Double](aNormalized.length) - - while (i < aNormalized.length) { - val ai = aNormalized(i) - val bi = bNormalized(i) - val delta = func(ai, bi) - - // debug only remove before merging - deltas.update(i, delta) - - result += delta - i += 1 - } - - if (debug) { - def printArr(arr: Array[Double]): String = - arr.map(v => f"$v%.3f").mkString(", ") - println(f""" - |aProjected : ${printArr(aProjected)} - |bProjected : ${printArr(bProjected)} - |aNormalized: ${printArr(aNormalized)} - |bNormalized: ${printArr(bNormalized)} - |deltas : ${printArr(deltas)} - |result : $result%.4f - |""".stripMargin) - } - result - } - - // java map is what thrift produces upon deserialization - type Histogram = java.util.Map[String, java.lang.Long] - def histogramDistance(a: Histogram, b: Histogram, metric: DriftMetric): Double = { - - @inline def sumValues(h: Histogram): Double = { - var result = 0.0 - val it = h.entrySet().iterator() - while (it.hasNext) { - result += it.next().getValue - } - result - } - val aSum = sumValues(a) - val bSum = sumValues(b) - - val aIt = a.entrySet().iterator() - var result = 0.0 - val func = termFunc(metric) - while (aIt.hasNext) { - val entry = aIt.next() - val key = entry.getKey - val aVal = entry.getValue.toDouble - val bValOpt: java.lang.Long = b.get(key) - val bVal: Double = if (bValOpt == null) 0.0 else bValOpt.toDouble - val term = func(aVal / aSum, bVal / bSum) - result += term - } - - val bIt = b.entrySet().iterator() - while (bIt.hasNext) { - val entry = bIt.next() - val key = entry.getKey - val bVal = entry.getValue.toDouble - val aValOpt = a.get(key) - if (aValOpt == null) { - val term = func(0.0, bVal / bSum) - result += term - } - } - - result - } - - @inline - private def normalize(arr: Array[Double]): Array[Double] = { - // TODO-OPTIMIZATION: normalize in place instead if this is a hotspot - val result = Array.ofDim[Double](arr.length) - val sum = arr.sum - var i = 0 - while (i < arr.length) { - result.update(i, arr(i) / sum) - i += 1 - } - result - } - - @inline - private def klDivergenceTerm(a: Double, b: Double): Double = { - if (a > 0 && b > 0) a * math.log(a / b) else 0 - } - - @inline - private def jsdTerm(a: Double, b: Double): Double = { - val m = (a + b) * 0.5 - (klDivergenceTerm(a, m) + klDivergenceTerm(b, m)) * 0.5 - } - - @inline - private def hellingerTerm(a: Double, b: Double): Double = { - pow(sqrt(a) - sqrt(b), 2) * 0.5 - } - - @inline - private def psiTerm(a: Double, b: Double): Double = { - val aFixed = if (a == 0.0) 1e-5 else a - val bFixed = if (b == 0.0) 1e-5 else b - (bFixed - aFixed) * log(bFixed / aFixed) - } - - @inline - private def termFunc(d: DriftMetric): (Double, Double) => Double = - d match { - case DriftMetric.PSI => psiTerm - case DriftMetric.HELLINGER => hellingerTerm - case DriftMetric.JENSEN_SHANNON => jsdTerm - } - - case class Thresholds(moderate: Double, severe: Double) { - def str(driftScore: Double): String = { - if (driftScore < moderate) "LOW" - else if (driftScore < severe) "MODERATE" - else "SEVERE" - } - } - - @inline - def thresholds(d: DriftMetric): Thresholds = - d match { - case DriftMetric.JENSEN_SHANNON => Thresholds(0.05, 0.15) - case DriftMetric.HELLINGER => Thresholds(0.05, 0.15) - case DriftMetric.PSI => Thresholds(0.1, 0.2) - } -} diff --git a/online/src/test/scala/ai/chronon/online/test/stats/DistanceMetricsTest.scala b/online/src/test/scala/ai/chronon/online/test/stats/DistanceMetricsTest.scala deleted file mode 100644 index 5e571ba5d7..0000000000 --- a/online/src/test/scala/ai/chronon/online/test/stats/DistanceMetricsTest.scala +++ /dev/null @@ -1,128 +0,0 @@ -package ai.chronon.online.test.stats - -import ai.chronon.api.DriftMetric -import ai.chronon.online.stats.DistanceMetrics.histogramDistance -import ai.chronon.online.stats.DistanceMetrics.percentileDistance -import org.scalatest.funsuite.AnyFunSuite -import org.scalatest.matchers.should.Matchers - -import scala.util.ScalaJavaConversions.JMapOps - -class DistanceMetricsTest extends AnyFunSuite with Matchers { - - def buildPercentiles(mean: Double, variance: Double, breaks: Int = 20): Array[Double] = { - val stdDev = math.sqrt(variance) - - val probPoints = (0 to breaks).map { i => - if (i == 0) 0.01 - else if (i == breaks) 0.99 - else i.toDouble / breaks - }.toArray - - probPoints.map { p => - val standardNormalPercentile = math.sqrt(2) * inverseErf(2 * p - 1) - mean + (stdDev * standardNormalPercentile) - } - } - - def inverseErf(x: Double): Double = { - val a = 0.147 - val signX = if (x >= 0) 1 else -1 - val absX = math.abs(x) - - val term1 = math.pow(2 / (math.Pi * a) + math.log(1 - absX * absX) / 2, 0.5) - val term2 = math.log(1 - absX * absX) / a - - signX * math.sqrt(term1 - term2) - } - type Histogram = java.util.Map[String, java.lang.Long] - - def compareDistributions(meanA: Double, - varianceA: Double, - meanB: Double, - varianceB: Double, - breaks: Int = 20, - debug: Boolean = false): Map[DriftMetric, (Double, Double)] = { - - val aPercentiles = buildPercentiles(meanA, varianceA, breaks) - val bPercentiles = buildPercentiles(meanB, varianceB, breaks) - - val aHistogram: Histogram = (0 to breaks) - .map { i => - val value = java.lang.Long.valueOf((math.abs(aPercentiles(i)) * 100).toLong) - i.toString -> value - } - .toMap - .toJava - - val bHistogram: Histogram = (0 to breaks) - .map { i => - val value = java.lang.Long.valueOf((math.abs(bPercentiles(i)) * 100).toLong) - i.toString -> value - } - .toMap - .toJava - - def calculateDrift(metric: DriftMetric): (Double, Double) = { - val pDrift = percentileDistance(aPercentiles, bPercentiles, metric, debug = debug) - val histoDrift = histogramDistance(aHistogram, bHistogram, metric) - (pDrift, histoDrift) - } - - Map( - DriftMetric.JENSEN_SHANNON -> calculateDrift(DriftMetric.JENSEN_SHANNON), - DriftMetric.PSI -> calculateDrift(DriftMetric.PSI), - DriftMetric.HELLINGER -> calculateDrift(DriftMetric.HELLINGER) - ) - } - - test("Low drift - similar distributions") { - val drifts = compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 101.0, varianceB = 225.0) - - // JSD assertions - val (jsdPercentile, jsdHisto) = drifts(DriftMetric.JENSEN_SHANNON) - jsdPercentile should be < 0.05 - jsdHisto should be < 0.05 - - // Hellinger assertions - val (hellingerPercentile, hellingerHisto) = drifts(DriftMetric.HELLINGER) - hellingerPercentile should be < 0.05 - hellingerHisto should be < 0.05 - } - - test("Moderate drift - slightly different distributions") { - val drifts = compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 105.0, varianceB = 256.0) - - // JSD assertions - val (jsdPercentile, _) = drifts(DriftMetric.JENSEN_SHANNON) - jsdPercentile should (be >= 0.05 and be <= 0.15) - - // Hellinger assertions - val (hellingerPercentile, _) = drifts(DriftMetric.HELLINGER) - hellingerPercentile should (be >= 0.05 and be <= 0.15) - } - - test("Severe drift - different means") { - val drifts = compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 110.0, varianceB = 225.0) - - // JSD assertions - val (jsdPercentile, _) = drifts(DriftMetric.JENSEN_SHANNON) - jsdPercentile should be > 0.15 - - // Hellinger assertions - val (hellingerPercentile, _) = drifts(DriftMetric.HELLINGER) - hellingerPercentile should be > 0.15 - } - - test("Severe drift - different variances") { - val drifts = compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 105.0, varianceB = 100.0) - - // JSD assertions - val (jsdPercentile, _) = drifts(DriftMetric.JENSEN_SHANNON) - jsdPercentile should be > 0.15 - - // Hellinger assertions - val (hellingerPercentile, _) = drifts(DriftMetric.HELLINGER) - hellingerPercentile should be > 0.15 - } -} From 1af7b767e50f3a2ad95f5134ca6cd9a5ea7b19a7 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 20 Nov 2024 15:48:43 -0500 Subject: [PATCH 081/152] Rename DynamoDB store to monitoring model store --- hub/app/controllers/ModelController.scala | 4 ++-- hub/app/controllers/SearchController.scala | 4 ++-- hub/app/module/DynamoDBModule.scala | 16 ---------------- hub/app/module/ModelStoreModule.scala | 16 ++++++++++++++++ ...ingStore.scala => MonitoringModelStore.scala} | 10 +++++----- hub/conf/application.conf | 2 +- hub/test/controllers/ModelControllerSpec.scala | 4 ++-- hub/test/controllers/SearchControllerSpec.scala | 4 ++-- ...Test.scala => MonitoringModelStoreTest.scala} | 6 +++--- 9 files changed, 33 insertions(+), 33 deletions(-) delete mode 100644 hub/app/module/DynamoDBModule.scala create mode 100644 hub/app/module/ModelStoreModule.scala rename hub/app/store/{DynamoDBMonitoringStore.scala => MonitoringModelStore.scala} (94%) rename hub/test/store/{DynamoDBMonitoringStoreTest.scala => MonitoringModelStoreTest.scala} (92%) diff --git a/hub/app/controllers/ModelController.scala b/hub/app/controllers/ModelController.scala index e895c8c27f..40ef41a56c 100644 --- a/hub/app/controllers/ModelController.scala +++ b/hub/app/controllers/ModelController.scala @@ -4,7 +4,7 @@ import io.circe.generic.auto._ import io.circe.syntax._ import model.ListModelResponse import play.api.mvc._ -import store.DynamoDBMonitoringStore +import store.MonitoringModelStore import javax.inject._ @@ -13,7 +13,7 @@ import javax.inject._ */ @Singleton class ModelController @Inject() (val controllerComponents: ControllerComponents, - monitoringStore: DynamoDBMonitoringStore) + monitoringStore: MonitoringModelStore) extends BaseController with Paginate { diff --git a/hub/app/controllers/SearchController.scala b/hub/app/controllers/SearchController.scala index ac6b39110e..cb36e76a62 100644 --- a/hub/app/controllers/SearchController.scala +++ b/hub/app/controllers/SearchController.scala @@ -5,7 +5,7 @@ import io.circe.syntax._ import model.Model import model.SearchModelResponse import play.api.mvc._ -import store.DynamoDBMonitoringStore +import store.MonitoringModelStore import javax.inject._ @@ -13,7 +13,7 @@ import javax.inject._ * Controller to power search related APIs */ class SearchController @Inject() (val controllerComponents: ControllerComponents, - monitoringStore: DynamoDBMonitoringStore) + monitoringStore: MonitoringModelStore) extends BaseController with Paginate { diff --git a/hub/app/module/DynamoDBModule.scala b/hub/app/module/DynamoDBModule.scala deleted file mode 100644 index c9d1ab307e..0000000000 --- a/hub/app/module/DynamoDBModule.scala +++ /dev/null @@ -1,16 +0,0 @@ -package module - -import ai.chronon.integrations.aws.AwsApiImpl -import com.google.inject.AbstractModule -import play.api.Configuration -import play.api.Environment -import store.DynamoDBMonitoringStore - -class DynamoDBModule(environment: Environment, configuration: Configuration) extends AbstractModule { - - override def configure(): Unit = { - val awsApiImpl = new AwsApiImpl(Map.empty) - val dynamoDBMonitoringStore = new DynamoDBMonitoringStore(awsApiImpl) - bind(classOf[DynamoDBMonitoringStore]).toInstance(dynamoDBMonitoringStore) - } -} diff --git a/hub/app/module/ModelStoreModule.scala b/hub/app/module/ModelStoreModule.scala new file mode 100644 index 0000000000..801faeaa77 --- /dev/null +++ b/hub/app/module/ModelStoreModule.scala @@ -0,0 +1,16 @@ +package module + +import ai.chronon.integrations.aws.AwsApiImpl +import com.google.inject.AbstractModule +import play.api.Configuration +import play.api.Environment +import store.MonitoringModelStore + +class ModelStoreModule(environment: Environment, configuration: Configuration) extends AbstractModule { + + override def configure(): Unit = { + val awsApiImpl = new AwsApiImpl(Map.empty) + val dynamoDBMonitoringStore = new MonitoringModelStore(awsApiImpl) + bind(classOf[MonitoringModelStore]).toInstance(dynamoDBMonitoringStore) + } +} diff --git a/hub/app/store/DynamoDBMonitoringStore.scala b/hub/app/store/MonitoringModelStore.scala similarity index 94% rename from hub/app/store/DynamoDBMonitoringStore.scala rename to hub/app/store/MonitoringModelStore.scala index d435f9ad52..bfa1f5e895 100644 --- a/hub/app/store/DynamoDBMonitoringStore.scala +++ b/hub/app/store/MonitoringModelStore.scala @@ -31,10 +31,10 @@ case class LoadedConfs(joins: Seq[api.Join] = Seq.empty, stagingQueries: Seq[api.StagingQuery] = Seq.empty, models: Seq[api.Model] = Seq.empty) -class DynamoDBMonitoringStore(apiImpl: Api) { +class MonitoringModelStore(apiImpl: Api) { - val dynamoDBKVStore: KVStore = apiImpl.genKvStore - implicit val executionContext: ExecutionContext = dynamoDBKVStore.executionContext + val kvStore: KVStore = apiImpl.genKvStore + implicit val executionContext: ExecutionContext = kvStore.executionContext // to help periodically refresh the load config catalog, we wrap this in a TTL cache lazy val configRegistryCache: TTLCache[String, LoadedConfs] = { @@ -59,7 +59,7 @@ class DynamoDBMonitoringStore(apiImpl: Api) { GroupBy(part.groupBy.metaData.name, part.groupBy.valueColumns) } - val outputColumns = thriftJoin.outputColumnsByGroup.values.flatten.toArray + val outputColumns = thriftJoin.ooutputColumnsByGroup.getOrElse("derivations", Array.empty) val join = Join(thriftJoin.metaData.name, outputColumns, groupBys) Option( Model(m.metaData.name, join, m.metaData.online, m.metaData.production, m.metaData.team, m.modelType.name())) @@ -81,7 +81,7 @@ class DynamoDBMonitoringStore(apiImpl: Api) { } val listRequest = ListRequest(MetadataEndPoint.ConfByKeyEndPointName, propsMap) logger.info(s"Triggering list conf lookup with request: $listRequest") - dynamoDBKVStore.list(listRequest).flatMap { response => + kvStore.list(listRequest).flatMap { response => val newLoadedConfs = makeLoadedConfs(response) val newAcc = LoadedConfs( acc.joins ++ newLoadedConfs.joins, diff --git a/hub/conf/application.conf b/hub/conf/application.conf index c3dceef48d..1d6b9996bf 100644 --- a/hub/conf/application.conf +++ b/hub/conf/application.conf @@ -28,4 +28,4 @@ play.filters.cors { } # Add DynamoDB module -play.modules.enabled += "module.DynamoDBModule" +play.modules.enabled += "module.ModelStoreModule" diff --git a/hub/test/controllers/ModelControllerSpec.scala b/hub/test/controllers/ModelControllerSpec.scala index 8e33830093..95b96ca24a 100644 --- a/hub/test/controllers/ModelControllerSpec.scala +++ b/hub/test/controllers/ModelControllerSpec.scala @@ -15,7 +15,7 @@ import play.api.http.Status.OK import play.api.mvc._ import play.api.test.Helpers._ import play.api.test._ -import store.DynamoDBMonitoringStore +import store.MonitoringModelStore class ModelControllerSpec extends PlaySpec with Results with EitherValues { @@ -24,7 +24,7 @@ class ModelControllerSpec extends PlaySpec with Results with EitherValues { // Create a stub ControllerComponents val stubCC: ControllerComponents = stubControllerComponents() // Create a mocked DynDB store - val mockedStore: DynamoDBMonitoringStore = mock(classOf[DynamoDBMonitoringStore]) + val mockedStore: MonitoringModelStore = mock(classOf[MonitoringModelStore]) val controller = new ModelController(stubCC, mockedStore) diff --git a/hub/test/controllers/SearchControllerSpec.scala b/hub/test/controllers/SearchControllerSpec.scala index 95807cfc35..5510ea4010 100644 --- a/hub/test/controllers/SearchControllerSpec.scala +++ b/hub/test/controllers/SearchControllerSpec.scala @@ -14,14 +14,14 @@ import play.api.http.Status.OK import play.api.mvc._ import play.api.test.Helpers._ import play.api.test._ -import store.DynamoDBMonitoringStore +import store.MonitoringModelStore class SearchControllerSpec extends PlaySpec with Results with EitherValues { // Create a stub ControllerComponents val stubCC: ControllerComponents = stubControllerComponents() // Create a mocked DynDB store - val mockedStore: DynamoDBMonitoringStore = mock(classOf[DynamoDBMonitoringStore]) + val mockedStore: MonitoringModelStore = mock(classOf[MonitoringModelStore]) val controller = new SearchController(stubCC, mockedStore) diff --git a/hub/test/store/DynamoDBMonitoringStoreTest.scala b/hub/test/store/MonitoringModelStoreTest.scala similarity index 92% rename from hub/test/store/DynamoDBMonitoringStoreTest.scala rename to hub/test/store/MonitoringModelStoreTest.scala index 6a77915a3c..a2fcb925a8 100644 --- a/hub/test/store/DynamoDBMonitoringStoreTest.scala +++ b/hub/test/store/MonitoringModelStoreTest.scala @@ -21,7 +21,7 @@ import scala.io.Source import scala.util.Success import scala.util.Try -class DynamoDBMonitoringStoreTest extends MockitoSugar with Matchers { +class MonitoringModelStoreTest extends MockitoSugar with Matchers { var api: Api = _ var kvStore: KVStore = _ @@ -41,13 +41,13 @@ class DynamoDBMonitoringStoreTest extends MockitoSugar with Matchers { @Test def monitoringStoreShouldReturnModels(): Unit = { - val dynamoDBMonitoringStore = new DynamoDBMonitoringStore(api) + val dynamoDBMonitoringStore = new MonitoringModelStore(api) when(kvStore.list(any())).thenReturn(generateListResponse()) validateLoadedConfigs(dynamoDBMonitoringStore) } - private def validateLoadedConfigs(dynamoDBMonitoringStore: DynamoDBMonitoringStore): Unit = { + private def validateLoadedConfigs(dynamoDBMonitoringStore: MonitoringModelStore): Unit = { // check that our store has loaded the relevant artifacts dynamoDBMonitoringStore.getConfigRegistry.models.length shouldBe 1 dynamoDBMonitoringStore.getConfigRegistry.groupBys.length shouldBe 2 From 1228488b88aa642abcd21dd5f10e5f5102f27453 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Thu, 21 Nov 2024 20:17:18 -0500 Subject: [PATCH 082/152] First cut wiring up with passing tests --- .../controllers/TimeSeriesController.scala | 253 ++++++++++-------- hub/app/module/DriftStoreModule.scala | 15 ++ hub/app/store/MonitoringModelStore.scala | 2 +- hub/conf/application.conf | 1 + hub/conf/routes | 2 +- .../TimeSeriesControllerSpec.scala | 187 ++++++++++--- .../ai/chronon/online/stats/Display.scala | 205 -------------- .../ai/chronon/online/stats/DriftStore.scala | 6 + .../online/stats/TileDriftCalculator.scala | 2 +- 9 files changed, 313 insertions(+), 360 deletions(-) create mode 100644 hub/app/module/DriftStoreModule.scala delete mode 100644 online/src/main/scala/ai/chronon/online/stats/Display.scala diff --git a/hub/app/controllers/TimeSeriesController.scala b/hub/app/controllers/TimeSeriesController.scala index f3f546d6b0..0f253999a3 100644 --- a/hub/app/controllers/TimeSeriesController.scala +++ b/hub/app/controllers/TimeSeriesController.scala @@ -1,19 +1,23 @@ package controllers -import ai.chronon.api.DriftMetric +import ai.chronon.api.Extensions.WindowOps +import ai.chronon.api.{DriftMetric, TileDriftSeries, TileSummarySeries, TimeUnit, Window} +import ai.chronon.online.stats.DriftStore import io.circe.generic.auto._ import io.circe.syntax._ import model._ import play.api.mvc._ import javax.inject._ +import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ -import scala.util.Random +import scala.util.{Failure, Random, Success} +import scala.jdk.CollectionConverters._ /** * Controller that serves various time series endpoints at the model, join and feature level */ @Singleton -class TimeSeriesController @Inject() (val controllerComponents: ControllerComponents) extends BaseController { +class TimeSeriesController @Inject() (val controllerComponents: ControllerComponents, driftStore: DriftStore)(implicit ec: ExecutionContext) extends BaseController { import TimeSeriesController._ @@ -25,17 +29,6 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon def fetchModel(id: String, startTs: Long, endTs: Long, offset: String, algorithm: String): Action[AnyContent] = doFetchModel(id, startTs, endTs, offset, algorithm) - /** - * Helps retrieve a model time series with the data sliced based on the relevant slice (identified by sliceId) - */ - def fetchModelSlice(id: String, - sliceId: String, - startTs: Long, - endTs: Long, - offset: String, - algorithm: String): Action[AnyContent] = - doFetchModel(id, startTs, endTs, offset, algorithm, Some(sliceId)) - /** * Helps retrieve a time series (drift or skew) for each of the features that are part of a join. Time series is * retrieved between the start and end ts. If the metric type is for drift, the offset is used to compute the @@ -50,28 +43,15 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon metrics: String, offset: Option[String], algorithm: Option[String]): Action[AnyContent] = - doFetchJoin(name, startTs, endTs, metricType, metrics, None, offset, algorithm) - - /** - * Helps retrieve a time series (drift or skew) for each of the features that are part of a join. The data is sliced - * based on the configured slice (looked up by sliceId) - */ - def fetchJoinSlice(name: String, - sliceId: String, - startTs: Long, - endTs: Long, - metricType: String, - metrics: String, - offset: Option[String], - algorithm: Option[String]): Action[AnyContent] = - doFetchJoin(name, startTs, endTs, metricType, metrics, Some(sliceId), offset, algorithm) + doFetchJoin(name, startTs, endTs, metricType, metrics, offset, algorithm) /** * Helps retrieve a time series (drift or skew) for a given feature. Time series is * retrieved between the start and end ts. Choice of granularity (raw, aggregate, percentiles) along with the * metric type (drift / skew) dictates the shape of the returned time series. */ - def fetchFeature(name: String, + def fetchFeature(join: String, + name: String, startTs: Long, endTs: Long, metricType: String, @@ -79,29 +59,13 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon granularity: String, offset: Option[String], algorithm: Option[String]): Action[AnyContent] = - doFetchFeature(name, startTs, endTs, metricType, metrics, None, granularity, offset, algorithm) - - /** - * Helps retrieve a time series (drift or skew) for a given feature. The data is sliced based on the configured slice - * (looked up by sliceId) - */ - def fetchFeatureSlice(name: String, - sliceId: String, - startTs: Long, - endTs: Long, - metricType: String, - metrics: String, - granularity: String, - offset: Option[String], - algorithm: Option[String]): Action[AnyContent] = - doFetchFeature(name, startTs, endTs, metricType, metrics, Some(sliceId), granularity, offset, algorithm) + doFetchFeature(join, name, startTs, endTs, metricType, metrics, granularity, offset, algorithm) private def doFetchModel(id: String, startTs: Long, endTs: Long, offset: String, - algorithm: String, - sliceId: Option[String] = None): Action[AnyContent] = + algorithm: String): Action[AnyContent] = Action { implicit request: Request[AnyContent] => (parseOffset(Some(offset)), parseAlgorithm(Some(algorithm))) match { case (None, _) => BadRequest(s"Unable to parse offset - $offset") @@ -118,18 +82,17 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon endTs: Long, metricType: String, metrics: String, - slice: Option[String], offset: Option[String], algorithm: Option[String]): Action[AnyContent] = - Action { implicit request: Request[AnyContent] => + Action.async { implicit request: Request[AnyContent] => val metricChoice = parseMetricChoice(Some(metricType)) val metricRollup = parseMetricRollup(Some(metrics)) (metricChoice, metricRollup) match { - case (None, _) => BadRequest("Invalid metric choice. Expect drift / skew") - case (_, None) => BadRequest("Invalid metric rollup. Expect null / value") - case (Some(Drift), Some(rollup)) => doFetchJoinDrift(name, startTs, endTs, rollup, slice, offset, algorithm) - case (Some(Skew), Some(rollup)) => doFetchJoinSkew(name, startTs, endTs, rollup, slice) + case (None, _) => Future.successful(BadRequest("Invalid metric choice. Expect drift / skew")) + case (_, None) => Future.successful(BadRequest("Invalid metric rollup. Expect null / value")) + case (Some(Drift), Some(rollup)) => doFetchJoinDrift(name, startTs, endTs, rollup, offset, algorithm) + case (Some(Skew), Some(rollup)) => doFetchJoinSkew(name, startTs, endTs, rollup) } } @@ -137,34 +100,36 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon startTs: Long, endTs: Long, metric: Metric, - sliceId: Option[String], offset: Option[String], - algorithm: Option[String]): Result = { + algorithm: Option[String]): Future[Result] = { (parseOffset(offset), parseAlgorithm(algorithm)) match { - case (None, _) => BadRequest(s"Unable to parse offset - $offset") - case (_, None) => BadRequest("Invalid drift algorithm. Expect JSD, PSI or Hellinger") - case (Some(_), Some(_)) => - // TODO: Use parsedOffset and parsedAlgorithm when ready - val mockGroupBys = generateMockGroupBys(3) - val groupByTimeSeries = mockGroupBys.map { g => - val mockFeatures = generateMockFeatures(g, 10) - val featureTS = mockFeatures.map { - FeatureTimeSeries(_, generateMockTimeSeriesPoints(startTs, endTs)) + case (None, _) => Future.successful(BadRequest(s"Unable to parse offset - $offset")) + case (_, None) => Future.successful(BadRequest("Invalid drift algorithm. Expect JSD, PSI or Hellinger")) + case (Some(o), Some(driftMetric)) => + val window = new Window(o.toMinutes.toInt, TimeUnit.MINUTES) + val maybeDriftSeries = driftStore.getDriftSeries(name, driftMetric, window, startTs, endTs) + maybeDriftSeries match { + case Failure(exception) => Future.successful(InternalServerError(s"Error computing join drift - ${exception.getMessage}")) + case Success(driftSeriesFuture) => driftSeriesFuture.map { + driftSeries => + // pull up a list of drift series objects for all the features in a group + val grpToDriftSeriesList: Map[String, Seq[TileDriftSeries]] = driftSeries.groupBy(_.key.groupName) + val groupByTimeSeries = grpToDriftSeriesList.map { + case (name, featureDriftSeriesInfoSeq) => GroupByTimeSeries(name, featureDriftSeriesInfoSeq.map(series => convertTileDriftSeriesInfoToTimeSeries(series, metric))) + }.toSeq + + val tsData = JoinTimeSeriesResponse(name, groupByTimeSeries) + Ok(tsData.asJson.noSpaces) } - GroupByTimeSeries(g, featureTS) } - - val mockTSData = JoinTimeSeriesResponse(name, groupByTimeSeries) - Ok(mockTSData.asJson.noSpaces) } } private def doFetchJoinSkew(name: String, startTs: Long, endTs: Long, - metric: Metric, - sliceId: Option[String]): Result = { + metric: Metric): Future[Result] = { val mockGroupBys = generateMockGroupBys(3) val groupByTimeSeries = mockGroupBys.map { g => val mockFeatures = generateMockFeatures(g, 10) @@ -176,71 +141,80 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon val mockTSData = JoinTimeSeriesResponse(name, groupByTimeSeries) val json = mockTSData.asJson.noSpaces - Ok(json) + Future.successful(Ok(json)) } - private def doFetchFeature(name: String, + private def doFetchFeature(join: String, + name: String, startTs: Long, endTs: Long, metricType: String, metrics: String, - slice: Option[String], granularity: String, offset: Option[String], algorithm: Option[String]): Action[AnyContent] = - Action { implicit request: Request[AnyContent] => + Action.async { implicit request: Request[AnyContent] => val metricChoice = parseMetricChoice(Some(metricType)) val metricRollup = parseMetricRollup(Some(metrics)) val granularityType = parseGranularity(granularity) (metricChoice, metricRollup, granularityType) match { - case (None, _, _) => BadRequest("Invalid metric choice. Expect drift / skew") - case (_, None, _) => BadRequest("Invalid metric rollup. Expect null / value") - case (_, _, None) => BadRequest("Invalid granularity. Expect raw / percentile / aggregates") + case (None, _, _) => Future.successful(BadRequest("Invalid metric choice. Expect drift / skew")) + case (_, None, _) => Future.successful(BadRequest("Invalid metric rollup. Expect null / value")) + case (_, _, None) => Future.successful(BadRequest("Invalid granularity. Expect raw / percentile / aggregates")) case (Some(Drift), Some(rollup), Some(g)) => - doFetchFeatureDrift(name, startTs, endTs, rollup, slice, g, offset, algorithm) - case (Some(Skew), Some(rollup), Some(g)) => doFetchFeatureSkew(name, startTs, endTs, rollup, slice, g) + doFetchFeatureDrift(join, name, startTs, endTs, rollup, g, offset, algorithm) + case (Some(Skew), Some(rollup), Some(g)) => doFetchFeatureSkew(name, startTs, endTs, rollup, g) } } - private def doFetchFeatureDrift(name: String, + private def doFetchFeatureDrift(join: String, + name: String, startTs: Long, endTs: Long, metric: Metric, - sliceId: Option[String], granularity: Granularity, offset: Option[String], - algorithm: Option[String]): Result = { + algorithm: Option[String]): Future[Result] = { if (granularity == Raw) { - BadRequest("We don't support Raw granularity for drift metric types") + Future.successful(BadRequest("We don't support Raw granularity for drift metric types")) } else { (parseOffset(offset), parseAlgorithm(algorithm)) match { - case (None, _) => BadRequest(s"Unable to parse offset - $offset") - case (_, None) => BadRequest("Invalid drift algorithm. Expect PSI or KL") - case (Some(_), Some(_)) => - // TODO: Use parsedOffset and parsedAlgorithm when ready - val featureTsJson = if (granularity == Aggregates) { - // if feature name ends in an even digit we consider it continuous and generate mock data accordingly - // else we generate mock data for a categorical feature - val featureId = name.split("_").last.toInt - val featureTs = if (featureId % 2 == 0) { - ComparedFeatureTimeSeries(name, - generateMockRawTimeSeriesPoints(startTs, 100), - generateMockRawTimeSeriesPoints(startTs, 100)) - } else { - ComparedFeatureTimeSeries(name, - generateMockCategoricalTimeSeriesPoints(startTs, 5, 1), - generateMockCategoricalTimeSeriesPoints(startTs, 5, 2)) + case (None, _) => Future.successful(BadRequest(s"Unable to parse offset - $offset")) + case (_, None) => Future.successful(BadRequest("Invalid drift algorithm. Expect JSD, PSI or Hellinger")) + case (Some(o), Some(driftMetric)) => + val window = new Window(o.toMinutes.toInt, TimeUnit.MINUTES) + if (granularity == Aggregates) { + val maybeDriftSeries = + driftStore.getDriftSeries(join, driftMetric, window, startTs, endTs, Some(name)) + maybeDriftSeries match { + case Failure(exception) => Future.successful(InternalServerError(s"Error computing feature drift - ${exception.getMessage}")) + case Success(driftSeriesFuture) => driftSeriesFuture.map { + driftSeries => + val featureTs = convertTileDriftSeriesInfoToTimeSeries(driftSeries.head, metric) + Ok(featureTs.asJson.noSpaces) + } } - featureTs.asJson } else { - // - //{new: Array[Double], old: Array[Double], x: Array[String]} - //{old_null_count: Long, new_null_count: long, old_total_count: Long, new_total_count: Long} - - FeatureTimeSeries(name, generateMockTimeSeriesPercentilePoints(startTs, endTs)).asJson + // percentiles + val maybeCurrentSummarySeries = driftStore.getSummarySeries(join, startTs, endTs, Some(name)) + val maybeBaselineSummarySeries = driftStore.getSummarySeries(join, startTs - window.millis, endTs - window.millis, Some(name)) + (maybeCurrentSummarySeries, maybeBaselineSummarySeries) match { + case (Failure(exceptionA), Failure(exceptionB)) => Future.successful(InternalServerError(s"Error computing feature percentiles for current + offset time window.\nCurrent window error: ${exceptionA.getMessage}\nOffset window error: ${exceptionB.getMessage}")) + case (_, Failure(exception)) => Future.successful(InternalServerError(s"Error computing feature percentiles for offset time window - ${exception.getMessage}")) + case (Failure(exception), _) => Future.successful(InternalServerError(s"Error computing feature percentiles for current time window - ${exception.getMessage}")) + case (Success(currentSummarySeriesFuture), Success(baselineSummarySeriesFuture)) => + Future.sequence(Seq(currentSummarySeriesFuture, baselineSummarySeriesFuture)).map { + merged => + val currentSummarySeries = merged.head + val baselineSummarySeries = merged.last + val currentFeatureTs = convertTileSummarySeriesToTimeSeries(currentSummarySeries.head, metric) + val baselineFeatureTs = convertTileSummarySeriesToTimeSeries(baselineSummarySeries.head, metric) + val comparedTsData = ComparedFeatureTimeSeries(name, baselineFeatureTs, currentFeatureTs) + Ok(comparedTsData.asJson.noSpaces) + } + } } - Ok(featureTsJson.noSpaces) } } } @@ -249,10 +223,9 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon startTs: Long, endTs: Long, metric: Metric, - sliceId: Option[String], - granularity: Granularity): Result = { + granularity: Granularity): Future[Result] = { if (granularity == Aggregates) { - BadRequest("We don't support Aggregates granularity for skew metric types") + Future.successful(BadRequest("We don't support Aggregates granularity for skew metric types")) } else { val featureTsJson = if (granularity == Raw) { val featureTs = ComparedFeatureTimeSeries(name, @@ -263,7 +236,53 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon val featuresTs = FeatureTimeSeries(name, generateMockTimeSeriesPercentilePoints(startTs, endTs)) featuresTs.asJson.noSpaces } - Ok(featureTsJson) + Future.successful(Ok(featureTsJson)) + } + } + + private def convertTileDriftSeriesInfoToTimeSeries(tileDriftSeries: TileDriftSeries, metric: Metric): FeatureTimeSeries = { + val lhsList = if (metric == NullMetric) { + tileDriftSeries.nullRatioChangePercentSeries.asScala + } else { + // check if we have a numeric / categorical feature. If the percentile drift series has non-null doubles + // then we have a numeric feature at hand + val isNumeric = tileDriftSeries.percentileDriftSeries.asScala != null && tileDriftSeries.percentileDriftSeries.asScala.exists(_ != null) + if (isNumeric) tileDriftSeries.percentileDriftSeries.asScala + else tileDriftSeries.histogramDriftSeries.asScala + } + val points = lhsList.zip(tileDriftSeries.timestamps.asScala).map { + case (v, ts) => TimeSeriesPoint(v, ts) + } + + FeatureTimeSeries(tileDriftSeries.getKey.getColumn, points) + } + + private def convertTileSummarySeriesToTimeSeries(summarySeries: TileSummarySeries, metric: Metric): Seq[TimeSeriesPoint] = { + if (metric == NullMetric) { + summarySeries.nullCount.asScala.zip(summarySeries.timestamps.asScala).map { + case (nullCount, ts) => TimeSeriesPoint(0, ts, nullValue = Some(nullCount.intValue())) + } + } else { + // check if we have a numeric / categorical feature. If the percentile drift series has non-null doubles + // then we have a numeric feature at hand + val isNumeric = summarySeries.percentiles.asScala != null && summarySeries.percentiles.asScala.exists(_ != null) + if (isNumeric) { + summarySeries.percentiles.asScala.zip(summarySeries.timestamps.asScala).flatMap { + case (percentiles, ts) => + DriftStore.percentileLabels.zip(percentiles.asScala).map { + case (l, value) => TimeSeriesPoint(value, ts, Some(l)) + } + } + } + else { + summarySeries.timestamps.asScala.zipWithIndex.flatMap { + case (ts, idx) => + summarySeries.histogram.asScala.map { + case (label, values) => + TimeSeriesPoint(values.get(idx).toDouble, ts, Some(label)) + } + } + } } } } @@ -281,13 +300,11 @@ object TimeSeriesController { } def parseAlgorithm(algorithm: Option[String]): Option[DriftMetric] = { - algorithm.map { - _.toLowerCase match { - case "psi" => DriftMetric.PSI - case "hellinger" => DriftMetric.HELLINGER - case "jsd" => DriftMetric.JENSEN_SHANNON - case _ => throw new IllegalArgumentException("Invalid drift algorithm. Pick one of PSI, Hellinger or JSD") - } + algorithm.map(_.toLowerCase) match { + case Some("psi") => Some(DriftMetric.PSI) + case Some("hellinger") => Some(DriftMetric.HELLINGER) + case Some("jsd") => Some(DriftMetric.JENSEN_SHANNON) + case _ => None } } diff --git a/hub/app/module/DriftStoreModule.scala b/hub/app/module/DriftStoreModule.scala new file mode 100644 index 0000000000..ee831e5cad --- /dev/null +++ b/hub/app/module/DriftStoreModule.scala @@ -0,0 +1,15 @@ +package module + +import ai.chronon.integrations.aws.AwsApiImpl +import ai.chronon.online.stats.DriftStore +import com.google.inject.AbstractModule +import play.api.{Configuration, Environment} + +class DriftStoreModule(environment: Environment, configuration: Configuration) extends AbstractModule { + + override def configure(): Unit = { + val awsApiImpl = new AwsApiImpl(Map.empty) + val driftStore = new DriftStore(awsApiImpl.genKvStore) + bind(classOf[DriftStore]).toInstance(driftStore) + } +} diff --git a/hub/app/store/MonitoringModelStore.scala b/hub/app/store/MonitoringModelStore.scala index bfa1f5e895..d4d67d8f4b 100644 --- a/hub/app/store/MonitoringModelStore.scala +++ b/hub/app/store/MonitoringModelStore.scala @@ -59,7 +59,7 @@ class MonitoringModelStore(apiImpl: Api) { GroupBy(part.groupBy.metaData.name, part.groupBy.valueColumns) } - val outputColumns = thriftJoin.ooutputColumnsByGroup.getOrElse("derivations", Array.empty) + val outputColumns = thriftJoin.outputColumnsByGroup.getOrElse("derivations", Array.empty) val join = Join(thriftJoin.metaData.name, outputColumns, groupBys) Option( Model(m.metaData.name, join, m.metaData.online, m.metaData.production, m.metaData.team, m.modelType.name())) diff --git a/hub/conf/application.conf b/hub/conf/application.conf index 1d6b9996bf..5696edf2d2 100644 --- a/hub/conf/application.conf +++ b/hub/conf/application.conf @@ -29,3 +29,4 @@ play.filters.cors { # Add DynamoDB module play.modules.enabled += "module.ModelStoreModule" +play.modules.enabled += "module.DriftStoreModule" diff --git a/hub/conf/routes b/hub/conf/routes index 7f686530ba..7c88e44bb6 100644 --- a/hub/conf/routes +++ b/hub/conf/routes @@ -12,7 +12,7 @@ GET /api/v1/join/:name/timeseries controllers.TimeSeriesC # join -> seq(feature) # when metricType == "drift" - will return time series list of drift values -GET /api/v1/feature/:name/timeseries controllers.TimeSeriesController.fetchFeature(name: String, startTs: Long, endTs: Long, metricType: String, metrics: String, granularity: String, offset: Option[String], algorithm: Option[String]) +GET /api/v1/join/:join/feature/:name/timeseries controllers.TimeSeriesController.fetchFeature(join: String, name: String, startTs: Long, endTs: Long, metricType: String, metrics: String, granularity: String, offset: Option[String], algorithm: Option[String]) # TODO - move the core flow to fine-grained end-points diff --git a/hub/test/controllers/TimeSeriesControllerSpec.scala b/hub/test/controllers/TimeSeriesControllerSpec.scala index 2d431e0a24..8fb07cb511 100644 --- a/hub/test/controllers/TimeSeriesControllerSpec.scala +++ b/hub/test/controllers/TimeSeriesControllerSpec.scala @@ -1,9 +1,13 @@ package controllers +import ai.chronon.api.{TileDriftSeries, TileSeriesKey, TileSummarySeries} +import ai.chronon.online.stats.DriftStore import io.circe._ import io.circe.generic.auto._ import io.circe.parser._ import model._ +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.{mock, when} import org.scalatest.EitherValues import org.scalatestplus.play._ import play.api.http.Status.BAD_REQUEST @@ -13,14 +17,24 @@ import play.api.test.Helpers._ import play.api.test._ import java.util.concurrent.TimeUnit +import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration +import scala.util.{Failure, Success, Try} +import java.lang.{Double => JDouble} +import java.lang.{Long => JLong} +import scala.jdk.CollectionConverters._ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { // Create a stub ControllerComponents val stubCC: ControllerComponents = stubControllerComponents() - val controller = new TimeSeriesController(stubCC) + implicit val ec: ExecutionContext = ExecutionContext.global + + // Create a mocked drift store + val mockedStore: DriftStore = mock(classOf[DriftStore]) + val controller = new TimeSeriesController(stubCC, mockedStore) + val mockCategories: Seq[String] = Seq("a", "b", "c") "TimeSeriesController's model ts lookup" should { @@ -78,11 +92,26 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { status(invalid1) mustBe BAD_REQUEST } - "send valid results on a correctly formed model ts drift lookup request" in { + "send 5xx on failed drift store lookup request" in { + when(mockedStore.getDriftSeries(any(), any(), any(), any(), any(), any())).thenReturn(Failure(new IllegalArgumentException("Some internal error"))) + val startTs = 1725926400000L // 09/10/2024 00:00 UTC val endTs = 1726106400000L // 09/12/2024 02:00 UTC val result = controller.fetchJoin("my_join", startTs, endTs, "drift", "null", Some("10h"), Some("psi")).apply(FakeRequest()) + + status(result) mustBe INTERNAL_SERVER_ERROR + } + + "send valid results on a correctly formed model ts drift lookup request" in { + val startTs = 1725926400000L // 09/10/2024 00:00 UTC + val endTs = 1726106400000L // 09/12/2024 02:00 UTC + + val mockedDriftStoreResponse = generateDriftSeries(startTs, endTs, "my_join", 2, 3) + when(mockedStore.getDriftSeries(any(), any(), any(), any(), any(), any())).thenReturn(mockedDriftStoreResponse) + + val result = + controller.fetchJoin("my_join", startTs, endTs, "drift", "value", Some("10h"), Some("psi")).apply(FakeRequest()) status(result) mustBe OK val bodyText = contentAsString(result) val modelTSResponse: Either[Error, JoinTimeSeriesResponse] = decode[JoinTimeSeriesResponse](bodyText) @@ -123,32 +152,32 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { "send 400 on an invalid metric choice" in { val invalid = - controller.fetchFeature("my_feature", 123L, 456L, "meow", "null", "raw", None, None).apply(FakeRequest()) + controller.fetchFeature("my_join", "my_feature", 123L, 456L, "meow", "null", "raw", None, None).apply(FakeRequest()) status(invalid) mustBe BAD_REQUEST } "send 400 on an invalid metric rollup" in { val invalid = - controller.fetchFeature("my_feature", 123L, 456L, "drift", "woof", "raw", None, None).apply(FakeRequest()) + controller.fetchFeature("my_join", "my_feature", 123L, 456L, "drift", "woof", "raw", None, None).apply(FakeRequest()) status(invalid) mustBe BAD_REQUEST } "send 400 on an invalid granularity" in { val invalid = - controller.fetchFeature("my_feature", 123L, 456L, "drift", "null", "woof", None, None).apply(FakeRequest()) + controller.fetchFeature("my_join", "my_feature", 123L, 456L, "drift", "null", "woof", None, None).apply(FakeRequest()) status(invalid) mustBe BAD_REQUEST } "send 400 on an invalid time offset for drift metric" in { val invalid1 = controller - .fetchFeature("my_feature", 123L, 456L, "drift", "null", "aggregates", Some("Xh"), Some("psi")) + .fetchFeature("my_join", "my_feature", 123L, 456L, "drift", "null", "aggregates", Some("Xh"), Some("psi")) .apply(FakeRequest()) status(invalid1) mustBe BAD_REQUEST val invalid2 = controller - .fetchFeature("my_feature", 123L, 456L, "drift", "null", "aggregates", Some("-1h"), Some("psi")) + .fetchFeature("my_join", "my_feature", 123L, 456L, "drift", "null", "aggregates", Some("-1h"), Some("psi")) .apply(FakeRequest()) status(invalid2) mustBe BAD_REQUEST } @@ -156,7 +185,7 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { "send 400 on an invalid algorithm for drift metric" in { val invalid1 = controller - .fetchFeature("my_feature", 123L, 456L, "drift", "null", "aggregates", Some("10h"), Some("meow")) + .fetchFeature("my_join", "my_feature", 123L, 456L, "drift", "null", "aggregates", Some("10h"), Some("meow")) .apply(FakeRequest()) status(invalid1) mustBe BAD_REQUEST } @@ -164,7 +193,7 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { "send 400 on an invalid granularity for drift metric" in { val invalid1 = controller - .fetchFeature("my_feature", 123L, 456L, "drift", "null", "raw", Some("10h"), Some("psi")) + .fetchFeature("my_join", "my_feature", 123L, 456L, "drift", "null", "raw", Some("10h"), Some("psi")) .apply(FakeRequest()) status(invalid1) mustBe BAD_REQUEST } @@ -172,74 +201,101 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { "send 400 on an invalid granularity for skew metric" in { val invalid1 = controller - .fetchFeature("my_feature", 123L, 456L, "skew", "null", "aggregates", Some("10h"), Some("psi")) + .fetchFeature("my_join", "my_feature", 123L, 456L, "skew", "null", "aggregates", Some("10h"), Some("psi")) .apply(FakeRequest()) status(invalid1) mustBe BAD_REQUEST } - "send valid results on a correctly formed numeric feature ts aggregate drift lookup request" in { + "send valid results on a correctly formed feature ts aggregate drift lookup request" in { val startTs = 1725926400000L // 09/10/2024 00:00 UTC val endTs = 1726106400000L // 09/12/2024 02:00 UTC + + val mockedDriftStoreResponse = generateDriftSeries(startTs, endTs, "my_join", 1, 1) + when(mockedStore.getDriftSeries(any(), any(), any(), any(), any(), any())).thenReturn(mockedDriftStoreResponse) + val result = controller - .fetchFeature("my_feature_0", startTs, endTs, "drift", "null", "aggregates", Some("10h"), Some("psi")) + .fetchFeature("my_join", "my_feature_0", startTs, endTs, "drift", "null", "aggregates", Some("10h"), Some("psi")) .apply(FakeRequest()) status(result) mustBe OK val bodyText = contentAsString(result) - val featureTSResponse: Either[Error, ComparedFeatureTimeSeries] = decode[ComparedFeatureTimeSeries](bodyText) + val featureTSResponse: Either[Error, FeatureTimeSeries] = decode[FeatureTimeSeries](bodyText) featureTSResponse.isRight mustBe true val response = featureTSResponse.right.value response.feature mustBe "my_feature_0" - response.current.length mustBe response.baseline.length - response.current.zip(response.baseline).foreach { - case (current, baseline) => - current.ts mustBe baseline.ts - } + val expectedLength = expectedHours(startTs, endTs) + response.points.length mustBe expectedLength } - "send valid results on a correctly formed categorical feature ts aggregate drift lookup request" in { + "send valid results on a correctly formed numeric feature ts percentile drift lookup request" in { val startTs = 1725926400000L // 09/10/2024 00:00 UTC val endTs = 1726106400000L // 09/12/2024 02:00 UTC + + val mockedSummarySeriesResponseA = generateSummarySeries(startTs, endTs, "my_join", "my_groupby", "my_feature", ValuesMetric, true) + val offset = Duration.apply(7, TimeUnit.DAYS) + val mockedSummarySeriesResponseB = + generateSummarySeries(startTs - offset.toMillis, endTs - offset.toMillis, "my_join", "my_groupby", "my_feature", ValuesMetric, true) + when(mockedStore.getSummarySeries(any(), any(), any(), any())).thenReturn(mockedSummarySeriesResponseA, mockedSummarySeriesResponseB) + val result = controller - .fetchFeature("my_feature_1", startTs, endTs, "drift", "null", "aggregates", Some("10h"), Some("psi")) + .fetchFeature("my_join", "my_feature", startTs, endTs, "drift", "value", "percentile", Some("10h"), Some("psi")) .apply(FakeRequest()) status(result) mustBe OK val bodyText = contentAsString(result) val featureTSResponse: Either[Error, ComparedFeatureTimeSeries] = decode[ComparedFeatureTimeSeries](bodyText) featureTSResponse.isRight mustBe true val response = featureTSResponse.right.value - response.feature mustBe "my_feature_1" - response.current.map(_.ts).toSet mustBe response.baseline.map(_.ts).toSet - response.current.foreach(_.label.isEmpty mustBe false) - response.baseline.foreach(_.label.isEmpty mustBe false) + response.feature mustBe "my_feature" + response.current.length mustBe response.baseline.length + response.current.zip(response.baseline).foreach { + case (current, baseline) => + (current.ts - baseline.ts) mustBe offset.toMillis + } + + // expect one entry per percentile for each time series point + val expectedLength = DriftStore.percentileLabels.length * expectedHours(startTs, endTs) + response.current.length mustBe expectedLength + response.baseline.length mustBe expectedLength } - "send valid results on a correctly formed feature ts percentile drift lookup request" in { + "send valid results on a correctly formed categorical feature ts percentile drift lookup request" in { val startTs = 1725926400000L // 09/10/2024 00:00 UTC val endTs = 1726106400000L // 09/12/2024 02:00 UTC + + val mockedSummarySeriesResponseA = generateSummarySeries(startTs, endTs, "my_join", "my_groupby", "my_feature", ValuesMetric, false) + val offset = Duration.apply(7, TimeUnit.DAYS) + val mockedSummarySeriesResponseB = + generateSummarySeries(startTs - offset.toMillis, endTs - offset.toMillis, "my_join", "my_groupby", "my_feature", ValuesMetric, false) + when(mockedStore.getSummarySeries(any(), any(), any(), any())).thenReturn(mockedSummarySeriesResponseA, mockedSummarySeriesResponseB) + val result = controller - .fetchFeature("my_feature", startTs, endTs, "drift", "null", "percentile", Some("10h"), Some("psi")) + .fetchFeature("my_join", "my_feature", startTs, endTs, "drift", "value", "percentile", Some("10h"), Some("psi")) .apply(FakeRequest()) status(result) mustBe OK val bodyText = contentAsString(result) - val featureTSResponse: Either[Error, FeatureTimeSeries] = decode[FeatureTimeSeries](bodyText) + val featureTSResponse: Either[Error, ComparedFeatureTimeSeries] = decode[ComparedFeatureTimeSeries](bodyText) featureTSResponse.isRight mustBe true val response = featureTSResponse.right.value response.feature mustBe "my_feature" - response.points.nonEmpty mustBe true - - // expect one entry per percentile for each time series point - val expectedLength = TimeSeriesController.mockGeneratedPercentiles.length * expectedHours(startTs, endTs) - response.points.length mustBe expectedLength + response.current.length mustBe response.baseline.length + // expect one entry per category for each time series point + val expectedLength = mockCategories.length * expectedHours(startTs, endTs) + response.current.length mustBe expectedLength + response.current.zip(response.baseline).foreach { + case (current, baseline) => + (current.ts - baseline.ts) mustBe offset.toMillis + current.label.isEmpty mustBe false + baseline.label.isEmpty mustBe false + } } "send valid results on a correctly formed feature ts raw skew lookup request" in { val startTs = 1725926400000L // 09/10/2024 00:00 UTC val endTs = 1726106400000L // 09/12/2024 02:00 UTC val result = - controller.fetchFeature("my_feature", startTs, endTs, "skew", "null", "raw", None, None).apply(FakeRequest()) + controller.fetchFeature("my_join", "my_feature", startTs, endTs, "skew", "null", "raw", None, None).apply(FakeRequest()) status(result) mustBe OK val bodyText = contentAsString(result) val featureTSResponse: Either[Error, ComparedFeatureTimeSeries] = @@ -259,7 +315,7 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { val endTs = 1726106400000L // 09/12/2024 02:00 UTC val result = controller - .fetchFeature("my_feature", startTs, endTs, "skew", "null", "percentile", None, None) + .fetchFeature("my_join", "my_feature", startTs, endTs, "skew", "null", "percentile", None, None) .apply(FakeRequest()) status(result) mustBe OK val bodyText = contentAsString(result) @@ -278,4 +334,67 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { private def expectedHours(startTs: Long, endTs: Long): Long = { Duration(endTs - startTs, TimeUnit.MILLISECONDS).toHours } + + private def generateDriftSeries(startTs: Long, endTs: Long, join: String, numGroups: Int, numFeaturesPerGroup: Int): Try[Future[Seq[TileDriftSeries]]] = { + val result = for { + group <- 0 until numGroups + feature <- 0 until numFeaturesPerGroup + } yield { + val name = s"my_group_$group" + val featureName = s"my_feature_$feature" + val tileKey = new TileSeriesKey() + tileKey.setNodeName(join) + tileKey.setGroupName(name) + tileKey.setColumn(featureName) + + val tileDriftSeries = new TileDriftSeries() + tileDriftSeries.setKey(tileKey) + + val timestamps = (startTs until endTs by (Duration(1, TimeUnit.HOURS).toMillis)).toList.map(JLong.valueOf(_)).asJava + // if feature name ends in an even digit we consider it continuous and generate mock data accordingly + // else we generate mock data for a categorical feature + val isNumeric = if (feature % 2 == 0) true else false + val percentileDrifts = if (isNumeric) List.fill(timestamps.size())(JDouble.valueOf(0.12)).asJava else List.fill[JDouble](timestamps.size())(null).asJava + val histogramDrifts = if (isNumeric) List.fill[JDouble](timestamps.size())(null).asJava else List.fill(timestamps.size())(JDouble.valueOf(0.23)).asJava + val nullRationChangePercents = List.fill(timestamps.size())(JDouble.valueOf(0.25)).asJava + tileDriftSeries.setTimestamps(timestamps) + tileDriftSeries.setPercentileDriftSeries(percentileDrifts) + tileDriftSeries.setNullRatioChangePercentSeries(nullRationChangePercents) + tileDriftSeries.setHistogramDriftSeries(histogramDrifts) + } + Success(Future.successful(result)) + } + + private def generateSummarySeries(startTs: Long, endTs: Long, join: String, groupBy: String, featureName: String, metric: Metric, isNumeric: Boolean): Try[Future[Seq[TileSummarySeries]]] = { + val tileKey = new TileSeriesKey() + tileKey.setNodeName(join) + tileKey.setGroupName(groupBy) + tileKey.setNodeName(join) + tileKey.setColumn(featureName) + + val timestamps = (startTs until endTs by (Duration(1, TimeUnit.HOURS).toMillis)).toList.map(JLong.valueOf(_)) + val tileSummarySeries = new TileSummarySeries() + tileSummarySeries.setKey(tileKey) + tileSummarySeries.setTimestamps(timestamps.asJava) + + if (metric == NullMetric) { + tileSummarySeries.setNullCount(List.fill(timestamps.length)(JLong.valueOf(1)).asJava) + } else { + if (isNumeric) { + val percentileList = timestamps.map { + _ => + List.fill(DriftStore.percentileLabels.length)(JDouble.valueOf(0.12)).asJava + }.asJava + tileSummarySeries.setPercentiles(percentileList) + } else { + val histogramMap = mockCategories.map { + category => + category -> List.fill(timestamps.length)(JLong.valueOf(1)).asJava + }.toMap.asJava + tileSummarySeries.setHistogram(histogramMap) + } + } + + Success(Future.successful(Seq(tileSummarySeries))) + } } diff --git a/online/src/main/scala/ai/chronon/online/stats/Display.scala b/online/src/main/scala/ai/chronon/online/stats/Display.scala deleted file mode 100644 index a4c757dc27..0000000000 --- a/online/src/main/scala/ai/chronon/online/stats/Display.scala +++ /dev/null @@ -1,205 +0,0 @@ -package ai.chronon.online.stats - -import cask._ -import scalatags.Text.all._ -import scalatags.Text.tags2.title - -// generates html / js code to serve a tabbed board on the network port -// boards are static and do not update, used for debugging only -// uses uPlot under the hood -object Display { - // single line inside a chart - case class Series(series: Array[Double], name: String) - // multiple lines in a chart plus the x-axis and a threshold (horizontal dashed line) - case class Chart(seriesList: Array[Series], - x: Array[String], - name: String, - moderateThreshold: Option[Double] = None, - severeThreshold: Option[Double] = None) - - // multiple charts in a section - case class Section(charts: Array[Chart], name: String) - // multiple sections in a tab - case class Tab(sectionList: Array[Section], name: String) - // multiple tabs in a board - case class Board(tabList: Array[Tab], name: String) - - private def generateChartJs(chart: Chart, chartId: String): String = { - val data = chart.seriesList.map(_.series) - val xData = chart.x.map(_.toString) - chart.seriesList.map(_.name) - - val seriesConfig = chart.seriesList.map(s => s"""{ - | label: "${s.name}", - | stroke: "rgb(${scala.util.Random.nextInt(255)}, ${scala.util.Random.nextInt(255)}, ${scala.util.Random.nextInt(255)})" - | - |}""".stripMargin).mkString(",\n") - - val thresholdLines = (chart.moderateThreshold.map(t => s""" - |{ - | label: "Moderate Threshold", - | value: $t, - | stroke: "#ff9800", - | style: [2, 2] - |}""".stripMargin) ++ - chart.severeThreshold.map(t => s""" - |{ - | label: "Severe Threshold", - | value: $t, - | stroke: "#f44336", - | style: [2, 2] - |}""".stripMargin)).mkString(",") - - s""" - |new uPlot({ - | title: "${chart.name}", - | id: "$chartId", - | class: "chart", - | width: 800, - | height: 400, - | scales: { - | x: { - | time: false, - | } - | }, - | series: [ - | {}, - | $seriesConfig - | ], - | axes: [ - | {}, - | { - | label: "Value", - | grid: true, - | } - | ], - | plugins: [ - | { - | hooks: { - | draw: u => { - | ${if (thresholdLines.nonEmpty) - s"""const lines = [$thresholdLines]; - | for (const line of lines) { - | const scale = u.scales.y; - | const y = scale.getPos(line.value); - | - | u.ctx.save(); - | u.ctx.strokeStyle = line.stroke; - | u.ctx.setLineDash(line.style); - | - | u.ctx.beginPath(); - | u.ctx.moveTo(u.bbox.left, y); - | u.ctx.lineTo(u.bbox.left + u.bbox.width, y); - | u.ctx.stroke(); - | - | u.ctx.restore(); - | }""".stripMargin - else ""} - | } - | } - | } - | ] - |}, [${xData.mkString("\"", "\",\"", "\"")}, ${data - .map(_.mkString(",")) - .mkString("[", "],[", "]")}], document.getElementById("$chartId")); - |""".stripMargin - } - - def serve(board: Board, portVal: Int = 9032): Unit = { - - object Server extends cask.MainRoutes { - @get("/") - def index() = { - val page = html( - head( - title(board.name), - script(src := "https://unpkg.com/uplot@1.6.24/dist/uPlot.iife.min.js"), - link(rel := "stylesheet", href := "https://unpkg.com/uplot@1.6.24/dist/uPlot.min.css"), - tag("style")(""" - |body { font-family: Arial, sans-serif; margin: 20px; } - |.tab { display: none; } - |.tab.active { display: block; } - |.tab-button { padding: 10px 20px; margin-right: 5px; cursor: pointer; } - |.tab-button.active { background-color: #ddd; } - |.section { margin: 20px 0; } - |.chart { margin: 20px 0; } - """.stripMargin) - ), - body( - h1(board.name), - div(cls := "tabs")( - board.tabList.map(tab => - button( - cls := "tab-button", - onclick := s"showTab('${tab.name}')", - tab.name - )) - ), - board.tabList.map(tab => - div(cls := "tab", id := tab.name)( - tab.sectionList.map(section => - div(cls := "section")( - h2(section.name), - section.charts.map(chart => - div(cls := "chart")( - div(id := s"${tab.name}-${section.name}-${chart.name}".replaceAll("\\s+", "-")) - )) - )) - )), - script(raw(""" - |function showTab(tabName) { - | document.querySelectorAll('.tab').forEach(tab => { - | tab.style.display = tab.id === tabName ? 'block' : 'none'; - | }); - | document.querySelectorAll('.tab-button').forEach(button => { - | button.classList.toggle('active', button.textContent === tabName); - | }); - |} - | - |// Show first tab by default - |document.querySelector('.tab-button').click(); - """.stripMargin)), - script( - raw( - board.tabList - .flatMap(tab => - tab.sectionList.flatMap(section => - section.charts.map(chart => - generateChartJs(chart, s"${tab.name}-${section.name}-${chart.name}".replaceAll("\\s+", "-"))))) - .mkString("\n") - )) - ) - ) - -// page.render - - cask.Response( - page.render, - headers = Seq("Content-Type" -> "text/html") - ) - } - - override def host: String = "0.0.0.0" - override def port: Int = portVal - - initialize() - } - - Server.main(Array()) - } - - def main(args: Array[String]): Unit = { - val series = Array(Series(Array(1.0, 2.0, 3.0), "Series 1"), Series(Array(2.0, 3.0, 4.0), "Series 2")) - val chart = Chart(series, Array("A", "B", "C"), "Chart 1", Some(2.5), Some(3.5)) - val section = Section(Array(chart), "Section 1") - val tab = Tab(Array(section), "Tab 1") - val board = Board(Array(tab), "Board 1") - - println("serving board at http://localhost:9032/") - serve(board) - // Keep the program running - while (true) { - Thread.sleep(5000) - } - } -} diff --git a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala index eea086e883..44c83e37c1 100644 --- a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala +++ b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala @@ -203,4 +203,10 @@ object DriftStore { def compactSerializer: SerializableSerializer = new SerializableSerializer(new TBinaryProtocol.Factory()) def compactDeserializer: TDeserializer = new TDeserializer(new TBinaryProtocol.Factory()) + + // todo - drop this hard-coded list in favor of a well known list or exposing as part of summaries + val percentileLabels: Seq[String] = Seq("p0", "p5", "p10", "p15", "p20", + "p25", "p30", "p35", "p40", "p45", + "p50", "p55", "p60", "p65", "p70", + "p75", "p80", "p85", "p90", "p95", "p100") } diff --git a/online/src/main/scala/ai/chronon/online/stats/TileDriftCalculator.scala b/online/src/main/scala/ai/chronon/online/stats/TileDriftCalculator.scala index 47518f555a..ef22b4cf1a 100644 --- a/online/src/main/scala/ai/chronon/online/stats/TileDriftCalculator.scala +++ b/online/src/main/scala/ai/chronon/online/stats/TileDriftCalculator.scala @@ -81,7 +81,7 @@ object TileDriftCalculator { result } - // for each summary with ts >= startMs, use spec.lookBack to find the previous summary and calculate dirft + // for each summary with ts >= startMs, use spec.lookBack to find the previous summary and calculate drift // we do this by first creating a map of summaries by timestamp def toTileDrifts(summariesWithTimestamps: Array[(TileSummary, Long)], metric: DriftMetric, From 86b1cfd85f767aba1aa22924de25252dfeb1e2c2 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Thu, 21 Nov 2024 20:31:35 -0500 Subject: [PATCH 083/152] Rip out mock data generation and corresponding endpoints --- .../controllers/TimeSeriesController.scala | 116 +----------------- hub/conf/routes | 4 +- .../TimeSeriesControllerSpec.scala | 108 ++-------------- 3 files changed, 16 insertions(+), 212 deletions(-) diff --git a/hub/app/controllers/TimeSeriesController.scala b/hub/app/controllers/TimeSeriesController.scala index 0f253999a3..625fcd729e 100644 --- a/hub/app/controllers/TimeSeriesController.scala +++ b/hub/app/controllers/TimeSeriesController.scala @@ -21,14 +21,6 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon import TimeSeriesController._ - /** - * Helps retrieve a model performance drift time series. Time series is retrieved between the start and end ts. - * The offset is used to compute the distribution to compare against (we compare current time range with the same - * sized time range starting offset time period prior). - */ - def fetchModel(id: String, startTs: Long, endTs: Long, offset: String, algorithm: String): Action[AnyContent] = - doFetchModel(id, startTs, endTs, offset, algorithm) - /** * Helps retrieve a time series (drift or skew) for each of the features that are part of a join. Time series is * retrieved between the start and end ts. If the metric type is for drift, the offset is used to compute the @@ -61,22 +53,6 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon algorithm: Option[String]): Action[AnyContent] = doFetchFeature(join, name, startTs, endTs, metricType, metrics, granularity, offset, algorithm) - private def doFetchModel(id: String, - startTs: Long, - endTs: Long, - offset: String, - algorithm: String): Action[AnyContent] = - Action { implicit request: Request[AnyContent] => - (parseOffset(Some(offset)), parseAlgorithm(Some(algorithm))) match { - case (None, _) => BadRequest(s"Unable to parse offset - $offset") - case (_, None) => BadRequest("Invalid drift algorithm. Expect PSI or KL") - case (Some(_), Some(_)) => - // TODO: Use parsedOffset and parsedAlgorithm when ready - val mockTSData = ModelTimeSeriesResponse(id, generateMockTimeSeriesPoints(startTs, endTs)) - Ok(mockTSData.asJson.noSpaces) - } - } - private def doFetchJoin(name: String, startTs: Long, endTs: Long, @@ -89,10 +65,9 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon val metricRollup = parseMetricRollup(Some(metrics)) (metricChoice, metricRollup) match { - case (None, _) => Future.successful(BadRequest("Invalid metric choice. Expect drift / skew")) + case (None, _) => Future.successful(BadRequest("Invalid metric choice. Expect drift")) case (_, None) => Future.successful(BadRequest("Invalid metric rollup. Expect null / value")) case (Some(Drift), Some(rollup)) => doFetchJoinDrift(name, startTs, endTs, rollup, offset, algorithm) - case (Some(Skew), Some(rollup)) => doFetchJoinSkew(name, startTs, endTs, rollup) } } @@ -126,24 +101,6 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon } } - private def doFetchJoinSkew(name: String, - startTs: Long, - endTs: Long, - metric: Metric): Future[Result] = { - val mockGroupBys = generateMockGroupBys(3) - val groupByTimeSeries = mockGroupBys.map { g => - val mockFeatures = generateMockFeatures(g, 10) - val featureTS = mockFeatures.map { - FeatureTimeSeries(_, generateMockTimeSeriesPoints(startTs, endTs)) - } - GroupByTimeSeries(g, featureTS) - } - - val mockTSData = JoinTimeSeriesResponse(name, groupByTimeSeries) - val json = mockTSData.asJson.noSpaces - Future.successful(Ok(json)) - } - private def doFetchFeature(join: String, name: String, startTs: Long, @@ -159,12 +116,11 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon val granularityType = parseGranularity(granularity) (metricChoice, metricRollup, granularityType) match { - case (None, _, _) => Future.successful(BadRequest("Invalid metric choice. Expect drift / skew")) + case (None, _, _) => Future.successful(BadRequest("Invalid metric choice. Expect drift")) case (_, None, _) => Future.successful(BadRequest("Invalid metric rollup. Expect null / value")) case (_, _, None) => Future.successful(BadRequest("Invalid granularity. Expect raw / percentile / aggregates")) case (Some(Drift), Some(rollup), Some(g)) => doFetchFeatureDrift(join, name, startTs, endTs, rollup, g, offset, algorithm) - case (Some(Skew), Some(rollup), Some(g)) => doFetchFeatureSkew(name, startTs, endTs, rollup, g) } } @@ -219,27 +175,6 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon } } - private def doFetchFeatureSkew(name: String, - startTs: Long, - endTs: Long, - metric: Metric, - granularity: Granularity): Future[Result] = { - if (granularity == Aggregates) { - Future.successful(BadRequest("We don't support Aggregates granularity for skew metric types")) - } else { - val featureTsJson = if (granularity == Raw) { - val featureTs = ComparedFeatureTimeSeries(name, - generateMockRawTimeSeriesPoints(startTs, 100), - generateMockRawTimeSeriesPoints(startTs, 100)) - featureTs.asJson.noSpaces - } else { - val featuresTs = FeatureTimeSeries(name, generateMockTimeSeriesPercentilePoints(startTs, endTs)) - featuresTs.asJson.noSpaces - } - Future.successful(Ok(featureTsJson)) - } - } - private def convertTileDriftSeriesInfoToTimeSeries(tileDriftSeries: TileDriftSeries, metric: Metric): FeatureTimeSeries = { val lhsList = if (metric == NullMetric) { tileDriftSeries.nullRatioChangePercentSeries.asScala @@ -308,11 +243,12 @@ object TimeSeriesController { } } + // We currently only support drift def parseMetricChoice(metricType: Option[String]): Option[MetricType] = { metricType.map(_.toLowerCase) match { case Some("drift") => Some(Drift) - case Some("skew") => Some(Skew) - case Some("ooc") => Some(Skew) +// case Some("skew") => Some(Skew) +// case Some("ooc") => Some(Skew) case _ => None } } @@ -333,46 +269,4 @@ object TimeSeriesController { case _ => None } } - - // !!!!! Mock generation code !!!!! // - - val mockGeneratedPercentiles: Seq[String] = - Seq("p0", "p10", "p20", "p30", "p40", "p50", "p60", "p70", "p75", "p80", "p90", "p95", "p99", "p100") - - // temporarily serve up mock data while we wait on hooking up our KV store layer + drift calculation - private def generateMockTimeSeriesPoints(startTs: Long, endTs: Long): Seq[TimeSeriesPoint] = { - val random = new Random(1000) - (startTs until endTs by (1.hours.toMillis)).map(ts => TimeSeriesPoint(random.nextDouble(), ts)) - } - - private def generateMockRawTimeSeriesPoints(timestamp: Long, count: Int): Seq[TimeSeriesPoint] = { - val random = new Random(1000) - (0 until count).map(_ => TimeSeriesPoint(random.nextDouble(), timestamp)) - } - - private def generateMockCategoricalTimeSeriesPoints(timestamp: Long, - categoryCount: Int, - nullCategoryCount: Int): Seq[TimeSeriesPoint] = { - val random = new Random(1000) - val catTSPoints = (0 until categoryCount).map(i => TimeSeriesPoint(random.nextInt(1000), timestamp, Some(s"A_$i"))) - val nullCatTSPoints = (0 until nullCategoryCount).map(i => - TimeSeriesPoint(random.nextDouble(), timestamp, Some(s"A_{$i + $categoryCount}"), Some(random.nextInt(10)))) - catTSPoints ++ nullCatTSPoints - } - - private def generateMockTimeSeriesPercentilePoints(startTs: Long, endTs: Long): Seq[TimeSeriesPoint] = { - val random = new Random(1000) - (startTs until endTs by (1.hours.toMillis)).flatMap { ts => - mockGeneratedPercentiles.zipWithIndex.map { - case (p, _) => TimeSeriesPoint(random.nextDouble(), ts, Some(p)) - } - } - } - - private def generateMockGroupBys(numGroupBys: Int): Seq[String] = - (1 to numGroupBys).map(i => s"my_groupby_$i") - - private def generateMockFeatures(groupBy: String, featuresPerGroupBy: Int): Seq[String] = - (1 to featuresPerGroupBy).map(i => s"$groupBy.my_feature_$i") - } diff --git a/hub/conf/routes b/hub/conf/routes index 7c88e44bb6..896fb7655c 100644 --- a/hub/conf/routes +++ b/hub/conf/routes @@ -3,8 +3,8 @@ GET /api/v1/ping controllers.Application GET /api/v1/models controllers.ModelController.list(offset: Option[Int], limit: Option[Int]) GET /api/v1/search controllers.SearchController.search(term: String, offset: Option[Int], limit: Option[Int]) -# model prediction & model drift -GET /api/v1/model/:id/timeseries controllers.TimeSeriesController.fetchModel(id: String, startTs: Long, endTs: Long, offset: String, algorithm: String) +# model prediction & model drift - this is TBD at the moment +# GET /api/v1/model/:id/timeseries controllers.TimeSeriesController.fetchModel(id: String, startTs: Long, endTs: Long, offset: String, algorithm: String) # all timeseries of a given join id # when metricType == "drift" - will return time series list of drift values diff --git a/hub/test/controllers/TimeSeriesControllerSpec.scala b/hub/test/controllers/TimeSeriesControllerSpec.scala index 8fb07cb511..77d1f16806 100644 --- a/hub/test/controllers/TimeSeriesControllerSpec.scala +++ b/hub/test/controllers/TimeSeriesControllerSpec.scala @@ -36,41 +36,16 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { val controller = new TimeSeriesController(stubCC, mockedStore) val mockCategories: Seq[String] = Seq("a", "b", "c") - "TimeSeriesController's model ts lookup" should { + "TimeSeriesController's join ts lookup" should { - "send 400 on an invalid time offset" in { - val invalid1 = controller.fetchModel("id-123", 123L, 456L, "Xh", "psi").apply(FakeRequest()) + "send 400 on an invalid metric choice" in { + val invalid1 = controller.fetchJoin("my_join", 123L, 456L, "meow", "null", None, None).apply(FakeRequest()) status(invalid1) mustBe BAD_REQUEST - val invalid2 = controller.fetchModel("id-123", 123L, 456L, "-10h", "psi").apply(FakeRequest()) + val invalid2 = controller.fetchJoin("my_join", 123L, 456L, "skew", "null", None, None).apply(FakeRequest()) status(invalid2) mustBe BAD_REQUEST } - "send 400 on an invalid algorithm" in { - val invalid1 = controller.fetchModel("id-123", 123L, 456L, "10h", "meow").apply(FakeRequest()) - status(invalid1) mustBe BAD_REQUEST - } - - "send valid results on a correctly formed model ts request" in { - val startTs = 1725926400000L // 09/10/2024 00:00 UTC - val endTs = 1726106400000L // 09/12/2024 02:00 UTC - val result = controller.fetchModel("id-123", startTs, endTs, "10h", "psi").apply(FakeRequest()) - status(result) mustBe OK - val bodyText = contentAsString(result) - val modelTSResponse: Either[Error, ModelTimeSeriesResponse] = decode[ModelTimeSeriesResponse](bodyText) - modelTSResponse.isRight mustBe true - val items = modelTSResponse.right.value.items - items.length mustBe (Duration(endTs, TimeUnit.MILLISECONDS) - Duration(startTs, TimeUnit.MILLISECONDS)).toHours - } - } - - "TimeSeriesController's join ts lookup" should { - - "send 400 on an invalid metric choice" in { - val invalid = controller.fetchJoin("my_join", 123L, 456L, "meow", "null", None, None).apply(FakeRequest()) - status(invalid) mustBe BAD_REQUEST - } - "send 400 on an invalid metric rollup" in { val invalid = controller.fetchJoin("my_join", 123L, 456L, "drift", "woof", None, None).apply(FakeRequest()) status(invalid) mustBe BAD_REQUEST @@ -126,34 +101,16 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { grpByTs.items.foreach(featureTs => featureTs.points.length mustBe expectedLength) } } - - "send valid results on a correctly formed model ts skew lookup request" in { - val startTs = 1725926400000L // 09/10/2024 00:00 UTC - val endTs = 1726106400000L // 09/12/2024 02:00 UTC - val result = - controller.fetchJoin("my_join", startTs, endTs, "skew", "null", None, None).apply(FakeRequest()) - status(result) mustBe OK - val bodyText = contentAsString(result) - val modelTSResponse: Either[Error, JoinTimeSeriesResponse] = decode[JoinTimeSeriesResponse](bodyText) - modelTSResponse.isRight mustBe true - val response = modelTSResponse.right.value - response.name mustBe "my_join" - response.items.nonEmpty mustBe true - - val expectedLength = expectedHours(startTs, endTs) - response.items.foreach { grpByTs => - grpByTs.items.isEmpty mustBe false - grpByTs.items.foreach(featureTs => featureTs.points.length mustBe expectedLength) - } - } } "TimeSeriesController's feature ts lookup" should { "send 400 on an invalid metric choice" in { - val invalid = - controller.fetchFeature("my_join", "my_feature", 123L, 456L, "meow", "null", "raw", None, None).apply(FakeRequest()) - status(invalid) mustBe BAD_REQUEST + val invalid1 = controller.fetchFeature("my_join", "my_feature", 123L, 456L, "meow", "null", "raw", None, None).apply(FakeRequest()) + status(invalid1) mustBe BAD_REQUEST + + val invalid2 = controller.fetchFeature("my_join", "my_feature", 123L, 456L, "skew", "null", "raw", None, None).apply(FakeRequest()) + status(invalid2) mustBe BAD_REQUEST } "send 400 on an invalid metric rollup" in { @@ -198,14 +155,6 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { status(invalid1) mustBe BAD_REQUEST } - "send 400 on an invalid granularity for skew metric" in { - val invalid1 = - controller - .fetchFeature("my_join", "my_feature", 123L, 456L, "skew", "null", "aggregates", Some("10h"), Some("psi")) - .apply(FakeRequest()) - status(invalid1) mustBe BAD_REQUEST - } - "send valid results on a correctly formed feature ts aggregate drift lookup request" in { val startTs = 1725926400000L // 09/10/2024 00:00 UTC val endTs = 1726106400000L // 09/12/2024 02:00 UTC @@ -290,45 +239,6 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { baseline.label.isEmpty mustBe false } } - - "send valid results on a correctly formed feature ts raw skew lookup request" in { - val startTs = 1725926400000L // 09/10/2024 00:00 UTC - val endTs = 1726106400000L // 09/12/2024 02:00 UTC - val result = - controller.fetchFeature("my_join", "my_feature", startTs, endTs, "skew", "null", "raw", None, None).apply(FakeRequest()) - status(result) mustBe OK - val bodyText = contentAsString(result) - val featureTSResponse: Either[Error, ComparedFeatureTimeSeries] = - decode[ComparedFeatureTimeSeries](bodyText) - featureTSResponse.isRight mustBe true - val response = featureTSResponse.right.value - response.feature mustBe "my_feature" - response.baseline.nonEmpty mustBe true - response.baseline.length mustBe response.current.length - // we expect a skew distribution at a fixed time stamp - response.baseline.foreach(p => p.ts mustBe startTs) - response.current.foreach(p => p.ts mustBe startTs) - } - - "send valid results on a correctly formed feature ts percentile skew lookup request" in { - val startTs = 1725926400000L // 09/10/2024 00:00 UTC - val endTs = 1726106400000L // 09/12/2024 02:00 UTC - val result = - controller - .fetchFeature("my_join", "my_feature", startTs, endTs, "skew", "null", "percentile", None, None) - .apply(FakeRequest()) - status(result) mustBe OK - val bodyText = contentAsString(result) - val featureTSResponse: Either[Error, FeatureTimeSeries] = decode[FeatureTimeSeries](bodyText) - featureTSResponse.isRight mustBe true - val response = featureTSResponse.right.value - response.feature mustBe "my_feature" - response.points.nonEmpty mustBe true - - // expect one entry per percentile for each time series point - val expectedLength = TimeSeriesController.mockGeneratedPercentiles.length * expectedHours(startTs, endTs) - response.points.length mustBe expectedLength - } } private def expectedHours(startTs: Long, endTs: Long): Long = { From 3657bb7f51668a86c4240a175d0e51a9ab4f330c Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Thu, 21 Nov 2024 20:54:06 -0500 Subject: [PATCH 084/152] Add joins endpoints and switch search to use joins --- hub/app/controllers/JoinController.scala | 42 ++++++++++++ hub/app/controllers/Paginate.scala | 4 +- hub/app/controllers/SearchController.scala | 17 +++-- hub/app/model/Model.scala | 5 +- hub/app/store/MonitoringModelStore.scala | 11 ++++ hub/conf/routes | 5 +- hub/test/controllers/JoinControllerSpec.scala | 65 +++++++++++++++++++ .../controllers/SearchControllerSpec.scala | 25 ++++--- 8 files changed, 150 insertions(+), 24 deletions(-) create mode 100644 hub/app/controllers/JoinController.scala create mode 100644 hub/test/controllers/JoinControllerSpec.scala diff --git a/hub/app/controllers/JoinController.scala b/hub/app/controllers/JoinController.scala new file mode 100644 index 0000000000..81fd73a970 --- /dev/null +++ b/hub/app/controllers/JoinController.scala @@ -0,0 +1,42 @@ +package controllers + +import io.circe.generic.auto._ +import io.circe.syntax._ +import model.ListJoinResponse +import play.api.mvc._ +import store.MonitoringModelStore + +import javax.inject._ + +/** + * Controller for the Zipline Join entities + */ +@Singleton +class JoinController @Inject()(val controllerComponents: ControllerComponents, + monitoringStore: MonitoringModelStore) + extends BaseController + with Paginate { + + /** + * Powers the /api/v1/joins endpoint. Returns a list of models + * @param offset - For pagination. We skip over offset entries before returning results + * @param limit - Number of elements to return + */ + def list(offset: Option[Int], limit: Option[Int]): Action[AnyContent] = + Action { implicit request: Request[AnyContent] => + // Default values if the parameters are not provided + val offsetValue = offset.getOrElse(defaultOffset) + val limitValue = limit.map(l => math.min(l, maxLimit)).getOrElse(defaultLimit) + + if (offsetValue < 0) { + BadRequest("Invalid offset - expect a positive number") + } else if (limitValue < 0) { + BadRequest("Invalid limit - expect a positive number") + } else { + val joins = monitoringStore.getJoins + val paginatedResults = paginateResults(joins, offsetValue, limitValue) + val json = ListJoinResponse(offsetValue, paginatedResults).asJson.noSpaces + Ok(json) + } + } +} diff --git a/hub/app/controllers/Paginate.scala b/hub/app/controllers/Paginate.scala index d77060cded..86a4eec3e5 100644 --- a/hub/app/controllers/Paginate.scala +++ b/hub/app/controllers/Paginate.scala @@ -1,13 +1,11 @@ package controllers -import model.Model - trait Paginate { val defaultOffset = 0 val defaultLimit = 10 val maxLimit = 100 - def paginateResults(results: Seq[Model], offset: Int, limit: Int): Seq[Model] = { + def paginateResults[T](results: Seq[T], offset: Int, limit: Int): Seq[T] = { results.slice(offset, offset + limit) } } diff --git a/hub/app/controllers/SearchController.scala b/hub/app/controllers/SearchController.scala index cb36e76a62..a6bd1a8ead 100644 --- a/hub/app/controllers/SearchController.scala +++ b/hub/app/controllers/SearchController.scala @@ -2,8 +2,7 @@ package controllers import io.circe.generic.auto._ import io.circe.syntax._ -import model.Model -import model.SearchModelResponse +import model.{Join, SearchJoinResponse} import play.api.mvc._ import store.MonitoringModelStore @@ -18,8 +17,8 @@ class SearchController @Inject() (val controllerComponents: ControllerComponents with Paginate { /** - * Powers the /api/v1/search endpoint. Returns a list of models - * @param term - Search term to search for (currently we only support searching model names) + * Powers the /api/v1/search endpoint. Returns a list of joins + * @param term - Search term to search for (currently we only support searching join names) * @param offset - For pagination. We skip over offset entries before returning results * @param limit - Number of elements to return */ @@ -36,14 +35,14 @@ class SearchController @Inject() (val controllerComponents: ControllerComponents } else { val searchResults = searchRegistry(term) val paginatedResults = paginateResults(searchResults, offsetValue, limitValue) - val json = SearchModelResponse(offsetValue, paginatedResults).asJson.noSpaces + val json = SearchJoinResponse(offsetValue, paginatedResults).asJson.noSpaces Ok(json) } } - // a trivial search where we check the model name for similarity with the search term - private def searchRegistry(term: String): Seq[Model] = { - val models = monitoringStore.getModels - models.filter(m => m.name.contains(term)) + // a trivial search where we check the join name for similarity with the search term + private def searchRegistry(term: String): Seq[Join] = { + val joins = monitoringStore.getJoins + joins.filter(j => j.name.contains(term)) } } diff --git a/hub/app/model/Model.scala b/hub/app/model/Model.scala index a14ce4d679..f3e2d9f54c 100644 --- a/hub/app/model/Model.scala +++ b/hub/app/model/Model.scala @@ -56,9 +56,10 @@ case class FeatureTimeSeries(feature: String, points: Seq[TimeSeriesPoint]) case class ComparedFeatureTimeSeries(feature: String, baseline: Seq[TimeSeriesPoint], current: Seq[TimeSeriesPoint]) case class GroupByTimeSeries(name: String, items: Seq[FeatureTimeSeries]) -// Currently search only covers models -case class SearchModelResponse(offset: Int, items: Seq[Model]) +// Currently search only covers joins case class ListModelResponse(offset: Int, items: Seq[Model]) +case class SearchJoinResponse(offset: Int, items: Seq[Join]) +case class ListJoinResponse(offset: Int, items: Seq[Join]) case class ModelTimeSeriesResponse(id: String, items: Seq[TimeSeriesPoint]) case class JoinTimeSeriesResponse(name: String, items: Seq[GroupByTimeSeries]) diff --git a/hub/app/store/MonitoringModelStore.scala b/hub/app/store/MonitoringModelStore.scala index d4d67d8f4b..9dd11d280c 100644 --- a/hub/app/store/MonitoringModelStore.scala +++ b/hub/app/store/MonitoringModelStore.scala @@ -69,6 +69,17 @@ class MonitoringModelStore(apiImpl: Api) { } } + def getJoins: Seq[Join] = { + configRegistryCache("default").joins.map { thriftJoin => + val groupBys = thriftJoin.joinParts.asScala.map { part => + GroupBy(part.groupBy.metaData.name, part.groupBy.valueColumns) + } + + val outputColumns = thriftJoin.outputColumnsByGroup.getOrElse("derivations", Array.empty) + Join(thriftJoin.metaData.name, outputColumns, groupBys) + } + } + val logger: Logger = Logger(this.getClass) val defaultListLookupLimit: Int = 100 diff --git a/hub/conf/routes b/hub/conf/routes index 896fb7655c..5d24c3bd52 100644 --- a/hub/conf/routes +++ b/hub/conf/routes @@ -1,10 +1,11 @@ # Backend APIs GET /api/v1/ping controllers.ApplicationController.ping() GET /api/v1/models controllers.ModelController.list(offset: Option[Int], limit: Option[Int]) +GET /api/v1/joins controllers.JoinController.list(offset: Option[Int], limit: Option[Int]) GET /api/v1/search controllers.SearchController.search(term: String, offset: Option[Int], limit: Option[Int]) # model prediction & model drift - this is TBD at the moment -# GET /api/v1/model/:id/timeseries controllers.TimeSeriesController.fetchModel(id: String, startTs: Long, endTs: Long, offset: String, algorithm: String) +# GET /api/v1/model/:id/timeseries controllers.TimeSeriesController.fetchModel(id: String, startTs: Long, endTs: Long, offset: String, algorithm: String) # all timeseries of a given join id # when metricType == "drift" - will return time series list of drift values @@ -12,7 +13,7 @@ GET /api/v1/join/:name/timeseries controllers.TimeSeriesC # join -> seq(feature) # when metricType == "drift" - will return time series list of drift values -GET /api/v1/join/:join/feature/:name/timeseries controllers.TimeSeriesController.fetchFeature(join: String, name: String, startTs: Long, endTs: Long, metricType: String, metrics: String, granularity: String, offset: Option[String], algorithm: Option[String]) +GET /api/v1/join/:join/feature/:name/timeseries controllers.TimeSeriesController.fetchFeature(join: String, name: String, startTs: Long, endTs: Long, metricType: String, metrics: String, granularity: String, offset: Option[String], algorithm: Option[String]) # TODO - move the core flow to fine-grained end-points diff --git a/hub/test/controllers/JoinControllerSpec.scala b/hub/test/controllers/JoinControllerSpec.scala new file mode 100644 index 0000000000..b6924cfdd1 --- /dev/null +++ b/hub/test/controllers/JoinControllerSpec.scala @@ -0,0 +1,65 @@ +package controllers + +import controllers.MockJoinService.mockJoinRegistry +import io.circe._ +import io.circe.generic.auto._ +import io.circe.parser._ +import model.ListJoinResponse +import org.mockito.Mockito._ +import org.scalatest.EitherValues +import org.scalatestplus.play._ +import play.api.http.Status.BAD_REQUEST +import play.api.http.Status.OK +import play.api.mvc._ +import play.api.test.Helpers._ +import play.api.test._ +import store.MonitoringModelStore + +class JoinControllerSpec extends PlaySpec with Results with EitherValues { + + // Create a stub ControllerComponents + val stubCC: ControllerComponents = stubControllerComponents() + // Create a mocked DynDB store + val mockedStore: MonitoringModelStore = mock(classOf[MonitoringModelStore]) + + val controller = new JoinController(stubCC, mockedStore) + + "JoinController" should { + + "send 400 on a bad offset" in { + val result = controller.list(Some(-1), Some(10)).apply(FakeRequest()) + status(result) mustBe BAD_REQUEST + } + + "send 400 on a bad limit" in { + val result = controller.list(Some(10), Some(-2)).apply(FakeRequest()) + status(result) mustBe BAD_REQUEST + } + + "send valid results on a correctly formed request" in { + when(mockedStore.getJoins).thenReturn(mockJoinRegistry) + + val result = controller.list(None, None).apply(FakeRequest()) + status(result) mustBe OK + val bodyText = contentAsString(result) + val listJoinResponse: Either[Error, ListJoinResponse] = decode[ListJoinResponse](bodyText) + val items = listJoinResponse.right.value.items + items.length mustBe controller.defaultLimit + items.map(_.name.toInt).toSet mustBe (0 until 10).toSet + } + + "send results in a paginated fashion correctly" in { + when(mockedStore.getJoins).thenReturn(mockJoinRegistry) + + val startOffset = 25 + val number = 20 + val result = controller.list(Some(startOffset), Some(number)).apply(FakeRequest()) + status(result) mustBe OK + val bodyText = contentAsString(result) + val listJoinResponse: Either[Error, ListJoinResponse] = decode[ListJoinResponse](bodyText) + val items = listJoinResponse.right.value.items + items.length mustBe number + items.map(_.name.toInt).toSet mustBe (startOffset until startOffset + number).toSet + } + } +} diff --git a/hub/test/controllers/SearchControllerSpec.scala b/hub/test/controllers/SearchControllerSpec.scala index 5510ea4010..b188336fa9 100644 --- a/hub/test/controllers/SearchControllerSpec.scala +++ b/hub/test/controllers/SearchControllerSpec.scala @@ -1,10 +1,10 @@ package controllers -import controllers.MockDataService.mockModelRegistry +import controllers.MockJoinService.mockJoinRegistry import io.circe._ import io.circe.generic.auto._ import io.circe.parser._ -import model.ListModelResponse +import model.{GroupBy, Join, ListJoinResponse} import org.mockito.Mockito.mock import org.mockito.Mockito.when import org.scalatest.EitherValues @@ -38,19 +38,19 @@ class SearchControllerSpec extends PlaySpec with Results with EitherValues { } "send valid results on a correctly formed request" in { - when(mockedStore.getModels).thenReturn(mockModelRegistry) + when(mockedStore.getJoins).thenReturn(mockJoinRegistry) val result = controller.search("1", None, None).apply(FakeRequest()) status(result) mustBe OK val bodyText = contentAsString(result) - val listModelResponse: Either[Error, ListModelResponse] = decode[ListModelResponse](bodyText) - val items = listModelResponse.right.value.items + val listJoinResponse: Either[Error, ListJoinResponse] = decode[ListJoinResponse](bodyText) + val items = listJoinResponse.right.value.items items.length mustBe controller.defaultLimit items.map(_.name.toInt).toSet mustBe Set(1, 10, 11, 12, 13, 14, 15, 16, 17, 18) } "send results in a paginated fashion correctly" in { - when(mockedStore.getModels).thenReturn(mockModelRegistry) + when(mockedStore.getJoins).thenReturn(mockJoinRegistry) val startOffset = 3 val number = 6 @@ -60,10 +60,19 @@ class SearchControllerSpec extends PlaySpec with Results with EitherValues { val expected = Set(12, 13, 14, 15, 16, 17) status(result) mustBe OK val bodyText = contentAsString(result) - val listModelResponse: Either[Error, ListModelResponse] = decode[ListModelResponse](bodyText) - val items = listModelResponse.right.value.items + val listJoinResponse: Either[Error, ListJoinResponse] = decode[ListJoinResponse](bodyText) + val items = listJoinResponse.right.value.items items.length mustBe number items.map(_.name.toInt).toSet mustBe expected } } } + +object MockJoinService { + def generateMockJoin(id: String): Join = { + val groupBys = Seq(GroupBy("my_groupBy", Seq("g1", "g2"))) + Join(id, Seq("ext_f1", "ext_f2", "d_1", "d2"), groupBys) + } + + val mockJoinRegistry: Seq[Join] = (0 until 100).map(i => generateMockJoin(i.toString)) +} From 622405af048fa131c95dbe61eb1eebd5041006f9 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Mon, 25 Nov 2024 10:04:16 -0500 Subject: [PATCH 085/152] Switch to correct metadata table --- online/src/main/scala/ai/chronon/online/MetadataEndPoint.scala | 2 +- spark/src/main/scala/ai/chronon/spark/Driver.scala | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/online/src/main/scala/ai/chronon/online/MetadataEndPoint.scala b/online/src/main/scala/ai/chronon/online/MetadataEndPoint.scala index cb1bc52e91..114d3b8ada 100644 --- a/online/src/main/scala/ai/chronon/online/MetadataEndPoint.scala +++ b/online/src/main/scala/ai/chronon/online/MetadataEndPoint.scala @@ -22,7 +22,7 @@ case class MetadataEndPoint[Conf <: TBase[_, _]: Manifest: ClassTag]( object MetadataEndPoint { @transient implicit lazy val logger: Logger = LoggerFactory.getLogger(getClass) - val ConfByKeyEndPointName = "ZIPLINE_METADATA" + val ConfByKeyEndPointName = "CHRONON_METADATA" val NameByTeamEndPointName = "CHRONON_ENTITY_BY_TEAM" private def getTeamFromMetadata(metaData: MetaData): String = { diff --git a/spark/src/main/scala/ai/chronon/spark/Driver.scala b/spark/src/main/scala/ai/chronon/spark/Driver.scala index 92489a90fb..103f382ee8 100644 --- a/spark/src/main/scala/ai/chronon/spark/Driver.scala +++ b/spark/src/main/scala/ai/chronon/spark/Driver.scala @@ -18,6 +18,7 @@ package ai.chronon.spark import ai.chronon.api import ai.chronon.api.Constants +import ai.chronon.api.Constants.MetadataDataset import ai.chronon.api.Extensions.GroupByOps import ai.chronon.api.Extensions.MetadataOps import ai.chronon.api.Extensions.SourceOps @@ -565,7 +566,7 @@ object Driver { lazy val api: Api = impl(serializableProps) def metaDataStore = - new MetadataStore(impl(serializableProps).genKvStore, "ZIPLINE_METADATA", timeoutMillis = 10000) + new MetadataStore(impl(serializableProps).genKvStore, MetadataDataset, timeoutMillis = 10000) def impl(props: Map[String, String]): Api = { val urls = Array(new File(onlineJar()).toURI.toURL) From db0619c26cb8a7355cbb80f6f04fc8e424cc787c Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Sun, 24 Nov 2024 15:32:54 -0800 Subject: [PATCH 086/152] observability script for demo --- docker-init/start.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker-init/start.sh b/docker-init/start.sh index 9f8b39d9f1..64b777a76e 100644 --- a/docker-init/start.sh +++ b/docker-init/start.sh @@ -19,7 +19,7 @@ fi # Load up metadata into DynamoDB echo "Loading metadata.." -if ! java -cp $SPARK_JAR:$CLASSPATH ai.chronon.spark.Driver metadata-upload --conf-path=/chronon_sample/production/ --online-jar=$CLOUD_AWS_JAR --online-class=$ONLINE_CLASS; then +if ! java -Dlog4j.configurationFile=log4j.properties -cp $SPARK_JAR:$CLASSPATH ai.chronon.spark.Driver metadata-upload --conf-path=/chronon_sample/production/ --online-jar=$CLOUD_AWS_JAR --online-class=$ONLINE_CLASS; then echo "Error: Failed to load metadata into DynamoDB" >&2 exit 1 fi @@ -27,7 +27,7 @@ echo "Metadata load completed successfully!" # Initialize DynamoDB echo "Initializing DynamoDB Table .." -if ! output=$(java -cp $SPARK_JAR:$CLASSPATH ai.chronon.spark.Driver create-summary-dataset \ +if ! output=$(java -Dlog4j.configurationFile=log4j.properties -cp $SPARK_JAR:$CLASSPATH ai.chronon.spark.Driver create-summary-dataset \ --online-jar=$CLOUD_AWS_JAR \ --online-class=$ONLINE_CLASS 2>&1); then echo "Error: Failed to bring up DynamoDB table" >&2 From 12db7cdd4cf936f660050584a1c87a5d6d475018 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Sun, 24 Nov 2024 18:32:26 -0800 Subject: [PATCH 087/152] running observability demo --- spark/src/main/scala/ai/chronon/spark/TableUtils.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index 09aca39240..d516d7a5f0 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -327,9 +327,9 @@ case class TableUtils(sparkSession: SparkSession) { sql(creationSql) } catch { case _: TableAlreadyExistsException => - logger.info(s"Table $tableName already exists, skipping creation") + println(s"Table $tableName already exists, skipping creation") case e: Exception => - logger.error(s"Failed to create table $tableName", e) + println(s"Failed to create table $tableName", e) throw e } } @@ -357,6 +357,7 @@ case class TableUtils(sparkSession: SparkSession) { // so that an exception will be thrown below dfRearranged } + println(s"Repartitioning and writing into table $tableName".yellow) repartitionAndWrite(finalizedDf, tableName, saveMode, stats, sortByCols) } From 0a8c8b327678578f53a1e9938d81d70b932e7ef9 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Mon, 25 Nov 2024 22:37:56 -0500 Subject: [PATCH 088/152] Add support for in-memory controller + kv store module --- build.sbt | 16 +++-- .../controllers/InMemKVStoreController.scala | 37 ++++++++++ hub/app/module/DriftStoreModule.scala | 18 +++-- hub/app/module/InMemoryKVStoreModule.scala | 18 +++++ hub/conf/application.conf | 1 + hub/conf/routes | 2 + .../scala/ai/chronon/online/HTTPKVStore.scala | 69 +++++++++++++++++++ 7 files changed, 151 insertions(+), 10 deletions(-) create mode 100644 hub/app/controllers/InMemKVStoreController.scala create mode 100644 hub/app/module/InMemoryKVStoreModule.scala create mode 100644 online/src/main/scala/ai/chronon/online/HTTPKVStore.scala diff --git a/build.sbt b/build.sbt index 6c60365dcb..139945dcba 100644 --- a/build.sbt +++ b/build.sbt @@ -80,6 +80,12 @@ val jackson = Seq( "com.fasterxml.jackson.module" %% "jackson-module-scala" ).map(_ % jackson_2_15) +val circe = Seq( + "io.circe" %% "circe-core", + "io.circe" %% "circe-generic", + "io.circe" %% "circe-parser", +).map(_ % circeVersion) + val flink_all = Seq( "org.apache.flink" %% "flink-streaming-scala", "org.apache.flink" % "flink-metrics-dropwizard", @@ -129,6 +135,10 @@ lazy val online = project "com.github.ben-manes.caffeine" % "caffeine" % "3.1.8" ), libraryDependencies ++= jackson, + // we pull in circe to help us ser case classes like PutRequest without requiring annotations + libraryDependencies ++= circe, + // dep needed for HTTPKvStore - yank when we rip this out + libraryDependencies += "com.softwaremill.sttp.client3" %% "core" % "3.9.7", libraryDependencies ++= spark_all.map(_ % "provided"), libraryDependencies ++= flink_all.map(_ % "provided") ) @@ -236,20 +246,18 @@ lazy val frontend = (project in file("frontend")) // build interop between one module solely on 2.13 and others on 2.12 is painful lazy val hub = (project in file("hub")) .enablePlugins(PlayScala) - .dependsOn(cloud_aws) + .dependsOn(cloud_aws, spark) .settings( name := "hub", libraryDependencies ++= Seq( guice, "org.scalatestplus.play" %% "scalatestplus-play" % "5.1.0" % Test, "org.scalatestplus" %% "mockito-3-4" % "3.2.10.0" % "test", - "io.circe" %% "circe-core" % circeVersion, - "io.circe" %% "circe-generic" % circeVersion, - "io.circe" %% "circe-parser" % circeVersion, "org.scala-lang.modules" %% "scala-xml" % "2.1.0", "org.scala-lang.modules" %% "scala-parser-combinators" % "2.3.0", "org.scala-lang.modules" %% "scala-java8-compat" % "1.0.2" ), + libraryDependencies ++= circe, libraryDependencySchemes ++= Seq( "org.scala-lang.modules" %% "scala-xml" % VersionScheme.Always, "org.scala-lang.modules" %% "scala-parser-combinators" % VersionScheme.Always, diff --git a/hub/app/controllers/InMemKVStoreController.scala b/hub/app/controllers/InMemKVStoreController.scala new file mode 100644 index 0000000000..82e3f1979d --- /dev/null +++ b/hub/app/controllers/InMemKVStoreController.scala @@ -0,0 +1,37 @@ +package controllers + +import ai.chronon.online.KVStore +import ai.chronon.online.KVStore.PutRequest +import play.api.mvc.{BaseController, ControllerComponents} +import io.circe.parser.decode +import play.api.Logger + +import javax.inject.Inject +import scala.concurrent.{ExecutionContext, Future} + +class InMemKVStoreController @Inject() (val controllerComponents: ControllerComponents, kvStore: KVStore)(implicit ec: ExecutionContext) extends BaseController { + + import ai.chronon.online.PutRequestCodec._ + + val logger: Logger = Logger(this.getClass) + + def bulkPut() = Action(parse.raw).async { request => + request.body.asBytes() match { + case Some(bytes) => + decode[Array[PutRequest]](bytes.utf8String) match { + case Right(putRequests) => + logger.info(s"Attempting a bulkPut with ${putRequests.length} items") + val resultFuture = kvStore.multiPut(putRequests) + resultFuture.map { + responses => + if (responses.contains(false)) { + logger.warn(s"Some write failures encountered") + } + Ok("Success") + } + case Left(error) => Future.successful(BadRequest(error.getMessage)) + } + case None => Future.successful(BadRequest("Empty body")) + } + } +} diff --git a/hub/app/module/DriftStoreModule.scala b/hub/app/module/DriftStoreModule.scala index ee831e5cad..b8c12786e3 100644 --- a/hub/app/module/DriftStoreModule.scala +++ b/hub/app/module/DriftStoreModule.scala @@ -1,15 +1,21 @@ package module -import ai.chronon.integrations.aws.AwsApiImpl +import ai.chronon.online.KVStore import ai.chronon.online.stats.DriftStore import com.google.inject.AbstractModule -import play.api.{Configuration, Environment} -class DriftStoreModule(environment: Environment, configuration: Configuration) extends AbstractModule { +import javax.inject.{Inject, Provider} + +class DriftStoreModule extends AbstractModule { override def configure(): Unit = { - val awsApiImpl = new AwsApiImpl(Map.empty) - val driftStore = new DriftStore(awsApiImpl.genKvStore) - bind(classOf[DriftStore]).toInstance(driftStore) + // TODO swap to concrete api impl in a follow up + bind(classOf[DriftStore]).toProvider(classOf[DriftStoreProvider]).asEagerSingleton() + } +} + +class DriftStoreProvider @Inject()(kvStore: KVStore) extends Provider[DriftStore] { + override def get(): DriftStore = { + new DriftStore(kvStore) } } diff --git a/hub/app/module/InMemoryKVStoreModule.scala b/hub/app/module/InMemoryKVStoreModule.scala new file mode 100644 index 0000000000..2467ca4573 --- /dev/null +++ b/hub/app/module/InMemoryKVStoreModule.scala @@ -0,0 +1,18 @@ +package module + +import ai.chronon.api.Constants +import ai.chronon.online.KVStore +import ai.chronon.spark.utils.InMemoryKvStore +import com.google.inject.AbstractModule + +// Module that creates and injects an in-memory kv store implementation to allow for quick docker testing +class InMemoryKVStoreModule extends AbstractModule { + + override def configure(): Unit = { + val inMemoryKVStore = InMemoryKvStore.build("hub", () => null) + // create relevant datasets in kv store + inMemoryKVStore.create(Constants.MetadataDataset) + inMemoryKVStore.create(Constants.TiledSummaryDataset) + bind(classOf[KVStore]).toInstance(inMemoryKVStore) + } +} diff --git a/hub/conf/application.conf b/hub/conf/application.conf index 5696edf2d2..292f0d42fc 100644 --- a/hub/conf/application.conf +++ b/hub/conf/application.conf @@ -29,4 +29,5 @@ play.filters.cors { # Add DynamoDB module play.modules.enabled += "module.ModelStoreModule" +play.modules.enabled += "module.InMemoryKVStoreModule" play.modules.enabled += "module.DriftStoreModule" diff --git a/hub/conf/routes b/hub/conf/routes index 5d24c3bd52..8939447e6e 100644 --- a/hub/conf/routes +++ b/hub/conf/routes @@ -15,6 +15,8 @@ GET /api/v1/join/:name/timeseries controllers.TimeSeriesC # when metricType == "drift" - will return time series list of drift values GET /api/v1/join/:join/feature/:name/timeseries controllers.TimeSeriesController.fetchFeature(join: String, name: String, startTs: Long, endTs: Long, metricType: String, metrics: String, granularity: String, offset: Option[String], algorithm: Option[String]) +# Temporary in-memory kv store endpoint +POST /api/v1/dataset/data controllers.InMemKVStoreController.bulkPut() # TODO - move the core flow to fine-grained end-points #GET /api/v1/feature/:name/timeseries controllers.TimeSeriesController.fetchFeature(name: String, startTs: Long, endTs: Long, metricType: String, metrics: String, granularity: String, offset: Option[String], algorithm: Option[String]) diff --git a/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala b/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala new file mode 100644 index 0000000000..0514a5aab5 --- /dev/null +++ b/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala @@ -0,0 +1,69 @@ +package ai.chronon.online + +import ai.chronon.online.KVStore.PutRequest +import io.circe._ +import io.circe.generic.semiauto._ +import io.circe.syntax._ +import sttp.client3._ +import sttp.model.StatusCode + +import java.util.Base64 +import scala.concurrent.Future + +// Hacky test kv store that we use to send objects to the in-memory KV store that lives in a different JVM (e.g spark -> hub) +class HTTPKVStore(host: String = "localhost", port: Int = 9000) extends KVStore with Serializable { + import PutRequestCodec._ + + val backend = HttpClientSyncBackend() + val baseUrl = s"http://$host:$port/api/v1/dataset" + + override def multiGet(requests: collection.Seq[KVStore.GetRequest]): Future[collection.Seq[KVStore.GetResponse]] = ??? + + override def multiPut(putRequests: collection.Seq[KVStore.PutRequest]): Future[collection.Seq[Boolean]] = { + if (putRequests.isEmpty) { + Future.successful(Seq.empty) + } else { + // typically should see the same dataset but we break up our calls by dataset to be safe + val requestsByDataset = putRequests.groupBy(_.dataset) + val futures: Seq[Future[Boolean]] = requestsByDataset.map { + case (dataset, requests) => + Future { + basicRequest + .post(uri"$baseUrl/$dataset/data") + .header("Content-Type", "application/json") + .body(requests.asJson.noSpaces) + .send(backend) + }.map { + response => + response.code match { + case StatusCode.Ok => true + case _ => + logger.error(s"HTTP multiPut failed with status ${response.code}: ${response.body}") + false + } + } + }.toSeq + + Future.sequence(futures) + } + } + + override def bulkPut(sourceOfflineTable: String, destinationOnlineDataSet: String, partition: String): Unit = ??? + + override def create(dataset: String): Unit = { + logger.warn(s"Skipping creation of $dataset in HTTP kv store implementation") + } +} + +object PutRequestCodec { + // Custom codec for byte arrays using Base64 + implicit val byteArrayEncoder: Encoder[Array[Byte]] = + Encoder.encodeString.contramap[Array[Byte]](Base64.getEncoder.encodeToString) + + implicit val byteArrayDecoder: Decoder[Array[Byte]] = + Decoder.decodeString.map(Base64.getDecoder.decode) + + // Derive codec for PutRequest + implicit val putRequestCodec: Codec[PutRequest] = deriveCodec[PutRequest] +} + From e7e2d1685414e9dfa771b48e4ca8804e9d4414b6 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Tue, 26 Nov 2024 10:30:43 -0500 Subject: [PATCH 089/152] Clean up scripts to load data and query via time series controller --- docker-init/Dockerfile | 1 + docker-init/demo/README.md | 20 +++++++ docker-init/demo/build.sh | 1 - docker-init/demo/load_summaries.sh | 12 ++++ docker-init/start.sh | 21 +------ .../controllers/InMemKVStoreController.scala | 17 +++++- .../controllers/TimeSeriesController.scala | 10 ++-- .../scala/ai/chronon/online/HTTPKVStore.scala | 57 ++++++++----------- .../spark/scripts/ObservabilityDemo.scala | 4 +- 9 files changed, 83 insertions(+), 60 deletions(-) delete mode 100755 docker-init/demo/build.sh create mode 100755 docker-init/demo/load_summaries.sh diff --git a/docker-init/Dockerfile b/docker-init/Dockerfile index e9b36c564d..b8b24da7d4 100644 --- a/docker-init/Dockerfile +++ b/docker-init/Dockerfile @@ -43,6 +43,7 @@ ENV CHRONON_DRIVER_JAR="/app/cli/spark.jar" # Set up Spark dependencies to help with launching CLI # Copy Spark JARs from the Bitnami image COPY --from=spark-source /opt/bitnami/spark/jars /opt/spark/jars +COPY --from=spark-source /opt/bitnami/spark/bin /opt/spark/bin # Add all Spark JARs to the classpath ENV CLASSPATH=/opt/spark/jars/* diff --git a/docker-init/demo/README.md b/docker-init/demo/README.md index c1abae2d9b..a3f0807eae 100644 --- a/docker-init/demo/README.md +++ b/docker-init/demo/README.md @@ -1,5 +1,25 @@ +# Populate Observability Demo Data +To populate the observability demo data: +* Launch the set of docker containers: +```bash +~/workspace/chronon $ docker-compose -f docker-init/compose.yaml up --build +... +app-1 | [info] 2024-11-26 05:10:45,758 [main] INFO play.api.Play - Application started (Prod) (no global state) +app-1 | [info] 2024-11-26 05:10:45,958 [main] INFO play.core.server.AkkaHttpServer - Listening for HTTP on /[0:0:0:0:0:0:0:0]:9000 +``` +(you can skip the --build if you don't wish to rebuild your code) + +Now you can trigger the script to load summary data: +```bash +~/workspace/chronon $ docker-init/demo/load_summaries.sh +... +Done uploading summaries! 🥳 +``` + +# Streamlit local experimentation run build.sh once, and you can repeatedly exec to quickly visualize In first terminal: `sbt spark/assembly` In second terminal: `./run.sh` to load the built jar and serve the data on localhost:8181 In third terminal: `streamlit run viz.py` + diff --git a/docker-init/demo/build.sh b/docker-init/demo/build.sh deleted file mode 100755 index 5627dac2f5..0000000000 --- a/docker-init/demo/build.sh +++ /dev/null @@ -1 +0,0 @@ -docker build -t obs . \ No newline at end of file diff --git a/docker-init/demo/load_summaries.sh b/docker-init/demo/load_summaries.sh new file mode 100755 index 0000000000..61b4d9db95 --- /dev/null +++ b/docker-init/demo/load_summaries.sh @@ -0,0 +1,12 @@ +# Kick off the ObsDemo spark job in the app container + +docker-compose -f docker-init/compose.yaml exec app /opt/spark/bin/spark-submit \ + --master "local[*]" \ + --driver-memory 8g \ + --conf "spark.driver.maxResultSize=6g" \ + --conf "spark.driver.memory=8g" \ + --driver-class-path "/opt/spark/jars/*:/app/cli/*" \ + --conf "spark.driver.host=localhost" \ + --conf "spark.driver.bindAddress=0.0.0.0" \ + --class ai.chronon.spark.scripts.ObservabilityDemo \ + /app/cli/spark.jar diff --git a/docker-init/start.sh b/docker-init/start.sh index 64b777a76e..a3340f894a 100644 --- a/docker-init/start.sh +++ b/docker-init/start.sh @@ -19,7 +19,7 @@ fi # Load up metadata into DynamoDB echo "Loading metadata.." -if ! java -Dlog4j.configurationFile=log4j.properties -cp $SPARK_JAR:$CLASSPATH ai.chronon.spark.Driver metadata-upload --conf-path=/chronon_sample/production/ --online-jar=$CLOUD_AWS_JAR --online-class=$ONLINE_CLASS; then +if ! java -cp $SPARK_JAR:$CLASSPATH ai.chronon.spark.Driver metadata-upload --conf-path=/chronon_sample/production/ --online-jar=$CLOUD_AWS_JAR --online-class=$ONLINE_CLASS; then echo "Error: Failed to load metadata into DynamoDB" >&2 exit 1 fi @@ -27,7 +27,7 @@ echo "Metadata load completed successfully!" # Initialize DynamoDB echo "Initializing DynamoDB Table .." -if ! output=$(java -Dlog4j.configurationFile=log4j.properties -cp $SPARK_JAR:$CLASSPATH ai.chronon.spark.Driver create-summary-dataset \ +if ! output=$(java -cp $SPARK_JAR:$CLASSPATH ai.chronon.spark.Driver create-summary-dataset \ --online-jar=$CLOUD_AWS_JAR \ --online-class=$ONLINE_CLASS 2>&1); then echo "Error: Failed to bring up DynamoDB table" >&2 @@ -39,23 +39,6 @@ echo "DynamoDB Table created successfully!" start_time=$(date +%s) -if ! java \ - --add-opens=java.base/sun.nio.ch=ALL-UNNAMED \ - --add-opens=java.base/sun.security.action=ALL-UNNAMED \ - -cp $SPARK_JAR:$CLASSPATH ai.chronon.spark.Driver summarize-and-upload \ - --online-jar=$CLOUD_AWS_JAR \ - --online-class=$ONLINE_CLASS \ - --parquet-path="$(pwd)/drift_data" \ - --conf-path=/chronon_sample/production/ \ - --time-column=transaction_time; then - echo "Error: Failed to load summary data into DynamoDB" >&2 - exit 1 -else - end_time=$(date +%s) - elapsed_time=$((end_time - start_time)) - echo "Summary load completed successfully! Took $elapsed_time seconds." -fi - # Add these java options as without them we hit the below error: # throws java.lang.ClassFormatError accessible: module java.base does not "opens java.lang" to unnamed module @36328710 export JAVA_OPTS="--add-opens java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED" diff --git a/hub/app/controllers/InMemKVStoreController.scala b/hub/app/controllers/InMemKVStoreController.scala index 82e3f1979d..ba6b226fec 100644 --- a/hub/app/controllers/InMemKVStoreController.scala +++ b/hub/app/controllers/InMemKVStoreController.scala @@ -2,16 +2,19 @@ package controllers import ai.chronon.online.KVStore import ai.chronon.online.KVStore.PutRequest +import io.circe.generic.semiauto.deriveCodec +import io.circe.{Codec, Decoder, Encoder} import play.api.mvc.{BaseController, ControllerComponents} import io.circe.parser.decode import play.api.Logger +import java.util.Base64 import javax.inject.Inject import scala.concurrent.{ExecutionContext, Future} class InMemKVStoreController @Inject() (val controllerComponents: ControllerComponents, kvStore: KVStore)(implicit ec: ExecutionContext) extends BaseController { - import ai.chronon.online.PutRequestCodec._ + import PutRequestCodec._ val logger: Logger = Logger(this.getClass) @@ -35,3 +38,15 @@ class InMemKVStoreController @Inject() (val controllerComponents: ControllerComp } } } + +object PutRequestCodec { + // Custom codec for byte arrays using Base64 + implicit val byteArrayEncoder: Encoder[Array[Byte]] = + Encoder.encodeString.contramap[Array[Byte]](Base64.getEncoder.encodeToString) + + implicit val byteArrayDecoder: Decoder[Array[Byte]] = + Decoder.decodeString.map(Base64.getDecoder.decode) + + // Derive codec for PutRequest + implicit val putRequestCodec: Codec[PutRequest] = deriveCodec[PutRequest] +} diff --git a/hub/app/controllers/TimeSeriesController.scala b/hub/app/controllers/TimeSeriesController.scala index 625fcd729e..2ceaa6bd47 100644 --- a/hub/app/controllers/TimeSeriesController.scala +++ b/hub/app/controllers/TimeSeriesController.scala @@ -83,7 +83,8 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon case (_, None) => Future.successful(BadRequest("Invalid drift algorithm. Expect JSD, PSI or Hellinger")) case (Some(o), Some(driftMetric)) => val window = new Window(o.toMinutes.toInt, TimeUnit.MINUTES) - val maybeDriftSeries = driftStore.getDriftSeries(name, driftMetric, window, startTs, endTs) + val joinPath = name.replaceFirst("\\.", "/") // we need to look up in the drift store with this transformed name + val maybeDriftSeries = driftStore.getDriftSeries(joinPath, driftMetric, window, startTs, endTs) maybeDriftSeries match { case Failure(exception) => Future.successful(InternalServerError(s"Error computing join drift - ${exception.getMessage}")) case Success(driftSeriesFuture) => driftSeriesFuture.map { @@ -140,9 +141,10 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon case (_, None) => Future.successful(BadRequest("Invalid drift algorithm. Expect JSD, PSI or Hellinger")) case (Some(o), Some(driftMetric)) => val window = new Window(o.toMinutes.toInt, TimeUnit.MINUTES) + val joinPath = join.replaceFirst("\\.", "/") // we need to look up in the drift store with this transformed name if (granularity == Aggregates) { val maybeDriftSeries = - driftStore.getDriftSeries(join, driftMetric, window, startTs, endTs, Some(name)) + driftStore.getDriftSeries(joinPath, driftMetric, window, startTs, endTs, Some(name)) maybeDriftSeries match { case Failure(exception) => Future.successful(InternalServerError(s"Error computing feature drift - ${exception.getMessage}")) case Success(driftSeriesFuture) => driftSeriesFuture.map { @@ -153,8 +155,8 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon } } else { // percentiles - val maybeCurrentSummarySeries = driftStore.getSummarySeries(join, startTs, endTs, Some(name)) - val maybeBaselineSummarySeries = driftStore.getSummarySeries(join, startTs - window.millis, endTs - window.millis, Some(name)) + val maybeCurrentSummarySeries = driftStore.getSummarySeries(joinPath, startTs, endTs, Some(name)) + val maybeBaselineSummarySeries = driftStore.getSummarySeries(joinPath, startTs - window.millis, endTs - window.millis, Some(name)) (maybeCurrentSummarySeries, maybeBaselineSummarySeries) match { case (Failure(exceptionA), Failure(exceptionB)) => Future.successful(InternalServerError(s"Error computing feature percentiles for current + offset time window.\nCurrent window error: ${exceptionA.getMessage}\nOffset window error: ${exceptionB.getMessage}")) case (_, Failure(exception)) => Future.successful(InternalServerError(s"Error computing feature percentiles for offset time window - ${exception.getMessage}")) diff --git a/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala b/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala index 0514a5aab5..ff7fbdddf2 100644 --- a/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala +++ b/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala @@ -1,8 +1,6 @@ package ai.chronon.online import ai.chronon.online.KVStore.PutRequest -import io.circe._ -import io.circe.generic.semiauto._ import io.circe.syntax._ import sttp.client3._ import sttp.model.StatusCode @@ -12,7 +10,6 @@ import scala.concurrent.Future // Hacky test kv store that we use to send objects to the in-memory KV store that lives in a different JVM (e.g spark -> hub) class HTTPKVStore(host: String = "localhost", port: Int = 9000) extends KVStore with Serializable { - import PutRequestCodec._ val backend = HttpClientSyncBackend() val baseUrl = s"http://$host:$port/api/v1/dataset" @@ -23,28 +20,21 @@ class HTTPKVStore(host: String = "localhost", port: Int = 9000) extends KVStore if (putRequests.isEmpty) { Future.successful(Seq.empty) } else { - // typically should see the same dataset but we break up our calls by dataset to be safe - val requestsByDataset = putRequests.groupBy(_.dataset) - val futures: Seq[Future[Boolean]] = requestsByDataset.map { - case (dataset, requests) => - Future { - basicRequest - .post(uri"$baseUrl/$dataset/data") - .header("Content-Type", "application/json") - .body(requests.asJson.noSpaces) - .send(backend) - }.map { - response => - response.code match { - case StatusCode.Ok => true - case _ => - logger.error(s"HTTP multiPut failed with status ${response.code}: ${response.body}") - false - } + Future { + basicRequest + .post(uri"$baseUrl/data") + .header("Content-Type", "application/json") + .body(jsonList(putRequests)) + .send(backend) + }.map { + response => + response.code match { + case StatusCode.Ok => Seq(true) + case _ => + logger.error(s"HTTP multiPut failed with status ${response.code}: ${response.body}") + Seq(false) } - }.toSeq - - Future.sequence(futures) + } } } @@ -53,17 +43,18 @@ class HTTPKVStore(host: String = "localhost", port: Int = 9000) extends KVStore override def create(dataset: String): Unit = { logger.warn(s"Skipping creation of $dataset in HTTP kv store implementation") } -} -object PutRequestCodec { - // Custom codec for byte arrays using Base64 - implicit val byteArrayEncoder: Encoder[Array[Byte]] = - Encoder.encodeString.contramap[Array[Byte]](Base64.getEncoder.encodeToString) + // wire up json conversion manually to side step serialization issues in spark executors + def jsonString(request: PutRequest): String = { + val keyBase64 = Base64.getEncoder.encodeToString(request.keyBytes) + val valueBase64 = Base64.getEncoder.encodeToString(request.valueBytes) + s"""{ "keyBytes": "${keyBase64}", "valueBytes": "${valueBase64}", "dataset": "${request.dataset}", "tsMillis": ${request.tsMillis.orNull}}""".stripMargin + } - implicit val byteArrayDecoder: Decoder[Array[Byte]] = - Decoder.decodeString.map(Base64.getDecoder.decode) + def jsonList(requests: Seq[PutRequest]): String = { + val requestsJson = requests.map(jsonString(_)).mkString(", ") - // Derive codec for PutRequest - implicit val putRequestCodec: Codec[PutRequest] = deriveCodec[PutRequest] + s"[ $requestsJson ]" + } } diff --git a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala index 064b69700f..b2e5015bf6 100644 --- a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala +++ b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala @@ -10,7 +10,7 @@ import ai.chronon.api.PartitionSpec import ai.chronon.api.TileDriftSeries import ai.chronon.api.TileSummarySeries import ai.chronon.api.Window -import ai.chronon.online.KVStore +import ai.chronon.online.{HTTPKVStore, KVStore} import ai.chronon.online.stats.DriftStore import ai.chronon.spark.SparkSessionBuilder import ai.chronon.spark.TableUtils @@ -56,7 +56,7 @@ object ObservabilityDemo { // mock api impl for online fetching and uploading val kvStoreFunc: () => KVStore = () => { // cannot reuse the variable - or serialization error - val result = InMemoryKvStore.build(namespace, () => null) + val result = new HTTPKVStore() result } val api = new MockApi(kvStoreFunc, namespace) From 281fc7c8dda3589055200910534698007d12b68d Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Tue, 26 Nov 2024 10:48:32 -0500 Subject: [PATCH 090/152] Address scalafix + fmt --- .../controllers/InMemKVStoreController.scala | 48 ++++---- hub/app/controllers/JoinController.scala | 3 +- hub/app/controllers/ModelController.scala | 3 +- hub/app/controllers/SearchController.scala | 6 +- .../controllers/TimeSeriesController.scala | 105 +++++++++++------- hub/app/module/DriftStoreModule.scala | 5 +- .../controllers/SearchControllerSpec.scala | 4 +- .../TimeSeriesControllerSpec.scala | 18 ++- .../scala/ai/chronon/online/HTTPKVStore.scala | 21 ++-- .../ai/chronon/online/stats/DriftStore.scala | 25 ++++- .../spark/scripts/ObservabilityDemo.scala | 18 +-- 11 files changed, 148 insertions(+), 108 deletions(-) diff --git a/hub/app/controllers/InMemKVStoreController.scala b/hub/app/controllers/InMemKVStoreController.scala index ba6b226fec..1101b0454c 100644 --- a/hub/app/controllers/InMemKVStoreController.scala +++ b/hub/app/controllers/InMemKVStoreController.scala @@ -2,41 +2,49 @@ package controllers import ai.chronon.online.KVStore import ai.chronon.online.KVStore.PutRequest +import io.circe.Codec +import io.circe.Decoder +import io.circe.Encoder import io.circe.generic.semiauto.deriveCodec -import io.circe.{Codec, Decoder, Encoder} -import play.api.mvc.{BaseController, ControllerComponents} import io.circe.parser.decode import play.api.Logger +import play.api.mvc.BaseController +import play.api.mvc.ControllerComponents import java.util.Base64 import javax.inject.Inject -import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.ExecutionContext +import scala.concurrent.Future +import play.api.mvc +import play.api.mvc.RawBuffer -class InMemKVStoreController @Inject() (val controllerComponents: ControllerComponents, kvStore: KVStore)(implicit ec: ExecutionContext) extends BaseController { +class InMemKVStoreController @Inject() (val controllerComponents: ControllerComponents, kvStore: KVStore)(implicit + ec: ExecutionContext) + extends BaseController { import PutRequestCodec._ val logger: Logger = Logger(this.getClass) - def bulkPut() = Action(parse.raw).async { request => - request.body.asBytes() match { - case Some(bytes) => - decode[Array[PutRequest]](bytes.utf8String) match { - case Right(putRequests) => - logger.info(s"Attempting a bulkPut with ${putRequests.length} items") - val resultFuture = kvStore.multiPut(putRequests) - resultFuture.map { - responses => + def bulkPut(): mvc.Action[RawBuffer] = + Action(parse.raw).async { request => + request.body.asBytes() match { + case Some(bytes) => + decode[Array[PutRequest]](bytes.utf8String) match { + case Right(putRequests) => + logger.info(s"Attempting a bulkPut with ${putRequests.length} items") + val resultFuture = kvStore.multiPut(putRequests) + resultFuture.map { responses => if (responses.contains(false)) { - logger.warn(s"Some write failures encountered") + logger.warn("Some write failures encountered") } - Ok("Success") - } - case Left(error) => Future.successful(BadRequest(error.getMessage)) - } - case None => Future.successful(BadRequest("Empty body")) + Ok("Success") + } + case Left(error) => Future.successful(BadRequest(error.getMessage)) + } + case None => Future.successful(BadRequest("Empty body")) + } } - } } object PutRequestCodec { diff --git a/hub/app/controllers/JoinController.scala b/hub/app/controllers/JoinController.scala index 81fd73a970..f65de2976c 100644 --- a/hub/app/controllers/JoinController.scala +++ b/hub/app/controllers/JoinController.scala @@ -12,8 +12,7 @@ import javax.inject._ * Controller for the Zipline Join entities */ @Singleton -class JoinController @Inject()(val controllerComponents: ControllerComponents, - monitoringStore: MonitoringModelStore) +class JoinController @Inject() (val controllerComponents: ControllerComponents, monitoringStore: MonitoringModelStore) extends BaseController with Paginate { diff --git a/hub/app/controllers/ModelController.scala b/hub/app/controllers/ModelController.scala index 40ef41a56c..66e197191e 100644 --- a/hub/app/controllers/ModelController.scala +++ b/hub/app/controllers/ModelController.scala @@ -12,8 +12,7 @@ import javax.inject._ * Controller for the Zipline models entities */ @Singleton -class ModelController @Inject() (val controllerComponents: ControllerComponents, - monitoringStore: MonitoringModelStore) +class ModelController @Inject() (val controllerComponents: ControllerComponents, monitoringStore: MonitoringModelStore) extends BaseController with Paginate { diff --git a/hub/app/controllers/SearchController.scala b/hub/app/controllers/SearchController.scala index a6bd1a8ead..075d03de6b 100644 --- a/hub/app/controllers/SearchController.scala +++ b/hub/app/controllers/SearchController.scala @@ -2,7 +2,8 @@ package controllers import io.circe.generic.auto._ import io.circe.syntax._ -import model.{Join, SearchJoinResponse} +import model.Join +import model.SearchJoinResponse import play.api.mvc._ import store.MonitoringModelStore @@ -11,8 +12,7 @@ import javax.inject._ /** * Controller to power search related APIs */ -class SearchController @Inject() (val controllerComponents: ControllerComponents, - monitoringStore: MonitoringModelStore) +class SearchController @Inject() (val controllerComponents: ControllerComponents, monitoringStore: MonitoringModelStore) extends BaseController with Paginate { diff --git a/hub/app/controllers/TimeSeriesController.scala b/hub/app/controllers/TimeSeriesController.scala index 2ceaa6bd47..ca5207c688 100644 --- a/hub/app/controllers/TimeSeriesController.scala +++ b/hub/app/controllers/TimeSeriesController.scala @@ -1,6 +1,10 @@ package controllers +import ai.chronon.api.DriftMetric import ai.chronon.api.Extensions.WindowOps -import ai.chronon.api.{DriftMetric, TileDriftSeries, TileSummarySeries, TimeUnit, Window} +import ai.chronon.api.TileDriftSeries +import ai.chronon.api.TileSummarySeries +import ai.chronon.api.TimeUnit +import ai.chronon.api.Window import ai.chronon.online.stats.DriftStore import io.circe.generic.auto._ import io.circe.syntax._ @@ -8,16 +12,20 @@ import model._ import play.api.mvc._ import javax.inject._ -import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.ExecutionContext +import scala.concurrent.Future import scala.concurrent.duration._ -import scala.util.{Failure, Random, Success} import scala.jdk.CollectionConverters._ +import scala.util.Failure +import scala.util.Success /** * Controller that serves various time series endpoints at the model, join and feature level */ @Singleton -class TimeSeriesController @Inject() (val controllerComponents: ControllerComponents, driftStore: DriftStore)(implicit ec: ExecutionContext) extends BaseController { +class TimeSeriesController @Inject() (val controllerComponents: ControllerComponents, driftStore: DriftStore)(implicit + ec: ExecutionContext) + extends BaseController { import TimeSeriesController._ @@ -79,25 +87,29 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon algorithm: Option[String]): Future[Result] = { (parseOffset(offset), parseAlgorithm(algorithm)) match { - case (None, _) => Future.successful(BadRequest(s"Unable to parse offset - $offset")) - case (_, None) => Future.successful(BadRequest("Invalid drift algorithm. Expect JSD, PSI or Hellinger")) + case (None, _) => Future.successful(BadRequest(s"Unable to parse offset - $offset")) + case (_, None) => Future.successful(BadRequest("Invalid drift algorithm. Expect JSD, PSI or Hellinger")) case (Some(o), Some(driftMetric)) => val window = new Window(o.toMinutes.toInt, TimeUnit.MINUTES) val joinPath = name.replaceFirst("\\.", "/") // we need to look up in the drift store with this transformed name val maybeDriftSeries = driftStore.getDriftSeries(joinPath, driftMetric, window, startTs, endTs) maybeDriftSeries match { - case Failure(exception) => Future.successful(InternalServerError(s"Error computing join drift - ${exception.getMessage}")) - case Success(driftSeriesFuture) => driftSeriesFuture.map { - driftSeries => + case Failure(exception) => + Future.successful(InternalServerError(s"Error computing join drift - ${exception.getMessage}")) + case Success(driftSeriesFuture) => + driftSeriesFuture.map { driftSeries => // pull up a list of drift series objects for all the features in a group val grpToDriftSeriesList: Map[String, Seq[TileDriftSeries]] = driftSeries.groupBy(_.key.groupName) val groupByTimeSeries = grpToDriftSeriesList.map { - case (name, featureDriftSeriesInfoSeq) => GroupByTimeSeries(name, featureDriftSeriesInfoSeq.map(series => convertTileDriftSeriesInfoToTimeSeries(series, metric))) + case (name, featureDriftSeriesInfoSeq) => + GroupByTimeSeries( + name, + featureDriftSeriesInfoSeq.map(series => convertTileDriftSeriesInfoToTimeSeries(series, metric))) }.toSeq - val tsData = JoinTimeSeriesResponse(name, groupByTimeSeries) - Ok(tsData.asJson.noSpaces) - } + val tsData = JoinTimeSeriesResponse(name, groupByTimeSeries) + Ok(tsData.asJson.noSpaces) + } } } } @@ -141,35 +153,45 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon case (_, None) => Future.successful(BadRequest("Invalid drift algorithm. Expect JSD, PSI or Hellinger")) case (Some(o), Some(driftMetric)) => val window = new Window(o.toMinutes.toInt, TimeUnit.MINUTES) - val joinPath = join.replaceFirst("\\.", "/") // we need to look up in the drift store with this transformed name + val joinPath = + join.replaceFirst("\\.", "/") // we need to look up in the drift store with this transformed name if (granularity == Aggregates) { val maybeDriftSeries = driftStore.getDriftSeries(joinPath, driftMetric, window, startTs, endTs, Some(name)) maybeDriftSeries match { - case Failure(exception) => Future.successful(InternalServerError(s"Error computing feature drift - ${exception.getMessage}")) - case Success(driftSeriesFuture) => driftSeriesFuture.map { - driftSeries => + case Failure(exception) => + Future.successful(InternalServerError(s"Error computing feature drift - ${exception.getMessage}")) + case Success(driftSeriesFuture) => + driftSeriesFuture.map { driftSeries => val featureTs = convertTileDriftSeriesInfoToTimeSeries(driftSeries.head, metric) Ok(featureTs.asJson.noSpaces) - } + } } } else { // percentiles val maybeCurrentSummarySeries = driftStore.getSummarySeries(joinPath, startTs, endTs, Some(name)) - val maybeBaselineSummarySeries = driftStore.getSummarySeries(joinPath, startTs - window.millis, endTs - window.millis, Some(name)) + val maybeBaselineSummarySeries = + driftStore.getSummarySeries(joinPath, startTs - window.millis, endTs - window.millis, Some(name)) (maybeCurrentSummarySeries, maybeBaselineSummarySeries) match { - case (Failure(exceptionA), Failure(exceptionB)) => Future.successful(InternalServerError(s"Error computing feature percentiles for current + offset time window.\nCurrent window error: ${exceptionA.getMessage}\nOffset window error: ${exceptionB.getMessage}")) - case (_, Failure(exception)) => Future.successful(InternalServerError(s"Error computing feature percentiles for offset time window - ${exception.getMessage}")) - case (Failure(exception), _) => Future.successful(InternalServerError(s"Error computing feature percentiles for current time window - ${exception.getMessage}")) + case (Failure(exceptionA), Failure(exceptionB)) => + Future.successful(InternalServerError( + s"Error computing feature percentiles for current + offset time window.\nCurrent window error: ${exceptionA.getMessage}\nOffset window error: ${exceptionB.getMessage}")) + case (_, Failure(exception)) => + Future.successful( + InternalServerError( + s"Error computing feature percentiles for offset time window - ${exception.getMessage}")) + case (Failure(exception), _) => + Future.successful( + InternalServerError( + s"Error computing feature percentiles for current time window - ${exception.getMessage}")) case (Success(currentSummarySeriesFuture), Success(baselineSummarySeriesFuture)) => - Future.sequence(Seq(currentSummarySeriesFuture, baselineSummarySeriesFuture)).map { - merged => - val currentSummarySeries = merged.head - val baselineSummarySeries = merged.last - val currentFeatureTs = convertTileSummarySeriesToTimeSeries(currentSummarySeries.head, metric) - val baselineFeatureTs = convertTileSummarySeriesToTimeSeries(baselineSummarySeries.head, metric) - val comparedTsData = ComparedFeatureTimeSeries(name, baselineFeatureTs, currentFeatureTs) - Ok(comparedTsData.asJson.noSpaces) + Future.sequence(Seq(currentSummarySeriesFuture, baselineSummarySeriesFuture)).map { merged => + val currentSummarySeries = merged.head + val baselineSummarySeries = merged.last + val currentFeatureTs = convertTileSummarySeriesToTimeSeries(currentSummarySeries.head, metric) + val baselineFeatureTs = convertTileSummarySeriesToTimeSeries(baselineSummarySeries.head, metric) + val comparedTsData = ComparedFeatureTimeSeries(name, baselineFeatureTs, currentFeatureTs) + Ok(comparedTsData.asJson.noSpaces) } } } @@ -177,13 +199,16 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon } } - private def convertTileDriftSeriesInfoToTimeSeries(tileDriftSeries: TileDriftSeries, metric: Metric): FeatureTimeSeries = { + private def convertTileDriftSeriesInfoToTimeSeries(tileDriftSeries: TileDriftSeries, + metric: Metric): FeatureTimeSeries = { val lhsList = if (metric == NullMetric) { tileDriftSeries.nullRatioChangePercentSeries.asScala } else { // check if we have a numeric / categorical feature. If the percentile drift series has non-null doubles // then we have a numeric feature at hand - val isNumeric = tileDriftSeries.percentileDriftSeries.asScala != null && tileDriftSeries.percentileDriftSeries.asScala.exists(_ != null) + val isNumeric = + tileDriftSeries.percentileDriftSeries.asScala != null && tileDriftSeries.percentileDriftSeries.asScala + .exists(_ != null) if (isNumeric) tileDriftSeries.percentileDriftSeries.asScala else tileDriftSeries.histogramDriftSeries.asScala } @@ -194,7 +219,8 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon FeatureTimeSeries(tileDriftSeries.getKey.getColumn, points) } - private def convertTileSummarySeriesToTimeSeries(summarySeries: TileSummarySeries, metric: Metric): Seq[TimeSeriesPoint] = { + private def convertTileSummarySeriesToTimeSeries(summarySeries: TileSummarySeries, + metric: Metric): Seq[TimeSeriesPoint] = { if (metric == NullMetric) { summarySeries.nullCount.asScala.zip(summarySeries.timestamps.asScala).map { case (nullCount, ts) => TimeSeriesPoint(0, ts, nullValue = Some(nullCount.intValue())) @@ -202,7 +228,7 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon } else { // check if we have a numeric / categorical feature. If the percentile drift series has non-null doubles // then we have a numeric feature at hand - val isNumeric = summarySeries.percentiles.asScala != null && summarySeries.percentiles.asScala.exists(_ != null) + val isNumeric = summarySeries.percentiles.asScala != null && summarySeries.percentiles.asScala.exists(_ != null) if (isNumeric) { summarySeries.percentiles.asScala.zip(summarySeries.timestamps.asScala).flatMap { case (percentiles, ts) => @@ -210,8 +236,7 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon case (l, value) => TimeSeriesPoint(value, ts, Some(l)) } } - } - else { + } else { summarySeries.timestamps.asScala.zipWithIndex.flatMap { case (ts, idx) => summarySeries.histogram.asScala.map { @@ -238,10 +263,10 @@ object TimeSeriesController { def parseAlgorithm(algorithm: Option[String]): Option[DriftMetric] = { algorithm.map(_.toLowerCase) match { - case Some("psi") => Some(DriftMetric.PSI) + case Some("psi") => Some(DriftMetric.PSI) case Some("hellinger") => Some(DriftMetric.HELLINGER) - case Some("jsd") => Some(DriftMetric.JENSEN_SHANNON) - case _ => None + case Some("jsd") => Some(DriftMetric.JENSEN_SHANNON) + case _ => None } } @@ -251,7 +276,7 @@ object TimeSeriesController { case Some("drift") => Some(Drift) // case Some("skew") => Some(Skew) // case Some("ooc") => Some(Skew) - case _ => None + case _ => None } } diff --git a/hub/app/module/DriftStoreModule.scala b/hub/app/module/DriftStoreModule.scala index b8c12786e3..6456626375 100644 --- a/hub/app/module/DriftStoreModule.scala +++ b/hub/app/module/DriftStoreModule.scala @@ -4,7 +4,8 @@ import ai.chronon.online.KVStore import ai.chronon.online.stats.DriftStore import com.google.inject.AbstractModule -import javax.inject.{Inject, Provider} +import javax.inject.Inject +import javax.inject.Provider class DriftStoreModule extends AbstractModule { @@ -14,7 +15,7 @@ class DriftStoreModule extends AbstractModule { } } -class DriftStoreProvider @Inject()(kvStore: KVStore) extends Provider[DriftStore] { +class DriftStoreProvider @Inject() (kvStore: KVStore) extends Provider[DriftStore] { override def get(): DriftStore = { new DriftStore(kvStore) } diff --git a/hub/test/controllers/SearchControllerSpec.scala b/hub/test/controllers/SearchControllerSpec.scala index b188336fa9..95a0680674 100644 --- a/hub/test/controllers/SearchControllerSpec.scala +++ b/hub/test/controllers/SearchControllerSpec.scala @@ -4,7 +4,9 @@ import controllers.MockJoinService.mockJoinRegistry import io.circe._ import io.circe.generic.auto._ import io.circe.parser._ -import model.{GroupBy, Join, ListJoinResponse} +import model.GroupBy +import model.Join +import model.ListJoinResponse import org.mockito.Mockito.mock import org.mockito.Mockito.when import org.scalatest.EitherValues diff --git a/hub/test/controllers/TimeSeriesControllerSpec.scala b/hub/test/controllers/TimeSeriesControllerSpec.scala index 77d1f16806..c32d87226e 100644 --- a/hub/test/controllers/TimeSeriesControllerSpec.scala +++ b/hub/test/controllers/TimeSeriesControllerSpec.scala @@ -1,13 +1,16 @@ package controllers -import ai.chronon.api.{TileDriftSeries, TileSeriesKey, TileSummarySeries} +import ai.chronon.api.TileDriftSeries +import ai.chronon.api.TileSeriesKey +import ai.chronon.api.TileSummarySeries import ai.chronon.online.stats.DriftStore import io.circe._ import io.circe.generic.auto._ import io.circe.parser._ import model._ import org.mockito.ArgumentMatchers.any -import org.mockito.Mockito.{mock, when} +import org.mockito.Mockito.mock +import org.mockito.Mockito.when import org.scalatest.EitherValues import org.scalatestplus.play._ import play.api.http.Status.BAD_REQUEST @@ -16,13 +19,16 @@ import play.api.mvc._ import play.api.test.Helpers._ import play.api.test._ -import java.util.concurrent.TimeUnit -import scala.concurrent.{ExecutionContext, Future} -import scala.concurrent.duration.Duration -import scala.util.{Failure, Success, Try} import java.lang.{Double => JDouble} import java.lang.{Long => JLong} +import java.util.concurrent.TimeUnit +import scala.concurrent.ExecutionContext +import scala.concurrent.Future +import scala.concurrent.duration.Duration import scala.jdk.CollectionConverters._ +import scala.util.Failure +import scala.util.Success +import scala.util.Try class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { diff --git a/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala b/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala index ff7fbdddf2..e050331db6 100644 --- a/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala +++ b/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala @@ -1,7 +1,6 @@ package ai.chronon.online import ai.chronon.online.KVStore.PutRequest -import io.circe.syntax._ import sttp.client3._ import sttp.model.StatusCode @@ -11,8 +10,8 @@ import scala.concurrent.Future // Hacky test kv store that we use to send objects to the in-memory KV store that lives in a different JVM (e.g spark -> hub) class HTTPKVStore(host: String = "localhost", port: Int = 9000) extends KVStore with Serializable { - val backend = HttpClientSyncBackend() - val baseUrl = s"http://$host:$port/api/v1/dataset" + val backend: SttpBackend[Identity, Any] = HttpClientSyncBackend() + val baseUrl: String = s"http://$host:$port/api/v1/dataset" override def multiGet(requests: collection.Seq[KVStore.GetRequest]): Future[collection.Seq[KVStore.GetResponse]] = ??? @@ -26,14 +25,13 @@ class HTTPKVStore(host: String = "localhost", port: Int = 9000) extends KVStore .header("Content-Type", "application/json") .body(jsonList(putRequests)) .send(backend) - }.map { - response => - response.code match { - case StatusCode.Ok => Seq(true) - case _ => - logger.error(s"HTTP multiPut failed with status ${response.code}: ${response.body}") - Seq(false) - } + }.map { response => + response.code match { + case StatusCode.Ok => Seq(true) + case _ => + logger.error(s"HTTP multiPut failed with status ${response.code}: ${response.body}") + Seq(false) + } } } } @@ -57,4 +55,3 @@ class HTTPKVStore(host: String = "localhost", port: Int = 9000) extends KVStore s"[ $requestsJson ]" } } - diff --git a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala index 44c83e37c1..02812082e0 100644 --- a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala +++ b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala @@ -205,8 +205,25 @@ object DriftStore { def compactDeserializer: TDeserializer = new TDeserializer(new TBinaryProtocol.Factory()) // todo - drop this hard-coded list in favor of a well known list or exposing as part of summaries - val percentileLabels: Seq[String] = Seq("p0", "p5", "p10", "p15", "p20", - "p25", "p30", "p35", "p40", "p45", - "p50", "p55", "p60", "p65", "p70", - "p75", "p80", "p85", "p90", "p95", "p100") + val percentileLabels: Seq[String] = Seq("p0", + "p5", + "p10", + "p15", + "p20", + "p25", + "p30", + "p35", + "p40", + "p45", + "p50", + "p55", + "p60", + "p65", + "p70", + "p75", + "p80", + "p85", + "p90", + "p95", + "p100") } diff --git a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala index b2e5015bf6..12622e222f 100644 --- a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala +++ b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala @@ -1,34 +1,20 @@ package ai.chronon.spark.scripts - -import ai.chronon import ai.chronon.api.ColorPrinter.ColorString import ai.chronon.api.Constants -import ai.chronon.api.DriftMetric import ai.chronon.api.Extensions.MetadataOps -import ai.chronon.api.Extensions.WindowOps -import ai.chronon.api.PartitionSpec -import ai.chronon.api.TileDriftSeries -import ai.chronon.api.TileSummarySeries -import ai.chronon.api.Window -import ai.chronon.online.{HTTPKVStore, KVStore} -import ai.chronon.online.stats.DriftStore +import ai.chronon.online.HTTPKVStore +import ai.chronon.online.KVStore import ai.chronon.spark.SparkSessionBuilder import ai.chronon.spark.TableUtils import ai.chronon.spark.stats.drift.Summarizer import ai.chronon.spark.stats.drift.SummaryUploader import ai.chronon.spark.stats.drift.scripts.PrepareData -import ai.chronon.spark.utils.InMemoryKvStore import ai.chronon.spark.utils.MockApi import org.rogach.scallop.ScallopConf import org.rogach.scallop.ScallopOption import org.slf4j.Logger import org.slf4j.LoggerFactory -import java.util.concurrent.TimeUnit -import scala.concurrent.Await -import scala.concurrent.duration.Duration -import scala.util.ScalaJavaConversions.IteratorOps - object ObservabilityDemo { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) From 7fc0637fabd1efbfe0113536676a09256243f48a Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Tue, 26 Nov 2024 13:10:46 -0500 Subject: [PATCH 091/152] Update colPrefix pass through in drift store --- online/src/main/scala/ai/chronon/online/stats/DriftStore.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala index 02812082e0..733ebbfc6d 100644 --- a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala +++ b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala @@ -77,7 +77,7 @@ class DriftStore(kvStore: KVStore, columnPrefix: Option[String]): Future[Seq[TileSummaryInfo]] = { val serializer: TSerializer = compactSerializer - val tileKeyMap = tileKeysForJoin(joinConf, columnPrefix) + val tileKeyMap = tileKeysForJoin(joinConf, None, columnPrefix) val requestContextMap: Map[GetRequest, SummaryRequestContext] = tileKeyMap.flatMap { case (group, keys) => keys.map { key => From 4d78fa58716d53f50792e5bcb4ed832c73d6ab82 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Tue, 26 Nov 2024 14:42:12 -0500 Subject: [PATCH 092/152] Add details to join response + join get endpoint --- .../controllers/InMemKVStoreController.scala | 4 ++-- hub/app/controllers/JoinController.scala | 13 +++++++++++++ hub/app/model/Model.scala | 7 ++++++- hub/app/store/MonitoringModelStore.scala | 14 ++++++++++++-- hub/conf/routes | 1 + hub/test/controllers/JoinControllerSpec.scala | 18 ++++++++++++++++++ hub/test/controllers/ModelControllerSpec.scala | 2 +- .../controllers/SearchControllerSpec.scala | 2 +- 8 files changed, 54 insertions(+), 7 deletions(-) diff --git a/hub/app/controllers/InMemKVStoreController.scala b/hub/app/controllers/InMemKVStoreController.scala index 1101b0454c..6eda57c4cf 100644 --- a/hub/app/controllers/InMemKVStoreController.scala +++ b/hub/app/controllers/InMemKVStoreController.scala @@ -8,15 +8,15 @@ import io.circe.Encoder import io.circe.generic.semiauto.deriveCodec import io.circe.parser.decode import play.api.Logger +import play.api.mvc import play.api.mvc.BaseController import play.api.mvc.ControllerComponents +import play.api.mvc.RawBuffer import java.util.Base64 import javax.inject.Inject import scala.concurrent.ExecutionContext import scala.concurrent.Future -import play.api.mvc -import play.api.mvc.RawBuffer class InMemKVStoreController @Inject() (val controllerComponents: ControllerComponents, kvStore: KVStore)(implicit ec: ExecutionContext) diff --git a/hub/app/controllers/JoinController.scala b/hub/app/controllers/JoinController.scala index f65de2976c..383b6438d5 100644 --- a/hub/app/controllers/JoinController.scala +++ b/hub/app/controllers/JoinController.scala @@ -38,4 +38,17 @@ class JoinController @Inject() (val controllerComponents: ControllerComponents, Ok(json) } } + + /** + * Returns a specific join by name + */ + def get(name: String): Action[AnyContent] = { + Action { implicit request: Request[AnyContent] => + val maybeJoin = monitoringStore.getJoins.find(j => j.name.equalsIgnoreCase(name)) + maybeJoin match { + case None => NotFound(s"Join: $name wasn't found") + case Some(join) => Ok(join.asJson.noSpaces) + } + } + } } diff --git a/hub/app/model/Model.scala b/hub/app/model/Model.scala index f3e2d9f54c..eb523251d3 100644 --- a/hub/app/model/Model.scala +++ b/hub/app/model/Model.scala @@ -2,7 +2,12 @@ package model /** Captures some details related to ML models registered with Zipline to surface in the Hub UI */ case class GroupBy(name: String, features: Seq[String]) -case class Join(name: String, joinFeatures: Seq[String], groupBys: Seq[GroupBy]) +case class Join(name: String, + joinFeatures: Seq[String], + groupBys: Seq[GroupBy], + online: Boolean, + production: Boolean, + team: String) case class Model(name: String, join: Join, online: Boolean, production: Boolean, team: String, modelType: String) // 1.) metadataUpload: join -> map> diff --git a/hub/app/store/MonitoringModelStore.scala b/hub/app/store/MonitoringModelStore.scala index 9dd11d280c..89fbc5fa3e 100644 --- a/hub/app/store/MonitoringModelStore.scala +++ b/hub/app/store/MonitoringModelStore.scala @@ -60,7 +60,12 @@ class MonitoringModelStore(apiImpl: Api) { } val outputColumns = thriftJoin.outputColumnsByGroup.getOrElse("derivations", Array.empty) - val join = Join(thriftJoin.metaData.name, outputColumns, groupBys) + val join = Join(thriftJoin.metaData.name, + outputColumns, + groupBys, + thriftJoin.metaData.online, + thriftJoin.metaData.production, + thriftJoin.metaData.team) Option( Model(m.metaData.name, join, m.metaData.online, m.metaData.production, m.metaData.team, m.modelType.name())) } else { @@ -76,7 +81,12 @@ class MonitoringModelStore(apiImpl: Api) { } val outputColumns = thriftJoin.outputColumnsByGroup.getOrElse("derivations", Array.empty) - Join(thriftJoin.metaData.name, outputColumns, groupBys) + Join(thriftJoin.metaData.name, + outputColumns, + groupBys, + thriftJoin.metaData.online, + thriftJoin.metaData.production, + thriftJoin.metaData.team) } } diff --git a/hub/conf/routes b/hub/conf/routes index 8939447e6e..9ec36cc087 100644 --- a/hub/conf/routes +++ b/hub/conf/routes @@ -2,6 +2,7 @@ GET /api/v1/ping controllers.ApplicationController.ping() GET /api/v1/models controllers.ModelController.list(offset: Option[Int], limit: Option[Int]) GET /api/v1/joins controllers.JoinController.list(offset: Option[Int], limit: Option[Int]) +GET /api/v1/join/:name controllers.JoinController.get(name: String) GET /api/v1/search controllers.SearchController.search(term: String, offset: Option[Int], limit: Option[Int]) # model prediction & model drift - this is TBD at the moment diff --git a/hub/test/controllers/JoinControllerSpec.scala b/hub/test/controllers/JoinControllerSpec.scala index b6924cfdd1..8cf8ad7c38 100644 --- a/hub/test/controllers/JoinControllerSpec.scala +++ b/hub/test/controllers/JoinControllerSpec.scala @@ -4,6 +4,7 @@ import controllers.MockJoinService.mockJoinRegistry import io.circe._ import io.circe.generic.auto._ import io.circe.parser._ +import model.Join import model.ListJoinResponse import org.mockito.Mockito._ import org.scalatest.EitherValues @@ -36,6 +37,13 @@ class JoinControllerSpec extends PlaySpec with Results with EitherValues { status(result) mustBe BAD_REQUEST } + "send 404 on missing join" in { + when(mockedStore.getJoins).thenReturn(mockJoinRegistry) + + val result = controller.get("fake_join").apply(FakeRequest()) + status(result) mustBe NOT_FOUND + } + "send valid results on a correctly formed request" in { when(mockedStore.getJoins).thenReturn(mockJoinRegistry) @@ -61,5 +69,15 @@ class JoinControllerSpec extends PlaySpec with Results with EitherValues { items.length mustBe number items.map(_.name.toInt).toSet mustBe (startOffset until startOffset + number).toSet } + + "send valid join object on specific join lookup" in { + when(mockedStore.getJoins).thenReturn(mockJoinRegistry) + + val result = controller.get("10").apply(FakeRequest()) + status(result) mustBe OK + val bodyText = contentAsString(result) + val joinResponse: Either[Error, Join] = decode[Join](bodyText) + joinResponse.right.value.name mustBe "10" + } } } diff --git a/hub/test/controllers/ModelControllerSpec.scala b/hub/test/controllers/ModelControllerSpec.scala index 95b96ca24a..d68b536e39 100644 --- a/hub/test/controllers/ModelControllerSpec.scala +++ b/hub/test/controllers/ModelControllerSpec.scala @@ -71,7 +71,7 @@ class ModelControllerSpec extends PlaySpec with Results with EitherValues { object MockDataService { def generateMockModel(id: String): Model = { val groupBys = Seq(GroupBy("my_groupBy", Seq("g1", "g2"))) - val join = Join("my_join", Seq("ext_f1", "ext_f2", "d_1", "d2"), groupBys) + val join = Join("my_join", Seq("ext_f1", "ext_f2", "d_1", "d2"), groupBys, true, true, "my_team") Model(id, join, online = true, production = true, "my team", "XGBoost") } diff --git a/hub/test/controllers/SearchControllerSpec.scala b/hub/test/controllers/SearchControllerSpec.scala index 95a0680674..8ecca5c5c7 100644 --- a/hub/test/controllers/SearchControllerSpec.scala +++ b/hub/test/controllers/SearchControllerSpec.scala @@ -73,7 +73,7 @@ class SearchControllerSpec extends PlaySpec with Results with EitherValues { object MockJoinService { def generateMockJoin(id: String): Join = { val groupBys = Seq(GroupBy("my_groupBy", Seq("g1", "g2"))) - Join(id, Seq("ext_f1", "ext_f2", "d_1", "d2"), groupBys) + Join(id, Seq("ext_f1", "ext_f2", "d_1", "d2"), groupBys, true, true, "my_team") } val mockJoinRegistry: Seq[Join] = (0 until 100).map(i => generateMockJoin(i.toString)) From 7e7d2a17cb24b3d1b0f4be164033b4285a2fdbc3 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 10:31:21 -0500 Subject: [PATCH 093/152] Rebase + comments --- .github/workflows/test_scala_no_spark.yaml | 7 ++++++- api/src/main/scala/ai/chronon/api/ColorPrinter.scala | 3 +++ build.sbt | 2 -- hub/app/model/Model.scala | 7 ------- spark/src/main/scala/ai/chronon/spark/TableUtils.scala | 6 +++--- 5 files changed, 12 insertions(+), 13 deletions(-) diff --git a/.github/workflows/test_scala_no_spark.yaml b/.github/workflows/test_scala_no_spark.yaml index 77edb859f9..00c4bc25f0 100644 --- a/.github/workflows/test_scala_no_spark.yaml +++ b/.github/workflows/test_scala_no_spark.yaml @@ -60,4 +60,9 @@ jobs: - name: Run api tests run: | - sbt "++ 2.12.18 api/test" \ No newline at end of file + sbt "++ 2.12.18 api/test" + + - name: Run hub tests + run: | + export SBT_OPTS="-Xmx8G -Xms2G" + sbt "++ 2.12.18 hub/test" diff --git a/api/src/main/scala/ai/chronon/api/ColorPrinter.scala b/api/src/main/scala/ai/chronon/api/ColorPrinter.scala index e779e3eaf1..4d1dc57c50 100644 --- a/api/src/main/scala/ai/chronon/api/ColorPrinter.scala +++ b/api/src/main/scala/ai/chronon/api/ColorPrinter.scala @@ -11,11 +11,14 @@ object ColorPrinter { private val ANSI_YELLOW = "\u001B[38;5;172m" // Muted Orange private val ANSI_GREEN = "\u001B[38;5;28m" // Forest green + private val BOLD = "\u001B[1m" + implicit class ColorString(val s: String) extends AnyVal { def red: String = s"$ANSI_RED$s$ANSI_RESET" def blue: String = s"$ANSI_BLUE$s$ANSI_RESET" def yellow: String = s"$ANSI_YELLOW$s$ANSI_RESET" def green: String = s"$ANSI_GREEN$s$ANSI_RESET" def low: String = s.toLowerCase + def highlight: String = s"$BOLD$ANSI_RED$s$ANSI_RESET" } } diff --git a/build.sbt b/build.sbt index 139945dcba..39aabe6364 100644 --- a/build.sbt +++ b/build.sbt @@ -135,8 +135,6 @@ lazy val online = project "com.github.ben-manes.caffeine" % "caffeine" % "3.1.8" ), libraryDependencies ++= jackson, - // we pull in circe to help us ser case classes like PutRequest without requiring annotations - libraryDependencies ++= circe, // dep needed for HTTPKvStore - yank when we rip this out libraryDependencies += "com.softwaremill.sttp.client3" %% "core" % "3.9.7", libraryDependencies ++= spark_all.map(_ % "provided"), diff --git a/hub/app/model/Model.scala b/hub/app/model/Model.scala index eb523251d3..a83c1c1803 100644 --- a/hub/app/model/Model.scala +++ b/hub/app/model/Model.scala @@ -10,13 +10,6 @@ case class Join(name: String, team: String) case class Model(name: String, join: Join, online: Boolean, production: Boolean, team: String, modelType: String) -// 1.) metadataUpload: join -> map> -// 2.) fetchJoinConf + listColumns: join => list -// 3.) (columns, start, end) -> list - -// 4.) 1:n/fetchTile: tileKey -> TileSummaries -// 5.) 1:n:n/compareTiles: TileSummaries, TileSummaries -> TileDrift -// 6.) Map[column, Seq[tileDrift]] -> TimeSeriesController /** Supported Metric types */ sealed trait MetricType diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index d516d7a5f0..5595b46f59 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -327,9 +327,9 @@ case class TableUtils(sparkSession: SparkSession) { sql(creationSql) } catch { case _: TableAlreadyExistsException => - println(s"Table $tableName already exists, skipping creation") + logger.info(s"Table $tableName already exists, skipping creation") case e: Exception => - println(s"Failed to create table $tableName", e) + logger.error(s"Failed to create table $tableName", e) throw e } } @@ -357,7 +357,7 @@ case class TableUtils(sparkSession: SparkSession) { // so that an exception will be thrown below dfRearranged } - println(s"Repartitioning and writing into table $tableName".yellow) + logger.info(s"Repartitioning and writing into table $tableName".yellow) repartitionAndWrite(finalizedDf, tableName, saveMode, stats, sortByCols) } From 470fd9eb35b564a8106321a73654bb3340027a87 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 10:42:11 -0500 Subject: [PATCH 094/152] Revert TableUtils for now --- spark/src/main/scala/ai/chronon/spark/TableUtils.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index 5595b46f59..09aca39240 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -357,7 +357,6 @@ case class TableUtils(sparkSession: SparkSession) { // so that an exception will be thrown below dfRearranged } - logger.info(s"Repartitioning and writing into table $tableName".yellow) repartitionAndWrite(finalizedDf, tableName, saveMode, stats, sortByCols) } From 18e260c445cb5eed427b29f201a0aadb400ee4b7 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 12:14:39 -0500 Subject: [PATCH 095/152] Swap to new uploader app and revert old code --- build.sbt | 5 +- docker-init/demo/build.sh | 1 + docker-init/demo/load_summaries.sh | 2 +- .../spark/scripts/ObservabilityDemo.scala | 18 ++- .../scripts/ObservabilityDemoDataLoader.scala | 114 ++++++++++++++++++ 5 files changed, 136 insertions(+), 4 deletions(-) create mode 100755 docker-init/demo/build.sh create mode 100644 spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala diff --git a/build.sbt b/build.sbt index 39aabe6364..fd4852dfb4 100644 --- a/build.sbt +++ b/build.sbt @@ -264,7 +264,10 @@ lazy val hub = (project in file("hub")) excludeDependencies ++= Seq( ExclusionRule(organization = "org.slf4j", name = "slf4j-log4j12"), ExclusionRule(organization = "log4j", name = "log4j"), - ExclusionRule(organization = "org.apache.logging.log4j", name = "log4j-to-slf4j") + ExclusionRule(organization = "org.apache.logging.log4j", name = "log4j-to-slf4j"), + ExclusionRule("org.apache.logging.log4j", "log4j-slf4j-impl"), + ExclusionRule("org.apache.logging.log4j", "log4j-core"), + ExclusionRule("org.apache.logging.log4j", "log4j-api") ), // Ensure consistent versions of logging libraries dependencyOverrides ++= Seq( diff --git a/docker-init/demo/build.sh b/docker-init/demo/build.sh new file mode 100755 index 0000000000..5627dac2f5 --- /dev/null +++ b/docker-init/demo/build.sh @@ -0,0 +1 @@ +docker build -t obs . \ No newline at end of file diff --git a/docker-init/demo/load_summaries.sh b/docker-init/demo/load_summaries.sh index 61b4d9db95..15bc3681a0 100755 --- a/docker-init/demo/load_summaries.sh +++ b/docker-init/demo/load_summaries.sh @@ -8,5 +8,5 @@ docker-compose -f docker-init/compose.yaml exec app /opt/spark/bin/spark-submit --driver-class-path "/opt/spark/jars/*:/app/cli/*" \ --conf "spark.driver.host=localhost" \ --conf "spark.driver.bindAddress=0.0.0.0" \ - --class ai.chronon.spark.scripts.ObservabilityDemo \ + --class ai.chronon.spark.scripts.ObservabilityDemoDataLoader \ /app/cli/spark.jar diff --git a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala index 12622e222f..064b69700f 100644 --- a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala +++ b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala @@ -1,20 +1,34 @@ package ai.chronon.spark.scripts + +import ai.chronon import ai.chronon.api.ColorPrinter.ColorString import ai.chronon.api.Constants +import ai.chronon.api.DriftMetric import ai.chronon.api.Extensions.MetadataOps -import ai.chronon.online.HTTPKVStore +import ai.chronon.api.Extensions.WindowOps +import ai.chronon.api.PartitionSpec +import ai.chronon.api.TileDriftSeries +import ai.chronon.api.TileSummarySeries +import ai.chronon.api.Window import ai.chronon.online.KVStore +import ai.chronon.online.stats.DriftStore import ai.chronon.spark.SparkSessionBuilder import ai.chronon.spark.TableUtils import ai.chronon.spark.stats.drift.Summarizer import ai.chronon.spark.stats.drift.SummaryUploader import ai.chronon.spark.stats.drift.scripts.PrepareData +import ai.chronon.spark.utils.InMemoryKvStore import ai.chronon.spark.utils.MockApi import org.rogach.scallop.ScallopConf import org.rogach.scallop.ScallopOption import org.slf4j.Logger import org.slf4j.LoggerFactory +import java.util.concurrent.TimeUnit +import scala.concurrent.Await +import scala.concurrent.duration.Duration +import scala.util.ScalaJavaConversions.IteratorOps + object ObservabilityDemo { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) @@ -42,7 +56,7 @@ object ObservabilityDemo { // mock api impl for online fetching and uploading val kvStoreFunc: () => KVStore = () => { // cannot reuse the variable - or serialization error - val result = new HTTPKVStore() + val result = InMemoryKvStore.build(namespace, () => null) result } val api = new MockApi(kvStoreFunc, namespace) diff --git a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala new file mode 100644 index 0000000000..65275b8d94 --- /dev/null +++ b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala @@ -0,0 +1,114 @@ +package ai.chronon.spark.scripts + +import ai.chronon.api.ColorPrinter.ColorString +import ai.chronon.api.Constants +import ai.chronon.api.Extensions.MetadataOps +import ai.chronon.online.{HTTPKVStore, KVStore} +import ai.chronon.spark.{SparkSessionBuilder, TableUtils} +import ai.chronon.spark.stats.drift.{Summarizer, SummaryUploader} +import ai.chronon.spark.stats.drift.scripts.PrepareData +import ai.chronon.spark.utils.{InMemoryKvStore, MockApi} +import org.rogach.scallop.{ScallopConf, ScallopOption} +import org.slf4j.{Logger, LoggerFactory} + +object ObservabilityDemoDataLoader { + @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) + + def time(message: String)(block: => Unit): Unit = { + logger.info(s"$message..".yellow) + val start = System.currentTimeMillis() + block + val end = System.currentTimeMillis() + logger.info(s"$message took ${end - start} ms".green) + } + + class Conf(arguments: Seq[String]) extends ScallopConf(arguments) { + val startDs: ScallopOption[String] = opt[String]( + name = "start-ds", + default = Some("2023-01-01"), + descr = "Start date in YYYY-MM-DD format" + ) + + val endDs: ScallopOption[String] = opt[String]( + name = "end-ds", + default = Some("2023-02-30"), + descr = "End date in YYYY-MM-DD format" + ) + + val rowCount: ScallopOption[Int] = opt[Int]( + name = "row-count", + default = Some(700000), + descr = "Number of rows to generate" + ) + + val namespace: ScallopOption[String] = opt[String]( + name = "namespace", + default = Some("observability_demo"), + descr = "Namespace for the demo" + ) + + verify() + } + + def main(args: Array[String]): Unit = { + + val config = new Conf(args) + val startDs = config.startDs() + val endDs = config.endDs() + val rowCount = config.rowCount() + val namespace = config.namespace() + + val spark = SparkSessionBuilder.build(namespace, local = true) + implicit val tableUtils: TableUtils = TableUtils(spark) + tableUtils.createDatabase(namespace) + + // generate anomalous data (join output) + val prepareData = PrepareData(namespace) + val join = prepareData.generateAnomalousFraudJoin + + time("Preparing data") { + val df = prepareData.generateFraudSampleData(rowCount, startDs, endDs, join.metaData.loggedTable) + df.show(10, truncate = false) + } + + // mock api impl for online fetching and uploading + val inMemKvStoreFunc: () => KVStore = () => { + // cannot reuse the variable - or serialization error + val result = InMemoryKvStore.build(namespace, () => null) + result + } + val inMemoryApi = new MockApi(inMemKvStoreFunc, namespace) + + time("Summarizing data") { + // compute summary table and packed table (for uploading) + Summarizer.compute(inMemoryApi, join.metaData, ds = endDs, useLogs = true) + } + + val packedTable = join.metaData.packedSummaryTable + + // create necessary tables in kvstore - we now publish to the HTTP KV store as we need this available to the Hub + val httpKvStoreFunc: () => KVStore = () => { + // cannot reuse the variable - or serialization error + val result = new HTTPKVStore() + result + } + val hubApi = new MockApi(httpKvStoreFunc, namespace) + + val kvStore = hubApi.genKvStore + kvStore.create(Constants.MetadataDataset) + kvStore.create(Constants.TiledSummaryDataset) + + // upload join conf + hubApi.buildFetcher().putJoinConf(join) + + time("Uploading summaries") { + val uploader = new SummaryUploader(tableUtils.loadTable(packedTable), hubApi) + uploader.run() + } + + println("Done uploading summaries! \uD83E\uDD73".green) + // clean up spark session and force jvm exit + spark.stop() + System.exit(0) + } +} From de044c4eab538e98d5bf9c698ae1a5ad98275ea2 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 12:14:59 -0500 Subject: [PATCH 096/152] Downgrade in mem controller log to debug --- hub/app/controllers/InMemKVStoreController.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hub/app/controllers/InMemKVStoreController.scala b/hub/app/controllers/InMemKVStoreController.scala index 6eda57c4cf..97768e86b8 100644 --- a/hub/app/controllers/InMemKVStoreController.scala +++ b/hub/app/controllers/InMemKVStoreController.scala @@ -32,7 +32,7 @@ class InMemKVStoreController @Inject() (val controllerComponents: ControllerComp case Some(bytes) => decode[Array[PutRequest]](bytes.utf8String) match { case Right(putRequests) => - logger.info(s"Attempting a bulkPut with ${putRequests.length} items") + logger.debug(s"Attempting a bulkPut with ${putRequests.length} items") val resultFuture = kvStore.multiPut(putRequests) resultFuture.map { responses => if (responses.contains(false)) { From dfc4e1fdd623f0465c82a952981bd1c05352a177 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 12:15:28 -0500 Subject: [PATCH 097/152] style: Apply scalafix and scalafmt changes --- .../scripts/ObservabilityDemoDataLoader.scala | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala index 65275b8d94..f317488273 100644 --- a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala +++ b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala @@ -3,13 +3,19 @@ package ai.chronon.spark.scripts import ai.chronon.api.ColorPrinter.ColorString import ai.chronon.api.Constants import ai.chronon.api.Extensions.MetadataOps -import ai.chronon.online.{HTTPKVStore, KVStore} -import ai.chronon.spark.{SparkSessionBuilder, TableUtils} -import ai.chronon.spark.stats.drift.{Summarizer, SummaryUploader} +import ai.chronon.online.HTTPKVStore +import ai.chronon.online.KVStore +import ai.chronon.spark.SparkSessionBuilder +import ai.chronon.spark.TableUtils +import ai.chronon.spark.stats.drift.Summarizer +import ai.chronon.spark.stats.drift.SummaryUploader import ai.chronon.spark.stats.drift.scripts.PrepareData -import ai.chronon.spark.utils.{InMemoryKvStore, MockApi} -import org.rogach.scallop.{ScallopConf, ScallopOption} -import org.slf4j.{Logger, LoggerFactory} +import ai.chronon.spark.utils.InMemoryKvStore +import ai.chronon.spark.utils.MockApi +import org.rogach.scallop.ScallopConf +import org.rogach.scallop.ScallopOption +import org.slf4j.Logger +import org.slf4j.LoggerFactory object ObservabilityDemoDataLoader { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) From 6ab315c8f1f819cce36456eb12510b9a136d760e Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 12:43:56 -0500 Subject: [PATCH 098/152] Handle empty responses --- hub/app/controllers/TimeSeriesController.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/hub/app/controllers/TimeSeriesController.scala b/hub/app/controllers/TimeSeriesController.scala index ca5207c688..cae7ca25c9 100644 --- a/hub/app/controllers/TimeSeriesController.scala +++ b/hub/app/controllers/TimeSeriesController.scala @@ -188,8 +188,14 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon Future.sequence(Seq(currentSummarySeriesFuture, baselineSummarySeriesFuture)).map { merged => val currentSummarySeries = merged.head val baselineSummarySeries = merged.last - val currentFeatureTs = convertTileSummarySeriesToTimeSeries(currentSummarySeries.head, metric) - val baselineFeatureTs = convertTileSummarySeriesToTimeSeries(baselineSummarySeries.head, metric) + val currentFeatureTs = { + if (currentSummarySeries.isEmpty) Seq.empty + else convertTileSummarySeriesToTimeSeries(currentSummarySeries.head, metric) + } + val baselineFeatureTs = { + if (baselineSummarySeries.isEmpty) Seq.empty + else convertTileSummarySeriesToTimeSeries(baselineSummarySeries.head, metric) + } val comparedTsData = ComparedFeatureTimeSeries(name, baselineFeatureTs, currentFeatureTs) Ok(comparedTsData.asJson.noSpaces) } From 62687b97a83c268fd3554fd1978035b42edf975c Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 15:15:23 -0500 Subject: [PATCH 099/152] Remove redundant log4j props file --- docker-init/demo/log4j2.properties | 17 ----------------- 1 file changed, 17 deletions(-) delete mode 100644 docker-init/demo/log4j2.properties diff --git a/docker-init/demo/log4j2.properties b/docker-init/demo/log4j2.properties deleted file mode 100644 index a0167384ee..0000000000 --- a/docker-init/demo/log4j2.properties +++ /dev/null @@ -1,17 +0,0 @@ -# Root logger -rootLogger.level = ERROR -rootLogger.appenderRef.console.ref = console - -# Console appender configuration -appender.console.type = Console -appender.console.name = console -appender.console.target = SYSTEM_OUT -appender.console.layout.type = PatternLayout -appender.console.layout.pattern = %yellow{%d{yyyy/MM/dd HH:mm:ss}} %highlight{%-5level} %green{%file:%line} - %message%n - -# Configure specific logger -logger.chronon.name = ai.chronon -logger.chronon.level = info - -# Configure colors -appender.console.layout.disableAnsi = false \ No newline at end of file From 5316680a3ec373dd1a34eca8b675875ae009dfd1 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 15:16:00 -0500 Subject: [PATCH 100/152] Use thread locals for thrift serializers --- .../ai/chronon/online/stats/DriftStore.scala | 20 ++++++++++++------- .../spark/stats/drift/Summarizer.scala | 5 +++-- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala index 733ebbfc6d..b1b935a04f 100644 --- a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala +++ b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala @@ -12,8 +12,8 @@ import ai.chronon.api.thrift.protocol.TProtocolFactory import ai.chronon.online.KVStore import ai.chronon.online.KVStore.GetRequest import ai.chronon.online.MetadataStore -import ai.chronon.online.stats.DriftStore.compactDeserializer -import ai.chronon.online.stats.DriftStore.compactSerializer +import ai.chronon.online.stats.DriftStore.binaryDeserializer +import ai.chronon.online.stats.DriftStore.binarySerializer import java.io.Serializable import scala.concurrent.Future @@ -52,8 +52,6 @@ class DriftStore(kvStore: KVStore, } } - private val deserializer: TDeserializer = compactDeserializer - private case class SummaryRequestContext(request: GetRequest, tileKey: TileKey, groupName: String) private case class SummaryResponseContext(summaries: Array[(TileSummary, Long)], tileKey: TileKey, groupName: String) @@ -76,7 +74,7 @@ class DriftStore(kvStore: KVStore, endMs: Option[Long], columnPrefix: Option[String]): Future[Seq[TileSummaryInfo]] = { - val serializer: TSerializer = compactSerializer + val serializer: TSerializer = binarySerializer.get() val tileKeyMap = tileKeysForJoin(joinConf, None, columnPrefix) val requestContextMap: Map[GetRequest, SummaryRequestContext] = tileKeyMap.flatMap { case (group, keys) => @@ -90,6 +88,7 @@ class DriftStore(kvStore: KVStore, val responseFuture = kvStore.multiGet(requestContextMap.keys.toSeq) responseFuture.map { responses => + val deserializer = binaryDeserializer.get() // deserialize the responses and surround with context val responseContextTries: Seq[Try[SummaryResponseContext]] = responses.map { response => val valuesTry = response.values @@ -200,9 +199,16 @@ object DriftStore { class SerializableSerializer(factory: TProtocolFactory) extends TSerializer(factory) with Serializable // crazy bug in compact protocol - do not change to compact - def compactSerializer: SerializableSerializer = new SerializableSerializer(new TBinaryProtocol.Factory()) - def compactDeserializer: TDeserializer = new TDeserializer(new TBinaryProtocol.Factory()) + @transient + lazy val binarySerializer: ThreadLocal[TSerializer] = new ThreadLocal[TSerializer] { + override def initialValue(): TSerializer = new TSerializer(new TBinaryProtocol.Factory()) + } + + @transient + lazy val binaryDeserializer: ThreadLocal[TDeserializer] = new ThreadLocal[TDeserializer] { + override def initialValue(): TDeserializer = new TDeserializer(new TBinaryProtocol.Factory()) + } // todo - drop this hard-coded list in favor of a well known list or exposing as part of summaries val percentileLabels: Seq[String] = Seq("p0", diff --git a/spark/src/main/scala/ai/chronon/spark/stats/drift/Summarizer.scala b/spark/src/main/scala/ai/chronon/spark/stats/drift/Summarizer.scala index 856850b892..2874f3b907 100644 --- a/spark/src/main/scala/ai/chronon/spark/stats/drift/Summarizer.scala +++ b/spark/src/main/scala/ai/chronon/spark/stats/drift/Summarizer.scala @@ -6,7 +6,7 @@ import ai.chronon.api._ import ai.chronon.online.Api import ai.chronon.online.KVStore.GetRequest import ai.chronon.online.KVStore.PutRequest -import ai.chronon.online.stats.DriftStore.compactSerializer +import ai.chronon.online.stats.DriftStore.binarySerializer import ai.chronon.spark.TableUtils import ai.chronon.spark.stats.drift.Expressions.CardinalityExpression import ai.chronon.spark.stats.drift.Expressions.SummaryExpression @@ -322,9 +322,10 @@ class SummaryPacker(confPath: String, val func: sql.Row => Seq[TileRow] = Expressions.summaryPopulatorFunc(summaryExpressions, df.schema, keyBuilder, tu.partitionColumn) - val serializer = compactSerializer val packedRdd: RDD[sql.Row] = df.rdd.flatMap(func).map { tileRow => // pack into bytes + val serializer = binarySerializer.get() + val partition = tileRow.partition val timestamp = tileRow.tileTs val summaries = tileRow.summaries From 92f0d11aa67c3c946a3e376ef9a9011892392bf9 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 16:01:35 -0500 Subject: [PATCH 101/152] Rebase + comments --- build.sbt | 1 + .../main/scala/ai/chronon/spark/scripts/DataServer.scala | 7 +------ 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/build.sbt b/build.sbt index fd4852dfb4..c80ede6a12 100644 --- a/build.sbt +++ b/build.sbt @@ -80,6 +80,7 @@ val jackson = Seq( "com.fasterxml.jackson.module" %% "jackson-module-scala" ).map(_ % jackson_2_15) +// Circe is used to ser / deser case class payloads for the Hub Play webservice val circe = Seq( "io.circe" %% "circe-core", "io.circe" %% "circe-generic", diff --git a/spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala b/spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala index fba36f4d0c..cf935fd334 100644 --- a/spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala +++ b/spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala @@ -35,15 +35,10 @@ class DataServer(driftSeries: Seq[TileDriftSeries], summarySeries: Seq[TileSumma ctx.flush() } - private val serializer: ThreadLocal[SerializableSerializer] = - ThreadLocal.withInitial(new Supplier[SerializableSerializer] { - override def get(): SerializableSerializer = DriftStore.compactSerializer - }) - private def convertToBytesMap[T <: TBase[_, _]: Manifest: ClassTag]( series: T, keyF: T => TileSeriesKey): Map[String, String] = { - val serializerInstance = serializer.get() + val serializerInstance = DriftStore.binarySerializer.get() val encoder = Base64.getEncoder val keyBytes = serializerInstance.serialize(keyF(series)) val valueBytes = serializerInstance.serialize(series) From 4965c8a292407b7f901dd0b8c25647e0337aacf8 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 16:02:11 -0500 Subject: [PATCH 102/152] style: Apply scalafix and scalafmt changes --- spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala b/spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala index cf935fd334..afd194a7d2 100644 --- a/spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala +++ b/spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala @@ -5,7 +5,6 @@ import ai.chronon.api.TileSeriesKey import ai.chronon.api.TileSummarySeries import ai.chronon.api.thrift.TBase import ai.chronon.online.stats.DriftStore -import ai.chronon.online.stats.DriftStore.SerializableSerializer import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.databind.SerializationFeature import com.fasterxml.jackson.module.scala.DefaultScalaModule @@ -19,7 +18,6 @@ import io.netty.handler.codec.http._ import io.netty.util.CharsetUtil import java.util.Base64 -import java.util.function.Supplier import scala.reflect.ClassTag class DataServer(driftSeries: Seq[TileDriftSeries], summarySeries: Seq[TileSummarySeries], port: Int = 8181) { From 2979bb9769ad64b9e652f8803b702b9df992171e Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Wed, 27 Nov 2024 16:46:26 -0500 Subject: [PATCH 103/152] remove old import --- frontend/src/routes/joins/[slug]/+page.svelte | 1 - 1 file changed, 1 deletion(-) diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index 479090fccd..87b603dbba 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -18,7 +18,6 @@ import PageHeader from '$lib/components/PageHeader/PageHeader.svelte'; import Separator from '$lib/components/ui/separator/separator.svelte'; import ResetZoomButton from '$lib/components/ResetZoomButton/ResetZoomButton.svelte'; - import DateRangeSelector from '$lib/components/DateRangeSelector/DateRangeSelector.svelte'; import IntersectionObserver from 'svelte-intersection-observer'; import { fade } from 'svelte/transition'; import { Button } from '$lib/components/ui/button'; From df7dc77cd5ba48a622fd0961e2beae9efa774bcf Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Tue, 3 Dec 2024 09:40:08 -0500 Subject: [PATCH 104/152] make isBarChart derived --- frontend/src/lib/components/EChart/EChart.svelte | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/frontend/src/lib/components/EChart/EChart.svelte b/frontend/src/lib/components/EChart/EChart.svelte index a43b555f39..2615546b37 100644 --- a/frontend/src/lib/components/EChart/EChart.svelte +++ b/frontend/src/lib/components/EChart/EChart.svelte @@ -91,7 +91,11 @@ let isCommandPressed = $state(false); let isMouseOverTooltip = $state(false); let hideTimeoutId: ReturnType; - let isBarChart = $state(false); + + const isBarChart = $derived.by(() => { + const series = mergedOption.series as EChartOption.Series[]; + return series?.some((s) => s.type === 'bar'); + }); function handleKeyDown(event: KeyboardEvent) { if (event.metaKey || event.ctrlKey) { @@ -142,10 +146,6 @@ chartInstance = echarts.init(chartDiv, theme); chartInstance.setOption(mergedOption); - // Set chart type - const series = mergedOption.series as EChartOption.Series[]; - isBarChart = series?.[0]?.type === 'bar'; - chartInstance.on('click', (params: ECElementEvent) => { dispatch('click', { detail: params, From 17500ed798fa97a3d28c7d98a47e8d651b471224 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Tue, 3 Dec 2024 09:45:45 -0500 Subject: [PATCH 105/152] wire up percentile and compared chart --- frontend/src/lib/api/api.ts | 33 ++++++---- .../PercentileChart/PercentileChart.svelte | 18 ++--- frontend/src/lib/types/Model/Model.ts | 4 +- frontend/src/lib/util/sample-data.ts | 8 +-- frontend/src/routes/joins/[slug]/+page.svelte | 66 ++++++++++++++++--- 5 files changed, 91 insertions(+), 38 deletions(-) diff --git a/frontend/src/lib/api/api.ts b/frontend/src/lib/api/api.ts index c489c4124c..0777e87e2b 100644 --- a/frontend/src/lib/api/api.ts +++ b/frontend/src/lib/api/api.ts @@ -92,16 +92,27 @@ export async function getJoinTimeseries({ return get(`join/${joinId}/timeseries?${params.toString()}`); } -export async function getFeatureTimeseries( - featureName: string, - startTs: number, - endTs: number, - metricType: string = 'drift', - metrics: string = 'null', - offset: string = '10h', - algorithm: string = 'psi', - granularity: string = 'percentile' -): Promise { +export async function getFeatureTimeseries({ + joinId, + featureName, + startTs, + endTs, + metricType = 'drift', + metrics = 'null', + offset = '10h', + algorithm = 'psi', + granularity = 'aggregates' +}: { + joinId: string; + featureName: string; + startTs: number; + endTs: number; + metricType?: string; + metrics?: string; + offset?: string; + algorithm?: string; + granularity?: string; +}): Promise { const params = new URLSearchParams({ startTs: startTs.toString(), endTs: endTs.toString(), @@ -111,5 +122,5 @@ export async function getFeatureTimeseries( algorithm, granularity }); - return get(`feature/${featureName}/timeseries?${params.toString()}`); + return get(`join/${joinId}/feature/${featureName}/timeseries?${params.toString()}`); } diff --git a/frontend/src/lib/components/PercentileChart/PercentileChart.svelte b/frontend/src/lib/components/PercentileChart/PercentileChart.svelte index bfd15b48c9..f473b04bd6 100644 --- a/frontend/src/lib/components/PercentileChart/PercentileChart.svelte +++ b/frontend/src/lib/components/PercentileChart/PercentileChart.svelte @@ -1,23 +1,19 @@ @@ -29,10 +35,15 @@
    - {#each joins as join} + {#each reorderedJoins as join} - + {join.name} From ec8c0f2f325487b7e8af01699f87a13216d2aa36 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Tue, 3 Dec 2024 10:30:11 -0500 Subject: [PATCH 109/152] update tests --- frontend/src/lib/api/api.ts | 16 ------ frontend/src/lib/types/Model/Model.test.ts | 59 ++++------------------ 2 files changed, 10 insertions(+), 65 deletions(-) diff --git a/frontend/src/lib/api/api.ts b/frontend/src/lib/api/api.ts index 1382ff292b..2ba9d7117d 100644 --- a/frontend/src/lib/api/api.ts +++ b/frontend/src/lib/api/api.ts @@ -43,22 +43,6 @@ export async function getJoins(offset: number = 0, limit: number = 10): Promise< return get(`joins?${params.toString()}`); } -export async function getModelTimeseries( - name: string, - startTs: number, - endTs: number, - offset: string = '10h', - algorithm: string = 'psi' -): Promise { - const params = new URLSearchParams({ - startTs: startTs.toString(), - endTs: endTs.toString(), - offset, - algorithm - }); - return get(`model/${name}/timeseries?${params.toString()}`); -} - export async function search(term: string, limit: number = 20): Promise { const params = new URLSearchParams({ term, diff --git a/frontend/src/lib/types/Model/Model.test.ts b/frontend/src/lib/types/Model/Model.test.ts index ecfc4dbc98..23a86227d1 100644 --- a/frontend/src/lib/types/Model/Model.test.ts +++ b/frontend/src/lib/types/Model/Model.test.ts @@ -45,45 +45,6 @@ describe('Model types', () => { } }); - it('should match TimeSeriesResponse type', async () => { - const result = (await api.getModels()) as ModelsResponse; - expect(result.items.length).toBeGreaterThan(0); - - const modelName = result.items[0].name; - const timeseriesResult = (await api.getModelTimeseries( - modelName, - 1725926400000, - 1726106400000 - )) as TimeSeriesResponse; - - const expectedKeys = ['id', 'items']; - expect(Object.keys(timeseriesResult)).toEqual(expect.arrayContaining(expectedKeys)); - - // Log a warning if there are additional fields - const additionalKeys = Object.keys(timeseriesResult).filter( - (key) => !expectedKeys.includes(key) - ); - if (additionalKeys.length > 0) { - console.warn(`Additional fields found in TimeSeriesResponse: ${additionalKeys.join(', ')}`); - } - - expect(Array.isArray(timeseriesResult.items)).toBe(true); - - if (timeseriesResult.items.length > 0) { - const item = timeseriesResult.items[0]; - const expectedItemKeys = ['value', 'ts', 'label', 'nullValue']; - expect(Object.keys(item)).toEqual(expect.arrayContaining(expectedItemKeys)); - - // Log a warning if there are additional fields - const additionalItemKeys = Object.keys(item).filter((key) => !expectedItemKeys.includes(key)); - if (additionalItemKeys.length > 0) { - console.warn( - `Additional fields found in TimeSeriesResponse item: ${additionalItemKeys.join(', ')}` - ); - } - } - }); - it('should match ModelsResponse type for search results', async () => { const searchTerm = 'risk.transaction_model.v1'; const limit = 5; @@ -132,11 +93,11 @@ describe('Model types', () => { const result = (await api.getModels()) as ModelsResponse; expect(result.items.length).toBeGreaterThan(0); - const modelName = result.items[0].name; + const modelName = 'risk.user_transactions.txn_join'; const joinResult = (await api.getJoinTimeseries({ joinId: modelName, - startTs: 1725926400000, - endTs: 1726106400000 + startTs: 1673308800000, + endTs: 1674172800000 })) as JoinTimeSeriesResponse; const expectedKeys = ['name', 'items']; @@ -202,12 +163,12 @@ describe('Model types', () => { }); it('should match FeatureResponse type', async () => { - const featureName = 'test_feature'; - const featureResult = (await api.getFeatureTimeseries( - featureName, - 1725926400000, - 1726106400000 - )) as FeatureResponse; + const featureResult = await api.getFeatureTimeseries({ + joinId: 'risk.user_transactions.txn_join', + featureName: 'dim_user_account_type', + startTs: 1673308800000, + endTs: 1674172800000 + }); const expectedKeys = ['feature', 'points']; expect(Object.keys(featureResult)).toEqual(expect.arrayContaining(expectedKeys)); @@ -220,7 +181,7 @@ describe('Model types', () => { expect(Array.isArray(featureResult.points)).toBe(true); - if (featureResult.points.length > 0) { + if (featureResult.points && featureResult.points?.length > 0) { const point = featureResult.points[0]; const expectedPointKeys = ['value', 'ts', 'label', 'nullValue']; expect(Object.keys(point)).toEqual(expect.arrayContaining(expectedPointKeys)); From 49a3f37b40c7cbaf5d42239aeb38a1533d53d8af Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Tue, 3 Dec 2024 10:33:37 -0500 Subject: [PATCH 110/152] remove unused imports --- frontend/src/lib/types/Model/Model.test.ts | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/frontend/src/lib/types/Model/Model.test.ts b/frontend/src/lib/types/Model/Model.test.ts index 23a86227d1..66aad9fdf3 100644 --- a/frontend/src/lib/types/Model/Model.test.ts +++ b/frontend/src/lib/types/Model/Model.test.ts @@ -1,12 +1,6 @@ import { describe, it, expect } from 'vitest'; import * as api from '$lib/api/api'; -import type { - ModelsResponse, - TimeSeriesResponse, - Model, - JoinTimeSeriesResponse, - FeatureResponse -} from '$lib/types/Model/Model'; +import type { ModelsResponse, Model, JoinTimeSeriesResponse } from '$lib/types/Model/Model'; describe('Model types', () => { it('should match ModelsResponse type', async () => { From 2d2f96cc7515e67c4076181a002bafc678ff5a02 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Tue, 3 Dec 2024 11:15:52 -0500 Subject: [PATCH 111/152] null ratio chart --- frontend/src/routes/joins/[slug]/+page.svelte | 88 +++++++++++++------ 1 file changed, 59 insertions(+), 29 deletions(-) diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index a8167d4226..713248e8da 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -7,11 +7,7 @@ import { ChartLine } from '@zipline-ai/icons'; import CollapsibleSection from '$lib/components/CollapsibleSection/CollapsibleSection.svelte'; import { connect } from 'echarts'; - import type { - FeatureResponse, - NullComparedFeatureResponse, - TimeSeriesItem - } from '$lib/types/Model/Model'; + import type { FeatureResponse, TimeSeriesItem } from '$lib/types/Model/Model'; import { ScrollArea } from '$lib/components/ui/scroll-area'; import { untrack } from 'svelte'; import PageHeader from '$lib/components/PageHeader/PageHeader.svelte'; @@ -23,7 +19,6 @@ import { parseDateRangeParams } from '$lib/util/date-ranges'; import { getFeatureTimeseries } from '$lib/api/api'; import { page } from '$app/stores'; - import { comparedFeatureNumericalSampleData, nullCountSampleData } from '$lib/util/sample-data'; import InfoTooltip from '$lib/components/InfoTooltip/InfoTooltip.svelte'; import { Table, TableBody, TableCell, TableRow } from '$lib/components/ui/table/index.js'; import TrueFalseBadge from '$lib/components/TrueFalseBadge/TrueFalseBadge.svelte'; @@ -193,7 +188,7 @@ } async function selectSeries(seriesName: string | undefined) { - [percentileData, comparedFeatureData] = [null, null]; + [percentileData, comparedFeatureData, nullData] = [null, null, null]; selectedSeries = seriesName; if (seriesName) { const { startTimestamp, endTimestamp } = parseDateRangeParams( @@ -201,22 +196,39 @@ ); try { - const featureData = await getFeatureTimeseries({ - joinId: joinTimeseries.name, - featureName: seriesName, - startTs: 1673308800000, // todo use startTimestamp - endTs: 1674172800000, // todo use endTimestamp - granularity: 'percentile', - metricType: 'drift', - metrics: 'value', - offset: '1D', - algorithm: 'psi' - }); + const [featureData, nullFeatureData] = await Promise.all([ + getFeatureTimeseries({ + joinId: joinTimeseries.name, + featureName: seriesName, + startTs: 1673308800000, // todo use startTimestamp + endTs: 1674172800000, // todo use endTimestamp + granularity: 'percentile', + metricType: 'drift', + metrics: 'value', + offset: '1D', + algorithm: 'psi' + }), + getFeatureTimeseries({ + joinId: joinTimeseries.name, + featureName: seriesName, + startTs: 1673308800000, // todo use startTimestamp + endTs: 1674172800000, // todo use endTimestamp + metricType: 'drift', + metrics: 'null', + offset: '1D', + algorithm: 'psi', + granularity: 'percentile' + }) + ]); + [percentileData, comparedFeatureData] = [featureData, featureData]; + nullData = nullFeatureData; + console.log(nullData); } catch (error) { console.error('Error fetching data:', error); percentileData = null; comparedFeatureData = null; + nullData = null; } } } @@ -285,43 +297,61 @@ }); } - function createNullRatioChartOption(data: NullComparedFeatureResponse): EChartOption { + let nullData: FeatureResponse | null = $state(null); + let nullRatioChartOption = $state({}); + + $effect(() => { + if (selectedSeries && selectedEvents[0]?.data) { + const timestamp = (selectedEvents[0].data as [number, number])[0]; + nullRatioChartOption = createNullRatioChartOption(nullData, timestamp); + } else { + nullRatioChartOption = {}; + } + }); + + function createNullRatioChartOption( + data: FeatureResponse | null, + timestamp: number + ): EChartOption { + if (!data?.current || !data?.baseline) return {}; + + // Get points at the selected timestamp + const currentPoint = data.current.find((point) => point.ts === timestamp); + const baselinePoint = data.baseline.find((point) => point.ts === timestamp); + + if (!currentPoint || !baselinePoint) return {}; + return createChartOption({ xAxis: { type: 'category', - data: ['Null Values', 'Non-null Values'] + data: ['Baseline', 'Current'] }, yAxis: { type: 'value' }, series: [ { - name: 'Baseline', + name: 'Null Values', type: 'bar', stack: 'total', emphasis: { focus: 'series' }, - data: [data.oldNullCount, data.oldValueCount] + data: [baselinePoint.nullValue, currentPoint.nullValue] } as EChartOption.Series, { - name: 'Current', + name: 'Non-null Values', type: 'bar', stack: 'total', emphasis: { focus: 'series' }, - data: [data.newNullCount, data.newValueCount] + data: [100 - baselinePoint.nullValue, 100 - currentPoint.nullValue] } as EChartOption.Series ] }); } - let nullRatioChartOption = $state({}); - $effect(() => { - nullRatioChartOption = createNullRatioChartOption(nullCountSampleData); - }); - $effect(() => { if (allCharts.length) { untrack(() => { From f3b75aeb51a80f3c44f00b6ddd935647669a8075 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Tue, 3 Dec 2024 11:27:18 -0500 Subject: [PATCH 112/152] update sample data, clearer wording --- frontend/src/lib/util/sample-data.ts | 13 ++++++++----- frontend/src/routes/joins/[slug]/+page.svelte | 9 ++++++--- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/frontend/src/lib/util/sample-data.ts b/frontend/src/lib/util/sample-data.ts index 81c6709961..01a6dd35b1 100644 --- a/frontend/src/lib/util/sample-data.ts +++ b/frontend/src/lib/util/sample-data.ts @@ -746,11 +746,14 @@ export const comparedFeatureCategoricalSampleData: RawComparedFeatureResponse = ] }; -export const nullCountSampleData: NullComparedFeatureResponse = { - oldNullCount: 10, - newNullCount: 20, - oldValueCount: 90, - newValueCount: 80 +export const nullCountSampleData: FeatureResponse = { + feature: 'feature_1', + current: [ + { ts: 1725926400000, value: 20, nullValue: 20 } // 20% null values in current + ], + baseline: [ + { ts: 1725926400000, value: 10, nullValue: 10 } // 10% null values in baseline + ] }; export function generatePercentileData( diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index 713248e8da..013fb8a07d 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -327,11 +327,14 @@ data: ['Baseline', 'Current'] }, yAxis: { - type: 'value' + type: 'value', + axisLabel: { + formatter: '{value}%' + } }, series: [ { - name: 'Null Values', + name: 'Null Value Percentage', type: 'bar', stack: 'total', emphasis: { @@ -340,7 +343,7 @@ data: [baselinePoint.nullValue, currentPoint.nullValue] } as EChartOption.Series, { - name: 'Non-null Values', + name: 'Non-null Value Percentage', type: 'bar', stack: 'total', emphasis: { From 8a7551b9e523ea3c026f0d33267f1f1d3c412b44 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Tue, 3 Dec 2024 11:28:02 -0500 Subject: [PATCH 113/152] remove console.log --- frontend/src/routes/joins/[slug]/+page.svelte | 1 - 1 file changed, 1 deletion(-) diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index 013fb8a07d..f6235d2213 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -223,7 +223,6 @@ [percentileData, comparedFeatureData] = [featureData, featureData]; nullData = nullFeatureData; - console.log(nullData); } catch (error) { console.error('Error fetching data:', error); percentileData = null; From 9bf521f3021d954b22dc652f323750ee68c68ce5 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Tue, 3 Dec 2024 12:17:32 -0500 Subject: [PATCH 114/152] add isNumeric --- frontend/src/lib/types/Model/Model.ts | 1 + frontend/src/routes/joins/[slug]/+page.svelte | 70 +++++++++++-------- 2 files changed, 42 insertions(+), 29 deletions(-) diff --git a/frontend/src/lib/types/Model/Model.ts b/frontend/src/lib/types/Model/Model.ts index 6fba50cb80..b45377a094 100644 --- a/frontend/src/lib/types/Model/Model.ts +++ b/frontend/src/lib/types/Model/Model.ts @@ -47,6 +47,7 @@ export type JoinTimeSeriesResponse = { export type FeatureResponse = { feature: string; + isNumeric?: boolean; points?: TimeSeriesItem[]; baseline?: TimeSeriesItem[]; current?: TimeSeriesItem[]; diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index f6235d2213..97a95a52a9 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -188,8 +188,10 @@ } async function selectSeries(seriesName: string | undefined) { + // Reset all data states [percentileData, comparedFeatureData, nullData] = [null, null, null]; selectedSeries = seriesName; + if (seriesName) { const { startTimestamp, endTimestamp } = parseDateRangeParams( new URL($page.url).searchParams @@ -221,7 +223,14 @@ }) ]); - [percentileData, comparedFeatureData] = [featureData, featureData]; + if (featureData.isNumeric) { + percentileData = featureData; + comparedFeatureData = null; + } else { + percentileData = null; + comparedFeatureData = featureData; + } + nullData = nullFeatureData; } catch (error) { console.error('Error fetching data:', error); @@ -559,35 +568,38 @@ {/if} {#if selectedSeries} - - {#snippet collapsibleContent()} - - {/snippet} - - - - {#snippet headerContentRight()} - {#if isComparedFeatureZoomed} -
    - -
    - {/if} - {/snippet} - {#snippet collapsibleContent()} - - {/snippet} -
    + {#if percentileData} + + {#snippet collapsibleContent()} + + {/snippet} + + {/if} + {#if comparedFeatureData} + + {#snippet headerContentRight()} + {#if isComparedFeatureZoomed} +
    + +
    + {/if} + {/snippet} + {#snippet collapsibleContent()} + + {/snippet} +
    + {/if} {#snippet collapsibleContent()} Date: Tue, 3 Dec 2024 12:19:43 -0500 Subject: [PATCH 115/152] mock isNumeric --- frontend/src/routes/joins/[slug]/+page.svelte | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index 97a95a52a9..27d72fde4b 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -187,6 +187,15 @@ return 'N/A'; } + function isNumericFeature(featureData: FeatureResponse): boolean { + // Check if any point in current data has a label starting with 'p' + return ( + featureData.current?.some( + (point) => typeof point.label === 'string' && point.label === 'p0' + ) ?? false + ); + } + async function selectSeries(seriesName: string | undefined) { // Reset all data states [percentileData, comparedFeatureData, nullData] = [null, null, null]; @@ -223,7 +232,8 @@ }) ]); - if (featureData.isNumeric) { + // todo use featureData.isNumeric when backend returns it + if (isNumericFeature(featureData)) { percentileData = featureData; comparedFeatureData = null; } else { From 48670d32c5821667f48ba824b108d776f81ec2f8 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Tue, 3 Dec 2024 12:50:24 -0500 Subject: [PATCH 116/152] Add field to indicate if feature / drift data is numeric --- .../controllers/TimeSeriesController.scala | 30 +++++++++++-------- hub/app/model/Model.scala | 4 +-- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/hub/app/controllers/TimeSeriesController.scala b/hub/app/controllers/TimeSeriesController.scala index 68c0d43c5f..5d9f8694f2 100644 --- a/hub/app/controllers/TimeSeriesController.scala +++ b/hub/app/controllers/TimeSeriesController.scala @@ -145,6 +145,10 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon granularity: Granularity, offset: Option[String], algorithm: Option[String]): Future[Result] = { + def checkIfNumeric(summarySeries: TileSummarySeries) = { + summarySeries.percentiles.asScala != null && summarySeries.percentiles.asScala.exists(_ != null) + } + if (granularity == Raw) { Future.successful(BadRequest("We don't support Raw granularity for drift metric types")) } else { @@ -188,15 +192,19 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon Future.sequence(Seq(currentSummarySeriesFuture, baselineSummarySeriesFuture)).map { merged => val currentSummarySeries = merged.head val baselineSummarySeries = merged.last + + val isCurrentNumeric = currentSummarySeries.headOption.forall(checkIfNumeric) + val isBaselineNumeric = baselineSummarySeries.headOption.forall(checkIfNumeric) + val currentFeatureTs = { if (currentSummarySeries.isEmpty) Seq.empty - else convertTileSummarySeriesToTimeSeries(currentSummarySeries.head, metric) + else convertTileSummarySeriesToTimeSeries(currentSummarySeries.head, isCurrentNumeric, metric) } val baselineFeatureTs = { if (baselineSummarySeries.isEmpty) Seq.empty - else convertTileSummarySeriesToTimeSeries(baselineSummarySeries.head, metric) + else convertTileSummarySeriesToTimeSeries(baselineSummarySeries.head, isBaselineNumeric, metric) } - val comparedTsData = ComparedFeatureTimeSeries(name, baselineFeatureTs, currentFeatureTs) + val comparedTsData = ComparedFeatureTimeSeries(name, isCurrentNumeric, baselineFeatureTs, currentFeatureTs) Ok(comparedTsData.asJson.noSpaces) } } @@ -207,14 +215,14 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon private def convertTileDriftSeriesInfoToTimeSeries(tileDriftSeries: TileDriftSeries, metric: Metric): FeatureTimeSeries = { + // check if we have a numeric / categorical feature. If the percentile drift series has non-null doubles + // then we have a numeric feature at hand + val isNumeric = + tileDriftSeries.percentileDriftSeries.asScala != null && tileDriftSeries.percentileDriftSeries.asScala + .exists(_ != null) val lhsList = if (metric == NullMetric) { tileDriftSeries.nullRatioChangePercentSeries.asScala } else { - // check if we have a numeric / categorical feature. If the percentile drift series has non-null doubles - // then we have a numeric feature at hand - val isNumeric = - tileDriftSeries.percentileDriftSeries.asScala != null && tileDriftSeries.percentileDriftSeries.asScala - .exists(_ != null) if (isNumeric) tileDriftSeries.percentileDriftSeries.asScala else tileDriftSeries.histogramDriftSeries.asScala } @@ -222,19 +230,17 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon case (v, ts) => TimeSeriesPoint(v, ts) } - FeatureTimeSeries(tileDriftSeries.getKey.getColumn, points) + FeatureTimeSeries(tileDriftSeries.getKey.getColumn, isNumeric, points) } private def convertTileSummarySeriesToTimeSeries(summarySeries: TileSummarySeries, + isNumeric: Boolean, metric: Metric): Seq[TimeSeriesPoint] = { if (metric == NullMetric) { summarySeries.nullCount.asScala.zip(summarySeries.timestamps.asScala).map { case (nullCount, ts) => TimeSeriesPoint(0, ts, nullValue = Some(nullCount.intValue())) } } else { - // check if we have a numeric / categorical feature. If the percentile drift series has non-null doubles - // then we have a numeric feature at hand - val isNumeric = summarySeries.percentiles.asScala != null && summarySeries.percentiles.asScala.exists(_ != null) if (isNumeric) { summarySeries.percentiles.asScala.zip(summarySeries.timestamps.asScala).flatMap { case (percentiles, ts) => diff --git a/hub/app/model/Model.scala b/hub/app/model/Model.scala index 39aa971892..f8d3df928c 100644 --- a/hub/app/model/Model.scala +++ b/hub/app/model/Model.scala @@ -50,8 +50,8 @@ case object Percentile extends Granularity case object Aggregates extends Granularity case class TimeSeriesPoint(value: Double, ts: Long, label: Option[String] = None, nullValue: Option[Int] = None) -case class FeatureTimeSeries(feature: String, points: Seq[TimeSeriesPoint]) -case class ComparedFeatureTimeSeries(feature: String, baseline: Seq[TimeSeriesPoint], current: Seq[TimeSeriesPoint]) +case class FeatureTimeSeries(feature: String, isNumeric: Boolean, points: Seq[TimeSeriesPoint]) +case class ComparedFeatureTimeSeries(feature: String, isNumeric: Boolean, baseline: Seq[TimeSeriesPoint], current: Seq[TimeSeriesPoint]) case class GroupByTimeSeries(name: String, items: Seq[FeatureTimeSeries]) // Currently search only covers joins From d0dcb3de87d2edbd9acafa45f208201a22eba2f4 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Tue, 3 Dec 2024 12:50:50 -0500 Subject: [PATCH 117/152] style: Apply scalafix and scalafmt changes --- hub/app/controllers/TimeSeriesController.scala | 3 ++- hub/app/model/Model.scala | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/hub/app/controllers/TimeSeriesController.scala b/hub/app/controllers/TimeSeriesController.scala index 5d9f8694f2..02bc4fc160 100644 --- a/hub/app/controllers/TimeSeriesController.scala +++ b/hub/app/controllers/TimeSeriesController.scala @@ -204,7 +204,8 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon if (baselineSummarySeries.isEmpty) Seq.empty else convertTileSummarySeriesToTimeSeries(baselineSummarySeries.head, isBaselineNumeric, metric) } - val comparedTsData = ComparedFeatureTimeSeries(name, isCurrentNumeric, baselineFeatureTs, currentFeatureTs) + val comparedTsData = + ComparedFeatureTimeSeries(name, isCurrentNumeric, baselineFeatureTs, currentFeatureTs) Ok(comparedTsData.asJson.noSpaces) } } diff --git a/hub/app/model/Model.scala b/hub/app/model/Model.scala index f8d3df928c..2c589b95ba 100644 --- a/hub/app/model/Model.scala +++ b/hub/app/model/Model.scala @@ -51,7 +51,10 @@ case object Aggregates extends Granularity case class TimeSeriesPoint(value: Double, ts: Long, label: Option[String] = None, nullValue: Option[Int] = None) case class FeatureTimeSeries(feature: String, isNumeric: Boolean, points: Seq[TimeSeriesPoint]) -case class ComparedFeatureTimeSeries(feature: String, isNumeric: Boolean, baseline: Seq[TimeSeriesPoint], current: Seq[TimeSeriesPoint]) +case class ComparedFeatureTimeSeries(feature: String, + isNumeric: Boolean, + baseline: Seq[TimeSeriesPoint], + current: Seq[TimeSeriesPoint]) case class GroupByTimeSeries(name: String, items: Seq[FeatureTimeSeries]) // Currently search only covers joins From 848cd1b079c1d2bfdc10fadecbf6f94f90b63b0e Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Tue, 3 Dec 2024 13:56:54 -0500 Subject: [PATCH 118/152] use isNumeric from backend --- frontend/src/routes/joins/[slug]/+page.svelte | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index 27d72fde4b..97a95a52a9 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -187,15 +187,6 @@ return 'N/A'; } - function isNumericFeature(featureData: FeatureResponse): boolean { - // Check if any point in current data has a label starting with 'p' - return ( - featureData.current?.some( - (point) => typeof point.label === 'string' && point.label === 'p0' - ) ?? false - ); - } - async function selectSeries(seriesName: string | undefined) { // Reset all data states [percentileData, comparedFeatureData, nullData] = [null, null, null]; @@ -232,8 +223,7 @@ }) ]); - // todo use featureData.isNumeric when backend returns it - if (isNumericFeature(featureData)) { + if (featureData.isNumeric) { percentileData = featureData; comparedFeatureData = null; } else { From 84b56f89a771bcbf85fad22cfd630e2a7c4d4ab5 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Tue, 3 Dec 2024 13:58:59 -0500 Subject: [PATCH 119/152] no need to reset the data, as it made changing graphs choppy --- frontend/src/routes/joins/[slug]/+page.svelte | 1 - 1 file changed, 1 deletion(-) diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index 97a95a52a9..77f3d5acfa 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -189,7 +189,6 @@ async function selectSeries(seriesName: string | undefined) { // Reset all data states - [percentileData, comparedFeatureData, nullData] = [null, null, null]; selectedSeries = seriesName; if (seriesName) { From da71abf32ed144ef3c5edd990bc37af1c9cbca22 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Tue, 3 Dec 2024 15:19:35 -0500 Subject: [PATCH 120/152] use real data for distribution percentile charts --- .../src/routes/joins/[slug]/+page.server.ts | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/frontend/src/routes/joins/[slug]/+page.server.ts b/frontend/src/routes/joins/[slug]/+page.server.ts index 2c2f37298d..ddb780360a 100644 --- a/frontend/src/routes/joins/[slug]/+page.server.ts +++ b/frontend/src/routes/joins/[slug]/+page.server.ts @@ -2,7 +2,6 @@ import type { PageServerLoad } from './$types'; import * as api from '$lib/api/api'; import type { JoinTimeSeriesResponse, Model, FeatureResponse } from '$lib/types/Model/Model'; import { parseDateRangeParams } from '$lib/util/date-ranges'; -import { generatePercentileData } from '$lib/util/sample-data'; import { getMetricTypeFromParams, type MetricType } from '$lib/types/MetricType/MetricType'; import { getSortDirection } from '$lib/types/SortDirection/SortDirection'; @@ -46,9 +45,31 @@ export const load: PageServerLoad = async ({ }); }); - const modelToReturn = models.items.find((m) => m.join.name === joinName); + // Get all unique feature names across all groups + const allFeatures = Array.from( + new Set(joinTimeseries.items.flatMap((group) => group.items.map((item) => item.feature))) + ); - const distributions = generatePercentileData(10, 240); // todo remove sample data when backend is ready + // Fetch percentile data for each feature + const distributionsPromises = allFeatures.map((featureName) => + api.getFeatureTimeseries({ + joinId: joinName, + featureName, + startTs: 1673308800000, // todo use dateRange.startTimestamp once backend has data for all joins + endTs: 1674172800000, // todo use dateRange.endTimestamp once backend has data for all joins + granularity: 'percentile', + metricType: 'drift', + metrics: 'value', + offset: '1D', + algorithm: 'psi' + }) + ); + + const distributions = (await Promise.all(distributionsPromises)).filter( + (response) => response.isNumeric + ); + + const modelToReturn = models.items.find((m) => m.join.name === joinName); return { joinTimeseries, From dccae9f6dfee30027f0f9304129bb8cf5f6578cd Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Tue, 3 Dec 2024 20:01:31 -0500 Subject: [PATCH 121/152] ux improvement for clicking a point when holding cmd --- .../src/lib/components/EChart/EChart.svelte | 37 +++++++++++++++++-- frontend/src/routes/joins/[slug]/+page.svelte | 7 +++- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/frontend/src/lib/components/EChart/EChart.svelte b/frontend/src/lib/components/EChart/EChart.svelte index 2615546b37..6575ae624b 100644 --- a/frontend/src/lib/components/EChart/EChart.svelte +++ b/frontend/src/lib/components/EChart/EChart.svelte @@ -101,11 +101,38 @@ if (event.metaKey || event.ctrlKey) { isCommandPressed = true; disableChartInteractions(); + + if (exactX !== null && chartInstance) { + const option = chartInstance.getOption(); + const series = option.series as EChartOption.Series[]; + const firstSeries = series[0]; + + if (Array.isArray(firstSeries.data)) { + // Find the index of the point with matching x-value + const dataIndex = firstSeries.data.findIndex( + (point) => (point as [number, number])[0] === exactX + ); + + if (dataIndex !== -1) { + // for some reason, we need to showTip somewhere else first + chartInstance.dispatchAction({ + type: 'showTip', + seriesIndex: 0, + dataIndex: -1 + }); + chartInstance.dispatchAction({ + type: 'showTip', + seriesIndex: 0, + dataIndex: dataIndex + }); + } + } + } } } function handleKeyUp(event: KeyboardEvent) { - if (!event.metaKey && !event.ctrlKey) { + if (event.key === 'Meta' || event.key === 'Control') { isCommandPressed = false; enableChartInteractions(); if (!isMouseOverTooltip) { @@ -116,13 +143,13 @@ function disableChartInteractions() { chartInstance?.setOption({ - silent: true + triggerOn: 'none' } as EChartOption); } function enableChartInteractions() { chartInstance?.setOption({ - silent: false + triggerOn: 'mousemove' } as EChartOption); } @@ -208,6 +235,8 @@ zr.on('globalout', hideTooltip); } + let exactX = $state(null); + function showTooltip(params: { offsetX: number; offsetY: number }) { if (isCommandPressed) return; @@ -274,7 +303,7 @@ return Math.abs(currX - pointInGrid[0]) < Math.abs(prevX - pointInGrid[0]) ? curr : prev; }); - const exactX = (nearestPoint as [number, number])[0]; + exactX = (nearestPoint as [number, number])[0]; // Get values for all series at this exact x-coordinate const seriesData = series diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index 77f3d5acfa..cde7790c89 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -75,8 +75,13 @@ name: feature, type: 'line', data: points.map((point) => [point.ts, point.value]), + symbolSize: 16, emphasis: { - focus: 'series' + focus: 'series', + itemStyle: { + borderWidth: 2, + borderColor: '#fff' + } } })) as EChartOption.Series[]; From acca95f8c06f08b3a3b7a9de66510a8982cd8156 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Tue, 3 Dec 2024 20:18:23 -0500 Subject: [PATCH 122/152] better ux after hiding a series --- .../CustomEChartLegend/CustomEChartLegend.svelte | 4 ++++ frontend/src/lib/util/chart.ts | 15 +++++++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/frontend/src/lib/components/CustomEChartLegend/CustomEChartLegend.svelte b/frontend/src/lib/components/CustomEChartLegend/CustomEChartLegend.svelte index 54587b3b8b..fe501b789a 100644 --- a/frontend/src/lib/components/CustomEChartLegend/CustomEChartLegend.svelte +++ b/frontend/src/lib/components/CustomEChartLegend/CustomEChartLegend.svelte @@ -27,6 +27,10 @@ groupSet.delete(seriesName); } else { groupSet.add(seriesName); + chart.dispatchAction({ + type: 'downplay', + seriesName + }); } hiddenSeries = { diff --git a/frontend/src/lib/util/chart.ts b/frontend/src/lib/util/chart.ts index 0d96021421..62f6a7d89f 100644 --- a/frontend/src/lib/util/chart.ts +++ b/frontend/src/lib/util/chart.ts @@ -7,10 +7,17 @@ export function handleChartHighlight( ) { if (!chart || !seriesName) return; - chart.dispatchAction({ - type, - seriesName - }); + // Get the series selected state from legend + const options = chart.getOption(); + const isSelected = options.legend?.[0]?.selected?.[seriesName]; + + // Only highlight if the series is selected (visible) + if (isSelected !== false) { + chart.dispatchAction({ + type, + seriesName + }); + } } export function getSeriesColor(chart: EChartsType | null, seriesName: string): string { From a9007697efb5854a9865e3f2035e1d648790b4f1 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Wed, 4 Dec 2024 11:43:01 -0500 Subject: [PATCH 123/152] fix ts issue --- frontend/src/lib/util/chart.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/frontend/src/lib/util/chart.ts b/frontend/src/lib/util/chart.ts index 62f6a7d89f..fc0e5e208a 100644 --- a/frontend/src/lib/util/chart.ts +++ b/frontend/src/lib/util/chart.ts @@ -9,7 +9,8 @@ export function handleChartHighlight( // Get the series selected state from legend const options = chart.getOption(); - const isSelected = options.legend?.[0]?.selected?.[seriesName]; + const legendOpt = Array.isArray(options.legend) ? options.legend[0] : options.legend; + const isSelected = legendOpt?.selected?.[seriesName]; // Only highlight if the series is selected (visible) if (isSelected !== false) { From 27e4b17b95e5aed39cd04641193f1bcab931c882 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Wed, 4 Dec 2024 11:44:42 -0500 Subject: [PATCH 124/152] remove unused imports --- frontend/src/lib/api/api.ts | 3 +-- frontend/src/lib/util/sample-data.ts | 6 +----- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/frontend/src/lib/api/api.ts b/frontend/src/lib/api/api.ts index 2ba9d7117d..b7a585b22a 100644 --- a/frontend/src/lib/api/api.ts +++ b/frontend/src/lib/api/api.ts @@ -2,8 +2,7 @@ import type { FeatureResponse, JoinsResponse, JoinTimeSeriesResponse, - ModelsResponse, - TimeSeriesResponse + ModelsResponse } from '$lib/types/Model/Model'; import { error } from '@sveltejs/kit'; import { browser } from '$app/environment'; diff --git a/frontend/src/lib/util/sample-data.ts b/frontend/src/lib/util/sample-data.ts index 01a6dd35b1..75680a32b9 100644 --- a/frontend/src/lib/util/sample-data.ts +++ b/frontend/src/lib/util/sample-data.ts @@ -1,8 +1,4 @@ -import type { - FeatureResponse, - NullComparedFeatureResponse, - RawComparedFeatureResponse -} from '$lib/types/Model/Model'; +import type { FeatureResponse, RawComparedFeatureResponse } from '$lib/types/Model/Model'; export const percentileSampleData: FeatureResponse = { feature: 'my_feat', From 329eed31518fc0e57473bef08454693662840699 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Wed, 4 Dec 2024 13:32:55 -0500 Subject: [PATCH 125/152] generate data for the entire year of 2023 --- .../chronon/spark/scripts/ObservabilityDemoDataLoader.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala index f317488273..9ddcf2e7cb 100644 --- a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala +++ b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala @@ -37,13 +37,13 @@ object ObservabilityDemoDataLoader { val endDs: ScallopOption[String] = opt[String]( name = "end-ds", - default = Some("2023-02-30"), + default = Some("2023-12-31"), descr = "End date in YYYY-MM-DD format" ) val rowCount: ScallopOption[Int] = opt[Int]( name = "row-count", - default = Some(700000), + default = Some(1400000), descr = "Number of rows to generate" ) From 43a960f437b64cac0df7850fd1cd384b5eca9b54 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Wed, 4 Dec 2024 13:33:20 -0500 Subject: [PATCH 126/152] make sure endTs is used to filter down data --- .../src/main/scala/ai/chronon/spark/utils/InMemoryKvStore.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/ai/chronon/spark/utils/InMemoryKvStore.scala b/spark/src/main/scala/ai/chronon/spark/utils/InMemoryKvStore.scala index 22ce5767db..bd4ef78ae0 100644 --- a/spark/src/main/scala/ai/chronon/spark/utils/InMemoryKvStore.scala +++ b/spark/src/main/scala/ai/chronon/spark/utils/InMemoryKvStore.scala @@ -61,7 +61,7 @@ class InMemoryKvStore(tableUtils: () => TableUtils) extends KVStore with Seriali else valueSeries .filter { - case (version, _) => req.startTsMillis.forall(version >= _) + case (version, _) => req.startTsMillis.forall(version >= _) && req.endTsMillis.forall(version <= _) } // filter version .map { case (version, bytes) => TimedValue(bytes, version) } } From 9d7c0e8223f60427ebdebde152e9494a45dafbce Mon Sep 17 00:00:00 2001 From: Ken Morton Date: Wed, 4 Dec 2024 13:48:14 -0500 Subject: [PATCH 127/152] ensure 1/31/2023 is included --- .../ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala index 9ddcf2e7cb..b73197d023 100644 --- a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala +++ b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala @@ -37,7 +37,7 @@ object ObservabilityDemoDataLoader { val endDs: ScallopOption[String] = opt[String]( name = "end-ds", - default = Some("2023-12-31"), + default = Some("2024-1-1"), descr = "End date in YYYY-MM-DD format" ) From 5e6b530f1dbecd84e0fabf899e650bdbf1a81b17 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Wed, 4 Dec 2024 14:58:08 -0500 Subject: [PATCH 128/152] date as YYYY-MM-DD --- .../ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala index b73197d023..258533a6ed 100644 --- a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala +++ b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala @@ -37,7 +37,7 @@ object ObservabilityDemoDataLoader { val endDs: ScallopOption[String] = opt[String]( name = "end-ds", - default = Some("2024-1-1"), + default = Some("2024-01-01"), descr = "End date in YYYY-MM-DD format" ) From 31e277f369423b54438581ee6d59fe05c1644305 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Wed, 4 Dec 2024 15:04:30 -0500 Subject: [PATCH 129/152] distribution charts side by side bars --- frontend/src/routes/joins/[slug]/+page.svelte | 2 -- 1 file changed, 2 deletions(-) diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index cde7790c89..4abf586f8c 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -284,7 +284,6 @@ { name: 'Baseline', type: 'bar', - stack: 'total', emphasis: { focus: 'series' }, @@ -296,7 +295,6 @@ { name: 'Current', type: 'bar', - stack: 'total', emphasis: { focus: 'series' }, From 643102db871a39bbd7e0f173a2855ae741451428 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Wed, 4 Dec 2024 15:21:55 -0500 Subject: [PATCH 130/152] add alert --- .../ui/alert/alert-description.svelte | 13 ++++++++ .../components/ui/alert/alert-title.svelte | 21 ++++++++++++ .../src/lib/components/ui/alert/alert.svelte | 17 ++++++++++ frontend/src/lib/components/ui/alert/index.ts | 32 +++++++++++++++++++ 4 files changed, 83 insertions(+) create mode 100644 frontend/src/lib/components/ui/alert/alert-description.svelte create mode 100644 frontend/src/lib/components/ui/alert/alert-title.svelte create mode 100644 frontend/src/lib/components/ui/alert/alert.svelte create mode 100644 frontend/src/lib/components/ui/alert/index.ts diff --git a/frontend/src/lib/components/ui/alert/alert-description.svelte b/frontend/src/lib/components/ui/alert/alert-description.svelte new file mode 100644 index 0000000000..06d344cc33 --- /dev/null +++ b/frontend/src/lib/components/ui/alert/alert-description.svelte @@ -0,0 +1,13 @@ + + +
    + +
    diff --git a/frontend/src/lib/components/ui/alert/alert-title.svelte b/frontend/src/lib/components/ui/alert/alert-title.svelte new file mode 100644 index 0000000000..c63089bde4 --- /dev/null +++ b/frontend/src/lib/components/ui/alert/alert-title.svelte @@ -0,0 +1,21 @@ + + + + + diff --git a/frontend/src/lib/components/ui/alert/alert.svelte b/frontend/src/lib/components/ui/alert/alert.svelte new file mode 100644 index 0000000000..0bf6eec74d --- /dev/null +++ b/frontend/src/lib/components/ui/alert/alert.svelte @@ -0,0 +1,17 @@ + + + diff --git a/frontend/src/lib/components/ui/alert/index.ts b/frontend/src/lib/components/ui/alert/index.ts new file mode 100644 index 0000000000..1bb6987c09 --- /dev/null +++ b/frontend/src/lib/components/ui/alert/index.ts @@ -0,0 +1,32 @@ +import { type VariantProps, tv } from 'tailwind-variants'; + +import Root from './alert.svelte'; +import Description from './alert-description.svelte'; +import Title from './alert-title.svelte'; + +export const alertVariants = tv({ + base: '[&>svg]:text-foreground relative w-full rounded-lg border px-4 py-3 text-sm [&>svg+div]:translate-y-[-3px] [&>svg]:absolute [&>svg]:left-4 [&>svg]:top-4 [&>svg~*]:pl-7', + variants: { + variant: { + default: 'bg-background text-foreground', + destructive: + 'border-destructive/50 text-destructive dark:border-destructive [&>svg]:text-destructive' + } + }, + defaultVariants: { + variant: 'default' + } +}); + +export type Variant = VariantProps['variant']; +export type HeadingLevel = 'h1' | 'h2' | 'h3' | 'h4' | 'h5' | 'h6'; + +export { + Root, + Description, + Title, + // + Root as Alert, + Description as AlertDescription, + Title as AlertTitle +}; From d70720d60375134143dee2629387452ffd6e1c24 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Wed, 4 Dec 2024 15:46:55 -0500 Subject: [PATCH 131/152] use real startTs and endTs with fallback treatment --- .../ChartControls/ChartControls.svelte | 31 +++++++-- .../src/routes/joins/[slug]/+page.server.ts | 66 +++++++++++++++++-- frontend/src/routes/joins/[slug]/+page.svelte | 39 +++++++++-- 3 files changed, 117 insertions(+), 19 deletions(-) diff --git a/frontend/src/lib/components/ChartControls/ChartControls.svelte b/frontend/src/lib/components/ChartControls/ChartControls.svelte index 95ffd9da9e..35652064b3 100644 --- a/frontend/src/lib/components/ChartControls/ChartControls.svelte +++ b/frontend/src/lib/components/ChartControls/ChartControls.svelte @@ -2,20 +2,39 @@ import ResetZoomButton from '$lib/components/ResetZoomButton/ResetZoomButton.svelte'; import MetricTypeToggle from '$lib/components/MetricTypeToggle/MetricTypeToggle.svelte'; import DateRangeSelector from '$lib/components/DateRangeSelector/DateRangeSelector.svelte'; + import * as Alert from '$lib/components/ui/alert/index.js'; + import { formatDate } from '$lib/util/format'; let { isZoomed = false, - onResetZoom + onResetZoom, + isUsingFallbackDates = false, + dateRange = { startTimestamp: 0, endTimestamp: 0 } }: { isZoomed: boolean; onResetZoom: () => void; + isUsingFallbackDates?: boolean; + dateRange?: { startTimestamp: number; endTimestamp: number }; } = $props(); -
    - {#if isZoomed} - +
    + {#if isUsingFallbackDates} +
    + + + No data for that date range. Showing data between {formatDate(dateRange.startTimestamp)} and + {formatDate(dateRange.endTimestamp)} + +
    {/if} - - + +
    + {#if isZoomed} + + {/if} + + +
    diff --git a/frontend/src/routes/joins/[slug]/+page.server.ts b/frontend/src/routes/joins/[slug]/+page.server.ts index ddb780360a..3378b5cbf4 100644 --- a/frontend/src/routes/joins/[slug]/+page.server.ts +++ b/frontend/src/routes/joins/[slug]/+page.server.ts @@ -5,6 +5,9 @@ import { parseDateRangeParams } from '$lib/util/date-ranges'; import { getMetricTypeFromParams, type MetricType } from '$lib/types/MetricType/MetricType'; import { getSortDirection } from '$lib/types/SortDirection/SortDirection'; +const FALLBACK_START_TS = 1672531200000; // 2023-01-01 +const FALLBACK_END_TS = 1677628800000; // 2023-03-01 + export const load: PageServerLoad = async ({ params, url @@ -13,17 +16,68 @@ export const load: PageServerLoad = async ({ model?: Model; distributions: FeatureResponse[]; metricType: MetricType; + dateRange: { + startTimestamp: number; + endTimestamp: number; + dateRangeValue: string; + isUsingFallback: boolean; + }; }> => { - const dateRange = parseDateRangeParams(url.searchParams); + const requestedDateRange = parseDateRangeParams(url.searchParams); const joinName = params.slug; const metricType = getMetricTypeFromParams(url.searchParams); const sortDirection = getSortDirection(url.searchParams); + // Try with requested date range first + try { + const data = await fetchAllData( + joinName, + requestedDateRange.startTimestamp, + requestedDateRange.endTimestamp, + metricType, + sortDirection + ); + return { + ...data, + dateRange: { + ...requestedDateRange, + isUsingFallback: false + } + }; + } catch (error) { + console.error('Error fetching data:', error); + // If the requested range fails, fall back to the known working range + const data = await fetchAllData( + joinName, + FALLBACK_START_TS, + FALLBACK_END_TS, + metricType, + sortDirection + ); + return { + ...data, + dateRange: { + startTimestamp: FALLBACK_START_TS, + endTimestamp: FALLBACK_END_TS, + dateRangeValue: requestedDateRange.dateRangeValue, // Preserve the user's selected range + isUsingFallback: true + } + }; + } +}; + +async function fetchAllData( + joinName: string, + startTs: number, + endTs: number, + metricType: MetricType, + sortDirection: string +) { const [joinTimeseries, models] = await Promise.all([ api.getJoinTimeseries({ joinId: joinName, - startTs: 1673308800000, // todo use dateRange.startTimestamp once backend has data for all joins - endTs: 1674172800000, // todo use dateRange.endTimestamp once backend has data for all joins + startTs, + endTs, metricType: 'drift', metrics: 'value', offset: undefined, @@ -55,8 +109,8 @@ export const load: PageServerLoad = async ({ api.getFeatureTimeseries({ joinId: joinName, featureName, - startTs: 1673308800000, // todo use dateRange.startTimestamp once backend has data for all joins - endTs: 1674172800000, // todo use dateRange.endTimestamp once backend has data for all joins + startTs, + endTs, granularity: 'percentile', metricType: 'drift', metrics: 'value', @@ -77,4 +131,4 @@ export const load: PageServerLoad = async ({ distributions, metricType }; -}; +} diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index 4abf586f8c..dc3c213a3b 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -37,6 +37,7 @@ const joinTimeseries = $derived(data.joinTimeseries); const model = $derived(data.model); const distributions = $derived(data.distributions); + const isUsingFallbackDates = $derived(data.dateRange.isUsingFallback); let isFeatureMonitoringOpen = $state(true); let isSheetOpen = $state(false); let selectedEvents = $state([]); @@ -206,8 +207,8 @@ getFeatureTimeseries({ joinId: joinTimeseries.name, featureName: seriesName, - startTs: 1673308800000, // todo use startTimestamp - endTs: 1674172800000, // todo use endTimestamp + startTs: startTimestamp, + endTs: endTimestamp, granularity: 'percentile', metricType: 'drift', metrics: 'value', @@ -217,8 +218,8 @@ getFeatureTimeseries({ joinId: joinTimeseries.name, featureName: seriesName, - startTs: 1673308800000, // todo use startTimestamp - endTs: 1674172800000, // todo use endTimestamp + startTs: startTimestamp, + endTs: endTimestamp, metricType: 'drift', metrics: 'null', offset: '1D', @@ -414,7 +415,15 @@ class="sticky top-0 z-20 bg-neutral-200 border-b border-border -mx-8 py-2 px-8 border-l" transition:fade={{ duration: 150 }} > - +
    {/if} @@ -451,7 +460,15 @@
    - + @@ -536,7 +553,15 @@ {selectedSeries ? `${selectedSeries} at ` : ''}{formatEventDate()} - + From 987876f5b881f29d03000f68b71cafde718460b4 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Wed, 4 Dec 2024 17:04:21 -0500 Subject: [PATCH 132/152] warning alert --- frontend/src/lib/components/ChartControls/ChartControls.svelte | 2 +- frontend/src/lib/components/ui/alert/index.ts | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/frontend/src/lib/components/ChartControls/ChartControls.svelte b/frontend/src/lib/components/ChartControls/ChartControls.svelte index 35652064b3..6b684a9535 100644 --- a/frontend/src/lib/components/ChartControls/ChartControls.svelte +++ b/frontend/src/lib/components/ChartControls/ChartControls.svelte @@ -21,7 +21,7 @@
    {#if isUsingFallbackDates}
    - + No data for that date range. Showing data between {formatDate(dateRange.startTimestamp)} and {formatDate(dateRange.endTimestamp)}svg]:text-destructive' + 'border-destructive/50 text-destructive dark:border-destructive [&>svg]:text-destructive', + warning: 'border-warning-800 text-warning-800 [&>svg]:text-warning-800' } }, defaultVariants: { From 29da77e23ee1f851c47f05faa31bf75463f5fb29 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Wed, 4 Dec 2024 17:05:35 -0500 Subject: [PATCH 133/152] stop jerky behavior of picking from cal --- .../lib/components/DateRangeSelector/DateRangeSelector.svelte | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/src/lib/components/DateRangeSelector/DateRangeSelector.svelte b/frontend/src/lib/components/DateRangeSelector/DateRangeSelector.svelte index 5a0d1b93c3..b339d334a8 100644 --- a/frontend/src/lib/components/DateRangeSelector/DateRangeSelector.svelte +++ b/frontend/src/lib/components/DateRangeSelector/DateRangeSelector.svelte @@ -93,7 +93,7 @@ {selectDateRange?.label || 'Select range'} - {#if calendarDateRange && calendarDateRange.start} + {#if !calendarDateRangePopoverOpen && calendarDateRange && calendarDateRange.start} {#if calendarDateRange.end} {df.format(calendarDateRange.start.toDate(getLocalTimeZone()))} - {df.format( calendarDateRange.end.toDate(getLocalTimeZone()) From 06d828148276f59897b2919fcb83ee735ac330c9 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Wed, 4 Dec 2024 17:13:55 -0500 Subject: [PATCH 134/152] dont stack percentile charts --- .../src/lib/components/PercentileChart/PercentileChart.svelte | 1 - 1 file changed, 1 deletion(-) diff --git a/frontend/src/lib/components/PercentileChart/PercentileChart.svelte b/frontend/src/lib/components/PercentileChart/PercentileChart.svelte index f473b04bd6..ecae9d4623 100644 --- a/frontend/src/lib/components/PercentileChart/PercentileChart.svelte +++ b/frontend/src/lib/components/PercentileChart/PercentileChart.svelte @@ -17,7 +17,6 @@ .map((label) => ({ name: label, type: 'line', - stack: 'Total', emphasis: { focus: 'series' }, From 6707d9b9c03906d699a282e79a304870239f39f4 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Wed, 4 Dec 2024 17:22:54 -0500 Subject: [PATCH 135/152] modify action buttons --- .../ActionButtons/ActionButtons.svelte | 21 ++++++++++++------- frontend/src/routes/joins/[slug]/+page.svelte | 2 +- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/frontend/src/lib/components/ActionButtons/ActionButtons.svelte b/frontend/src/lib/components/ActionButtons/ActionButtons.svelte index 26b38d2332..ec6c029dfc 100644 --- a/frontend/src/lib/components/ActionButtons/ActionButtons.svelte +++ b/frontend/src/lib/components/ActionButtons/ActionButtons.svelte @@ -6,8 +6,11 @@ import { page } from '$app/stores'; import { getSortDirection, type SortDirection } from '$lib/types/SortDirection/SortDirection'; - let { showCluster = false, class: className }: { showCluster?: boolean; class?: string } = - $props(); + let { + showCluster = false, + class: className, + showSort = false + }: { showCluster?: boolean; class?: string; showSort?: boolean } = $props(); let activeCluster = showCluster ? 'GroupBys' : null; @@ -40,16 +43,18 @@
    - + {/if} + - {#if showCluster} - diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index dc3c213a3b..3fc5f849f1 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -488,7 +488,7 @@ - + {#each joinTimeseries.items as group (group.name)} Date: Thu, 5 Dec 2024 09:53:05 -0500 Subject: [PATCH 136/152] use start and end of entire range --- frontend/src/routes/joins/[slug]/+page.svelte | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index 3fc5f849f1..24ce3b7115 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -16,9 +16,7 @@ import IntersectionObserver from 'svelte-intersection-observer'; import { fade } from 'svelte/transition'; import { Button } from '$lib/components/ui/button'; - import { parseDateRangeParams } from '$lib/util/date-ranges'; import { getFeatureTimeseries } from '$lib/api/api'; - import { page } from '$app/stores'; import InfoTooltip from '$lib/components/InfoTooltip/InfoTooltip.svelte'; import { Table, TableBody, TableCell, TableRow } from '$lib/components/ui/table/index.js'; import TrueFalseBadge from '$lib/components/TrueFalseBadge/TrueFalseBadge.svelte'; @@ -198,17 +196,13 @@ selectedSeries = seriesName; if (seriesName) { - const { startTimestamp, endTimestamp } = parseDateRangeParams( - new URL($page.url).searchParams - ); - try { const [featureData, nullFeatureData] = await Promise.all([ getFeatureTimeseries({ joinId: joinTimeseries.name, featureName: seriesName, - startTs: startTimestamp, - endTs: endTimestamp, + startTs: data.dateRange.startTimestamp, + endTs: data.dateRange.endTimestamp, granularity: 'percentile', metricType: 'drift', metrics: 'value', @@ -218,8 +212,8 @@ getFeatureTimeseries({ joinId: joinTimeseries.name, featureName: seriesName, - startTs: startTimestamp, - endTs: endTimestamp, + startTs: data.dateRange.startTimestamp, + endTs: data.dateRange.endTimestamp, metricType: 'drift', metrics: 'null', offset: '1D', From 05dc6ff02e5ca393ea782cc924bf99bcf720f7b8 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Thu, 5 Dec 2024 10:20:39 -0500 Subject: [PATCH 137/152] gitignore for metals generated stuff --- .gitignore | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 0a3c7db8e6..cb82ba238b 100644 --- a/.gitignore +++ b/.gitignore @@ -76,4 +76,15 @@ releases /cloud_aws/dynamodb-local-metadata.json # Elastic Search files -/docker-init/elasticsearch-data \ No newline at end of file +/docker-init/elasticsearch-data + +# Metals and Bloop +.metals/ +.bloop/ +.bsp/ +.worksheet/ +.project/ + +# Metals-generated sbt files +/project/**/metals.sbt +/project/**/metals.sbt.lock \ No newline at end of file From ae79a94d8aa7bf4111a00fa88a22ae5b26a88ee5 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Thu, 5 Dec 2024 10:31:41 -0500 Subject: [PATCH 138/152] remove duplicate .bsp --- .gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitignore b/.gitignore index cb82ba238b..37daedfb53 100644 --- a/.gitignore +++ b/.gitignore @@ -81,7 +81,6 @@ releases # Metals and Bloop .metals/ .bloop/ -.bsp/ .worksheet/ .project/ From deeb85edd05a766de11f88a73f0c2654805107e3 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Thu, 5 Dec 2024 14:10:30 -0800 Subject: [PATCH 139/152] fixing logic in timeseries controller to align summaries --- .../controllers/TimeSeriesController.scala | 26 +++--- .../ai/chronon/online/stats/DriftStore.scala | 1 - .../spark/test/stats/drift/DriftTest.scala | 93 +++++++++++++++++-- 3 files changed, 97 insertions(+), 23 deletions(-) diff --git a/hub/app/controllers/TimeSeriesController.scala b/hub/app/controllers/TimeSeriesController.scala index 02bc4fc160..c8097499a0 100644 --- a/hub/app/controllers/TimeSeriesController.scala +++ b/hub/app/controllers/TimeSeriesController.scala @@ -17,6 +17,8 @@ import scala.concurrent.Future import scala.concurrent.duration._ import scala.jdk.CollectionConverters._ import scala.util.Failure +import scala.util.ScalaJavaConversions.ListOps +import scala.util.ScalaJavaConversions.MapOps import scala.util.Success /** @@ -243,20 +245,20 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon } } else { if (isNumeric) { - summarySeries.percentiles.asScala.zip(summarySeries.timestamps.asScala).flatMap { - case (percentiles, ts) => - DriftStore.breaks(20).zip(percentiles.asScala).map { - case (l, value) => TimeSeriesPoint(value, ts, Some(l)) - } + val percentileSeriesPerBreak = summarySeries.percentiles.toScala + val timeStamps = summarySeries.timestamps.toScala + val breaks = DriftStore.breaks(20) + percentileSeriesPerBreak.zip(breaks).flatMap { + case (percentileSeries, break) => + percentileSeries.toScala.zip(timeStamps).map { case (value, ts) => TimeSeriesPoint(value, ts, Some(break)) } } } else { - summarySeries.timestamps.asScala.zipWithIndex.flatMap { - case (ts, idx) => - summarySeries.histogram.asScala.map { - case (label, values) => - TimeSeriesPoint(values.get(idx).toDouble, ts, Some(label)) - } - } + val histogramOfSeries = summarySeries.histogram.toScala + val timeStamps = summarySeries.timestamps.toScala + histogramOfSeries.flatMap { + case (label, values) => + values.toScala.zip(timeStamps).map { case (value, ts) => TimeSeriesPoint(value.toDouble, ts, Some(label)) } + }.toSeq } } } diff --git a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala index f9e00aa587..b1776488b4 100644 --- a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala +++ b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala @@ -210,6 +210,5 @@ object DriftStore { override def initialValue(): TDeserializer = new TDeserializer(new TBinaryProtocol.Factory()) } - // todo - drop this hard-coded list in favor of a well known list or exposing as part of summaries def breaks(count: Int): Seq[String] = (0 to count).map(_ * (100 / count)).map("p" + _.toString) } diff --git a/spark/src/test/scala/ai/chronon/spark/test/stats/drift/DriftTest.scala b/spark/src/test/scala/ai/chronon/spark/test/stats/drift/DriftTest.scala index 666f3fa9f6..08e796b886 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/stats/drift/DriftTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/stats/drift/DriftTest.scala @@ -7,6 +7,7 @@ import ai.chronon.api.DriftMetric import ai.chronon.api.Extensions.MetadataOps import ai.chronon.api.Extensions.WindowOps import ai.chronon.api.PartitionSpec +import ai.chronon.api.TileSummarySeries import ai.chronon.api.Window import ai.chronon.online.KVStore import ai.chronon.online.stats.DriftStore @@ -23,8 +24,12 @@ import org.scalatest.matchers.should.Matchers import java.util.concurrent.TimeUnit import scala.concurrent.Await +import scala.concurrent.Future import scala.concurrent.duration.Duration import scala.util.ScalaJavaConversions.IteratorOps +import scala.util.ScalaJavaConversions.ListOps +import scala.util.ScalaJavaConversions.MapOps +import scala.util.Success class DriftTest extends AnyFlatSpec with Matchers { @@ -126,7 +131,7 @@ class DriftTest extends AnyFlatSpec with Matchers { startMs, endMs ) - val summarySeries = Await.result(summarySeriesFuture.get, Duration.create(10, TimeUnit.SECONDS)) + val summarySeries = Await.result(summarySeriesFuture.get, Duration.create(100, TimeUnit.SECONDS)) val (summaryNulls, summaryTotals) = summarySeries.iterator.foldLeft(0 -> 0) { case ((nulls, total), s) => if (s.getPercentiles == null) { @@ -153,14 +158,82 @@ class DriftTest extends AnyFlatSpec with Matchers { val window = new Window(10, ai.chronon.api.TimeUnit.HOURS) val joinPath = joinName.replaceFirst("\\.", "/") - println("Looking up current summary series") - val maybeCurrentSummarySeries = driftStore.getSummarySeries(joinPath, startTs, endTs, Some(name)).get - val currentSummarySeries = Await.result(maybeCurrentSummarySeries, Duration.create(10, TimeUnit.SECONDS)) - println("Now looking up baseline summary series") - val maybeBaselineSummarySeries = driftStore.getSummarySeries(joinPath, startTs - window.millis, endTs - window.millis, Some(name)) - val baselineSummarySeries = Await.result(maybeBaselineSummarySeries.get, Duration.create(10, TimeUnit.SECONDS)) - - println(s"Current summary series: $currentSummarySeries") - println(s"Baseline summary series: $baselineSummarySeries") + + implicit val execContext = scala.concurrent.ExecutionContext.global + val metric = ValuesMetric + val maybeCurrentSummarySeries = driftStore.getSummarySeries(joinPath, startTs, endTs, Some(name)) + val maybeBaselineSummarySeries = + driftStore.getSummarySeries(joinPath, startTs - window.millis, endTs - window.millis, Some(name)) + val result = (maybeCurrentSummarySeries, maybeBaselineSummarySeries) match { + case (Success(currentSummarySeriesFuture), Success(baselineSummarySeriesFuture)) => + Future.sequence(Seq(currentSummarySeriesFuture, baselineSummarySeriesFuture)).map { merged => + val currentSummarySeries = merged.head + val baselineSummarySeries = merged.last + val isCurrentNumeric = currentSummarySeries.headOption.forall(checkIfNumeric) + val isBaselineNumeric = baselineSummarySeries.headOption.forall(checkIfNumeric) + + val currentFeatureTs = { + if (currentSummarySeries.isEmpty) Seq.empty + else convertTileSummarySeriesToTimeSeries(currentSummarySeries.head, isCurrentNumeric, metric) + } + val baselineFeatureTs = { + if (baselineSummarySeries.isEmpty) Seq.empty + else convertTileSummarySeriesToTimeSeries(baselineSummarySeries.head, isBaselineNumeric, metric) + } + + ComparedFeatureTimeSeries(name, isCurrentNumeric, baselineFeatureTs, currentFeatureTs) + } + } + println(Await.result(result, Duration.create(10, TimeUnit.SECONDS))) + } + + // this is clunky copy of code, but was necessary to run the logic end-to-end without mocking drift store + // TODO move this into TimeSeriesControllerSpec and refactor that test to be more end-to-end. + case class ComparedFeatureTimeSeries(feature: String, + isNumeric: Boolean, + baseline: Seq[TimeSeriesPoint], + current: Seq[TimeSeriesPoint]) + + sealed trait Metric + + /** Roll up over null counts */ + case object NullMetric extends Metric + + /** Roll up over raw values */ + case object ValuesMetric extends Metric + + + case class TimeSeriesPoint(value: Double, ts: Long, label: Option[String] = None, nullValue: Option[Int] = None) + + def checkIfNumeric(summarySeries: TileSummarySeries): Boolean = { + val ptiles = summarySeries.percentiles.toScala + ptiles != null && ptiles.exists(_ != null) + } + + + + private def convertTileSummarySeriesToTimeSeries(summarySeries: TileSummarySeries, + isNumeric: Boolean, + metric: Metric): Seq[TimeSeriesPoint] = { + if (metric == NullMetric) { + summarySeries.nullCount.toScala.zip(summarySeries.timestamps.toScala).map { + case (nullCount, ts) => TimeSeriesPoint(0, ts, nullValue = Some(nullCount.intValue())) + } + } else { + if (isNumeric) { + val percentileSeriesPerBreak = summarySeries.percentiles.toScala + val timeStamps = summarySeries.timestamps.toScala + val breaks = DriftStore.breaks(20) + percentileSeriesPerBreak.zip(breaks).flatMap{ case (percentileSeries, break) => + percentileSeries.toScala.zip(timeStamps).map{case (value, ts) => TimeSeriesPoint(value, ts, Some(break))} + } + } else { + val histogramOfSeries = summarySeries.histogram.toScala + val timeStamps = summarySeries.timestamps.toScala + histogramOfSeries.flatMap{ case (label, values) => + values.toScala.zip(timeStamps).map{case (value, ts) => TimeSeriesPoint(value.toDouble, ts, Some(label))} + }.toSeq + } + } } } \ No newline at end of file From 1d6f5c2e447a53e78dac35a7a2484d7399341971 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Fri, 6 Dec 2024 12:41:12 -0500 Subject: [PATCH 140/152] let user select text from collapsible section --- .../lib/components/CollapsibleSection/CollapsibleSection.svelte | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/src/lib/components/CollapsibleSection/CollapsibleSection.svelte b/frontend/src/lib/components/CollapsibleSection/CollapsibleSection.svelte index 9da3feab1e..9a07f7abac 100644 --- a/frontend/src/lib/components/CollapsibleSection/CollapsibleSection.svelte +++ b/frontend/src/lib/components/CollapsibleSection/CollapsibleSection.svelte @@ -47,7 +47,7 @@ size="16" class="transition-transform duration-200 {open ? '' : 'rotate-180'}" /> -

    {title}

    +

    {title}

    {#if headerContentLeft} From aae53323e96c73c65916f4a35a640b2ab2064fa0 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Fri, 6 Dec 2024 13:49:02 -0500 Subject: [PATCH 141/152] move loading of distributions into client side for better performance --- .../src/routes/joins/[slug]/+page.server.ts | 49 ++++--------- frontend/src/routes/joins/[slug]/+page.svelte | 69 ++++++++++++++++--- 2 files changed, 72 insertions(+), 46 deletions(-) diff --git a/frontend/src/routes/joins/[slug]/+page.server.ts b/frontend/src/routes/joins/[slug]/+page.server.ts index 3378b5cbf4..6b1ef8fdc7 100644 --- a/frontend/src/routes/joins/[slug]/+page.server.ts +++ b/frontend/src/routes/joins/[slug]/+page.server.ts @@ -1,6 +1,6 @@ import type { PageServerLoad } from './$types'; import * as api from '$lib/api/api'; -import type { JoinTimeSeriesResponse, Model, FeatureResponse } from '$lib/types/Model/Model'; +import type { JoinTimeSeriesResponse, Model } from '$lib/types/Model/Model'; import { parseDateRangeParams } from '$lib/util/date-ranges'; import { getMetricTypeFromParams, type MetricType } from '$lib/types/MetricType/MetricType'; import { getSortDirection } from '$lib/types/SortDirection/SortDirection'; @@ -14,7 +14,6 @@ export const load: PageServerLoad = async ({ }): Promise<{ joinTimeseries: JoinTimeSeriesResponse; model?: Model; - distributions: FeatureResponse[]; metricType: MetricType; dateRange: { startTimestamp: number; @@ -30,15 +29,18 @@ export const load: PageServerLoad = async ({ // Try with requested date range first try { - const data = await fetchAllData( + const { joinTimeseries, model } = await fetchInitialData( joinName, requestedDateRange.startTimestamp, requestedDateRange.endTimestamp, metricType, sortDirection ); + return { - ...data, + joinTimeseries, + model, + metricType, dateRange: { ...requestedDateRange, isUsingFallback: false @@ -47,26 +49,29 @@ export const load: PageServerLoad = async ({ } catch (error) { console.error('Error fetching data:', error); // If the requested range fails, fall back to the known working range - const data = await fetchAllData( + const { joinTimeseries, model } = await fetchInitialData( joinName, FALLBACK_START_TS, FALLBACK_END_TS, metricType, sortDirection ); + return { - ...data, + joinTimeseries, + model, + metricType, dateRange: { startTimestamp: FALLBACK_START_TS, endTimestamp: FALLBACK_END_TS, - dateRangeValue: requestedDateRange.dateRangeValue, // Preserve the user's selected range + dateRangeValue: requestedDateRange.dateRangeValue, isUsingFallback: true } }; } }; -async function fetchAllData( +async function fetchInitialData( joinName: string, startTs: number, endTs: number, @@ -99,36 +104,10 @@ async function fetchAllData( }); }); - // Get all unique feature names across all groups - const allFeatures = Array.from( - new Set(joinTimeseries.items.flatMap((group) => group.items.map((item) => item.feature))) - ); - - // Fetch percentile data for each feature - const distributionsPromises = allFeatures.map((featureName) => - api.getFeatureTimeseries({ - joinId: joinName, - featureName, - startTs, - endTs, - granularity: 'percentile', - metricType: 'drift', - metrics: 'value', - offset: '1D', - algorithm: 'psi' - }) - ); - - const distributions = (await Promise.all(distributionsPromises)).filter( - (response) => response.isNumeric - ); - const modelToReturn = models.items.find((m) => m.join.name === joinName); return { joinTimeseries, - model: modelToReturn, - distributions, - metricType + model: modelToReturn }; } diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index 24ce3b7115..d3a9b9db7f 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -29,12 +29,12 @@ import { getSeriesColor } from '$lib/util/chart'; import { handleChartHighlight } from '$lib/util/chart'; import ChartControls from '$lib/components/ChartControls/ChartControls.svelte'; + import { onMount } from 'svelte'; const { data } = $props(); let scale = $derived(METRIC_SCALES[data.metricType]); const joinTimeseries = $derived(data.joinTimeseries); const model = $derived(data.model); - const distributions = $derived(data.distributions); const isUsingFallbackDates = $derived(data.dateRange.isUsingFallback); let isFeatureMonitoringOpen = $state(true); let isSheetOpen = $state(false); @@ -402,6 +402,47 @@ handleChartHighlight(chart, seriesName, type); } } + + let distributions: FeatureResponse[] = $state([]); + let isLoadingDistributions = $state(false); + + async function loadDistributions() { + if (distributions.length > 0 || isLoadingDistributions) return; + + isLoadingDistributions = true; + try { + // Get all unique feature names across all groups + const allFeatures = Array.from( + new Set(joinTimeseries.items.flatMap((group) => group.items.map((item) => item.feature))) + ); + + // Fetch percentile data for each feature + const distributionsPromises = allFeatures.map((featureName) => + getFeatureTimeseries({ + joinId: joinTimeseries.name, + featureName, + startTs: data.dateRange.startTimestamp, + endTs: data.dateRange.endTimestamp, + granularity: 'percentile', + metricType: 'drift', + metrics: 'value', + offset: '1D', + algorithm: 'psi' + }) + ); + + const responses = await Promise.all(distributionsPromises); + distributions = responses.filter((response) => response.isNumeric); + } catch (error) { + console.error('Error loading distributions:', error); + } finally { + isLoadingDistributions = false; + } + } + + onMount(() => { + loadDistributions(); + }); {#if shouldShowStickyHeader} @@ -513,16 +554,22 @@ {/each} - {#each distributions as feature} - - {#snippet collapsibleContent()} - - {/snippet} - - {/each} + {#if isLoadingDistributions} +
    Loading distributions...
    + {:else if distributions.length === 0} +
    No distribution data available
    + {:else} + {#each distributions as feature} + + {#snippet collapsibleContent()} + + {/snippet} + + {/each} + {/if}
    {/snippet} From 2c55852beb07306a44bc884ce31dea0873e45c45 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Fri, 6 Dec 2024 13:51:22 -0500 Subject: [PATCH 142/152] adjust margin --- frontend/src/routes/joins/[slug]/+page.svelte | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index d3a9b9db7f..2e272fed9a 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -555,9 +555,9 @@ {#if isLoadingDistributions} -
    Loading distributions...
    +
    Loading distributions...
    {:else if distributions.length === 0} -
    No distribution data available
    +
    No distribution data available
    {:else} {#each distributions as feature} From 543cab0c6a9bafe605148d181108de68bcf17e17 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Fri, 6 Dec 2024 14:24:12 -0500 Subject: [PATCH 143/152] sort in distributions and drift tab --- .../ActionButtons/ActionButtons.svelte | 24 ++++++++++---- .../lib/types/SortDirection/SortDirection.ts | 6 ---- frontend/src/lib/utils/sort.ts | 32 +++++++++++++++++++ .../src/routes/joins/[slug]/+page.server.ts | 4 +-- frontend/src/routes/joins/[slug]/+page.svelte | 12 +++++-- 5 files changed, 62 insertions(+), 16 deletions(-) delete mode 100644 frontend/src/lib/types/SortDirection/SortDirection.ts create mode 100644 frontend/src/lib/utils/sort.ts diff --git a/frontend/src/lib/components/ActionButtons/ActionButtons.svelte b/frontend/src/lib/components/ActionButtons/ActionButtons.svelte index ec6c029dfc..8c1be6a6f3 100644 --- a/frontend/src/lib/components/ActionButtons/ActionButtons.svelte +++ b/frontend/src/lib/components/ActionButtons/ActionButtons.svelte @@ -4,22 +4,34 @@ import { Icon, Plus, ArrowsUpDown, Square3Stack3d, XMark } from 'svelte-hero-icons'; import { goto } from '$app/navigation'; import { page } from '$app/stores'; - import { getSortDirection, type SortDirection } from '$lib/types/SortDirection/SortDirection'; + import { + getSortDirection, + updateContextSort, + type SortDirection, + type SortContext + } from '$lib/utils/sort'; let { showCluster = false, class: className, - showSort = false - }: { showCluster?: boolean; class?: string; showSort?: boolean } = $props(); + showSort = false, + context = 'drift' + }: { + showCluster?: boolean; + class?: string; + showSort?: boolean; + context?: SortContext; + } = $props(); let activeCluster = showCluster ? 'GroupBys' : null; - let currentSort: SortDirection = $derived.by(() => getSortDirection($page.url.searchParams)); + let currentSort: SortDirection = $derived.by(() => + getSortDirection($page.url.searchParams, context) + ); function handleSort() { const newSort: SortDirection = currentSort === 'asc' ? 'desc' : 'asc'; - const url = new URL($page.url); - url.searchParams.set('sort', newSort); + const url = updateContextSort($page.url, context, newSort); goto(url, { replaceState: true }); } diff --git a/frontend/src/lib/types/SortDirection/SortDirection.ts b/frontend/src/lib/types/SortDirection/SortDirection.ts deleted file mode 100644 index e419548882..0000000000 --- a/frontend/src/lib/types/SortDirection/SortDirection.ts +++ /dev/null @@ -1,6 +0,0 @@ -export type SortDirection = 'asc' | 'desc'; - -export function getSortDirection(searchParams: URLSearchParams): SortDirection { - const param = searchParams.get('sort'); - return param === 'asc' || param === 'desc' ? param : 'asc'; -} diff --git a/frontend/src/lib/utils/sort.ts b/frontend/src/lib/utils/sort.ts new file mode 100644 index 0000000000..a517842dd8 --- /dev/null +++ b/frontend/src/lib/utils/sort.ts @@ -0,0 +1,32 @@ +import type { FeatureResponse } from '$lib/types/Model/Model'; + +export type SortDirection = 'asc' | 'desc'; +export type SortContext = 'drift' | 'distributions'; + +export function getSortParamKey(context: SortContext): string { + return `${context}Sort`; +} + +export function getSortDirection( + searchParams: URLSearchParams, + context: SortContext +): SortDirection { + const param = searchParams.get(getSortParamKey(context)); + return param === 'desc' ? 'desc' : 'asc'; +} + +export function updateContextSort(url: URL, context: SortContext, direction: SortDirection): URL { + const newUrl = new URL(url); + newUrl.searchParams.set(getSortParamKey(context), direction); + return newUrl; +} + +export function sortDistributions( + distributions: FeatureResponse[], + direction: SortDirection +): FeatureResponse[] { + return [...distributions].sort((a, b) => { + const comparison = a.feature.localeCompare(b.feature); + return direction === 'asc' ? comparison : -comparison; + }); +} diff --git a/frontend/src/routes/joins/[slug]/+page.server.ts b/frontend/src/routes/joins/[slug]/+page.server.ts index 6b1ef8fdc7..46bfe6f069 100644 --- a/frontend/src/routes/joins/[slug]/+page.server.ts +++ b/frontend/src/routes/joins/[slug]/+page.server.ts @@ -3,7 +3,7 @@ import * as api from '$lib/api/api'; import type { JoinTimeSeriesResponse, Model } from '$lib/types/Model/Model'; import { parseDateRangeParams } from '$lib/util/date-ranges'; import { getMetricTypeFromParams, type MetricType } from '$lib/types/MetricType/MetricType'; -import { getSortDirection } from '$lib/types/SortDirection/SortDirection'; +import { getSortDirection } from '$lib/utils/sort'; const FALLBACK_START_TS = 1672531200000; // 2023-01-01 const FALLBACK_END_TS = 1677628800000; // 2023-03-01 @@ -25,7 +25,7 @@ export const load: PageServerLoad = async ({ const requestedDateRange = parseDateRangeParams(url.searchParams); const joinName = params.slug; const metricType = getMetricTypeFromParams(url.searchParams); - const sortDirection = getSortDirection(url.searchParams); + const sortDirection = getSortDirection(url.searchParams, 'drift'); // Try with requested date range first try { diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index 2e272fed9a..cd7309aa66 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -30,6 +30,8 @@ import { handleChartHighlight } from '$lib/util/chart'; import ChartControls from '$lib/components/ChartControls/ChartControls.svelte'; import { onMount } from 'svelte'; + import { page } from '$app/stores'; + import { getSortDirection, sortDistributions } from '$lib/utils/sort'; const { data } = $props(); let scale = $derived(METRIC_SCALES[data.metricType]); @@ -443,6 +445,11 @@ onMount(() => { loadDistributions(); }); + + const sortedDistributions = $derived.by(() => { + const distributionsSort = getSortDirection($page.url.searchParams, 'distributions'); + return sortDistributions(distributions, distributionsSort); + }); {#if shouldShowStickyHeader} @@ -523,7 +530,7 @@ - + {#each joinTimeseries.items as group (group.name)} + {#if isLoadingDistributions}
    Loading distributions...
    {:else if distributions.length === 0}
    No distribution data available
    {:else} - {#each distributions as feature} + {#each sortedDistributions as feature} {#snippet collapsibleContent()} Date: Fri, 6 Dec 2024 14:26:36 -0500 Subject: [PATCH 144/152] move drift sort logic to sort.ts --- frontend/src/lib/utils/sort.ts | 22 ++++++++++++++++++- .../src/routes/joins/[slug]/+page.server.ts | 20 ++++------------- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/frontend/src/lib/utils/sort.ts b/frontend/src/lib/utils/sort.ts index a517842dd8..ddb3e9c362 100644 --- a/frontend/src/lib/utils/sort.ts +++ b/frontend/src/lib/utils/sort.ts @@ -1,4 +1,4 @@ -import type { FeatureResponse } from '$lib/types/Model/Model'; +import type { FeatureResponse, JoinTimeSeriesResponse } from '$lib/types/Model/Model'; export type SortDirection = 'asc' | 'desc'; export type SortContext = 'drift' | 'distributions'; @@ -21,6 +21,26 @@ export function updateContextSort(url: URL, context: SortContext, direction: Sor return newUrl; } +export function sortDrift( + joinTimeseries: JoinTimeSeriesResponse, + direction: SortDirection +): JoinTimeSeriesResponse { + const sorted = { ...joinTimeseries }; + + // Sort main groups + sorted.items = [...joinTimeseries.items].sort((a, b) => { + const comparison = a.name.localeCompare(b.name); + return direction === 'asc' ? comparison : -comparison; + }); + + // Sort features within each group + sorted.items.forEach((group) => { + group.items.sort((a, b) => a.feature.localeCompare(b.feature)); + }); + + return sorted; +} + export function sortDistributions( distributions: FeatureResponse[], direction: SortDirection diff --git a/frontend/src/routes/joins/[slug]/+page.server.ts b/frontend/src/routes/joins/[slug]/+page.server.ts index 46bfe6f069..cf7f0c9668 100644 --- a/frontend/src/routes/joins/[slug]/+page.server.ts +++ b/frontend/src/routes/joins/[slug]/+page.server.ts @@ -3,7 +3,7 @@ import * as api from '$lib/api/api'; import type { JoinTimeSeriesResponse, Model } from '$lib/types/Model/Model'; import { parseDateRangeParams } from '$lib/util/date-ranges'; import { getMetricTypeFromParams, type MetricType } from '$lib/types/MetricType/MetricType'; -import { getSortDirection } from '$lib/utils/sort'; +import { getSortDirection, sortDrift, type SortDirection } from '$lib/utils/sort'; const FALLBACK_START_TS = 1672531200000; // 2023-01-01 const FALLBACK_END_TS = 1677628800000; // 2023-03-01 @@ -76,7 +76,7 @@ async function fetchInitialData( startTs: number, endTs: number, metricType: MetricType, - sortDirection: string + sortDirection: SortDirection ) { const [joinTimeseries, models] = await Promise.all([ api.getJoinTimeseries({ @@ -91,23 +91,11 @@ async function fetchInitialData( api.getModels() ]); - // Sort main groups - joinTimeseries.items.sort((a, b) => { - const comparison = a.name.localeCompare(b.name); - return sortDirection === 'asc' ? comparison : -comparison; - }); - - // Sort features within each group - joinTimeseries.items.forEach((group) => { - group.items.sort((a, b) => { - return a.feature.localeCompare(b.feature); - }); - }); - + const sortedJoinTimeseries = sortDrift(joinTimeseries, sortDirection); const modelToReturn = models.items.find((m) => m.join.name === joinName); return { - joinTimeseries, + joinTimeseries: sortedJoinTimeseries, model: modelToReturn }; } From 8ddabfb21d0db17da54733d70392d560c23d6494 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Fri, 6 Dec 2024 14:34:55 -0500 Subject: [PATCH 145/152] move sort to util folder instead of utils --- frontend/src/lib/components/ActionButtons/ActionButtons.svelte | 2 +- frontend/src/lib/{utils => util}/sort.ts | 0 frontend/src/routes/joins/[slug]/+page.server.ts | 2 +- frontend/src/routes/joins/[slug]/+page.svelte | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) rename frontend/src/lib/{utils => util}/sort.ts (100%) diff --git a/frontend/src/lib/components/ActionButtons/ActionButtons.svelte b/frontend/src/lib/components/ActionButtons/ActionButtons.svelte index 8c1be6a6f3..25ed1a4883 100644 --- a/frontend/src/lib/components/ActionButtons/ActionButtons.svelte +++ b/frontend/src/lib/components/ActionButtons/ActionButtons.svelte @@ -9,7 +9,7 @@ updateContextSort, type SortDirection, type SortContext - } from '$lib/utils/sort'; + } from '$lib/util/sort'; let { showCluster = false, diff --git a/frontend/src/lib/utils/sort.ts b/frontend/src/lib/util/sort.ts similarity index 100% rename from frontend/src/lib/utils/sort.ts rename to frontend/src/lib/util/sort.ts diff --git a/frontend/src/routes/joins/[slug]/+page.server.ts b/frontend/src/routes/joins/[slug]/+page.server.ts index cf7f0c9668..fb9c82ce8d 100644 --- a/frontend/src/routes/joins/[slug]/+page.server.ts +++ b/frontend/src/routes/joins/[slug]/+page.server.ts @@ -3,7 +3,7 @@ import * as api from '$lib/api/api'; import type { JoinTimeSeriesResponse, Model } from '$lib/types/Model/Model'; import { parseDateRangeParams } from '$lib/util/date-ranges'; import { getMetricTypeFromParams, type MetricType } from '$lib/types/MetricType/MetricType'; -import { getSortDirection, sortDrift, type SortDirection } from '$lib/utils/sort'; +import { getSortDirection, sortDrift, type SortDirection } from '$lib/util/sort'; const FALLBACK_START_TS = 1672531200000; // 2023-01-01 const FALLBACK_END_TS = 1677628800000; // 2023-03-01 diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index cd7309aa66..48d8e38e23 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -31,7 +31,7 @@ import ChartControls from '$lib/components/ChartControls/ChartControls.svelte'; import { onMount } from 'svelte'; import { page } from '$app/stores'; - import { getSortDirection, sortDistributions } from '$lib/utils/sort'; + import { getSortDirection, sortDistributions } from '$lib/util/sort.js'; const { data } = $props(); let scale = $derived(METRIC_SCALES[data.metricType]); From b6b7e992f0c530b813b8e9982c7ec412fe1e1f20 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Fri, 6 Dec 2024 15:18:26 -0500 Subject: [PATCH 146/152] move action buttons into chart controls --- .../ChartControls/ChartControls.svelte | 22 +++++++++++++++++-- frontend/src/routes/joins/[slug]/+page.svelte | 22 ++++++++++++++----- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/frontend/src/lib/components/ChartControls/ChartControls.svelte b/frontend/src/lib/components/ChartControls/ChartControls.svelte index 6b684a9535..803e8310c3 100644 --- a/frontend/src/lib/components/ChartControls/ChartControls.svelte +++ b/frontend/src/lib/components/ChartControls/ChartControls.svelte @@ -2,19 +2,29 @@ import ResetZoomButton from '$lib/components/ResetZoomButton/ResetZoomButton.svelte'; import MetricTypeToggle from '$lib/components/MetricTypeToggle/MetricTypeToggle.svelte'; import DateRangeSelector from '$lib/components/DateRangeSelector/DateRangeSelector.svelte'; + import ActionButtons from '$lib/components/ActionButtons/ActionButtons.svelte'; import * as Alert from '$lib/components/ui/alert/index.js'; import { formatDate } from '$lib/util/format'; + import type { SortContext } from '$lib/util/sort'; let { isZoomed = false, onResetZoom, isUsingFallbackDates = false, - dateRange = { startTimestamp: 0, endTimestamp: 0 } + dateRange = { startTimestamp: 0, endTimestamp: 0 }, + showActionButtons = false, + showCluster = false, + showSort = false, + context }: { isZoomed: boolean; onResetZoom: () => void; isUsingFallbackDates?: boolean; dateRange?: { startTimestamp: number; endTimestamp: number }; + showActionButtons?: boolean; + showCluster?: boolean; + showSort?: boolean; + context?: SortContext; } = $props(); @@ -34,7 +44,15 @@ {#if isZoomed} {/if} - + {#if context === 'drift'} + + {/if}
    + + {#if showActionButtons} +
    + +
    + {/if}
    diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index 48d8e38e23..734884443f 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -20,7 +20,6 @@ import InfoTooltip from '$lib/components/InfoTooltip/InfoTooltip.svelte'; import { Table, TableBody, TableCell, TableRow } from '$lib/components/ui/table/index.js'; import TrueFalseBadge from '$lib/components/TrueFalseBadge/TrueFalseBadge.svelte'; - import ActionButtons from '$lib/components/ActionButtons/ActionButtons.svelte'; import { Dialog, DialogContent, DialogHeader } from '$lib/components/ui/dialog'; import { formatDate, formatValue } from '$lib/util/format'; import PercentileChart from '$lib/components/PercentileChart/PercentileChart.svelte'; @@ -31,7 +30,7 @@ import ChartControls from '$lib/components/ChartControls/ChartControls.svelte'; import { onMount } from 'svelte'; import { page } from '$app/stores'; - import { getSortDirection, sortDistributions } from '$lib/util/sort.js'; + import { getSortDirection, sortDistributions, type SortContext } from '$lib/util/sort'; const { data } = $props(); let scale = $derived(METRIC_SCALES[data.metricType]); @@ -450,6 +449,8 @@ const distributionsSort = getSortDirection($page.url.searchParams, 'distributions'); return sortDistributions(distributions, distributionsSort); }); + + let selectedTab = $state('drift'); {#if shouldShowStickyHeader} @@ -465,6 +466,10 @@ startTimestamp: data.dateRange.startTimestamp, endTimestamp: data.dateRange.endTimestamp }} + showActionButtons={true} + showCluster={selectedTab === 'drift'} + showSort={true} + context={selectedTab} />
    {/if} @@ -510,13 +515,17 @@ startTimestamp: data.dateRange.startTimestamp, endTimestamp: data.dateRange.endTimestamp }} + showActionButtons={true} + showCluster={selectedTab === 'drift'} + showSort={true} + context={selectedTab} />
    {#snippet collapsibleContent()} - + @@ -530,8 +539,6 @@ - - {#each joinTimeseries.items as group (group.name)} - {#if isLoadingDistributions}
    Loading distributions...
    {:else if distributions.length === 0} @@ -610,6 +616,10 @@ startTimestamp: data.dateRange.startTimestamp, endTimestamp: data.dateRange.endTimestamp }} + showActionButtons={false} + showCluster={selectedTab === 'drift'} + showSort={false} + context={selectedTab} /> From 698221045dddaf127ebf309f11f57c1aa8011235 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Fri, 6 Dec 2024 15:39:03 -0500 Subject: [PATCH 147/152] update fallback times --- frontend/src/routes/joins/[slug]/+page.server.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/src/routes/joins/[slug]/+page.server.ts b/frontend/src/routes/joins/[slug]/+page.server.ts index fb9c82ce8d..33723150fe 100644 --- a/frontend/src/routes/joins/[slug]/+page.server.ts +++ b/frontend/src/routes/joins/[slug]/+page.server.ts @@ -5,8 +5,8 @@ import { parseDateRangeParams } from '$lib/util/date-ranges'; import { getMetricTypeFromParams, type MetricType } from '$lib/types/MetricType/MetricType'; import { getSortDirection, sortDrift, type SortDirection } from '$lib/util/sort'; -const FALLBACK_START_TS = 1672531200000; // 2023-01-01 -const FALLBACK_END_TS = 1677628800000; // 2023-03-01 +const FALLBACK_START_TS = 1698796800000; // 2023-11-01 +const FALLBACK_END_TS = 1703980800000; // 2023-12-31 export const load: PageServerLoad = async ({ params, From 630b4695a1e57a5a52f5efc5b130e36e3389f751 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Fri, 6 Dec 2024 16:32:00 -0500 Subject: [PATCH 148/152] fix buggy cmd key behavior --- frontend/src/lib/components/EChart/EChart.svelte | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/src/lib/components/EChart/EChart.svelte b/frontend/src/lib/components/EChart/EChart.svelte index 6575ae624b..3e70c8cf25 100644 --- a/frontend/src/lib/components/EChart/EChart.svelte +++ b/frontend/src/lib/components/EChart/EChart.svelte @@ -98,7 +98,7 @@ }); function handleKeyDown(event: KeyboardEvent) { - if (event.metaKey || event.ctrlKey) { + if ((event.metaKey || event.ctrlKey) && event.type === 'keydown') { isCommandPressed = true; disableChartInteractions(); From fd6afde8b5500a5253564c3e80b0d2b88231674d Mon Sep 17 00:00:00 2001 From: Ken Morton Date: Wed, 11 Dec 2024 12:36:46 -0500 Subject: [PATCH 149/152] sync the metric type toggle ui when url changes (#120) ## Summary Stacked on #89. Quick change to make sure `MetricTypeToggle.svelte` stays synced with the url query params [Demo](https://drive.google.com/file/d/1YNBxIi0fP-Qi-LVIw5FexwZQVEfeCWaS/view?usp=drive_link) ## Checklist - [ ] Added Unit Tests - [ ] Covered by existing CI - [ ] Integration tested - [ ] Documentation update --- .../lib/components/MetricTypeToggle/MetricTypeToggle.svelte | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/frontend/src/lib/components/MetricTypeToggle/MetricTypeToggle.svelte b/frontend/src/lib/components/MetricTypeToggle/MetricTypeToggle.svelte index 609788e8a7..028dc6f762 100644 --- a/frontend/src/lib/components/MetricTypeToggle/MetricTypeToggle.svelte +++ b/frontend/src/lib/components/MetricTypeToggle/MetricTypeToggle.svelte @@ -9,13 +9,12 @@ import { page } from '$app/stores'; import { goto } from '$app/navigation'; - let selected = getMetricTypeFromParams(new URL($page.url).searchParams); + let selected = $derived(getMetricTypeFromParams(new URL($page.url).searchParams)); function toggle(value: MetricType) { const url = new URL($page.url); url.searchParams.set('metric', value); goto(url, { replaceState: true }); - selected = value; } From c0c84f237979179f1020b3d55ffe698ef89b4aaf Mon Sep 17 00:00:00 2001 From: Ken Morton Date: Wed, 11 Dec 2024 12:56:01 -0500 Subject: [PATCH 150/152] update markpoint when you change the drift calculation metric (#119) ## Summary Stacking this PR on #89. This makes sure the mark point line updates when the chart data changes. [Demo](https://drive.google.com/file/d/1utbMc0nxHFsnr7fHzP_QcDPzWt8MJcB8/view?usp=drive_link) ## Checklist - [ ] Added Unit Tests - [ ] Covered by existing CI - [ ] Integration tested - [ ] Documentation update --- frontend/src/routes/joins/[slug]/+page.svelte | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/frontend/src/routes/joins/[slug]/+page.svelte b/frontend/src/routes/joins/[slug]/+page.svelte index 734884443f..7e95cad423 100644 --- a/frontend/src/routes/joins/[slug]/+page.svelte +++ b/frontend/src/routes/joins/[slug]/+page.svelte @@ -451,6 +451,40 @@ }); let selectedTab = $state('drift'); + + // update selectedEvents when joinTimeseries changes + $effect(() => { + if (joinTimeseries) { + untrack(() => { + // Only update selectedEvents if we have a previously selected point + if (selectedEvents.length > 0 && selectedEvents[0]?.data && dialogGroupChart) { + const [timestamp] = selectedEvents[0].data as [number, number]; + const seriesName = selectedEvents[0].seriesName; + + // Get the updated series data from the chart + const series = dialogGroupChart.getOption().series as EChartOption.Series[]; + const updatedSeries = series.find((s) => s.name === seriesName); + + if (updatedSeries && Array.isArray(updatedSeries.data)) { + // Find the point at the same timestamp + const updatedPoint = updatedSeries.data.find((point) => { + const [pointTimestamp] = point as [number, number]; + return pointTimestamp === timestamp; + }) as [number, number] | undefined; + + if (updatedPoint) { + selectedEvents = [ + { + ...selectedEvents[0], + data: updatedPoint + } + ]; + } + } + } + }); + } + }); {#if shouldShowStickyHeader} From ebee549edb2b7108d2e819e97956f528fde41be7 Mon Sep 17 00:00:00 2001 From: ken-zlai Date: Wed, 11 Dec 2024 15:25:55 -0500 Subject: [PATCH 151/152] Revert "running observability demo" This reverts commit c5e32703ddf94aeef1dd9aa3ba55265432f9be3d. --- spark/src/main/scala/ai/chronon/spark/TableUtils.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index 1a4adc9b15..04525119e2 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -507,9 +507,9 @@ case class TableUtils(sparkSession: SparkSession) { sql(creationSql) } catch { case _: TableAlreadyExistsException => - println(s"Table $tableName already exists, skipping creation") + logger.info(s"Table $tableName already exists, skipping creation") case e: Exception => - println(s"Failed to create table $tableName", e) + logger.error(s"Failed to create table $tableName", e) throw e } } @@ -537,7 +537,6 @@ case class TableUtils(sparkSession: SparkSession) { // so that an exception will be thrown below dfRearranged } - println(s"Repartitioning and writing into table $tableName".yellow) repartitionAndWrite(finalizedDf, tableName, saveMode, stats, sortByCols) } From 12981cc30ee323016d318b0db4fda296bdb60036 Mon Sep 17 00:00:00 2001 From: Ken Morton Date: Wed, 11 Dec 2024 21:26:41 -0500 Subject: [PATCH 152/152] fix: load less sample for better performance (#123) ## Summary Wanted to scope back the amount of data generated, as it was leading to some performance issues. There is some optimization that can be done, but this fixes immediate performance issues. ## Checklist - [ ] Added Unit Tests - [ ] Covered by existing CI - [ ] Integration tested - [ ] Documentation update --- frontend/src/routes/joins/[slug]/+page.server.ts | 4 ++-- .../chronon/spark/scripts/ObservabilityDemoDataLoader.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/frontend/src/routes/joins/[slug]/+page.server.ts b/frontend/src/routes/joins/[slug]/+page.server.ts index 33723150fe..fb9c82ce8d 100644 --- a/frontend/src/routes/joins/[slug]/+page.server.ts +++ b/frontend/src/routes/joins/[slug]/+page.server.ts @@ -5,8 +5,8 @@ import { parseDateRangeParams } from '$lib/util/date-ranges'; import { getMetricTypeFromParams, type MetricType } from '$lib/types/MetricType/MetricType'; import { getSortDirection, sortDrift, type SortDirection } from '$lib/util/sort'; -const FALLBACK_START_TS = 1698796800000; // 2023-11-01 -const FALLBACK_END_TS = 1703980800000; // 2023-12-31 +const FALLBACK_START_TS = 1672531200000; // 2023-01-01 +const FALLBACK_END_TS = 1677628800000; // 2023-03-01 export const load: PageServerLoad = async ({ params, diff --git a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala index 258533a6ed..8bf4f52190 100644 --- a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala +++ b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala @@ -37,13 +37,13 @@ object ObservabilityDemoDataLoader { val endDs: ScallopOption[String] = opt[String]( name = "end-ds", - default = Some("2024-01-01"), + default = Some("2023-03-01"), descr = "End date in YYYY-MM-DD format" ) val rowCount: ScallopOption[Int] = opt[Int]( name = "row-count", - default = Some(1400000), + default = Some(700000), descr = "Number of rows to generate" )