From 71dc2bbeb3e87a6744e575d823606554efe9388c Mon Sep 17 00:00:00 2001 From: julieg18 Date: Fri, 17 Mar 2023 11:15:57 -0500 Subject: [PATCH] Create custom plots when plots are requested * creates custom plots when the plots are requested in `getCustomPlots` --- extension/src/plots/model/collect.test.ts | 120 +++++++++------------- extension/src/plots/model/collect.ts | 76 ++++++++------ extension/src/plots/model/index.ts | 87 +++++----------- extension/src/plots/webview/messages.ts | 14 ++- 4 files changed, 133 insertions(+), 164 deletions(-) diff --git a/extension/src/plots/model/collect.test.ts b/extension/src/plots/model/collect.test.ts index c31bf79162..295beb3e65 100644 --- a/extension/src/plots/model/collect.test.ts +++ b/extension/src/plots/model/collect.test.ts @@ -4,9 +4,9 @@ import { collectData, collectTemplates, collectOverrideRevisionDetails, - collectCustomPlots, - collectCustomPlotData + collectCustomPlots } from './collect' +import { isCheckpointPlot } from './custom' import plotsDiffFixture from '../../test/fixtures/plotsDiff/output' import customPlotsFixture, { customPlotsOrderFixture, @@ -18,9 +18,10 @@ import { } from '../../cli/dvc/contract' import { sameContents } from '../../util/array' import { - CheckpointPlot, - CustomPlot, CustomPlotData, + CustomPlotType, + DEFAULT_NB_ITEMS_PER_ROW, + DEFAULT_PLOT_HEIGHT, TemplatePlot } from '../webview/contract' import { getCLICommitId } from '../../test/fixtures/plotsDiff/util' @@ -31,81 +32,62 @@ const logsLossPath = join('logs', 'loss.tsv') const logsLossPlot = (plotsDiffFixture[logsLossPath][0] || {}) as TemplatePlot -const getCustomPlotFromCustomPlotData = ({ - id, - metric, - param, - type, - values -}: CustomPlotData) => - ({ - id, - metric, - param, - type, - values - } as CustomPlot) - describe('collectCustomPlots', () => { + const defaultFuncArgs = { + experiments: experimentsWithCheckpoints, + hasCheckpoints: true, + height: DEFAULT_PLOT_HEIGHT, + nbItemsPerRow: DEFAULT_NB_ITEMS_PER_ROW, + plotsOrderValues: customPlotsOrderFixture, + selectedRevisions: customPlotsFixture.colors?.domain + } + it('should return the expected data from the test fixture', () => { - const expectedOutput: CustomPlot[] = customPlotsFixture.plots.map( - getCustomPlotFromCustomPlotData - ) - const data = collectCustomPlots( - customPlotsOrderFixture, - experimentsWithCheckpoints - ) + const expectedOutput: CustomPlotData[] = customPlotsFixture.plots + const data = collectCustomPlots(defaultFuncArgs) expect(data).toStrictEqual(expectedOutput) }) -}) -describe('collectCustomPlotData', () => { - it('should return the expected data from test fixture', () => { - const expectedMetricVsParamPlotData = customPlotsFixture.plots[0] - const expectedCheckpointsPlotData = customPlotsFixture.plots[2] - const metricVsParamPlot = getCustomPlotFromCustomPlotData( - expectedMetricVsParamPlotData - ) - const checkpointsPlot = getCustomPlotFromCustomPlotData( - expectedCheckpointsPlotData + it('should return only custom plots if there no selected revisions', () => { + const expectedOutput: CustomPlotData[] = customPlotsFixture.plots.filter( + plot => plot.type !== CustomPlotType.CHECKPOINT ) + const data = collectCustomPlots({ + ...defaultFuncArgs, + selectedRevisions: undefined + }) - const metricVsParamData = collectCustomPlotData( - metricVsParamPlot, - customPlotsFixture.colors, - customPlotsFixture.nbItemsPerRow, - customPlotsFixture.height - ) + expect(data).toStrictEqual(expectedOutput) + }) - const checkpointsData = collectCustomPlotData( - { - ...checkpointsPlot, - values: [ - ...checkpointsPlot.values, - { - group: 'exp-123', - iteration: 1, - y: 1.4534177053451538 - }, - { - group: 'exp-123', - iteration: 2, - y: 1.757687 - }, - { - group: 'exp-123', - iteration: 3, - y: 1.989894 - } - ] - } as CheckpointPlot, - customPlotsFixture.colors, - customPlotsFixture.nbItemsPerRow, - customPlotsFixture.height + it('should return only custom plots if checkpoints are not enabled', () => { + const expectedOutput: CustomPlotData[] = customPlotsFixture.plots.filter( + plot => plot.type !== CustomPlotType.CHECKPOINT ) + const data = collectCustomPlots({ + ...defaultFuncArgs, + hasCheckpoints: false + }) - expect(metricVsParamData).toStrictEqual(expectedMetricVsParamPlotData) - expect(checkpointsData).toStrictEqual(expectedCheckpointsPlotData) + expect(data).toStrictEqual(expectedOutput) + }) + + it('should return checkpoint plots with values only containing selected experiments data', () => { + const domain = customPlotsFixture.colors?.domain.slice(1) as string[] + + const expectedOutput = customPlotsFixture.plots.map(plot => ({ + ...plot, + values: isCheckpointPlot(plot) + ? plot.values.filter(value => domain.includes(value.group)) + : plot.values + })) + + const data = collectCustomPlots({ + ...defaultFuncArgs, + selectedRevisions: domain + }) + + expect(data).toStrictEqual(expectedOutput) }) }) diff --git a/extension/src/plots/model/collect.ts b/extension/src/plots/model/collect.ts index 3493545017..7612e164ae 100644 --- a/extension/src/plots/model/collect.ts +++ b/extension/src/plots/model/collect.ts @@ -5,7 +5,6 @@ import { getFullValuePath, CHECKPOINTS_PARAM, CustomPlotsOrderValue, - isCheckpointPlot, isCheckpointValue } from './custom' import { getRevisionFirstThreeColumns } from './util' @@ -20,7 +19,6 @@ import { TemplatePlotSection, PlotsType, Revision, - CustomPlot, CustomPlotData, MetricVsParamPlotValues } from '../webview/contract' @@ -126,10 +124,13 @@ const getMetricVsParamValues = ( return values } -const getCustomPlot = ( +const getCustomPlotData = ( orderValue: CustomPlotsOrderValue, - experiments: ExperimentWithCheckpoints[] -): CustomPlot => { + experiments: ExperimentWithCheckpoints[], + selectedRevisions: string[] | undefined = [], + height: number, + nbItemsPerRow: number +): CustomPlotData => { const { metric, param, type } = orderValue const metricPath = getFullValuePath( ColumnType.METRICS, @@ -139,8 +140,12 @@ const getCustomPlot = ( const paramPath = getFullValuePath(ColumnType.PARAMS, param, FILE_SEPARATOR) + const selectedExperiments = experiments.filter(({ name, label }) => + selectedRevisions.includes(name || label) + ) + const values = isCheckpointValue(type) - ? getCheckpointValues(experiments, metricPath) + ? getCheckpointValues(selectedExperiments, metricPath) : getMetricVsParamValues(experiments, metricPath, paramPath) return { @@ -148,37 +153,46 @@ const getCustomPlot = ( metric, param, type, - values - } as CustomPlot + values, + yTitle: truncateVerticalTitle(metric, nbItemsPerRow, height) as string + } as CustomPlotData } -export const collectCustomPlots = ( - plotsOrderValues: CustomPlotsOrderValue[], +export const collectCustomPlots = ({ + plotsOrderValues, + experiments, + hasCheckpoints, + selectedRevisions, + height, + nbItemsPerRow +}: { + plotsOrderValues: CustomPlotsOrderValue[] experiments: ExperimentWithCheckpoints[] -): CustomPlot[] => { - return plotsOrderValues.map(plotOrderValue => - getCustomPlot(plotOrderValue, experiments) - ) -} - -export const collectCustomPlotData = ( - plot: CustomPlot, - colors: ColorScale | undefined, - nbItemsPerRow: number, + hasCheckpoints: boolean + selectedRevisions: string[] | undefined height: number -): CustomPlotData => { - const selectedExperiments = colors?.domain - const filteredValues = isCheckpointPlot(plot) - ? plot.values.filter(value => - (selectedExperiments as string[]).includes(value.group) + nbItemsPerRow: number +}): CustomPlotData[] => { + const plots = [] + const shouldSkipCheckpointPlots = !hasCheckpoints || !selectedRevisions + + for (const value of plotsOrderValues) { + if (shouldSkipCheckpointPlots && isCheckpointValue(value.type)) { + continue + } + + plots.push( + getCustomPlotData( + value, + experiments, + selectedRevisions, + height, + nbItemsPerRow ) - : plot.values + ) + } - return { - ...plot, - values: filteredValues, - yTitle: truncateVerticalTitle(plot.metric, nbItemsPerRow, height) as string - } as CustomPlotData + return plots } type RevisionPathData = { [path: string]: Record[] } diff --git a/extension/src/plots/model/index.ts b/extension/src/plots/model/index.ts index d4431d0e7d..0f5705b18c 100644 --- a/extension/src/plots/model/index.ts +++ b/extension/src/plots/model/index.ts @@ -11,15 +11,10 @@ import { collectCommitRevisionDetails, collectOverrideRevisionDetails, collectCustomPlots, - getCustomPlotId, - collectCustomPlotData + getCustomPlotId } from './collect' import { getRevisionFirstThreeColumns } from './util' -import { - cleanupOldOrderValue, - CustomPlotsOrderValue, - isCheckpointPlot -} from './custom' +import { cleanupOldOrderValue, CustomPlotsOrderValue } from './custom' import { CheckpointPlot, ComparisonPlots, @@ -31,8 +26,6 @@ import { SectionCollapsed, CustomPlotData, CustomPlotsData, - CustomPlot, - ColorScale, DEFAULT_HEIGHT, DEFAULT_NB_ITEMS_PER_ROW, PlotHeight @@ -78,8 +71,6 @@ export class PlotsModel extends ModelWithPersistence { private multiSourceVariations: MultiSourceVariations = {} private multiSourceEncoding: MultiSourceEncoding = {} - private customPlots?: CustomPlot[] - constructor( dvcRoot: string, experiments: Experiments, @@ -103,8 +94,6 @@ export class PlotsModel extends ModelWithPersistence { } public transformAndSetExperiments() { - this.recreateCustomPlots() - return this.removeStaleData() } @@ -149,7 +138,13 @@ export class PlotsModel extends ModelWithPersistence { } public getCustomPlots(): CustomPlotsData | undefined { - if (!this.customPlots) { + const experimentsWithNoCommitData = this.experiments.hasCheckpoints() + ? this.experiments + .getExperimentsWithCheckpoints() + .filter(({ checkpoints }) => !!checkpoints) + : this.experiments.getExperiments() + + if (experimentsWithNoCommitData.length === 0) { return } @@ -158,32 +153,29 @@ export class PlotsModel extends ModelWithPersistence { .getSelectedExperiments() .map(({ displayColor, id: revision }) => ({ displayColor, revision })) ) + const height = this.getHeight(Section.CUSTOM_PLOTS) + const nbItemsPerRow = this.getNbItemsPerRow(Section.CUSTOM_PLOTS) + const plotsOrderValues = this.getCustomPlotsOrder() + + const plots: CustomPlotData[] = collectCustomPlots({ + experiments: experimentsWithNoCommitData, + hasCheckpoints: this.experiments.hasCheckpoints(), + height, + nbItemsPerRow, + plotsOrderValues, + selectedRevisions: colors?.domain + }) - return { - colors, - height: this.getHeight(Section.CUSTOM_PLOTS), - nbItemsPerRow: this.getNbItemsPerRow(Section.CUSTOM_PLOTS), - plots: this.getCustomPlotsData(this.customPlots, colors) - } - } - - public recreateCustomPlots() { - const experimentsWithNoCommitData = this.experiments.hasCheckpoints() - ? this.experiments - .getExperimentsWithCheckpoints() - .filter(({ checkpoints }) => !!checkpoints) - : this.experiments.getExperiments() - - if (experimentsWithNoCommitData.length === 0) { - this.customPlots = undefined + if (plots.length === 0 && plotsOrderValues.length > 0) { return } - const customPlots: CustomPlot[] = collectCustomPlots( - this.getCustomPlotsOrder(), - experimentsWithNoCommitData - ) - this.customPlots = customPlots + return { + colors, + height, + nbItemsPerRow, + plots + } } public getCustomPlotsOrder() { @@ -194,7 +186,6 @@ export class PlotsModel extends ModelWithPersistence { public updateCustomPlotsOrder(plotsOrder: CustomPlotsOrderValue[]) { this.customPlotsOrder = plotsOrder - this.recreateCustomPlots() } public setCustomPlotsOrder(plotsOrder: CustomPlotsOrderValue[]) { @@ -454,28 +445,6 @@ export class PlotsModel extends ModelWithPersistence { return this.commitRevisions[label] || label } - private getCustomPlotsData( - plots: CustomPlot[], - colors: ColorScale | undefined - ): CustomPlotData[] { - const selectedExperimentsExist = !!colors - const filteredPlots: CustomPlotData[] = [] - for (const plot of plots) { - if (!selectedExperimentsExist && isCheckpointPlot(plot)) { - continue - } - filteredPlots.push( - collectCustomPlotData( - plot, - colors, - this.getNbItemsPerRow(Section.CUSTOM_PLOTS), - this.getHeight(Section.CUSTOM_PLOTS) - ) - ) - } - return filteredPlots - } - private getSelectedComparisonPlots( paths: string[], selectedRevisions: string[] diff --git a/extension/src/plots/webview/messages.ts b/extension/src/plots/webview/messages.ts index 6a9d82441f..a0bd20cc0f 100644 --- a/extension/src/plots/webview/messages.ts +++ b/extension/src/plots/webview/messages.ts @@ -37,6 +37,7 @@ import { doesCustomPlotAlreadyExist, isCheckpointValue } from '../model/custom' +import { getCustomPlotId } from '../model/collect' export class WebviewMessages { private readonly paths: PathsModel @@ -278,20 +279,23 @@ export class WebviewMessages { } private setCustomPlotsOrder(plotIds: string[]) { - const customPlots = this.plots.getCustomPlots()?.plots - if (!customPlots) { - return - } + const customPlotsOrderWithId = this.plots + .getCustomPlotsOrder() + .map(value => ({ + ...value, + id: getCustomPlotId(value.metric, value.param) + })) const newOrder: CustomPlotsOrderValue[] = reorderObjectList( plotIds, - customPlots, + customPlotsOrderWithId, 'id' ).map(({ metric, param, type }) => ({ metric, param, type })) + this.plots.setCustomPlotsOrder(newOrder) this.sendCustomPlotsAndEvent(EventName.VIEWS_REORDER_PLOTS_CUSTOM) }