Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 5 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,33 +53,23 @@ Advanced usage:
```js
import loadCsv from 'tensorflow-load-csv';

const {
features,
labels,
testFeatures,
testLabels,
mean, // tensor holding mean of features, ignores testFeatures
variance, // tensor holding variance of features, ignores testFeatures
} = loadCsv('./data.csv', {
const { features, labels, testFeatures, testLabels } = loadCsv('./data.csv', {
featureColumns: ['lat', 'lng', 'height'],
labelColumns: ['temperature'],
mappings: {
height: (ft) => ft * 0.3048, // feet to meters
temperature: (f) => (f < 50 ? [1, 0] : [0, 1]), // cold or hot classification
}, // Map values based on which column they are in before they are loaded into tensors.
flatten: ['temperature'], // Flattens the array result of a mapping so that each member is a new column.
shuffle: true, // Pass true to shuffle with a fixed seed, or a string to use it as a seed for the shuffling.
splitTest: true, // Splits your data in half. You can also provide a certain row count for the test data, or a percentage string (e.g. 10%).
prependOnes: true, // Prepends a column of 1s to your features and testFeatures tensors, useful for linear regression.
standardise: true, // Calculates mean and variance for each feature column using data only in features, then standardises the values in features and testFeatures. Does not touch labels.
shuffle: true, // Pass true to shuffle with a fixed seed, or a string to use as a seed for the shuffling.
splitTest: true, // Splits your data in half. You can also provide a certain row count for the test data, or a percentage string (e.g. '10%').
standardise: ['height'], // Calculates mean and variance for each feature column using data only in features, then standardises the values in features and testFeatures. Does not touch labels.
prependOnes: true, // Prepends a column of 1s to your features and testFeatures tensors, useful for regression problems.
});

features.print();
labels.print();

testFeatures.print();
testLabels.print();

mean.print();
variance.print();
```
8 changes: 4 additions & 4 deletions jest.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ module.exports = {
coveragePathIgnorePatterns: ['/node_modules/', '/tests/'],
coverageThreshold: {
global: {
branches: 90,
functions: 95,
lines: 95,
statements: 95,
branches: 100,
functions: 100,
lines: 100,
statements: 100,
},
},
collectCoverageFrom: ['src/*.{js,ts}'],
Expand Down
4 changes: 2 additions & 2 deletions src/loadCsv.models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ export interface CsvReadOptions {
*/
prependOnes?: boolean;
/**
* If true, calculates mean and variance for each feature column using data only in features, then standardises the values in features and testFeatures. Does not touch labels.
* Calculates mean and variance for given columns using data only in features, then standardises the values in features and testFeatures. Does not touch labels.
*/
standardise?: boolean | string[];
standardise?: string[];
/**
* Useful for classification problems, if you have mapped a column's values to an array using `mappings`, you can choose to flatten it here so that each element becomes a new column.
*
Expand Down
28 changes: 16 additions & 12 deletions src/loadCsv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import filterColumns from './filterColumns';
import splitTestData from './splitTestData';
import applyMappings from './applyMappings';
import shuffle from './shuffle';
import standardise from './standardise';

const defaultShuffleSeed = 'mncv9340ur';

Expand All @@ -31,10 +32,10 @@ const loadCsv = (
featureColumns,
labelColumns,
mappings = {},
shuffle: shouldShuffle = false,
shuffle: shouldShuffleOrSeed = false,
splitTest,
prependOnes = false,
standardise = false,
standardise: columnsToStandardise = [],
flatten = [],
}: CsvReadOptions
) => {
Expand All @@ -54,11 +55,13 @@ const loadCsv = (
};

tables.labels.shift();
tables.features.shift();
const featureColumnNames = tables.features.shift() as string[];

if (shouldShuffle) {
if (shouldShuffleOrSeed) {
const seed =
typeof shouldShuffle === 'string' ? shouldShuffle : defaultShuffleSeed;
typeof shouldShuffleOrSeed === 'string'
? shouldShuffleOrSeed
: defaultShuffleSeed;
tables.features = shuffle(tables.features, seed);
tables.labels = shuffle(tables.labels, seed);
}
Expand All @@ -76,11 +79,14 @@ const loadCsv = (
const labels = tf.tensor(tables.labels);
const testLabels = tf.tensor(tables.testLabels);

const { mean, variance } = tf.moments(features, 0);

if (standardise) {
features = features.sub(mean).div(variance.pow(0.5));
testFeatures = testFeatures.sub(mean).div(variance.pow(0.5));
if (columnsToStandardise.length > 0) {
const result = standardise(
features,
testFeatures,
featureColumnNames.map((c) => columnsToStandardise.includes(c))
);
features = result.features;
testFeatures = result.testFeatures;
}

if (prependOnes) {
Expand All @@ -93,8 +99,6 @@ const loadCsv = (
labels,
testFeatures,
testLabels,
mean,
variance,
};
};

Expand Down
68 changes: 68 additions & 0 deletions src/standardise.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import * as tf from '@tensorflow/tfjs';

const standardise = (
features: tf.Tensor<tf.Rank>,
testFeatures: tf.Tensor<tf.Rank>,
indicesToStandardise: boolean[]
): {
features: tf.Tensor<tf.Rank>;
testFeatures: tf.Tensor<tf.Rank>;
} => {
let newFeatures, newTestFeatures;

if (features.shape.length < 2 || testFeatures.shape.length < 2) {
throw new Error(
'features and testFeatures must have at least two dimensions'
);
}

if (features.shape[1] !== testFeatures.shape[1]) {
throw new Error(
'Length of the second dimension of features and testFeatures must be the same'
);
}

if (features.shape[1] !== indicesToStandardise.length) {
throw new Error(
'Length of indicesToStandardise must match the length of the second dimension of features'
);
}

if (features.shape[1] === 0) {
return { features, testFeatures };
}

for (let i = 0; i < features.shape[1]; i++) {
let featureSlice = features.slice([0, i], [features.shape[0], 1]);
let testFeatureSlice = testFeatures.slice(
[0, i],
[testFeatures.shape[0], 1]
);
if (indicesToStandardise[i]) {
const sliceMoments = tf.moments(featureSlice);
featureSlice = featureSlice
.sub(sliceMoments.mean)
.div(sliceMoments.variance.pow(0.5));
testFeatureSlice = testFeatureSlice
.sub(sliceMoments.mean)
.div(sliceMoments.variance.pow(0.5));
}
if (!newFeatures) {
newFeatures = featureSlice;
} else {
newFeatures = newFeatures.concat(featureSlice, 1);
}
if (!newTestFeatures) {
newTestFeatures = testFeatureSlice;
} else {
newTestFeatures = newTestFeatures.concat(testFeatureSlice, 1);
}
}

return {
features: newFeatures as tf.Tensor<tf.Rank>,
testFeatures: newTestFeatures as tf.Tensor<tf.Rank>,
};
};

export default standardise;
97 changes: 28 additions & 69 deletions tests/loadCsv.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,95 +46,54 @@ test('Loading with only the required options should work', () => {
]);
});

test('Shuffling should work and preserve feature - label pairs', () => {
const { features, labels } = loadCsv(filePath, {
test('Loading with all extra options should work', () => {
const { features, labels, testFeatures, testLabels } = loadCsv(filePath, {
featureColumns: ['lat', 'lng'],
labelColumns: ['country'],
mappings: {
country: (name) => (name as string).toUpperCase(),
lat: (lat) => ((lat as number) > 0 ? [0, 1] : [1, 0]), // South or North classification
},
flatten: ['lat'],
shuffle: true,
splitTest: true,
prependOnes: true,
standardise: ['lng'],
});
// @ts-ignore
expect(features.arraySync()).toBeDeepCloseTo(
[
[5, 40.34],
[0.234, 1.47],
[-93.2, 103.34],
[102, -164],
[1, 0, 1, 1],
[1, 0, 1, -1],
],
3
);
expect(labels.arraySync()).toMatchObject([
['Landistan'],
['SomeCountria'],
['SomeOtherCountria'],
['Landotzka'],
]);
});

test('Shuffling with a custom seed should work', () => {
const { features, labels } = loadCsv(filePath, {
featureColumns: ['lat', 'lng'],
labelColumns: ['country'],
shuffle: 'hello-is-it-me-you-are-looking-for',
});
expect(labels.arraySync()).toMatchObject([['LANDISTAN'], ['SOMECOUNTRIA']]);
// @ts-ignore
expect(features.arraySync()).toBeDeepCloseTo(
expect(testFeatures.arraySync()).toBeDeepCloseTo(
[
[-93.2, 103.34],
[102, -164],
[5, 40.34],
[0.234, 1.47],
[1, 1, 0, 4.241],
[1, 0, 1, -9.514],
],
3
);
expect(labels.arraySync()).toMatchObject([
['SomeOtherCountria'],
['Landotzka'],
['Landistan'],
['SomeCountria'],
expect(testLabels.arraySync()).toMatchObject([
['SOMEOTHERCOUNTRIA'],
['LANDOTZKA'],
]);
});

test('Loading with all extra options other than shuffle as true should work', () => {
const {
features,
labels,
testFeatures,
testLabels,
mean,
variance,
} = loadCsv(filePath, {
test('Loading with custom seed should use the custom seed', () => {
const { features } = loadCsv(filePath, {
featureColumns: ['lat', 'lng'],
labelColumns: ['country'],
mappings: {
country: (name) => (name as string).toUpperCase(),
},
splitTest: true,
prependOnes: true,
standardise: true,
shuffle: true,
});
const { features: featuresCustom } = loadCsv(filePath, {
featureColumns: ['lat', 'lng'],
labelColumns: ['country'],
shuffle: 'sdhjhdf',
});
// @ts-ignore
expect(features.arraySync()).toBeDeepCloseTo(
[
[1, 1, -1],
[1, -1, 1],
],
3
);
expect(labels.arraySync()).toMatchObject([
['SOMECOUNTRIA'],
['SOMEOTHERCOUNTRIA'],
]);
// @ts-ignore
expect(testFeatures.arraySync()).toBeDeepCloseTo(
[
[1, 1.102, -0.236],
[1, 3.178, -4.248],
],
3
);
expect(testLabels.arraySync()).toMatchObject([['LANDISTAN'], ['LANDOTZKA']]);
// @ts-ignore
expect(mean.arraySync()).toBeDeepCloseTo([-46.482, 52.404], 3);
// @ts-ignore
expect(variance.arraySync()).toBeDeepCloseTo([2182.478, 2594.374], 3);
expect(features).not.toBeDeepCloseTo(featuresCustom, 1);
});
Loading