diff --git a/src/base/feature_extraction_utils.js b/src/base/feature_extraction_utils.js new file mode 100644 index 000000000..53a5e4941 --- /dev/null +++ b/src/base/feature_extraction_utils.js @@ -0,0 +1,54 @@ +import { FEATURE_EXTRACTOR_NAME } from "../utils/constants.js"; +import { Callable } from "../utils/generic.js"; +import { getModelJSON } from "../utils/hub.js"; + +/** + * Base class for feature extractors. + */ +export class FeatureExtractor extends Callable { + /** + * Constructs a new FeatureExtractor instance. + * + * @param {Object} config The configuration for the feature extractor. + */ + constructor(config) { + super(); + this.config = config + } + + /** + * Instantiate one of the processor classes of the library from a pretrained model. + * + * The processor class to instantiate is selected based on the `image_processor_type` (or `feature_extractor_type`; legacy) + * property of the config object (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible) + * + * @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either: + * - A string, the *model id* of a pretrained processor hosted inside a model repo on huggingface.co. + * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a + * user or organization name, like `dbmdz/bert-base-german-cased`. + * - A path to a *directory* containing processor files, e.g., `./my_model_directory/`. + * @param {import('../utils/hub.js').PretrainedOptions} options Additional options for loading the processor. + * + * @returns {Promise} A new instance of the Processor class. + */ + static async from_pretrained(pretrained_model_name_or_path, options) { + const preprocessorConfig = await getModelJSON(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, true, options); + return new this(preprocessorConfig); + } +} + + +/** + * Helper function to validate audio inputs. + * @param {any} audio The audio data. + * @param {string} feature_extractor The name of the feature extractor. + * @private + */ +export function validate_audio_inputs(audio, feature_extractor) { + if (!(audio instanceof Float32Array || audio instanceof Float64Array)) { + throw new Error( + `${feature_extractor} expects input to be a Float32Array or a Float64Array, but got ${audio?.constructor?.name ?? typeof audio} instead. ` + + `If using the feature extractor directly, remember to use \`read_audio(url, sampling_rate)\` to obtain the raw audio data of the file/url.` + ) + } +} diff --git a/src/base/image_processors_utils.js b/src/base/image_processors_utils.js new file mode 100644 index 000000000..3cd12740a --- /dev/null +++ b/src/base/image_processors_utils.js @@ -0,0 +1,1043 @@ +import { Callable } from "../utils/generic.js"; +import { Tensor, interpolate,stack } from "../utils/tensor.js"; +import { bankers_round, max, min, softmax } from "../utils/maths.js"; +import { RawImage } from "../utils/image.js"; +import { calculateReflectOffset } from "../utils/core.js"; +import { getModelJSON } from "../utils/hub.js"; +import { IMAGE_PROCESSOR_NAME } from '../utils/constants.js'; + +/** + * Named tuple to indicate the order we are using is (height x width), + * even though the Graphics' industry standard is (width x height). + * @typedef {[height: number, width: number]} HeightWidth + */ + + +/** + * @typedef {object} ImageProcessorResult + * @property {Tensor} pixel_values The pixel values of the batched preprocessed images. + * @property {HeightWidth[]} original_sizes Array of two-dimensional tuples like [[480, 640]]. + * @property {HeightWidth[]} reshaped_input_sizes Array of two-dimensional tuples like [[1000, 1330]]. + */ + + + +/** + * Helper function to constrain a value to be a multiple of a number. + * @param {number} val The value to constrain. + * @param {number} multiple The number to constrain to. + * @param {number} [minVal=0] The minimum value to constrain to. + * @param {number} [maxVal=null] The maximum value to constrain to. + * @returns {number} The constrained value. + * @private + */ +function constraint_to_multiple_of(val, multiple, minVal = 0, maxVal = null) { + const a = val / multiple; + let x = bankers_round(a) * multiple; + + if (maxVal !== null && x > maxVal) { + x = Math.floor(a) * multiple; + } + + if (x < minVal) { + x = Math.ceil(a) * multiple; + } + + return x; +} + +/** + * Rounds the height and width down to the closest multiple of size_divisibility + * @param {[number, number]} size The size of the image + * @param {number} divisor The divisor to use. + * @returns {[number, number]} The rounded size. + */ +function enforce_size_divisibility([width, height], divisor) { + return [ + Math.max(Math.floor(width / divisor), 1) * divisor, + Math.max(Math.floor(height / divisor), 1) * divisor + ]; +} + + +// Helper functions + +/** + * Converts bounding boxes from center format to corners format. + * + * @param {number[]} arr The coordinate for the center of the box and its width, height dimensions (center_x, center_y, width, height) + * @returns {number[]} The coodinates for the top-left and bottom-right corners of the box (top_left_x, top_left_y, bottom_right_x, bottom_right_y) + */ +function center_to_corners_format([centerX, centerY, width, height]) { + return [ + centerX - width / 2, + centerY - height / 2, + centerX + width / 2, + centerY + height / 2 + ]; +} + +/** + * Post-processes the outputs of the model (for object detection). + * @param {Object} outputs The outputs of the model that must be post-processed + * @param {Tensor} outputs.logits The logits + * @param {Tensor} outputs.pred_boxes The predicted boxes. + * @param {number} [threshold=0.5] The threshold to use for the scores. + * @param {[number, number][]} [target_sizes=null] The sizes of the original images. + * @param {boolean} [is_zero_shot=false] Whether zero-shot object detection was performed. + * @return {Object[]} An array of objects containing the post-processed outputs. + */ +export function post_process_object_detection(outputs, threshold = 0.5, target_sizes = null, is_zero_shot = false) { + const out_logits = outputs.logits; + const out_bbox = outputs.pred_boxes; + const [batch_size, num_boxes, num_classes] = out_logits.dims; + + if (target_sizes !== null && target_sizes.length !== batch_size) { + throw Error("Make sure that you pass in as many target sizes as the batch dimension of the logits") + } + let toReturn = []; + for (let i = 0; i < batch_size; ++i) { + let target_size = target_sizes !== null ? target_sizes[i] : null; + let info = { + boxes: [], + classes: [], + scores: [] + } + let logits = out_logits[i]; + let bbox = out_bbox[i]; + + for (let j = 0; j < num_boxes; ++j) { + let logit = logits[j]; + + let indices = []; + let probs; + if (is_zero_shot) { + // Get indices of classes with high enough probability + probs = logit.sigmoid().data; + for (let k = 0; k < probs.length; ++k) { + if (probs[k] > threshold) { + indices.push(k); + } + } + + } else { + // Get most probable class + let maxIndex = max(logit.data)[1]; + + if (maxIndex === num_classes - 1) { + // This is the background class, skip it + continue; + } + // Compute softmax over classes + probs = softmax(logit.data); + + if (probs[maxIndex] < threshold) { + continue; + } + indices.push(maxIndex); + } + + for (const index of indices) { + + // Some class has a high enough probability + /** @type {number[]} */ + let box = bbox[j].data; + + // convert to [x0, y0, x1, y1] format + box = center_to_corners_format(box) + if (target_size !== null) { + box = box.map((x, i) => x * target_size[(i + 1) % 2]) + } + + info.boxes.push(box); + info.classes.push(index); + info.scores.push(probs[index]); + } + } + toReturn.push(info); + } + return toReturn; +} + + +/** + * Post-processes the outputs of the model (for semantic segmentation). + * @param {*} outputs Raw outputs of the model. + * @param {[number, number][]} [target_sizes=null] List of tuples corresponding to the requested final size + * (height, width) of each prediction. If unset, predictions will not be resized. + * @returns {{segmentation: Tensor; labels: number[]}[]} The semantic segmentation maps. + */ +export function post_process_semantic_segmentation(outputs, target_sizes = null) { + + const logits = outputs.logits; + const batch_size = logits.dims[0]; + + if (target_sizes !== null && target_sizes.length !== batch_size) { + throw Error("Make sure that you pass in as many target sizes as the batch dimension of the logits") + } + + const toReturn = []; + for (let i = 0; i < batch_size; ++i) { + const target_size = target_sizes !== null ? target_sizes[i] : null; + + let data = logits[i]; + + // 1. If target_size is not null, we need to resize the masks to the target size + if (target_size !== null) { + // resize the masks to the target size + data = interpolate(data, target_size, 'bilinear', false); + } + const [height, width] = target_size ?? data.dims.slice(-2); + + const segmentation = new Tensor( + 'int32', + new Int32Array(height * width), + [height, width] + ); + + // Buffer to store current largest value + const buffer = data[0].data; + const segmentation_data = segmentation.data; + for (let j = 1; j < data.dims[0]; ++j) { + const row = data[j].data; + for (let k = 0; k < row.length; ++k) { + if (row[k] > buffer[k]) { + buffer[k] = row[k]; + segmentation_data[k] = j; + } + } + } + + // Store which objects have labels + // This is much more efficient that creating a set of the final values + const hasLabel = new Array(data.dims[0]); + for (let j = 0; j < segmentation_data.length; ++j) { + const index = segmentation_data[j]; + hasLabel[index] = index; + } + /** @type {number[]} The unique list of labels that were detected */ + const labels = hasLabel.filter(x => x !== undefined); + + toReturn.push({ segmentation, labels }); + } + return toReturn; +} + + +/** + * Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and `labels`. + * @param {Tensor} class_logits The class logits. + * @param {Tensor} mask_logits The mask logits. + * @param {number} object_mask_threshold A number between 0 and 1 used to binarize the masks. + * @param {number} num_labels The number of labels. + * @returns {[Tensor[], number[], number[]]} The binarized masks, the scores, and the labels. + * @private + */ +function remove_low_and_no_objects(class_logits, mask_logits, object_mask_threshold, num_labels) { + + const mask_probs_item = []; + const pred_scores_item = []; + const pred_labels_item = []; + + for (let j = 0; j < class_logits.dims[0]; ++j) { + const cls = class_logits[j]; + const mask = mask_logits[j]; + + const pred_label = max(cls.data)[1]; + if (pred_label === num_labels) { + // Is the background, so we ignore it + continue; + } + + const scores = softmax(cls.data); + const pred_score = scores[pred_label]; + if (pred_score > object_mask_threshold) { + mask_probs_item.push(mask); + pred_scores_item.push(pred_score); + pred_labels_item.push(pred_label); + } + } + + return [mask_probs_item, pred_scores_item, pred_labels_item]; +} + +/** + * Checks whether the segment is valid or not. + * @param {Int32Array} mask_labels Labels for each pixel in the mask. + * @param {Tensor[]} mask_probs Probabilities for each pixel in the masks. + * @param {number} k The class id of the segment. + * @param {number} mask_threshold The mask threshold. + * @param {number} overlap_mask_area_threshold The overlap mask area threshold. + * @returns {[boolean, number[]]} Whether the segment is valid or not, and the indices of the valid labels. + * @private + */ +function check_segment_validity( + mask_labels, + mask_probs, + k, + mask_threshold = 0.5, + overlap_mask_area_threshold = 0.8 +) { + // mask_k is a 1D array of indices, indicating where the mask is equal to k + const mask_k = []; + let mask_k_area = 0; + let original_area = 0; + + const mask_probs_k_data = mask_probs[k].data; + + // Compute the area of all the stuff in query k + for (let i = 0; i < mask_labels.length; ++i) { + if (mask_labels[i] === k) { + mask_k.push(i); + ++mask_k_area; + } + + if (mask_probs_k_data[i] >= mask_threshold) { + ++original_area; + } + } + let mask_exists = mask_k_area > 0 && original_area > 0; + + // Eliminate disconnected tiny segments + if (mask_exists) { + // Perform additional check + let area_ratio = mask_k_area / original_area; + mask_exists = area_ratio > overlap_mask_area_threshold; + } + + return [mask_exists, mask_k] +} + +/** + * Computes the segments. + * @param {Tensor[]} mask_probs The mask probabilities. + * @param {number[]} pred_scores The predicted scores. + * @param {number[]} pred_labels The predicted labels. + * @param {number} mask_threshold The mask threshold. + * @param {number} overlap_mask_area_threshold The overlap mask area threshold. + * @param {Set} label_ids_to_fuse The label ids to fuse. + * @param {number[]} target_size The target size of the image. + * @returns {[Tensor, Array<{id: number, label_id: number, score: number}>]} The computed segments. + * @private + */ +function compute_segments( + mask_probs, + pred_scores, + pred_labels, + mask_threshold, + overlap_mask_area_threshold, + label_ids_to_fuse = null, + target_size = null, +) { + const [height, width] = target_size ?? mask_probs[0].dims; + + const segmentation = new Tensor( + 'int32', + new Int32Array(height * width), + [height, width] + ); + const segments = []; + + // 1. If target_size is not null, we need to resize the masks to the target size + if (target_size !== null) { + // resize the masks to the target size + for (let i = 0; i < mask_probs.length; ++i) { + mask_probs[i] = interpolate(mask_probs[i], target_size, 'bilinear', false); + } + } + + // 2. Weigh each mask by its prediction score + // NOTE: `mask_probs` is updated in-place + // + // Temporary storage for the best label/scores for each pixel ([height, width]): + const mask_labels = new Int32Array(mask_probs[0].data.length); + const bestScores = new Float32Array(mask_probs[0].data.length); + + for (let i = 0; i < mask_probs.length; ++i) { + let score = pred_scores[i]; + + const mask_probs_i_data = mask_probs[i].data; + + for (let j = 0; j < mask_probs_i_data.length; ++j) { + mask_probs_i_data[j] *= score + if (mask_probs_i_data[j] > bestScores[j]) { + mask_labels[j] = i; + bestScores[j] = mask_probs_i_data[j]; + } + } + } + + let current_segment_id = 0; + + // let stuff_memory_list = {} + const segmentation_data = segmentation.data; + for (let k = 0; k < pred_labels.length; ++k) { + const pred_class = pred_labels[k]; + + // TODO add `should_fuse` + // let should_fuse = pred_class in label_ids_to_fuse + + // Check if mask exists and large enough to be a segment + const [mask_exists, mask_k] = check_segment_validity( + mask_labels, + mask_probs, + k, + mask_threshold, + overlap_mask_area_threshold + ) + + if (!mask_exists) { + // Nothing to see here + continue; + } + + // TODO + // if (pred_class in stuff_memory_list) { + // current_segment_id = stuff_memory_list[pred_class] + // } else { + // current_segment_id += 1; + // } + ++current_segment_id; + + + // Add current object segment to final segmentation map + for (const index of mask_k) { + segmentation_data[index] = current_segment_id; + } + + segments.push({ + id: current_segment_id, + label_id: pred_class, + // was_fused: should_fuse, TODO + score: pred_scores[k], + }) + + // TODO + // if(should_fuse){ + // stuff_memory_list[pred_class] = current_segment_id + // } + } + + return [segmentation, segments]; +} + + +/** + * Post-process the model output to generate the final panoptic segmentation. + * @param {*} outputs The model output to post process + * @param {number} [threshold=0.5] The probability score threshold to keep predicted instance masks. + * @param {number} [mask_threshold=0.5] Threshold to use when turning the predicted masks into binary values. + * @param {number} [overlap_mask_area_threshold=0.8] The overlap mask area threshold to merge or discard small disconnected parts within each binary instance mask. + * @param {Set} [label_ids_to_fuse=null] The labels in this state will have all their instances be fused together. + * @param {[number, number][]} [target_sizes=null] The target sizes to resize the masks to. + * @returns {Array<{ segmentation: Tensor, segments_info: Array<{id: number, label_id: number, score: number}>}>} + */ +export function post_process_panoptic_segmentation( + outputs, + threshold = 0.5, + mask_threshold = 0.5, + overlap_mask_area_threshold = 0.8, + label_ids_to_fuse = null, + target_sizes = null, +) { + if (label_ids_to_fuse === null) { + console.warn("`label_ids_to_fuse` unset. No instance will be fused.") + label_ids_to_fuse = new Set(); + } + + const class_queries_logits = outputs.class_queries_logits ?? outputs.logits; // [batch_size, num_queries, num_classes+1] + const masks_queries_logits = outputs.masks_queries_logits ?? outputs.pred_masks; // [batch_size, num_queries, height, width] + + const mask_probs = masks_queries_logits.sigmoid() // [batch_size, num_queries, height, width] + + let [batch_size, num_queries, num_labels] = class_queries_logits.dims; + num_labels -= 1; // Remove last class (background) + + if (target_sizes !== null && target_sizes.length !== batch_size) { + throw Error("Make sure that you pass in as many target sizes as the batch dimension of the logits") + } + + let toReturn = []; + for (let i = 0; i < batch_size; ++i) { + let target_size = target_sizes !== null ? target_sizes[i] : null; + + let class_logits = class_queries_logits[i]; + let mask_logits = mask_probs[i]; + + let [mask_probs_item, pred_scores_item, pred_labels_item] = remove_low_and_no_objects(class_logits, mask_logits, threshold, num_labels); + + if (pred_labels_item.length === 0) { + // No mask found + let [height, width] = target_size ?? mask_logits.dims.slice(-2); + + let segmentation = new Tensor( + 'int32', + new Int32Array(height * width).fill(-1), + [height, width] + ) + toReturn.push({ + segmentation: segmentation, + segments_info: [] + }); + continue; + } + + + // Get segmentation map and segment information of batch item + let [segmentation, segments] = compute_segments( + mask_probs_item, + pred_scores_item, + pred_labels_item, + mask_threshold, + overlap_mask_area_threshold, + label_ids_to_fuse, + target_size, + ) + + toReturn.push({ + segmentation: segmentation, + segments_info: segments + }) + } + + return toReturn; +} + + +/** + * Post-processes the outputs of the model (for instance segmentation). + * @param {*} outputs Raw outputs of the model. + * @param {number} [threshold=0.5] The probability score threshold to keep predicted instance masks. + * @param {[number, number][]} [target_sizes=null] List of tuples corresponding to the requested final size + * (height, width) of each prediction. If unset, predictions will not be resized. + * @returns {Array<{ segmentation: Tensor, segments_info: Array<{id: number, label_id: number, score: number}>}>} + */ +export function post_process_instance_segmentation(outputs, threshold = 0.5, target_sizes = null) { + throw new Error('`post_process_instance_segmentation` is not yet implemented.'); +} + + +/** + * @typedef {Object} ImageProcessorConfig A configuration object used to create an image processor. + * @property {function} [progress_callback=null] If specified, this function will be called during model construction, to provide the user with progress updates. + * @property {number[]} [image_mean] The mean values for image normalization. + * @property {number[]} [image_std] The standard deviation values for image normalization. + * @property {boolean} [do_rescale] Whether to rescale the image pixel values to the [0,1] range. + * @property {number} [rescale_factor] The factor to use for rescaling the image pixel values. + * @property {boolean} [do_normalize] Whether to normalize the image pixel values. + * @property {boolean} [do_resize] Whether to resize the image. + * @property {number} [resample] What method to use for resampling. + * @property {number|Object} [size] The size to resize the image to. + * @property {number|Object} [image_size] The size to resize the image to (same as `size`). + * @property {boolean} [do_flip_channel_order=false] Whether to flip the color channels from RGB to BGR. + * Can be overridden by the `do_flip_channel_order` parameter in the `preprocess` method. + * @property {boolean} [do_center_crop] Whether to center crop the image to the specified `crop_size`. + * Can be overridden by `do_center_crop` in the `preprocess` method. + * @property {boolean} [do_thumbnail] Whether to resize the image using thumbnail method. + * @property {boolean} [keep_aspect_ratio] If `true`, the image is resized to the largest possible size such that the aspect ratio is preserved. + * Can be overidden by `keep_aspect_ratio` in `preprocess`. + * @property {number} [ensure_multiple_of] If `do_resize` is `true`, the image is resized to a size that is a multiple of this value. + * Can be overidden by `ensure_multiple_of` in `preprocess`. + * + * @property {number[]} [mean] The mean values for image normalization (same as `image_mean`). + * @property {number[]} [std] The standard deviation values for image normalization (same as `image_std`). + */ + +export class ImageProcessor extends Callable { + + /** + * Constructs a new `ImageProcessor`. + * @param {ImageProcessorConfig} config The configuration object. + */ + constructor(config) { + super(); + + this.image_mean = config.image_mean ?? config.mean; + this.image_std = config.image_std ?? config.std; + + this.resample = config.resample ?? 2; // 2 => bilinear + this.do_rescale = config.do_rescale ?? true; + this.rescale_factor = config.rescale_factor ?? (1 / 255); + this.do_normalize = config.do_normalize; + + this.do_thumbnail = config.do_thumbnail; + this.size = config.size ?? config.image_size; + this.do_resize = config.do_resize ?? (this.size !== undefined); + this.size_divisibility = config.size_divisibility ?? config.size_divisor; + + this.do_center_crop = config.do_center_crop; + this.crop_size = config.crop_size; + this.do_convert_rgb = config.do_convert_rgb ?? true; + this.do_crop_margin = config.do_crop_margin; + + this.pad_size = config.pad_size; + this.do_pad = config.do_pad; + + if (this.do_pad && !this.pad_size && this.size && this.size.width !== undefined && this.size.height !== undefined) { + // Should pad, but no pad size specified + // We infer the pad size from the resize size + this.pad_size = this.size + } + + this.do_flip_channel_order = config.do_flip_channel_order ?? false; + + this.config = config; + } + + /** + * Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any + * corresponding dimension of the specified size. + * @param {RawImage} image The image to be resized. + * @param {{height:number, width:number}} size The size `{"height": h, "width": w}` to resize the image to. + * @param {string | 0 | 1 | 2 | 3 | 4 | 5} [resample=2] The resampling filter to use. + * @returns {Promise} The resized image. + */ + async thumbnail(image, size, resample = 2) { + const input_height = image.height; + const input_width = image.width; + + const output_height = size.height; + const output_width = size.width; + + // We always resize to the smallest of either the input or output size. + let height = Math.min(input_height, output_height) + let width = Math.min(input_width, output_width) + + if (height === input_height && width === input_width) { + return image; + } + if (input_height > input_width) { + width = Math.floor(input_width * height / input_height); + } else if (input_width > input_height) { + height = Math.floor(input_height * width / input_width); + } + return await image.resize(width, height, { resample }); + } + + + /** + * Crops the margin of the image. Gray pixels are considered margin (i.e., pixels with a value below the threshold). + * @param {RawImage} image The image to be cropped. + * @param {number} gray_threshold Value below which pixels are considered to be gray. + * @returns {Promise} The cropped image. + */ + async crop_margin(image, gray_threshold = 200) { + + const gray_image = image.clone().grayscale(); + + const minValue = min(gray_image.data)[0]; + const maxValue = max(gray_image.data)[0]; + const diff = maxValue - minValue; + + if (diff === 0) { + return image; + } + + const threshold = gray_threshold / 255; + + let x_min = gray_image.width, y_min = gray_image.height, x_max = 0, y_max = 0; + const gray_image_data = gray_image.data; + for (let j = 0; j < gray_image.height; ++j) { + const row = j * gray_image.width; + for (let i = 0; i < gray_image.width; ++i) { + if ((gray_image_data[row + i] - minValue) / diff < threshold) { + // We have a non-zero pixel, so we update the min/max values accordingly + x_min = Math.min(x_min, i); + y_min = Math.min(y_min, j); + x_max = Math.max(x_max, i); + y_max = Math.max(y_max, j); + } + } + } + + image = await image.crop([x_min, y_min, x_max, y_max]); + return image; + } + + /** + * Pad the image by a certain amount. + * @param {Float32Array} pixelData The pixel data to pad. + * @param {number[]} imgDims The dimensions of the image (height, width, channels). + * @param {{width:number; height:number}|number} padSize The dimensions of the padded image. + * @param {Object} options The options for padding. + * @param {'constant'|'symmetric'} [options.mode='constant'] The type of padding to add. + * @param {boolean} [options.center=false] Whether to center the image. + * @param {number|number[]} [options.constant_values=0] The constant value to use for padding. + * @returns {[Float32Array, number[]]} The padded pixel data and image dimensions. + */ + pad_image(pixelData, imgDims, padSize, { + mode = 'constant', + center = false, + constant_values = 0, + } = {}) { + const [imageHeight, imageWidth, imageChannels] = imgDims; + + let paddedImageWidth, paddedImageHeight; + if (typeof padSize === 'number') { + paddedImageWidth = padSize; + paddedImageHeight = padSize; + } else { + paddedImageWidth = padSize.width; + paddedImageHeight = padSize.height; + } + + // Only add padding if there is a difference in size + if (paddedImageWidth !== imageWidth || paddedImageHeight !== imageHeight) { + const paddedPixelData = new Float32Array(paddedImageWidth * paddedImageHeight * imageChannels); + if (Array.isArray(constant_values)) { + // Fill with constant values, cycling through the array + for (let i = 0; i < paddedPixelData.length; ++i) { + paddedPixelData[i] = constant_values[i % imageChannels]; + } + } else if (constant_values !== 0) { + paddedPixelData.fill(constant_values); + } + + const [left, top] = center + ? [Math.floor((paddedImageWidth - imageWidth) / 2), Math.floor((paddedImageHeight - imageHeight) / 2)] + : [0, 0]; + + // Copy the original image into the padded image + for (let i = 0; i < imageHeight; ++i) { + const a = (i + top) * paddedImageWidth; + const b = i * imageWidth; + for (let j = 0; j < imageWidth; ++j) { + const c = (a + j + left) * imageChannels; + const d = (b + j) * imageChannels; + for (let k = 0; k < imageChannels; ++k) { + paddedPixelData[c + k] = pixelData[d + k]; + } + } + } + + if (mode === 'symmetric') { + if (center) { + throw new Error('`center` padding is not supported when `mode` is set to `symmetric`.'); + // TODO: Implement this + } + const h1 = imageHeight - 1; + const w1 = imageWidth - 1; + for (let i = 0; i < paddedImageHeight; ++i) { + const a = i * paddedImageWidth; + const b = calculateReflectOffset(i, h1) * imageWidth; + + for (let j = 0; j < paddedImageWidth; ++j) { + if (i < imageHeight && j < imageWidth) continue; // Do not overwrite original image + const c = (a + j) * imageChannels; + const d = (b + calculateReflectOffset(j, w1)) * imageChannels; + + // Copy channel-wise + for (let k = 0; k < imageChannels; ++k) { + paddedPixelData[c + k] = pixelData[d + k]; + } + } + } + } + + + // Update pixel data and image dimensions + pixelData = paddedPixelData; + imgDims = [paddedImageHeight, paddedImageWidth, imageChannels] + } + return [pixelData, imgDims]; + } + + /** + * Rescale the image' pixel values by `this.rescale_factor`. + * @param {Float32Array} pixelData The pixel data to rescale. + * @returns {void} + */ + rescale(pixelData) { + for (let i = 0; i < pixelData.length; ++i) { + pixelData[i] = this.rescale_factor * pixelData[i]; + } + } + + /** + * Find the target (width, height) dimension of the output image after + * resizing given the input image and the desired size. + * @param {RawImage} image The image to resize. + * @param {any} size The size to use for resizing the image. + * @returns {[number, number]} The target (width, height) dimension of the output image after resizing. + */ + get_resize_output_image_size(image, size) { + // `size` comes in many forms, so we need to handle them all here: + // 1. `size` is an integer, in which case we resize the image to be a square + + const [srcWidth, srcHeight] = image.size; + + let shortest_edge; + let longest_edge; + + if (this.do_thumbnail) { + // NOTE: custom logic for `Donut` models + const { height, width } = size; + shortest_edge = Math.min(height, width) + } + // Support both formats for backwards compatibility + else if (Number.isInteger(size)) { + shortest_edge = size; + longest_edge = this.config.max_size ?? shortest_edge; + + } else if (size !== undefined) { + // Extract known properties from `size` + shortest_edge = size.shortest_edge; + longest_edge = size.longest_edge; + } + + // If `longest_edge` and `shortest_edge` are set, maintain aspect ratio and resize to `shortest_edge` + // while keeping the largest dimension <= `longest_edge` + if (shortest_edge !== undefined || longest_edge !== undefined) { + // http://opensourcehacker.com/2011/12/01/calculate-aspect-ratio-conserving-resize-for-images-in-javascript/ + // Try resize so that shortest edge is `shortest_edge` (target) + const shortResizeFactor = shortest_edge === undefined + ? 1 // If `shortest_edge` is not set, don't upscale + : Math.max(shortest_edge / srcWidth, shortest_edge / srcHeight); + + const newWidth = srcWidth * shortResizeFactor; + const newHeight = srcHeight * shortResizeFactor; + + // The new width and height might be greater than `longest_edge`, so + // we downscale again to ensure the largest dimension is `longest_edge` + const longResizeFactor = longest_edge === undefined + ? 1 // If `longest_edge` is not set, don't downscale + : Math.min(longest_edge / newWidth, longest_edge / newHeight); + + // To avoid certain floating point precision issues, we round to 2 decimal places + let finalWidth = Math.floor(Number((newWidth * longResizeFactor).toFixed(2))); + let finalHeight = Math.floor(Number((newHeight * longResizeFactor).toFixed(2))); + + if (this.size_divisibility !== undefined) { + [finalWidth, finalHeight] = enforce_size_divisibility([finalWidth, finalHeight], this.size_divisibility) + } + return [finalWidth, finalHeight]; + + } else if (size !== undefined && size.width !== undefined && size.height !== undefined) { + // If `width` and `height` are set, resize to those dimensions + + let newWidth = size.width; + let newHeight = size.height; + + // Custom for DPT models + if (this.config.keep_aspect_ratio && this.config.ensure_multiple_of) { + + // determine new height and width + let scale_height = newHeight / srcHeight; + let scale_width = newWidth / srcWidth; + + // scale as little as possible + if (Math.abs(1 - scale_width) < Math.abs(1 - scale_height)) { + // fit width + scale_height = scale_width; + } else { + // fit height + scale_width = scale_height; + } + + newHeight = constraint_to_multiple_of(scale_height * srcHeight, this.config.ensure_multiple_of); + newWidth = constraint_to_multiple_of(scale_width * srcWidth, this.config.ensure_multiple_of); + } + + return [newWidth, newHeight]; + + } else if (this.size_divisibility !== undefined) { + return enforce_size_divisibility([srcWidth, srcHeight], this.size_divisibility); + } else { + throw new Error(`Could not resize image due to unsupported \`this.size\` option in config: ${JSON.stringify(size)}`); + } + } + + /** + * Resizes the image. + * @param {RawImage} image The image to resize. + * @returns {Promise} The resized image. + */ + async resize(image) { + const [newWidth, newHeight] = this.get_resize_output_image_size(image, this.size); + return await image.resize(newWidth, newHeight, { + resample: this.resample, + }); + } + + /** + * @typedef {object} PreprocessedImage + * @property {HeightWidth} original_size The original size of the image. + * @property {HeightWidth} reshaped_input_size The reshaped input size of the image. + * @property {Tensor} pixel_values The pixel values of the preprocessed image. + */ + + /** + * Preprocesses the given image. + * + * @param {RawImage} image The image to preprocess. + * @param {Object} overrides The overrides for the preprocessing options. + * @returns {Promise} The preprocessed image. + */ + async preprocess(image, { + do_normalize = null, + do_pad = null, + do_convert_rgb = null, + do_convert_grayscale = null, + do_flip_channel_order = null, + } = {}) { + if (this.do_crop_margin) { + // NOTE: Specific to nougat processors. This is done before resizing, + // and can be interpreted as a pre-preprocessing step. + image = await this.crop_margin(image); + } + + const [srcWidth, srcHeight] = image.size; // original image size + + // Convert image to RGB if specified in config. + if (do_convert_rgb ?? this.do_convert_rgb) { + image = image.rgb(); + } else if (do_convert_grayscale) { + image = image.grayscale(); + } + + // TODO: + // For efficiency reasons, it might be best to merge the resize and center crop operations into one. + + // Resize all images + if (this.do_resize) { + image = await this.resize(image); + } + + // Resize the image using thumbnail method. + if (this.do_thumbnail) { + image = await this.thumbnail(image, this.size, this.resample); + } + + if (this.do_center_crop) { + + let crop_width; + let crop_height; + if (Number.isInteger(this.crop_size)) { + crop_width = this.crop_size; + crop_height = this.crop_size; + } else { + crop_width = this.crop_size.width; + crop_height = this.crop_size.height; + } + + image = await image.center_crop(crop_width, crop_height); + } + + /** @type {HeightWidth} */ + const reshaped_input_size = [image.height, image.width]; + + // NOTE: All pixel-level manipulation (i.e., modifying `pixelData`) + // occurs with data in the hwc format (height, width, channels), + // to emulate the behavior of the original Python code (w/ numpy). + let pixelData = Float32Array.from(image.data); + let imgDims = [image.height, image.width, image.channels]; + + if (this.do_rescale) { + this.rescale(pixelData); + } + + if (do_normalize ?? this.do_normalize) { + let image_mean = this.image_mean; + if (!Array.isArray(this.image_mean)) { + image_mean = new Array(image.channels).fill(image_mean); + } + + let image_std = this.image_std; + if (!Array.isArray(this.image_std)) { + image_std = new Array(image.channels).fill(image_mean); + } + + if (image_mean.length !== image.channels || image_std.length !== image.channels) { + throw new Error(`When set to arrays, the length of \`image_mean\` (${image_mean.length}) and \`image_std\` (${image_std.length}) must match the number of channels in the image (${image.channels}).`); + } + + for (let i = 0; i < pixelData.length; i += image.channels) { + for (let j = 0; j < image.channels; ++j) { + pixelData[i + j] = (pixelData[i + j] - image_mean[j]) / image_std[j]; + } + } + } + + // do padding after rescaling/normalizing + if (do_pad ?? this.do_pad) { + if (this.pad_size) { + const padded = this.pad_image(pixelData, [image.height, image.width, image.channels], this.pad_size); + [pixelData, imgDims] = padded; // Update pixel data and image dimensions + } else if (this.size_divisibility) { + const [paddedWidth, paddedHeight] = enforce_size_divisibility([imgDims[1], imgDims[0]], this.size_divisibility); + [pixelData, imgDims] = this.pad_image(pixelData, imgDims, { width: paddedWidth, height: paddedHeight }); + } + } + + if (do_flip_channel_order ?? this.do_flip_channel_order) { + if (imgDims[2] !== 3) { + throw new Error('Flipping channel order is only supported for RGB images.'); + } + // Convert RGB to BGR + for (let i = 0; i < pixelData.length; i += 3) { + const temp = pixelData[i]; + pixelData[i] = pixelData[i + 2]; + pixelData[i + 2] = temp; + } + } + + const pixel_values = new Tensor('float32', pixelData, imgDims) + .permute(2, 0, 1); // convert to channel dimension format (hwc -> chw) + + return { + original_size: [srcHeight, srcWidth], + reshaped_input_size: reshaped_input_size, + pixel_values, + } + } + + /** + * Calls the feature extraction process on an array of images, + * preprocesses each image, and concatenates the resulting + * features into a single Tensor. + * @param {RawImage[]} images The image(s) to extract features from. + * @param {...any} args Additional arguments. + * @returns {Promise} An object containing the concatenated pixel values (and other metadata) of the preprocessed images. + */ + async _call(images, ...args) { + if (!Array.isArray(images)) { + images = [images]; + } + /** @type {PreprocessedImage[]} */ + const imageData = await Promise.all(images.map(x => this.preprocess(x))); + + // Stack pixel values + const pixel_values = stack(imageData.map(x => x.pixel_values), 0); + + return { + pixel_values, + + // Original sizes of images + original_sizes: imageData.map(x => x.original_size), + + // Reshaped sizes of images, before padding or cropping + reshaped_input_sizes: imageData.map(x => x.reshaped_input_size), + } + } + + + /** + * Instantiate one of the processor classes of the library from a pretrained model. + * + * The processor class to instantiate is selected based on the `image_processor_type` (or `feature_extractor_type`; legacy) + * property of the config object (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible) + * + * @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either: + * - A string, the *model id* of a pretrained processor hosted inside a model repo on huggingface.co. + * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a + * user or organization name, like `dbmdz/bert-base-german-cased`. + * - A path to a *directory* containing processor files, e.g., `./my_model_directory/`. + * @param {import('../utils/hub.js').PretrainedOptions} options Additional options for loading the processor. + * + * @returns {Promise} A new instance of the Processor class. + */ + static async from_pretrained(pretrained_model_name_or_path, options) { + const preprocessorConfig = await getModelJSON(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME, true, options); + return new this(preprocessorConfig); + } +} diff --git a/src/base/processing_utils.js b/src/base/processing_utils.js new file mode 100644 index 000000000..e782f5c61 --- /dev/null +++ b/src/base/processing_utils.js @@ -0,0 +1,131 @@ + +/** + * @file Processors are used to prepare inputs (e.g., text, image or audio) for a model. + * + * **Example:** Using a `WhisperProcessor` to prepare an audio input for a model. + * ```javascript + * import { AutoProcessor, read_audio } from '@huggingface/transformers'; + * + * const processor = await AutoProcessor.from_pretrained('openai/whisper-tiny.en'); + * const audio = await read_audio('https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac', 16000); + * const { input_features } = await processor(audio); + * // Tensor { + * // data: Float32Array(240000) [0.4752984642982483, 0.5597258806228638, 0.56434166431427, ...], + * // dims: [1, 80, 3000], + * // type: 'float32', + * // size: 240000, + * // } + * ``` + * + * @module processors + */ +import { PROCESSOR_NAME } from '../utils/constants.js'; +import { + Callable, +} from '../utils/generic.js'; +import { getModelJSON } from '../utils/hub.js'; + +/** + * @typedef {Object} ProcessorProperties Additional processor-specific properties. + * @typedef {import('../utils/hub.js').PretrainedOptions & ProcessorProperties} PretrainedProcessorOptions + */ + + +/** + * Represents a Processor that extracts features from an input. + */ +export class Processor extends Callable { + static classes = [ + 'image_processor_class', + 'tokenizer_class', + 'feature_extractor_class', + ] + static uses_processor_config = false; + + /** + * Creates a new Processor with the given components + * @param {Object} config + * @param {Record} components + */ + constructor(config, components) { + super(); + this.config = config; + this.components = components; + } + + /** + * @returns {import('./image_processors_utils.js').ImageProcessor|undefined} The image processor of the processor, if it exists. + */ + get image_processor() { + return this.components.image_processor; + } + + /** + * @returns {import('../tokenizers.js').PreTrainedTokenizer|undefined} The tokenizer of the processor, if it exists. + */ + get tokenizer() { + return this.components.tokenizer; + } + + /** + * @returns {import('./feature_extraction_utils.js').FeatureExtractor|undefined} The feature extractor of the processor, if it exists. + */ + get feature_extractor() { + return this.components.feature_extractor; + } + + /** + * Calls the feature_extractor function with the given input. + * @param {any} input The input to extract features from. + * @param {...any} args Additional arguments. + * @returns {Promise} A Promise that resolves with the extracted features. + */ + async _call(input, ...args) { + for (const item of [this.image_processor, this.feature_extractor, this.tokenizer]) { + if (item) { + return item(input, ...args); + } + } + throw new Error('No image processor, feature extractor, or tokenizer found.'); + } + + + /** + * Instantiate one of the processor classes of the library from a pretrained model. + * + * The processor class to instantiate is selected based on the `feature_extractor_type` property of the config object + * (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible) + * + * @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either: + * - A string, the *model id* of a pretrained processor hosted inside a model repo on huggingface.co. + * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a + * user or organization name, like `dbmdz/bert-base-german-cased`. + * - A path to a *directory* containing processor files, e.g., `./my_model_directory/`. + * @param {PretrainedProcessorOptions} options Additional options for loading the processor. + * + * @returns {Promise} A new instance of the Processor class. + */ + static async from_pretrained(pretrained_model_name_or_path, options) { + + // console.log('FROM PRETRAINED'); + // console.log(this.classes); + // console.log(this.classes.map((cls) => cls in this)); + + const [config, components] = await Promise.all([ + // TODO: + this.uses_processor_config + ? getModelJSON(pretrained_model_name_or_path, PROCESSOR_NAME, true, options) + : {}, + Promise.all( + this.classes + .filter((cls) => cls in this) + .map(async (cls) => { + const component = await this[cls].from_pretrained(pretrained_model_name_or_path, options); + return [cls.replace(/_class$/,''), component]; + }) + ).then(Object.fromEntries) + ]); + + return new this(config, components); + } +} diff --git a/src/configs.js b/src/configs.js index 4bc95cf80..033e311fa 100644 --- a/src/configs.js +++ b/src/configs.js @@ -69,6 +69,9 @@ function getNormalizedConfig(config) { case 'musicgen': init_normalized_config = getNormalizedConfig(config.decoder); break; + case 'multi_modality': + init_normalized_config = getNormalizedConfig(config.language_config); + break; // Decoder-only models case 'gpt2': @@ -216,14 +219,12 @@ function getNormalizedConfig(config) { */ export function getKeyValueShapes(config, { prefix = 'past_key_values', + batch_size=1, } = {}) { /** @type {Record} */ const decoderFeeds = {}; const normalized_config = config.normalized_config; - // TODO support batches (i.e., batch_size > 1) - const batch_size = 1; - if (normalized_config.is_encoder_decoder && ( 'num_encoder_heads' in normalized_config && 'num_decoder_heads' in normalized_config )) { diff --git a/src/models.js b/src/models.js index d357a83e4..14e2aaee5 100644 --- a/src/models.js +++ b/src/models.js @@ -61,7 +61,6 @@ import { } from './utils/generic.js'; import { - isIntegralNumber, mergeArrays, pick, } from './utils/core.js'; @@ -99,6 +98,7 @@ import { import { cat, + full, full_like, mean, ones, @@ -108,6 +108,7 @@ import { Tensor, zeros_like, } from './utils/tensor.js'; +import { RawImage } from './utils/image.js'; import { dynamic_time_warping, medianFilter } from './utils/maths.js'; import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js'; @@ -128,6 +129,7 @@ const MODEL_TYPES = { MaskGeneration: 5, ImageTextToText: 6, Musicgen: 7, + MultiModality: 8, } ////////////////////////////////////////////////// @@ -386,7 +388,7 @@ async function sessionRun(session, inputs) { } catch (e) { // This usually occurs when the inputs are of the wrong type. console.error(`An error occurred during model execution: "${e}".`); - console.error('Inputs given to model:', checkedInputs); + console.error('Inputs given to model:', checkedInputs) throw e; } } @@ -716,6 +718,52 @@ function image_text_to_text_prepare_inputs_for_generation(self, ...args) { } } +function multimodality_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) { + const has_past_key_values = !!model_inputs.past_key_values; + + if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) { + if (has_past_key_values) { + model_inputs.input_ids = cat([ + model_inputs.input_ids, + model_inputs.input_ids, + ], 0) + // NOTE: attention_mask handled in generation + } else { + model_inputs.input_ids = cat([ + model_inputs.input_ids, + full_like(model_inputs.input_ids, BigInt(generation_config.pad_token_id)), + ], 0); + model_inputs.attention_mask = cat([ + model_inputs.attention_mask, + full_like(model_inputs.attention_mask, 0n), + ], 0); + } + } + + if (has_past_key_values || !model_inputs.pixel_values) { + model_inputs.pixel_values = full([0, 0, 3, 384, 384], 1.0); + } + + if (has_past_key_values) { + const num_img_tokens = 0; + const num_text_tokens = 1; + const has_image = num_img_tokens > 0 ? 1 : 0; + + const batch_size = 1; + model_inputs.images_seq_mask = new Tensor( + 'bool', + new Array(num_img_tokens + num_text_tokens).fill(true).fill(false, 0, num_text_tokens), + [batch_size, num_img_tokens + num_text_tokens], + ); + model_inputs.images_emb_mask = new Tensor( + 'bool', + new Array(num_img_tokens).fill(!!has_image), + [batch_size, 1, num_img_tokens], + ); + } + return model_inputs; +} + ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -769,6 +817,11 @@ export class PreTrainedModel extends Callable { this._prepare_inputs_for_generation = image_text_to_text_prepare_inputs_for_generation; break; + case MODEL_TYPES.MultiModality: + this.can_generate = true; + this._prepare_inputs_for_generation = multimodality_prepare_inputs_for_generation; + break; + default: // should be MODEL_TYPES.EncoderOnly this._forward = encoderForward; @@ -912,6 +965,21 @@ export class PreTrainedModel extends Callable { }, options), ]); + } else if (modelType === MODEL_TYPES.MultiModality) { + info = await Promise.all([ + constructSessions(pretrained_model_name_or_path, { + prepare_inputs_embeds: 'prepare_inputs_embeds', + model: 'language_model', + lm_head: 'lm_head', + gen_head: 'gen_head', + gen_img_embeds: 'gen_img_embeds', + image_decode: 'image_decode', + }, options), + getOptionalConfigs(pretrained_model_name_or_path, { + generation_config: 'generation_config.json', + }, options), + ]); + } else { // should be MODEL_TYPES.EncoderOnly if (modelType !== MODEL_TYPES.EncoderOnly) { console.warn(`Model type for '${modelName ?? config?.model_type}' not found, assuming encoder-only architecture. Please report this at ${GITHUB_ISSUE_URL}.`) @@ -1658,7 +1726,8 @@ export class PreTrainedModel extends Callable { const dtype = session?.config?.kv_cache_dtype ?? 'float32'; const empty = (dtype === 'float16') ? new Uint16Array() : []; - const shapes = getKeyValueShapes(this.config); + const batch_size = (decoderFeeds[this.main_input_name] ?? decoderFeeds.attention_mask).dims?.[0] ?? 1; + const shapes = getKeyValueShapes(this.config, { batch_size }); for (const name in shapes) { decoderFeeds[name] = new Tensor(dtype, empty, shapes[name]); @@ -5954,6 +6023,111 @@ export class DecisionTransformerModel extends DecisionTransformerPreTrainedModel ////////////////////////////////////////////////// +export class MultiModalityPreTrainedModel extends PreTrainedModel { } +export class MultiModalityCausalLM extends MultiModalityPreTrainedModel { + forward_params = [ + // prepare_inputs_embeds + 'input_ids', + 'pixel_values', + 'images_seq_mask', + 'images_emb_mask', + + // language_model + 'attention_mask', + 'position_ids', + 'past_key_values', + ]; + + constructor(...args) { + super(...args); + + // State-based approach to switch out which heads to use during generation + this._generation_mode = 'text'; + } + + async forward(model_inputs) { + const mode = this._generation_mode ?? 'text'; + + // TODO support re-using PKVs for input_ids.dims[1] !== 1 + // if (model_inputs.past_key_values) { + // // && model_inputs.input_ids.dims[1] === 1 + // } + + let output_1; + if (mode === 'text' || !model_inputs.past_key_values) { + const session = this.sessions['prepare_inputs_embeds']; + const prep_inputs = pick(model_inputs, session.inputNames); + output_1 = await sessionRun(session, prep_inputs); + } else { + const session = this.sessions['gen_img_embeds']; + const prep_inputs = pick({ + image_ids: model_inputs.input_ids, + }, session.inputNames); + output_1 = await sessionRun(session, prep_inputs); + } + + const input_2 = { ...model_inputs, ...output_1 } + const output_2 = await decoderForward(this, input_2); + + const head = this.sessions[ + mode === 'text' + ? 'lm_head' + : 'gen_head' + ]; + if (!head) { + throw new Error(`Unable to find "${head}" generation head`); + } + + const output_3 = await sessionRun(head, pick(output_2, head.inputNames)) + + return { + ...output_1, + ...output_2, + ...output_3, + }; + } + + /** + * @param {import('./generation/parameters.js').GenerationFunctionParameters} options + */ + async generate(options) { + this._generation_mode = 'text'; + return super.generate(options); + } + + /** + * @param {import('./generation/parameters.js').GenerationFunctionParameters} options + */ + async generate_images(options) { + this._generation_mode = 'image'; + + const start_num_tokens = (options.inputs ?? options[this.main_input_name]).dims[1]; + const all_tokens = await super.generate(options); + + const generated_tokens = (/** @type {Tensor} */(all_tokens)).slice(null, [start_num_tokens, null]) + + const image_decode = this.sessions['image_decode']; + const { decoded_image } = await sessionRun(image_decode, { + generated_tokens, + }); + + // Equivalent to `np.clip((dec + 1) / 2 * 255, 0, 255)` + const clamped = decoded_image + .add_(1) + .mul_(255 / 2) + .clamp_(0, 255) + .to('uint8'); + + // Return as a list of images + const images = []; + for (const tensor of clamped) { + const img = RawImage.fromTensor(tensor); + images.push(img); + } + return images; + } +} + ////////////////////////////////////////////////// // AutoModels, used to simplify construction of PreTrainedModels // (uses config to instantiate correct class) @@ -6232,6 +6406,11 @@ const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([ ['stablelm', ['StableLmForCausalLM', StableLmForCausalLM]], ]); +const MODEL_FOR_MULTIMODALITY_MAPPING_NAMES = new Map([ + ['multi_modality', ['MultiModalityCausalLM', MultiModalityCausalLM]], +]); + + const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([ ['bert', ['BertForMaskedLM', BertForMaskedLM]], ['roformer', ['RoFormerForMaskedLM', RoFormerForMaskedLM]], @@ -6404,6 +6583,7 @@ const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq], [MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Seq2Seq], [MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES.DecoderOnly], + [MODEL_FOR_MULTIMODALITY_MAPPING_NAMES, MODEL_TYPES.MultiModality], [MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Vision2Seq], diff --git a/src/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.js b/src/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.js new file mode 100644 index 000000000..9533f47b1 --- /dev/null +++ b/src/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.js @@ -0,0 +1,90 @@ +import { FeatureExtractor, validate_audio_inputs } from '../../base/feature_extraction_utils.js'; +import { Tensor } from '../../utils/tensor.js'; +import { mel_filter_bank, spectrogram, window_function } from '../../utils/audio.js'; + + +export class ASTFeatureExtractor extends FeatureExtractor { + + constructor(config) { + super(config); + + const sampling_rate = this.config.sampling_rate; + const mel_filters = mel_filter_bank( + 256, // num_frequency_bins + this.config.num_mel_bins, // num_mel_filters + 20, // min_frequency + Math.floor(sampling_rate / 2), // max_frequency + sampling_rate, // sampling_rate + null, // norm + "kaldi", // mel_scale + true, // triangularize_in_mel_space + ); + + // Do padding: + for (let i = 0; i < mel_filters.length; ++i) { + mel_filters[i].push(0); + } + this.mel_filters = mel_filters; + + this.window = window_function(400, 'hann', { + periodic: false, + }) + + this.mean = this.config.mean; + this.std = this.config.std; + } + + /** + * Computes the log-Mel spectrogram of the provided audio waveform. + * @param {Float32Array|Float64Array} waveform The audio waveform to process. + * @param {number} max_length The maximum number of frames to return. + * @returns {Promise} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers. + */ + async _extract_fbank_features(waveform, max_length) { + // NOTE: We don't pad/truncate since that is passed in as `max_num_frames` + return spectrogram( + waveform, + this.window, // window + 400, // frame_length + 160, // hop_length + { + fft_length: 512, + power: 2.0, + center: false, + preemphasis: 0.97, + mel_filters: this.mel_filters, + log_mel: 'log', + mel_floor: 1.192092955078125e-07, + remove_dc_offset: true, + + // Custom + max_num_frames: max_length, + transpose: true, + } + ) + } + + + /** + * Asynchronously extracts features from a given audio using the provided configuration. + * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array. + * @returns {Promise<{ input_values: Tensor }>} A Promise resolving to an object containing the extracted input features as a Tensor. + */ + async _call(audio) { + validate_audio_inputs(audio, 'ASTFeatureExtractor'); + + const features = await this._extract_fbank_features(audio, this.config.max_length); + if (this.config.do_normalize) { + // Normalize the input audio spectrogram to have mean=0, std=0.5 + const denom = this.std * 2; + const features_data = features.data; + for (let i = 0; i < features_data.length; ++i) { + features_data[i] = (features_data[i] - this.mean) / denom; + } + } + + return { + input_values: features.unsqueeze_(0) + }; + } +} diff --git a/src/models/auto/feature_extraction_auto.js b/src/models/auto/feature_extraction_auto.js new file mode 100644 index 000000000..5a18eabb9 --- /dev/null +++ b/src/models/auto/feature_extraction_auto.js @@ -0,0 +1,41 @@ + +import { FEATURE_EXTRACTOR_NAME, GITHUB_ISSUE_URL } from '../../utils/constants.js'; +import { getModelJSON } from '../../utils/hub.js'; +import { FeatureExtractor } from '../../base/feature_extraction_utils.js'; +import * as AllFeatureExtractors from '../feature_extractors.js'; + +export class AutoFeatureExtractor { + + /** + * Instantiate one of the feature extractor classes of the library from a pretrained model. + * + * The processor class to instantiate is selected based on the `feature_extractor_type` property of + * the config object (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible) + * + * @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either: + * - A string, the *model id* of a pretrained processor hosted inside a model repo on huggingface.co. + * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a + * user or organization name, like `dbmdz/bert-base-german-cased`. + * - A path to a *directory* containing processor files, e.g., `./my_model_directory/`. + * @param {import('../../utils/hub.js').PretrainedOptions} options Additional options for loading the processor. + * + * @returns {Promise} A new instance of the Processor class. + */ + + /** @type {typeof FeatureExtractor.from_pretrained} */ + static async from_pretrained(pretrained_model_name_or_path, options={}) { + + const preprocessorConfig = await getModelJSON(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, true, options); + + // Determine feature extractor class + const key = preprocessorConfig.feature_extractor_type; + const feature_extractor_class = AllFeatureExtractors[key]; + + if (!feature_extractor_class) { + throw new Error(`Unknown feature_extractor_type: '${key}'. Please report this at ${GITHUB_ISSUE_URL}.`); + } + + // Instantiate feature extractor + return new feature_extractor_class(preprocessorConfig); + } +} diff --git a/src/models/auto/image_processing_auto.js b/src/models/auto/image_processing_auto.js new file mode 100644 index 000000000..07f6c1a0d --- /dev/null +++ b/src/models/auto/image_processing_auto.js @@ -0,0 +1,29 @@ + +import { GITHUB_ISSUE_URL, IMAGE_PROCESSOR_NAME } from '../../utils/constants.js'; +import { getModelJSON } from '../../utils/hub.js'; +import { ImageProcessor } from '../../base/image_processors_utils.js'; +import * as AllImageProcessors from '../image_processors.js'; + +export class AutoImageProcessor { + + /** @type {typeof ImageProcessor.from_pretrained} */ + static async from_pretrained(pretrained_model_name_or_path, options={}) { + + const preprocessorConfig = await getModelJSON(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME, true, options); + + // Determine image processor class + const key = preprocessorConfig.image_processor_type ?? preprocessorConfig.feature_extractor_type; + let image_processor_class = AllImageProcessors[key]; + + if (!image_processor_class) { + if (key !== undefined) { + // Only log a warning if the class is not found and the key is set. + console.warn(`Image processor type '${key}' not found, assuming base ImageProcessor. Please report this at ${GITHUB_ISSUE_URL}.`) + } + image_processor_class = ImageProcessor; + } + + // Instantiate image processor + return new image_processor_class(preprocessorConfig); + } +} diff --git a/src/models/auto/processing_auto.js b/src/models/auto/processing_auto.js new file mode 100644 index 000000000..3b462b6e9 --- /dev/null +++ b/src/models/auto/processing_auto.js @@ -0,0 +1,100 @@ + + +import { IMAGE_PROCESSOR_NAME } from '../../utils/constants.js'; +import { getModelJSON } from '../../utils/hub.js'; +import { Processor } from '../../base/processing_utils.js'; + +import * as AllProcessors from '../processors.js'; +import * as AllImageProcessors from '../image_processors.js'; +import * as AllFeatureExtractors from '../feature_extractors.js'; + +/** + * Helper class which is used to instantiate pretrained processors with the `from_pretrained` function. + * The chosen processor class is determined by the type specified in the processor config. + * + * **Example:** Load a processor using `from_pretrained`. + * ```javascript + * let processor = await AutoProcessor.from_pretrained('openai/whisper-tiny.en'); + * ``` + * + * **Example:** Run an image through a processor. + * ```javascript + * let processor = await AutoProcessor.from_pretrained('Xenova/clip-vit-base-patch16'); + * let image = await RawImage.read('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg'); + * let image_inputs = await processor(image); + * // { + * // "pixel_values": { + * // "dims": [ 1, 3, 224, 224 ], + * // "type": "float32", + * // "data": Float32Array [ -1.558687686920166, -1.558687686920166, -1.5440893173217773, ... ], + * // "size": 150528 + * // }, + * // "original_sizes": [ + * // [ 533, 800 ] + * // ], + * // "reshaped_input_sizes": [ + * // [ 224, 224 ] + * // ] + * // } + * ``` + */ +export class AutoProcessor { + + /** + * Instantiate one of the processor classes of the library from a pretrained model. + * + * The processor class to instantiate is selected based on the `image_processor_type` (or `feature_extractor_type`; legacy) + * property of the config object (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible) + * + * @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either: + * - A string, the *model id* of a pretrained processor hosted inside a model repo on huggingface.co. + * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a + * user or organization name, like `dbmdz/bert-base-german-cased`. + * - A path to a *directory* containing processor files, e.g., `./my_model_directory/`. + * @param {import('../../utils/hub.js').PretrainedOptions} options Additional options for loading the processor. + * + * @returns {Promise} A new instance of the Processor class. + */ + + /** @type {typeof Processor.from_pretrained} */ + static async from_pretrained(pretrained_model_name_or_path, options={}) { + + // TODO: first check for processor.json + const preprocessorConfig = await getModelJSON(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME, true, options); + + const { image_processor_type, feature_extractor_type, processor_class } = preprocessorConfig; + if (processor_class && AllProcessors[processor_class]) { + return AllProcessors[processor_class].from_pretrained(pretrained_model_name_or_path, options); + } + + if (!image_processor_type && !feature_extractor_type) { + throw new Error('No `image_processor_type` or `feature_extractor_type` found in the config.'); + } + + const components = {}; + if (image_processor_type) { + const image_processor_class = AllImageProcessors[image_processor_type]; + if (!image_processor_class) { + throw new Error(`Unknown image_processor_type: '${image_processor_type}'.`); + } + components.image_processor = new image_processor_class(preprocessorConfig); + } + + if (feature_extractor_type) { + const image_processor_class = AllImageProcessors[feature_extractor_type]; + if (image_processor_class) { + // Handle legacy case where image processors were specified as feature extractors + components.image_processor = new image_processor_class(preprocessorConfig); + } else { + const feature_extractor_class = AllFeatureExtractors[feature_extractor_type]; + if (!feature_extractor_class) { + throw new Error(`Unknown feature_extractor_type: '${feature_extractor_type}'.`); + } + components.feature_extractor = new feature_extractor_class(preprocessorConfig); + } + } + + const config = {}; + return new Processor(config, components); + } +} diff --git a/src/models/beit/image_processing_beit.js b/src/models/beit/image_processing_beit.js new file mode 100644 index 000000000..006399edf --- /dev/null +++ b/src/models/beit/image_processing_beit.js @@ -0,0 +1,5 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + +export class BeitFeatureExtractor extends ImageProcessor { } diff --git a/src/models/bit/image_processing_bit.js b/src/models/bit/image_processing_bit.js new file mode 100644 index 000000000..66db82277 --- /dev/null +++ b/src/models/bit/image_processing_bit.js @@ -0,0 +1,5 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + +export class BitImageProcessor extends ImageProcessor { } diff --git a/src/models/chinese_clip/image_processing_chinese_clip.js b/src/models/chinese_clip/image_processing_chinese_clip.js new file mode 100644 index 000000000..d720eb662 --- /dev/null +++ b/src/models/chinese_clip/image_processing_chinese_clip.js @@ -0,0 +1,5 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + +export class ChineseCLIPFeatureExtractor extends ImageProcessor { } diff --git a/src/models/clap/feature_extraction_clap.js b/src/models/clap/feature_extraction_clap.js new file mode 100644 index 000000000..5261a10b5 --- /dev/null +++ b/src/models/clap/feature_extraction_clap.js @@ -0,0 +1,159 @@ +import { FeatureExtractor, validate_audio_inputs } from '../../base/feature_extraction_utils.js'; +import { Tensor } from '../../utils/tensor.js'; +import { mel_filter_bank, spectrogram, window_function } from '../../utils/audio.js'; + + +export class ClapFeatureExtractor extends FeatureExtractor { + + constructor(config) { + super(config); + + this.mel_filters = mel_filter_bank( + this.config.nb_frequency_bins, // num_frequency_bins + this.config.feature_size, // num_mel_filters + this.config.frequency_min, // min_frequency + this.config.frequency_max, // max_frequency + this.config.sampling_rate, // sampling_rate + null, // norm + "htk", // mel_scale + ); + + this.mel_filters_slaney = mel_filter_bank( + this.config.nb_frequency_bins, // num_frequency_bins + this.config.feature_size, // num_mel_filters + this.config.frequency_min, // min_frequency + this.config.frequency_max, // max_frequency + this.config.sampling_rate, // sampling_rate + "slaney", // norm + "slaney", // mel_scale + ); + + this.window = window_function(this.config.fft_window_size, 'hann') + + } + + + /** + * Extracts the mel spectrogram and prepares it for the mode based on the `truncation` and `padding` arguments. + * + * Four different path are possible: + * - `truncation="fusion"` and the length of the waveform is greater than the max length: the mel spectrogram + * will be computed on the entire audio. 3 random crops and a dowsampled version of the full mel spectrogram + * are then stacked together. They will later be used for `feature_fusion`. + * - `truncation="rand_trunc"` and the length of the waveform is smaller than the max length: the audio is + * padded based on `padding`. + * - `truncation="fusion"` and the length of the waveform is smaller than the max length: the audio is padded + * based on `padding`, and is repeated `4` times. + * - `truncation="rand_trunc"` and the length of the waveform is greater than the max length: the mel + * spectrogram will be computed on a random crop of the waveform. + * + * @param {Float32Array|Float64Array} waveform The input waveform. + * @param {number} max_length The maximum length of the waveform. + * @param {string} truncation The truncation strategy to use. + * @param {string} padding The padding strategy to use. + * @returns {Promise} An object containing the mel spectrogram data as a Float32Array, its dimensions as an array of numbers, and a boolean indicating whether the waveform was longer than the max length. + * @private + */ + async _get_input_mel(waveform, max_length, truncation, padding) { + + /** @type {Tensor} */ + let input_mel; + let longer = false; + const diff = waveform.length - max_length; + if (diff > 0) { + if (truncation === 'rand_trunc') { + longer = true; + const idx = Math.floor(Math.random() * (diff + 1)); + waveform = waveform.subarray(idx, idx + max_length); + + input_mel = await this._extract_fbank_features(waveform, this.mel_filters_slaney, this.config.nb_max_samples); + } else { + // TODO implement fusion strategy + throw new Error(`Truncation strategy "${truncation}" not implemented`) + } + } else { + if (diff < 0) { + let padded = new Float64Array(max_length); // already padded with zeros + padded.set(waveform); + + if (padding === 'repeat') { + for (let i = waveform.length; i < max_length; i += waveform.length) { + padded.set(waveform.subarray(0, Math.min(waveform.length, max_length - i)), i); + } + } else if (padding === 'repeatpad') { + for (let i = waveform.length; i < -diff; i += waveform.length) { + padded.set(waveform, i); + } + } + waveform = padded; + } + + if (truncation === 'fusion') { + throw new Error(`Truncation strategy "${truncation}" not implemented`) + } + + input_mel = await this._extract_fbank_features(waveform, this.mel_filters_slaney, this.config.nb_max_samples); + } + + return input_mel.unsqueeze_(0); + } + + /** + * Compute the log-mel spectrogram of the provided `waveform` using the Hann window. + * In CLAP, two different filter banks are used depending on the truncation pattern: + * - `self.mel_filters`: they correspond to the default parameters of `torchaudio` which can be obtained from + * calling `torchaudio.transforms.MelSpectrogram().mel_scale.fb`. These filters are used when `truncation` + * is set to `"fusion"`. + * - `self.mel_filteres_slaney` : they correspond to the default parameters of `librosa` which used + * `librosa.filters.mel` when computing the mel spectrogram. These filters were only used in the original + * implementation when the truncation mode is not `"fusion"`. + * + * @param {Float32Array|Float64Array} waveform The audio waveform to process. + * @param {number[][]} mel_filters The mel filters to use. + * @param {number} [max_length=null] The maximum number of frames to return. + * @returns {Promise} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers. + */ + async _extract_fbank_features(waveform, mel_filters, max_length = null) { + // NOTE: We don't pad/truncate since that is passed in as `max_num_frames` + return spectrogram( + waveform, + this.window, // window + this.config.fft_window_size, // frame_length + this.config.hop_length, // hop_length + { + power: 2.0, + mel_filters, + log_mel: 'dB', + + // Custom + max_num_frames: max_length, + do_pad: false, + transpose: true, + } + ) + } + + + /** + * Asynchronously extracts features from a given audio using the provided configuration. + * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array. + * @returns {Promise<{ input_features: Tensor }>} A Promise resolving to an object containing the extracted input features as a Tensor. + */ + async _call(audio, { + max_length = null, + } = {}) { + validate_audio_inputs(audio, 'ClapFeatureExtractor'); + + // convert to mel spectrogram, truncate and pad if needed. + const padded_inputs = await this._get_input_mel( + audio, + max_length ?? this.config.nb_max_samples, + this.config.truncation, + this.config.padding, + ); + + return { + input_features: padded_inputs.unsqueeze_(0), + } + } +} diff --git a/src/models/clip/image_processing_clip.js b/src/models/clip/image_processing_clip.js new file mode 100644 index 000000000..3f2f9dcb0 --- /dev/null +++ b/src/models/clip/image_processing_clip.js @@ -0,0 +1,6 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + +export class CLIPImageProcessor extends ImageProcessor { } +export class CLIPFeatureExtractor extends CLIPImageProcessor { } diff --git a/src/models/convnext/image_processing_convnext.js b/src/models/convnext/image_processing_convnext.js new file mode 100644 index 000000000..525e736cd --- /dev/null +++ b/src/models/convnext/image_processing_convnext.js @@ -0,0 +1,45 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + +export class ConvNextImageProcessor extends ImageProcessor { + constructor(config) { + super(config); + + /** + * Percentage of the image to crop. Only has an effect if this.size < 384. + */ + this.crop_pct = this.config.crop_pct ?? (224 / 256); + } + + async resize(image) { + const shortest_edge = this.size?.shortest_edge; + if (shortest_edge === undefined) { + throw new Error(`Size dictionary must contain 'shortest_edge' key.`); + } + + if (shortest_edge < 384) { + // maintain same ratio, resizing shortest edge to shortest_edge/crop_pct + const resize_shortest_edge = Math.floor(shortest_edge / this.crop_pct); + + const [newWidth, newHeight] = this.get_resize_output_image_size(image, { + shortest_edge: resize_shortest_edge, + }); + + image = await image.resize(newWidth, newHeight, { + resample: this.resample, + }); + + // then crop to (shortest_edge, shortest_edge) + image = await image.center_crop(shortest_edge, shortest_edge); + } else { + // warping (no cropping) when evaluated at 384 or larger + image = await image.resize(shortest_edge, shortest_edge, { + resample: this.resample, + }); + } + + return image; + } +} +export class ConvNextFeatureExtractor extends ConvNextImageProcessor { } diff --git a/src/models/deit/image_processing_deit.js b/src/models/deit/image_processing_deit.js new file mode 100644 index 000000000..fd3857842 --- /dev/null +++ b/src/models/deit/image_processing_deit.js @@ -0,0 +1,6 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + +export class DeiTImageProcessor extends ImageProcessor { } +export class DeiTFeatureExtractor extends DeiTImageProcessor { } \ No newline at end of file diff --git a/src/models/detr/image_processing_detr.js b/src/models/detr/image_processing_detr.js new file mode 100644 index 000000000..40ce1232f --- /dev/null +++ b/src/models/detr/image_processing_detr.js @@ -0,0 +1,52 @@ +import { + ImageProcessor, + post_process_object_detection, + post_process_panoptic_segmentation, + post_process_instance_segmentation, +} from "../../base/image_processors_utils.js"; + +import { full } from '../../utils/tensor.js'; + + +/** + * @typedef {object} DetrFeatureExtractorResultProps + * @property {import('../../utils/tensor.js').Tensor} pixel_mask + * @typedef {import('../../base/image_processors_utils.js').ImageProcessorResult & DetrFeatureExtractorResultProps} DetrFeatureExtractorResult + */ + +export class DetrImageProcessor extends ImageProcessor { + /** + * Calls the feature extraction process on an array of images, preprocesses + * each image, and concatenates the resulting features into a single Tensor. + * @param {import('../../utils/image.js').RawImage[]} images The image(s) to extract features from. + * @returns {Promise} An object containing the concatenated pixel values of the preprocessed images. + */ + async _call(images) { + const result = await super._call(images); + + // TODO support differently-sized images, for now assume all images are the same size. + // TODO support different mask sizes (not just 64x64) + // Currently, just fill pixel mask with 1s + const maskSize = [result.pixel_values.dims[0], 64, 64]; + const pixel_mask = full(maskSize, 1n); + + return { ...result, pixel_mask }; + } + + /** @type {typeof post_process_object_detection} */ + post_process_object_detection(...args) { + return post_process_object_detection(...args); + } + + /** @type {typeof post_process_panoptic_segmentation} */ + post_process_panoptic_segmentation(...args) { + return post_process_panoptic_segmentation(...args); + } + + /** @type {typeof post_process_instance_segmentation} */ + post_process_instance_segmentation(...args) { + return post_process_instance_segmentation(...args); + } +} + +export class DetrFeatureExtractor extends DetrImageProcessor { } // NOTE: extends DetrImageProcessor diff --git a/src/models/donut/image_processing_donut.js b/src/models/donut/image_processing_donut.js new file mode 100644 index 000000000..f848a9fa5 --- /dev/null +++ b/src/models/donut/image_processing_donut.js @@ -0,0 +1,31 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + +export class DonutImageProcessor extends ImageProcessor { + pad_image(pixelData, imgDims, padSize, options = {}) { + const [imageHeight, imageWidth, imageChannels] = imgDims; + + let image_mean = this.image_mean; + if (!Array.isArray(this.image_mean)) { + image_mean = new Array(imageChannels).fill(image_mean); + } + + let image_std = this.image_std; + if (!Array.isArray(image_std)) { + image_std = new Array(imageChannels).fill(image_mean); + } + + const constant_values = image_mean.map((x, i) => - x / image_std[i]); + + return super.pad_image(pixelData, imgDims, padSize, { + center: true, + + // Since normalization is done after padding, we need to use certain constant values to ensure the same behaviour is observed. + // For more information, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/donut/image_processing_donut.py#L433-L451 + constant_values, + ...options, + }); + } +} +export class DonutFeatureExtractor extends DonutImageProcessor { } diff --git a/src/models/dpt/image_processing_dpt.js b/src/models/dpt/image_processing_dpt.js new file mode 100644 index 000000000..0c19175e7 --- /dev/null +++ b/src/models/dpt/image_processing_dpt.js @@ -0,0 +1,6 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + +export class DPTImageProcessor extends ImageProcessor { } +export class DPTFeatureExtractor extends DPTImageProcessor { } // NOTE: extends DPTImageProcessor diff --git a/src/models/efficientnet/image_processing_efficientnet.js b/src/models/efficientnet/image_processing_efficientnet.js new file mode 100644 index 000000000..9fde87156 --- /dev/null +++ b/src/models/efficientnet/image_processing_efficientnet.js @@ -0,0 +1,13 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + +export class EfficientNetImageProcessor extends ImageProcessor { + constructor(config) { + super(config); + this.include_top = this.config.include_top ?? true; + if (this.include_top) { + this.image_std = this.image_std.map(x => x * x); + } + } +} diff --git a/src/models/feature_extractors.js b/src/models/feature_extractors.js new file mode 100644 index 000000000..6364da7a8 --- /dev/null +++ b/src/models/feature_extractors.js @@ -0,0 +1,9 @@ + +export * from './audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.js'; +export * from './clap/feature_extraction_clap.js'; +export * from './pyannote/feature_extraction_pyannote.js'; +export * from './seamless_m4t/feature_extraction_seamless_m4t.js'; +export * from './speecht5/feature_extraction_speecht5.js'; +export * from './wav2vec2/feature_extraction_wav2vec2.js'; +export * from './wespeaker/feature_extraction_wespeaker.js'; +export * from './whisper/feature_extraction_whisper.js'; diff --git a/src/models/florence2/processing_florence2.js b/src/models/florence2/processing_florence2.js new file mode 100644 index 000000000..ec644df25 --- /dev/null +++ b/src/models/florence2/processing_florence2.js @@ -0,0 +1,128 @@ +import { Processor } from "../../base/processing_utils.js"; +import { AutoImageProcessor } from "../auto/image_processing_auto.js"; +import { AutoTokenizer } from "../../tokenizers.js"; + +export class Florence2Processor extends Processor { + static tokenizer_class = AutoTokenizer + static image_processor_class = AutoImageProcessor + + constructor(config, components) { + super(config, components); + + const { + tasks_answer_post_processing_type, + task_prompts_without_inputs, + task_prompts_with_input, + } = this.image_processor.config; + + /** @type {Map} */ + this.tasks_answer_post_processing_type = new Map(Object.entries(tasks_answer_post_processing_type ?? {})); + + /** @type {Map} */ + this.task_prompts_without_inputs = new Map(Object.entries(task_prompts_without_inputs ?? {})); + + /** @type {Map} */ + this.task_prompts_with_input = new Map(Object.entries(task_prompts_with_input ?? {})); + + this.regexes = { + quad_boxes: /(.+?)/gm, + bboxes: /([^<]+)?/gm, + } + this.size_per_bin = 1000; + } + + /** + * Helper function to construct prompts from input texts + * @param {string|string[]} text + * @returns {string[]} + */ + construct_prompts(text) { + if (typeof text === 'string') { + text = [text]; + } + + const prompts = []; + for (const t of text) { + // 1. fixed task prompts without additional inputs + if (this.task_prompts_without_inputs.has(t)) { + prompts.push(this.task_prompts_without_inputs.get(t)); + } + // 2. task prompts with additional inputs + else { + for (const [task, prompt] of this.task_prompts_with_input) { + if (t.includes(task)) { + prompts.push(prompt.replaceAll('{input}', t).replaceAll(task, '')); + break; + } + } + + // 3. default prompt + if (prompts.length !== text.length) { + prompts.push(t); + } + } + } + return prompts; + } + + /** + * Post-process the output of the model to each of the task outputs. + * @param {string} text The text to post-process. + * @param {string} task The task to post-process the text for. + * @param {[number, number]} image_size The size of the image. height x width. + */ + post_process_generation(text, task, image_size) { + const task_answer_post_processing_type = this.tasks_answer_post_processing_type.get(task) ?? 'pure_text'; + + // remove the special tokens + text = text.replaceAll('', '').replaceAll('', ''); + + let final_answer; + switch (task_answer_post_processing_type) { + case 'pure_text': + final_answer = text; + break; + + case 'description_with_bboxes': + case 'bboxes': + case 'phrase_grounding': + case 'ocr': + const key = task_answer_post_processing_type === 'ocr' ? 'quad_boxes' : 'bboxes'; + const matches = text.matchAll(this.regexes[key]); + const labels = []; + const items = []; + for (const [_, label, ...locations] of matches) { + // Push new label, or duplicate the last label + labels.push(label ? label.trim() : labels.at(-1) ?? ''); + items.push(locations.map((x, i) => + // NOTE: Add 0.5 to use the center position of the bin as the coordinate. + (Number(x) + 0.5) / this.size_per_bin * image_size[i % 2]) + ); + } + final_answer = { labels, [key]: items }; + break; + + default: + throw new Error(`Task "${task}" (of type "${task_answer_post_processing_type}") not yet implemented.`); + } + + return { [task]: final_answer } + } + + // NOTE: images and text are switched from the python version + // `images` is required, `text` is optional + async _call(images, text=null, kwargs = {}) { + + if (!images && !text){ + throw new Error('Either text or images must be provided'); + } + + const image_inputs = await this.image_processor(images, kwargs); + const text_inputs = text ? this.tokenizer(text, kwargs) : {}; + + return { + ...image_inputs, + ...text_inputs, + } + } +} diff --git a/src/models/glpn/image_processing_glpn.js b/src/models/glpn/image_processing_glpn.js new file mode 100644 index 000000000..609f1b996 --- /dev/null +++ b/src/models/glpn/image_processing_glpn.js @@ -0,0 +1,5 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + +export class GLPNFeatureExtractor extends ImageProcessor { } diff --git a/src/models/image_processors.js b/src/models/image_processors.js new file mode 100644 index 000000000..70c7d30ff --- /dev/null +++ b/src/models/image_processors.js @@ -0,0 +1,33 @@ + +export * from './beit/image_processing_beit.js' +export * from './bit/image_processing_bit.js' +export * from './chinese_clip/image_processing_chinese_clip.js' +export * from './clip/image_processing_clip.js' +export * from './convnext/image_processing_convnext.js' +export * from './deit/image_processing_deit.js' +export * from './detr/image_processing_detr.js' +export * from './donut/image_processing_donut.js' +export * from './dpt/image_processing_dpt.js' +export * from './efficientnet/image_processing_efficientnet.js' +export * from './glpn/image_processing_glpn.js' +export * from './janus/image_processing_janus.js' +export * from './jina_clip/image_processing_jina_clip.js' +export * from './mask2former/image_processing_mask2former.js' +export * from './maskformer/image_processing_maskformer.js' +export * from './mobilenet_v1/image_processing_mobilenet_v1.js' +export * from './mobilenet_v2/image_processing_mobilenet_v2.js' +export * from './mobilenet_v3/image_processing_mobilenet_v3.js' +export * from './mobilenet_v4/image_processing_mobilenet_v4.js' +export * from './mobilevit/image_processing_mobilevit.js' +export * from './nougat/image_processing_nougat.js' +export * from './owlv2/image_processing_owlv2.js' +export * from './owlvit/image_processing_owlvit.js' +export * from './pvt/image_processing_pvt.js' +export * from './rt_detr/image_processing_rt_detr.js' +export * from './sam/image_processing_sam.js' +export * from './segformer/image_processing_segformer.js' +export * from './siglip/image_processing_siglip.js' +export * from './swin2sr/image_processing_swin2sr.js' +export * from './vit/image_processing_vit.js' +export * from './vitmatte/image_processing_vitmatte.js' +export * from './yolos/image_processing_yolos.js' diff --git a/src/models/janus/image_processing_janus.js b/src/models/janus/image_processing_janus.js new file mode 100644 index 000000000..4dae64ff4 --- /dev/null +++ b/src/models/janus/image_processing_janus.js @@ -0,0 +1,26 @@ + +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + +export class VLMImageProcessor extends ImageProcessor { + constructor(config) { + super({ + do_pad: true, + pad_size: { + width: config.image_size, + height: config.image_size, + }, + ...config, + }); + this.constant_values = this.config.background_color.map(x => x * this.rescale_factor) + } + + pad_image(pixelData, imgDims, padSize, options) { + return super.pad_image(pixelData, imgDims, padSize, { + constant_values: this.constant_values, + center: true, + ...options, + }); + } +} diff --git a/src/models/janus/processing_janus.js b/src/models/janus/processing_janus.js new file mode 100644 index 000000000..434192e18 --- /dev/null +++ b/src/models/janus/processing_janus.js @@ -0,0 +1,115 @@ + +import { Processor } from "../../base/processing_utils.js"; +import { AutoImageProcessor } from "../auto/image_processing_auto.js"; +import { AutoTokenizer } from "../../tokenizers.js"; +import { mergeArrays } from "../../utils/core.js"; +import { Tensor } from "../../utils/tensor.js"; +import { RawImage } from "../../utils/image.js"; + +export class VLChatProcessor extends Processor { + static image_processor_class = AutoImageProcessor + static tokenizer_class = AutoTokenizer + static uses_processor_config = true; + + constructor(config, components) { + super(config, components); + + this.image_tag = this.config.image_tag; + this.image_start_tag = this.config.image_start_tag; + this.image_end_tag = this.config.image_end_tag; + this.num_image_tokens = this.config.num_image_tokens; + } + + /** + * @typedef {Object} MultimodalMessageProperties Additional properties for multimodal messages. + * @property {(RawImage | string | URL)[]} [images] The images in the message. + * @typedef {(import('../../tokenizers.js').Message & MultimodalMessageProperties)[]} MultimodalConversation The conversation possibly containing multimodal inputs. + */ + + /** + * @param {MultimodalConversation} conversation The chat messages to process. + * @param {Object} options Additional options for processing. + * @param {RawImage|RawImage[]} [options.images] The images to process, if not set in the conversation. + * @param {string} [options.chat_template="default"] The chat template to use. + * @returns {Promise<{input_ids: Tensor; attention_mask: Tensor; images_seq_mask: Tensor; images_emb_mask: Tensor;} & import('../../base/image_processors_utils.js').ImageProcessorResult>} The processed input. + */ + async _call(conversation, { + images = null, + chat_template = "default", + }={}) { + if (!images) { + images = await Promise.all( + conversation + .filter((msg) => msg.images) + .flatMap((msg) => msg.images) + .map((img) => RawImage.read(img)) + ); + } else if (!Array.isArray(images)) { + images = [images]; + } + + const tokenizer = this.tokenizer; + const result = tokenizer.apply_chat_template(conversation, { + tokenize: false, + add_generation_prompt: true, + chat_template, + }); + + const encode = (text) => tokenizer.encode(text, { add_special_tokens: false }); + const parts = (/** @type {string} */(result)) + .split(this.image_tag); + const num_images = parts.length - 1; + if (images.length !== num_images) { + throw new Error(`Number of images provided (${images.length}) does not match number of "${this.image_tag}" image tags (${num_images})`); + } + + const [ + image_placeholder_tag_id, + image_start_tag_id, + image_end_tag_id, + ] = tokenizer.model.convert_tokens_to_ids([ + this.image_tag, + this.image_start_tag, + this.image_end_tag, + ]); + + let input_ids = encode(parts[0]); + let images_seq_mask = new Array(input_ids.length).fill(false); + for (let i = 1; i < parts.length; ++i) { + const placeholder_image_tokens = new Array(this.num_image_tokens).fill(image_placeholder_tag_id); + const tokens = encode(parts[i]); + input_ids = mergeArrays( + input_ids, + [image_start_tag_id], placeholder_image_tokens, [image_end_tag_id], + tokens, + ); + const image_mask = new Array(this.num_image_tokens).fill(true); + images_seq_mask = mergeArrays( + images_seq_mask, + [false], image_mask, [false], + new Array(tokens.length).fill(false), + ); + } + + const dims = [1, input_ids.length]; + const final = { + input_ids: new Tensor('int64', input_ids, dims), + attention_mask: new Tensor('int64', new Array(input_ids.length).fill(1), dims), + images_seq_mask: new Tensor('bool', images_seq_mask, dims), + images_emb_mask: new Tensor( + 'bool', + new Array(num_images * this.num_image_tokens).fill(true), + [1, num_images, this.num_image_tokens], + ), + } + + if (images && images.length > 0) { + const image_inputs = await this.image_processor(images); + // Set the batch_size dimension to 1 + image_inputs.pixel_values.unsqueeze_(0); + return { ...final, ...image_inputs }; + } + + return final; + } +} diff --git a/src/models/jina_clip/image_processing_jina_clip.js b/src/models/jina_clip/image_processing_jina_clip.js new file mode 100644 index 000000000..648e80d42 --- /dev/null +++ b/src/models/jina_clip/image_processing_jina_clip.js @@ -0,0 +1,5 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + +export class JinaCLIPImageProcessor extends ImageProcessor {} diff --git a/src/models/mask2former/image_processing_mask2former.js b/src/models/mask2former/image_processing_mask2former.js new file mode 100644 index 000000000..5e02b5c38 --- /dev/null +++ b/src/models/mask2former/image_processing_mask2former.js @@ -0,0 +1,5 @@ + +import { MaskFormerImageProcessor } from "../maskformer/image_processing_maskformer.js"; + +// NOTE: extends MaskFormerImageProcessor +export class Mask2FormerImageProcessor extends MaskFormerImageProcessor { } diff --git a/src/models/maskformer/image_processing_maskformer.js b/src/models/maskformer/image_processing_maskformer.js new file mode 100644 index 000000000..6b90b0451 --- /dev/null +++ b/src/models/maskformer/image_processing_maskformer.js @@ -0,0 +1,18 @@ +import { + ImageProcessor, + post_process_panoptic_segmentation, + post_process_instance_segmentation, +} from "../../base/image_processors_utils.js"; + +export class MaskFormerImageProcessor extends ImageProcessor { + + /** @type {typeof post_process_panoptic_segmentation} */ + post_process_panoptic_segmentation(...args) { + return post_process_panoptic_segmentation(...args); + } + /** @type {typeof post_process_instance_segmentation} */ + post_process_instance_segmentation(...args) { + return post_process_instance_segmentation(...args); + } +} +export class MaskFormerFeatureExtractor extends MaskFormerImageProcessor { } diff --git a/src/models/mobilenet_v1/image_processing_mobilenet_v1.js b/src/models/mobilenet_v1/image_processing_mobilenet_v1.js new file mode 100644 index 000000000..61246131e --- /dev/null +++ b/src/models/mobilenet_v1/image_processing_mobilenet_v1.js @@ -0,0 +1,7 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + + +export class MobileNetV1ImageProcessor extends ImageProcessor { } +export class MobileNetV1FeatureExtractor extends MobileNetV1ImageProcessor { } diff --git a/src/models/mobilenet_v2/image_processing_mobilenet_v2.js b/src/models/mobilenet_v2/image_processing_mobilenet_v2.js new file mode 100644 index 000000000..1d80a67a3 --- /dev/null +++ b/src/models/mobilenet_v2/image_processing_mobilenet_v2.js @@ -0,0 +1,7 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + + +export class MobileNetV2ImageProcessor extends ImageProcessor { } +export class MobileNetV2FeatureExtractor extends MobileNetV2ImageProcessor { } diff --git a/src/models/mobilenet_v3/image_processing_mobilenet_v3.js b/src/models/mobilenet_v3/image_processing_mobilenet_v3.js new file mode 100644 index 000000000..3a935d30d --- /dev/null +++ b/src/models/mobilenet_v3/image_processing_mobilenet_v3.js @@ -0,0 +1,7 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + + +export class MobileNetV3ImageProcessor extends ImageProcessor { } +export class MobileNetV3FeatureExtractor extends MobileNetV3ImageProcessor { } diff --git a/src/models/mobilenet_v4/image_processing_mobilenet_v4.js b/src/models/mobilenet_v4/image_processing_mobilenet_v4.js new file mode 100644 index 000000000..fc6401f73 --- /dev/null +++ b/src/models/mobilenet_v4/image_processing_mobilenet_v4.js @@ -0,0 +1,7 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + + +export class MobileNetV4ImageProcessor extends ImageProcessor { } +export class MobileNetV4FeatureExtractor extends MobileNetV4ImageProcessor { } diff --git a/src/models/mobilevit/image_processing_mobilevit.js b/src/models/mobilevit/image_processing_mobilevit.js new file mode 100644 index 000000000..356570c68 --- /dev/null +++ b/src/models/mobilevit/image_processing_mobilevit.js @@ -0,0 +1,6 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + +export class MobileViTImageProcessor extends ImageProcessor { } +export class MobileViTFeatureExtractor extends MobileViTImageProcessor { } diff --git a/src/models/nougat/image_processing_nougat.js b/src/models/nougat/image_processing_nougat.js new file mode 100644 index 000000000..c845fce3a --- /dev/null +++ b/src/models/nougat/image_processing_nougat.js @@ -0,0 +1,5 @@ + +import { DonutImageProcessor } from "../donut/image_processing_donut.js"; + +// NOTE: extends DonutImageProcessor +export class NougatImageProcessor extends DonutImageProcessor { } diff --git a/src/models/owlv2/image_processing_owlv2.js b/src/models/owlv2/image_processing_owlv2.js new file mode 100644 index 000000000..224f49cc1 --- /dev/null +++ b/src/models/owlv2/image_processing_owlv2.js @@ -0,0 +1,5 @@ + +import { OwlViTImageProcessor } from "../owlvit/image_processing_owlvit.js"; + +// NOTE: extends OwlViTImageProcessor +export class Owlv2ImageProcessor extends OwlViTImageProcessor { } diff --git a/src/models/owlvit/image_processing_owlvit.js b/src/models/owlvit/image_processing_owlvit.js new file mode 100644 index 000000000..e7c3c69cf --- /dev/null +++ b/src/models/owlvit/image_processing_owlvit.js @@ -0,0 +1,12 @@ +import { + ImageProcessor, + post_process_object_detection, +} from "../../base/image_processors_utils.js"; + +export class OwlViTImageProcessor extends ImageProcessor { + /** @type {typeof post_process_object_detection} */ + post_process_object_detection(...args) { + return post_process_object_detection(...args); + } +} +export class OwlViTFeatureExtractor extends OwlViTImageProcessor { } diff --git a/src/models/owlvit/processing_owlvit.js b/src/models/owlvit/processing_owlvit.js new file mode 100644 index 000000000..f596dbe19 --- /dev/null +++ b/src/models/owlvit/processing_owlvit.js @@ -0,0 +1,7 @@ +import { Processor } from "../../base/processing_utils.js"; +import { AutoImageProcessor } from "../auto/image_processing_auto.js"; +import { AutoTokenizer } from "../../tokenizers.js"; +export class OwlViTProcessor extends Processor { + static tokenizer_class = AutoTokenizer + static image_processor_class = AutoImageProcessor +} diff --git a/src/models/processors.js b/src/models/processors.js new file mode 100644 index 000000000..e32eb9622 --- /dev/null +++ b/src/models/processors.js @@ -0,0 +1,8 @@ +export * from './florence2/processing_florence2.js'; +export * from './janus/processing_janus.js'; +export * from './owlvit/processing_owlvit.js'; +export * from './pyannote/processing_pyannote.js'; +export * from './sam/processing_sam.js'; +export * from './speecht5/processing_speecht5.js'; +export * from './wav2vec2/processing_wav2vec2.js'; +export * from './whisper/processing_whisper.js'; diff --git a/src/models/pvt/image_processing_pvt.js b/src/models/pvt/image_processing_pvt.js new file mode 100644 index 000000000..2156dfe0d --- /dev/null +++ b/src/models/pvt/image_processing_pvt.js @@ -0,0 +1,5 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + +export class PvtImageProcessor extends ImageProcessor { } diff --git a/src/models/pyannote/feature_extraction_pyannote.js b/src/models/pyannote/feature_extraction_pyannote.js new file mode 100644 index 000000000..74b40fec9 --- /dev/null +++ b/src/models/pyannote/feature_extraction_pyannote.js @@ -0,0 +1,28 @@ +import { FeatureExtractor, validate_audio_inputs } from '../../base/feature_extraction_utils.js'; +import { Tensor } from '../../utils/tensor.js'; + + +export class PyAnnoteFeatureExtractor extends FeatureExtractor { + /** + * Asynchronously extracts features from a given audio using the provided configuration. + * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array. + * @returns {Promise<{ input_values: Tensor; }>} The extracted input features. + */ + async _call(audio) { + validate_audio_inputs(audio, 'PyAnnoteFeatureExtractor'); + + if (audio instanceof Float64Array) { + audio = new Float32Array(audio); + } + + const shape = [ + 1, /* batch_size */ + 1, /* num_channels */ + audio.length, /* num_samples */ + ]; + return { + input_values: new Tensor('float32', audio, shape), + }; + } + +} diff --git a/src/models/pyannote/processing_pyannote.js b/src/models/pyannote/processing_pyannote.js new file mode 100644 index 000000000..cf66251a8 --- /dev/null +++ b/src/models/pyannote/processing_pyannote.js @@ -0,0 +1,71 @@ +import { Processor } from '../../base/processing_utils.js'; +import { AutoFeatureExtractor } from '../auto/feature_extraction_auto.js'; +import { max, softmax } from '../../utils/maths.js'; + +export class PyAnnoteProcessor extends Processor { + static feature_extractor_class = AutoFeatureExtractor + + /** + * Calls the feature_extractor function with the given audio input. + * @param {any} audio The audio input to extract features from. + * @returns {Promise} A Promise that resolves with the extracted features. + */ + async _call(audio) { + return await this.feature_extractor(audio) + } + + /** + * NOTE: Can return fractional values. `Math.ceil` will ensure correct value. + * @param {number} samples The number of frames in the audio. + * @returns {number} The number of frames in the audio. + */ + samples_to_frames(samples) { + return ((samples - this.config.offset) / this.config.step); + } + + /** + * Post-processes the speaker diarization logits output by the model. + * @param {import('../../utils/tensor.js').Tensor} logits The speaker diarization logits output by the model. + * @param {number} num_samples Number of samples in the input audio. + * @returns {Array>} The post-processed speaker diarization results. + */ + post_process_speaker_diarization(logits, num_samples) { + const ratio = ( + num_samples / this.samples_to_frames(num_samples) + ) / this.config.sampling_rate; + + const results = []; + for (const scores of logits.tolist()) { + const accumulated_segments = []; + + let current_speaker = -1; + for (let i = 0; i < scores.length; ++i) { + const probabilities = softmax(scores[i]); + const [score, id] = max(probabilities); + const [start, end] = [i, i + 1]; + + if (id !== current_speaker) { + // Speaker has changed + current_speaker = id; + accumulated_segments.push({ id, start, end, score }); + } else { + // Continue the current segment + accumulated_segments.at(-1).end = end; + accumulated_segments.at(-1).score += score; + } + } + + results.push(accumulated_segments.map( + // Convert frame-space to time-space + // and compute the confidence + ({ id, start, end, score }) => ({ + id, + start: start * ratio, + end: end * ratio, + confidence: score / (end - start), + }) + )); + } + return results; + } +} diff --git a/src/models/rt_detr/image_processing_rt_detr.js b/src/models/rt_detr/image_processing_rt_detr.js new file mode 100644 index 000000000..eef753352 --- /dev/null +++ b/src/models/rt_detr/image_processing_rt_detr.js @@ -0,0 +1,12 @@ +import { + ImageProcessor, + post_process_object_detection, +} from "../../base/image_processors_utils.js"; + + +export class RTDetrImageProcessor extends ImageProcessor { + /** @type {typeof post_process_object_detection} */ + post_process_object_detection(...args) { + return post_process_object_detection(...args); + } +} diff --git a/src/models/sam/image_processing_sam.js b/src/models/sam/image_processing_sam.js new file mode 100644 index 000000000..bd71e1f43 --- /dev/null +++ b/src/models/sam/image_processing_sam.js @@ -0,0 +1,242 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; +import { calculateDimensions } from "../../utils/core.js"; + +import { + interpolate_4d, + Tensor, +} from "../../utils/tensor.js"; + + +/** + * @typedef {object} SamImageProcessorResult + * @property {Tensor} pixel_values + * @property {import("../../base/image_processors_utils.js").HeightWidth[]} original_sizes + * @property {import("../../base/image_processors_utils.js").HeightWidth[]} reshaped_input_sizes + * @property {Tensor} [input_points] + * @property {Tensor} [input_labels] + * @property {Tensor} [input_boxes] + */ + +export class SamImageProcessor extends ImageProcessor { + + /** + * + * @param {any} input_points + * @param {import("../../base/image_processors_utils.js").HeightWidth[]} original_sizes + * @param {import("../../base/image_processors_utils.js").HeightWidth[]} reshaped_input_sizes + * @returns {Tensor} + */ + reshape_input_points(input_points, original_sizes, reshaped_input_sizes, is_bounding_box = false) { + + // Make deep copy to avoid altering user's input + input_points = structuredClone(input_points); + let shape = calculateDimensions(input_points); + + // TODO: add support for 2D input_points + if (shape.length === 3) { + // Correct user's input + if (!is_bounding_box) { + shape = [1, ...shape]; + } + input_points = [input_points]; + } else if (shape.length !== 4) { + throw Error("The input_points must be a 4D tensor of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.") + } + + // Reshape input points + for (let i = 0; i < input_points.length; ++i) { // batch_size + let originalImageSize = original_sizes[i]; + let reshapedImageSize = reshaped_input_sizes[i]; + + let resizeFactors = [ + reshapedImageSize[0] / originalImageSize[0], + reshapedImageSize[1] / originalImageSize[1] + ] + + for (let j = 0; j < input_points[i].length; ++j) { // point_batch_size + for (let k = 0; k < input_points[i][j].length; ++k) { // nb_points_per_image + for (let w = 0; w < input_points[i][j][k].length; ++w) { // 2 or 4 + input_points[i][j][k][w] *= resizeFactors[w % 2]; + } + } + } + } + + return new Tensor( + 'float32', + Float32Array.from(input_points.flat(Infinity)), + shape + ) + + } + + /** + * + * @param {any} input_labels + * @param {Tensor} input_points + * @returns {Tensor} + */ + add_input_labels(input_labels, input_points) { + let shape = calculateDimensions(input_labels); + if (shape.length === 2) { + // Correct user's input + shape = [1, ...shape]; + input_labels = [input_labels]; + } else if (shape.length !== 3) { + throw Error("The input_points must be a 4D tensor of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.") + } + + if (shape.some((x, i) => x !== input_points.dims[i])) { + throw Error(`The first ${shape.length} dimensions of 'input_points' and 'input_labels' must be the same.`) + } + return new Tensor( + 'int64', + input_labels.flat(Infinity).map(BigInt), + shape, + ) + } + /** + * @param {any[]} images The URL(s) of the image(s) to extract features from. + * @param {Object} [options] Additional options for the processor. + * @param {any} [options.input_points=null] A 3D or 4D array, representing the input points provided by the user. + * - 3D: `[point_batch_size, nb_points_per_image, 2]`. In this case, `batch_size` is assumed to be 1. + * - 4D: `[batch_size, point_batch_size, nb_points_per_image, 2]`. + * @param {any} [options.input_labels=null] A 2D or 3D array, representing the input labels for the points, used by the prompt encoder to encode the prompt. + * - 2D: `[point_batch_size, nb_points_per_image]`. In this case, `batch_size` is assumed to be 1. + * - 3D: `[batch_size, point_batch_size, nb_points_per_image]`. + * @param {number[][][]} [options.input_boxes=null] A 3D array of shape `(batch_size, num_boxes, 4)`, representing the input boxes provided by the user. + * This is used by the prompt encoder to encode the prompt. Generally yields to much better generated masks. + * The processor will generate a tensor, with each dimension corresponding respectively to the image batch size, + * the number of boxes per image and the coordinates of the top left and botton right point of the box. + * In the order (`x1`, `y1`, `x2`, `y2`): + * - `x1`: the x coordinate of the top left point of the input box + * - `y1`: the y coordinate of the top left point of the input box + * - `x2`: the x coordinate of the bottom right point of the input box + * - `y2`: the y coordinate of the bottom right point of the input box + * @returns {Promise} + */ + async _call(images, { + input_points = null, + input_labels = null, + input_boxes = null + } = {}) { + // TODO allow user to use preprocessed images + /** @type {SamImageProcessorResult} */ + const processed = await super._call(images); + + if (input_points) { + processed.input_points = this.reshape_input_points( + input_points, processed.original_sizes, processed.reshaped_input_sizes + ); + } + + if (input_labels) { + if (!processed.input_points) { + throw Error("`input_points` must be provided if `input_labels` are provided.") + } + processed.input_labels = this.add_input_labels(input_labels, processed.input_points); + } + + if (input_boxes) { + processed.input_boxes = this.reshape_input_points( + input_boxes, processed.original_sizes, processed.reshaped_input_sizes, true, + ); + } + + return processed; + } + + /** + * Remove padding and upscale masks to the original image size. + * @param {Tensor} masks Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. + * @param {[number, number][]} original_sizes The original sizes of each image before it was resized to the model's expected input shape, in (height, width) format. + * @param {[number, number][]} reshaped_input_sizes The size of each image as it is fed to the model, in (height, width) format. Used to remove padding. + * @param {Object} options Optional parameters for post-processing. + * @param {number} [options.mask_threshold] The threshold to use for binarizing the masks. + * @param {boolean} [options.binarize] Whether to binarize the masks. + * @param {Object} [options.pad_size] The target size the images were padded to before being passed to the model. If `null`, the target size is assumed to be the processor's `pad_size`. + * @param {number} [options.pad_size.height] The height the images were padded to. + * @param {number} [options.pad_size.width] The width the images were padded to. + * @returns {Promise} Batched masks in batch_size, num_channels, height, width) format, where (height, width) is given by original_size. + */ + async post_process_masks(masks, original_sizes, reshaped_input_sizes, { + mask_threshold = 0.0, + binarize = true, + pad_size = null, + } = {}) { + // masks: [1, 1, 3, 256, 256] + + const output_masks = []; + + pad_size = pad_size ?? this.pad_size; + + /** @type {[number, number]} */ + const target_image_size = [pad_size.height, pad_size.width]; + + for (let i = 0; i < original_sizes.length; ++i) { + const original_size = original_sizes[i]; + const reshaped_input_size = reshaped_input_sizes[i]; + + // Upscale mask to padded size + let interpolated_mask = (await interpolate_4d( + masks[i], + { mode: 'bilinear', size: target_image_size } + )); + + // Crop mask + interpolated_mask = interpolated_mask.slice(null, null, [0, reshaped_input_size[0]], [0, reshaped_input_size[1]]); + + // Downscale mask + interpolated_mask = (await interpolate_4d( + interpolated_mask, + { mode: 'bilinear', size: original_size } + )); + + if (binarize) { + const data = interpolated_mask.data; + const binarizedMaskData = new Uint8Array(data.length); + for (let i = 0; i < data.length; ++i) { + if (data[i] > mask_threshold) { + binarizedMaskData[i] = 1; + } + } + interpolated_mask = new Tensor( + 'bool', + binarizedMaskData, + interpolated_mask.dims + ) + } + + output_masks.push(interpolated_mask); + } + + return output_masks; + } + + /** + * Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. + * @param {import("../../utils/image.js").RawImage} image Input original image + * @param {number} target_size Target size of the resized image + * @param {Object} options Options for generating crop boxes + * @param {number} [options.crop_n_layers] If >0, mask prediction will be run again on crops of the image. + * Sets the number of layers to run, where each layer has 2**i_layer number of image crops. + * @param {number} [options.overlap_ratio] Sets the degree to which crops overlap. In the first crop layer, + * crops will overlap by this fraction of the image length. Later layers with more crops scale down this overlap. + * @param {number} [options.points_per_crop] Number of points to sample from each crop. + * @param {number} [options.crop_n_points_downscale_factor] The number of points-per-side sampled in layer n is + * scaled down by crop_n_points_downscale_factor**n. + * @returns {Object} An object containing the crop boxes, number of points per crop, cropped images, and input labels. + */ + generate_crop_boxes(image, target_size, { + crop_n_layers = 0, + overlap_ratio = 512 / 1500, + points_per_crop = 32, + crop_n_points_downscale_factor = 1, + } = {}) { + // TODO: Implement + // return { crop_boxes, points_per_crop, cropped_images, input_labels } + } +} + diff --git a/src/models/sam/processing_sam.js b/src/models/sam/processing_sam.js new file mode 100644 index 000000000..4cc0f29e5 --- /dev/null +++ b/src/models/sam/processing_sam.js @@ -0,0 +1,20 @@ +import { Processor } from "../../base/processing_utils.js"; +import { AutoImageProcessor } from "../auto/image_processing_auto.js"; + +export class SamProcessor extends Processor { + static image_processor_class = AutoImageProcessor + + async _call(...args) { + return await this.image_processor(...args); + } + + post_process_masks(...args) { + // @ts-ignore + return this.image_processor.post_process_masks(...args); + } + + reshape_input_points(...args) { + // @ts-ignore + return this.image_processor.reshape_input_points(...args); + } +} \ No newline at end of file diff --git a/src/models/sapiens/image_processing_sapiens.js b/src/models/sapiens/image_processing_sapiens.js new file mode 100644 index 000000000..df78763cf --- /dev/null +++ b/src/models/sapiens/image_processing_sapiens.js @@ -0,0 +1,13 @@ +import { + ImageProcessor, + post_process_semantic_segmentation, +} from "../../base/image_processors_utils.js"; + + +export class SapiensImageProcessor extends ImageProcessor { + /** @type {typeof post_process_semantic_segmentation} */ + post_process_semantic_segmentation(...args) { + return post_process_semantic_segmentation(...args); + } +} +export class SapiensFeatureExtractor extends SapiensImageProcessor { } diff --git a/src/models/seamless_m4t/feature_extraction_seamless_m4t.js b/src/models/seamless_m4t/feature_extraction_seamless_m4t.js new file mode 100644 index 000000000..8f02de062 --- /dev/null +++ b/src/models/seamless_m4t/feature_extraction_seamless_m4t.js @@ -0,0 +1,180 @@ +import { FeatureExtractor, validate_audio_inputs } from '../../base/feature_extraction_utils.js'; +import { Tensor } from '../../utils/tensor.js'; +import { mel_filter_bank, spectrogram, window_function } from '../../utils/audio.js'; + +export class SeamlessM4TFeatureExtractor extends FeatureExtractor { + + constructor(config) { + super(config); + + const sampling_rate = this.config.sampling_rate; + const mel_filters = mel_filter_bank( + 256, // num_frequency_bins + this.config.num_mel_bins, // num_mel_filters + 20, // min_frequency + Math.floor(sampling_rate / 2), // max_frequency + sampling_rate, // sampling_rate + null, // norm + "kaldi", // mel_scale + true, // triangularize_in_mel_space + ); + + // Do padding: + for (let i = 0; i < mel_filters.length; ++i) { + mel_filters[i].push(0); + } + this.mel_filters = mel_filters; + + this.window = window_function(400, 'povey', { + periodic: false, + }) + } + + /** + * Computes the log-Mel spectrogram of the provided audio waveform. + * @param {Float32Array|Float64Array} waveform The audio waveform to process. + * @param {number} max_length The maximum number of frames to return. + * @returns {Promise} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers. + */ + async _extract_fbank_features(waveform, max_length) { + // NOTE: We don't pad/truncate since that is passed in as `max_num_frames` + + // Kaldi compliance: 16-bit signed integers + // 32768 == 2 ** 15 + waveform = waveform.map((/** @type {number} */ x) => x * 32768) + + return spectrogram( + waveform, + this.window, // window + 400, // frame_length + 160, // hop_length + { + fft_length: 512, + power: 2.0, + center: false, + preemphasis: 0.97, + mel_filters: this.mel_filters, + log_mel: 'log', + mel_floor: 1.192092955078125e-07, + remove_dc_offset: true, + + // Custom + max_num_frames: max_length, + transpose: true, + } + ) + } + + /** + * Asynchronously extracts features from a given audio using the provided configuration. + * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array. + * @param {Object} options Optional parameters for feature extraction. + * @param {boolean} [options.padding=true] Whether to pad the sequence to a multiple of `pad_to_multiple_of`. + * @param {number} [options.pad_to_multiple_of=2] The number to pad the sequence to a multiple of. + * @param {boolean} [options.do_normalize_per_mel_bins=true] Whether or not to zero-mean unit-variance normalize the input per mel-channel. + * @param {boolean} [options.return_attention_mask=true] Whether to return the attention mask. + * @returns {Promise<{ input_features: Tensor, attention_mask?: Tensor }>} A Promise resolving to an object containing the extracted input features and attention masks as Tensors. + */ + async _call(audio, { + padding = true, + pad_to_multiple_of = 2, + do_normalize_per_mel_bins = true, + return_attention_mask = true, + } = {}) { + validate_audio_inputs(audio, 'SeamlessM4TFeatureExtractor'); + + let features = await this._extract_fbank_features(audio, this.config.max_length); + + if (do_normalize_per_mel_bins) { + const [num_features, feature_size] = features.dims; + const data = features.data; + for (let i = 0; i < feature_size; ++i) { + let sum = 0; + for (let j = 0; j < num_features; ++j) { + sum += data[j * feature_size + i]; + } + + const mean = sum / num_features; + + let variance = 0; + for (let j = 0; j < num_features; ++j) { + variance += (data[j * feature_size + i] - mean) ** 2; + } + variance /= num_features - 1; // NOTE: We use ddof=1 + + const std = Math.sqrt(variance + 1e-7); + for (let j = 0; j < num_features; ++j) { + const index = j * feature_size + i; + data[index] = (data[index] - mean) / std; + } + } + } + + let padded_attention_mask; + if (padding) { + const [num_frames, num_channels] = features.dims; + const data = /** @type {Float32Array} */(features.data); + + const pad_size = num_frames % pad_to_multiple_of; + if (pad_size > 0) { + const padded_data = new Float32Array(num_channels * (num_frames + pad_size)); + padded_data.set(data) + padded_data.fill(this.config.padding_value, data.length) + + const numPaddedFrames = num_frames + pad_size; + features = new Tensor( + features.type, + padded_data, + [numPaddedFrames, num_channels], + ) + + if (return_attention_mask) { + padded_attention_mask = new Tensor( + 'int64', + new BigInt64Array(numPaddedFrames), + [1, numPaddedFrames], + ) + padded_attention_mask.data.fill(1n, 0, num_frames); + } + } + } + + const [num_frames, num_channels] = features.dims; + + const stride = this.config.stride; + const remainder = num_frames % stride; + if (remainder !== 0) { + throw new Error(`The number of frames (${num_frames}) must be a multiple of the stride (${stride}).`) + } + + const input_features = features.view( + 1, + Math.floor(num_frames / stride), + num_channels * stride, + ); + + const result = { input_features } + + if (return_attention_mask) { + const reshapedNumFrames = input_features.dims[1]; + + const attention_mask_data = new BigInt64Array(reshapedNumFrames); + + if (padded_attention_mask) { + const padded_attention_mask_data = padded_attention_mask.data; + for (let i = 1, j = 0; i < num_frames; i += stride, ++j) { + attention_mask_data[j] = padded_attention_mask_data[i]; + } + } else { + attention_mask_data.fill(1n); + } + result.attention_mask = new Tensor( + 'int64', + attention_mask_data, + [1, reshapedNumFrames], + ); + } + + return result; + } +} diff --git a/src/models/segformer/image_processing_segformer.js b/src/models/segformer/image_processing_segformer.js new file mode 100644 index 000000000..fe129a05a --- /dev/null +++ b/src/models/segformer/image_processing_segformer.js @@ -0,0 +1,13 @@ +import { + ImageProcessor, + post_process_semantic_segmentation, +} from "../../base/image_processors_utils.js"; + + +export class SegformerImageProcessor extends ImageProcessor { + /** @type {typeof post_process_semantic_segmentation} */ + post_process_semantic_segmentation(...args) { + return post_process_semantic_segmentation(...args); + } +} +export class SegformerFeatureExtractor extends SegformerImageProcessor { } diff --git a/src/models/siglip/image_processing_siglip.js b/src/models/siglip/image_processing_siglip.js new file mode 100644 index 000000000..5e666562b --- /dev/null +++ b/src/models/siglip/image_processing_siglip.js @@ -0,0 +1,5 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + +export class SiglipImageProcessor extends ImageProcessor { } diff --git a/src/models/speecht5/feature_extraction_speecht5.js b/src/models/speecht5/feature_extraction_speecht5.js new file mode 100644 index 000000000..0f3f2ab38 --- /dev/null +++ b/src/models/speecht5/feature_extraction_speecht5.js @@ -0,0 +1,4 @@ + +import { FeatureExtractor } from "../../base/feature_extraction_utils.js"; + +export class SpeechT5FeatureExtractor extends FeatureExtractor { } diff --git a/src/models/speecht5/processing_speecht5.js b/src/models/speecht5/processing_speecht5.js new file mode 100644 index 000000000..08af8ba1a --- /dev/null +++ b/src/models/speecht5/processing_speecht5.js @@ -0,0 +1,17 @@ +import { Processor } from "../../base/processing_utils.js"; +import { AutoTokenizer } from "../../tokenizers.js"; +import { AutoFeatureExtractor } from "../auto/feature_extraction_auto.js"; + +export class SpeechT5Processor extends Processor { + static tokenizer_class = AutoTokenizer + static feature_extractor_class = AutoFeatureExtractor + + /** + * Calls the feature_extractor function with the given input. + * @param {any} input The input to extract features from. + * @returns {Promise} A Promise that resolves with the extracted features. + */ + async _call(input) { + return await this.feature_extractor(input) + } +} diff --git a/src/models/swin2sr/image_processing_swin2sr.js b/src/models/swin2sr/image_processing_swin2sr.js new file mode 100644 index 000000000..e53c5c4c1 --- /dev/null +++ b/src/models/swin2sr/image_processing_swin2sr.js @@ -0,0 +1,24 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + +export class Swin2SRImageProcessor extends ImageProcessor { + pad_image(pixelData, imgDims, padSize, options = {}) { + // NOTE: In this case, `padSize` represents the size of the sliding window for the local attention. + // In other words, the image is padded so that its width and height are multiples of `padSize`. + const [imageHeight, imageWidth, imageChannels] = imgDims; + + return super.pad_image(pixelData, imgDims, { + // NOTE: For Swin2SR models, the original python implementation adds padding even when the image's width/height is already + // a multiple of `pad_size`. However, this is most likely a bug (PR: https://github.com/mv-lab/swin2sr/pull/19). + // For this reason, we only add padding when the image's width/height is not a multiple of `pad_size`. + width: imageWidth + (padSize - imageWidth % padSize) % padSize, + height: imageHeight + (padSize - imageHeight % padSize) % padSize, + }, { + mode: 'symmetric', + center: false, + constant_values: -1, + ...options, + }) + } +} \ No newline at end of file diff --git a/src/models/vit/image_processing_vit.js b/src/models/vit/image_processing_vit.js new file mode 100644 index 000000000..ad07ca27e --- /dev/null +++ b/src/models/vit/image_processing_vit.js @@ -0,0 +1,7 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + +export class ViTImageProcessor extends ImageProcessor { } +export class ViTFeatureExtractor extends ViTImageProcessor { } + diff --git a/src/models/vitmatte/image_processing_vitmatte.js b/src/models/vitmatte/image_processing_vitmatte.js new file mode 100644 index 000000000..274862344 --- /dev/null +++ b/src/models/vitmatte/image_processing_vitmatte.js @@ -0,0 +1,50 @@ +import { + ImageProcessor, +} from "../../base/image_processors_utils.js"; + +import { + stack, + cat, +} from "../../utils/tensor.js"; + +export class VitMatteImageProcessor extends ImageProcessor { + /** + * Calls the feature extraction process on an array of images, preprocesses + * each image, and concatenates the resulting features into a single Tensor. + * @param {import("../../utils/image.js").RawImage[]} images The image(s) to extract features from. + * @param {import("../../utils/image.js").RawImage[]} trimaps The trimaps(s) to extract features from. + * @returns {Promise} An object containing the concatenated pixel values of the preprocessed images. + */ + async _call(images, trimaps) { + if (!Array.isArray(images)) { + images = [images]; + } + if (!Array.isArray(trimaps)) { + trimaps = [trimaps]; + } + + const imageData = await Promise.all(images.map(x => this.preprocess(x))); + const trimapData = await Promise.all(trimaps.map(x => this.preprocess(x, { + do_normalize: false, + do_convert_rgb: false, + do_convert_grayscale: true, + }))); + + + // Stack pixel values + const pixel_values = stack(imageData.map( + // Concatenate images and trimaps + (x, i) => cat([x.pixel_values, trimapData[i].pixel_values], 0) + ), 0); + + return { + pixel_values, + + // Original sizes of images + original_sizes: imageData.map(x => x.original_size), + + // Reshaped sizes of images, before padding or cropping + reshaped_input_sizes: imageData.map(x => x.reshaped_input_size), + } + } +} diff --git a/src/models/wav2vec2/feature_extraction_wav2vec2.js b/src/models/wav2vec2/feature_extraction_wav2vec2.js new file mode 100644 index 000000000..51f007603 --- /dev/null +++ b/src/models/wav2vec2/feature_extraction_wav2vec2.js @@ -0,0 +1,44 @@ +import { FeatureExtractor, validate_audio_inputs } from "../../base/feature_extraction_utils.js"; +import { Tensor } from "../../utils/tensor.js"; + +export class Wav2Vec2FeatureExtractor extends FeatureExtractor { + + /** + * @param {Float32Array} input_values + * @returns {Float32Array} + */ + _zero_mean_unit_var_norm(input_values) { + // TODO support batch? + const sum = input_values.reduce((a, b) => a + b, 0); + const mean = sum / input_values.length; + const variance = input_values.reduce((a, b) => a + (b - mean) ** 2, 0) / input_values.length; + return input_values.map(x => (x - mean) / Math.sqrt(variance + 1e-7)); + } + + /** + * Asynchronously extracts features from a given audio using the provided configuration. + * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array. + * @returns {Promise<{ input_values: Tensor; attention_mask: Tensor }>} A Promise resolving to an object containing the extracted input features and attention mask as Tensors. + */ + async _call(audio) { + validate_audio_inputs(audio, 'Wav2Vec2FeatureExtractor'); + + if (audio instanceof Float64Array) { + audio = new Float32Array(audio); + } + + let input_values = audio; + + // zero-mean and unit-variance normalization + if (this.config.do_normalize) { + input_values = this._zero_mean_unit_var_norm(input_values); + } + + // TODO: allow user to pass in attention mask + const shape = [1, input_values.length]; + return { + input_values: new Tensor('float32', input_values, shape), + attention_mask: new Tensor('int64', new BigInt64Array(input_values.length).fill(1n), shape) + }; + } +} diff --git a/src/models/wav2vec2/processing_wav2vec2.js b/src/models/wav2vec2/processing_wav2vec2.js new file mode 100644 index 000000000..490fe2fc9 --- /dev/null +++ b/src/models/wav2vec2/processing_wav2vec2.js @@ -0,0 +1,15 @@ +import { Processor } from "../../base/processing_utils.js"; +import { AutoFeatureExtractor } from "../auto/feature_extraction_auto.js"; + +export class Wav2Vec2ProcessorWithLM extends Processor { + static feature_extractor_class = AutoFeatureExtractor + + /** + * Calls the feature_extractor function with the given audio input. + * @param {any} audio The audio input to extract features from. + * @returns {Promise} A Promise that resolves with the extracted features. + */ + async _call(audio) { + return await this.feature_extractor(audio) + } +} diff --git a/src/models/wespeaker/feature_extraction_wespeaker.js b/src/models/wespeaker/feature_extraction_wespeaker.js new file mode 100644 index 000000000..0815f9cda --- /dev/null +++ b/src/models/wespeaker/feature_extraction_wespeaker.js @@ -0,0 +1,100 @@ +import { FeatureExtractor, validate_audio_inputs } from '../../base/feature_extraction_utils.js'; +import { Tensor } from '../../utils/tensor.js'; +import { mel_filter_bank, spectrogram, window_function } from '../../utils/audio.js'; + + +export class WeSpeakerFeatureExtractor extends FeatureExtractor { + + constructor(config) { + super(config); + + const sampling_rate = this.config.sampling_rate; + const mel_filters = mel_filter_bank( + 256, // num_frequency_bins + this.config.num_mel_bins, // num_mel_filters + 20, // min_frequency + Math.floor(sampling_rate / 2), // max_frequency + sampling_rate, // sampling_rate + null, // norm + "kaldi", // mel_scale + true, // triangularize_in_mel_space + ); + + // Do padding: + for (let i = 0; i < mel_filters.length; ++i) { + mel_filters[i].push(0); + } + this.mel_filters = mel_filters; + + this.window = window_function(400, 'hamming', { + periodic: false, + }) + this.min_num_frames = this.config.min_num_frames; + } + + /** + * Computes the log-Mel spectrogram of the provided audio waveform. + * @param {Float32Array|Float64Array} waveform The audio waveform to process. + * @returns {Promise} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers. + */ + async _extract_fbank_features(waveform) { + // Kaldi compliance: 16-bit signed integers + // 32768 == 2 ** 15 + waveform = waveform.map((/** @type {number} */ x) => x * 32768) + + return spectrogram( + waveform, + this.window, // window + 400, // frame_length + 160, // hop_length + { + fft_length: 512, + power: 2.0, + center: false, + preemphasis: 0.97, + mel_filters: this.mel_filters, + log_mel: 'log', + mel_floor: 1.192092955078125e-07, + remove_dc_offset: true, + + // Custom + transpose: true, + min_num_frames: this.min_num_frames, + } + ) + } + + + /** + * Asynchronously extracts features from a given audio using the provided configuration. + * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array. + * @returns {Promise<{ input_features: Tensor }>} A Promise resolving to an object containing the extracted input features as a Tensor. + */ + async _call(audio) { + validate_audio_inputs(audio, 'WeSpeakerFeatureExtractor'); + + const features = (await this._extract_fbank_features(audio)).unsqueeze_(0); + + if (this.config.fbank_centering_span === null) { + // center features with global average + const meanData = /** @type {Float32Array} */ (features.mean(1).data); + const featuresData = /** @type {Float32Array} */(features.data); + const [batch_size, num_frames, feature_size] = features.dims; + + for (let i = 0; i < batch_size; ++i) { + const offset1 = i * num_frames * feature_size; + const offset2 = i * feature_size; + for (let j = 0; j < num_frames; ++j) { + const offset3 = offset1 + j * feature_size; + for (let k = 0; k < feature_size; ++k) { + featuresData[offset3 + k] -= meanData[offset2 + k]; + } + } + } + } + + return { + input_features: features + }; + } +} diff --git a/src/models/whisper/feature_extraction_whisper.js b/src/models/whisper/feature_extraction_whisper.js new file mode 100644 index 000000000..f4d351f88 --- /dev/null +++ b/src/models/whisper/feature_extraction_whisper.js @@ -0,0 +1,84 @@ +import { FeatureExtractor, validate_audio_inputs } from '../../base/feature_extraction_utils.js'; +import { Tensor } from '../../utils/tensor.js'; +import { mel_filter_bank, spectrogram, window_function } from '../../utils/audio.js'; +import { max } from '../../utils/maths.js'; + +export class WhisperFeatureExtractor extends FeatureExtractor { + + constructor(config) { + super(config); + + // Prefer given `mel_filters` from preprocessor_config.json, or calculate them if they don't exist. + this.config.mel_filters ??= mel_filter_bank( + Math.floor(1 + this.config.n_fft / 2), // num_frequency_bins + this.config.feature_size, // num_mel_filters + 0.0, // min_frequency + 8000.0, // max_frequency + this.config.sampling_rate, // sampling_rate + "slaney", // norm + "slaney", // mel_scale + ); + + this.window = window_function(this.config.n_fft, 'hann'); + } + + /** + * Computes the log-Mel spectrogram of the provided audio waveform. + * @param {Float32Array|Float64Array} waveform The audio waveform to process. + * @returns {Promise} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers. + */ + async _extract_fbank_features(waveform) { + const features = await spectrogram( + waveform, + this.window, // window + this.config.n_fft, // frame_length + this.config.hop_length, // hop_length + { + power: 2.0, + mel_filters: this.config.mel_filters, + log_mel: 'log10', + + // Custom + max_num_frames: this.config.nb_max_frames, // 3000 + } + ) + + const data = features.data; + const maxValue = max(data)[0]; + + for (let i = 0; i < data.length; ++i) { + data[i] = (Math.max(data[i], maxValue - 8.0) + 4.0) / 4.0; + } + + return features; + } + + /** + * Asynchronously extracts features from a given audio using the provided configuration. + * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array. + * @returns {Promise<{ input_features: Tensor }>} A Promise resolving to an object containing the extracted input features as a Tensor. + */ + async _call(audio) { + validate_audio_inputs(audio, 'WhisperFeatureExtractor'); + + let waveform; + if (audio.length > this.config.n_samples) { + console.warn( + "Attempting to extract features for audio longer than 30 seconds. " + + "If using a pipeline to extract transcript from a long audio clip, " + + "remember to specify `chunk_length_s` and/or `stride_length_s`." + ); + waveform = audio.slice(0, this.config.n_samples); + } else { + // pad with zeros + waveform = new Float32Array(this.config.n_samples); + waveform.set(audio); + } + + const features = await this._extract_fbank_features(waveform); + + return { + input_features: features.unsqueeze_(0) + }; + } +} diff --git a/src/models/whisper/processing_whisper.js b/src/models/whisper/processing_whisper.js new file mode 100644 index 000000000..b676273b8 --- /dev/null +++ b/src/models/whisper/processing_whisper.js @@ -0,0 +1,21 @@ +import { AutoFeatureExtractor } from "../auto/feature_extraction_auto.js" +import { AutoTokenizer } from "../../tokenizers.js" +import { Processor } from "../../base/processing_utils.js" + +/** + * Represents a WhisperProcessor that extracts features from an audio input. + */ +export class WhisperProcessor extends Processor { + static tokenizer_class = AutoTokenizer + static feature_extractor_class = AutoFeatureExtractor + + /** + * Calls the feature_extractor function with the given audio input. + * @param {any} audio The audio input to extract features from. + * @returns {Promise} A Promise that resolves with the extracted features. + */ + async _call(audio) { + return await this.feature_extractor(audio); + } +} + diff --git a/src/models/yolos/image_processing_yolos.js b/src/models/yolos/image_processing_yolos.js new file mode 100644 index 000000000..f82b08984 --- /dev/null +++ b/src/models/yolos/image_processing_yolos.js @@ -0,0 +1,12 @@ +import { + ImageProcessor, + post_process_object_detection, +} from "../../base/image_processors_utils.js"; + +export class YolosImageProcessor extends ImageProcessor { + /** @type {typeof post_process_object_detection} */ + post_process_object_detection(...args) { + return post_process_object_detection(...args); + } +} +export class YolosFeatureExtractor extends YolosImageProcessor { } diff --git a/src/pipelines.js b/src/pipelines.js index 3b7373cf9..a61cb1dde 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -45,8 +45,10 @@ import { } from './models.js'; import { AutoProcessor, - Processor -} from './processors.js'; +} from './models/auto/processing_auto.js'; +import { + Processor, +} from './base/processing_utils.js'; import { Callable, @@ -54,7 +56,6 @@ import { import { dispatchCallback, - pop, product, } from './utils/core.js'; import { @@ -158,7 +159,6 @@ function get_bounding_box(box, asInteger) { /** * The Pipeline class is the class from which all pipelines inherit. * Refer to this class for methods shared across different pipelines. - * @extends Callable */ export class Pipeline extends Callable { /** @@ -2131,8 +2131,8 @@ export class ImageSegmentationPipeline extends (/** @type {new (options: ImagePi fn = this.subtasks_mapping[subtask]; } else { for (let [task, func] of Object.entries(this.subtasks_mapping)) { - if (func in this.processor.feature_extractor) { - fn = this.processor.feature_extractor[func].bind(this.processor.feature_extractor); + if (func in this.processor.image_processor) { + fn = this.processor.image_processor[func].bind(this.processor.image_processor); subtask = task; break; } @@ -2362,7 +2362,7 @@ export class ObjectDetectionPipeline extends (/** @type {new (options: ImagePipe const output = await this.model({ pixel_values, pixel_mask }); // @ts-ignore - const processed = this.processor.feature_extractor.post_process_object_detection(output, threshold, imageSizes); + const processed = this.processor.image_processor.post_process_object_detection(output, threshold, imageSizes); // Add labels const id2label = this.model.config.id2label; @@ -2510,7 +2510,7 @@ export class ZeroShotObjectDetectionPipeline extends (/** @type {new (options: T const output = await this.model({ ...text_inputs, pixel_values }); // @ts-ignore - const processed = this.processor.feature_extractor.post_process_object_detection(output, threshold, imageSize, true)[0]; + const processed = this.processor.image_processor.post_process_object_detection(output, threshold, imageSize, true)[0]; let result = processed.boxes.map((box, i) => ({ score: processed.scores[i], label: candidate_labels[processed.classes[i]], diff --git a/src/processors.js b/src/processors.js deleted file mode 100644 index 9af0791be..000000000 --- a/src/processors.js +++ /dev/null @@ -1,2655 +0,0 @@ - -/** - * @file Processors are used to prepare non-textual inputs (e.g., image or audio) for a model. - * - * **Example:** Using a `WhisperProcessor` to prepare an audio input for a model. - * ```javascript - * import { AutoProcessor, read_audio } from '@huggingface/transformers'; - * - * let processor = await AutoProcessor.from_pretrained('openai/whisper-tiny.en'); - * let audio = await read_audio('https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac', 16000); - * let { input_features } = await processor(audio); - * // Tensor { - * // data: Float32Array(240000) [0.4752984642982483, 0.5597258806228638, 0.56434166431427, ...], - * // dims: [1, 80, 3000], - * // type: 'float32', - * // size: 240000, - * // } - * ``` - * - * @module processors - */ -import { - Callable, -} from './utils/generic.js'; - -import { - calculateDimensions, - calculateReflectOffset, -} from './utils/core.js'; - -import { - getModelJSON, -} from './utils/hub.js'; - -import { - min, - max, - softmax, - bankers_round, -} from './utils/maths.js'; - - -import { Tensor, cat, interpolate, stack, interpolate_4d, full } from './utils/tensor.js'; - -import { RawImage } from './utils/image.js'; -import { - window_function, - spectrogram, - mel_filter_bank, -} from './utils/audio.js'; - - -// Helper functions - -/** - * Converts bounding boxes from center format to corners format. - * - * @param {number[]} arr The coordinate for the center of the box and its width, height dimensions (center_x, center_y, width, height) - * @returns {number[]} The coodinates for the top-left and bottom-right corners of the box (top_left_x, top_left_y, bottom_right_x, bottom_right_y) - */ -function center_to_corners_format([centerX, centerY, width, height]) { - return [ - centerX - width / 2, - centerY - height / 2, - centerX + width / 2, - centerY + height / 2 - ]; -} - -/** - * Post-processes the outputs of the model (for object detection). - * @param {Object} outputs The outputs of the model that must be post-processed - * @param {Tensor} outputs.logits The logits - * @param {Tensor} outputs.pred_boxes The predicted boxes. - * @param {number} [threshold=0.5] The threshold to use for the scores. - * @param {[number, number][]} [target_sizes=null] The sizes of the original images. - * @param {boolean} [is_zero_shot=false] Whether zero-shot object detection was performed. - * @return {Object[]} An array of objects containing the post-processed outputs. - * @private - */ -function post_process_object_detection(outputs, threshold = 0.5, target_sizes = null, is_zero_shot = false) { - const out_logits = outputs.logits; - const out_bbox = outputs.pred_boxes; - const [batch_size, num_boxes, num_classes] = out_logits.dims; - - if (target_sizes !== null && target_sizes.length !== batch_size) { - throw Error("Make sure that you pass in as many target sizes as the batch dimension of the logits") - } - let toReturn = []; - for (let i = 0; i < batch_size; ++i) { - let target_size = target_sizes !== null ? target_sizes[i] : null; - let info = { - boxes: [], - classes: [], - scores: [] - } - let logits = out_logits[i]; - let bbox = out_bbox[i]; - - for (let j = 0; j < num_boxes; ++j) { - let logit = logits[j]; - - let indices = []; - let probs; - if (is_zero_shot) { - // Get indices of classes with high enough probability - probs = logit.sigmoid().data; - for (let k = 0; k < probs.length; ++k) { - if (probs[k] > threshold) { - indices.push(k); - } - } - - } else { - // Get most probable class - let maxIndex = max(logit.data)[1]; - - if (maxIndex === num_classes - 1) { - // This is the background class, skip it - continue; - } - // Compute softmax over classes - probs = softmax(logit.data); - - if (probs[maxIndex] < threshold) { - continue; - } - indices.push(maxIndex); - } - - for (const index of indices) { - - // Some class has a high enough probability - /** @type {number[]} */ - let box = bbox[j].data; - - // convert to [x0, y0, x1, y1] format - box = center_to_corners_format(box) - if (target_size !== null) { - box = box.map((x, i) => x * target_size[(i + 1) % 2]) - } - - info.boxes.push(box); - info.classes.push(index); - info.scores.push(probs[index]); - } - } - toReturn.push(info); - } - return toReturn; -} - - -/** - * Post-processes the outputs of the model (for semantic segmentation). - * @param {*} outputs Raw outputs of the model. - * @param {[number, number][]} [target_sizes=null] List of tuples corresponding to the requested final size - * (height, width) of each prediction. If unset, predictions will not be resized. - * @returns {{segmentation: Tensor; labels: number[]}[]} The semantic segmentation maps. - */ -function post_process_semantic_segmentation(outputs, target_sizes = null) { - - const logits = outputs.logits; - const batch_size = logits.dims[0]; - - if (target_sizes !== null && target_sizes.length !== batch_size) { - throw Error("Make sure that you pass in as many target sizes as the batch dimension of the logits") - } - - const toReturn = []; - for (let i = 0; i < batch_size; ++i) { - const target_size = target_sizes !== null ? target_sizes[i] : null; - - let data = logits[i]; - - // 1. If target_size is not null, we need to resize the masks to the target size - if (target_size !== null) { - // resize the masks to the target size - data = interpolate(data, target_size, 'bilinear', false); - } - const [height, width] = target_size ?? data.dims.slice(-2); - - const segmentation = new Tensor( - 'int32', - new Int32Array(height * width), - [height, width] - ); - - // Buffer to store current largest value - const buffer = data[0].data; - const segmentation_data = segmentation.data; - for (let j = 1; j < data.dims[0]; ++j) { - const row = data[j].data; - for (let k = 0; k < row.length; ++k) { - if (row[k] > buffer[k]) { - buffer[k] = row[k]; - segmentation_data[k] = j; - } - } - } - - // Store which objects have labels - // This is much more efficient that creating a set of the final values - const hasLabel = new Array(data.dims[0]); - for (let j = 0; j < segmentation_data.length; ++j) { - const index = segmentation_data[j]; - hasLabel[index] = index; - } - /** @type {number[]} The unique list of labels that were detected */ - const labels = hasLabel.filter(x => x !== undefined); - - toReturn.push({ segmentation, labels }); - } - return toReturn; -} - - -/** - * Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and `labels`. - * @param {Tensor} class_logits The class logits. - * @param {Tensor} mask_logits The mask logits. - * @param {number} object_mask_threshold A number between 0 and 1 used to binarize the masks. - * @param {number} num_labels The number of labels. - * @returns {[Tensor[], number[], number[]]} The binarized masks, the scores, and the labels. - * @private - */ -function remove_low_and_no_objects(class_logits, mask_logits, object_mask_threshold, num_labels) { - - const mask_probs_item = []; - const pred_scores_item = []; - const pred_labels_item = []; - - for (let j = 0; j < class_logits.dims[0]; ++j) { - const cls = class_logits[j]; - const mask = mask_logits[j]; - - const pred_label = max(cls.data)[1]; - if (pred_label === num_labels) { - // Is the background, so we ignore it - continue; - } - - const scores = softmax(cls.data); - const pred_score = scores[pred_label]; - if (pred_score > object_mask_threshold) { - mask_probs_item.push(mask); - pred_scores_item.push(pred_score); - pred_labels_item.push(pred_label); - } - } - - return [mask_probs_item, pred_scores_item, pred_labels_item]; -} - -/** - * Checks whether the segment is valid or not. - * @param {Int32Array} mask_labels Labels for each pixel in the mask. - * @param {Tensor[]} mask_probs Probabilities for each pixel in the masks. - * @param {number} k The class id of the segment. - * @param {number} mask_threshold The mask threshold. - * @param {number} overlap_mask_area_threshold The overlap mask area threshold. - * @returns {[boolean, number[]]} Whether the segment is valid or not, and the indices of the valid labels. - * @private - */ -function check_segment_validity( - mask_labels, - mask_probs, - k, - mask_threshold = 0.5, - overlap_mask_area_threshold = 0.8 -) { - // mask_k is a 1D array of indices, indicating where the mask is equal to k - const mask_k = []; - let mask_k_area = 0; - let original_area = 0; - - const mask_probs_k_data = mask_probs[k].data; - - // Compute the area of all the stuff in query k - for (let i = 0; i < mask_labels.length; ++i) { - if (mask_labels[i] === k) { - mask_k.push(i); - ++mask_k_area; - } - - if (mask_probs_k_data[i] >= mask_threshold) { - ++original_area; - } - } - let mask_exists = mask_k_area > 0 && original_area > 0; - - // Eliminate disconnected tiny segments - if (mask_exists) { - // Perform additional check - let area_ratio = mask_k_area / original_area; - mask_exists = area_ratio > overlap_mask_area_threshold; - } - - return [mask_exists, mask_k] -} - -/** - * Computes the segments. - * @param {Tensor[]} mask_probs The mask probabilities. - * @param {number[]} pred_scores The predicted scores. - * @param {number[]} pred_labels The predicted labels. - * @param {number} mask_threshold The mask threshold. - * @param {number} overlap_mask_area_threshold The overlap mask area threshold. - * @param {Set} label_ids_to_fuse The label ids to fuse. - * @param {number[]} target_size The target size of the image. - * @returns {[Tensor, Array<{id: number, label_id: number, score: number}>]} The computed segments. - * @private - */ -function compute_segments( - mask_probs, - pred_scores, - pred_labels, - mask_threshold, - overlap_mask_area_threshold, - label_ids_to_fuse = null, - target_size = null, -) { - const [height, width] = target_size ?? mask_probs[0].dims; - - const segmentation = new Tensor( - 'int32', - new Int32Array(height * width), - [height, width] - ); - const segments = []; - - // 1. If target_size is not null, we need to resize the masks to the target size - if (target_size !== null) { - // resize the masks to the target size - for (let i = 0; i < mask_probs.length; ++i) { - mask_probs[i] = interpolate(mask_probs[i], target_size, 'bilinear', false); - } - } - - // 2. Weigh each mask by its prediction score - // NOTE: `mask_probs` is updated in-place - // - // Temporary storage for the best label/scores for each pixel ([height, width]): - const mask_labels = new Int32Array(mask_probs[0].data.length); - const bestScores = new Float32Array(mask_probs[0].data.length); - - for (let i = 0; i < mask_probs.length; ++i) { - let score = pred_scores[i]; - - const mask_probs_i_data = mask_probs[i].data; - - for (let j = 0; j < mask_probs_i_data.length; ++j) { - mask_probs_i_data[j] *= score - if (mask_probs_i_data[j] > bestScores[j]) { - mask_labels[j] = i; - bestScores[j] = mask_probs_i_data[j]; - } - } - } - - let current_segment_id = 0; - - // let stuff_memory_list = {} - const segmentation_data = segmentation.data; - for (let k = 0; k < pred_labels.length; ++k) { - const pred_class = pred_labels[k]; - - // TODO add `should_fuse` - // let should_fuse = pred_class in label_ids_to_fuse - - // Check if mask exists and large enough to be a segment - const [mask_exists, mask_k] = check_segment_validity( - mask_labels, - mask_probs, - k, - mask_threshold, - overlap_mask_area_threshold - ) - - if (!mask_exists) { - // Nothing to see here - continue; - } - - // TODO - // if (pred_class in stuff_memory_list) { - // current_segment_id = stuff_memory_list[pred_class] - // } else { - // current_segment_id += 1; - // } - ++current_segment_id; - - - // Add current object segment to final segmentation map - for (const index of mask_k) { - segmentation_data[index] = current_segment_id; - } - - segments.push({ - id: current_segment_id, - label_id: pred_class, - // was_fused: should_fuse, TODO - score: pred_scores[k], - }) - - // TODO - // if(should_fuse){ - // stuff_memory_list[pred_class] = current_segment_id - // } - } - - return [segmentation, segments]; -} - - -/** - * Post-process the model output to generate the final panoptic segmentation. - * @param {*} outputs The model output to post process - * @param {number} [threshold=0.5] The probability score threshold to keep predicted instance masks. - * @param {number} [mask_threshold=0.5] Threshold to use when turning the predicted masks into binary values. - * @param {number} [overlap_mask_area_threshold=0.8] The overlap mask area threshold to merge or discard small disconnected parts within each binary instance mask. - * @param {Set} [label_ids_to_fuse=null] The labels in this state will have all their instances be fused together. - * @param {[number, number][]} [target_sizes=null] The target sizes to resize the masks to. - * @returns {Array<{ segmentation: Tensor, segments_info: Array<{id: number, label_id: number, score: number}>}>} - */ -function post_process_panoptic_segmentation( - outputs, - threshold = 0.5, - mask_threshold = 0.5, - overlap_mask_area_threshold = 0.8, - label_ids_to_fuse = null, - target_sizes = null, -) { - if (label_ids_to_fuse === null) { - console.warn("`label_ids_to_fuse` unset. No instance will be fused.") - label_ids_to_fuse = new Set(); - } - - const class_queries_logits = outputs.class_queries_logits ?? outputs.logits; // [batch_size, num_queries, num_classes+1] - const masks_queries_logits = outputs.masks_queries_logits ?? outputs.pred_masks; // [batch_size, num_queries, height, width] - - const mask_probs = masks_queries_logits.sigmoid() // [batch_size, num_queries, height, width] - - let [batch_size, num_queries, num_labels] = class_queries_logits.dims; - num_labels -= 1; // Remove last class (background) - - if (target_sizes !== null && target_sizes.length !== batch_size) { - throw Error("Make sure that you pass in as many target sizes as the batch dimension of the logits") - } - - let toReturn = []; - for (let i = 0; i < batch_size; ++i) { - let target_size = target_sizes !== null ? target_sizes[i] : null; - - let class_logits = class_queries_logits[i]; - let mask_logits = mask_probs[i]; - - let [mask_probs_item, pred_scores_item, pred_labels_item] = remove_low_and_no_objects(class_logits, mask_logits, threshold, num_labels); - - if (pred_labels_item.length === 0) { - // No mask found - let [height, width] = target_size ?? mask_logits.dims.slice(-2); - - let segmentation = new Tensor( - 'int32', - new Int32Array(height * width).fill(-1), - [height, width] - ) - toReturn.push({ - segmentation: segmentation, - segments_info: [] - }); - continue; - } - - - // Get segmentation map and segment information of batch item - let [segmentation, segments] = compute_segments( - mask_probs_item, - pred_scores_item, - pred_labels_item, - mask_threshold, - overlap_mask_area_threshold, - label_ids_to_fuse, - target_size, - ) - - toReturn.push({ - segmentation: segmentation, - segments_info: segments - }) - } - - return toReturn; -} - - -/** - * Post-processes the outputs of the model (for instance segmentation). - * @param {*} outputs Raw outputs of the model. - * @param {number} [threshold=0.5] The probability score threshold to keep predicted instance masks. - * @param {[number, number][]} [target_sizes=null] List of tuples corresponding to the requested final size - * (height, width) of each prediction. If unset, predictions will not be resized. - * @returns {Array<{ segmentation: Tensor, segments_info: Array<{id: number, label_id: number, score: number}>}>} - */ -function post_process_instance_segmentation(outputs, threshold = 0.5, target_sizes = null) { - throw new Error('Not implemented yet'); - return []; -} - -/** - * Named tuple to indicate the order we are using is (height x width), even though - * the Graphics’ industry standard is (width x height). - * @typedef {[height: number, width: number]} HeightWidth - */ - -/** - * Helper function to validate audio inputs. - * @param {any} audio The audio data. - * @param {string} feature_extractor The name of the feature extractor. - * @private - */ -function validate_audio_inputs(audio, feature_extractor) { - if (!(audio instanceof Float32Array || audio instanceof Float64Array)) { - throw new Error( - `${feature_extractor} expects input to be a Float32Array or a Float64Array, but got ${audio?.constructor?.name ?? typeof audio} instead. ` + - `If using the feature extractor directly, remember to use \`read_audio(url, sampling_rate)\` to obtain the raw audio data of the file/url.` - ) - } -} - -/** - * Helper function to constrain a value to be a multiple of a number. - * @param {number} val The value to constrain. - * @param {number} multiple The number to constrain to. - * @param {number} [minVal=0] The minimum value to constrain to. - * @param {number} [maxVal=null] The maximum value to constrain to. - * @returns {number} The constrained value. - * @private - */ -function constraint_to_multiple_of(val, multiple, minVal = 0, maxVal = null) { - const a = val / multiple; - let x = bankers_round(a) * multiple; - - if (maxVal !== null && x > maxVal) { - x = Math.floor(a) * multiple; - } - - if (x < minVal) { - x = Math.ceil(a) * multiple; - } - - return x; -} - -/** - * Rounds the height and width down to the closest multiple of size_divisibility - * @param {[number, number]} size The size of the image - * @param {number} divisor The divisor to use. - * @returns {[number, number]} The rounded size. - */ -function enforce_size_divisibility([width, height], divisor) { - return [ - Math.max(Math.floor(width / divisor), 1) * divisor, - Math.max(Math.floor(height / divisor), 1) * divisor - ]; -} - - -/** - * Base class for feature extractors. - * - * @extends Callable - */ -export class FeatureExtractor extends Callable { - /** - * Constructs a new FeatureExtractor instance. - * - * @param {Object} config The configuration for the feature extractor. - */ - constructor(config) { - super(); - this.config = config - } -} - -/** - * @typedef {object} ImageFeatureExtractorResult - * @property {Tensor} pixel_values The pixel values of the batched preprocessed images. - * @property {HeightWidth[]} original_sizes Array of two-dimensional tuples like [[480, 640]]. - * @property {HeightWidth[]} reshaped_input_sizes Array of two-dimensional tuples like [[1000, 1330]]. - */ - -/** - * Feature extractor for image models. - * - * @extends FeatureExtractor - */ -export class ImageFeatureExtractor extends FeatureExtractor { - - /** - * Constructs a new ImageFeatureExtractor instance. - * - * @param {Object} config The configuration for the feature extractor. - * @param {number[]} config.image_mean The mean values for image normalization. - * @param {number[]} config.image_std The standard deviation values for image normalization. - * @param {boolean} config.do_rescale Whether to rescale the image pixel values to the [0,1] range. - * @param {number} config.rescale_factor The factor to use for rescaling the image pixel values. - * @param {boolean} config.do_normalize Whether to normalize the image pixel values. - * @param {boolean} config.do_resize Whether to resize the image. - * @param {number} config.resample What method to use for resampling. - * @param {number|Object} config.size The size to resize the image to. - * @param {boolean} [config.do_flip_channel_order=false] Whether to flip the color channels from RGB to BGR. - * Can be overridden by the `do_flip_channel_order` parameter in the `preprocess` method. - */ - constructor(config) { - super(config); - - this.image_mean = this.config.image_mean ?? this.config.mean; - this.image_std = this.config.image_std ?? this.config.std; - - this.resample = this.config.resample ?? 2; // 2 => bilinear - this.do_rescale = this.config.do_rescale ?? true; - this.rescale_factor = this.config.rescale_factor ?? (1 / 255); - this.do_normalize = this.config.do_normalize; - - this.do_resize = this.config.do_resize; - this.do_thumbnail = this.config.do_thumbnail; - this.size = this.config.size; - this.size_divisibility = this.config.size_divisibility ?? this.config.size_divisor; - - this.do_center_crop = this.config.do_center_crop; - this.crop_size = this.config.crop_size; - this.do_convert_rgb = this.config.do_convert_rgb ?? true; - this.do_crop_margin = this.config.do_crop_margin; - - this.pad_size = this.config.pad_size; - this.do_pad = this.config.do_pad; - - if (this.do_pad && !this.pad_size && this.size && this.size.width !== undefined && this.size.height !== undefined) { - // Should pad, but no pad size specified - // We infer the pad size from the resize size - this.pad_size = this.size - } - - this.do_flip_channel_order = this.config.do_flip_channel_order ?? false; - } - - /** - * Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any - * corresponding dimension of the specified size. - * @param {RawImage} image The image to be resized. - * @param {{height:number, width:number}} size The size `{"height": h, "width": w}` to resize the image to. - * @param {string | 0 | 1 | 2 | 3 | 4 | 5} [resample=2] The resampling filter to use. - * @returns {Promise} The resized image. - */ - async thumbnail(image, size, resample = 2) { - const input_height = image.height; - const input_width = image.width; - - const output_height = size.height; - const output_width = size.width; - - // We always resize to the smallest of either the input or output size. - let height = Math.min(input_height, output_height) - let width = Math.min(input_width, output_width) - - if (height === input_height && width === input_width) { - return image; - } - if (input_height > input_width) { - width = Math.floor(input_width * height / input_height); - } else if (input_width > input_height) { - height = Math.floor(input_height * width / input_width); - } - return await image.resize(width, height, { resample }); - } - - - /** - * Crops the margin of the image. Gray pixels are considered margin (i.e., pixels with a value below the threshold). - * @param {RawImage} image The image to be cropped. - * @param {number} gray_threshold Value below which pixels are considered to be gray. - * @returns {Promise} The cropped image. - */ - async crop_margin(image, gray_threshold = 200) { - - const gray_image = image.clone().grayscale(); - - const minValue = min(gray_image.data)[0]; - const maxValue = max(gray_image.data)[0]; - const diff = maxValue - minValue; - - if (diff === 0) { - return image; - } - - const threshold = gray_threshold / 255; - - let x_min = gray_image.width, y_min = gray_image.height, x_max = 0, y_max = 0; - const gray_image_data = gray_image.data; - for (let j = 0; j < gray_image.height; ++j) { - const row = j * gray_image.width; - for (let i = 0; i < gray_image.width; ++i) { - if ((gray_image_data[row + i] - minValue) / diff < threshold) { - // We have a non-zero pixel, so we update the min/max values accordingly - x_min = Math.min(x_min, i); - y_min = Math.min(y_min, j); - x_max = Math.max(x_max, i); - y_max = Math.max(y_max, j); - } - } - } - - image = await image.crop([x_min, y_min, x_max, y_max]); - return image; - } - - /** - * Pad the image by a certain amount. - * @param {Float32Array} pixelData The pixel data to pad. - * @param {number[]} imgDims The dimensions of the image (height, width, channels). - * @param {{width:number; height:number}|number} padSize The dimensions of the padded image. - * @param {Object} options The options for padding. - * @param {'constant'|'symmetric'} [options.mode='constant'] The type of padding to add. - * @param {boolean} [options.center=false] Whether to center the image. - * @param {number} [options.constant_values=0] The constant value to use for padding. - * @returns {[Float32Array, number[]]} The padded pixel data and image dimensions. - */ - pad_image(pixelData, imgDims, padSize, { - mode = 'constant', - center = false, - constant_values = 0, - } = {}) { - const [imageHeight, imageWidth, imageChannels] = imgDims; - - let paddedImageWidth, paddedImageHeight; - if (typeof padSize === 'number') { - paddedImageWidth = padSize; - paddedImageHeight = padSize; - } else { - paddedImageWidth = padSize.width; - paddedImageHeight = padSize.height; - } - - // Only add padding if there is a difference in size - if (paddedImageWidth !== imageWidth || paddedImageHeight !== imageHeight) { - const paddedPixelData = new Float32Array(paddedImageWidth * paddedImageHeight * imageChannels); - if (Array.isArray(constant_values)) { - // Fill with constant values, cycling through the array - for (let i = 0; i < paddedPixelData.length; ++i) { - paddedPixelData[i] = constant_values[i % imageChannels]; - } - } else if (constant_values !== 0) { - paddedPixelData.fill(constant_values); - } - - const [left, top] = center - ? [Math.floor((paddedImageWidth - imageWidth) / 2), Math.floor((paddedImageHeight - imageHeight) / 2)] - : [0, 0]; - - // Copy the original image into the padded image - for (let i = 0; i < imageHeight; ++i) { - const a = (i + top) * paddedImageWidth; - const b = i * imageWidth; - for (let j = 0; j < imageWidth; ++j) { - const c = (a + j + left) * imageChannels; - const d = (b + j) * imageChannels; - for (let k = 0; k < imageChannels; ++k) { - paddedPixelData[c + k] = pixelData[d + k]; - } - } - } - - if (mode === 'symmetric') { - if (center) { - throw new Error('`center` padding is not supported when `mode` is set to `symmetric`.'); - // TODO: Implement this - } - const h1 = imageHeight - 1; - const w1 = imageWidth - 1; - for (let i = 0; i < paddedImageHeight; ++i) { - const a = i * paddedImageWidth; - const b = calculateReflectOffset(i, h1) * imageWidth; - - for (let j = 0; j < paddedImageWidth; ++j) { - if (i < imageHeight && j < imageWidth) continue; // Do not overwrite original image - const c = (a + j) * imageChannels; - const d = (b + calculateReflectOffset(j, w1)) * imageChannels; - - // Copy channel-wise - for (let k = 0; k < imageChannels; ++k) { - paddedPixelData[c + k] = pixelData[d + k]; - } - } - } - } - - - // Update pixel data and image dimensions - pixelData = paddedPixelData; - imgDims = [paddedImageHeight, paddedImageWidth, imageChannels] - } - return [pixelData, imgDims]; - } - - /** - * Rescale the image' pixel values by `this.rescale_factor`. - * @param {Float32Array} pixelData The pixel data to rescale. - * @returns {void} - */ - rescale(pixelData) { - for (let i = 0; i < pixelData.length; ++i) { - pixelData[i] = this.rescale_factor * pixelData[i]; - } - } - - /** - * Find the target (width, height) dimension of the output image after - * resizing given the input image and the desired size. - * @param {RawImage} image The image to resize. - * @param {any} size The size to use for resizing the image. - * @returns {[number, number]} The target (width, height) dimension of the output image after resizing. - */ - get_resize_output_image_size(image, size) { - // `size` comes in many forms, so we need to handle them all here: - // 1. `size` is an integer, in which case we resize the image to be a square - - const [srcWidth, srcHeight] = image.size; - - let shortest_edge; - let longest_edge; - - if (this.do_thumbnail) { - // NOTE: custom logic for `Donut` models - const { height, width } = size; - shortest_edge = Math.min(height, width) - } - // Support both formats for backwards compatibility - else if (Number.isInteger(size)) { - shortest_edge = size; - longest_edge = this.config.max_size ?? shortest_edge; - - } else if (size !== undefined) { - // Extract known properties from `size` - shortest_edge = size.shortest_edge; - longest_edge = size.longest_edge; - } - - // If `longest_edge` and `shortest_edge` are set, maintain aspect ratio and resize to `shortest_edge` - // while keeping the largest dimension <= `longest_edge` - if (shortest_edge !== undefined || longest_edge !== undefined) { - // http://opensourcehacker.com/2011/12/01/calculate-aspect-ratio-conserving-resize-for-images-in-javascript/ - // Try resize so that shortest edge is `shortest_edge` (target) - const shortResizeFactor = shortest_edge === undefined - ? 1 // If `shortest_edge` is not set, don't upscale - : Math.max(shortest_edge / srcWidth, shortest_edge / srcHeight); - - const newWidth = srcWidth * shortResizeFactor; - const newHeight = srcHeight * shortResizeFactor; - - // The new width and height might be greater than `longest_edge`, so - // we downscale again to ensure the largest dimension is `longest_edge` - const longResizeFactor = longest_edge === undefined - ? 1 // If `longest_edge` is not set, don't downscale - : Math.min(longest_edge / newWidth, longest_edge / newHeight); - - // To avoid certain floating point precision issues, we round to 2 decimal places - let finalWidth = Math.floor(Number((newWidth * longResizeFactor).toFixed(2))); - let finalHeight = Math.floor(Number((newHeight * longResizeFactor).toFixed(2))); - - if (this.size_divisibility !== undefined) { - [finalWidth, finalHeight] = enforce_size_divisibility([finalWidth, finalHeight], this.size_divisibility) - } - return [finalWidth, finalHeight]; - - } else if (size !== undefined && size.width !== undefined && size.height !== undefined) { - // If `width` and `height` are set, resize to those dimensions - - let newWidth = size.width; - let newHeight = size.height; - - // Custom for DPT models - if (this.config.keep_aspect_ratio && this.config.ensure_multiple_of) { - - // determine new height and width - let scale_height = newHeight / srcHeight; - let scale_width = newWidth / srcWidth; - - // scale as little as possible - if (Math.abs(1 - scale_width) < Math.abs(1 - scale_height)) { - // fit width - scale_height = scale_width; - } else { - // fit height - scale_width = scale_height; - } - - newHeight = constraint_to_multiple_of(scale_height * srcHeight, this.config.ensure_multiple_of); - newWidth = constraint_to_multiple_of(scale_width * srcWidth, this.config.ensure_multiple_of); - } - - return [newWidth, newHeight]; - - } else if (this.size_divisibility !== undefined) { - return enforce_size_divisibility([srcWidth, srcHeight], this.size_divisibility); - } else { - throw new Error(`Could not resize image due to unsupported \`this.size\` option in config: ${JSON.stringify(size)}`); - } - } - - /** - * Resizes the image. - * @param {RawImage} image The image to resize. - * @returns {Promise} The resized image. - */ - async resize(image) { - const [newWidth, newHeight] = this.get_resize_output_image_size(image, this.size); - return await image.resize(newWidth, newHeight, { - resample: this.resample, - }); - } - - /** - * @typedef {object} PreprocessedImage - * @property {HeightWidth} original_size The original size of the image. - * @property {HeightWidth} reshaped_input_size The reshaped input size of the image. - * @property {Tensor} pixel_values The pixel values of the preprocessed image. - */ - - /** - * Preprocesses the given image. - * - * @param {RawImage} image The image to preprocess. - * @param {Object} overrides The overrides for the preprocessing options. - * @returns {Promise} The preprocessed image. - */ - async preprocess(image, { - do_normalize = null, - do_pad = null, - do_convert_rgb = null, - do_convert_grayscale = null, - do_flip_channel_order = null, - } = {}) { - if (this.do_crop_margin) { - // NOTE: Specific to nougat processors. This is done before resizing, - // and can be interpreted as a pre-preprocessing step. - image = await this.crop_margin(image); - } - - const [srcWidth, srcHeight] = image.size; // original image size - - // Convert image to RGB if specified in config. - if (do_convert_rgb ?? this.do_convert_rgb) { - image = image.rgb(); - } else if (do_convert_grayscale) { - image = image.grayscale(); - } - - // TODO: - // For efficiency reasons, it might be best to merge the resize and center crop operations into one. - - // Resize all images - if (this.do_resize) { - image = await this.resize(image); - } - - // Resize the image using thumbnail method. - if (this.do_thumbnail) { - image = await this.thumbnail(image, this.size, this.resample); - } - - if (this.do_center_crop) { - - let crop_width; - let crop_height; - if (Number.isInteger(this.crop_size)) { - crop_width = this.crop_size; - crop_height = this.crop_size; - } else { - crop_width = this.crop_size.width; - crop_height = this.crop_size.height; - } - - image = await image.center_crop(crop_width, crop_height); - } - - /** @type {HeightWidth} */ - const reshaped_input_size = [image.height, image.width]; - - // NOTE: All pixel-level manipulation (i.e., modifying `pixelData`) - // occurs with data in the hwc format (height, width, channels), - // to emulate the behavior of the original Python code (w/ numpy). - let pixelData = Float32Array.from(image.data); - let imgDims = [image.height, image.width, image.channels]; - - if (this.do_rescale) { - this.rescale(pixelData); - } - - if (do_normalize ?? this.do_normalize) { - let image_mean = this.image_mean; - if (!Array.isArray(this.image_mean)) { - image_mean = new Array(image.channels).fill(image_mean); - } - - let image_std = this.image_std; - if (!Array.isArray(this.image_std)) { - image_std = new Array(image.channels).fill(image_mean); - } - - if (image_mean.length !== image.channels || image_std.length !== image.channels) { - throw new Error(`When set to arrays, the length of \`image_mean\` (${image_mean.length}) and \`image_std\` (${image_std.length}) must match the number of channels in the image (${image.channels}).`); - } - - for (let i = 0; i < pixelData.length; i += image.channels) { - for (let j = 0; j < image.channels; ++j) { - pixelData[i + j] = (pixelData[i + j] - image_mean[j]) / image_std[j]; - } - } - } - - // do padding after rescaling/normalizing - if (do_pad ?? this.do_pad) { - if (this.pad_size) { - const padded = this.pad_image(pixelData, [image.height, image.width, image.channels], this.pad_size); - [pixelData, imgDims] = padded; // Update pixel data and image dimensions - } else if (this.size_divisibility) { - const [paddedWidth, paddedHeight] = enforce_size_divisibility([imgDims[1], imgDims[0]], this.size_divisibility); - [pixelData, imgDims] = this.pad_image(pixelData, imgDims, { width: paddedWidth, height: paddedHeight }); - } - } - - if (do_flip_channel_order ?? this.do_flip_channel_order) { - if (imgDims[2] !== 3) { - throw new Error('Flipping channel order is only supported for RGB images.'); - } - // Convert RGB to BGR - for (let i = 0; i < pixelData.length; i += 3) { - const temp = pixelData[i]; - pixelData[i] = pixelData[i + 2]; - pixelData[i + 2] = temp; - } - } - - const pixel_values = new Tensor('float32', pixelData, imgDims) - .permute(2, 0, 1); // convert to channel dimension format (hwc -> chw) - - return { - original_size: [srcHeight, srcWidth], - reshaped_input_size: reshaped_input_size, - pixel_values, - } - } - - /** - * Calls the feature extraction process on an array of images, - * preprocesses each image, and concatenates the resulting - * features into a single Tensor. - * @param {RawImage[]} images The image(s) to extract features from. - * @param {...any} args Additional arguments. - * @returns {Promise} An object containing the concatenated pixel values (and other metadata) of the preprocessed images. - */ - async _call(images, ...args) { - if (!Array.isArray(images)) { - images = [images]; - } - /** @type {PreprocessedImage[]} */ - const imageData = await Promise.all(images.map(x => this.preprocess(x))); - - // Stack pixel values - const pixel_values = stack(imageData.map(x => x.pixel_values), 0); - - return { - pixel_values, - - // Original sizes of images - original_sizes: imageData.map(x => x.original_size), - - // Reshaped sizes of images, before padding or cropping - reshaped_input_sizes: imageData.map(x => x.reshaped_input_size), - } - } - -} - -export class SapiensFeatureExtractor extends ImageFeatureExtractor { - /** @type {typeof post_process_semantic_segmentation} */ - post_process_semantic_segmentation(...args) { - return post_process_semantic_segmentation(...args); - } -} -export class SegformerFeatureExtractor extends ImageFeatureExtractor { - /** @type {typeof post_process_semantic_segmentation} */ - post_process_semantic_segmentation(...args) { - return post_process_semantic_segmentation(...args); - } -} -export class PvtImageProcessor extends ImageFeatureExtractor { } -export class DPTFeatureExtractor extends ImageFeatureExtractor { } -export class DPTImageProcessor extends DPTFeatureExtractor { } // NOTE: extends DPTFeatureExtractor -export class BitImageProcessor extends ImageFeatureExtractor { } -export class GLPNFeatureExtractor extends ImageFeatureExtractor { } -export class CLIPFeatureExtractor extends ImageFeatureExtractor { } -export class CLIPImageProcessor extends CLIPFeatureExtractor { } // NOTE: extends CLIPFeatureExtractor -export class ChineseCLIPFeatureExtractor extends ImageFeatureExtractor { } -export class SiglipImageProcessor extends ImageFeatureExtractor { } -export class ConvNextFeatureExtractor extends ImageFeatureExtractor { - constructor(config) { - super(config); - - /** - * Percentage of the image to crop. Only has an effect if this.size < 384. - */ - this.crop_pct = this.config.crop_pct ?? (224 / 256); - } - - async resize(image) { - const shortest_edge = this.size?.shortest_edge; - if (shortest_edge === undefined) { - throw new Error(`Size dictionary must contain 'shortest_edge' key.`); - } - - if (shortest_edge < 384) { - // maintain same ratio, resizing shortest edge to shortest_edge/crop_pct - const resize_shortest_edge = Math.floor(shortest_edge / this.crop_pct); - - const [newWidth, newHeight] = this.get_resize_output_image_size(image, { - shortest_edge: resize_shortest_edge, - }); - - image = await image.resize(newWidth, newHeight, { - resample: this.resample, - }); - - // then crop to (shortest_edge, shortest_edge) - image = await image.center_crop(shortest_edge, shortest_edge); - } else { - // warping (no cropping) when evaluated at 384 or larger - image = await image.resize(shortest_edge, shortest_edge, { - resample: this.resample, - }); - } - - return image; - } -} -export class ConvNextImageProcessor extends ConvNextFeatureExtractor { } // NOTE extends ConvNextFeatureExtractor -export class ViTFeatureExtractor extends ImageFeatureExtractor { } -export class ViTImageProcessor extends ImageFeatureExtractor { } - -export class EfficientNetImageProcessor extends ImageFeatureExtractor { - constructor(config) { - super(config); - this.include_top = this.config.include_top ?? true; - if (this.include_top) { - this.image_std = this.image_std.map(x => x * x); - } - } -} - -export class MobileNetV1FeatureExtractor extends ImageFeatureExtractor { } -export class MobileNetV2FeatureExtractor extends ImageFeatureExtractor { } -export class MobileNetV3FeatureExtractor extends ImageFeatureExtractor { } -export class MobileNetV4FeatureExtractor extends ImageFeatureExtractor { } - -export class MobileViTFeatureExtractor extends ImageFeatureExtractor { } -export class MobileViTImageProcessor extends MobileViTFeatureExtractor { } // NOTE extends MobileViTFeatureExtractor -export class OwlViTFeatureExtractor extends ImageFeatureExtractor { - /** @type {typeof post_process_object_detection} */ - post_process_object_detection(...args) { - return post_process_object_detection(...args); - } -} -export class Owlv2ImageProcessor extends OwlViTFeatureExtractor { } // NOTE extends OwlViTFeatureExtractor - -export class RTDetrImageProcessor extends ImageFeatureExtractor { - /** @type {typeof post_process_object_detection} */ - post_process_object_detection(...args) { - return post_process_object_detection(...args); - } -} - -export class DeiTFeatureExtractor extends ImageFeatureExtractor { } -export class BeitFeatureExtractor extends ImageFeatureExtractor { } -export class DonutFeatureExtractor extends ImageFeatureExtractor { - pad_image(pixelData, imgDims, padSize, options = {}) { - const [imageHeight, imageWidth, imageChannels] = imgDims; - - let image_mean = this.image_mean; - if (!Array.isArray(this.image_mean)) { - image_mean = new Array(imageChannels).fill(image_mean); - } - - let image_std = this.image_std; - if (!Array.isArray(image_std)) { - image_std = new Array(imageChannels).fill(image_mean); - } - - const constant_values = image_mean.map((x, i) => - x / image_std[i]); - - return super.pad_image(pixelData, imgDims, padSize, { - center: true, - - // Since normalization is done after padding, we need to use certain constant values to ensure the same behaviour is observed. - // For more information, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/donut/image_processing_donut.py#L433-L451 - constant_values: constant_values, - ...options, - }); - } -} -export class DonutImageProcessor extends DonutFeatureExtractor { } // NOTE extends DonutFeatureExtractor -export class NougatImageProcessor extends DonutFeatureExtractor { } // NOTE extends DonutFeatureExtractor - -/** - * @typedef {object} DetrFeatureExtractorResultProps - * @property {Tensor} pixel_mask - * @typedef {ImageFeatureExtractorResult & DetrFeatureExtractorResultProps} DetrFeatureExtractorResult - */ - -/** - * Detr Feature Extractor. - * - * @extends ImageFeatureExtractor - */ -export class DetrFeatureExtractor extends ImageFeatureExtractor { - /** - * Calls the feature extraction process on an array of images, preprocesses - * each image, and concatenates the resulting features into a single Tensor. - * @param {RawImage[]} images The image(s) to extract features from. - * @returns {Promise} An object containing the concatenated pixel values of the preprocessed images. - */ - async _call(images) { - const result = await super._call(images); - - // TODO support differently-sized images, for now assume all images are the same size. - // TODO support different mask sizes (not just 64x64) - // Currently, just fill pixel mask with 1s - const maskSize = [result.pixel_values.dims[0], 64, 64]; - const pixel_mask = full(maskSize, 1n); - - return { ...result, pixel_mask }; - } - - /** @type {typeof post_process_object_detection} */ - post_process_object_detection(...args) { - return post_process_object_detection(...args); - } - - /** @type {typeof post_process_panoptic_segmentation} */ - post_process_panoptic_segmentation(...args) { - return post_process_panoptic_segmentation(...args); - } - - post_process_instance_segmentation() { - // TODO - throw Error("Not implemented yet"); - } -} - -export class MaskFormerFeatureExtractor extends ImageFeatureExtractor { - - /** @type {typeof post_process_panoptic_segmentation} */ - post_process_panoptic_segmentation(...args) { - return post_process_panoptic_segmentation(...args); - } - - post_process_instance_segmentation() { - // TODO - throw Error("Not implemented yet"); - } -} - - -export class YolosFeatureExtractor extends ImageFeatureExtractor { - /** @type {typeof post_process_object_detection} */ - post_process_object_detection(...args) { - return post_process_object_detection(...args); - } -} - -/** - * @typedef {object} SamImageProcessorResult - * @property {Tensor} pixel_values - * @property {HeightWidth[]} original_sizes - * @property {HeightWidth[]} reshaped_input_sizes - * @property {Tensor} [input_points] - * @property {Tensor} [input_labels] - * @property {Tensor} [input_boxes] - */ - -export class SamImageProcessor extends ImageFeatureExtractor { - - /** - * - * @param {any} input_points - * @param {HeightWidth[]} original_sizes - * @param {HeightWidth[]} reshaped_input_sizes - * @returns {Tensor} - */ - reshape_input_points(input_points, original_sizes, reshaped_input_sizes, is_bounding_box = false) { - - // Make deep copy to avoid altering user's input - input_points = structuredClone(input_points); - let shape = calculateDimensions(input_points); - - // TODO: add support for 2D input_points - if (shape.length === 3) { - // Correct user's input - if (!is_bounding_box) { - shape = [1, ...shape]; - } - input_points = [input_points]; - } else if (shape.length !== 4) { - throw Error("The input_points must be a 4D tensor of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.") - } - - // Reshape input points - for (let i = 0; i < input_points.length; ++i) { // batch_size - let originalImageSize = original_sizes[i]; - let reshapedImageSize = reshaped_input_sizes[i]; - - let resizeFactors = [ - reshapedImageSize[0] / originalImageSize[0], - reshapedImageSize[1] / originalImageSize[1] - ] - - for (let j = 0; j < input_points[i].length; ++j) { // point_batch_size - for (let k = 0; k < input_points[i][j].length; ++k) { // nb_points_per_image - for (let w = 0; w < input_points[i][j][k].length; ++w) { // 2 or 4 - input_points[i][j][k][w] *= resizeFactors[w % 2]; - } - } - } - } - - return new Tensor( - 'float32', - Float32Array.from(input_points.flat(Infinity)), - shape - ) - - } - - /** - * - * @param {any} input_labels - * @param {Tensor} input_points - * @returns {Tensor} - */ - add_input_labels(input_labels, input_points) { - let shape = calculateDimensions(input_labels); - if (shape.length === 2) { - // Correct user's input - shape = [1, ...shape]; - input_labels = [input_labels]; - } else if (shape.length !== 3) { - throw Error("The input_points must be a 4D tensor of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.") - } - - if (shape.some((x, i) => x !== input_points.dims[i])) { - throw Error(`The first ${shape.length} dimensions of 'input_points' and 'input_labels' must be the same.`) - } - return new Tensor( - 'int64', - input_labels.flat(Infinity).map(BigInt), - shape, - ) - } - /** - * @param {any[]} images The URL(s) of the image(s) to extract features from. - * @param {Object} [options] Additional options for the processor. - * @param {any} [options.input_points=null] A 3D or 4D array, representing the input points provided by the user. - * - 3D: `[point_batch_size, nb_points_per_image, 2]`. In this case, `batch_size` is assumed to be 1. - * - 4D: `[batch_size, point_batch_size, nb_points_per_image, 2]`. - * @param {any} [options.input_labels=null] A 2D or 3D array, representing the input labels for the points, used by the prompt encoder to encode the prompt. - * - 2D: `[point_batch_size, nb_points_per_image]`. In this case, `batch_size` is assumed to be 1. - * - 3D: `[batch_size, point_batch_size, nb_points_per_image]`. - * @param {number[][][]} [options.input_boxes=null] A 3D array of shape `(batch_size, num_boxes, 4)`, representing the input boxes provided by the user. - * This is used by the prompt encoder to encode the prompt. Generally yields to much better generated masks. - * The processor will generate a tensor, with each dimension corresponding respectively to the image batch size, - * the number of boxes per image and the coordinates of the top left and botton right point of the box. - * In the order (`x1`, `y1`, `x2`, `y2`): - * - `x1`: the x coordinate of the top left point of the input box - * - `y1`: the y coordinate of the top left point of the input box - * - `x2`: the x coordinate of the bottom right point of the input box - * - `y2`: the y coordinate of the bottom right point of the input box - * @returns {Promise} - */ - async _call(images, { - input_points = null, - input_labels = null, - input_boxes = null - } = {}) { - // TODO allow user to use preprocessed images - /** @type {SamImageProcessorResult} */ - const processed = await super._call(images); - - if (input_points) { - processed.input_points = this.reshape_input_points( - input_points, processed.original_sizes, processed.reshaped_input_sizes - ); - } - - if (input_labels) { - if (!processed.input_points) { - throw Error("`input_points` must be provided if `input_labels` are provided.") - } - processed.input_labels = this.add_input_labels(input_labels, processed.input_points); - } - - if (input_boxes) { - processed.input_boxes = this.reshape_input_points( - input_boxes, processed.original_sizes, processed.reshaped_input_sizes, true, - ); - } - - return processed; - } - - /** - * Remove padding and upscale masks to the original image size. - * @param {Tensor} masks Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. - * @param {[number, number][]} original_sizes The original sizes of each image before it was resized to the model's expected input shape, in (height, width) format. - * @param {[number, number][]} reshaped_input_sizes The size of each image as it is fed to the model, in (height, width) format. Used to remove padding. - * @param {Object} options Optional parameters for post-processing. - * @param {number} [options.mask_threshold] The threshold to use for binarizing the masks. - * @param {boolean} [options.binarize] Whether to binarize the masks. - * @param {Object} [options.pad_size] The target size the images were padded to before being passed to the model. If `null`, the target size is assumed to be the processor's `pad_size`. - * @param {number} [options.pad_size.height] The height the images were padded to. - * @param {number} [options.pad_size.width] The width the images were padded to. - * @returns {Promise} Batched masks in batch_size, num_channels, height, width) format, where (height, width) is given by original_size. - */ - async post_process_masks(masks, original_sizes, reshaped_input_sizes, { - mask_threshold = 0.0, - binarize = true, - pad_size = null, - } = {}) { - // masks: [1, 1, 3, 256, 256] - - const output_masks = []; - - pad_size = pad_size ?? this.pad_size; - - /** @type {[number, number]} */ - const target_image_size = [pad_size.height, pad_size.width]; - - for (let i = 0; i < original_sizes.length; ++i) { - const original_size = original_sizes[i]; - const reshaped_input_size = reshaped_input_sizes[i]; - - // Upscale mask to padded size - let interpolated_mask = (await interpolate_4d( - masks[i], - { mode: 'bilinear', size: target_image_size } - )); - - // Crop mask - interpolated_mask = interpolated_mask.slice(null, null, [0, reshaped_input_size[0]], [0, reshaped_input_size[1]]); - - // Downscale mask - interpolated_mask = (await interpolate_4d( - interpolated_mask, - { mode: 'bilinear', size: original_size } - )); - - if (binarize) { - const data = interpolated_mask.data; - const binarizedMaskData = new Uint8Array(data.length); - for (let i = 0; i < data.length; ++i) { - if (data[i] > mask_threshold) { - binarizedMaskData[i] = 1; - } - } - interpolated_mask = new Tensor( - 'bool', - binarizedMaskData, - interpolated_mask.dims - ) - } - - output_masks.push(interpolated_mask); - } - - return output_masks; - } - - /** - * Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. - * @param {RawImage} image Input original image - * @param {number} target_size Target size of the resized image - * @param {Object} options Options for generating crop boxes - * @param {number} [options.crop_n_layers] If >0, mask prediction will be run again on crops of the image. - * Sets the number of layers to run, where each layer has 2**i_layer number of image crops. - * @param {number} [options.overlap_ratio] Sets the degree to which crops overlap. In the first crop layer, - * crops will overlap by this fraction of the image length. Later layers with more crops scale down this overlap. - * @param {number} [options.points_per_crop] Number of points to sample from each crop. - * @param {number} [options.crop_n_points_downscale_factor] The number of points-per-side sampled in layer n is - * scaled down by crop_n_points_downscale_factor**n. - * @returns {Object} An object containing the crop boxes, number of points per crop, cropped images, and input labels. - */ - generate_crop_boxes(image, target_size, { - crop_n_layers = 0, - overlap_ratio = 512 / 1500, - points_per_crop = 32, - crop_n_points_downscale_factor = 1, - } = {}) { - // TODO: Implement - // return { crop_boxes, points_per_crop, cropped_images, input_labels } - } -} - -export class Swin2SRImageProcessor extends ImageFeatureExtractor { - pad_image(pixelData, imgDims, padSize, options = {}) { - // NOTE: In this case, `padSize` represents the size of the sliding window for the local attention. - // In other words, the image is padded so that its width and height are multiples of `padSize`. - const [imageHeight, imageWidth, imageChannels] = imgDims; - - return super.pad_image(pixelData, imgDims, { - // NOTE: For Swin2SR models, the original python implementation adds padding even when the image's width/height is already - // a multiple of `pad_size`. However, this is most likely a bug (PR: https://github.com/mv-lab/swin2sr/pull/19). - // For this reason, we only add padding when the image's width/height is not a multiple of `pad_size`. - width: imageWidth + (padSize - imageWidth % padSize) % padSize, - height: imageHeight + (padSize - imageHeight % padSize) % padSize, - }, { - mode: 'symmetric', - center: false, - constant_values: -1, - ...options, - }) - } -} - -export class VitMatteImageProcessor extends ImageFeatureExtractor { - /** - * Calls the feature extraction process on an array of images, preprocesses - * each image, and concatenates the resulting features into a single Tensor. - * @param {RawImage[]} images The image(s) to extract features from. - * @param {RawImage[]} trimaps The trimaps(s) to extract features from. - * @returns {Promise} An object containing the concatenated pixel values of the preprocessed images. - */ - async _call(images, trimaps) { - if (!Array.isArray(images)) { - images = [images]; - } - if (!Array.isArray(trimaps)) { - trimaps = [trimaps]; - } - - const imageData = await Promise.all(images.map(x => this.preprocess(x))); - const trimapData = await Promise.all(trimaps.map(x => this.preprocess(x, { - do_normalize: false, - do_convert_rgb: false, - do_convert_grayscale: true, - }))); - - - // Stack pixel values - const pixel_values = stack(imageData.map( - // Concatenate images and trimaps - (x, i) => cat([x.pixel_values, trimapData[i].pixel_values], 0) - ), 0); - - return { - pixel_values, - - // Original sizes of images - original_sizes: imageData.map(x => x.original_size), - - // Reshaped sizes of images, before padding or cropping - reshaped_input_sizes: imageData.map(x => x.reshaped_input_size), - } - } -} - -export class WhisperFeatureExtractor extends FeatureExtractor { - - constructor(config) { - super(config); - - // Prefer given `mel_filters` from preprocessor_config.json, or calculate them if they don't exist. - this.config.mel_filters ??= mel_filter_bank( - Math.floor(1 + this.config.n_fft / 2), // num_frequency_bins - this.config.feature_size, // num_mel_filters - 0.0, // min_frequency - 8000.0, // max_frequency - this.config.sampling_rate, // sampling_rate - "slaney", // norm - "slaney", // mel_scale - ); - - this.window = window_function(this.config.n_fft, 'hann'); - } - - /** - * Computes the log-Mel spectrogram of the provided audio waveform. - * @param {Float32Array|Float64Array} waveform The audio waveform to process. - * @returns {Promise} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers. - */ - async _extract_fbank_features(waveform) { - const features = await spectrogram( - waveform, - this.window, // window - this.config.n_fft, // frame_length - this.config.hop_length, // hop_length - { - power: 2.0, - mel_filters: this.config.mel_filters, - log_mel: 'log10', - - // Custom - max_num_frames: this.config.nb_max_frames, // 3000 - } - ) - - const data = features.data; - const maxValue = max(data)[0]; - - for (let i = 0; i < data.length; ++i) { - data[i] = (Math.max(data[i], maxValue - 8.0) + 4.0) / 4.0; - } - - return features; - } - - /** - * Asynchronously extracts features from a given audio using the provided configuration. - * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array. - * @returns {Promise<{ input_features: Tensor }>} A Promise resolving to an object containing the extracted input features as a Tensor. - */ - async _call(audio) { - validate_audio_inputs(audio, 'WhisperFeatureExtractor'); - - let waveform; - if (audio.length > this.config.n_samples) { - console.warn( - "Attempting to extract features for audio longer than 30 seconds. " + - "If using a pipeline to extract transcript from a long audio clip, " + - "remember to specify `chunk_length_s` and/or `stride_length_s`." - ); - waveform = audio.slice(0, this.config.n_samples); - } else { - // pad with zeros - waveform = new Float32Array(this.config.n_samples); - waveform.set(audio); - } - - const features = await this._extract_fbank_features(waveform); - - return { - input_features: features.unsqueeze_(0) - }; - } -} - -export class Wav2Vec2FeatureExtractor extends FeatureExtractor { - - /** - * @param {Float32Array} input_values - * @returns {Float32Array} - */ - _zero_mean_unit_var_norm(input_values) { - // TODO support batch? - const sum = input_values.reduce((a, b) => a + b, 0); - const mean = sum / input_values.length; - const variance = input_values.reduce((a, b) => a + (b - mean) ** 2, 0) / input_values.length; - return input_values.map(x => (x - mean) / Math.sqrt(variance + 1e-7)); - } - - /** - * Asynchronously extracts features from a given audio using the provided configuration. - * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array. - * @returns {Promise<{ input_values: Tensor; attention_mask: Tensor }>} A Promise resolving to an object containing the extracted input features and attention mask as Tensors. - */ - async _call(audio) { - validate_audio_inputs(audio, 'Wav2Vec2FeatureExtractor'); - - if (audio instanceof Float64Array) { - audio = new Float32Array(audio); - } - - let input_values = audio; - - // zero-mean and unit-variance normalization - if (this.config.do_normalize) { - input_values = this._zero_mean_unit_var_norm(input_values); - } - - // TODO: allow user to pass in attention mask - const shape = [1, input_values.length]; - return { - input_values: new Tensor('float32', input_values, shape), - attention_mask: new Tensor('int64', new BigInt64Array(input_values.length).fill(1n), shape) - }; - } -} - -export class SeamlessM4TFeatureExtractor extends FeatureExtractor { - - constructor(config) { - super(config); - - const sampling_rate = this.config.sampling_rate; - const mel_filters = mel_filter_bank( - 256, // num_frequency_bins - this.config.num_mel_bins, // num_mel_filters - 20, // min_frequency - Math.floor(sampling_rate / 2), // max_frequency - sampling_rate, // sampling_rate - null, // norm - "kaldi", // mel_scale - true, // triangularize_in_mel_space - ); - - // Do padding: - for (let i = 0; i < mel_filters.length; ++i) { - mel_filters[i].push(0); - } - this.mel_filters = mel_filters; - - this.window = window_function(400, 'povey', { - periodic: false, - }) - } - - /** - * Computes the log-Mel spectrogram of the provided audio waveform. - * @param {Float32Array|Float64Array} waveform The audio waveform to process. - * @param {number} max_length The maximum number of frames to return. - * @returns {Promise} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers. - */ - async _extract_fbank_features(waveform, max_length) { - // NOTE: We don't pad/truncate since that is passed in as `max_num_frames` - - // Kaldi compliance: 16-bit signed integers - // 32768 == 2 ** 15 - waveform = waveform.map((/** @type {number} */ x) => x * 32768) - - return spectrogram( - waveform, - this.window, // window - 400, // frame_length - 160, // hop_length - { - fft_length: 512, - power: 2.0, - center: false, - preemphasis: 0.97, - mel_filters: this.mel_filters, - log_mel: 'log', - mel_floor: 1.192092955078125e-07, - remove_dc_offset: true, - - // Custom - max_num_frames: max_length, - transpose: true, - } - ) - } - - /** - * Asynchronously extracts features from a given audio using the provided configuration. - * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array. - * @param {Object} options Optional parameters for feature extraction. - * @param {boolean} [options.padding=true] Whether to pad the sequence to a multiple of `pad_to_multiple_of`. - * @param {number} [options.pad_to_multiple_of=2] The number to pad the sequence to a multiple of. - * @param {boolean} [options.do_normalize_per_mel_bins=true] Whether or not to zero-mean unit-variance normalize the input per mel-channel. - * @param {boolean} [options.return_attention_mask=true] Whether to return the attention mask. - * @returns {Promise<{ input_features: Tensor, attention_mask?: Tensor }>} A Promise resolving to an object containing the extracted input features and attention masks as Tensors. - */ - async _call(audio, { - padding = true, - pad_to_multiple_of = 2, - do_normalize_per_mel_bins = true, - return_attention_mask = true, - } = {}) { - validate_audio_inputs(audio, 'SeamlessM4TFeatureExtractor'); - - let features = await this._extract_fbank_features(audio, this.config.max_length); - - if (do_normalize_per_mel_bins) { - const [num_features, feature_size] = features.dims; - const data = features.data; - for (let i = 0; i < feature_size; ++i) { - let sum = 0; - for (let j = 0; j < num_features; ++j) { - sum += data[j * feature_size + i]; - } - - const mean = sum / num_features; - - let variance = 0; - for (let j = 0; j < num_features; ++j) { - variance += (data[j * feature_size + i] - mean) ** 2; - } - variance /= num_features - 1; // NOTE: We use ddof=1 - - const std = Math.sqrt(variance + 1e-7); - for (let j = 0; j < num_features; ++j) { - const index = j * feature_size + i; - data[index] = (data[index] - mean) / std; - } - } - } - - let padded_attention_mask; - if (padding) { - const [num_frames, num_channels] = features.dims; - const data = /** @type {Float32Array} */(features.data); - - const pad_size = num_frames % pad_to_multiple_of; - if (pad_size > 0) { - const padded_data = new Float32Array(num_channels * (num_frames + pad_size)); - padded_data.set(data) - padded_data.fill(this.config.padding_value, data.length) - - const numPaddedFrames = num_frames + pad_size; - features = new Tensor( - features.type, - padded_data, - [numPaddedFrames, num_channels], - ) - - if (return_attention_mask) { - padded_attention_mask = new Tensor( - 'int64', - new BigInt64Array(numPaddedFrames), - [1, numPaddedFrames], - ) - padded_attention_mask.data.fill(1n, 0, num_frames); - } - } - } - - const [num_frames, num_channels] = features.dims; - - const stride = this.config.stride; - const remainder = num_frames % stride; - if (remainder !== 0) { - throw new Error(`The number of frames (${num_frames}) must be a multiple of the stride (${stride}).`) - } - - const input_features = features.view( - 1, - Math.floor(num_frames / stride), - num_channels * stride, - ); - - const result = { input_features } - - if (return_attention_mask) { - const reshapedNumFrames = input_features.dims[1]; - - const attention_mask_data = new BigInt64Array(reshapedNumFrames); - - if (padded_attention_mask) { - const padded_attention_mask_data = padded_attention_mask.data; - for (let i = 1, j = 0; i < num_frames; i += stride, ++j) { - attention_mask_data[j] = padded_attention_mask_data[i]; - } - } else { - attention_mask_data.fill(1n); - } - result.attention_mask = new Tensor( - 'int64', - attention_mask_data, - [1, reshapedNumFrames], - ); - } - - return result; - } -} - -export class ASTFeatureExtractor extends FeatureExtractor { - - - constructor(config) { - super(config); - - const sampling_rate = this.config.sampling_rate; - const mel_filters = mel_filter_bank( - 256, // num_frequency_bins - this.config.num_mel_bins, // num_mel_filters - 20, // min_frequency - Math.floor(sampling_rate / 2), // max_frequency - sampling_rate, // sampling_rate - null, // norm - "kaldi", // mel_scale - true, // triangularize_in_mel_space - ); - - // Do padding: - for (let i = 0; i < mel_filters.length; ++i) { - mel_filters[i].push(0); - } - this.mel_filters = mel_filters; - - this.window = window_function(400, 'hann', { - periodic: false, - }) - - this.mean = this.config.mean; - this.std = this.config.std; - } - - /** - * Computes the log-Mel spectrogram of the provided audio waveform. - * @param {Float32Array|Float64Array} waveform The audio waveform to process. - * @param {number} max_length The maximum number of frames to return. - * @returns {Promise} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers. - */ - async _extract_fbank_features(waveform, max_length) { - // NOTE: We don't pad/truncate since that is passed in as `max_num_frames` - return spectrogram( - waveform, - this.window, // window - 400, // frame_length - 160, // hop_length - { - fft_length: 512, - power: 2.0, - center: false, - preemphasis: 0.97, - mel_filters: this.mel_filters, - log_mel: 'log', - mel_floor: 1.192092955078125e-07, - remove_dc_offset: true, - - // Custom - max_num_frames: max_length, - transpose: true, - } - ) - } - - - /** - * Asynchronously extracts features from a given audio using the provided configuration. - * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array. - * @returns {Promise<{ input_values: Tensor }>} A Promise resolving to an object containing the extracted input features as a Tensor. - */ - async _call(audio) { - validate_audio_inputs(audio, 'ASTFeatureExtractor'); - - const features = await this._extract_fbank_features(audio, this.config.max_length); - if (this.config.do_normalize) { - // Normalize the input audio spectrogram to have mean=0, std=0.5 - const denom = this.std * 2; - const features_data = features.data; - for (let i = 0; i < features_data.length; ++i) { - features_data[i] = (features_data[i] - this.mean) / denom; - } - } - - return { - input_values: features.unsqueeze_(0) - }; - } -} - -export class ClapFeatureExtractor extends FeatureExtractor { - - constructor(config) { - super(config); - - this.mel_filters = mel_filter_bank( - this.config.nb_frequency_bins, // num_frequency_bins - this.config.feature_size, // num_mel_filters - this.config.frequency_min, // min_frequency - this.config.frequency_max, // max_frequency - this.config.sampling_rate, // sampling_rate - null, // norm - "htk", // mel_scale - ); - - this.mel_filters_slaney = mel_filter_bank( - this.config.nb_frequency_bins, // num_frequency_bins - this.config.feature_size, // num_mel_filters - this.config.frequency_min, // min_frequency - this.config.frequency_max, // max_frequency - this.config.sampling_rate, // sampling_rate - "slaney", // norm - "slaney", // mel_scale - ); - - this.window = window_function(this.config.fft_window_size, 'hann') - - } - - - /** - * Extracts the mel spectrogram and prepares it for the mode based on the `truncation` and `padding` arguments. - * - * Four different path are possible: - * - `truncation="fusion"` and the length of the waveform is greater than the max length: the mel spectrogram - * will be computed on the entire audio. 3 random crops and a dowsampled version of the full mel spectrogram - * are then stacked together. They will later be used for `feature_fusion`. - * - `truncation="rand_trunc"` and the length of the waveform is smaller than the max length: the audio is - * padded based on `padding`. - * - `truncation="fusion"` and the length of the waveform is smaller than the max length: the audio is padded - * based on `padding`, and is repeated `4` times. - * - `truncation="rand_trunc"` and the length of the waveform is greater than the max length: the mel - * spectrogram will be computed on a random crop of the waveform. - * - * @param {Float32Array|Float64Array} waveform The input waveform. - * @param {number} max_length The maximum length of the waveform. - * @param {string} truncation The truncation strategy to use. - * @param {string} padding The padding strategy to use. - * @returns {Promise} An object containing the mel spectrogram data as a Float32Array, its dimensions as an array of numbers, and a boolean indicating whether the waveform was longer than the max length. - * @private - */ - async _get_input_mel(waveform, max_length, truncation, padding) { - - /** @type {Tensor} */ - let input_mel; - let longer = false; - const diff = waveform.length - max_length; - if (diff > 0) { - if (truncation === 'rand_trunc') { - longer = true; - const idx = Math.floor(Math.random() * (diff + 1)); - waveform = waveform.subarray(idx, idx + max_length); - - input_mel = await this._extract_fbank_features(waveform, this.mel_filters_slaney, this.config.nb_max_samples); - } else { - // TODO implement fusion strategy - throw new Error(`Truncation strategy "${truncation}" not implemented`) - } - } else { - if (diff < 0) { - let padded = new Float64Array(max_length); // already padded with zeros - padded.set(waveform); - - if (padding === 'repeat') { - for (let i = waveform.length; i < max_length; i += waveform.length) { - padded.set(waveform.subarray(0, Math.min(waveform.length, max_length - i)), i); - } - } else if (padding === 'repeatpad') { - for (let i = waveform.length; i < -diff; i += waveform.length) { - padded.set(waveform, i); - } - } - waveform = padded; - } - - if (truncation === 'fusion') { - throw new Error(`Truncation strategy "${truncation}" not implemented`) - } - - input_mel = await this._extract_fbank_features(waveform, this.mel_filters_slaney, this.config.nb_max_samples); - } - - return input_mel.unsqueeze_(0); - } - - /** - * Compute the log-mel spectrogram of the provided `waveform` using the Hann window. - * In CLAP, two different filter banks are used depending on the truncation pattern: - * - `self.mel_filters`: they correspond to the default parameters of `torchaudio` which can be obtained from - * calling `torchaudio.transforms.MelSpectrogram().mel_scale.fb`. These filters are used when `truncation` - * is set to `"fusion"`. - * - `self.mel_filteres_slaney` : they correspond to the default parameters of `librosa` which used - * `librosa.filters.mel` when computing the mel spectrogram. These filters were only used in the original - * implementation when the truncation mode is not `"fusion"`. - * - * @param {Float32Array|Float64Array} waveform The audio waveform to process. - * @param {number[][]} mel_filters The mel filters to use. - * @param {number} [max_length=null] The maximum number of frames to return. - * @returns {Promise} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers. - */ - async _extract_fbank_features(waveform, mel_filters, max_length = null) { - // NOTE: We don't pad/truncate since that is passed in as `max_num_frames` - return spectrogram( - waveform, - this.window, // window - this.config.fft_window_size, // frame_length - this.config.hop_length, // hop_length - { - power: 2.0, - mel_filters, - log_mel: 'dB', - - // Custom - max_num_frames: max_length, - do_pad: false, - transpose: true, - } - ) - } - - - /** - * Asynchronously extracts features from a given audio using the provided configuration. - * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array. - * @returns {Promise<{ input_features: Tensor }>} A Promise resolving to an object containing the extracted input features as a Tensor. - */ - async _call(audio, { - max_length = null, - } = {}) { - validate_audio_inputs(audio, 'ClapFeatureExtractor'); - - // convert to mel spectrogram, truncate and pad if needed. - const padded_inputs = await this._get_input_mel( - audio, - max_length ?? this.config.nb_max_samples, - this.config.truncation, - this.config.padding, - ); - - return { - input_features: padded_inputs.unsqueeze_(0), - } - } -} - - -export class PyAnnoteFeatureExtractor extends FeatureExtractor { - /** - * Asynchronously extracts features from a given audio using the provided configuration. - * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array. - * @returns {Promise<{ input_values: Tensor; }>} The extracted input features. - */ - async _call(audio) { - validate_audio_inputs(audio, 'PyAnnoteFeatureExtractor'); - - if (audio instanceof Float64Array) { - audio = new Float32Array(audio); - } - - const shape = [ - 1, /* batch_size */ - 1, /* num_channels */ - audio.length, /* num_samples */ - ]; - return { - input_values: new Tensor('float32', audio, shape), - }; - } - - /** - * NOTE: Can return fractional values. `Math.ceil` will ensure correct value. - * @param {number} samples The number of frames in the audio. - * @returns {number} The number of frames in the audio. - */ - samples_to_frames(samples) { - return ((samples - this.config.offset) / this.config.step); - } - - /** - * Post-processes the speaker diarization logits output by the model. - * @param {Tensor} logits The speaker diarization logits output by the model. - * @param {number} num_samples Number of samples in the input audio. - * @returns {Array>} The post-processed speaker diarization results. - */ - post_process_speaker_diarization(logits, num_samples) { - const ratio = ( - num_samples / this.samples_to_frames(num_samples) - ) / this.config.sampling_rate; - - const results = []; - for (const scores of logits.tolist()) { - const accumulated_segments = []; - - let current_speaker = -1; - for (let i = 0; i < scores.length; ++i) { - const probabilities = softmax(scores[i]); - const [score, id] = max(probabilities); - const [start, end] = [i, i + 1]; - - if (id !== current_speaker) { - // Speaker has changed - current_speaker = id; - accumulated_segments.push({ id, start, end, score }); - } else { - // Continue the current segment - accumulated_segments.at(-1).end = end; - accumulated_segments.at(-1).score += score; - } - } - - results.push(accumulated_segments.map( - // Convert frame-space to time-space - // and compute the confidence - ({ id, start, end, score }) => ({ - id, - start: start * ratio, - end: end * ratio, - confidence: score / (end - start), - }) - )); - } - return results; - } - -} - -export class WeSpeakerFeatureExtractor extends FeatureExtractor { - - constructor(config) { - super(config); - - const sampling_rate = this.config.sampling_rate; - const mel_filters = mel_filter_bank( - 256, // num_frequency_bins - this.config.num_mel_bins, // num_mel_filters - 20, // min_frequency - Math.floor(sampling_rate / 2), // max_frequency - sampling_rate, // sampling_rate - null, // norm - "kaldi", // mel_scale - true, // triangularize_in_mel_space - ); - - // Do padding: - for (let i = 0; i < mel_filters.length; ++i) { - mel_filters[i].push(0); - } - this.mel_filters = mel_filters; - - this.window = window_function(400, 'hamming', { - periodic: false, - }) - this.min_num_frames = this.config.min_num_frames; - } - - /** - * Computes the log-Mel spectrogram of the provided audio waveform. - * @param {Float32Array|Float64Array} waveform The audio waveform to process. - * @returns {Promise} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers. - */ - async _extract_fbank_features(waveform) { - // Kaldi compliance: 16-bit signed integers - // 32768 == 2 ** 15 - waveform = waveform.map((/** @type {number} */ x) => x * 32768) - - return spectrogram( - waveform, - this.window, // window - 400, // frame_length - 160, // hop_length - { - fft_length: 512, - power: 2.0, - center: false, - preemphasis: 0.97, - mel_filters: this.mel_filters, - log_mel: 'log', - mel_floor: 1.192092955078125e-07, - remove_dc_offset: true, - - // Custom - transpose: true, - min_num_frames: this.min_num_frames, - } - ) - } - - - /** - * Asynchronously extracts features from a given audio using the provided configuration. - * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array. - * @returns {Promise<{ input_features: Tensor }>} A Promise resolving to an object containing the extracted input features as a Tensor. - */ - async _call(audio) { - validate_audio_inputs(audio, 'WeSpeakerFeatureExtractor'); - - const features = (await this._extract_fbank_features(audio)).unsqueeze_(0); - - if (this.config.fbank_centering_span === null) { - // center features with global average - const meanData = /** @type {Float32Array} */ (features.mean(1).data); - const featuresData = /** @type {Float32Array} */(features.data); - const [batch_size, num_frames, feature_size] = features.dims; - - for (let i = 0; i < batch_size; ++i) { - const offset1 = i * num_frames * feature_size; - const offset2 = i * feature_size; - for (let j = 0; j < num_frames; ++j) { - const offset3 = offset1 + j * feature_size; - for (let k = 0; k < feature_size; ++k) { - featuresData[offset3 + k] -= meanData[offset2 + k]; - } - } - } - } - - return { - input_features: features - }; - } -} - -export class SpeechT5FeatureExtractor extends FeatureExtractor { } - -/** - * Represents a Processor that extracts features from an input. - * @extends Callable - */ -export class Processor extends Callable { - /** - * Creates a new Processor with the given feature extractor. - * @param {FeatureExtractor} feature_extractor The function used to extract features from the input. - */ - constructor(feature_extractor) { - super(); - this.feature_extractor = feature_extractor; - // TODO use tokenizer here? - } - - /** - * Calls the feature_extractor function with the given input. - * @param {any} input The input to extract features from. - * @param {...any} args Additional arguments. - * @returns {Promise} A Promise that resolves with the extracted features. - */ - async _call(input, ...args) { - return await this.feature_extractor(input, ...args); - } -} - -export class SamProcessor extends Processor { - /** - * @borrows SamImageProcessor#_call as _call - */ - async _call(...args) { - return await this.feature_extractor(...args); - } - - /** - * @borrows SamImageProcessor#post_process_masks as post_process_masks - */ - post_process_masks(...args) { - // @ts-ignore - return this.feature_extractor.post_process_masks(...args); - } - /** - * @borrows SamImageProcessor#reshape_input_points as reshape_input_points - */ - reshape_input_points(...args) { - // @ts-ignore - return this.feature_extractor.reshape_input_points(...args); - } -} - -/** - * Represents a WhisperProcessor that extracts features from an audio input. - * @extends Processor - */ -export class WhisperProcessor extends Processor { - /** - * Calls the feature_extractor function with the given audio input. - * @param {any} audio The audio input to extract features from. - * @returns {Promise} A Promise that resolves with the extracted features. - */ - async _call(audio) { - return await this.feature_extractor(audio) - } -} - - -export class Wav2Vec2ProcessorWithLM extends Processor { - /** - * Calls the feature_extractor function with the given audio input. - * @param {any} audio The audio input to extract features from. - * @returns {Promise} A Promise that resolves with the extracted features. - */ - async _call(audio) { - return await this.feature_extractor(audio) - } -} - -export class PyAnnoteProcessor extends Processor { - /** - * Calls the feature_extractor function with the given audio input. - * @param {any} audio The audio input to extract features from. - * @returns {Promise} A Promise that resolves with the extracted features. - */ - async _call(audio) { - return await this.feature_extractor(audio) - } - - post_process_speaker_diarization(...args) { - // @ts-ignore - return this.feature_extractor.post_process_speaker_diarization(...args); - } - -} - -export class SpeechT5Processor extends Processor { - /** - * Calls the feature_extractor function with the given input. - * @param {any} input The input to extract features from. - * @returns {Promise} A Promise that resolves with the extracted features. - */ - async _call(input) { - return await this.feature_extractor(input) - } -} - -export class OwlViTProcessor extends Processor { } - -export class Florence2Processor extends Processor { - constructor(feature_extractor) { - super(feature_extractor); - - const { - tasks_answer_post_processing_type, - task_prompts_without_inputs, - task_prompts_with_input, - } = feature_extractor.config; - - /** @type {Map} */ - this.tasks_answer_post_processing_type = new Map(Object.entries(tasks_answer_post_processing_type ?? {})); - - /** @type {Map} */ - this.task_prompts_without_inputs = new Map(Object.entries(task_prompts_without_inputs ?? {})); - - /** @type {Map} */ - this.task_prompts_with_input = new Map(Object.entries(task_prompts_with_input ?? {})); - - this.regexes = { - quad_boxes: /(.+?)/gm, - bboxes: /([^<]+)?/gm, - } - this.size_per_bin = 1000; - } - - /** - * Helper function to construct prompts from input texts - * @param {string|string[]} text - * @returns {string[]} - */ - construct_prompts(text) { - if (typeof text === 'string') { - text = [text]; - } - - const prompts = []; - for (const t of text) { - // 1. fixed task prompts without additional inputs - if (this.task_prompts_without_inputs.has(t)) { - prompts.push(this.task_prompts_without_inputs.get(t)); - } - // 2. task prompts with additional inputs - else { - for (const [task, prompt] of this.task_prompts_with_input) { - if (t.includes(task)) { - prompts.push(prompt.replaceAll('{input}', t).replaceAll(task, '')); - break; - } - } - - // 3. default prompt - if (prompts.length !== text.length) { - prompts.push(t); - } - } - } - return prompts; - } - - /** - * Post-process the output of the model to each of the task outputs. - * @param {string} text The text to post-process. - * @param {string} task The task to post-process the text for. - * @param {[number, number]} image_size The size of the image. height x width. - */ - post_process_generation(text, task, image_size) { - const task_answer_post_processing_type = this.tasks_answer_post_processing_type.get(task) ?? 'pure_text'; - - // remove the special tokens - text = text.replaceAll('', '').replaceAll('', ''); - - let final_answer; - switch (task_answer_post_processing_type) { - case 'pure_text': - final_answer = text; - break; - - case 'description_with_bboxes': - case 'bboxes': - case 'phrase_grounding': - case 'ocr': - const key = task_answer_post_processing_type === 'ocr' ? 'quad_boxes' : 'bboxes'; - const matches = text.matchAll(this.regexes[key]); - const labels = []; - const items = []; - for (const [_, label, ...locations] of matches) { - // Push new label, or duplicate the last label - labels.push(label ? label.trim() : labels.at(-1) ?? ''); - items.push(locations.map((x, i) => - // NOTE: Add 0.5 to use the center position of the bin as the coordinate. - (Number(x) + 0.5) / this.size_per_bin * image_size[i % 2]) - ); - } - final_answer = { labels, [key]: items }; - break; - - default: - throw new Error(`Task "${task}" (of type "${task_answer_post_processing_type}") not yet implemented.`); - } - - return { [task]: final_answer } - } -} - -////////////////////////////////////////////////// -/** - * Helper class which is used to instantiate pretrained processors with the `from_pretrained` function. - * The chosen processor class is determined by the type specified in the processor config. - * - * **Example:** Load a processor using `from_pretrained`. - * ```javascript - * let processor = await AutoProcessor.from_pretrained('openai/whisper-tiny.en'); - * ``` - * - * **Example:** Run an image through a processor. - * ```javascript - * let processor = await AutoProcessor.from_pretrained('Xenova/clip-vit-base-patch16'); - * let image = await RawImage.read('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg'); - * let image_inputs = await processor(image); - * // { - * // "pixel_values": { - * // "dims": [ 1, 3, 224, 224 ], - * // "type": "float32", - * // "data": Float32Array [ -1.558687686920166, -1.558687686920166, -1.5440893173217773, ... ], - * // "size": 150528 - * // }, - * // "original_sizes": [ - * // [ 533, 800 ] - * // ], - * // "reshaped_input_sizes": [ - * // [ 224, 224 ] - * // ] - * // } - * ``` - */ -export class AutoProcessor { - static FEATURE_EXTRACTOR_CLASS_MAPPING = { - ImageFeatureExtractor, - WhisperFeatureExtractor, - ViTFeatureExtractor, - MobileViTFeatureExtractor, - MobileViTImageProcessor, - MobileNetV1FeatureExtractor, - MobileNetV2FeatureExtractor, - MobileNetV3FeatureExtractor, - MobileNetV4FeatureExtractor, - OwlViTFeatureExtractor, - Owlv2ImageProcessor, - CLIPFeatureExtractor, - CLIPImageProcessor, - Florence2Processor, - ChineseCLIPFeatureExtractor, - SiglipImageProcessor, - ConvNextFeatureExtractor, - ConvNextImageProcessor, - SegformerFeatureExtractor, - SapiensFeatureExtractor, - BitImageProcessor, - DPTImageProcessor, - DPTFeatureExtractor, - PvtImageProcessor, - GLPNFeatureExtractor, - BeitFeatureExtractor, - DeiTFeatureExtractor, - DetrFeatureExtractor, - RTDetrImageProcessor, - MaskFormerFeatureExtractor, - YolosFeatureExtractor, - DonutFeatureExtractor, - DonutImageProcessor, - NougatImageProcessor, - EfficientNetImageProcessor, - - ViTImageProcessor, - VitMatteImageProcessor, - SamImageProcessor, - Swin2SRImageProcessor, - Wav2Vec2FeatureExtractor, - SeamlessM4TFeatureExtractor, - SpeechT5FeatureExtractor, - ASTFeatureExtractor, - ClapFeatureExtractor, - PyAnnoteFeatureExtractor, - WeSpeakerFeatureExtractor, - } - - static PROCESSOR_CLASS_MAPPING = { - WhisperProcessor, - Wav2Vec2ProcessorWithLM, - PyAnnoteProcessor, - SamProcessor, - SpeechT5Processor, - OwlViTProcessor, - Florence2Processor, - } - - /** - * Instantiate one of the processor classes of the library from a pretrained model. - * - * The processor class to instantiate is selected based on the `feature_extractor_type` property of the config object - * (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible) - * - * @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either: - * - A string, the *model id* of a pretrained processor hosted inside a model repo on huggingface.co. - * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a - * user or organization name, like `dbmdz/bert-base-german-cased`. - * - A path to a *directory* containing processor files, e.g., `./my_model_directory/`. - * @param {import('./utils/hub.js').PretrainedOptions} options Additional options for loading the processor. - * - * @returns {Promise} A new instance of the Processor class. - */ - static async from_pretrained(pretrained_model_name_or_path, { - progress_callback = null, - config = null, - cache_dir = null, - local_files_only = false, - revision = 'main', - } = {}) { - - let preprocessorConfig = config ?? await getModelJSON(pretrained_model_name_or_path, 'preprocessor_config.json', true, { - progress_callback, - config, - cache_dir, - local_files_only, - revision, - }) - - // Determine feature extractor class - // TODO: Ensure backwards compatibility with old configs - let key = preprocessorConfig.feature_extractor_type ?? preprocessorConfig.image_processor_type; - let feature_extractor_class = this.FEATURE_EXTRACTOR_CLASS_MAPPING[key]; - - if (!feature_extractor_class) { - if (preprocessorConfig.size !== undefined) { - // Assume ImageFeatureExtractor - console.warn(`Feature extractor type "${key}" not found, assuming ImageFeatureExtractor due to size parameter in config.`); - feature_extractor_class = ImageFeatureExtractor; - } else { - throw new Error(`Unknown Feature Extractor type: ${key}`); - } - } - - // If no associated processor class, use default - let processor_class = this.PROCESSOR_CLASS_MAPPING[preprocessorConfig.processor_class] ?? Processor; - - // Instantiate processor and feature extractor - let feature_extractor = new feature_extractor_class(preprocessorConfig); - return new processor_class(feature_extractor); - } -} -////////////////////////////////////////////////// - diff --git a/src/transformers.js b/src/transformers.js index be7ad176e..4ef2704fc 100644 --- a/src/transformers.js +++ b/src/transformers.js @@ -12,10 +12,10 @@ */ export { env } from './env.js'; + export * from './pipelines.js'; export * from './models.js'; export * from './tokenizers.js'; -export * from './processors.js'; export * from './configs.js'; export * from './utils/audio.js'; @@ -23,6 +23,19 @@ export * from './utils/image.js'; export * from './utils/tensor.js'; export * from './utils/maths.js'; + +export { FeatureExtractor } from './base/feature_extraction_utils.js'; +export * from './models/feature_extractors.js'; +export * from './models/auto/feature_extraction_auto.js'; + +export { ImageProcessor } from './base/image_processors_utils.js'; +export * from './models/image_processors.js'; +export * from './models/auto/image_processing_auto.js'; + +export { Processor } from './base/processing_utils.js'; +export * from './models/processors.js'; +export * from './models/auto/processing_auto.js'; + export * from './generation/streamers.js'; export * from './generation/stopping_criteria.js'; diff --git a/src/utils/constants.js b/src/utils/constants.js index 9d0e9ee42..ed456a56b 100644 --- a/src/utils/constants.js +++ b/src/utils/constants.js @@ -1,2 +1,9 @@ -export const GITHUB_ISSUE_URL = 'https://github.com/huggingface/transformers.js/issues/new/choose'; \ No newline at end of file +export const GITHUB_ISSUE_URL = 'https://github.com/huggingface/transformers.js/issues/new/choose'; + +export const CONFIG_NAME = "config.json" +export const FEATURE_EXTRACTOR_NAME = "preprocessor_config.json" +export const IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME +export const PROCESSOR_NAME = "processor_config.json" +export const CHAT_TEMPLATE_NAME = "chat_template.json" +export const GENERATION_CONFIG_NAME = "generation_config.json" diff --git a/tests/tiny_random.test.js b/tests/tiny_random.test.js index 634924f32..e23558e25 100644 --- a/tests/tiny_random.test.js +++ b/tests/tiny_random.test.js @@ -750,7 +750,10 @@ describe("Tiny random models", () => { }); describe("florence2", () => { - const texts = ["Describe with a paragraph what is shown in the image.", "Locate the objects with category name in the image."]; + const texts = [ + "Describe with a paragraph what is shown in the image.", + "Locate the objects with category name in the image.", + ]; // Empty white image const dims = [224, 224, 3]; @@ -761,8 +764,6 @@ describe("Tiny random models", () => { /** @type {Florence2ForConditionalGeneration} */ let model; - /** @type {BartTokenizer} */ - let tokenizer; /** @type {Florence2Processor} */ let processor; beforeAll(async () => { @@ -770,22 +771,18 @@ describe("Tiny random models", () => { // TODO move to config ...DEFAULT_MODEL_OPTIONS, }); - tokenizer = await BartTokenizer.from_pretrained(model_id); processor = await AutoProcessor.from_pretrained(model_id); }, MAX_MODEL_LOAD_TIME); it( "forward", async () => { - const text_inputs = tokenizer(texts[0]); - const vision_inputs = await processor(image); - const inputs = { - ...text_inputs, - ...vision_inputs, - decoder_input_ids: full([1, 1], 2n), - }; + const inputs = await processor(image, texts[0]); - const { logits } = await model(inputs); + const { logits } = await model({ + ...inputs, + decoder_input_ids: full([1, 1], 2n), + }); expect(logits.dims).toEqual([1, 1, 51289]); }, MAX_TEST_EXECUTION_TIME, @@ -794,15 +791,13 @@ describe("Tiny random models", () => { it( "batch_size=1", async () => { - const text_inputs = tokenizer(texts[0]); { + const text_inputs = processor.tokenizer(texts[0]); const generate_ids = await model.generate({ ...text_inputs, max_new_tokens: 10 }); expect(generate_ids.tolist()).toEqual([[2n, 0n, 0n, 0n, 1n, 0n, 0n, 2n]]); } { - const vision_inputs = await processor(image); - const inputs = { ...text_inputs, ...vision_inputs }; - + const inputs = await processor(image, texts[0]); const generate_ids = await model.generate({ ...inputs, max_new_tokens: 10 }); expect(generate_ids.tolist()).toEqual([[2n, 0n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 2n]]); } @@ -813,8 +808,8 @@ describe("Tiny random models", () => { it( "batch_size>1", async () => { - const text_inputs = tokenizer(texts, { padding: true }); { + const text_inputs = processor.tokenizer(texts, { padding: true }); const generate_ids = await model.generate({ ...text_inputs, max_new_tokens: 10 }); expect(generate_ids.tolist()).toEqual([ [2n, 0n, 0n, 0n, 1n, 0n, 0n, 2n], @@ -822,8 +817,7 @@ describe("Tiny random models", () => { ]); } { - const vision_inputs = await processor([image, image]); - const inputs = { ...text_inputs, ...vision_inputs }; + const inputs = await processor([image, image], texts, { padding: true }); const generate_ids = await model.generate({ ...inputs, max_new_tokens: 10 }); expect(generate_ids.tolist()).toEqual([