Skip to content

Commit

Permalink
Fix avgPool3d (#7133)
Browse files Browse the repository at this point in the history
BUG

* fix webgl

* rename

* add strides check

* Update conv_util.ts

* reduce valid

* lint

* reduce

* add tests

* isArray

* Update pool_gpu.ts

* Update pool2d_webgpu.ts

* Update avg_pool_3d_test.ts

* Update avg_pool_3d_test.ts

* skip tests for node
  • Loading branch information
Linchenn authored Dec 6, 2022
1 parent ad4153d commit cd8c668
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 37 deletions.
5 changes: 3 additions & 2 deletions tfjs-backend-cpu/src/utils/pool_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,9 @@ export function pool3d(
}
}
const outputOffset = outputColOffset + channel;
outputVals[outputOffset] =
poolType === 'avg' ? avgValue / count : minMaxValue;
outputVals[outputOffset] = poolType === 'avg' ?
avgValue / Math.max(count, 1) :
minMaxValue;
}
}
}
Expand Down
9 changes: 6 additions & 3 deletions tfjs-backend-webgl/src/pool_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ export class Pool2DProgram implements GPGPUProgram {
let returnValue = `${poolType}(${poolType}(${poolType}(` +
'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
if (poolType === 'avg') {
returnValue = `avgValue / count`;
returnValue = `avgValue / max(count, 1.0)`;
}

const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
Expand Down Expand Up @@ -342,7 +342,10 @@ export class Pool3DProgram implements GPGPUProgram {
let returnValue = `${poolType}(${poolType}(${poolType}(` +
'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
if (poolType === 'avg') {
returnValue = `avgValue / count`;
// Use `max(count, 1.0)` instead of `count` in case count === 0.0.
// If count === 0.0, `avgValue` is always 0.0 and we change `count`'s
// value to avoid dividing zero.
returnValue = `avgValue / max(count, 1.0)`;
}

const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
Expand Down Expand Up @@ -448,8 +451,8 @@ export class Pool3DProgram implements GPGPUProgram {
${updateSnippet}
}
}
setOutput(${returnValue});
}
setOutput(${returnValue});
}
`;
}
Expand Down
2 changes: 1 addition & 1 deletion tfjs-backend-webgpu/src/pool2d_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ export class Pool2DProgram implements WebGPUProgram {

let returnValue = `resultValue`;
if (this.poolType === 'avg') {
returnValue = `resultValue / count`;
returnValue = `resultValue / max(count, 1.0)`;
}

const userCode = `
Expand Down
7 changes: 6 additions & 1 deletion tfjs-core/src/ops/avg_pool_3d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import * as util from '../util';

import {checkPadOnDimRoundingMode} from './conv_util';
import {cast} from './cast';
import {checkPadOnDimRoundingMode} from './conv_util';
import {op} from './operation';
import {reshape} from './reshape';

Expand Down Expand Up @@ -86,6 +86,11 @@ function avgPool3d_<T extends Tensor4D|Tensor5D>(
dataFormat === 'NDHWC',
() => `Error in avgPool3d: Only NDHWC is currently supported, ` +
`but got dataFormat of ${dataFormat}`);
util.assert(
(typeof strides === 'number' && strides > 0) ||
(Array.isArray(strides) && strides[0] > 0 && strides[1] > 0 &&
strides[2] > 0),
() => `Error in avgPool3d: Stride must be > 0, but got '${strides}'`);
checkPadOnDimRoundingMode('avgPool3d', pad, dimRoundingMode);
const inputs: AvgPool3DInputs = {x: x5D};
const attrs:
Expand Down
44 changes: 44 additions & 0 deletions tfjs-core/src/ops/avg_pool_3d_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ describeWithFlags('avgPool3d', ALL_ENVS, () => {
expectArraysClose(await result.data(), [4.5]);
});

it('x=[2,2,2,1] f=[1,2,2] s=1 p=valid', async () => {
const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2, 1]);

const result = tf.avgPool3d(x, [1, 2, 2], 1, 'valid');

expect(result.shape).toEqual([2, 1, 1, 1]);
expectArraysClose(await result.data(), [2.5, 6.5]);
});

