Skip to content

Commit 2c9af0f

Browse files
authored
[Runtime] Allow aborting fetchNDArray through AbortSignal (#17208)
[Runtime] Allow aborting fetchNDArray
1 parent f62445c commit 2c9af0f

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

web/src/artifact_cache.ts

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,14 @@ export interface ArtifactCacheTemplate {
5858
*
5959
* @param url: The url to the data to be cached.
6060
* @param storetype: Only applies to `ArtifactIndexedDBCache`. Since `indexedDB` stores the actual
61+
* @param signal: An optional AbortSignal to abort data retrival
6162
* data rather than a request, we specify `storagetype`. There are two options:
6263
* 1. "json": IndexedDB stores `fetch(url).json()`
6364
* 2. "arraybuffer": IndexedDB stores `fetch(url).arrayBuffer()`
6465
*
6566
* @note This is an async function.
6667
*/
67-
addToCache(url: string, storetype?: string): Promise<void>;
68+
addToCache(url: string, storetype?: string, signal?: AbortSignal): Promise<void>;
6869

6970
/**
7071
* check if cache has all keys in Cache
@@ -126,8 +127,8 @@ export class ArtifactCache implements ArtifactCacheTemplate {
126127
}
127128

128129
// eslint-disable-next-line @typescript-eslint/no-unused-vars
129-
async addToCache(url: string, storetype?: string) {
130-
const request = new Request(url);
130+
async addToCache(url: string, storetype?: string, signal?: AbortSignal) {
131+
const request = new Request(url, signal ? { signal } : undefined);
131132
if (this.cache === undefined) {
132133
this.cache = await caches.open(this.scope);
133134
}
@@ -282,15 +283,15 @@ export class ArtifactIndexedDBCache implements ArtifactCacheTemplate {
282283
});
283284
}
284285

285-
async addToCache(url: string, storetype?: string): Promise<void> {
286+
async addToCache(url: string, storetype?: string, signal?: AbortSignal): Promise<void> {
286287
await this.initDB(); // await the initDB process
287288
// If already cached, nothing to do
288289
const isInDB = await this.isUrlInDB(url);
289290
if (isInDB) {
290291
return;
291292
}
292293
try {
293-
const response = await fetch(url);
294+
const response = await fetch(url, signal ? { signal } : undefined);
294295
if (!response.ok) {
295296
throw new Error('Network response was not ok');
296297
}

web/src/runtime.ts

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,13 +1444,15 @@ export class Instance implements Disposable {
14441444
* @param device The device to be fetched to.
14451445
* @param cacheScope The scope identifier of the cache
14461446
* @param cacheType The type of the cache: "cache" or "indexedDB"
1447+
* @param signal An optional AbortSignal to abort the fetch
14471448
* @returns The meta data
14481449
*/
14491450
async fetchNDArrayCache(
14501451
ndarrayCacheUrl: string,
14511452
device: DLDevice,
14521453
cacheScope = "tvmjs",
1453-
cacheType = "cache"
1454+
cacheType = "cache",
1455+
signal?: AbortSignal,
14541456
): Promise<any> {
14551457
let artifactCache: ArtifactCacheTemplate;
14561458
if (cacheType === undefined || cacheType.toLowerCase() === "cache") {
@@ -1465,7 +1467,8 @@ export class Instance implements Disposable {
14651467
const list = await artifactCache.fetchWithCache(jsonUrl, "json");
14661468
await this.fetchNDArrayCacheInternal(
14671469
ndarrayCacheUrl,
1468-
list["records"] as Array<NDArrayShardEntry>, device, artifactCache);
1470+
list["records"] as Array<NDArrayShardEntry>, device, artifactCache,
1471+
signal);
14691472
this.cacheMetadata = { ...this.cacheMetadata, ...(list["metadata"] as Record<string, any>) };
14701473
}
14711474

@@ -1477,12 +1480,14 @@ export class Instance implements Disposable {
14771480
* @param list The list of array data.
14781481
* @param device The device to store the data to.
14791482
* @param artifactCache The artifact cache
1483+
* @param signal An optional AbortSignal to abort the fetch
14801484
*/
14811485
private async fetchNDArrayCacheInternal(
14821486
ndarrayCacheUrl: string,
14831487
list: Array<NDArrayShardEntry>,
14841488
device: DLDevice,
1485-
artifactCache: ArtifactCacheTemplate
1489+
artifactCache: ArtifactCacheTemplate,
1490+
signal?: AbortSignal,
14861491
) {
14871492
const perf = compact.getPerformance();
14881493
const tstart = perf.now();
@@ -1537,7 +1542,7 @@ export class Instance implements Disposable {
15371542
const shard = list[i];
15381543
const dataUrl = new URL(shard.dataPath, ndarrayCacheUrl).href;
15391544
try {
1540-
await artifactCache.addToCache(dataUrl, "arraybuffer");
1545+
await artifactCache.addToCache(dataUrl, "arraybuffer", signal);
15411546
} catch (err) {
15421547
this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err);
15431548
throw err;

0 commit comments

Comments
 (0)