Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

src: clean up operator resolve rules #210

Merged
merged 1 commit into from
Aug 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
303 changes: 164 additions & 139 deletions docs/operators.md

Large diffs are not rendered by default.

11 changes: 6 additions & 5 deletions lib/backends/cpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ export const CPU_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Acosh', '', '9+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.acosh)],
['Add', '', '7+', () => new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 + e2))],
['And', '', '7+', () => new CpuBinaryOp(['bool'], (e1, e2) => (e1 && e2))],
['ArgMax', '', '1+', () => new CpuArgMax()],
['ArgMax', '', '1-11', () => new CpuArgMax()],
['Asin', '', '7+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.asin)],
['Asinh', '', '9+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.asinh)],
['Atan', '', '7+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.atan)],
['Atanh', '', '9+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.atanh)],
['AveragePool', '', '7+', () => new CpuAveragePool()], // TODO: support new attributes for AveragePool-10
['AveragePool', '', '7-10', () => new CpuAveragePool()], // TODO: support new attributes for AveragePool-10
['BatchNormalization', '', '7+', () => new CpuBatchNormalization()],
['Ceil', '', '6+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.ceil)],
['Clip', '', '6+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.clip, unaryOps.clipInitializer)],
['Clip', '', '6-10', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.clip, unaryOps.clipInitializer)],
['Concat', '', '4+', () => new CpuConcat()],
['Conv', '', '1+', () => new CpuConv()],
['Cos', '', '7+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.cos)],
Expand All @@ -58,7 +58,8 @@ export const CPU_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Flatten', '', '1+', () => new CpuFlatten()],
['Floor', '', '6+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.floor)],
['Gather', '', '1+', () => new CpuGather()],
['Gemm', '', '7+', () => new CpuGemm()],
['Gemm', '', '7-10', () => new CpuGemm(false)],
['Gemm', '', '11+', () => new CpuGemm(true)],
['GlobalAveragePool', '', '1+', () => new CpuGlobalAveragePool()],
['GlobalMaxPool', '', '1+', () => new CpuGlobalMaxPool()],
['ImageScaler', '', '1+', () => new CpuImageScaler()],
Expand All @@ -68,7 +69,7 @@ export const CPU_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Log', '', '6+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.log)],
['LRN', '', '1+', () => new CpuLrn()],
['MatMul', '', '1+', () => new CpuMatMul()],
['MaxPool', '', '1+', () => new CpuMaxPool()], // TODO: support new attributes for MaxPool-8 and MaxPool-10
['MaxPool', '', '1-9', () => new CpuMaxPool()], // TODO: support new attributes for MaxPool-8 and MaxPool-10
['Mul', '', '7+', () => new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 * e2))],
['Neg', '', '6+', () => new CpuUnaryOp(NUMBER_TYPES, unaryOps.neg)],
['Not', '', '1+', () => new CpuUnaryOp(['bool'], unaryOps.not, undefined, 'bool')],
Expand Down
2 changes: 1 addition & 1 deletion lib/backends/cpu/ops/argMax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export class CpuArgMax extends ArgMax {

export function argMax(x: Tensor, axis: number, keepdims: boolean): Tensor {
const rank = x.dims ? x.dims.length : 1;
axis = ShapeUtil.parseAxis(axis, rank);
axis = ShapeUtil.normalizeAxis(axis, rank);
const outputDims = ReduceUtil.calcReduceShape(x.dims, [axis], true);
const X = x.data;
const Y = new Int32Array(ShapeUtil.size(outputDims));
Expand Down
5 changes: 3 additions & 2 deletions lib/backends/cpu/ops/gather.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export class CpuGather extends Gather {
}

export function gather(x: Tensor, indices: Tensor, axis: number): Tensor {
axis = ShapeUtil.parseAxis(axis, x.dims.length);
axis = ShapeUtil.normalizeAxis(axis, x.dims.length);
const dims = x.dims.slice();
const newDims = dims.slice();
const indicesData = indices.data;
Expand All @@ -24,7 +24,8 @@ export function gather(x: Tensor, indices: Tensor, axis: number): Tensor {
for (let i = 0; i < Y.length; ++i) {
const newLogicalIndex = ShapeUtil.offsetToIndices(i, newDimsStrides);
const oldLogicalIndex = newLogicalIndex.slice();
oldLogicalIndex[axis] = indicesData[newLogicalIndex[axis]] as number;
const idx = indicesData[newLogicalIndex[axis]] as number;
oldLogicalIndex[axis] = idx < 0 ? idx + dims[axis] : idx;
const oldOffset = ShapeUtil.indicesToOffset(oldLogicalIndex, dimsStrides);
Y[i] = X[oldOffset] as number;
}
Expand Down
10 changes: 6 additions & 4 deletions lib/backends/cpu/ops/gemm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,20 @@ import {matMul2d} from './matmul';

export class CpuGemm extends Gemm {
run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] {
const output = gemm(inputs[0], inputs[1], inputs[2], this.alpha, this.beta, this.transA, this.transB);
const output = gemm(
inputs[0], inputs[1], this.alpha, this.beta, this.transA, this.transB,
inputs.length === 3 ? inputs[2] : undefined);
return [output];
}
}

export function gemm(a: Tensor, b: Tensor, c: Tensor, alpha: number, beta: number, transA: boolean, transB: boolean) {
const [M, N, K] = util.GemmUtil.getShapeOfGemmResult(a.dims, transA, b.dims, transB, c.dims);
export function gemm(a: Tensor, b: Tensor, alpha: number, beta: number, transA: boolean, transB: boolean, c?: Tensor) {
const [M, N, K] = util.GemmUtil.getShapeOfGemmResult(a.dims, transA, b.dims, transB, c?.dims);

// The result will always be of the shape [M,N]
const output = new Tensor([M, N], a.type);
// broadcast and assign value from C to output
if (util.BroadcastUtil.calc(output, c, (a, b) => b, true) !== output) {
if (c && util.BroadcastUtil.calc(output, c, (a, b) => b, true) !== output) {
throw new Error(`tensor C is not broadcastable to [M,N]`);
}

Expand Down
14 changes: 7 additions & 7 deletions lib/backends/cpu/ops/reduce.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,49 +8,49 @@ import {CpuInferenceHandler} from '../inference-handler';

export class CpuReduceSum extends ReduceBase {
run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]> {
const output = reduceSum(inputs[0], this.axes, this.keepDims);
const output = reduceSum(inputs[0], ShapeUtil.normalizeAxes(this.axes, inputs[0].dims.length), this.keepDims);
return [output];
}
}

export class CpuReduceSumSquare extends ReduceBase {
run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] {
const output = reduceSumSquare(inputs[0], this.axes, this.keepDims);
const output = reduceSumSquare(inputs[0], ShapeUtil.normalizeAxes(this.axes, inputs[0].dims.length), this.keepDims);
return [output];
}
}

export class CpuReduceLogSum extends ReduceBase {
run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] {
const output = reduceLogSum(inputs[0], this.axes, this.keepDims);
const output = reduceLogSum(inputs[0], ShapeUtil.normalizeAxes(this.axes, inputs[0].dims.length), this.keepDims);
return [output];
}
}

export class CpuReduceMax extends ReduceBase {
run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] {
const output = reduceMax(inputs[0], this.axes, this.keepDims);
const output = reduceMax(inputs[0], ShapeUtil.normalizeAxes(this.axes, inputs[0].dims.length), this.keepDims);
return [output];
}
}

export class CpuReduceMin extends ReduceBase {
run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] {
const output = reduceMin(inputs[0], this.axes, this.keepDims);
const output = reduceMin(inputs[0], ShapeUtil.normalizeAxes(this.axes, inputs[0].dims.length), this.keepDims);
return [output];
}
}

export class CpuReduceMean extends ReduceBase {
run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] {
const output = reduceMean(inputs[0], this.axes, this.keepDims);
const output = reduceMean(inputs[0], ShapeUtil.normalizeAxes(this.axes, inputs[0].dims.length), this.keepDims);
return [output];
}
}

export class CpuReduceProd extends ReduceBase {
run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] {
const output = reduceProd(inputs[0], this.axes, this.keepDims);
const output = reduceProd(inputs[0], ShapeUtil.normalizeAxes(this.axes, inputs[0].dims.length), this.keepDims);
return [output];
}
}
Expand Down
6 changes: 3 additions & 3 deletions lib/backends/cpu/ops/slice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,18 @@ export function slice(
if (axes.length === 0) {
axes = x.dims.map((val, ind) => ind);
}
axes = axes.map(axis => ShapeUtil.parseAxis(axis, x.dims.length));
axes = ShapeUtil.normalizeAxes(axes, x.dims.length);
starts = starts.map((start, ind) => {
if (start > x.dims[axes[ind]] - 1) {
return x.dims[axes[ind]];
}
return ShapeUtil.parseAxis(start, x.dims[axes[ind]]);
return ShapeUtil.normalizeAxis(start, x.dims[axes[ind]]);
});
ends = ends.map((end, ind) => {
if (end > x.dims[axes[ind]] - 1) {
return x.dims[axes[ind]];
}
return ShapeUtil.parseAxis(end, x.dims[axes[ind]]);
return ShapeUtil.normalizeAxis(end, x.dims[axes[ind]]);
});
const size: number[] = [];
const adjustedStarts: number[] = [];
Expand Down
6 changes: 3 additions & 3 deletions lib/backends/cpu/ops/softmax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ export function softmax(x: Tensor, axis: number): Tensor {
const inputDimensions = x.dims;
const inputRank = inputDimensions.length;

const axisCorrected = util.ShapeUtil.parseAxis(axis, inputRank);
const N = util.ShapeUtil.sizeToDimension(inputDimensions, axisCorrected);
const D = util.ShapeUtil.sizeFromDimension(inputDimensions, axisCorrected);
axis = util.ShapeUtil.normalizeAxis(axis, inputRank);
const N = util.ShapeUtil.sizeToDimension(inputDimensions, axis);
const D = util.ShapeUtil.sizeFromDimension(inputDimensions, axis);

const X = x.numberData;

Expand Down
9 changes: 5 additions & 4 deletions lib/backends/wasm/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@ import {WasmSum} from './ops/sum';
export const WASM_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Add', '', '7+', () => new WasmBinaryOp(['float32'], 'Add')],
['And', '', '7+', () => new WasmBinaryOp(['bool'], 'And')],
['AveragePool', '', '7+', () => new WasmAveragePool()], // TODO: support new attributes for AveragePool-10
['AveragePool', '', '7-10', () => new WasmAveragePool()], // TODO: support new attributes for AveragePool-10
['BatchNormalization', '', '7+', () => new WasmBatchNormalization()],
['Clip', '', '6+', () => new WasmClip()],
['Clip', '', '6-10', () => new WasmClip()],
['Conv', '', '1+', () => new WasmConv()],
['Div', '', '7+', () => new WasmBinaryOp(['float32'], 'Div')],
['Gemm', '', '7+', () => new WasmGemm()],
['Gemm', '', '7-10', () => new WasmGemm(false)],
['Gemm', '', '11+', () => new WasmGemm(true)],
['GlobalAveragePool', '', '1+', () => new WasmGlobalAveragePool()],
['GlobalMaxPool', '', '1+', () => new WasmGlobalMaxPool()],
['InstanceNormalization', '', '6+', () => new WasmInstanceNormalization()],
['MatMul', '', '1+', () => new WasmMatMul()],
['MaxPool', '', '1+', () => new WasmMaxPool()], // TODO: support new attributes for MaxPool-8 and MaxPool-10
['MaxPool', '', '1-9', () => new WasmMaxPool()], // TODO: support new attributes for MaxPool-8 and MaxPool-10
['Mul', '', '7+', () => new WasmBinaryOp(['float32'], 'Mul')],
['Or', '', '7+', () => new WasmBinaryOp(['bool'], 'Or')],
['PRelu', '', '7+', () => new WasmBinaryOp(['float32'], 'PRelu')],
Expand Down
4 changes: 2 additions & 2 deletions lib/backends/wasm/ops/gemm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ export class WasmGemm extends Gemm {
const b = inputs[1];
const c = inputs[2];

const [M, N] = GemmUtil.getShapeOfGemmResult(a.dims, this.transA, b.dims, this.transB, c.dims);
const [M, N] = GemmUtil.getShapeOfGemmResult(a.dims, this.transA, b.dims, this.transB, c?.dims);
const y = new Tensor([M, N], a.type);
if (!BroadcastUtil.calc(y, c, (a, b) => (b), true)) {
if (c && !BroadcastUtil.calc(y, c, (a, b) => (b), true)) {
throw new Error(`c is not broadcastable to the shape of the result of the Gemm operator`);
}
WasmBinding.getInstance().ccall(
Expand Down
6 changes: 3 additions & 3 deletions lib/backends/wasm/ops/softmax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ import {WasmInferenceHandler} from '../inference-handler';
export class WasmSoftmax extends Softmax {
run(inferenceHandler: WasmInferenceHandler, inputs: Tensor[]): Tensor[] {
const x = inputs[0];
const axisCorrected = ShapeUtil.parseAxis(this.axis, x.dims.length);
const N = ShapeUtil.sizeToDimension(x.dims, axisCorrected);
const D = ShapeUtil.sizeFromDimension(x.dims, axisCorrected);
const axis = ShapeUtil.normalizeAxis(this.axis, x.dims.length);
const N = ShapeUtil.sizeToDimension(x.dims, axis);
const D = ShapeUtil.sizeFromDimension(x.dims, axis);
const y = new Tensor(x.dims, x.type);
WasmBinding.getInstance().ccall(
'_softmax_f32', [x.floatData, 'float32ptr'], [y.floatData, 'float32ptr', 'out'], [N, 'int32'], [D, 'int32']);
Expand Down
9 changes: 5 additions & 4 deletions lib/backends/webgl/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ export const WEBGL_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['And', '', '7+', () => new binaryOps.WebGLBinaryOp(['bool'], binaryOps.glslAnd())],
['Asin', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslAsin())],
['Atan', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslAtan())],
['AveragePool', '', '7+', () => new WebGLAveragePool()], // TODO: support new attributes for AveragePool-10
['AveragePool', '', '7-10', () => new WebGLAveragePool()], // TODO: support new attributes for AveragePool-10
['BatchNormalization', '', '7+', () => new WebGLBatchNormalization()],
['Ceil', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslCeil())],
['Clip', '', '6+', () => new WebGLClip()],
['Clip', '', '6-10', () => new WebGLClip()],
['Concat', '', '4+', () => new WebGLConcat()],
['Conv', '', '1+', () => new WebGLConv()],
['Cos', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslCos())],
Expand All @@ -55,7 +55,8 @@ export const WEBGL_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Flatten', '', '1+', () => new WebGLFlatten()],
['Floor', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslFloor())],
['Gather', '', '1+', () => new WebGLGather()],
['Gemm', '', '7+', () => new WebGLGemm()],
['Gemm', '', '7-10', () => new WebGLGemm(false)],
['Gemm', '', '11+', () => new WebGLGemm(true)],
['GlobalAveragePool', '', '1+', () => new WebGLGlobalAveragePool()],
['GlobalMaxPool', '', '1+', () => new WebGLGlobalMaxPool()],
['Greater', '', '7+', () => new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslGreater(), undefined, 'bool')],
Expand All @@ -66,7 +67,7 @@ export const WEBGL_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Less', '', '7+', () => new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslLess(), undefined, 'bool')],
['Log', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslLog())],
['MatMul', '', '1+', () => new WebGLMatMul()],
['MaxPool', '', '1+', () => new WebGLMaxPool()], // TODO: support new attributes for MaxPool-8 and MaxPool-10
['MaxPool', '', '1-9', () => new WebGLMaxPool()], // TODO: support new attributes for MaxPool-8 and MaxPool-10
['Mul', '', '7+', () => new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslMul())],
['Neg', '', '6+', () => new unaryOps.WebGLUnaryOp(NUMBER_TYPES, unaryOps.glslNeg())],
['Not', '', '1+', () => new unaryOps.WebGLUnaryOp(['bool'], unaryOps.glslNot())],
Expand Down
19 changes: 11 additions & 8 deletions lib/backends/webgl/ops/gather.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import {Gather} from '../../../ops/gather';
import {Tensor} from '../../../tensor';
import {ShapeUtil} from '../../../util';
import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo, RunData, WebGLOperator} from '../types';

