Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
257da7c
updated
Feb 17, 2023
cd13fc5
updated
Feb 20, 2023
9dc6e10
upadted
Feb 20, 2023
c6ac834
changed special token fule
Feb 22, 2023
62dbded
optimize tokenizer
Feb 23, 2023
d09f1d5
updated
Feb 27, 2023
5c3258d
Merge pull request #232 from 920232796/master
ftgreat Feb 27, 2023
c830a02
Update setup.py
ftgreat Feb 28, 2023
02ded3b
Create __init__.py
xuanricheng Feb 28, 2023
21a2090
bminf added
ftgreat Feb 28, 2023
3f43f1f
Merge branch 'master' into add_bminf
ftgreat Feb 28, 2023
bcdf9b1
updated
ftgreat Mar 1, 2023
3ab5e8d
Merge pull request #236 from xuanricheng/patch-1
ftgreat Mar 1, 2023
0811530
added bminf
ftgreat Mar 2, 2023
66aa41a
fixed conflicts
ftgreat Mar 2, 2023
0fe927f
test
ftgreat Mar 2, 2023
95dfe19
test
ftgreat Mar 2, 2023
2093bb9
updated
ftgreat Mar 2, 2023
8d4e86a
fixed error
ftgreat Mar 2, 2023
0abaed6
upadted
ftgreat Mar 2, 2023
5279f8a
updated
ftgreat Mar 2, 2023
c50c662
fixed inconsistency
ftgreat Mar 2, 2023
3c99037
updated
ftgreat Mar 2, 2023
173ce99
new version
ftgreat Mar 2, 2023
948e3f9
removed unused bminf
ftgreat Mar 2, 2023
df91be4
modified according to comments
ftgreat Mar 2, 2023
f27f9e9
updated
ftgreat Mar 2, 2023
c61d9ca
Merge pull request #238 from Anhforth/add_bminf
ftgreat Mar 2, 2023
bd8d657
updated
ftgreat Mar 2, 2023
ae728a8
Merge pull request #239 from Anhforth/add_bminf
ftgreat Mar 2, 2023
4d9c638
Update README.md
ftgreat Mar 6, 2023
257d79e
Update README_zh.md
ftgreat Mar 7, 2023
32c4fce
add llama model
920232796 Mar 7, 2023
254226a
add a weight merging tool file
920232796 Mar 7, 2023
e8bd3b4
modify the file name
920232796 Mar 7, 2023
ff75c3d
Merge pull request #243 from 920232796/master
ftgreat Mar 7, 2023
06af33f
Revert "add llama model"
ftgreat Mar 7, 2023
7ebc45d
Merge pull request #247 from FlagAI-Open/revert-243-master
ftgreat Mar 7, 2023
249b127
updated
ftgreat Mar 7, 2023
b01e285
fixed issue246
ftgreat Mar 8, 2023
c827d1e
Merge branch 'master' into fix_issue246
ftgreat Mar 8, 2023
a792538
updated
ftgreat Mar 8, 2023
92e39e0
Merge pull request #249 from Anhforth/fix_issue246
ftgreat Mar 8, 2023
ff5028b
updated
ftgreat Mar 10, 2023
523bb61
ignore file updated
ftgreat Mar 10, 2023
c8c3e60
saved work
ftgreat Mar 10, 2023
1a3bc5e
fixed
ftgreat Mar 14, 2023
4cb9cf8
Merge pull request #263 from Anhforth/fix_issue262
ftgreat Mar 14, 2023
dbd5a4e
added new tokenizer and test_tokenizer
ftgreat Mar 15, 2023
1238ec4
Merge pull request #13 from FlagAI-Open/master
shunxing1234 Mar 15, 2023
3f46c38
add optimzier
ftgreat Mar 15, 2023
12433c8
updated
ftgreat Mar 15, 2023
ade0895
updated
ftgreat Mar 15, 2023
3560b60
can assert new special tokens
ftgreat Mar 15, 2023
5f7edd3
add optimzier
shunxing1234 Mar 15, 2023
0ce3898
removed testing cpm tokenizer
ftgreat Mar 15, 2023
269ca89
fix
shunxing1234 Mar 15, 2023
8d82c58
Merge pull request #264 from Anhforth/opt_tokenizer
ftgreat Mar 16, 2023
d74a612
Merge pull request #266 from shunxing1234/master
ftgreat Mar 16, 2023
b8f8639
fix bug in setting for mp size >1
jongjyh Mar 16, 2023
ad9e3a3
add optimizer tutorial
shunxing1234 Mar 17, 2023
d71563b
add optimizer tutorial
shunxing1234 Mar 17, 2023
5f816cf
fix optimizer zh tutorial
shunxing1234 Mar 17, 2023
9900ed4
add optimizer tutorial
shunxing1234 Mar 17, 2023
84993b6
add tutorial
shunxing1234 Mar 17, 2023
9967b3c
add tutorial
shunxing1234 Mar 17, 2023
a202792
Merge pull request #271 from shunxing1234/master
ftgreat Mar 17, 2023
7715186
Merge pull request #269 from marscrazy/master
ftgreat Mar 17, 2023
5352f6a
env_trainer fix data_loader sampler when ds&mpu
ftgreat Mar 20, 2023
059603e
fix trainer dataloader sampler when ds&mpu
ftgreat Mar 20, 2023
590d178
Merge branch 'gpm_dev' into master
ftgreat Mar 21, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ datasets
qqp
glm_large_qqp_pytorch
wandb
clip_benchmark_datasets
examples/AltCLIP/clip_benchmark_datasets
examples/glm_pretrain/data.lazy
examples/glm_pretrain/examples/glm_pretrain/data.lazy
examples/vit_cifar100/cifar100
examples/vit_cifar100/data
examples/vit_cifar100/data
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ FlagAI (Fast LArge-scale General AI models) is a fast, easy-to-use and extensibl

