Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Commit

Permalink
webgl: add full texture cache
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed May 31, 2019
1 parent 09b3d36 commit 5a412bf
Show file tree
Hide file tree
Showing 14 changed files with 140 additions and 82 deletions.
4 changes: 4 additions & 0 deletions lib/api/onnx.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ export declare namespace Backend {
* set or get the maximum batch size for matmul. 0 means to disable batching.
*/
matmulMaxBatchSize?: number;
/**
* set or get the texture cache mode
*/
textureCacheMode?: 'initializerOnly'|'full';
}
/**
* set options for the WebAssembly backend
Expand Down
4 changes: 4 additions & 0 deletions lib/backends/backend-webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@ export class WebGLBackend implements Backend, WebGLOptions {
glContext: WebGLContext;
contextId?: 'webgl'|'webgl2';
matmulMaxBatchSize?: number;
textureCacheMode?: 'initializerOnly'|'full';

initialize(): boolean {
try {
this.glContext = createWebGLContext(this.contextId);
if (typeof this.matmulMaxBatchSize !== 'number') {
this.matmulMaxBatchSize = 16;
}
if (typeof this.textureCacheMode !== 'string') {
this.textureCacheMode = 'full';
}
Logger.verbose('WebGLBackend', `Created WebGLContext: ${typeof this.glContext}`);
return true;
} catch (e) {
Expand Down
25 changes: 9 additions & 16 deletions lib/backends/webgl/inference-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,27 @@ import {Tensor} from '../../tensor';
import {ShapeUtil} from '../../util';

import {WebGLUint8Encode} from './ops/uint8-encode';
import {ProgramManager} from './program-manager';
import {WebGLSessionHandler} from './session-handler';
import {Encoder} from './texture-data-encoder';
import {TextureHelper} from './texture-helper';
import {WidthHeightPrefs} from './texture-layout-strategy';
import {TextureData, TextureLayout, WebGLOperator} from './types';
import {getPackedShape} from './utils';

export class WebGLInferenceHandler implements InferenceHandler {
textureHelper: TextureHelper;
programManager: ProgramManager;
private textureDataCache: Map<Tensor.Id, TextureData>;
constructor(public session: WebGLSessionHandler) {
this.textureHelper = session.textureHelper;
this.programManager = session.programManager;
this.textureDataCache = new Map();
}

run(op: WebGLOperator, inputs: Tensor[]): Tensor[] {
let artifact = this.programManager.getArtifact(op);
let artifact = this.session.programManager.getArtifact(op);
if (!artifact) {
const programInfo = op.createProgramInfo(this, inputs);
artifact = this.programManager.build(programInfo);
this.programManager.setArtifact(op, artifact);
artifact = this.session.programManager.build(programInfo);
this.session.programManager.setArtifact(op, artifact);
}
const runData = op.createRunData(this, artifact.programInfo, inputs);
this.programManager.run(artifact, runData);
this.session.programManager.run(artifact, runData);
return [runData.outputTextureData.tensor];
}

Expand Down Expand Up @@ -90,7 +84,7 @@ export class WebGLInferenceHandler implements InferenceHandler {
layout: TextureLayout, dataType: Tensor.DataType, data?: Tensor.NumberType, tensor?: Tensor,
usage?: Encoder.Usage): TextureData {
Logger.verbose('InferenceHandler', `Creating TextureData: layout:[${JSON.stringify(layout)}]`);
const texture = this.textureHelper.createTextureFromLayout(dataType, layout, data, usage);
const texture = this.session.textureManager.createTextureFromLayout(dataType, layout, data, usage);
return this.createTextureDataFromTexture(layout, dataType, texture, tensor);
}

Expand Down Expand Up @@ -175,18 +169,17 @@ export class WebGLInferenceHandler implements InferenceHandler {
}

dispose(): void {
this.textureHelper.clearActiveTextures();
this.textureDataCache.forEach(td => this.textureHelper.releaseTexture(td));
this.session.textureManager.clearActiveTextures();
this.textureDataCache.forEach(td => this.session.textureManager.releaseTexture(td));
this.textureDataCache = new Map();
}

readTexture(textureData: TextureData): Tensor.NumberType {
if (!this.session.backend.glContext.isFloat32DownloadSupported) {
const op = new WebGLUint8Encode();
const uint8TD = op.runInternal(this, textureData);
return this.textureHelper.readUint8TextureAsFloat(uint8TD);
return this.session.textureManager.readUint8TextureAsFloat(uint8TD);
}
const values = this.textureHelper.readTexture(textureData, textureData.tensor.type, textureData.channels);
return values;
return this.session.textureManager.readTexture(textureData, textureData.tensor.type, textureData.channels);
}
}
40 changes: 19 additions & 21 deletions lib/backends/webgl/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ import {WebGLContext} from '../webgl-context';

export class WebGLConv extends Conv {
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
const programManager = inferenceHandler.programManager;
const programManager = inferenceHandler.session.programManager;
if (!this.artifacts) {
this.artifacts = [];
const programInfos = this.createProgramInfos(inferenceHandler, inputs);
for (let i = 0; i < programInfos.length; ++i) {
const artifact = inferenceHandler.programManager.build(programInfos[i]);
const artifact = inferenceHandler.session.programManager.build(programInfos[i]);
this.artifacts.push(artifact);
}
}
Expand Down Expand Up @@ -70,40 +70,38 @@ export class WebGLConv extends Conv {
inputTDs.push(inferenceHandler.getOrCreateTextureData(b));
}
const outputTD = inferenceHandler.createTextureDataFromLayout(programInfos[1].outputLayout, inputs[0].type);
const blendEnabled = inferenceHandler.session.backend.glContext.isBlendSupported;
const runDataDotProduct = {
inputTextureDatas: inputTDs,
outputTextureData: outputTD,
uniformData: {},
preRun: blendEnabled ?
(glContext: WebGLContext, artifact: Artifact) => {
const gl = glContext.gl;
gl.enable(gl.BLEND);
glContext.checkError();
gl.blendEquation(gl.FUNC_ADD);
glContext.checkError();
gl.blendFunc(gl.ONE, gl.ONE);
glContext.checkError();
} :
undefined,
postRun: blendEnabled ?
(glContext: WebGLContext, artifact: Artifact) => {
const gl = glContext.gl;
gl.disable(gl.BLEND);
glContext.checkError();
} :
undefined,
draw: (glContext: WebGLContext, artifact: Artifact) => {
const gl = glContext.gl;
const sharedDim = artifact.programInfo.params!.sharedDim as number;
const sharedDimReadSize = artifact.programInfo.params!.sharedDimReadSize as number;
const sharedDimOffsetLocation = artifact.uniformLocations.find(l => l.name === 'sharedDimOffset')!.location;
let blend = false;
for (let k = 0; k < sharedDim; k += sharedDimReadSize) {
Logger.verbose('MatMul2D', `k = ${k}, sharedDim: ${sharedDim}, readSize = ${sharedDimReadSize}`);

if (k === sharedDimReadSize) {
blend = true;
gl.enable(gl.BLEND);
glContext.checkError();
gl.blendEquation(gl.FUNC_ADD);
glContext.checkError();
gl.blendFunc(gl.ONE, gl.ONE);
glContext.checkError();
}

gl.uniform1i(sharedDimOffsetLocation, k);
glContext.checkError();
glContext.draw();
}

if (blend) {
gl.disable(gl.BLEND);
glContext.checkError();
}
}
};
return [runtDataIm2Col, runDataDotProduct];
Expand Down
4 changes: 2 additions & 2 deletions lib/backends/webgl/ops/softmax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ export class WebGLSoftmax extends Softmax {
this.artifacts = [];
const programInfos = this.createProgramInfos(inferenceHandler, inputs);
programInfos.forEach((pi, i) => {
const artifact = inferenceHandler.programManager.build(pi);
const artifact = inferenceHandler.session.programManager.build(pi);
this.artifacts.push(artifact);
});
}

const runDatas = this.createRunDatas(inferenceHandler, this.artifacts.map(a => a.programInfo), inputs);
runDatas.forEach((v, i) => inferenceHandler.programManager.run(this.artifacts[i], v));
runDatas.forEach((v, i) => inferenceHandler.session.programManager.run(this.artifacts[i], v));
// return only the last output
return [runDatas[runDatas.length - 1].outputTextureData.tensor];
}
Expand Down
4 changes: 2 additions & 2 deletions lib/backends/webgl/ops/split.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ export class WebGLSplit extends Split {
this.artifacts = [];
for (let i = 0; i < count; ++i) {
const programInfo = this.createProgramInfo(inferenceHandler, inputs[0], i);
const artifact = inferenceHandler.programManager.build(programInfo);
const artifact = inferenceHandler.session.programManager.build(programInfo);
this.artifacts.push(artifact);
}
}
const results: Tensor[] = [];

this.artifacts.forEach(artifact => {
const rundata = this.createRunData(inferenceHandler, artifact.programInfo, inputs);
inferenceHandler.programManager.run(artifact, rundata);
inferenceHandler.session.programManager.run(artifact, rundata);
results.push(rundata.outputTextureData.tensor);
});
return results;
Expand Down
4 changes: 2 additions & 2 deletions lib/backends/webgl/ops/uint8-encode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,15 @@ export class WebGLUint8Encode {
${glsl.output} = encodeAsUint8(value);
}`;
const programInfo = {inputLayouts: [input], outputLayout, samplers: ['X'], shaderSource, hasMain: true};
const artifact = inferenceHandler.programManager.build(programInfo);
const artifact = inferenceHandler.session.programManager.build(programInfo);

const encoder = inferenceHandler.session.backend.glContext.getEncoder('byte', 4);
const texture =
inferenceHandler.session.backend.glContext.allocateTexture(outputLayout.width, outputLayout.height, encoder);
const outputTextureData = inferenceHandler.createSharedTextureData(outputLayout, 'uint8', texture, {});
const runData = {inputTextureDatas: [input], outputTextureData, uniformData: {}};

inferenceHandler.programManager.run(artifact, runData);
inferenceHandler.session.programManager.run(artifact, runData);
return runData.outputTextureData;
}
}
8 changes: 0 additions & 8 deletions lib/backends/webgl/program-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ export class ProgramManager {
}
run(buildArtifact: Artifact, runData: RunData): void {
this.profiler.event('backend', 'ProgramManager.run', () => {
if (runData.preRun) {
Logger.verbose('ProgramManager', 'PreRun');
runData.preRun(this.glContext, buildArtifact);
}
const gl = this.glContext.gl;
const program = buildArtifact.program;
gl.useProgram(program);
Expand All @@ -56,10 +52,6 @@ export class ProgramManager {
this.doDraw(buildArtifact, runData);
gl.flush();
});
if (runData.postRun) {
Logger.verbose('ProgramManager', 'PostRun');
runData.postRun(this.glContext, buildArtifact);
}
});
}
dispose(): void {
Expand Down
12 changes: 7 additions & 5 deletions lib/backends/webgl/session-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,23 @@ import {WebGLBackend} from '../backend-webgl';
import {WebGLInferenceHandler} from './inference-handler';
import {WEBGL_OP_RESOLVE_RULES} from './op-resolve-rules';
import {ProgramManager} from './program-manager';
import {TextureHelper} from './texture-helper';
import {AlwaysKeepOriginalSizeStrategy, TextureLayoutStrategy} from './texture-layout-strategy';
import {TextureManager} from './texture-manager';
import {TextureData} from './types';

export class WebGLSessionHandler implements SessionHandler {
programManager: ProgramManager;
textureHelper: TextureHelper;
textureManager: TextureManager;
layoutStrategy: TextureLayoutStrategy;
textureDataCache: Map<Tensor.Id, TextureData>;
initializers: Set<Tensor.Id>;

constructor(public readonly backend: WebGLBackend, public readonly context: Session.Context) {
this.programManager = new ProgramManager(this.context.profiler, backend.glContext);
this.layoutStrategy = new AlwaysKeepOriginalSizeStrategy(backend.glContext.maxTextureSize);
this.textureHelper = new TextureHelper(backend.glContext, this.layoutStrategy, this.context.profiler);
this.textureManager = new TextureManager(
backend.glContext, this.layoutStrategy, this.context.profiler,
{reuseTextures: backend.textureCacheMode === 'full'});
this.textureDataCache = new Map();
}

Expand All @@ -50,8 +52,8 @@ export class WebGLSessionHandler implements SessionHandler {
}
dispose(): void {
this.programManager.dispose();
this.textureHelper.clearActiveTextures();
this.textureDataCache.forEach(td => this.textureHelper.releaseTexture(td));
this.textureManager.clearActiveTextures();
this.textureDataCache.forEach(td => this.textureManager.releaseTexture(td, true));
this.textureDataCache = new Map();
}
resolve(node: Graph.Node, opsets: ReadonlyArray<OpSet>): Operator {
Expand Down
Loading

0 comments on commit 5a412bf

Please sign in to comment.