@@ -174,6 +174,7 @@ class RuntimeContext implements Disposable {
174174 applyRepetitionPenalty : PackedFunc ;
175175 applyPresenceAndFrequencyPenalty : PackedFunc ;
176176 applySoftmaxWithTemperature : PackedFunc ;
177+ concatEmbeddings : PackedFunc | undefined ;
177178
178179 private autoDisposeScope : Array < Array < Disposable | undefined > > = [ ] ;
179180
@@ -199,6 +200,11 @@ class RuntimeContext implements Disposable {
199200 this . applyRepetitionPenalty = getGlobalFunc ( "vm.builtin.apply_repetition_penalty" ) ;
200201 this . applyPresenceAndFrequencyPenalty = getGlobalFunc ( "vm.builtin.apply_presence_and_frequency_penalty" ) ;
201202 this . applySoftmaxWithTemperature = getGlobalFunc ( "vm.builtin.apply_softmax_with_temperature" ) ;
203+ try {
204+ this . concatEmbeddings = getGlobalFunc ( "tvmjs.runtime.ConcatEmbeddings" ) ;
205+ } catch {
206+ // TODO: remove soon. Older artifacts do not have this, try-catch for backward compatibility.
207+ }
202208 }
203209
204210 dispose ( ) : void {
@@ -223,6 +229,7 @@ class RuntimeContext implements Disposable {
223229 this . applyRepetitionPenalty . dispose ( ) ;
224230 this . applyPresenceAndFrequencyPenalty . dispose ( ) ;
225231 this . applySoftmaxWithTemperature . dispose ( ) ;
232+ this . concatEmbeddings ?. dispose ( ) ;
226233 }
227234
228235 beginScope ( ) : void {
@@ -575,7 +582,10 @@ export class NDArray implements Disposable {
575582 * @param data The source data array.
576583 * @returns this
577584 */
578- copyFrom ( data : NDArray | Array < number > | Float32Array ) : this {
585+ copyFrom (
586+ data : NDArray | Array < number > | Float32Array | Float64Array |
587+ Int32Array | Int8Array | Uint8Array | Uint8ClampedArray
588+ ) : this {
579589 if ( data instanceof NDArray ) {
580590 this . lib . checkCall (
581591 ( this . lib . exports . TVMArrayCopyFromTo as ctypes . FTVMArrayCopyFromTo ) (
@@ -608,6 +618,8 @@ export class NDArray implements Disposable {
608618 buffer = Int8Array . from ( data ) . buffer ;
609619 } else if ( this . dtype === "uint8" ) {
610620 buffer = Uint8Array . from ( data ) . buffer ;
621+ } else if ( this . dtype === "uint32" ) {
622+ buffer = Uint32Array . from ( data ) . buffer ;
611623 } else {
612624 throw new Error ( "Unsupported data type " + this . dtype ) ;
613625 }
@@ -1906,6 +1918,30 @@ export class Instance implements Disposable {
19061918 return this . ctx . arrayConcat ( ...listOfArrays ) as TVMArray ;
19071919 }
19081920
1921+ /**
1922+ * Join a sequence of NDArrays that represent embeddings.
1923+ * @param inputs A list of embeddings in NDArrays, each array i has shape (m_i, hidden_size).
1924+ * @returns An NDArray of shape (\sum_{i} {m}, hidden_size)
1925+ */
1926+ concatEmbeddings ( embeddings : Array < NDArray > ) : NDArray {
1927+ // 1. Check shape validity
1928+ const hidden_size = embeddings [ 0 ] . shape [ 1 ] ;
1929+ embeddings . forEach ( ( input ) => {
1930+ if ( input . shape . length !== 2 || input . shape [ 1 ] !== hidden_size ) {
1931+ throw new Error ( "Expect embeddings to concatenate have shape (m_i, hidden_size)." ) ;
1932+ }
1933+ } )
1934+
1935+ // 2. Call global func
1936+ if ( this . ctx . concatEmbeddings === undefined ) {
1937+ throw new Error (
1938+ "Global function tvmjs.runtime.ConcatEmbeddings was " +
1939+ "not found, but called concatEmbeddings."
1940+ ) ;
1941+ }
1942+ return this . ctx . concatEmbeddings ( ...embeddings ) as NDArray ;
1943+ }
1944+
19091945 /**
19101946 * Create a {@link TVMString} that can be consumed by runtime.
19111947 *
0 commit comments