Releases: huggingface/transformers.js
3.2.0
🔥 Transformers.js v3.2 — Moonshine for real-time speech recognition, Phi-3.5 Vision for multi-frame image understanding and reasoning, and more!
Table of contents:
🤖 New models: Moonshine, Phi-3.5 Vision, EXAONE
Moonshine for real-time speech recognition
Moonshine is a family of speech-to-text models optimized for fast and accurate automatic speech recognition (ASR) on resource-constrained devices. They are well-suited to real-time, on-device applications like live transcription and voice command recognition, and are perfect for in-browser usage (demo coming soon). See #1099 for more information and here for the list of supported models.
Example: Automatic speech recognition w/ Moonshine tiny.
import { pipeline } from "@huggingface/transformers";
const transcriber = await pipeline("automatic-speech-recognition", "onnx-community/moonshine-tiny-ONNX");
const output = await transcriber("https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav");
console.log(output);
// { text: 'And so my fellow Americans ask not what your country can do for you as what you can do for your country.' }
See example using the MoonshineForConditionalGeneration API
import { MoonshineForConditionalGeneration, AutoProcessor, read_audio } from "@huggingface/transformers";
// Load model and processor
const model_id = "onnx-community/moonshine-tiny-ONNX";
const model = await MoonshineForConditionalGeneration.from_pretrained(model_id, {
dtype: "q4",
});
const processor = await AutoProcessor.from_pretrained(model_id);
// Load audio and prepare inputs
const audio = await read_audio("https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav", 16000);
const inputs = await processor(audio);
// Generate outputs
const outputs = await model.generate({ ...inputs, max_new_tokens: 100 });
// Decode outputs
const decoded = processor.batch_decode(outputs, { skip_special_tokens: true });
console.log(decoded[0]);
// And so my fellow Americans ask not what your country can do for you, ask what you can do for your country.
Phi-3.5 Vision for multi-frame image understanding and reasoning
Phi-3.5 Vision is a lightweight, state-of-the-art, open multimodal model that can be used for multi-frame image understanding and reasoning. See #1094 for more information and here for the list of supported models.
Examples:
See example code
Example: Single-frame (critique an image)
import {
AutoProcessor,
AutoModelForCausalLM,
TextStreamer,
load_image,
} from "@huggingface/transformers";
// Load processor and model
const model_id = "onnx-community/Phi-3.5-vision-instruct";
const processor = await AutoProcessor.from_pretrained(model_id, {
legacy: true, // Use legacy to match python version
});
const model = await AutoModelForCausalLM.from_pretrained(model_id, {
dtype: {
vision_encoder: "q4", // 'q4' or 'q4f16'
prepare_inputs_embeds: "q4", // 'q4' or 'q4f16'
model: "q4f16", // 'q4f16'
},
});
// Load image
const image = await load_image("https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/meme.png");
// Prepare inputs
const messages = [
{ role: "user", content: "<|image_1|>What's funny about this image?" },
];
const prompt = processor.tokenizer.apply_chat_template(messages, {
tokenize: false,
add_generation_prompt: true,
});
const inputs = await processor(prompt, image, { num_crops: 4 });
// (Optional) Set up text streamer
const streamer = new TextStreamer(processor.tokenizer, {
skip_prompt: true,
skip_special_tokens: true,
});
// Generate response
const output = await model.generate({
...inputs,
streamer,
max_new_tokens: 256,
});
Or, decode the output at the end:
// Decode and display the answer
const generated_ids = output.slice(null, [inputs.input_ids.dims[1], null]);
const answer = processor.batch_decode(generated_ids, {
skip_special_tokens: true,
});
console.log(answer[0]);
Example: Multi-frame (summarize slides)
import {
AutoProcessor,
AutoModelForCausalLM,
TextStreamer,
load_image,
} from "@huggingface/transformers";
// Load processor and model
const model_id = "onnx-community/Phi-3.5-vision-instruct";
const processor = await AutoProcessor.from_pretrained(model_id, {
legacy: true, // Use legacy to match python version
});
const model = await AutoModelForCausalLM.from_pretrained(model_id, {
dtype: {
vision_encoder: "q4", // 'q4' or 'q4f16'
prepare_inputs_embeds: "q4", // 'q4' or 'q4f16'
model: "q4f16", // 'q4f16'
},
});
// Load images
const urls = [
"https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-1-2048.jpg",
"https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-2-2048.jpg",
"https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-3-2048.jpg",
];
const images = await Promise.all(urls.map(load_image));
// Prepare inputs
const placeholder = images.map((_, i) => `<|image_${i + 1}|>\n`).join("");
const messages = [
{ role: "user", content: placeholder + "Summarize the deck of slides." },
];
const prompt = processor.tokenizer.apply_chat_template(messages, {
tokenize: false,
add_generation_prompt: true,
});
const inputs = await processor(prompt, images, { num_crops: 4 });
// (Optional) Set up text streamer
const streamer = new TextStreamer(processor.tokenizer, {
skip_prompt: true,
skip_special_tokens: true,
});
// Generate response
const output = await model.generate({
...inputs,
streamer,
max_new_tokens: 256,
});
EXAONE 3.5 for bilingual (English and Korean) text generation
EXAONE 3.5 is a collection of instruction-tuned bilingual (English and Korean) generative models, developed and released by LG AI Research. See #1084 for more information and here for the list of supported models.
Example: Text-generation w/ EXAONE-3.5-2.4B-Instruct
:
import { pipeline } from "@huggingface/transformers";
// Create a text generation pipeline
const generator = await pipeline(
"text-generation",
"onnx-community/EXAONE-3.5-2.4B-Instruct",
{ dtype: "q4f16" },
);
// Define the list of messages
const messages = [
...
3.1.2
🤖 New models
-
Add support for PaliGemma (& PaliGemma2) in #1074
Example: Image captioning with
onnx-community/paligemma2-3b-ft-docci-448
.import { AutoProcessor, PaliGemmaForConditionalGeneration, load_image } from '@huggingface/transformers'; // Load processor and model const model_id = 'onnx-community/paligemma2-3b-ft-docci-448'; const processor = await AutoProcessor.from_pretrained(model_id); const model = await PaliGemmaForConditionalGeneration.from_pretrained(model_id, { dtype: { embed_tokens: 'fp16', // or 'q8' vision_encoder: 'fp16', // or 'q4', 'q8' decoder_model_merged: 'q4', // or 'q4f16' }, }); // Prepare inputs const url = 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg' const raw_image = await load_image(url); const prompt = '<image>caption en'; // Caption the image in English const inputs = await processor(raw_image, prompt); // Generate a response const output = await model.generate({ ...inputs, max_new_tokens: 100, }) const generated_ids = output.slice(null, [inputs.input_ids.dims[1], null]); const answer = processor.batch_decode( generated_ids, { skip_special_tokens: true }, ); console.log(answer[0]); // A side view of a light blue 1970s Volkswagen Beetle parked on a gray cement road. It is facing to the right. It has a reflection on the side of it. Behind it is a yellow building with a brown double door on the right. It has a white frame around it. Part of a gray cement wall is visible on the far left.
List of supported models: https://huggingface.co/models?library=transformers.js&other=paligemma
-
Add support for I-JEPA in #1073
Example: Image feature extraction with
onnx-community/ijepa_vith14_1k
.import { pipeline, cos_sim } from "@huggingface/transformers"; // Create an image feature extraction pipeline const extractor = await pipeline( "image-feature-extraction", "onnx-community/ijepa_vith14_1k", { dtype: "q8" }, ); // Compute image embeddings const url_1 = "http://images.cocodataset.org/val2017/000000039769.jpg" const url_2 = "http://images.cocodataset.org/val2017/000000219578.jpg" const output = await extractor([url_1, url_2]); const pooled_output = output.mean(1); // Apply mean pooling // Compute cosine similarity const similarity = cos_sim(pooled_output[0].data, pooled_output[1].data); console.log(similarity); // 0.5168613045518973
List of supported models: https://huggingface.co/models?library=transformers.js&other=ijepa
-
Add support for OLMo2 in #1076. List of supported models: https://huggingface.co/models?library=transformers.js&other=olmo2
🐛 Bug fixes
- Fix whisper timestamp extraction for tokenizers with added tokens by @aravindMahadevan in #804
- Add missing 'ready' status in the ProgressInfo type by @ocavue in #1070
🛠️ Other improvements
- Add function to apply mask to RawImage by @BritishWerewolf in #1020
- Bump versions + webpack improvements in #1075
🤗 New contributors
- @aravindMahadevan made their first contribution in #804
Full Changelog: 3.1.1...3.1.2
3.1.1
🤖 New models
-
Add support for Idefics3 (SmolVLM) in #1059
import { AutoProcessor, AutoModelForVision2Seq, load_image, } from "@huggingface/transformers"; // Initialize processor and model const model_id = "HuggingFaceTB/SmolVLM-Instruct"; const processor = await AutoProcessor.from_pretrained(model_id); const model = await AutoModelForVision2Seq.from_pretrained(model_id, { dtype: { embed_tokens: "fp16", // "fp32", "fp16", "q8" vision_encoder: "q4", // "fp32", "fp16", "q8", "q4", "q4f16" decoder_model_merged: "q4", // "q8", "q4", "q4f16" } }); // Load images const image1 = await load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"); const image2 = await load_image("https://huggingface.co/spaces/merve/chameleon-7b/resolve/main/bee.jpg"); // Create input messages const messages = [ { role: "user", content: [ { type: "image" }, { type: "image" }, { type: "text", text: "Can you describe the two images?" }, ], }, ]; // Prepare inputs const text = processor.apply_chat_template(messages, { add_generation_prompt: true }); const inputs = await processor(text, [image1, image2], { // Set `do_image_splitting: true` to split images into multiple patches. // NOTE: This uses more memory, but can provide more accurate results. do_image_splitting: false, }); // Generate outputs const generated_ids = await model.generate({ ...inputs, max_new_tokens: 500, }); const generated_texts = processor.batch_decode( generated_ids.slice(null, [inputs.input_ids.dims.at(-1), null]), { skip_special_tokens: true }, ); console.log(generated_texts[0]); // ' In the first image, there is a green statue of liberty on a pedestal in the middle of the water. The water is surrounded by trees and buildings in the background. In the second image, there are pink and red flowers with a bee on the pink flower.'
🐛 Bug fixes
- Fix repetition penalty logits processor in #1062
- Fix optional chaining for batch size calculation in PreTrainedModel by @emojiiii in #1063
📝 Documentation improvements
- Add an example and type enhancement for TextStreamer by @seonglae in #1066
- The smallest typo fix for webgpu.md by @JoramMillenaar in #1068
🛠️ Other improvements
- Only log warning if type not explicitly set to "custom" in #1061
- Improve browser vs. webworker detection in #1067
🤗 New contributors
- @emojiiii made their first contribution in #1063
- @seonglae made their first contribution in #1066
- @JoramMillenaar made their first contribution in #1068
Full Changelog: 3.1.0...3.1.1
3.1.0
🚀 Transformers.js v3.1 — any-to-any, text-to-image, image-to-text, pose estimation, time series forecasting, and more!
Table of contents:
- 🤖 New models: Janus, Qwen2-VL, JinaCLIP, LLaVA-OneVision, ViTPose, MGP-STR, PatchTST, PatchTSMixer.
- 🐛 Bug fixes
- 📝 Documentation improvements
- 🛠️ Other improvements
- 🤗 New contributors
🤖 New models: Janus, Qwen2-VL, JinaCLIP, LLaVA-OneVision, ViTPose, MGP-STR, PatchTST, PatchTSMixer.
Janus for Any-to-Any generation (e.g., image-to-text and text-to-image)
First of all, this release adds support for Janus, a novel autoregressive framework that unifies multimodal understanding and generation. The most popular model, deepseek-ai/Janus-1.3B, is tagged as an "any-to-any" model, and has specifically been trained for the following tasks:
Example: Image-Text-to-Text
import { AutoProcessor, MultiModalityCausalLM } from "@huggingface/transformers";
// Load processor and model
const model_id = "onnx-community/Janus-1.3B-ONNX";
const processor = await AutoProcessor.from_pretrained(model_id);
const model = await MultiModalityCausalLM.from_pretrained(model_id);
// Prepare inputs
const conversation = [
{
role: "User",
content: "<image_placeholder>\nConvert the formula into latex code.",
images: ["https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/quadratic_formula.png"],
},
];
const inputs = await processor(conversation);
// Generate response
const outputs = await model.generate({
...inputs,
max_new_tokens: 150,
do_sample: false,
});
// Decode output
const new_tokens = outputs.slice(null, [inputs.input_ids.dims.at(-1), null]);
const decoded = processor.batch_decode(new_tokens, { skip_special_tokens: true });
console.log(decoded[0]);
Sample output:
Sure, here is the LaTeX code for the given formula:
```
x = \frac{-b \pm \sqrt{b^2 - 4a c}}{2a}
```
This code represents the mathematical expression for the variable \( x \).
Example: Text-to-Image
import { AutoProcessor, MultiModalityCausalLM } from "@huggingface/transformers";
// Load processor and model
const model_id = "onnx-community/Janus-1.3B-ONNX";
const processor = await AutoProcessor.from_pretrained(model_id);
const model = await MultiModalityCausalLM.from_pretrained(model_id);
// Prepare inputs
const conversation = [
{
role: "User",
content: "A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting,immortal,fluffy, shiny mane,Petals,fairyism,unreal engine 5 and Octane Render,highly detailed, photorealistic, cinematic, natural colors.",
},
];
const inputs = await processor(conversation, { chat_template: "text_to_image" });
// Generate response
const num_image_tokens = processor.num_image_tokens;
const outputs = await model.generate_images({
...inputs,
min_new_tokens: num_image_tokens,
max_new_tokens: num_image_tokens,
do_sample: true,
});
// Save the generated image
await outputs[0].save("test.png");
Sample outputs:
What to play around with the model? Check out our online WebGPU demo! 👇
Janus-WebGPU.mp4
Qwen2-VL for Image-Text-to-Text
Example: Image-Text-to-Text
Next, we added support for Qwen2-VL, the multimodal large language model series developed by Qwen team, Alibaba Cloud. It introduces the Naive Dynamic Resolution mechanism, allowing the model to process images of varying resolutions and leading to more efficient and accurate visual representations.
import { AutoProcessor, Qwen2VLForConditionalGeneration, RawImage } from "@huggingface/transformers";
// Load processor and model
const model_id = "onnx-community/Qwen2-VL-2B-Instruct";
const processor = await AutoProcessor.from_pretrained(model_id);
const model = await Qwen2VLForConditionalGeneration.from_pretrained(model_id);
// Prepare inputs
const url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg";
const image = await (await RawImage.read(url)).resize(448, 448);
const conversation = [
{
role: "user",
content: [
{ type: "image" },
{ type: "text", text: "Describe this image." },
],
},
];
const text = processor.apply_chat_template(conversation, { add_generation_prompt: true });
const inputs = await processor(text, image);
// Perform inference
const outputs = await model.generate({
...inputs,
max_new_tokens: 128,
});
// Decode output
const decoded = processor.batch_decode(
outputs.slice(null, [inputs.input_ids.dims.at(-1), null]),
{ skip_special_tokens: true },
);
console.log(decoded[0]);
// The image depicts a serene beach scene with a woman and a dog. The woman is sitting on the sand, wearing a plaid shirt, and appears to be engaged in a playful interaction with the dog. The dog, which is a large breed, is sitting on its hind legs and appears to be reaching out to the woman, possibly to give her a high-five or a paw. The background shows the ocean with gentle waves, and the sky is clear, suggesting it might be either sunrise or sunset. The overall atmosphere is calm and relaxed, capturing a moment of connection between the woman and the dog.
JinaCLIP for multimodal embeddings
JinaCLIP is a series of general-purpose multilingual multimodal embedding models for text & images, created by Jina AI.
Example: Compute text and/or image embeddings with jinaai/jina-clip-v2
:
import { AutoModel, AutoProcessor, RawImage, matmul } from "@huggingface/transformers";
// Load processor and model
const model_id = "jinaai/jina-clip-v2";
const processor = await AutoProcessor.from_pretrained(model_id);
const model = await AutoModel.from_pretrained(model_id, { dtype: "q4" /* e.g., "fp16", "q8", or "q4" */ });
// Prepare inputs
const urls = ["https://i.ibb.co/nQNGqL0/beach1.jpg", "https://i.ibb.co/r5w8hG8/beach2.jpg"];
const images = await Promise.all(urls.map(url => RawImage.read(url)));
const sentences = [
"غروب جميل على الشاطئ", // Arabic
"海滩上美丽的日落", // Chinese
"Un beau coucher de soleil sur la plage", // French
"Ein wunderschöner Sonnenuntergang am Strand", // German
"Ένα όμορφο ηλιοβασίλεμα πάνω από την παραλία", // Greek
"समुद्र तट पर एक खूबसूरत सूर्यास्त", // Hindi
"Un bellissimo tramonto sulla spiaggia", // Italian
"浜辺に沈む美しい夕日", // Japanese
"해변 위로 아름다운 일몰", // Korean
];
// Encode text and images
const inputs = await processor(sentences, images, { padding: true, truncation: true });
const { l2norm_text_embeddings, l2norm_image_embeddings } = await model(inputs);
// Encode query (text-only)
const query_prefix = "Represent the query for retrieving evidence documents: ";
const query_inputs = await processor(query_prefix + "beautiful sunset over the beach");
const { l2norm_text_embeddings: query_embeddings } = await model(query_inputs);
// Compute text-image similarity scores
const text_to_image_scores = await matmul(query_embeddings, l2norm_image_embeddings.transpose(1, 0));
console.log("text-image similarity scores", text_to_image_scores.tolist()[0]); // [0.29530206322669983, 0.3183615803718567]
// Compute image-image similarity scores
const image_to_image_score = await matmul(l2norm_image_embeddings[0], l2norm_image_embeddings[1]);
console.log("image-image similarity score", image_to_image_score.item()); // 0.9344457387924194
// Compute text-text similarity scores
const text_to_text_scores = await matmul(query_embeddings, l2norm_text_embeddings.transpose(1, 0));
console.log("text-text similarity scores", text_to_text_scores.tolist()[0]); // [0.5566609501838684, 0.7028406858444214, 0.582255482673645, 0.6648036241531372, 0.5462006330490112, 0.6791588068008423, 0.6192430257797241, 0.6258729100227356, 0.6453716158866882]
LLaVA-OneVision for Image-Text-to-Text
LLaVA-OneVision is a Vision-Language Model that can generate text conditioned on one or several images/videos. The model consists of SigLIP vision encoder and a Qwen2 language backbone.
Example: Multi-round conversations w/ PKV caching
import { AutoProcessor, AutoTokenizer, LlavaOnevisionForConditionalGeneration, RawImage } from '@huggingface/transformers';
// Load tokenizer, processor and model
const model_id = 'llava-hf/llava-onevision-qwen2-0.5b-ov-hf';
...
3.0.2
What's new?
-
Add support for MobileLLM in #1003
Example: Text generation with
onnx-community/MobileLLM-125M
.import { pipeline } from "@huggingface/transformers"; // Create a text generation pipeline const generator = await pipeline( "text-generation", "onnx-community/MobileLLM-125M", { dtype: "fp32" }, ); // Define the list of messages const text = "Q: What is the capital of France?\nA: Paris\nQ: What is the capital of England?\nA:"; // Generate a response const output = await generator(text, { max_new_tokens: 30 }); console.log(output[0].generated_text);
Example output
Q: What is the capital of France? A: Paris Q: What is the capital of England? A: London Q: What is the capital of Scotland? A: Edinburgh Q: What is the capital of Wales? A: Cardiff
-
Add support for OLMo in #1011
Example: Text generation with
onnx-community/AMD-OLMo-1B-SFT-DPO"
.import { pipeline } from "@huggingface/transformers"; // Create a text generation pipeline const generator = await pipeline( "text-generation", "onnx-community/AMD-OLMo-1B-SFT-DPO", { dtype: "q4" }, ); // Define the list of messages const messages = [ { role: "system", content: "You are a helpful assistant." }, { role: "user", content: "Tell me a joke." }, ]; // Generate a response const output = await generator(messages, { max_new_tokens: 128 }); console.log(output[0].generated_text.at(-1).content);
Example output
Why don't scientists trust atoms? Because they make up everything!
-
Fix CommonJS bundling in #1012. Thanks @jens-ghc for reporting!
-
Remove duplicate
gemma
value fromNO_PER_CHANNEL_REDUCE_RANGE_MODEL
by @bekzod in #1005
🤗 New contributors
Full Changelog: 3.0.1...3.0.2
3.0.1
3.0.0
Transformers.js v3: WebGPU Support, New Models & Tasks, New Quantizations, Deno & Bun Compatibility, and More…
After more than a year of development, we're excited to announce the release of 🤗 Transformers.js v3!
You can get started by installing Transformers.js v3 from NPM using:
npm i @huggingface/transformers
Then, importing the library with
import { pipeline } from "@huggingface/transformers";
or, via a CDN
import { pipeline } from "https://cdn.jsdelivr.net/npm/@huggingface/[email protected]";
For more information, check out the documentation.
⚡ WebGPU support (up to 100x faster than WASM!)
WebGPU is a new web standard for accelerated graphics and compute. The API enables web developers to use the underlying system's GPU to carry out high-performance computations directly in the browser. WebGPU is the successor to WebGL and provides significantly better performance, because it allows for more direct interaction with modern GPUs. Lastly, it supports general-purpose GPU computations, which makes it just perfect for machine learning!
Warning
As of October 2024, global WebGPU support is around 70% (according to caniuse.com), meaning some users may not be able to use the API.
If the following demos do not work in your browser, you may need to enable it using a feature flag:
Usage in Transformers.js v3
Thanks to our collaboration with ONNX Runtime Web, enabling WebGPU acceleration is as simple as setting device: 'webgpu'
when loading a model. Let's see some examples!
Example: Compute text embeddings on WebGPU (demo)
import { pipeline } from "@huggingface/transformers";
// Create a feature-extraction pipeline
const extractor = await pipeline(
"feature-extraction",
"mixedbread-ai/mxbai-embed-xsmall-v1",
{ device: "webgpu" },
});
// Compute embeddings
const texts = ["Hello world!", "This is an example sentence."];
const embeddings = await extractor(texts, { pooling: "mean", normalize: true });
console.log(embeddings.tolist());
// [
// [-0.016986183822155, 0.03228696808218956, -0.0013630966423079371, ... ],
// [0.09050482511520386, 0.07207386940717697, 0.05762749910354614, ... ],
// ]
Example: Perform automatic speech recognition with OpenAI whisper on WebGPU (demo)
import { pipeline } from "@huggingface/transformers";
// Create automatic speech recognition pipeline
const transcriber = await pipeline(
"automatic-speech-recognition",
"onnx-community/whisper-tiny.en",
{ device: "webgpu" },
);
// Transcribe audio from a URL
const url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav";
const output = await transcriber(url);
console.log(output);
// { text: ' And so my fellow Americans ask not what your country can do for you, ask what you can do for your country.' }
Example: Perform image classification with MobileNetV4 on WebGPU (demo)
import { pipeline } from "@huggingface/transformers";
// Create image classification pipeline
const classifier = await pipeline(
"image-classification",
"onnx-community/mobilenetv4_conv_small.e2400_r224_in1k",
{ device: "webgpu" },
);
// Classify an image from a URL
const url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg";
const output = await classifier(url);
console.log(output);
// [
// { label: 'tiger, Panthera tigris', score: 0.6149784922599792 },
// { label: 'tiger cat', score: 0.30281734466552734 },
// { label: 'tabby, tabby cat', score: 0.0019135422771796584 },
// { label: 'lynx, catamount', score: 0.0012161266058683395 },
// { label: 'Egyptian cat', score: 0.0011465961579233408 }
// ]
🔢 New quantization formats (dtypes)
Before Transformers.js v3, we used the quantized
option to specify whether to use a quantized (q8) or full-precision (fp32) variant of the model by setting quantized
to true
or false
, respectively. Now, we've added the ability to select from a much larger list with the dtype
parameter.
The list of available quantizations depends on the model, but some common ones are: full-precision ("fp32"
), half-precision ("fp16"
), 8-bit ("q8"
, "int8"
, "uint8"
), and 4-bit ("q4"
, "bnb4"
, "q4f16"
).
Basic usage
Example: Run Qwen2.5-0.5B-Instruct in 4-bit quantization (demo)
import { pipeline } from "@huggingface/transformers";
// Create a text generation pipeline
const generator = await pipeline(
"text-generation",
"onnx-community/Qwen2.5-0.5B-Instruct",
{ dtype: "q4", device: "webgpu" },
);
// Define the list of messages
const messages = [
{ role: "system", content: "You are a helpful assistant." },
{ role: "user", content: "Tell me a funny joke." },
];
// Generate a response
const output = await generator(messages, { max_new_tokens: 128 });
console.log(output[0].generated_text.at(-1).content);
Per-module dtypes
Some encoder-decoder models, like Whisper or Florence-2, are extremely sensitive to quantization settings: especially of the encoder. For this reason, we added the ability to select per-module dtypes, which can be done by providing a mapping from module name to dtype.
Example: Run Florence-2 on WebGPU (demo)
import { Florence2ForConditionalGeneration } from "@huggingface/transformers";
const model = await Florence2ForConditionalGeneration.from_pretrained(
"onnx-community/Florence-2-base-ft",
{
dtype: {
embed_tokens: "fp16",
vision_encoder: "fp16",
encoder_model: "q4",
decoder_model_merged: "q4",
},
device: "webgpu",
},
);
See full code example
import {
Florence2ForConditionalGeneration,
AutoProcessor,
AutoTokenizer,
RawImage,
} from "@huggingface/transformers";
// Load model, processor, and tokenizer
const model_id = "onnx-community/Florence-2-base-ft";
const model = await Florence2ForConditionalGeneration.from_pretrained(
model_id,
{
dtype: {
embed_tokens: "fp16",
vision_encoder: "fp16",
encoder_model: "q4",
decoder_model_merged: "q4",
},
device: "webgpu",
},
);
const processor = await AutoProcessor.from_pretrained(model_id);
const tokenizer = await AutoTokenizer.from_pretrained(model_id);
// Load image and prepare vision inputs
const url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg";
const image = await RawImage.fromURL(url);
const vision_inputs = await processor(image);
// Specify task and prepare text inputs
const task = "<MORE_DETAILED_CAPTION>";
const prompts = processor.construct_prompts(task);
const text_inputs = tokenizer(prompts);
// Generate text
const generated_ids = await model.generate({
...text_inputs,
...vision_inputs,
max_new_tokens: 100,
});
// Decode generated text
const generated_text = tokenizer.batch_decode(generated_ids, {
skip_special_tokens: false,
})[0];
// Post-process the generated text
const result = processor.post_process_generation(
generated_text,
task,
image.size,
);
console.log(result);
// { '<MORE_DETAILED_CAPTION>': 'A green car is parked in front of a tan building. The building has a brown door and two brown windows. The car is a two door and the door is closed. The green car has black tires.' }
🏛 A total of 120 supported architectures
This release increases the total number of supported architectures to 120 (see full list), spanning a wide range of input modalities and tasks. Notable ...
2.17.2
🚀 What's new?
-
Add support for MobileViTv2 in #721
import { pipeline } from '@xenova/transformers'; // Create an image classification pipeline const classifier = await pipeline('image-classification', 'Xenova/mobilevitv2-1.0-imagenet1k-256', { quantized: false, }); // Classify an image const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg'; const output = await classifier(url); // [{ label: 'tiger, Panthera tigris', score: 0.6491137742996216 }]
See here for the full list of supported models.
-
Add support for FastViT in #749
import { pipeline } from '@xenova/transformers'; // Create an image classification pipeline const classifier = await pipeline('image-classification', 'Xenova/fastvit_t12.apple_in1k', { quantized: false }); // Classify an image const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg'; const output = await classifier(url, { topk: 5 }); // [ // { label: 'tiger, Panthera tigris', score: 0.6649345755577087 }, // { label: 'tiger cat', score: 0.12454754114151001 }, // { label: 'lynx, catamount', score: 0.0010689536575227976 }, // { label: 'dhole, Cuon alpinus', score: 0.0010422508930787444 }, // { label: 'silky terrier, Sydney silky', score: 0.0009548701345920563 } // ]
See here for the full list of supported models.
-
Optimize FFT in #766
-
Add sequence post processor in #771
-
Update pipelines.js to allow for
token_embeddings
as well by @NikhilVerma in #770 -
Remove old import from
stream/web
forReadableStream
in #752 -
docs: update vanilla-js.md by @eltociear in #738
-
Fix CI by in #768
-
Update Next.js demos to 14.2.3 in #772
🤗 New contributors
- @eltociear made their first contribution in #738
- @KTibow made their first contribution in #737
- @NawarA made their first contribution in #594
- @NikhilVerma made their first contribution in #770
Full Changelog: 2.17.1...2.17.2
2.17.1
2.17.0
What's new?
💬 Improved text-generation
pipeline for conversational models
This version adds support for passing an array of chat messages (with "role" and "content" properties) to the text-generation
pipeline (PR). Check out the list of supported models here.
Example: Chat with Xenova/Qwen1.5-0.5B-Chat
.
import { pipeline } from '@xenova/transformers';
// Create text-generation pipeline
const generator = await pipeline('text-generation', 'Xenova/Qwen1.5-0.5B-Chat');
// Define the list of messages
const messages = [
{ role: 'system', content: 'You are a helpful assistant.' },
{ role: 'user', content: 'Tell me a funny joke.' }
]
// Generate text
const output = await generator(messages, {
max_new_tokens: 128,
do_sample: false,
})
console.log(output[0].generated_text);
// [
// { role: 'system', content: 'You are a helpful assistant.' },
// { role: 'user', content: 'Tell me a funny joke.' },
// { role: 'assistant', content: "Sure, here's one:\n\nWhy was the math book sad?\n\nBecause it had too many problems.\n\nI hope you found that joke amusing! Do you have any other questions or topics you'd like to discuss?" },
// ]
We also added the return_full_text
parameter, which means if you set return_full_text=false
, only the newly-generated tokens will be returned (only applicable if passing the raw text prompt to the pipeline).
🔢 Binary embedding quantization support
Transformers.js v2.17 adds two new parameters to the feature-extraction
pipeline ("quantize" and "precision"), enabling you to generate binary embeddings. These can be used with certain embedding models to shrink the size of the document embeddings for retrieval. This results in reductions in index size/memory usage (for storage) and improvements in retrieval speed. Surprisingly, you can still achieve up to ~95% of the original performance, but at 32x storage savings and up to 32x retrieval speeds! 🤯 Thanks to @jonathanpv for this addition in #691!
import { pipeline } from '@xenova/transformers';
// Create feature-extraction pipeline
const extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2');
// Compute binary embeddings
const output = await extractor('This is a simple test.', { pooling: 'mean', quantize: true, precision: 'binary' });
// Tensor {
// type: 'int8',
// data: Int8Array [49, 108, 24, ...],
// dims: [1, 48]
// }
As you can see, this produces a 32x smaller output tensor (a 4x reduction in data type with Float32Array → Int8Array, as well as an 8x reduction in dimensionality from 384 → 48). For more information, check out this PR in sentence-transformers, which inspired this update!
🛠️ Misc. improvements
🤗 New contributors
- @pulsejet made their first contribution in #667
- @jonathanpv made their first contribution in #691
Full Changelog: 2.16.1...2.17.0