From 32296174922b0b4f1d2aef84b268e31df566d794 Mon Sep 17 00:00:00 2001 From: Baris Sencan Date: Wed, 5 Aug 2020 21:55:26 +0100 Subject: [PATCH 1/2] feat(lib): Support passing a seed to the shuffle option --- README.md | 2 +- src/filterColumns.ts | 10 +++-- src/loadCsv.models.ts | 8 +++- src/loadCsv.ts | 85 +++++++++++++++++++++++++------------------ tests/loadCsv.test.ts | 24 ++++++++++++ 5 files changed, 87 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index 15ad69d..10cc5cd 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ const { } = loadCsv('./data.csv', { featureColumns: ['lat', 'lng', 'height'], labelColumns: ['temperature'], - shuffle: true, + shuffle: true, // Pass true to shuffle with a fixed seed, or a string to use it as a seed for the shuffling. splitTest: true, // Splits your data in half. You can also provide a certain row count for the test data. prependOnes: true, // Prepends a column of 1s to your features and testFeatures tensors, useful for linear regression. standardise: true, // Calculates mean and variance for each feature column using data only in features, then standardises the values in features and testFeatures. Does not touch labels. diff --git a/src/filterColumns.ts b/src/filterColumns.ts index 4ea6f37..aec6b0a 100644 --- a/src/filterColumns.ts +++ b/src/filterColumns.ts @@ -1,8 +1,12 @@ -const filterColumns = (data: (string | number)[][], columnNames: string[]) => { - const indexKeepDecisions = data[0].map( +import { CsvTable } from './loadCsv.models'; + +const filterColumns = (table: CsvTable, columnNames: string[]) => { + const indexKeepDecisions = table[0].map( (header) => columnNames.indexOf(header as string) > -1 ); - return data.map((row) => row.filter((_, index) => indexKeepDecisions[index])); + return table.map((row) => + row.filter((_, index) => indexKeepDecisions[index]) + ); }; export default filterColumns; diff --git a/src/loadCsv.models.ts b/src/loadCsv.models.ts index 47044f0..70eab87 100644 --- a/src/loadCsv.models.ts +++ b/src/loadCsv.models.ts @@ -8,9 +8,11 @@ export interface CsvReadOptions { */ labelColumns: string[]; /** - * If true, shuffles all rows. + * If true, shuffles all rows with a fixed seed, meaning that shuffling the same data will always result in the same shuffled data. + * + * You pass a string instead of a boolean to customise the shuffle seed. */ - shuffle?: boolean; + shuffle?: boolean | string; /** * If true, splits your features and labels in half and moves them into testFeatures and testLabels. * @@ -26,3 +28,5 @@ export interface CsvReadOptions { */ standardise?: boolean; } + +export type CsvTable = (string | number)[][]; diff --git a/src/loadCsv.ts b/src/loadCsv.ts index 2d6f615..1785241 100644 --- a/src/loadCsv.ts +++ b/src/loadCsv.ts @@ -3,10 +3,28 @@ import fs from 'fs'; import * as tf from '@tensorflow/tfjs'; import { shuffle } from 'shuffle-seed'; -import { CsvReadOptions } from './loadCsv.models'; +import { CsvReadOptions, CsvTable } from './loadCsv.models'; import filterColumns from './filterColumns'; -const shuffleSeed = 'mncv9340ur'; // TODO: Randomise this. +const defaultShuffleSeed = 'mncv9340ur'; + +const splitTestData = ( + features: CsvTable, + labels: CsvTable, + splitTest: boolean | number +) => { + const length = + typeof splitTest === 'number' + ? Math.max(0, Math.min(splitTest, features.length - 1)) + : Math.floor(features.length / 2); + + return { + testFeatures: features.slice(length), + testLabels: labels.slice(length), + features: features.slice(0, length), + labels: labels.slice(0, length), + }; +}; const loadCsv = (filename: string, options: CsvReadOptions) => { const { @@ -32,58 +50,53 @@ const loadCsv = (filename: string, options: CsvReadOptions) => { }) ); - let labels = filterColumns(data, labelColumns); - let features = filterColumns(data, featureColumns); - let testFeatures: (string | number)[][] = []; - let testLabels: (string | number)[][] = []; + const tables: { [key: string]: CsvTable } = { + labels: filterColumns(data, labelColumns), + features: filterColumns(data, featureColumns), + testFeatures: [], + testLabels: [], + }; - features.shift(); - labels.shift(); + tables.labels.shift(); + tables.features.shift(); if (shouldShuffle) { - features = shuffle(features, shuffleSeed); - labels = shuffle(labels, shuffleSeed); + const seed = + typeof shouldShuffle === 'string' ? shouldShuffle : defaultShuffleSeed; + tables.features = shuffle(tables.features, seed); + tables.labels = shuffle(tables.labels, seed); } if (splitTest) { - const length = - typeof splitTest === 'number' - ? Math.max(0, Math.min(splitTest, features.length - 1)) - : Math.floor(features.length / 2); - - testFeatures = features.slice(length); - testLabels = labels.slice(length); - features = features.slice(0, length); - labels = labels.slice(0, length); + Object.assign( + tables, + splitTestData(tables.features, tables.labels, splitTest) + ); } - let featuresTensor = tf.tensor(features); - let testFeaturesTensor = tf.tensor(testFeatures); + let features = tf.tensor(tables.features); + let testFeatures = tf.tensor(tables.testFeatures); - const labelsTensor = tf.tensor(labels); - const testLabelsTensor = tf.tensor(testLabels); + const labels = tf.tensor(tables.labels); + const testLabels = tf.tensor(tables.testLabels); - const { mean, variance } = tf.moments(featuresTensor, 0); + const { mean, variance } = tf.moments(features, 0); if (standardise) { - featuresTensor = featuresTensor.sub(mean).div(variance.pow(0.5)); - testFeaturesTensor = testFeaturesTensor.sub(mean).div(variance.pow(0.5)); + features = features.sub(mean).div(variance.pow(0.5)); + testFeatures = testFeatures.sub(mean).div(variance.pow(0.5)); } if (prependOnes) { - featuresTensor = tf - .ones([featuresTensor.shape[0], 1]) - .concat(featuresTensor, 1); - testFeaturesTensor = tf - .ones([testFeaturesTensor.shape[0], 1]) - .concat(testFeaturesTensor, 1); + features = tf.ones([features.shape[0], 1]).concat(features, 1); + testFeatures = tf.ones([testFeatures.shape[0], 1]).concat(testFeatures, 1); } return { - features: featuresTensor, - labels: labelsTensor, - testFeatures: testFeaturesTensor, - testLabels: testLabelsTensor, + features, + labels, + testFeatures, + testLabels, mean, variance, }; diff --git a/tests/loadCsv.test.ts b/tests/loadCsv.test.ts index a23dd09..b96e941 100644 --- a/tests/loadCsv.test.ts +++ b/tests/loadCsv.test.ts @@ -52,6 +52,30 @@ test('Shuffling should work and preserve feature - label pairs', () => { ]); }); +test('Shuffling with a custom seed should work', () => { + const { features, labels } = loadCsv(filePath, { + featureColumns: ['lat', 'lng'], + labelColumns: ['country'], + shuffle: 'hello-is-it-me-you-are-looking-for', + }); + // @ts-ignore + expect(features.arraySync()).toBeDeepCloseTo( + [ + [5, 40.34], + [102, -164], + [0.234, 1.47], + [-93.2, 103.34], + ], + 3 + ); + expect(labels.arraySync()).toMatchObject([ + ['Landistan'], + ['Landotzka'], + ['SomeCountria'], + ['SomeOtherCountria'], + ]); +}); + test('Loading with all extra options other than shuffle as true should work', () => { const { features, From 0159e303a3348226ed6545420391554d0b349601 Mon Sep 17 00:00:00 2001 From: Baris Sencan Date: Wed, 5 Aug 2020 21:58:52 +0100 Subject: [PATCH 2/2] docs(lib): Fix typo --- src/loadCsv.models.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/loadCsv.models.ts b/src/loadCsv.models.ts index 70eab87..0ffb99a 100644 --- a/src/loadCsv.models.ts +++ b/src/loadCsv.models.ts @@ -10,7 +10,7 @@ export interface CsvReadOptions { /** * If true, shuffles all rows with a fixed seed, meaning that shuffling the same data will always result in the same shuffled data. * - * You pass a string instead of a boolean to customise the shuffle seed. + * You can pass a string instead of a boolean to customise the shuffle seed. */ shuffle?: boolean | string; /**