Skip to content

Commit

Permalink
Add support for BLOOM models (#273)
Browse files Browse the repository at this point in the history
* Add support for Bloom models

* Update `BloomTokenizer` to fix the default (invalid) regex

* Update supported models

* Update default quantization settings for bloom models

* Fix `use_cache_branch`
  • Loading branch information
xenova authored Sep 1, 2023
1 parent 62159eb commit 9077c21
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 6 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **[ALBERT](https://huggingface.co/docs/transformers/model_doc/albert)** (from Google Research and the Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut.
1. **[BART](https://huggingface.co/docs/transformers/model_doc/bart)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer.
1. **[BERT](https://huggingface.co/docs/transformers/model_doc/bert)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
1. **[BLOOM](https://huggingface.co/docs/transformers/model_doc/bloom)** (from BigScience workshop) released by the [BigScience Workshop](https://bigscience.huggingface.co/).
1. **[CLIP](https://huggingface.co/docs/transformers/model_doc/clip)** (from OpenAI) released with the paper [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever.
1. **[CodeGen](https://huggingface.co/docs/transformers/model_doc/codegen)** (from Salesforce) released with the paper [A Conversational Paradigm for Program Synthesis](https://arxiv.org/abs/2203.13474) by Erik Nijkamp, Bo Pang, Hiroaki Hayashi, Lifu Tu, Huan Wang, Yingbo Zhou, Silvio Savarese, Caiming Xiong.
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
Expand Down
1 change: 1 addition & 0 deletions docs/snippets/6_supported-models.snippet
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
1. **[ALBERT](https://huggingface.co/docs/transformers/model_doc/albert)** (from Google Research and the Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut.
1. **[BART](https://huggingface.co/docs/transformers/model_doc/bart)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer.
1. **[BERT](https://huggingface.co/docs/transformers/model_doc/bert)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
1. **[BLOOM](https://huggingface.co/docs/transformers/model_doc/bloom)** (from BigScience workshop) released by the [BigScience Workshop](https://bigscience.huggingface.co/).
1. **[CLIP](https://huggingface.co/docs/transformers/model_doc/clip)** (from OpenAI) released with the paper [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever.
1. **[CodeGen](https://huggingface.co/docs/transformers/model_doc/codegen)** (from Salesforce) released with the paper [A Conversational Paradigm for Program Synthesis](https://arxiv.org/abs/2203.13474) by Erik Nijkamp, Bo Pang, Hiroaki Hayashi, Lifu Tu, Huan Wang, Yingbo Zhou, Silvio Savarese, Caiming Xiong.
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
Expand Down
4 changes: 4 additions & 0 deletions scripts/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
'whisper': {
'per_channel': False,
'reduce_range': False,
},
'bloom': {
'per_channel': False,
'reduce_range': False,
}
}

Expand Down
8 changes: 4 additions & 4 deletions scripts/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@
'unitary/toxic-bert',
],
# TODO:
# 'bloom':[
# 'bigscience/bloom-560m',
# 'bigscience/bloomz-560m',
# ],
'bloom': [
'bigscience/bloom-560m',
# 'bigscience/bloomz-560m',
],
# TODO:
# 'blenderbot-small': [
# 'facebook/blenderbot_small-90M',
Expand Down
104 changes: 103 additions & 1 deletion src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ async function decoderForward(self, model_inputs) {
let decoderFeeds = {
input_ids: input_ids,
attention_mask: attention_mask ?? prepareAttentionMask(self, input_ids),
use_cache_branch: boolTensor(past_key_values !== null)
use_cache_branch: boolTensor(!!past_key_values)
}

self.addPastKeyValues(decoderFeeds, past_key_values);
Expand Down Expand Up @@ -1178,6 +1178,17 @@ export class PreTrainedModel extends Callable {
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key_value`] = new Tensor('float32', [], dims)
}
} else if (this.config.model_type === 'bloom') {
// Custom implementation for Bloom
// @ts-ignore
let keyDims = [1 * this.num_heads, this.dim_kv, 0] // [batch_size x num_heads,64,past_sequence_length]
// @ts-ignore
let valueDims = [1 * this.num_heads, 0, this.dim_kv] // [batch_size x num_heads,past_sequence_length,64]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], keyDims)
decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], valueDims)
}
} else {
// @ts-ignore
let dims = [1, this.num_heads, 0, this.dim_kv]
Expand Down Expand Up @@ -2660,6 +2671,9 @@ export class GPT2LMHeadModel extends GPT2PreTrainedModel {
// TODO
// }
//////////////////////////////////////////////////

//////////////////////////////////////////////////
// GPTNeo models
export class GPTNeoPreTrainedModel extends PreTrainedModel {
/**
* Creates a new instance of the `GPTNeoPreTrainedModel` class.
Expand Down Expand Up @@ -2985,6 +2999,92 @@ export class LlamaForCausalLM extends LlamaPreTrainedModel {
//////////////////////////////////////////////////


//////////////////////////////////////////////////
// Bloom models
/**
* The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).
*/
export class BloomPreTrainedModel extends PreTrainedModel {
/**
* Creates a new instance of the `BloomPreTrainedModel` class.
* @param {Object} config The configuration of the model.
* @param {any} session The ONNX session containing the model weights.
*/
constructor(config, session) {
super(config, session);

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id

this.num_heads = this.config.n_head
this.num_layers = this.config.n_layer
this.dim_kv = this.config.hidden_size / this.num_heads;
}
}

/**
* The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.
*/
export class BloomModel extends BloomPreTrainedModel {

/**
* BloomModel is not compatible with `.generate()`, as it doesn't have a language model head.
* @param {...any} args
* @throws {Error}
* @returns {Promise<any>}
*/
async generate(...args) {
throw Error(
"The current model class (BloomModel) is not compatible with `.generate()`, as it doesn't have a language model head. Please use one of the following classes instead: {'BloomForCausalLM'}"
)
}
}

/**
* The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).
*/
export class BloomForCausalLM extends BloomPreTrainedModel {

/**
* Initializes and returns the beam for text generation task
* @param {Tensor} inputTokenIds The input token ids.
* @param {number} numOutputTokens The number of tokens to be generated.
* @param {Tensor} inputs_attention_mask Optional input attention mask.
* @returns {any} A Beam object representing the initialized beam.
*/
getStartBeams(inputTokenIds, numOutputTokens, inputs_attention_mask) {
return decoderStartBeams(this, inputTokenIds, numOutputTokens, inputs_attention_mask)
}

/**
* Runs a single step of the beam search generation algorithm.
* @param {any} beam The current beam being generated.
* @returns {Promise<any>} The updated beam after a single generation step.
*/
async runBeam(beam) {
return await decoderRunBeam(this, beam);
}

/**
* Updates the given beam with the new generated token id.
* @param {any} beam The Beam object representing the beam.
* @param {number} newTokenId The new generated token id to be added to the beam.
*/
updateBeam(beam, newTokenId) {
return decoderUpdatebeam(beam, newTokenId);
}

/**
* Forward pass for the model.
* @param {Object} model_inputs The inputs for the model.
* @returns {Promise<any>} The output tensor of the model.
*/
async forward(model_inputs) {
return await decoderForward(this, model_inputs);
}
}
//////////////////////////////////////////////////

//////////////////////////////////////////////////
export class ViTPreTrainedModel extends PreTrainedModel { }
export class ViTModel extends ViTPreTrainedModel { }
Expand Down Expand Up @@ -3478,6 +3578,7 @@ const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([


const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
['bloom', BloomModel],
['gpt2', GPT2Model],
['gpt_bigcode', GPTBigCodeModel],
['gpt_neo', GPTNeoModel],
Expand Down Expand Up @@ -3519,6 +3620,7 @@ const MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES = new Map([
]);

const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([
['bloom', BloomForCausalLM],
['gpt2', GPT2LMHeadModel],
['gpt_bigcode', GPTBigCodeForCausalLM],
['gpt_neo', GPTNeoForCausalLM],
Expand Down
13 changes: 12 additions & 1 deletion src/tokenizers.js
Original file line number Diff line number Diff line change
Expand Up @@ -2566,7 +2566,18 @@ export class GPT2Tokenizer extends PreTrainedTokenizer { }
export class BartTokenizer extends PreTrainedTokenizer { }
export class RobertaTokenizer extends PreTrainedTokenizer { }

export class BloomTokenizer extends PreTrainedTokenizer { }
export class BloomTokenizer extends PreTrainedTokenizer {
constructor(tokenizerJSON, tokenizerConfig) {
// Override the default (invalid) regex of the pretokenizer.
// For more information, see https://github.com/xenova/transformers.js/issues/94
const splitChars = '.,!?\u2026\u3002\uff0c\u3001\u0964\u06d4\u060c';
const patternObject = tokenizerJSON.pre_tokenizer?.pretokenizers[0]?.pattern;
if (patternObject && patternObject.Regex === ` ?[^(\\s|[${splitChars}])]+`) {
patternObject.Regex = ` ?[^\\s${splitChars}]+`;
}
super(tokenizerJSON, tokenizerConfig);
}
}
export class LlamaTokenizer extends PreTrainedTokenizer { }

export class XLMRobertaTokenizer extends PreTrainedTokenizer { }
Expand Down

0 comments on commit 9077c21

Please sign in to comment.