Expand All @@ -19,22 +20,23 @@ export class WebGLGather extends Gather implements WebGLOperator {
throw Error('A scalar tensor output has not been supported');
}

const axis = ShapeUtil.normalizeAxis(this.axis, inputShape.length);
const indexCopyOps: string[] = [];
for (let i = 0; i < outputShape.length; i++) {
// outputShape is divided into three parts: A, B, C
// |0 this.axis| this.axis + indexDataShape.length| end|
// | A | B | C |
// |0 axis| axis + indexDataShape.length | end|
// | A | B | C |
//
// inputIdx: [A, inputs[1][B], C]
if (i < this.axis) { // A
if (i < axis) { // A
outputShape[i] = inputShape[i];
indexCopyOps.push(`inputIdx[${i}] = outputIdx[${i}];`);
} else {
if (i < this.axis + indexDataShape.length) { // B
outputShape[i] = indexDataShape[i - this.axis];
indexCopyOps.push(`indexDataIdx[${i - this.axis}] = outputIdx[${i}];`);
if (i < axis + indexDataShape.length) { // B
outputShape[i] = indexDataShape[i - axis];
indexCopyOps.push(`indexDataIdx[${i - axis}] = outputIdx[${i}];`);
} else { // C
outputShape[i] = inputShape[i - indexDataShape.length + 1]; // skip 1 for this.axis
outputShape[i] = inputShape[i - indexDataShape.length + 1]; // skip 1 for axis
indexCopyOps.push(`inputIdx[${i - indexDataShape.length + 1}] = outputIdx[${i}];`);
}
}
Expand All @@ -48,7 +50,8 @@ export class WebGLGather extends Gather implements WebGLOperator {
int inputIdx[${irank}];
int indexDataIdx[${iDrank}];
${indexCopyOps.join('\n ')}
inputIdx[${this.axis}] = int(_B(indexDataIdx));
int idx = int(_B(indexDataIdx));
inputIdx[${axis}] = idx < 0 ? idx + ${inputShape[axis]} : idx;
return _A(inputIdx);
}`;
return {
Expand Down
16 changes: 9 additions & 7 deletions lib/backends/webgl/ops/gemm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ export class WebGLGemm extends Gemm implements WebGLOperator {
createProgramInfo(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
const aShape = inputs[0].dims.slice();
const bShape = inputs[1].dims.slice();
const cShape = inputs[2].dims.slice();
const [M, N] = GemmUtil.getShapeOfGemmResult(aShape, this.transA, bShape, this.transB, cShape);
const [M, N] = GemmUtil.getShapeOfGemmResult(
aShape, this.transA, bShape, this.transB, inputs.length === 3 ? inputs[2].dims : undefined);
const oShape = [M, N];
if (!oShape) {
throw new Error('Can\'t use gemm on the given tensors');
Expand All @@ -35,16 +35,18 @@ export class WebGLGemm extends Gemm implements WebGLOperator {
line = `value += _A(a) * _B(b);`;
}
const rank = oShape.length;
const cRank = cShape.length;
const declareC = inputs.length === 3 ? `int c[${inputs[2].dims.length}];` : '';
const broadcastC = inputs.length === 3 ? `bcastIndices_C(indices, c);` : '';
const calculateC = inputs.length === 3 ? `value += beta * _C(c);` : '';
const shaderSource = `
float process(int indices[${rank}]) {
int a[${rank}];
int b[${rank}];
int c[${cRank}];
${declareC}
copyVec(indices, a);
copyVec(indices, b);
bcastIndices_C(indices, c);
${broadcastC}
float value = 0.0;
for (int k=0; k<${sharedDim}; ++k) {
Expand All @@ -54,14 +56,14 @@ export class WebGLGemm extends Gemm implements WebGLOperator {
}
value = value * alpha;
value += beta * _C(c);
${calculateC}
return value;
}`;
const inputLayouts = inputs.map(t => inferenceHandler.getOrCreateTextureLayout(t));
return {
inputLayouts,
outputLayout: inferenceHandler.createTextureLayoutFromShape(oShape),
samplers: ['A', 'B', 'C'],
samplers: inputs.length === 3 ? ['A', 'B', 'C'] : ['A', 'B'],
variables: [{name: 'alpha', type: 'float'}, {name: 'beta', type: 'float'}],
shaderSource,
};
Expand Down
Loading