* These models can be applied to (Chinese/English) Text, for tasks like text classification, information extraction, question answering, summarization, and text generation.

* FlagAI is backed by the three most popular data/model parallel libraries — [PyTorch](https://pytorch.org/)/[Deepspeed](https://www.deepspeed.ai/)/[Megatron-LM](https://github.com/NVIDIA/Megatron-LM)/[BMTrain](https://github.com/OpenBMB/BMTrain) — with seamless integration between them. Users can parallel their training/testing process with less than ten lines of code.
* FlagAI is backed by the four most popular data/model parallel libraries — [PyTorch](https://pytorch.org/)/[Deepspeed](https://www.deepspeed.ai/)/[Megatron-LM](https://github.com/NVIDIA/Megatron-LM)/[BMTrain](https://github.com/OpenBMB/BMTrain) — with seamless integration between them. Users can parallel their training/testing process with less than ten lines of code.

The code is partially based on [GLM](https://github.com/THUDM/GLM), [Transformers](https://github.com/huggingface/transformers),[timm](https://github.com/rwightman/pytorch-image-models) and [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM).

## News
- [2 Mar 2023] release v1.6.1, Support Galactica model [#234](https://github.com/FlagAI-Open/FlagAI/pull/234); BMInf, a low-resource inference package [#238](https://github.com/FlagAI-Open/FlagAI/pull/238), and examples for p-tuning [#227](https://github.com/FlagAI-Open/FlagAI/pull/238)
- [12 Jan 2023] release v1.6.0, support a new parallel lib called [**BMTrain**](https://github.com/OpenBMB/BMTrain) and integate [**Flash Attention**](https://github.com/HazyResearch/flash-attention) to speedup training of Bert and Vit models, examples in [FlashAttentionBERT](https://github.com/FlagAI-Open/FlagAI/blob/master/examples/bert_title_generation_english/train_flash_atten.py) and [FlashAttentionViT](https://github.com/FlagAI-Open/FlagAI/blob/master/examples/vit_cifar100/train_single_gpu_flash_atten.py). Also add the contrastive search based text generation method [**SimCTG**](https://github.com/yxuansu/SimCTG) and DreamBooth finetuning based on AltDiffusion, examples in [AltDiffusionNaruto](https://github.com/FlagAI-Open/FlagAI/blob/master/examples/AltDiffusion/dreambooth.py).
- [28 Nov 2022] release v1.5.0, support 1.1B [**EVA-CLIP**](https://github.com/FlagAI-Open/FlagAI/tree/master/examples/EVA_CLIP) and [ALM: A large Arabic Language Model based on GLM], examples in [**ALM**](https://github.com/FlagAI-Open/FlagAI/tree/master/examples/ALM)
- [10 Nov 2022] release v1.4.0, support [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://arxiv.org/abs/2211.06679v1), examples in [**AltCLIP**](https://github.com/FlagAI-Open/FlagAI/tree/master/examples/AltCLIP) and [**AltDiffusion**](https://github.com/FlagAI-Open/FlagAI/tree/master/examples/AltDiffusion)
Expand Down Expand Up @@ -259,6 +260,6 @@ The majority of FlagAI is licensed under the [Apache 2.0 license](LICENSE), howe
### ↳ Star History
<div align="center">

[![Star History Chart](https://api.star-history.com/svg?repos=FlagAI-Open/FlagAI&type=Date)](https://star-history.com/#baaivision/EVA&Date)
![Star History Chart](https://api.star-history.com/svg?repos=FlagAI-Open/FlagAI&type=Date)]

</div>
3 changes: 2 additions & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@

* 这些模型可以应用于文本,用于文本分类、信息提取、问答、摘要、文本生成等任务,尤其是中文。

* 飞智由三个最流行的数据/模型并行库([PyTorch](https://pytorch.org/)/[Deepspeed](https://www.deepspeed.ai/)/[Megatron-LM](https://github.com/NVIDIA/Megatron-LM)/[BMTrain](https://github.com/OpenBMB/BMTrain))提供支持,它们之间实现了无缝集成。 你可以用不到十行代码来并行你的训练/测试过程。
* 飞智由四个最流行的数据/模型并行库([PyTorch](https://pytorch.org/)/[Deepspeed](https://www.deepspeed.ai/)/[Megatron-LM](https://github.com/NVIDIA/Megatron-LM)/[BMTrain](https://github.com/OpenBMB/BMTrain))提供支持,它们之间实现了无缝集成。 你可以用不到十行代码来并行你的训练/测试过程。


本项目的部分代码基于[GLM](https://github.com/THUDM/GLM),[Transformers](https://github.com/huggingface/transformers),[timm](https://github.com/rwightman/pytorch-image-models) 和 [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM).

## 动态
- [2 Mar 2023] 支持v1.6.1版本, 增加Galactica模型 [#234](https://github.com/FlagAI-Open/FlagAI/pull/234), 大模型推理的低资源工具包BMInf [#238](https://github.com/FlagAI-Open/FlagAI/pull/238), 以及P-tuning样例 [#227](https://github.com/FlagAI-Open/FlagAI/pull/238)
- [12 Jan 2023] 发布v1.6.0版本, 新增支持并行训练库 [**BMTrain**](https://github.com/OpenBMB/BMTrain) 以及集成 [**Flash Attention**](https://github.com/HazyResearch/flash-attention) 到 Bert 和 Vit 模型提速端到端训练, 示例见 [FlashAttentionBERT](https://github.com/FlagAI-Open/FlagAI/blob/master/examples/bert_title_generation_english/train_flash_atten.py)和 [FlashAttentionViT](https://github.com/FlagAI-Open/FlagAI/blob/master/examples/vit_cifar100/train_single_gpu_flash_atten.py). 同时增加了基于对比搜索的文本生成方法 [**SimCTG**](https://github.com/yxuansu/SimCTG) 以及基于 AltDiffusion 进行 DreamBooth 个性化微调, 示例见 [AltDiffusionNaruto](https://github.com/FlagAI-Open/FlagAI/blob/master/examples/AltDiffusion/dreambooth.py).
- [28 Nov 2022] 发布v1.5.0版本, 支持1.1B参数的 [**EVA-CLIP**](https://github.com/FlagAI-Open/FlagAI/tree/master/examples/EVA_CLIP) 以及[ALM: 基于GLM的阿拉伯语大模型], 示例见[**ALM**](https://github.com/FlagAI-Open/FlagAI/tree/master/examples/ALM)
- [10 Nov 2022] 发布v1.4.0版本, 支持[AltCLIP: 更改CLIP中的语言编码器以扩展语言功能](https://arxiv.org/abs/2211.06679v1), 示例见[**AltCLIP**](https://github.com/FlagAI-Open/FlagAI/tree/master/examples/AltCLIP)以及[**AltDiffusion**](https://github.com/FlagAI-Open/FlagAI/tree/master/examples/AltDiffusion)
Expand Down
2 changes: 1 addition & 1 deletion doc_zh/TUTORIAL_15_BERT_EXAMPLE_TITLE_GENERATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
### 1. 数据加载
样例数据位于 /examples/bert_title_generation/data/

需要在 ```trianer.py```文件中定义数据读取过程,例如:
需要在 ```trainer.py```文件中定义数据读取过程,例如:
```python
def read_file():
src = []
Expand Down
54 changes: 54 additions & 0 deletions doc_zh/TUTORIAL_21_OPTIMIZER.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# 如何使用优化器

## 优化器是什么?
在机器学习和深度学习的语境下,
优化器(Optimizer)是指用于更新模型参数的算法或方法,以便最小化预测输出和实际输出之间的误差。

优化器的目标是找到最优的参数组合,以在给定任务上获得最佳性能。
这个过程通常在机器学习模型的训练阶段执行。

优化器通过计算损失函数相对于模型参数的梯度,并使用这些信息来更新参数,以减少损失。
有多种可用的优化算法,例如随机梯度下降(SGD)、Adagrad、Adam、RMSprop等,每种算法都有其优点和缺点。

优化器的选择取决于特定问题、数据集的大小、模型的复杂性和其他因素。
一个好的优化器可以显著提高模型的训练速度和准确性。




## 加载优化器

### 依赖
#### adan
```
python3 -m pip install git+https://github.com/sail-sg/Adan.git
```
#### lion
```
$ pip install lion-pytorch
```
#### lamb
```
$ pip install torch_optimizer
```
#### 例子
```python
>>> # currently FlagAI support adam, adamw, lion, adan, adafactor and lamb, which can be defined by setting optimizer_type when defining Trainer
>>> trainer = Trainer(env_type='pytorch',
>>> epochs=1,
>>> batch_size=2,
>>> eval_interval=100,
>>> log_interval=10,
>>> experiment_name='glm_large_bmtrain',
>>> pytorch_device='cuda',
>>> load_dir=None,
>>> lr=1e-4,
>>> num_gpus = 1,
>>> weight_decay=1e-2,
>>> save_interval=1000,
>>> hostfile='./hostfile',
>>> training_script=__file__,
>>> deepspeed_config='./deepspeed.json',
>>> optimizer_type='lion') #load optimizer
```

2 changes: 1 addition & 1 deletion doc_zh/TUTORIAL_3_MODEL.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
## From_pretrain

`From_pretrain` 函数用于加载模型。同一个模型结构的模型可以用同一个class进行加载,比如`BERT-base-ch` 和`Roberta-base-ch`模型都能用`BertModel`这个`Class`进行加载。`From_pretrain`为了数据/模型并行的模型加载进行了特定优化,避免重复下载导致的资源浪费。
通过调用`ClassName.from_pretrian()`来进行加载.
通过调用`ClassName.from_pretrain()`来进行加载.
### 从modelhub加载
现在我们支持从modelhub中下载[常用模型](#所有支持模型),可以直接通过`from_pretrain`下载模型配置文件`config.json`,模型权重`pytorch_model.bin`,以及字典文件`vocab.txt`。例子:
```python
Expand Down
57 changes: 57 additions & 0 deletions docs/TUTORIAL_21_OPTIMIZER.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# How to use Optimizer

## What is Optimizer?
In the context of machine learning and deep learning,
an optimizer is an algorithm or method used to update the parameters of a model in order to minimize the error between the predicted output and the actual output.

The goal of an optimizer is to find the optimal set of parameters that can achieve the best performance on a given task.
This process is typically performed during the training phase of a machine learning model.

Optimizers work by computing the gradients of the loss function with respect to the model parameters,
and using this information to update the parameters in the direction that reduces the loss.
There are various optimization algorithms available,
such as stochastic gradient descent (SGD), Adagrad, Adam, RMSprop, and more, each with their own advantages and disadvantages.

The choice of optimizer depends on the specific problem, the size of the dataset,
the complexity of the model, and other factors.
A good optimizer can significantly improve the training speed and accuracy of a model.




## Loading optimizer

### dependencies
#### adan
```
python3 -m pip install git+https://github.com/sail-sg/Adan.git
```
#### lion
```
$ pip install lion-pytorch
```
#### lamb
```
$ pip install torch_optimizer
```
#### example
```python
>>> # currently FlagAI support adam, adamw, lion, adan, adafactor and lamb, which can be defined by setting optimizer_type when defining Trainer
>>> trainer = Trainer(env_type='pytorch',
>>> epochs=1,
>>> batch_size=2,
>>> eval_interval=100,
>>> log_interval=10,
>>> experiment_name='glm_large_bmtrain',
>>> pytorch_device='cuda',
>>> load_dir=None,
>>> lr=1e-4,
>>> num_gpus = 1,
>>> weight_decay=1e-2,
>>> save_interval=1000,
>>> hostfile='./hostfile',
>>> training_script=__file__,
>>> deepspeed_config='./deepspeed.json',
>>> optimizer_type='lion') #load optimizer
```

2 changes: 1 addition & 1 deletion docs/TUTORIAL_3_MODEL.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ All supported models now support the three most common model types [encoder, dec

### load model from modelhub

By calling `ClassName.from_pretrian()` to load following [supported models](#all-supported-models), it will automatically download the model configuration file `config.json`, model weights `pytorch_model.bin`, and dictionary files `vocab .txt`.
By calling `ClassName.from_pretrain()` to load following [supported models](#all-supported-models), it will automatically download the model configuration file `config.json`, model weights `pytorch_model.bin`, and dictionary files `vocab .txt`.

```python
>>> # Downloading GLM-large-ch from modelhub
Expand Down
2 changes: 1 addition & 1 deletion examples/bert_title_generation_english/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_dir = "../state_dict/"
model_dir = "./checkpoints/"

# Note "./checkpoints_seq2seq/{}/mp_rank_00_model_states.pt", {} is a directory in the checkpoints_seq2seq.
model_save_path = "./checkpoints_seq2seq/7079/mp_rank_00_model_states.pt"
Expand Down
1 change: 0 additions & 1 deletion examples/bert_title_generation_english/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright © 2022 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
import sys
import os
import torch
from torch.utils.data import Dataset
Expand Down
45 changes: 45 additions & 0 deletions examples/bminf_generate/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@

# BMInf

## 简介/Overview

BMInf is a low-resource inference package for large-scale pretrained language models.

BMInf supports running models with more than 10 billion parameters on a single NVIDIA GTX 1060 GPU in its minimum requirements. Running with better GPUs leads to better performance. In cases where the GPU memory supports the large model inference (such as V100 or A100), BMInf still has a significant performance improvement over the existing PyTorch implementation.

BMInf Github Repository address: https://github.com/OpenBMB/BMInf

BMInf (Big Model Inference) 是一个用于大规模预训练语言模型(pretrained language models, PLM)推理阶段的低资源工具包。

BMInf最低支持在NVIDIA GTX 1060单卡运行百亿大模型。在此基础上,使用更好的gpu运行会有更好的性能。在显存支持进行大模型推理的情况下(如V100或A100显卡),BMInf的实现较现有PyTorch版本仍有较大性能提升。

BMInf 仓库地址:https://github.com/OpenBMB/BMInf

## 应用/Application

在模型加载参数之后,使用如下代码来用BMInf转换模型

```Python
with torch.cuda.device(0):
model = bminf.wrapper(model, quantization=False, memory_limit=20 << 30)
```
The `quantization` parameter represents whether to use the model quantization technique, but if it is a generated class model, it needs to be set to `False`.

You can use the `memory_limit` parameter to set the maximum available storage, the unit is Mb.

`quantization`参数代表是否使用了模型量化的技巧,但如果是生成类模型,则需要设置成`False`

可以用`memory_limit`参数设置最大的可用存储,单位为Mb

如果`bminf.wrapper`不能很好的适配你的模型,你可以用以下的方法来进行手动适配。

* 将 `torch.nn.ModuleList` 替换为 `bminf.TransformerBlockList`.
```python
module_list = bminf.TransformerBlockList([
], [CUDA_DEVICE_INDEX])
```

* 将 `torch.nn.Linear` 替换为 `bminf.QuantizedLinear`.
```python
linear = bminf.QuantizedLinear(torch.nn.Linear(...))
```
35 changes: 35 additions & 0 deletions examples/bminf_generate/cpm1_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
from flagai.auto_model.auto_loader import AutoLoader
from flagai.model.predictor.predictor import Predictor
import bminf
import time


if __name__ == '__main__':

text = '''默写古诗:
白日依山尽,黄河入海流。
床前明月光,'''

loader = AutoLoader(task_name="lm",
model_name="CPM-large-ch",
model_dir="./checkpoints",
device="cpu")

model = loader.get_model()
time_start=time.time()
with torch.cuda.device(0):
model = bminf.wrapper(model, quantization=False, memory_limit=20 << 30)
tokenizer = loader.get_tokenizer()

predictor = Predictor(model=model,
tokenizer=tokenizer,
)

out = predictor.predict_generate_randomsample(text,
top_p=0.9,
out_max_length=50)
time_end=time.time()
print('time cost',time_end-time_start,'s')

print(out)
37 changes: 37 additions & 0 deletions examples/bminf_generate/galactica_6.7b_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

from flagai.model.predictor.predictor import Predictor
from flagai.auto_model.auto_loader import AutoLoader
import torch
import bminf
import time
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


loader = AutoLoader(task_name="lm",
model_name="galactica-6.7b-en",
model_dir="./checkpoints/")

model = loader.get_model()
with torch.cuda.device(0):
model = bminf.wrapper(model, quantization=False, memory_limit=20 << 30)
model.to(device)
model.eval()
tokenizer = loader.get_tokenizer()
predictor = Predictor(model, tokenizer)
print("model loaded")
time_start=time.time()

text = "Please write a abstract about the computer vision. \n"
out = predictor.predict_generate_randomsample(text,
out_max_length=700,
top_k=50,
repetition_penalty=1.2,
temperature=0.7
)

time_end=time.time()
print('time cost',time_end-time_start,'s')
print(out)



20 changes: 20 additions & 0 deletions examples/bminf_generate/glm_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from flagai.model.glm_model import GLMModel
from flagai.data.tokenizer import Tokenizer
from flagai.auto_model.auto_loader import AutoLoader
from flagai.model.predictor.predictor import Predictor
import torch
import bminf

model_name = 'GLM-10b-ch'
loader = AutoLoader("lm", 'GLM-10b-ch', model_dir="./checkpoints/")
model = loader.get_model()
tokenizer = loader.get_tokenizer()
with torch.cuda.device(0):
model = bminf.wrapper(model, quantization=False, memory_limit=30 << 39)

tokenizer = Tokenizer.from_pretrained(model_name)
predictor = Predictor(model, tokenizer)

text = "今天天气不错[gMASK]"
output = predictor.predict_generate_randomsample(text, out_max_length=10)
print(text, '\n', output)
35 changes: 35 additions & 0 deletions examples/bminf_generate/gpt2_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright © 2022 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
import torch
from flagai.auto_model.auto_loader import AutoLoader
from flagai.model.predictor.predictor import Predictor
import bminf
import time

if __name__ == '__main__':

loader = AutoLoader("seq2seq",
"GPT2-base-ch",
model_dir="./checkpoints/")
model = loader.get_model()
model = model.to('cpu')
tokenizer = loader.get_tokenizer()
time_start=time.time()
with torch.cuda.device(0):
model = bminf.wrapper(model, quantization=False, memory_limit=20 << 30)
predictor = Predictor(model, tokenizer)

text = "今天天气不错"

out_2 = predictor.predict_generate_randomsample(text,
input_max_length=512,
out_max_length=100,
repetition_penalty=1.5,
top_k=20,
top_p=0.8)

time_end=time.time()
print('time cost',time_end-time_start,'s')
# print(f"out_1 is {out_1}")
print(f"out_2 is {out_2}")
Loading