Skip to content

Commit

Permalink
Add support for JAIS models (#906)
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Aug 28, 2024
1 parent 535cdfe commit 4e1acf0
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/snippets/6_supported-models.snippet
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
1. **[GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode)** (from BigCode) released with the paper [SantaCoder: don't reach for the stars!](https://arxiv.org/abs/2301.03988) by Loubna Ben Allal, Raymond Li, Denis Kocetkov, Chenghao Mou, Christopher Akiki, Carlos Munoz Ferrandis, Niklas Muennighoff, Mayank Mishra, Alex Gu, Manan Dey, Logesh Kumar Umapathi, Carolyn Jane Anderson, Yangtian Zi, Joel Lamy Poirier, Hailey Schoelkopf, Sergey Troshin, Dmitry Abulkhanov, Manuel Romero, Michael Lappert, Francesco De Toni, Bernardo García del Río, Qian Liu, Shamik Bose, Urvashi Bhattacharyya, Terry Yue Zhuo, Ian Yu, Paulo Villegas, Marco Zocca, Sourab Mangrulkar, David Lansky, Huu Nguyen, Danish Contractor, Luis Villa, Jia Li, Dzmitry Bahdanau, Yacine Jernite, Sean Hughes, Daniel Fried, Arjun Guha, Harm de Vries, Leandro von Werra.
1. **[HerBERT](https://huggingface.co/docs/transformers/model_doc/herbert)** (from Allegro.pl, AGH University of Science and Technology) released with the paper [KLEJ: Comprehensive Benchmark for Polish Language Understanding](https://www.aclweb.org/anthology/2020.acl-main.111.pdf) by Piotr Rybak, Robert Mroczkowski, Janusz Tracz, Ireneusz Gawlik.
1. **[Hubert](https://huggingface.co/docs/transformers/model_doc/hubert)** (from Facebook) released with the paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed.
1. **JAIS** (from Core42) released with the paper [Jais and Jais-chat: Arabic-Centric Foundation and Instruction-Tuned Open Generative Large Language Models](https://arxiv.org/pdf/2308.16149) by Neha Sengupta, Sunil Kumar Sahu, Bokang Jia, Satheesh Katipomu, Haonan Li, Fajri Koto, William Marshall, Gurpreet Gosal, Cynthia Liu, Zhiming Chen, Osama Mohammed Afzal, Samta Kamboj, Onkar Pandit, Rahul Pal, Lalit Pradhan, Zain Muhammad Mujahid, Massa Baali, Xudong Han, Sondos Mahmoud Bsharat, Alham Fikri Aji, Zhiqiang Shen, Zhengzhong Liu, Natalia Vassilieva, Joel Hestness, Andy Hock, Andrew Feldman, Jonathan Lee, Andrew Jackson, Hector Xuguang Ren, Preslav Nakov, Timothy Baldwin, Eric Xing.
1. **[LongT5](https://huggingface.co/docs/transformers/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang.
1. **[LLaMA](https://huggingface.co/docs/transformers/model_doc/llama)** (from The FAIR team of Meta AI) released with the paper [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample.
1. **[Llama2](https://huggingface.co/docs/transformers/model_doc/llama2)** (from The FAIR team of Meta AI) released with the paper [Llama2: Open Foundation and Fine-Tuned Chat Models](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/XXX) by Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, Dan Bikel, Lukas Blecher, Cristian Canton Ferrer, Moya Chen, Guillem Cucurull, David Esiobu, Jude Fernandes, Jeremy Fu, Wenyin Fu, Brian Fuller, Cynthia Gao, Vedanuj Goswami, Naman Goyal, Anthony Hartshorn, Saghar Hosseini, Rui Hou, Hakan Inan, Marcin Kardas, Viktor Kerkez Madian Khabsa, Isabel Kloumann, Artem Korenev, Punit Singh Koura, Marie-Anne Lachaux, Thibaut Lavril, Jenya Lee, Diana Liskovich, Yinghai Lu, Yuning Mao, Xavier Martinet, Todor Mihaylov, Pushka rMishra, Igor Molybog, Yixin Nie, Andrew Poulton, Jeremy Reizenstein, Rashi Rungta, Kalyan Saladi, Alan Schelten, Ruan Silva, Eric Michael Smith, Ranjan Subramanian, Xiaoqing EllenTan, Binh Tang, Ross Taylor, Adina Williams, Jian Xiang Kuan, Puxin Xu, Zheng Yan, Iliyan Zarov, Yuchen Zhang, Angela Fan, Melanie Kambadur, Sharan Narang, Aurelien Rodriguez, Robert Stojnic, Sergey Edunov, Thomas Scialom.
Expand Down
1 change: 1 addition & 0 deletions src/configs.js
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ function getNormalizedConfig(config) {
// Decoder-only models
case 'gpt2':
case 'gptj':
case 'jais':
case 'codegen':
case 'gpt_bigcode':
mapping['num_heads'] = 'n_head';
Expand Down
29 changes: 29 additions & 0 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -3783,6 +3783,33 @@ export class GPT2LMHeadModel extends GPT2PreTrainedModel { }
// }
//////////////////////////////////////////////////

//////////////////////////////////////////////////
// JAIS models
export class JAISPreTrainedModel extends PreTrainedModel {
/**
* Creates a new instance of the `JAISPreTrainedModel` class.
* @param {Object} config The model configuration.
* @param {Record<string, any>} sessions The inference sessions for the model.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, sessions, generation_config) {
super(config, sessions);
this.generation_config = generation_config;
}
}

/**
* The bare JAIS Model transformer outputting raw hidden-states without any specific head on top.
*/
export class JAISModel extends JAISPreTrainedModel { }

/**
* The JAIS Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).
*/
export class JAISLMHeadModel extends JAISPreTrainedModel { }
//////////////////////////////////////////////////


//////////////////////////////////////////////////
// GPTNeo models
export class GPTNeoPreTrainedModel extends PreTrainedModel {
Expand Down Expand Up @@ -6345,6 +6372,7 @@ const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([

const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
['bloom', ['BloomModel', BloomModel]],
['jais', ['JAISModel', JAISModel]],
['gpt2', ['GPT2Model', GPT2Model]],
['gptj', ['GPTJModel', GPTJModel]],
['gpt_bigcode', ['GPTBigCodeModel', GPTBigCodeModel]],
Expand Down Expand Up @@ -6433,6 +6461,7 @@ const MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = new Map([
const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([
['bloom', ['BloomForCausalLM', BloomForCausalLM]],
['gpt2', ['GPT2LMHeadModel', GPT2LMHeadModel]],
['jais', ['JAISLMHeadModel', JAISLMHeadModel]],
['gptj', ['GPTJForCausalLM', GPTJForCausalLM]],
['gpt_bigcode', ['GPTBigCodeForCausalLM', GPTBigCodeForCausalLM]],
['gpt_neo', ['GPTNeoForCausalLM', GPTNeoForCausalLM]],
Expand Down
46 changes: 46 additions & 0 deletions tests/tiny_random.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import {
BloomForCausalLM,
GPTBigCodeForCausalLM,
GPT2LMHeadModel,
JAISLMHeadModel,
MptForCausalLM,
CodeGenForCausalLM,
MistralForCausalLM,
Expand Down Expand Up @@ -1345,6 +1346,51 @@ describe('Tiny random models', () => {
});
});

describe('jais', () => {
describe('JAISLMHeadModel', () => {
const model_id = 'onnx-community/tiny-random-jais';
/** @type {JAISLMHeadModel} */
let model;
/** @type {PreTrainedTokenizer} */
let tokenizer;
beforeAll(async () => {
model = await JAISLMHeadModel.from_pretrained(model_id, {
// TODO move to config
...DEFAULT_MODEL_OPTIONS,
});
tokenizer = await PreTrainedTokenizer.from_pretrained(model_id);
tokenizer.padding_side = 'left';
}, MAX_MODEL_LOAD_TIME);

it('batch_size=1', async () => {
const inputs = tokenizer('hello');
const outputs = await model.generate({
...inputs,
max_length: 10,
});
expect(outputs.tolist()).toEqual([
[55422n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n],
]);
}, MAX_TEST_EXECUTION_TIME);

it('batch_size>1', async () => {
const inputs = tokenizer(['hello', 'hello world'], { padding: true });
const outputs = await model.generate({
...inputs,
max_length: 10,
});
expect(outputs.tolist()).toEqual([
[0n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n],
[55422n, 2838n, 2838n, 2838n, 2838n, 2838n, 2838n, 2838n, 2838n, 2838n],
]);
}, MAX_TEST_EXECUTION_TIME);

afterAll(async () => {
await model?.dispose();
}, MAX_MODEL_DISPOSE_TIME);
});
});

describe('mpt', () => {
describe('MptForCausalLM', () => {
const model_id = 'hf-internal-testing/tiny-random-MptForCausalLM';
Expand Down

0 comments on commit 4e1acf0

Please sign in to comment.