diff --git a/README.md b/README.md index 8a4af79c6..2e457f7f8 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,10 @@ +**🎉 2024.12.17 支持[GOT-OCR2_0](./paddlemix/examples/GOT_OCR_2_0)推理和训练** + +**🎉 2024.12.17 支持[InternVL2_5(1B、2B、4B、8B)](./paddlemix/examples/internvl2)推理** + **🎉 2024.11.27 支持[Janus/JanusFlow](./paddlemix/examples/janus)推理** **🎉 2024.11.21 支持[MiniCPM-V-2_6](./paddlemix/examples/minicpm-v-2_6)推理** diff --git a/README_EN.md b/README_EN.md index 6d109cf9e..f318e135d 100644 --- a/README_EN.md +++ b/README_EN.md @@ -48,6 +48,7 @@ Welcome your submissions! ## 📣 Latest Developments +**🎉 2024.12.17 Support for [InternVL2_5 (1B, 2B, 4B, 8B)](./paddlemix/examples/internvl2) inference** **🎉 2024.11.27 Added support for [Janus/JanusFlow](./paddlemix/examples/janus) inference** diff --git a/paddlemix/examples/GOT_OCR_2_0/README.md b/paddlemix/examples/GOT_OCR_2_0/README.md index c74efd13b..19425c35d 100644 --- a/paddlemix/examples/GOT_OCR_2_0/README.md +++ b/paddlemix/examples/GOT_OCR_2_0/README.md @@ -2,7 +2,15 @@ ## 1. 模型介绍 -[GOT-OCR2.0](https://arxiv.org/abs/2409.01704)是一款极具突破性的通用OCR模型,旨在解决传统OCR系统(OCR-1.0)和当前大规模视觉语言模型(LVLMs)在OCR任务中的局限性。本仓库提供paddle版本的`GOT-OCR2.0`模型。 +[GOT-OCR2.0](https://arxiv.org/abs/2409.01704)是由 StepFun 和中国科学院大学推出的专用于通用 OCR 任务的多模态大模型,参数量 0.6B,是一款极具突破性的通用OCR多模态模型,旨在解决传统OCR系统(OCR-1.0)和当前大规模视觉语言模型(LVLMs)在OCR任务中的局限性。 + +**本仓库支持的模型权重:** + +| Model | +|--------------------| +| stepfun-ai/GOT-OCR2_0 | + +注意:与huggingface权重同名,但权重为paddle框架的Tensor,使用`xxx.from_pretrained("stepfun-ai/GOT-OCR2_0")`即可自动下载该权重文件夹到缓存目录。 ## 2. 环境要求 @@ -36,11 +44,39 @@ python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py \ --ocr_type format \ ``` +### 3.3. multi_crop plain texts OCR: +```bash +python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py \ + --model_name_or_path stepfun-ai/GOT-OCR2_0 \ + --image_file paddlemix/demo_images/hospital.jpeg \ + --ocr_type ocr \ + --multi_crop \ +``` + ## 4 训练 + +与[官方github代码库](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/?tab=readme-ov-file#train)一样,目前仅支持基于GOT权重的post-training(stage-2/stage-3),其中stage2是全参数微调,stage3是冻结vision encoder后微调,默认训练方式是stage2全参数微调,训练显存约10GB每卡。 + +### 数据集下载 +PaddleMIX团队提供了一个改版的SynthDoG-EN数据集,统一修改了其原先的question为```\nOCR:```,下载链接为: +``` +wget https://paddlenlp.bj.bcebos.com/datasets/paddlemix/playground/synthdog_en.tar # 2.4G +``` +synthdog_en.tar包括了图片images文件夹和标注json文件,需下载解压或软链接在PaddleMIX/目录下。 + +### 数据集格式 + +同[官方例子](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/train_sample.jpg),其中question统一为```\nOCR:```,answer是其OCR结果。 + + +### 训练命令 + ```bash sh paddlemix/examples/GOT_OCR_2_0/run_train.sh ``` +注意:默认训练方式是stage2全参数微调,训练显存约10GB每卡。也可通过设置```--freeze_vision_tower True```冻结vision encoder后微调。 + ## 参考文献 ```BibTeX diff --git a/paddlemix/examples/GOT_OCR_2_0/configs/demo_dataset.json b/paddlemix/examples/GOT_OCR_2_0/configs/demo_dataset.json index 3fe8acb5e..c021a46b4 100644 --- a/paddlemix/examples/GOT_OCR_2_0/configs/demo_dataset.json +++ b/paddlemix/examples/GOT_OCR_2_0/configs/demo_dataset.json @@ -1,6 +1,6 @@ { "synthdog_en": { "images": "synthdog_en/", - "annotations": "synthdog_en/synthdog_en_29765_ocr_1k.json" + "annotations": "synthdog_en/synthdog_en_29765_ocr.json" } } diff --git a/paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py b/paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py index d71f5eac1..4775a4f6e 100644 --- a/paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py +++ b/paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py @@ -13,12 +13,16 @@ # limitations under the License. import argparse + import paddle from paddlenlp.transformers import QWenTokenizer + from paddlemix.models.GOT.GOT_ocr_2_0 import GOTQwenForCausalLM parser = argparse.ArgumentParser() -parser.add_argument("--model_name_or_path", type=str, default="stepfun-ai/GOT-OCR2_0", help="pretrained ckpt and tokenizer") +parser.add_argument( + "--model_name_or_path", type=str, default="stepfun-ai/GOT-OCR2_0", help="pretrained ckpt and tokenizer" +) parser.add_argument("--image_file", type=str, default="paddlemix/demo_images/hospital.jpeg") parser.add_argument("--multi_crop", action="store_true") parser.add_argument("--ocr_type", type=str, default="plain", choices=["ocr", "format"]) @@ -38,41 +42,9 @@ with paddle.no_grad(): if args.multi_crop: # multi-crop OCR: - res = model.chat_crop( - tokenizer, image_file, ocr_type=args.ocr_type, render=args.render, save_render_file="./demo.html" - ) + res = model.chat_crop(tokenizer, image_file, ocr_type=args.ocr_type) else: # plain texts OCR # format texts OCR - # fine-grained OCR - # render the formatted OCR results - res = model.chat( - tokenizer, - image_file, - ocr_type=args.ocr_type, - ocr_box=args.box, - ocr_color=args.color, - render=args.render, - save_render_file="./demo.html", - ) - - # plain texts OCR - # res = model.chat(tokenizer, image_file, ocr_type='ocr') - - # format texts OCR: - # res = model.chat(tokenizer, image_file, ocr_type='format') - - # fine-grained OCR: - # res = model.chat(tokenizer, image_file, ocr_type='ocr', ocr_box='') - # res = model.chat(tokenizer, image_file, ocr_type='format', ocr_box='') - # res = model.chat(tokenizer, image_file, ocr_type='ocr', ocr_color='') - # res = model.chat(tokenizer, image_file, ocr_type='format', ocr_color='') - - # multi-crop OCR: - # res = model.chat_crop(tokenizer, image_file, ocr_type='ocr') - # res = model.chat_crop(tokenizer, image_file, ocr_type='format') - - # render the formatted OCR results: - # res = model.chat(tokenizer, image_file, ocr_type='format', render=True, save_render_file = './demo.html') - - print(res) + res = model.chat(tokenizer, image_file, ocr_type=args.ocr_type) + print("output:\n", res) diff --git a/paddlemix/examples/GOT_OCR_2_0/run_train.sh b/paddlemix/examples/GOT_OCR_2_0/run_train.sh index b1ec2d19e..4d94222a8 100644 --- a/paddlemix/examples/GOT_OCR_2_0/run_train.sh +++ b/paddlemix/examples/GOT_OCR_2_0/run_train.sh @@ -15,7 +15,7 @@ set -x GPUS=${GPUS:-8} -BATCH_SIZE=${BATCH_SIZE:-8} +BATCH_SIZE=${BATCH_SIZE:-32} PER_DEVICE_BATCH_SIZE=${PER_DEVICE_BATCH_SIZE:-1} GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / GPUS)) @@ -28,8 +28,6 @@ export TF_CPP_MIN_LOG_LEVEL=3 OUTPUT_DIR='work_dirs/got_ocr_20' -# meta='pdf-ocr+scence' - if [ ! -d "$OUTPUT_DIR" ]; then mkdir -p "$OUTPUT_DIR" fi @@ -38,6 +36,8 @@ TRAINING_MODEL_RESUME="None" TRAINER_INSTANCES='127.0.0.1' MASTER='127.0.0.1:8080' +# --freeze_vision_tower False \ # True for stage3 + TRAINING_PYTHON="python -m paddle.distributed.launch --master ${MASTER} --nnodes 1 --nproc_per_node ${GPUS} --rank 0 --ips ${TRAINER_INSTANCES} --run_mode=collective" ${TRAINING_PYTHON} --log_dir ${OUTPUT_DIR}/paddle_distributed_logs \ paddlemix/examples/GOT_OCR_2_0/train_GOT.py \ diff --git a/paddlemix/examples/GOT_OCR_2_0/train_GOT.py b/paddlemix/examples/GOT_OCR_2_0/train_GOT.py index 9fdee3c86..a32529b8d 100644 --- a/paddlemix/examples/GOT_OCR_2_0/train_GOT.py +++ b/paddlemix/examples/GOT_OCR_2_0/train_GOT.py @@ -14,22 +14,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import sys -import paddle.distributed as dist -import paddle -import paddlenlp -from paddlemix.datasets.got_dataset import make_supervised_data_module -from paddlemix.models.GOT.GOT_ocr_2_0 import GOTQwenForCausalLM -from paddlenlp.trainer.trainer_utils import get_last_checkpoint +from dataclasses import dataclass, field +from typing import Optional -from paddlemix.models.GOT.utils.utils import smart_tokenizer_and_embedding_resize +import paddle +import paddle.distributed as dist from paddlenlp.trainer import PdArgumentParser, TrainingArguments, set_seed from paddlenlp.trainer.trainer import Trainer -from dataclasses import dataclass, field -from typing import Dict, Optional +from paddlenlp.trainer.trainer_utils import get_last_checkpoint from paddlenlp.transformers import QWenTokenizer -import logging + +from paddlemix.datasets.got_dataset import make_supervised_data_module +from paddlemix.models.GOT.GOT_ocr_2_0 import GOTQwenForCausalLM +from paddlemix.models.GOT.utils.utils import smart_tokenizer_and_embedding_resize + logger = logging.getLogger(__name__) @@ -57,8 +58,8 @@ class ModelArguments: vision_tower: Optional[str] = field(default="openai/clip-vit-large-patch14") freeze_vision_tower: bool = field(default=False) freeze_lm_model: bool = field(default=False) - pretrained_stage1_model: Optional[str] = field(default=None) # mlp &/ vision tower - vision_select_layer: Optional[int] = field(default=-1) # default to the last layer + pretrained_stage1_model: Optional[str] = field(default=None) # mlp &/ vision tower + vision_select_layer: Optional[int] = field(default=-1) # default to the last layer use_im_start_end: bool = field(default=False) @@ -71,14 +72,14 @@ class DataArguments: ) sep_image_conv_front: bool = False image_token_len: int = 256 - image_aspect_ratio: str = 'square' - conversation_version: str = 'mpt' + image_aspect_ratio: str = "square" + conversation_version: str = "mpt" box_limit: int = 0 max_seq_length: int = 8192 @dataclass -class TrainingArguments(paddlenlp.trainer.TrainingArguments): +class GOTTrainingArguments(TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") remove_unused_columns: bool = field(default=False) @@ -87,10 +88,7 @@ class TrainingArguments(paddlenlp.trainer.TrainingArguments): with_box: bool = field(default=False) model_max_length: int = field( default=512, - metadata={ - "help": - "Maximum sequence length. Sequences will be right padded (and possibly truncated)." - }, + metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, ) lora_enable: bool = False lora_r: int = 8 @@ -101,9 +99,7 @@ class TrainingArguments(paddlenlp.trainer.TrainingArguments): def train(): - # parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) - # model_args, data_args, training_args = parser.parse_args_into_dataclasses() - parser = PdArgumentParser((ModelArguments, DataArguments, TrainingArguments)) + parser = PdArgumentParser((ModelArguments, DataArguments, GOTTrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script, and it's the path to a json file, # let's parse it to get our arguments. @@ -147,19 +143,17 @@ def train(): print(f"Loading Tokenizer: {tokenizer_path}") tokenizer = QWenTokenizer.from_pretrained( - model_args.model_name_or_path, - padding_side="right", - model_max_length=training_args.model_max_length) + model_args.model_name_or_path, padding_side="right", model_max_length=training_args.model_max_length + ) print("tokenizer", tokenizer) - print("len(tokenizer)", len(tokenizer)) - print("tokenizer.added_tokens_encoder", tokenizer.added_tokens_encoder) - print("tokenizer.added_tokens_decoder", tokenizer.added_tokens_decoder) + # print("len(tokenizer)", len(tokenizer)) + # print("tokenizer.added_tokens_encoder", tokenizer.added_tokens_encoder) + # print("tokenizer.added_tokens_decoder", tokenizer.added_tokens_decoder) - model = GOTQwenForCausalLM.from_pretrained( - model_args.model_name_or_path, dtype=dtype) + model = GOTQwenForCausalLM.from_pretrained(model_args.model_name_or_path, dtype=dtype) smart_tokenizer_and_embedding_resize( - special_tokens_dict=dict(pad_token='<|endoftext|>'), + special_tokens_dict=dict(pad_token="<|endoftext|>"), tokenizer=tokenizer, model=model, ) @@ -174,16 +168,15 @@ def train(): ) model.initialize_vision_tokenizer( - tokenizer=tokenizer, - freeze_lm_model=model_args.freeze_lm_model, + tokenizer=tokenizer, + freeze_lm_model=model_args.freeze_lm_model, pretrained_stage1_model=model_args.pretrained_stage1_model, ) # 'image_processor_high - # data_args.image_token_len = vision_tower_dict['image_token_len'] data_args.image_token_len = 256 - data_args.image_processor = vision_tower_dict['image_processor'] - data_args.image_processor_high = vision_tower_dict['image_processor_high'] + data_args.image_processor = vision_tower_dict["image_processor"] + data_args.image_processor_high = vision_tower_dict["image_processor_high"] data_args.use_im_start_end = model_args.use_im_start_end def _freeze_params(module): @@ -199,11 +192,9 @@ def _freeze_params(module): if model_args.freeze_vision_tower: _freeze_params(model.qwen2.vision_tower_high) - # params_grad = [p.numel() for n, p in model.named_parameters() if p.requires_grad] - # print(f"Number of Mapping Trainable Parameters: {sum(params_grad) / (1 << 20):.2f} M") print_trainable_params(model) - # trainable params: 464959488 || all params: 560528640 || trainable%: 82.9502 - + # trainable params: 464959488 || all params: 560528640 || trainable%: 82.9502 # stage3 + # trainable params: 560528640 || all params: 560528640 || trainable%: 100 # stage2 params_grad = [p.numel() for n, p in model.named_parameters() if not p.stop_gradient] print(f"Number of Mapping Trainable Parameters: {sum(params_grad) / (1 << 20):.2f} M") @@ -217,13 +208,9 @@ def _freeze_params(module): set_seed(training_args.seed) data_module = make_supervised_data_module( - interleave=training_args.interleave, - with_box=training_args.with_box, - tokenizer=tokenizer, - data_args=data_args + interleave=training_args.interleave, with_box=training_args.with_box, tokenizer=tokenizer, data_args=data_args ) - #trainer = GOTTrainer( trainer = Trainer( model=model, args=training_args, diff --git a/paddlemix/examples/internvl2/README.md b/paddlemix/examples/internvl2/README.md index 61b57c70a..ca1b50c92 100644 --- a/paddlemix/examples/internvl2/README.md +++ b/paddlemix/examples/internvl2/README.md @@ -26,8 +26,12 @@ | Model | |--------------------| | OpenGVLab/InternVL2-1B | +| OpenGVLab/InternVL2_5-1B | | OpenGVLab/InternVL2-2B | +| OpenGVLab/InternVL2_5-2B | +| OpenGVLab/InternVL2_5-4B | | OpenGVLab/InternVL2-8B | +| OpenGVLab/InternVL2_5-8B | | OpenGVLab/InternVL2-26B | | OpenGVLab/InternVL2-40B | | OpenGVLab/InternVL2-8B-MPO | diff --git a/paddlemix/examples/internvl2/chat_demo.py b/paddlemix/examples/internvl2/chat_demo.py index ff9165828..14d6bd7a2 100644 --- a/paddlemix/examples/internvl2/chat_demo.py +++ b/paddlemix/examples/internvl2/chat_demo.py @@ -99,12 +99,14 @@ def load_tokenizer(model_path): import re match = re.search(r"\d+B", model_path) + model2_5 = "InternVL2_5" in model_path if match: model_size = match.group() else: model_size = "2B" - - if model_size in ["1B"]: + if model2_5 and model_size in ["1B", "4B"]: + tokenizer = Qwen2Tokenizer.from_pretrained(model_path) + elif model_size in ["1B"]: tokenizer = Qwen2Tokenizer.from_pretrained(model_path) elif model_size in ["2B", "8B", "26B"]: tokenizer = InternLM2Tokenizer.from_pretrained(model_path) @@ -135,8 +137,7 @@ def main(args): print("len(tokenizer): ", len(tokenizer)) model = InternVLChatModel.from_pretrained(MODEL_PATH, dtype=args.dtype).eval() - - generation_config = dict(max_new_tokens=1024, do_sample=False) + generation_config = dict(max_new_tokens=1024, do_sample=False, top_p=0.01) with paddle.no_grad(): response, history = model.chat( diff --git a/paddlemix/models/GOT/GOT_ocr_2_0.py b/paddlemix/models/GOT/GOT_ocr_2_0.py index 66ae85824..f1ab5336f 100644 --- a/paddlemix/models/GOT/GOT_ocr_2_0.py +++ b/paddlemix/models/GOT/GOT_ocr_2_0.py @@ -12,30 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses +from enum import Enum, auto from io import BytesIO from typing import List, Optional import paddle import paddle.nn as nn import requests -from paddlenlp.generation.stopping_criteria import ( - StoppingCriteriaList, -) +from paddle.vision import transforms +from paddlenlp.generation.stopping_criteria import StoppingCriteriaList from paddlenlp.transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2Model from paddlenlp.transformers.model_outputs import CausalLMOutputWithPast from PIL import Image +from ...processors.got_process import BlipImageEvalProcessor +from .got_vision_b import build_GOT_vit_b + DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" -import dataclasses -from enum import Enum, auto -from paddle.vision import transforms -from ...processors.got_process import BlipImageEvalProcessor -from .got_vision_b import build_GOT_vit_b - class Qwen2LMHead(nn.Layer): def __init__(self, config, embedding_weights=None, transpose_y=False, tensor_parallel_output=1): @@ -263,11 +261,9 @@ def forward( if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None: use_im_start_end = getattr(self.config, "use_im_start_end", -1) - # vision_select_layer = getattr(self.config, "vision_select_layer", -1) im_patch_token = getattr(self.config, "im_patch_token", -1) im_start_token = getattr(self.config, "im_start_token", -1) im_end_token = getattr(self.config, "im_end_token", -1) - # freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False) im_patch_token = 151859 im_start_token = 151857 @@ -300,7 +296,6 @@ def forward( image_features.append(image_feature) dummy_image_features_2 = paddle.zeros([256, 1024], dtype=inputs_embeds.dtype) - # dummy_image_features_2 = self.mm_projector_vary(dummy_image_features_2) dummy_image_features = dummy_image_features_2 use_im_start_end = True new_input_embeds = [] @@ -339,11 +334,11 @@ def forward( return super().forward( input_ids=None, - attention_mask=attention_mask, # [1, 1, 1, 800] + attention_mask=attention_mask, # past_key_values=past_key_values, # None inputs_embeds=inputs_embeds, # [1, 800, 1024] use_cache=use_cache, # True - position_ids=position_ids, # [1, 1, 1, 800] + position_ids=position_ids, # output_attentions=output_attentions, # False output_hidden_states=output_hidden_states, # False return_dict=return_dict, # False @@ -358,7 +353,6 @@ def __init__(self, config): self.qwen2 = GOTQwenModel(config) self.vocab_size = config.vocab_size - # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias_attr=False) if config.tie_word_embeddings: self.lm_head = Qwen2LMHead(config, embedding_weights=self.qwen2.embed_tokens.weight, transpose_y=True) @@ -416,14 +410,14 @@ def forward( shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:] # Flatten the tokens - #loss_fct = nn.CrossEntropyLoss() + # loss_fct = nn.CrossEntropyLoss() loss_fct = nn.CrossEntropyLoss(reduction="sum") shift_logits = shift_logits.reshape([-1, self.config.vocab_size]) shift_labels = shift_labels.reshape([-1]) # Enable model parallelism loss = loss_fct(shift_logits, shift_labels) - label_sum = paddle.sum(shift_labels != -100) #.cast("float32") + label_sum = paddle.sum(shift_labels != -100) # .cast("float32") loss = loss / label_sum if not return_dict: @@ -441,48 +435,14 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): - # input_ids [1, 287], past_key_values=None, attention_mask [1, 287], inputs_embeds=None - # kwargs ['images', 'use_cache', 'cache_position'] - # [1, 3, 1024, 1024], True, [0,,,,286] - - # input_ids [1, 288], past_key_values len(past_key_values)=24, attention_mask [1, 288], inputs_embeds=None - # kwargs ['images', 'use_cache', 'cache_position'] - # [1, 3, 1024, 1024], True, [287] - batch_size, seq_length = input_ids.shape attention_mask = paddle.ones((batch_size, seq_length), dtype=paddle.bool) # Omit tokens covered by past_key_values if past_key_values is not None: - # if isinstance(past_key_values, Cache): ### - # cache_length = past_key_values.get_seq_length() - # past_length = past_key_values.seen_tokens - # max_cache_length = past_key_values.get_max_length() - # else: past_length = past_key_values[0][0].shape[1] # [1, 800, 16, 64] - # max_cache_length = None - # cache_length = past_length - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - # if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - # input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # # input_ids based on the past_length. - # el if past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - # if ( - # max_cache_length is not None - # and attention_mask is not None - # and cache_length + input_ids.shape[1] > max_cache_length - # ): - # attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -612,18 +572,11 @@ def chat( conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() - if print_prompt: - print("prompt", prompt) - inputs = tokenizer([prompt]) image_tensor_1 = image_processor_high(image) - input_ids = paddle.to_tensor(inputs.input_ids) - # print('input_ids', input_ids.shape, input_ids.sum().item(), input_ids) - # [1, 287] - stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) @@ -637,11 +590,8 @@ def chat( max_new_tokens=4096, stopping_criteria=stopping_criteria, # list of stopping criteria )[0] - # print('output_ids:\n', output_ids.shape, output_ids.sum().item(), output_ids) - # outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() outputs = tokenizer.decode(output_ids[0]).strip() - # print('outputs', outputs) if outputs.endswith(stop_str): outputs = outputs[: -len(stop_str)] @@ -663,7 +613,6 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_ elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio - # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') return best_ratio orig_width, orig_height = image.size @@ -677,7 +626,6 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_ for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num ) - # print(target_ratios) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target @@ -685,7 +633,6 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_ aspect_ratio, target_ratios, orig_width, orig_height, image_size ) - # print(target_aspect_ratio) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] @@ -732,21 +679,12 @@ def chat_crop( image_list = [] - # if len(image_file_list)>1: - # multi_page = True - if multi_page: qs = "OCR with format across multi pages: " - # only for png files - # import glob - # from natsort import natsorted - # patches = glob.glob(image_file + '/*png') patches = image_file - # patches = natsorted(patches) sub_images = [] for sub_image in patches: sub_images.append(self.load_image(sub_image)) - ll = len(patches) else: @@ -797,13 +735,9 @@ def chat_crop( conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() - if print_prompt: - print(prompt) - inputs = tokenizer([prompt]) input_ids = paddle.to_tensor(inputs.input_ids) - # print('input_ids', input_ids.shape, input_ids.sum().item(), input_ids) stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] @@ -819,7 +753,6 @@ def chat_crop( stopping_criteria=stopping_criteria, )[0] - # outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() outputs = tokenizer.decode(output_ids[0]).strip() if outputs.endswith(stop_str): diff --git a/paddlemix/models/GOT/got_vision_b.py b/paddlemix/models/GOT/got_vision_b.py index a005a7f48..c88fd6eb4 100644 --- a/paddlemix/models/GOT/got_vision_b.py +++ b/paddlemix/models/GOT/got_vision_b.py @@ -198,7 +198,6 @@ def __init__( def forward(self, x: paddle.Tensor) -> paddle.Tensor: shortcut = x - # import pdb; pdb.set_trace() x = self.norm1(x) # Window partition if self.window_size > 0: @@ -443,7 +442,6 @@ def _build_GOT_vision( prompt_embed_dim = 256 image_size = 1024 vit_patch_size = 16 - # image_embedding_size = image_size // vit_patch_size image_encoder = ImageEncoderViT( depth=encoder_depth, embed_dim=encoder_embed_dim, @@ -459,4 +457,4 @@ def _build_GOT_vision( out_chans=prompt_embed_dim, ) - return image_encoder \ No newline at end of file + return image_encoder diff --git a/paddlemix/models/GOT/utils/constants.py b/paddlemix/models/GOT/utils/constants.py index 5caa54e01..18018a41a 100644 --- a/paddlemix/models/GOT/utils/constants.py +++ b/paddlemix/models/GOT/utils/constants.py @@ -18,7 +18,6 @@ LOGDIR = "log" IGNORE_INDEX = -100 -# DEFAULT_PAD_TOKEN = "[PAD]" DEFAULT_PAD_TOKEN = "<|endoftext|>" DEFAULT_EOS_TOKEN = "" diff --git a/paddlemix/models/GOT/utils/conversation.py b/paddlemix/models/GOT/utils/conversation.py index 54b08b4a5..0572f900f 100644 --- a/paddlemix/models/GOT/utils/conversation.py +++ b/paddlemix/models/GOT/utils/conversation.py @@ -77,25 +77,6 @@ def get_prompt(self): return ret else: raise ValueError(f"Invalid style: {self.sep_style}") - # if self.sep_style == SeparatorStyle.MPT: - # if self.system: - # ret = self.system + self.sep - # else: - # ret = '' - # for role, message in self.messages: - # if message: - # if type(message) is tuple: - # message, _, _ = message - # ret += role + message + self.sep - # # if 'user' in role: - # # ret += role + message + self.sep + "\n" - # # else: - # # ret += role + message + self.sep - # else: - # ret += role - # return ret - # else: - # raise ValueError(f"Invalid style: {self.sep_style}") def append_message(self, role, message): self.messages.append([role, message]) diff --git a/paddlemix/models/GOT/utils/utils.py b/paddlemix/models/GOT/utils/utils.py index c16e2e2e0..abe46fdd3 100644 --- a/paddlemix/models/GOT/utils/utils.py +++ b/paddlemix/models/GOT/utils/utils.py @@ -15,7 +15,6 @@ import paddle from paddlenlp.generation.stopping_criteria import StoppingCriteria - server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." handler = None @@ -51,79 +50,14 @@ def smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, model): Note: This is the unoptimized version that may make your embedding size not be divisible by 64. """ - # num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) - # # num_new_tokens = 1 - # # tokenizer.add_tokens(special_tokens_dict, special_tokens=True) - # model.resize_token_embeddings(len(tokenizer)) - num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: - input_embeddings = model.get_input_embeddings().weight # .data - output_embeddings = model.get_output_embeddings().weight # .data - + input_embeddings = model.get_input_embeddings().weight + output_embeddings = model.get_output_embeddings().weight input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg - - -# def maybe_zero_3(param, ignore_status=False, name=None): -# from deepspeed import zero -# from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus -# if hasattr(param, "ds_id"): -# if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: -# if not ignore_status: -# logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") -# with zero.GatheredParameters([param]): -# param = param.data.detach().cpu().clone() -# else: -# param = param.detach().cpu().clone() -# return param - - -# # Borrowed from peft.utils.get_peft_model_state_dict -# def get_peft_state_maybe_zero_3(named_params, bias): -# if bias == "none": -# to_return = {k: t for k, t in named_params if "lora_" in k} -# elif bias == "all": -# to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} -# elif bias == "lora_only": -# to_return = {} -# maybe_lora_bias = {} -# lora_bias_names = set() -# for k, t in named_params: -# if "lora_" in k: -# to_return[k] = t -# bias_name = k.split("lora_")[0] + "bias" -# lora_bias_names.add(bias_name) -# elif "bias" in k: -# maybe_lora_bias[k] = t -# for k, t in maybe_lora_bias: -# if bias_name in lora_bias_names: -# to_return[bias_name] = t -# else: -# raise NotImplementedError -# to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()} -# return to_return - - -# def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): -# to_return = {k: t for k, t in named_params if "lora_" not in k} -# if require_grad_only: -# to_return = {k: t for k, t in to_return.items() if t.requires_grad} -# to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} -# return to_return - - -# def find_all_linear_names(model): -# cls = torch.nn.Linear -# lora_module_names = set() -# for name, module in model.named_modules(): -# if isinstance(module, cls) and 'vision_model' not in name and 'mm_projector' not in name and 'vision_encoder' not in name and 'conv_final' not in name and'lm_head' not in name: -# lora_module_names.add(name) - -# print(lora_module_names) -# return list(lora_module_names) diff --git a/paddlemix/models/internvl2/conversation.py b/paddlemix/models/internvl2/conversation.py index a40a94101..80765d747 100644 --- a/paddlemix/models/internvl2/conversation.py +++ b/paddlemix/models/internvl2/conversation.py @@ -399,3 +399,18 @@ def get_conv_template(name: str) -> Conversation: ] ) ) + +register_conv_template( + Conversation( + name='internvl2_5', + system_template='<|im_start|>system\n{system_message}', + system_message='你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', + roles=('<|im_start|>user\n', '<|im_start|>assistant\n'), + sep_style=SeparatorStyle.MPT, + sep='<|im_end|>\n', + stop_token_ids=[ + 151645, + ] + + ) +) diff --git a/paddlemix/models/internvl2/internvl_chat/modeling_intern_vit.py b/paddlemix/models/internvl2/internvl_chat/modeling_intern_vit.py index 9255bbab9..888f0f4bf 100644 --- a/paddlemix/models/internvl2/internvl_chat/modeling_intern_vit.py +++ b/paddlemix/models/internvl2/internvl_chat/modeling_intern_vit.py @@ -434,7 +434,6 @@ def forward( encoder_states = () if output_hidden_states else None hidden_states = inputs_embeds - for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) @@ -511,7 +510,6 @@ def forward( hidden_states = self.embeddings(pixel_values) else: raise ValueError(f'wrong pixel_values size: {pixel_values.shape}') - encoder_outputs = self.encoder( inputs_embeds=hidden_states, output_hidden_states=output_hidden_states, diff --git a/paddlemix/models/internvl2/internvl_chat/modeling_internvl_chat.py b/paddlemix/models/internvl2/internvl_chat/modeling_internvl_chat.py index ed93d40d4..70f4d2d1b 100644 --- a/paddlemix/models/internvl2/internvl_chat/modeling_internvl_chat.py +++ b/paddlemix/models/internvl2/internvl_chat/modeling_internvl_chat.py @@ -295,7 +295,6 @@ def chat( IMG_CONTEXT_TOKEN="", verbose=False, ): - if history is None and pixel_values is not None and "" not in question: question = "\n" + question @@ -330,7 +329,6 @@ def chat( input_ids = model_inputs["input_ids"] attention_mask = model_inputs["attention_mask"] generation_config["eos_token_id"] = eos_token_id - generation_output = self.generate( pixel_values=pixel_values, # [7, 3, 448, 448] input_ids=input_ids, # [1, 1847] @@ -361,7 +359,6 @@ def generate( return_dict: Optional[bool] = None, **generate_kwargs, ) -> paddle.Tensor: - assert self.img_context_token_id is not None if pixel_values is not None: if visual_features is not None: @@ -386,14 +383,30 @@ def generate( else: input_embeds = self.language_model.get_input_embeddings()(input_ids) - outputs = self.language_model.generate( - inputs_embeds=input_embeds, - attention_mask=attention_mask, - generation_config=generation_config, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - use_cache=True, - **generate_kwargs, - ) + ### must add position_ids, paddlenlp bug + if isinstance(self.language_model, Qwen2ForCausalLM): + batch_size, seq_length = attention_mask.shape + position_ids = paddle.arange(seq_length).expand((batch_size, seq_length)) + outputs = self.language_model.generate( + position_ids=position_ids, + inputs_embeds=input_embeds, + attention_mask=attention_mask, + generation_config=generation_config, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=True, + **generate_kwargs, + ) + ### + else: + outputs = self.language_model.generate( + inputs_embeds=input_embeds, + attention_mask=attention_mask, + generation_config=generation_config, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=True, + **generate_kwargs, + ) return outputs