it('x=[1,1,1,1,1] f=[1,1,1] s=1 [0] => [0]', async () => {
const x = tf.tensor5d([0], [1, 1, 1, 1, 1]);

Expand Down Expand Up @@ -150,6 +159,41 @@ describeWithFlags('avgPool3d', ALL_ENVS, () => {
expectArraysClose(await result.data(), expected);
});

it('x=[1,1,1,1,1] f=[1,1,3] s=1 p=valid', async () => {
// Output tensor would have a dimension of zero, if a certain filter's
// dimension is larger than the input's.
const x = tf.tensor5d([1], [1, 1, 1, 1, 1]);
const expected: number[] = [];
const result = tf.avgPool3d(x, [1, 1, 3], 1, 'valid');

expect(result.shape).toEqual([1, 1, 1, 0, 1]);
expectArraysClose(await result.data(), expected);
});

it('x=[1,1,1,4,1] f=[1,1,1] s=[1,1,2] p=0', async () => {
// Works if the padding is a number.
const x = tf.ones([1, 1, 1, 4, 1]) as tf.Tensor5D;
const expected = [1, 1];
const result = tf.avgPool3d(x, [1, 1, 1], [1, 1, 2], 0);

expect(result.shape).toEqual([1, 1, 1, 2, 1]);
expectArraysClose(await result.data(), expected);
});

it('x=[1,1,1,1,1] f=[2,2,2] s=1 p=2', async () => {
// Works if the padding is larger than filter size.
const x = tf.ones([1, 1, 1, 1, 1]) as tf.Tensor5D;
const expected = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
];
const result = tf.avgPool3d(x, [2, 2, 2], 1, 2);

expect(result.shape).toEqual([1, 4, 4, 4, 1]);
expectArraysClose(await result.data(), expected);
});

it('throws when x is not rank 5', async () => {
// tslint:disable-next-line:no-any
const x: any = tf.tensor1d([1]);
Expand Down
51 changes: 21 additions & 30 deletions tfjs-core/src/ops/conv_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -365,24 +365,23 @@ function computeOutputShape2D(
}

function computeOutputShape4D(
inShape: [number, number, number, number], fieldSize: number,
outChannels: number, stride: number, zeroPad?: number,
inShape: [number, number, number, number],
filterShape: [number, number, number], outChannels: number,
strides: [number, number, number], zeroPad?: number,
roundingMode?: 'floor'|'round'|'ceil'): [number, number, number, number] {
if (zeroPad == null) {
zeroPad = computeDefaultPad(inShape, fieldSize, stride);
zeroPad = computeDefaultPad(inShape, filterShape[0], strides[0]);
}
const inputDepth = inShape[0];
const inputRows = inShape[1];
const inputCols = inShape[2];

const outputDepths =
round((inputDepth - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
const outputRows =
round((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
const outputCols =
round((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);

return [outputDepths, outputRows, outputCols, outChannels];
const outShape: [number, number, number, number] = [0, 0, 0, outChannels];
for (let index = 0; index < 3; index++) {
if (inShape[index] + 2 * zeroPad >= filterShape[index]) {
outShape[index] = round(
(inShape[index] - filterShape[index] + 2 * zeroPad) / strides[index] +
1,
roundingMode);
}
}
return outShape;
}

export function computeDefaultPad(
Expand Down Expand Up @@ -496,6 +495,10 @@ function get3DPadAndOutInfo(
let outHeight: number;
let outWidth: number;

if (pad === 'valid') {
pad = 0;
}

if (typeof pad === 'number') {
const padType = (pad === 0) ? 'VALID' : 'NUMBER';
padInfo = {
Expand All @@ -508,8 +511,9 @@ function get3DPadAndOutInfo(
type: padType
};
const outShape = computeOutputShape4D(
[inDepth, inHeight, inWidth, 1], filterDepth, 1, strideDepth, pad,
roundingMode);
[inDepth, inHeight, inWidth, 1],
[filterDepth, filterHeight, filterWidth], 1,
[strideDepth, strideHeight, strideWidth], pad, roundingMode);
outDepth = outShape[0];
outHeight = outShape[1];
outWidth = outShape[2];
Expand All @@ -529,19 +533,6 @@ function get3DPadAndOutInfo(
const right = padAlongWidth - left;

padInfo = {top, bottom, left, right, front, back, type: 'SAME'};
} else if (pad === 'valid') {
padInfo = {
top: 0,
bottom: 0,
left: 0,
right: 0,
front: 0,
back: 0,
type: 'VALID'
};
outDepth = Math.ceil((inDepth - filterDepth + 1) / strideDepth);
outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
} else {
throw Error(`Unknown padding parameter: ${pad}`);
}
Expand Down
4 changes: 4 additions & 0 deletions tfjs-node/src/run_tests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ const IGNORE_LIST: string[] = [
'avgPool test-tensorflow {} gradient x=[3,3,1] f=[3,3] s=1 p=explicit',
// tslint:disable-next-line:max-line-length
'avgPool3d test-tensorflow {} x=[1,2,2,2,1] f=[2,2,2] s=1 p=1 roundingMode=floor',
// https://github.com/tensorflow/tensorflow/issues/58758
'avgPool3d test-tensorflow {} x=[1,1,1,1,1] f=[1,1,3] s=1 p=valid',
// Node backend which uses TF 2.11.0 doesn't support number padding
'avgPool3d test-tensorflow {} x=[1,1,1,1,1] f=[2,2,2] s=1 p=2',
// Node backend which uses TF 2.4.0 doesn't support explicit padding
'maxPool test-tensorflow {} x=[3,3,1] f=[3,3] s=1 p=explicit',
'maxPoolBackprop test-tensorflow {} gradient x=[3,3,1] f=3 s=1 p=explicit',
Expand Down

0 comments on commit cd8c668

Please sign in to comment.