Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ const {
} = loadCsv('./data.csv', {
featureColumns: ['lat', 'lng', 'height'],
labelColumns: ['temperature'],
shuffle: true,
shuffle: true, // Pass true to shuffle with a fixed seed, or a string to use it as a seed for the shuffling.
splitTest: true, // Splits your data in half. You can also provide a certain row count for the test data.
prependOnes: true, // Prepends a column of 1s to your features and testFeatures tensors, useful for linear regression.
standardise: true, // Calculates mean and variance for each feature column using data only in features, then standardises the values in features and testFeatures. Does not touch labels.
Expand Down
10 changes: 7 additions & 3 deletions src/filterColumns.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
const filterColumns = (data: (string | number)[][], columnNames: string[]) => {
const indexKeepDecisions = data[0].map(
import { CsvTable } from './loadCsv.models';

const filterColumns = (table: CsvTable, columnNames: string[]) => {
const indexKeepDecisions = table[0].map(
(header) => columnNames.indexOf(header as string) > -1
);
return data.map((row) => row.filter((_, index) => indexKeepDecisions[index]));
return table.map((row) =>
row.filter((_, index) => indexKeepDecisions[index])
);
};

export default filterColumns;
8 changes: 6 additions & 2 deletions src/loadCsv.models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ export interface CsvReadOptions {
*/
labelColumns: string[];
/**
* If true, shuffles all rows.
* If true, shuffles all rows with a fixed seed, meaning that shuffling the same data will always result in the same shuffled data.
*
* You can pass a string instead of a boolean to customise the shuffle seed.
*/
shuffle?: boolean;
shuffle?: boolean | string;
/**
* If true, splits your features and labels in half and moves them into testFeatures and testLabels.
*
Expand All @@ -26,3 +28,5 @@ export interface CsvReadOptions {
*/
standardise?: boolean;
}

export type CsvTable = (string | number)[][];
85 changes: 49 additions & 36 deletions src/loadCsv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,28 @@ import fs from 'fs';
import * as tf from '@tensorflow/tfjs';
import { shuffle } from 'shuffle-seed';

import { CsvReadOptions } from './loadCsv.models';
import { CsvReadOptions, CsvTable } from './loadCsv.models';
import filterColumns from './filterColumns';

const shuffleSeed = 'mncv9340ur'; // TODO: Randomise this.
const defaultShuffleSeed = 'mncv9340ur';

const splitTestData = (
features: CsvTable,
labels: CsvTable,
splitTest: boolean | number
) => {
const length =
typeof splitTest === 'number'
? Math.max(0, Math.min(splitTest, features.length - 1))
: Math.floor(features.length / 2);

return {
testFeatures: features.slice(length),
testLabels: labels.slice(length),
features: features.slice(0, length),
labels: labels.slice(0, length),
};
};

const loadCsv = (filename: string, options: CsvReadOptions) => {
const {
Expand All @@ -32,58 +50,53 @@ const loadCsv = (filename: string, options: CsvReadOptions) => {
})
);

let labels = filterColumns(data, labelColumns);
let features = filterColumns(data, featureColumns);
let testFeatures: (string | number)[][] = [];
let testLabels: (string | number)[][] = [];
const tables: { [key: string]: CsvTable } = {
labels: filterColumns(data, labelColumns),
features: filterColumns(data, featureColumns),
testFeatures: [],
testLabels: [],
};

features.shift();
labels.shift();
tables.labels.shift();
tables.features.shift();

if (shouldShuffle) {
features = shuffle(features, shuffleSeed);
labels = shuffle(labels, shuffleSeed);
const seed =
typeof shouldShuffle === 'string' ? shouldShuffle : defaultShuffleSeed;
tables.features = shuffle(tables.features, seed);
tables.labels = shuffle(tables.labels, seed);
}

if (splitTest) {
const length =
typeof splitTest === 'number'
? Math.max(0, Math.min(splitTest, features.length - 1))
: Math.floor(features.length / 2);

testFeatures = features.slice(length);
testLabels = labels.slice(length);
features = features.slice(0, length);
labels = labels.slice(0, length);
Object.assign(
tables,
splitTestData(tables.features, tables.labels, splitTest)
);
}

let featuresTensor = tf.tensor(features);
let testFeaturesTensor = tf.tensor(testFeatures);
let features = tf.tensor(tables.features);
let testFeatures = tf.tensor(tables.testFeatures);

const labelsTensor = tf.tensor(labels);
const testLabelsTensor = tf.tensor(testLabels);
const labels = tf.tensor(tables.labels);
const testLabels = tf.tensor(tables.testLabels);

const { mean, variance } = tf.moments(featuresTensor, 0);
const { mean, variance } = tf.moments(features, 0);

if (standardise) {
featuresTensor = featuresTensor.sub(mean).div(variance.pow(0.5));
testFeaturesTensor = testFeaturesTensor.sub(mean).div(variance.pow(0.5));
features = features.sub(mean).div(variance.pow(0.5));
testFeatures = testFeatures.sub(mean).div(variance.pow(0.5));
}

if (prependOnes) {
featuresTensor = tf
.ones([featuresTensor.shape[0], 1])
.concat(featuresTensor, 1);
testFeaturesTensor = tf
.ones([testFeaturesTensor.shape[0], 1])
.concat(testFeaturesTensor, 1);
features = tf.ones([features.shape[0], 1]).concat(features, 1);
testFeatures = tf.ones([testFeatures.shape[0], 1]).concat(testFeatures, 1);
}

return {
features: featuresTensor,
labels: labelsTensor,
testFeatures: testFeaturesTensor,
testLabels: testLabelsTensor,
features,
labels,
testFeatures,
testLabels,
mean,
variance,
};
Expand Down
24 changes: 24 additions & 0 deletions tests/loadCsv.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,30 @@ test('Shuffling should work and preserve feature - label pairs', () => {
]);
});

test('Shuffling with a custom seed should work', () => {
const { features, labels } = loadCsv(filePath, {
featureColumns: ['lat', 'lng'],
labelColumns: ['country'],
shuffle: 'hello-is-it-me-you-are-looking-for',
});
// @ts-ignore
expect(features.arraySync()).toBeDeepCloseTo(
[
[5, 40.34],
[102, -164],
[0.234, 1.47],
[-93.2, 103.34],
],
3
);
expect(labels.arraySync()).toMatchObject([
['Landistan'],
['Landotzka'],
['SomeCountria'],
['SomeOtherCountria'],
]);
});

test('Loading with all extra options other than shuffle as true should work', () => {
const {
features,
Expand Down