Skip to content

Commit

Permalink
Fix missing isTypedArray when mixing versions of @tensorflow packages (
Browse files Browse the repository at this point in the history
…#7489)

A new function, `isTypedArray` was added to the `platform` interface by #7181
and first published in tfjs-core 4.2.0. This made 4.2.0 incompatible with
earlier versions of backends that implemented `platform`, such as node and
react-native. This change adds a fallback to the use of `isTypedArray` so
earlier versions of platforms that don't implement `isTypedArray` will not throw
an error.

Note that the fallback behavior may not be perfect, such as when running Jest
tests in node. See #7175 for more details and upgrade all @tensorflow scoped
packages to ^4.2.0 to avoid this.
  • Loading branch information
mattsoulanille authored Mar 21, 2023
1 parent 8fc6fe9 commit 7a3e61a
Show file tree
Hide file tree
Showing 14 changed files with 348 additions and 309 deletions.
22 changes: 22 additions & 0 deletions tfjs-core/src/platforms/is_typed_array_browser.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

export function isTypedArrayBrowser(a: unknown): a is Uint8Array
| Float32Array | Int32Array | Uint8ClampedArray {
return a instanceof Float32Array || a instanceof Int32Array ||
a instanceof Uint8Array || a instanceof Uint8ClampedArray;
}
4 changes: 2 additions & 2 deletions tfjs-core/src/platforms/platform_browser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {BrowserLocalStorage, BrowserLocalStorageManager} from '../io/local_stora
import {ModelStoreManagerRegistry} from '../io/model_management';

import {Platform} from './platform';
import {isTypedArrayBrowser} from './is_typed_array_browser';

export class PlatformBrowser implements Platform {
// According to the spec, the built-in encoder can do only UTF-8 encoding.
Expand Down Expand Up @@ -93,8 +94,7 @@ export class PlatformBrowser implements Platform {

isTypedArray(a: unknown): a is Uint8Array | Float32Array | Int32Array
| Uint8ClampedArray {
return a instanceof Float32Array || a instanceof Int32Array ||
a instanceof Uint8Array || a instanceof Uint8ClampedArray;
return isTypedArrayBrowser(a);
}
}

Expand Down
8 changes: 7 additions & 1 deletion tfjs-core/src/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

import {env} from './environment';
import {isTypedArrayBrowser} from './platforms/is_typed_array_browser';
import {BackendValues, DataType, RecursiveArray, TensorLike, TypedArray} from './types';
import * as base from './util_base';
export * from './util_base';
Expand Down Expand Up @@ -134,7 +135,12 @@ export function decodeString(bytes: Uint8Array, encoding = 'utf-8'): string {

export function isTypedArray(a: {}): a is Float32Array|Int32Array|Uint8Array|
Uint8ClampedArray {
return env().platform.isTypedArray(a);
// TODO(mattsoulanille): Remove this fallback in 5.0.0
if (env().platform.isTypedArray != null) {
return env().platform.isTypedArray(a);
} else {
return isTypedArrayBrowser(a);
}
}

// NOTE: We explicitly type out what T extends instead of any so that
Expand Down
17 changes: 17 additions & 0 deletions tfjs-core/src/util_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {ALL_ENVS, describeWithFlags} from './jasmine_util';
import {complex, scalar, tensor2d} from './ops/ops';
import {inferShape} from './tensor_util_env';
import * as util from './util';
import {env} from './environment';

describe('Util', () => {
it('Correctly gets size from shape', () => {
Expand Down Expand Up @@ -133,6 +134,22 @@ describe('Util', () => {
];
expect(inferShape(a, 'string')).toEqual([2, 2, 1]);
});
describe('isTypedArray', () => {
it('checks if a value is a typed array', () => {
expect(util.isTypedArray(new Uint8Array([1,2,3]))).toBeTrue();
expect(util.isTypedArray([1,2,3])).toBeFalse();
});
it('uses fallback if platform is missing isTypedArray', () => {
const tmpIsTypedArray = env().platform.isTypedArray;
try {
env().platform.isTypedArray = null;
expect(util.isTypedArray(new Uint8Array([1,2,3]))).toBeTrue();
expect(util.isTypedArray([1,2,3])).toBeFalse();
} finally {
env().platform.isTypedArray = tmpIsTypedArray;
}
});
});
});

describe('util.flatten', () => {
Expand Down
6 changes: 3 additions & 3 deletions tfjs-node-gpu/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,17 @@
},
"devDependencies": {
"@tensorflow/tfjs-core": "link:../link-package/node_modules/@tensorflow/tfjs-core",
"@types/jasmine": "~2.8.6",
"@types/jasmine": "~4.0.3",
"@types/node": "^10.5.1",
"@types/progress": "^2.0.1",
"@types/rimraf": "~2.0.2",
"@types/yargs": "^13.0.3",
"clang-format": "~1.8.0",
"jasmine": "~3.1.0",
"jasmine": "~4.2.1",
"node-fetch": "~2.6.1",
"nyc": "^15.1.0",
"tmp": "^0.0.33",
"ts-node": "^5.0.1",
"ts-node": "~8.8.2",
"tslint": "~6.1.3",
"tslint-no-circular-imports": "^0.7.0",
"typescript": "4.9.4",
Expand Down
6 changes: 3 additions & 3 deletions tfjs-node/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,17 @@
},
"devDependencies": {
"@tensorflow/tfjs-core": "link:../link-package/node_modules/@tensorflow/tfjs-core",
"@types/jasmine": "~2.8.6",
"@types/jasmine": "~4.0.3",
"@types/node": "^10.5.1",
"@types/progress": "^2.0.1",
"@types/rimraf": "~2.0.2",
"@types/yargs": "^13.0.3",
"clang-format": "~1.8.0",
"jasmine": "~3.1.0",
"jasmine": "~4.2.1",
"node-fetch": "~2.6.1",
"nyc": "^15.1.0",
"tmp": "^0.0.33",
"ts-node": "^5.0.1",
"ts-node": "~8.8.2",
"tslint": "~6.1.3",
"tslint-no-circular-imports": "^0.7.0",
"typescript": "4.9.4",
Expand Down
3 changes: 2 additions & 1 deletion tfjs-node/src/callbacks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ import * as ProgressBar from 'progress';

import {summaryFileWriter, SummaryFileWriter} from './tensorboard';

type LogFunction = (message: string) => void;
// A helper class created for testing with the jasmine `spyOn` method, which
// operates only on member methods of objects.
// tslint:disable-next-line:no-any
export const progressBarHelper: {ProgressBar: any, log: Function} = {
export const progressBarHelper: {ProgressBar: any, log: LogFunction} = {
ProgressBar,
log: console.log
};
Expand Down
55 changes: 18 additions & 37 deletions tfjs-node/src/image_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -273,50 +273,31 @@ describe('decode images', () => {
.toBe(beforeNumTFTensors + 1);
});

it('throw error if request non int32 dtype', async done => {
try {
const uint8array = await getUint8ArrayFromImage(
'test_objects/images/image_png_test.png');
tf.node.decodeImage(uint8array, 0, 'uint8');
done.fail();
} catch (error) {
expect(error.message)
.toBe(
'decodeImage could only return Tensor of type `int32` for now.');
done();
}
it('throw error if request non int32 dtype', async () => {
const uint8array = await getUint8ArrayFromImage(
'test_objects/images/image_png_test.png');
expect(() => tf.node.decodeImage(uint8array, 0, 'uint8')).toThrowError(
'decodeImage could only return Tensor of type `int32` for now.');
});

it('throw error if decode invalid image type', async done => {
try {
const uint8array = await getUint8ArrayFromImage('package.json');
tf.node.decodeImage(uint8array);
done.fail();
} catch (error) {
expect(error.message)
.toBe(
'Expected image (BMP, JPEG, PNG, or GIF), ' +
'but got unsupported image type');
done();
}
it('throw error if decode invalid image type', async () => {
const uint8array = await getUint8ArrayFromImage('package.json');
expect(() => tf.node.decodeImage(uint8array)).toThrowError(
'Expected image (BMP, JPEG, PNG, or GIF), ' +
'but got unsupported image type');
});

it('throw error if backend is not tensorflow', async done => {
it('throw error if backend is not tensorflow', async () => {
const testBackend = new TestKernelBackend();
registerBackend('fake', () => testBackend);
setBackend('fake');
try {
const testBackend = new TestKernelBackend();
registerBackend('fake', () => testBackend);
setBackend('fake');

const uint8array = await getUint8ArrayFromImage(
'test_objects/images/image_png_test.png');
tf.node.decodeImage(uint8array);
done.fail();
} catch (err) {
expect(err.message)
.toBe(
'Expect the current backend to be "tensorflow", but got "fake"');
'test_objects/images/image_png_test.png');
expect(() => tf.node.decodeImage(uint8array)).toThrowError(
'Expect the current backend to be "tensorflow", but got "fake"');
} finally {
setBackend('tensorflow');
done();
}
});
});
Expand Down
Loading

0 comments on commit 7a3e61a

Please sign in to comment.