Skip to content

Commit

Permalink
fix upload
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent b6e7fba commit a782667
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
8 changes: 4 additions & 4 deletions js/web/lib/onnxjs/backends/webgpu/gpu-data-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export interface GpuDataManager {
/**
* upload data to GPU. if the ID already exists in cache, returns the cached value without uploading anything.
*/
upload(id: GpuDataId, data: Tensor.NumberType, gpuDataType: GpuDataType): Promise<GpuData>;
upload(data: Tensor.NumberType, gpuDataType: GpuDataType): Promise<GpuData>;
/**
* create new data on GPU.
*/
Expand Down Expand Up @@ -54,7 +54,7 @@ class GpuDataManagerImpl implements GpuDataManager {
this.downloadCache = new Map();
}

async upload(id: GpuDataId, data: Tensor.NumberType, gpuDataType: GpuDataType): Promise<GpuData> {
async upload(data: Tensor.NumberType, gpuDataType: GpuDataType): Promise<GpuData> {
if (gpuDataType !== GpuDataType.default) {
throw new Error('we only support default GPU data type now');
}
Expand All @@ -72,8 +72,8 @@ class GpuDataManagerImpl implements GpuDataManager {
new Uint8Array(arrayBuffer).set(new Uint8Array(srcArrayBuffer, srcOffset, srcLength));
gpuBuffer.unmap();

const gpuData = {id, type: GpuDataType.default, buffer: gpuBuffer};
this.storageCache.set(id, {gpuData, size: srcLength});
const gpuData = {id: Guid.create(), type: GpuDataType.default, buffer: gpuBuffer};
this.storageCache.set(gpuData.id, {gpuData, size: srcLength});
return gpuData;
}

Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/onnxjs/backends/webgpu/inference-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export class WebGpuInferenceHandler implements InferenceHandler {
this.dataManager = createTensorDataManager(session.backend.device);
}

private uploadGpuData(tensor: Tensor, textureType: GpuDataType): GpuData {
private async uploadGpuData(tensor: Tensor, textureType: GpuDataType): Promise<GpuData> {
if (this.session.isInitializer(tensor.dataId)) {
return this.session.dataManager.uploadTensorToGpu(tensor, textureType);
}
Expand All @@ -46,7 +46,7 @@ export class WebGpuInferenceHandler implements InferenceHandler {
// create info for input
const inputDatas: GpuData[] = [];
for (let i = 0; i < program.inputTypes.length; ++i) {
inputDatas[i] = this.uploadGpuData(inputs[i], program.inputTypes[i]);
inputDatas[i] = await this.uploadGpuData(inputs[i], program.inputTypes[i]);
}

const key = getProgramInfoUniqueKey(program, inputDatas);
Expand Down
6 changes: 3 additions & 3 deletions js/web/lib/onnxjs/backends/webgpu/tensor-data-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export interface TensorDataManager {
/**
* upload a CPU tensor to GPU.
*/
uploadTensorToGpu(tensor: Tensor, gpuDataType: GpuDataType): GpuData;
uploadTensorToGpu(tensor: Tensor, gpuDataType: GpuDataType): Promise<GpuData>;

/**
* create a new GPU tensor.
Expand Down Expand Up @@ -55,7 +55,7 @@ class TensorDataManagerImpl implements TensorDataManager {
tensorIds.add(tensorId);
}

uploadTensorToGpu(tensor: Tensor, gpuDataType: GpuDataType): GpuData {
async uploadTensorToGpu(tensor: Tensor, gpuDataType: GpuDataType): Promise<GpuData> {
const gpuDataId = this.map.get(tensor.dataId);
if (gpuDataId) {
const gpuData = this.gpuDataManager.get(gpuDataId);
Expand All @@ -65,7 +65,7 @@ class TensorDataManagerImpl implements TensorDataManager {
return gpuData;
}

const gpuData = this.gpuDataManager.create(tensor.type, tensor.dims, gpuDataType);
const gpuData = await this.gpuDataManager.upload(tensor.numberData, gpuDataType);
this.registerIdMapping(tensor.dataId, gpuData.id);
return gpuData;
}
Expand Down

0 comments on commit a782667

Please sign in to comment.