Skip to content

Commit 44808b4

Browse files
authored
[WASM] Implement concat embeddings (#17404)
* [WASM] Implement concat embeddings * Make concatEmbeddings optional for backward compatibility
1 parent 9e2a75d commit 44808b4

File tree

3 files changed

+84
-1
lines changed

3 files changed

+84
-1
lines changed

src/target/source/codegen_webgpu.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re
125125
name_supply_->ReserveName("var");
126126
name_supply_->ReserveName("let");
127127
name_supply_->ReserveName("const");
128+
name_supply_->ReserveName("std");
128129

129130
// skip the first underscore, so SSA variable starts from
130131
name_supply_->FreshName("v_");

web/emcc/wasm_runtime.cc

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,5 +173,51 @@ TVM_REGISTER_GLOBAL("tvmjs.runtime.ArrayConcat").set_body([](TVMArgs args, TVMRe
173173
}
174174
*ret = Array<ObjectRef>(data);
175175
});
176+
177+
NDArray ConcatEmbeddings(const std::vector<NDArray>& embeddings) {
178+
// Get output shape
179+
int64_t hidden_size = embeddings[0]->shape[1];
180+
DLDataType dtype = embeddings[0]->dtype;
181+
DLDevice device = embeddings[0]->device;
182+
int seqLen = 0;
183+
for (int i = 0; i < embeddings.size(); ++i) {
184+
ICHECK_EQ(embeddings[i]->ndim, 2);
185+
ICHECK_EQ(embeddings[i]->shape[1], hidden_size);
186+
seqLen += embeddings[i]->shape[0];
187+
}
188+
189+
// Create output
190+
std::vector<int64_t> shape;
191+
shape.push_back(seqLen);
192+
shape.push_back(hidden_size);
193+
NDArray result = NDArray::Empty(shape, dtype, device);
194+
195+
// Copy
196+
int offset = 0;
197+
for (int i = 0; i < embeddings.size(); i++) {
198+
const DLTensor& copy_src = *(embeddings[i].operator->());
199+
const DLTensor* p_copy_dst = result.operator->();
200+
DLTensor copy_dst = *p_copy_dst;
201+
copy_dst.shape = embeddings[i]->shape;
202+
copy_dst.byte_offset =
203+
offset * hidden_size * ((embeddings[i]->dtype.bits * embeddings[i]->dtype.lanes + 7) / 8);
204+
NDArray::CopyFromTo(&copy_src, &copy_dst);
205+
offset += embeddings[i]->shape[0];
206+
}
207+
208+
return result;
209+
}
210+
211+
// Concatenate n NDArrays
212+
TVM_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings").set_body([](TVMArgs args, TVMRetValue* ret) {
213+
std::vector<NDArray> embeddings;
214+
for (int i = 0; i < args.size(); ++i) {
215+
ICHECK_EQ(args[i].type_code(), kTVMNDArrayHandle);
216+
embeddings.push_back(args[i]);
217+
}
218+
NDArray result = ConcatEmbeddings(std::move(embeddings));
219+
*ret = result;
220+
});
221+
176222
} // namespace runtime
177223
} // namespace tvm

web/src/runtime.ts

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ class RuntimeContext implements Disposable {
174174
applyRepetitionPenalty: PackedFunc;
175175
applyPresenceAndFrequencyPenalty: PackedFunc;
176176
applySoftmaxWithTemperature: PackedFunc;
177+
concatEmbeddings: PackedFunc | undefined;
177178

178179
private autoDisposeScope: Array<Array<Disposable | undefined>> = [];
179180

@@ -199,6 +200,11 @@ class RuntimeContext implements Disposable {
199200
this.applyRepetitionPenalty = getGlobalFunc("vm.builtin.apply_repetition_penalty");
200201
this.applyPresenceAndFrequencyPenalty = getGlobalFunc("vm.builtin.apply_presence_and_frequency_penalty");
201202
this.applySoftmaxWithTemperature = getGlobalFunc("vm.builtin.apply_softmax_with_temperature");
203+
try {
204+
this.concatEmbeddings = getGlobalFunc("tvmjs.runtime.ConcatEmbeddings");
205+
} catch {
206+
// TODO: remove soon. Older artifacts do not have this, try-catch for backward compatibility.
207+
}
202208
}
203209

204210
dispose(): void {
@@ -223,6 +229,7 @@ class RuntimeContext implements Disposable {
223229
this.applyRepetitionPenalty.dispose();
224230
this.applyPresenceAndFrequencyPenalty.dispose();
225231
this.applySoftmaxWithTemperature.dispose();
232+
this.concatEmbeddings?.dispose();
226233
}
227234

228235
beginScope(): void {
@@ -575,7 +582,10 @@ export class NDArray implements Disposable {
575582
* @param data The source data array.
576583
* @returns this
577584
*/
578-
copyFrom(data: NDArray | Array<number> | Float32Array): this {
585+
copyFrom(
586+
data: NDArray | Array<number> | Float32Array | Float64Array |
587+
Int32Array | Int8Array | Uint8Array | Uint8ClampedArray
588+
): this {
579589
if (data instanceof NDArray) {
580590
this.lib.checkCall(
581591
(this.lib.exports.TVMArrayCopyFromTo as ctypes.FTVMArrayCopyFromTo)(
@@ -608,6 +618,8 @@ export class NDArray implements Disposable {
608618
buffer = Int8Array.from(data).buffer;
609619
} else if (this.dtype === "uint8") {
610620
buffer = Uint8Array.from(data).buffer;
621+
} else if (this.dtype === "uint32") {
622+
buffer = Uint32Array.from(data).buffer;
611623
} else {
612624
throw new Error("Unsupported data type " + this.dtype);
613625
}
@@ -1906,6 +1918,30 @@ export class Instance implements Disposable {
19061918
return this.ctx.arrayConcat(...listOfArrays) as TVMArray;
19071919
}
19081920

1921+
/**
1922+
* Join a sequence of NDArrays that represent embeddings.
1923+
* @param inputs A list of embeddings in NDArrays, each array i has shape (m_i, hidden_size).
1924+
* @returns An NDArray of shape (\sum_{i} {m}, hidden_size)
1925+
*/
1926+
concatEmbeddings(embeddings: Array<NDArray>): NDArray {
1927+
// 1. Check shape validity
1928+
const hidden_size = embeddings[0].shape[1];
1929+
embeddings.forEach((input) => {
1930+
if (input.shape.length !== 2 || input.shape[1] !== hidden_size) {
1931+
throw new Error("Expect embeddings to concatenate have shape (m_i, hidden_size).");
1932+
}
1933+
})
1934+
1935+
// 2. Call global func
1936+
if (this.ctx.concatEmbeddings === undefined) {
1937+
throw new Error(
1938+
"Global function tvmjs.runtime.ConcatEmbeddings was " +
1939+
"not found, but called concatEmbeddings."
1940+
);
1941+
}
1942+
return this.ctx.concatEmbeddings(...embeddings) as NDArray;
1943+
}
1944+
19091945
/**
19101946
* Create a {@link TVMString} that can be consumed by runtime.
19111947
*

0 commit comments

Comments
 (0)