diff --git a/.gitignore b/.gitignore index b2a41d03f..89e553033 100644 --- a/.gitignore +++ b/.gitignore @@ -195,8 +195,7 @@ cache/ outputs/ wandb/ .idea +.vscode/ # macOS .DS_Store - -.vscode/ diff --git a/configs/qwen2-5-vl-eagle3.json b/configs/qwen2-5-vl-eagle3.json new file mode 100644 index 000000000..672193e3b --- /dev/null +++ b/configs/qwen2-5-vl-eagle3.json @@ -0,0 +1,40 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 3584, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 8192, + "max_window_layers": 28, + "model_type": "llama", + "target_model_type": "qwen2_5_vl", + "num_attention_heads": 28, + "num_hidden_layers": 1, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "pretraining_tp": 1, + "rope_scaling": { + "type": "mrope", + "mrope_section": [ + 16, + 24, + 24 + ] + }, + "rope_theta": 1000000, + "sliding_window": 32768, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 152064, + "draft_vocab_size": 32000 + } diff --git a/examples/run_qwen2_5_vl_eagle3_online.sh b/examples/run_qwen2_5_vl_eagle3_online.sh new file mode 100644 index 000000000..0e64b41ef --- /dev/null +++ b/examples/run_qwen2_5_vl_eagle3_online.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# support tp1 train eagle3 for qwen2.5-vl-7b-instruct +NUM_GPUS=${1:-1} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3_online.py \ + --target-model-path Qwen/Qwen2.5-VL-7B-Instruct \ + --draft-model-config $ROOT_DIR/configs/qwen2-5-vl-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/allava4v_train.jsonl \ + --output-dir $ROOT_DIR/outputs/Qwen2.5-VL-7B-eagle3 \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 8192 \ + --dist-timeout 360 \ + --chat-template qwen2-vl \ + --cache-dir $ROOT_DIR/cache \ + --embedding-key model.embed_tokens.weight \ + --tp-size 1 \ + --is-vlm \ + --min-pixels 50176 \ + --max-pixels 802816 diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py index 63baa7261..95162ab54 100644 --- a/scripts/prepare_data.py +++ b/scripts/prepare_data.py @@ -33,7 +33,7 @@ def parse_args(): parser.add_argument( "--dataset", type=str, - choices=["ultrachat", "sharegpt", "opc"], + choices=["ultrachat", "sharegpt", "sharegpt4v", "allava4v", "opc"], help="The demo dataset to quickly run the training for speculative decoding", ) parser.add_argument( @@ -48,6 +48,17 @@ def parse_args(): default=None, help="The path to the custom dataset, if not specified, the default dataset will be loaded", ) + parser.add_argument( + "--sample-size", + type=int, + default=None, + help="The number of samples to process from the dataset, if not specified, all samples will be processed", + ) + parser.add_argument( + "--split-eval", + action="store_true", + help="Whether to split the dataset into train and eval sets, default is False", + ) return parser.parse_args() @@ -101,12 +112,83 @@ def process_sharegpt_row(row: Dict) -> Tuple[Dict, int]: return row, skipped_count +def process_sharegpt4v_row(row) -> Dict: + """ + sharegpt4v dataset schema: + { + "id": str, + "image": str, # path to the image + "conversations": [ + { + "from": , + "value": , + }, + ... + ] + } + """ + conversations = row["conversations"] + image = f'FreedomIntelligence/ALLaVA-4V/{row["image"]}' + if not os.path.exists(image): + print(f"Image path {image} does not exist, skipping this sample.") + return None, None + formatted_conversations = [] + skipped_count = 0 + for message in conversations: + if message["from"] not in ROLE_MAPPING: + skipped_count += 1 + continue + new_role = ROLE_MAPPING[message["from"]] + if new_role == "user": + text_content = message["value"].replace("\n", "") + content = text_content + else: + content = message["value"] + formatted_conversations.append({"role": new_role, "content": content}) + + row = {"id": row["id"], "image": image, "conversations": formatted_conversations} + return row, skipped_count + + def load_dataset_from_path(data_path: Path): suffix = data_path.suffix.split(".")[1] ds = load_dataset(suffix, data_files=str(data_path), split="train") return ds +def process_and_save_ds(train_ds, test_ds, output_path, proc_fn, dataset_name): + train_output_jsonl_path = output_path.joinpath(f"{dataset_name}_train.jsonl") + if train_output_jsonl_path.exists(): + print( + f"The dataset {dataset_name} has already been processed and saved in {train_output_jsonl_path}, skipping..." + ) + return + + total_skipped_count = 0 + with open(train_output_jsonl_path, "w") as f: + for item in tqdm(train_ds, desc=f"Processing {dataset_name} dataset"): + row, skipped_count = proc_fn(item) + if row is None: + continue + total_skipped_count += skipped_count + f.write(json.dumps(row) + "\n") + + if test_ds is not None: + test_output_jsonl_path = output_path.joinpath(f"{dataset_name}_test.jsonl") + with open(test_output_jsonl_path, "w") as f: + for item in tqdm(test_ds, desc=f"Processing {dataset_name} test dataset"): + row, skipped_count = proc_fn(item) + if row is None: + continue + total_skipped_count += skipped_count + f.write(json.dumps(row) + "\n") + + if total_skipped_count > 0: + print( + f"Skipped {total_skipped_count}/{len(train_ds)+len(test_ds)} messages for {dataset_name}" + ) + + import hashlib @@ -135,6 +217,14 @@ def main(): print("Loading dataset from custom data path: ", args.data_path) ds = load_dataset_from_path(Path(args.data_path)) proc_fn = process_sharegpt_row + elif args.dataset == "sharegpt4v": + ds = load_dataset("Lin-Chen/ShareGPT4V")["train"] + proc_fn = process_sharegpt4v_row + elif args.dataset == "allava4v": + ds = load_dataset("FreedomIntelligence/ALLaVA-4V", name="allava_laion")[ + "instruct" + ] + proc_fn = process_sharegpt4v_row elif args.dataset == "opc": ds = load_dataset( "OpenCoder-LLM/opc-sft-stage1", "largescale_diverse_instruct" @@ -145,30 +235,27 @@ def main(): "This script only supports ultrachat_200k and sharegpt datasets for demo purpose, if you wish to use other datasets, please modify this script." ) + # filter and split dataset + if args.sample_size is not None and args.sample_size < len(ds): + ds = ds.select(range(args.sample_size)) + print(f"Processing {args.sample_size} samples from the dataset {args.dataset}") + if args.split_eval: + ds = ds.train_test_split(test_size=0.05) + train_ds = ds["train"] + test_ds = ds["test"] + else: + train_ds = ds + test_ds = None + if args.output_path is None: root_path = Path(__file__).parent.parent output_path = root_path.joinpath("cache", "dataset") output_path.mkdir(parents=True, exist_ok=True) else: output_path = Path(args.output_path) + output_path.mkdir(parents=True, exist_ok=True) - output_jsonl_path = output_path.joinpath(f"{args.dataset}.jsonl") - - if output_jsonl_path.exists(): - print( - f"The dataset {args.dataset} has already been processed and saved in {output_jsonl_path}, skipping..." - ) - return - - total_skipped_count = 0 - with open(output_jsonl_path, "w") as f: - for item in tqdm(ds, desc=f"Processing {args.dataset} dataset"): - row, skipped_count = proc_fn(item) - total_skipped_count += skipped_count - f.write(json.dumps(row) + "\n") - - if total_skipped_count > 0: - print(f"Skipped {total_skipped_count}/{len(ds)} messages for {args.dataset}") + process_and_save_ds(train_ds, test_ds, output_path, proc_fn, args.dataset) if __name__ == "__main__": diff --git a/scripts/train_eagle3_online.py b/scripts/train_eagle3_online.py index 88922fd58..3d429521f 100644 --- a/scripts/train_eagle3_online.py +++ b/scripts/train_eagle3_online.py @@ -10,13 +10,14 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer from specforge import ( AutoDistributedTargetModel, AutoDraftModelConfig, AutoEagle3DraftModel, OnlineEagle3Model, + QwenVLOnlineEagle3Model, ) from specforge.data import ( build_eagle3_dataset, @@ -41,6 +42,9 @@ def parse_args(): default="model.embed_tokens.weight", help="The key of the embedding weight to load from the target model", ) + parser.add_argument( + "--is-vlm", action="store_true", help="Whether the target model is a VLM" + ) # add training-related arguments parser.add_argument("--train-data-path", type=str, required=True) @@ -112,6 +116,14 @@ def parse_args(): help="The API key for swanlab non-interactive login.", ) + # vlm related args + parser.add_argument( + "--min-pixels", type=int, default=50176 + ) # 64*28*28 for qwen2.5-vl + parser.add_argument( + "--max-pixels", type=int, default=802816 + ) # 1024*28*28 for qwen2.5-vl + parser.add_argument("--build-dataset-num-proc", type=int, default=8) parser.add_argument("--verbose", action="store_true") parser.add_argument("--profile", action="store_true") @@ -143,6 +155,8 @@ def main(): parser.error(f"Unknown tracker: {args.report_to}") tracker = create_tracker(args, args.output_dir) + # load draft model config + draft_model_config = AutoDraftModelConfig.from_file(args.draft_model_config) # detecting last ckpt for draft model draft_model_last_checkpoint = None @@ -161,18 +175,29 @@ def main(): device="cuda", ).eval() else: - target_model = ( - AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=args.target_model_path, - torch_dtype=torch.bfloat16, - cache_dir=args.cache_dir, + if args.is_vlm and draft_model_config.target_model_type == "qwen2_5_vl": + from transformers import Qwen2_5_VLForConditionalGeneration + + target_model = ( + Qwen2_5_VLForConditionalGeneration.from_pretrained( + pretrained_model_name_or_path=args.target_model_path, + torch_dtype=torch.bfloat16, + ) + .eval() + .cuda() + ) + else: + target_model = ( + AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=args.target_model_path, + torch_dtype=torch.bfloat16, + cache_dir=args.cache_dir, + ) + .eval() + .cuda() ) - .eval() - .cuda() - ) print_with_rank("Initialized target model") # load model with resume - draft_model_config = AutoDraftModelConfig.from_file(args.draft_model_config) if draft_model_last_checkpoint: draft_model = ( AutoEagle3DraftModel.from_pretrained( @@ -195,6 +220,14 @@ def main(): # build dataloaders tokenizer = AutoTokenizer.from_pretrained(args.target_model_path) + if args.is_vlm: + processor = AutoProcessor.from_pretrained( + args.target_model_path, + min_pixels=args.min_pixels, + max_pixels=args.max_pixels, + ) + else: + processor = None # convert to dataloader cache_params_string = ( @@ -213,6 +246,8 @@ def main(): max_length=args.max_length, cache_dir=os.path.join(args.cache_dir, "processed_dataset"), cache_key=cache_key, + is_vlm=args.is_vlm, + processor=processor, num_proc=args.build_dataset_num_proc, ) vocab_mapping_path = generate_vocab_mapping_file( @@ -228,6 +263,7 @@ def main(): num_workers=4, shuffle=True, process_group=get_dp_group(), + is_vlm=args.is_vlm, ) print_with_rank("Initialized train dataloader") @@ -242,6 +278,8 @@ def main(): tokenizer, args.chat_template, args.max_length, + is_vlm=args.is_vlm, + processor=processor, num_proc=args.build_dataset_num_proc, ) eval_dataloader = prepare_dp_dataloaders( @@ -250,17 +288,26 @@ def main(): num_workers=4, shuffle=False, process_group=get_dp_group(), + is_vlm=args.is_vlm, ) print_with_rank("Initialized eval dataloader") # build Eagle3 model # broadcast draft model - eagle3_model = OnlineEagle3Model( - target_model=target_model, - draft_model=draft_model, - length=args.ttt_length, - attention_backend=args.attention_backend, - ) + if args.is_vlm and draft_model_config.target_model_type == "qwen2_5_vl": + eagle3_model = QwenVLOnlineEagle3Model( + target_model=target_model, + draft_model=draft_model, + processor=processor, + length=args.ttt_length, + ) + else: + eagle3_model = OnlineEagle3Model( + target_model=target_model, + draft_model=draft_model, + length=args.ttt_length, + attention_backend=args.attention_backend, + ) # eagle3_model = DDP(eagle3_model, find_unused_parameters=True) eagle3_model = FSDP( eagle3_model, @@ -348,11 +395,20 @@ def main(): torch_profiler.export_chrome_trace(output_path) optimizer.zero_grad() - plosses, _, acces = eagle3_model( - input_ids=data["input_ids"].cuda(), - attention_mask=data["attention_mask"].cuda(), - loss_mask=data["loss_mask"].cuda(), - ) + if args.is_vlm: + plosses, _, acces = eagle3_model( + input_ids=data["input_ids"].cuda(), + attention_mask=data["attention_mask"].cuda(), + loss_mask=data["loss_mask"].cuda(), + pixel_values=data["pixel_values"].cuda(), + image_grid_thw=data["image_grid_thw"].cuda(), + ) + else: + plosses, _, acces = eagle3_model( + input_ids=data["input_ids"].cuda(), + attention_mask=data["attention_mask"].cuda(), + loss_mask=data["loss_mask"].cuda(), + ) # calculate weighted loss ploss_weight = [0.8**i for i in range(len(plosses))] @@ -409,11 +465,21 @@ def main(): eval_plosses = [[] for _ in range(eagle3_model.length)] for data in tqdm(eval_dataloader, desc=f"Evaluating Epoch {epoch}"): - plosses, _, acces = eagle3_model( - input_ids=data["input_ids"].cuda(), - attention_mask=data["attention_mask"].cuda(), - loss_mask=data["loss_mask"].cuda(), - ) + if args.is_vlm: + plosses, _, acces = eagle3_model( + input_ids=data["input_ids"].cuda(), + attention_mask=data["attention_mask"].cuda(), + loss_mask=data["loss_mask"].cuda(), + pixel_values=data["pixel_values"].cuda(), + image_grid_thw=data["image_grid_thw"].cuda(), + ) + else: + plosses, _, acces = eagle3_model( + input_ids=data["input_ids"].cuda(), + attention_mask=data["attention_mask"].cuda(), + loss_mask=data["loss_mask"].cuda(), + ) + eval_acces = [eval_acces[i] + [acces[i]] for i in range(len(acces))] eval_plosses = [ eval_plosses[i] + [plosses[i].item()] for i in range(len(plosses)) diff --git a/specforge/__init__.py b/specforge/__init__.py index 207fc554f..0003e47bd 100644 --- a/specforge/__init__.py +++ b/specforge/__init__.py @@ -7,4 +7,5 @@ "AutoEagle3DraftModel", "AutoDistributedTargetModel", "AutoDraftModelConfig", + "QwenVLOnlineEagle3Model", ] diff --git a/specforge/core/__init__.py b/specforge/core/__init__.py index 9baaf8241..9587d90a7 100644 --- a/specforge/core/__init__.py +++ b/specforge/core/__init__.py @@ -1,3 +1,3 @@ -from .eagle3 import OfflineEagle3Model, OnlineEagle3Model +from .eagle3 import OfflineEagle3Model, OnlineEagle3Model, QwenVLOnlineEagle3Model -__all__ = ["OnlineEagle3Model", "OfflineEagle3Model"] +__all__ = ["OnlineEagle3Model", "OfflineEagle3Model", "QwenVLOnlineEagle3Model"] diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index c0aefee14..68894fa7e 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -442,6 +442,293 @@ def forward( return plosses, vlosses, acces +class QwenVLOnlineEagle3Model(Eagle3Model): + """ + In sgl-spec, we implement offline/online training. + Online training means we have the target hidden_states available during training. + Eagle3 using test time training technique (TTT) to train the draft model. + 1. We first extract the hidden states from the target model. + 2. Then concatenate the hidden states from 3 aux layers (layer 1, layer num_layers//2, layer num_layers-4). + 3. We project the concatenated hidden states to the target hidden size. from (batch, seq_len, 3*hidden_size) to (batch, seq_len, hidden_size) + 4. We concat the projected hidden states and embedding output as the input for the draft model. + 5. finally, we run TTT to train the draft model. input size is (batch, seq_len, hidden_size * 2) + """ + + def __init__( + self, target_model, draft_model: Eagle3DraftModel, processor, length: int = 7 + ): + """ + Args: + target_model: the target model to extract hidden states. + draft_model: the draft model to be trained. + length: TTT length, it means how many turns to unroll during TTT. + """ + super().__init__() + self.target_model = target_model + self.draft_model = draft_model + self.processor = processor + self.length = length + + @torch.no_grad() + def _prepare_data( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + modified from: https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py#L692 + Extract the hidden states from the target model outputs. + + Args: + input_ids: (batch, seq_len) + attention_mask: (batch, seq_len) + loss_mask: (batch, seq_len) + device: the device to run the target model, if None, use the input_ids device + pixel_values: image pixel values, used for VLM models + image_grid_thw: image grid thw, used for VLM models + + Returns: + hidden_states: (batch, seq_len, 3*hidden_size) + target: (batch, seq_len, vocab_size) + loss_mask: (batch, seq_len) + input_ids: (batch, seq_len) + """ + + if device is None: + device = input_ids.device + + # run the target model to get the hidden states + outputs = self.target_model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + output_hidden_states=True, + use_cache=False, + ) + + # extract the aux hidden states + # output_hidden_states = True will return the embedding output as well + # so we have an offset of 1 + num_hidden_states = len(outputs.hidden_states) + offset = 1 + num_layers = num_hidden_states - 1 + + # Eagle3 uses 3 aux layers from layer 1, num_layers//2, num_layers-4 + low_aux_layer = 1 + offset + mid_aux_layer = num_layers // 2 - 1 + offset + last_aux_layer = num_layers - 4 + offset + + hidden_states0 = outputs.hidden_states[low_aux_layer] + hidden_states1 = outputs.hidden_states[mid_aux_layer] + hidden_states2 = outputs.hidden_states[last_aux_layer] + + hidden_states = torch.cat( + (hidden_states0, hidden_states1, hidden_states2), dim=-1 + ) + + # apply pading + target = outputs.logits + target = padding(target, left=False) + input_ids = padding(input_ids, left=False) + + if target is not None: + target = target.to(device) + loss_mask = loss_mask[..., None] + loss_mask = loss_mask.to(device) + + return hidden_states, target, loss_mask, input_ids + + @torch.no_grad() + def _get_input_embeds( + self, + input_ids: torch.Tensor, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, + ) -> torch.Tensor: + # get input embeding with image + # inputs_embeds = self.target_model.model.get_input_embeddings()(input_ids) + inputs_embeds = self.draft_model.embed_input_ids(input_ids) + image_embeds = self.target_model.model.get_image_features( + pixel_values, image_grid_thw + ) + image_embeds = torch.cat(image_embeds, dim=0) + n_image_tokens = ( + input_ids == self.target_model.model.config.image_token_id + ).sum() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.target_model.model.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_ids: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + """ + Online eagle model trainer, modified from: https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py#L711 + + Args: + input_ids: (batch, seq_len) + attention_mask: (batch, seq_len) + loss_mask: (batch, seq_len) + past_key_values: We dont use this past_key_values in eagle3, but keep it for compatibility. We control kvcache by cache_hidden. + position_ids: (batch, seq_len) + pixel_values: batch image pixel values, used for VLM models + image_grid_thw: (batch, 3), image grid thw, used for VLM models + """ + # Step 1: prepare data with the target model + hidden_states, target, loss_mask, input_ids = self._prepare_data( + input_ids, attention_mask, loss_mask, pixel_values, image_grid_thw + ) + + # basic info + batch_size, seq_length, _ = hidden_states.shape + seq_length_with_past = seq_length + past_key_values_length = 0 + + # Step 2: project the concatenated hidden states to the target hidden size + hidden_states = self.draft_model.project_hidden_states(hidden_states) + + # Step 3: process kv cache, position ids and position ids + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + attention_mask_tensor = ( + attention_mask + if not isinstance(attention_mask, dict) + else attention_mask["full_attention"] + ) + if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: + attention_mask_tensor = torch.diagonal( + attention_mask_tensor[:, 0], dim1=1, dim2=2 + ) + attention_mask_tensor = ( + attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min + ) + attention_mask_tensor = (1.0 - attention_mask_tensor).int() + + position_ids, rope_deltas = self.target_model.model.get_rope_index( + input_ids, + image_grid_thw, + None, + second_per_grid_ts=None, + attention_mask=attention_mask_tensor, + ) + self.rope_deltas = rope_deltas + else: + position_ids = position_ids + + # Step 4: handle attention mask + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=hidden_states.device, + ) + attention_mask = self.draft_model.prepare_decoder_attention_mask( + attention_mask=attention_mask, + hidden_states=hidden_states, + batch_size=batch_size, + seq_length=seq_length, + past_key_values_length=past_key_values_length, + ) + + # Step 5: run TTT + plosses = [] + vlosses = [] + acces = [] + cache_hidden = [[], []] + + for idx in range(self.length): + is_last = idx == self.length - 1 + + # Step 5.1: embed the input ids + # inputs_embeds = self._get_input_embeds(input_ids, pixel_values, image_grid_thw) + inputs_embeds = self.draft_model.embed_input_ids(input_ids) + inputs_embeds = inputs_embeds.to(hidden_states.dtype) + + # Step 5.2: run the draft model backbone + hidden_states_out = self.draft_model.backbone( + input_embeds=inputs_embeds, + hidden_states=hidden_states, + cache_hidden=cache_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=True, + ) + + # Step 5.3: handle vocab size + with torch.no_grad(): + target_head = target + target_max_token = target_head.argmax(-1) + target_mask = self.draft_model.t2d[target_max_token] + target_mask = target_mask[..., None].int() + position_mask = target_mask * loss_mask + target_head = target_head[..., self.draft_model.t2d] + target_head = target_head.float() + target_p = nn.Softmax(dim=2)(target_head) + target_p = target_p.detach() + + # update hidden states for next step + hidden_states = hidden_states_out + + # Step 5.4: get logits + logits = self.draft_model.compute_logits(hidden_states) + logits = logits.float() + + # Step 5.5: calculate loss + out_logp = nn.LogSoftmax(dim=2)(logits) + plogp = target_p * out_logp + loss = -torch.sum(position_mask * plogp, 2).mean() + + # Step 5.6: record metrics + plosses.append(loss) + with torch.no_grad(): + acces.append( + ( + (logits.argmax(-1) == target_p.argmax(-1)) + * position_mask.squeeze(-1) + ) + .sum() + .item() + / (loss_mask.sum().item() + 1e-6) + ) + + if not is_last: + # Step 5.7: we need to update the loss mask + input_ids = padding(input_ids, left=False) + target = padding(target, left=False) + loss_mask = padding(loss_mask, left=False) + ind = torch.arange(seq_length, device=attention_mask.device) + ind0 = ind[idx:] + ind1 = ind[: seq_length - idx] + attention_mask[:, :, ind0, ind1] = torch.finfo(attention_mask.dtype).min + return plosses, vlosses, acces + + def _compute_target_p_padded(target, t2d, loss_mask, length): with torch.no_grad(): target_p, position_mask = _compute_target_p( diff --git a/specforge/data/preprocessing.py b/specforge/data/preprocessing.py index f90e4da82..f883fcd81 100644 --- a/specforge/data/preprocessing.py +++ b/specforge/data/preprocessing.py @@ -28,8 +28,9 @@ import torch from datasets import Dataset as HFDataset +from qwen_vl_utils import process_vision_info from tqdm import tqdm -from transformers import PreTrainedTokenizer +from transformers import ImageProcessingMixin, PreTrainedTokenizer from specforge.utils import padding @@ -139,6 +140,141 @@ def preprocess_conversations( return results +def preprocess_vlm_conversations( + processor: ImageProcessingMixin, + examples: List[Conversation], + chat_template: ChatTemplate, + max_length: int = 2048, +) -> Dict[str, List[torch.Tensor]]: + """ + Preprocess a batch of ShareGPT style conversations. + + Args: + processor: The image processor to use for processing images. + examples: A list of examples, where each example is a dictionary containing: + - image: The image in the conversation. + - conversations: A list of conversations, where each conversation is a list of messages. + chat_template: The chat template to use for formatting the conversations. + max_length: The maximum length of the tokenized input. + + Returns: + A dictionary containing: + - input_ids: List of tokenized input IDs. + - loss_mask: List of loss masks indicating which tokens should contribute to the loss. + - attention_mask: List of attention masks. + - pixel_values: List of pixel values for images in the examples. + - image_grid_thw: List of image grid tensors. + """ + system_prompt = chat_template.system_prompt + user_message_separator = ( + f"{chat_template.end_of_turn_token}{chat_template.user_header}" + ) + assistant_message_separator = ( + f"{chat_template.end_of_turn_token}{chat_template.assistant_header}" + ) + + # prepare result + results = { + "input_ids": [], + "loss_mask": [], + "attention_mask": [], + "pixel_values": [], + "image_grid_thw": [], + } + + # Note: currently, we assume that each example has only one image + for i, image in enumerate(examples["image"]): + source = examples["conversations"][i] + messages = [{"role": "system", "content": system_prompt}] + if not source: + # if the source is None, skip it + continue + + if source[0]["role"] != "user": + # if the first message is not from user, skip it + source = source[1:] + + convroles = ["user", "assistant"] + for j, sentence in enumerate(source): + role = sentence["role"] + assert role == convroles[j % 2], f"unexpected role {role}" + if role == "user": + # if the message is from user and has image, process the image + messages.append( + { + "role": role, + "content": [ + { + "type": "image", + "image": image, + }, + {"type": "text", "text": sentence["content"]}, + ], + } + ) + else: + messages.append({"role": role, "content": sentence["content"]}) + + conversation = processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False, + ) + # get vision infor use qwen_vl_utils + image_inputs, video_inputs = process_vision_info(messages) + assert image_inputs is not None, "image_inputs must not be None" + + encoding = processor( + text=[conversation], + images=image_inputs, + videos=video_inputs, + max_length=max_length, + truncation=True, + return_tensors="pt", + return_offsets_mapping=True, + add_special_tokens=False, + ) + input_ids = encoding.input_ids[0] + offsets = encoding.offset_mapping[0] + loss_mask = torch.zeros(len(input_ids), dtype=torch.long) + pixel_values = encoding.pixel_values + image_grid_thw = encoding.image_grid_thw[0] + + # Find spans of assistant responses using regex + assistant_pattern = ( + re.escape(assistant_message_separator) + + r"(.*?)(?=" + + re.escape(user_message_separator) + + "|$)" + ) + + # get conversation with image info + decoded_conversation = processor.tokenizer.decode( + encoding.input_ids[0], skip_special_tokens=False + ) + + for match in re.finditer(assistant_pattern, decoded_conversation, re.DOTALL): + # Assistant response text span (excluding assistant_header itself) + assistant_start_char = match.start(1) + assistant_end_char = match.end(1) + + # Mark tokens overlapping with assistant response + for idx, (token_start, token_end) in enumerate(offsets): + # Token is part of the assistant response span + if token_end <= assistant_start_char: + continue # token before assistant text + if token_start > assistant_end_char: + continue # token after assistant text + loss_mask[idx] = 1 + + results["input_ids"].append(input_ids[None, :]) + results["loss_mask"].append(loss_mask[None, :]) + results["attention_mask"].append(torch.ones_like(loss_mask)[None, :]) + results["pixel_values"].append(pixel_values) + results["image_grid_thw"].append(image_grid_thw[None, :]) + return results + + def build_eagle3_dataset( dataset: HFDataset, tokenizer: PreTrainedTokenizer, @@ -148,6 +284,8 @@ def build_eagle3_dataset( num_proc: Optional[int] = 8, cache_dir: Optional[str] = None, cache_key: Optional[str] = None, + is_vlm: Optional[bool] = False, + processor: Optional[ImageProcessingMixin] = None, ) -> HFDataset: """ build eagle3 dataset @@ -161,10 +299,14 @@ def build_eagle3_dataset( num_proc: The number of processes to use for multiprocessing. cache_dir: The directory to use for caching the processed dataset. cache_key: The key to use for caching the processed dataset. + is_vlm: Whether the dataset is for VLM models. + processor: The image processor to use for processing images. Returns: The processed HF dataset. """ + if is_vlm: + assert processor is not None, "processor must be provided when is_vlm is True" # Get chat template assert ( chat_template in TEMPLATE_REGISTRY.get_all_template_names() @@ -176,12 +318,20 @@ def build_eagle3_dataset( def preprocess_function(examples): # Always do preprocessing - processed = preprocess_conversations( - tokenizer, - examples["conversations"], - template, - max_length, - ) + if is_vlm: + processed = preprocess_vlm_conversations( + processor, + examples, + template, + max_length, + ) + else: + processed = preprocess_conversations( + tokenizer, + examples["conversations"], + template, + max_length, + ) return processed # Process dataset only once @@ -199,11 +349,18 @@ def preprocess_function(examples): f"cache_dir and cache_key must be provided together to make caching work" ) + # reduce batch size for VLM datasets to avoid PyArrow offset overflow + if is_vlm: + batch_size = 200 + else: + batch_size = 1000 dataset = dataset.map( preprocess_function, batched=True, num_proc=num_proc, + batch_size=batch_size, remove_columns=original_cols, + # keep_in_memory=True, load_from_cache_file=load_from_cache_file, cache_file_name=cache_file_name, ) diff --git a/specforge/data/template.py b/specforge/data/template.py index 02d980997..cfef4ea8b 100644 --- a/specforge/data/template.py +++ b/specforge/data/template.py @@ -114,6 +114,16 @@ def get_all_template_names(self) -> List[str]: ), ) +TEMPLATE_REGISTRY.register( + name="qwen2-vl", + template=ChatTemplate( + assistant_header="<|im_start|>assistant\n", + user_header="<|im_start|>user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|im_end|>\n", + ), +) + TEMPLATE_REGISTRY.register( name="deepseek", template=ChatTemplate( diff --git a/specforge/data/utils.py b/specforge/data/utils.py index 55817481a..1ce6a070a 100644 --- a/specforge/data/utils.py +++ b/specforge/data/utils.py @@ -116,6 +116,105 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: return batch +class VlmDataCollatorWithPadding: + """ + Datacollator that will dynamically pad the inputs for batching. + """ + + def paddingtensor(self, intensors: torch.Tensor, N: int) -> torch.Tensor: + """ + Pad to the longest sequence in the batch. + + Args: + intensors: (B, n, S) + N: the length to pad to, N >= n + + Returns: + outtensors: (B, N, S) + """ + B, n, S = intensors.shape + padding_tensor = torch.zeros(B, N - n, S, dtype=intensors.dtype) + outtensors = torch.cat((intensors, padding_tensor), dim=1) + return outtensors + + def paddingtensor2D(self, intensors: torch.Tensor, N: int) -> torch.Tensor: + """ + Pad 2D tensor to the longest sequence in the batch. + + Args: + intensors: (B, n) + N: the length to pad to, N >= n + + Returns: + outtensors: (B, N) + """ + B, n = intensors.shape + padding_tensor = torch.zeros(B, N - n, dtype=intensors.dtype) + outtensors = torch.cat((intensors, padding_tensor), dim=1) + return outtensors + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Collate a batch of features. + + Args: + features: A list of features, where each feature is a dictionary containing: + - input_ids: torch.Tensor of shape (n,) + - attention_mask: torch.Tensor of shape (n,) + - loss_mask: torch.Tensor of shape (n,) + - pixel_values: torch.Tensor of shape (grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size) + - image_grid_thw: torch.Tensor of shape (3,) + + Returns: + A dictionary containing: + - input_ids: torch.Tensor of shape (B, N) + - attention_mask: torch.Tensor of shape (B, N) + - loss_mask: torch.Tensor of shape (B, N) + """ + max_length = max(item["input_ids"].shape[1] for item in features) + batch_input_ids = torch.cat( + [self.paddingtensor2D(item["input_ids"], max_length) for item in features] + ) + batch_attention_mask = torch.cat( + [ + self.paddingtensor2D(item["attention_mask"], max_length) + for item in features + ] + ) + batch_loss_mask = torch.cat( + [self.paddingtensor2D(item["loss_mask"], max_length) for item in features] + ) + batch_pixel_values = torch.cat( + [item["pixel_values"] for item in features], dim=0 + ) + batch_image_grid_thw = torch.cat( + [item["image_grid_thw"] for item in features], dim=0 + ) + batch = { + "input_ids": batch_input_ids, + "attention_mask": batch_attention_mask, + "loss_mask": batch_loss_mask, + "pixel_values": batch_pixel_values, + "image_grid_thw": batch_image_grid_thw, + "hidden_state": None, + "target": None, + } + if all("hidden_state" in item for item in features): + assert all( + "target" in item for item in features + ), "target is required when hidden_state is provided" + batch["hidden_state"] = torch.cat( + [ + self.paddingtensor(item["hidden_state"], max_length) + for item in features + ] + ) + batch["target"] = torch.cat( + [self.paddingtensor(item["target"], max_length) for item in features] + ) + return batch + + def prepare_dp_dataloaders( dataset: Dataset, batch_size: int, @@ -123,6 +222,7 @@ def prepare_dp_dataloaders( process_group: Optional[dist.ProcessGroup] = None, pin_memory: Optional[bool] = False, shuffle: Optional[bool] = False, + is_vlm: Optional[bool] = False, **dataloader_kwargs ) -> DataLoader: """ @@ -135,6 +235,7 @@ def prepare_dp_dataloaders( process_group: The process group for distributed training. pin_memory: Whether to pin memory for data loading. shuffle: Whether to shuffle the dataset. + is_vlm: Whether the dataset is a vision-language model dataset. **dataloader_kwargs: Additional keyword arguments for the DataLoader. Returns: @@ -145,13 +246,17 @@ def prepare_dp_dataloaders( sampler = DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=shuffle ) + if is_vlm: + datacollator_cls = VlmDataCollatorWithPadding + else: + datacollator_cls = DataCollatorWithPadding dataloader = DataLoader( dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=pin_memory, - collate_fn=DataCollatorWithPadding(), + collate_fn=datacollator_cls(), **dataloader_kwargs ) return dataloader diff --git a/specforge/modeling/auto.py b/specforge/modeling/auto.py index 18575f001..b8aa8f071 100644 --- a/specforge/modeling/auto.py +++ b/specforge/modeling/auto.py @@ -11,6 +11,7 @@ Llama4TextConfig, LlamaConfig, PretrainedConfig, + Qwen2_5_VLConfig, Qwen2Config, Qwen3Config, Qwen3MoeConfig, diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 94d6c1fcc..024254ff7 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -103,6 +103,51 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): return q_embed, k_embed +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat( + [m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1 + ).unsqueeze(unsqueeze_dim) + sin = torch.cat( + [m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1 + ).unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def prepare_decoder_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ): @@ -252,6 +297,46 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): ) +class LlamaMutiRotaryEmbedding(LlamaRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + super().__init__(dim, max_position_embeddings, base, device) + self.scaling_factor = scaling_factor + + def forward(self, x, position_ids): + # In contrast to other models, Qwen2_5_VL has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = ( + self.inv_freq[None, None, :, None] + .float() + .expand(3, position_ids.shape[1], -1, 1) + ) + position_ids_expanded = position_ids[ + :, :, None, : + ].float() # shape (3, bs, 1, positions) + + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.scaling_factor + sin = emb.sin() * self.scaling_factor + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -289,7 +374,8 @@ def _init_rope(self): ) else: scaling_type = self.config.rope_scaling["rope_type"] - scaling_factor = self.config.rope_scaling["factor"] + if hasattr(self.config.rope_scaling, "factor"): + scaling_factor = self.config.rope_scaling["factor"] if scaling_type == "linear": self.rotary_emb = LlamaLinearScalingRotaryEmbedding( self.head_dim, @@ -307,6 +393,10 @@ def _init_rope(self): self.rotary_emb = LlamaRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings ) + elif scaling_type == "mrope": + self.rotary_emb = LlamaMutiRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings + ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") @@ -344,11 +434,22 @@ def forward( ).transpose(1, 2) if cache_hidden is None: - cos, sin = self.rotary_emb(query_states, seq_len=q_len) - cos, sin = cos.to(query_states.device), sin.to(query_states.device) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) + if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): + cos, sin = self.rotary_emb(query_states, position_ids) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + self.config.rope_scaling["mrope_section"], + ) + else: + cos, sin = self.rotary_emb(query_states, seq_len=q_len) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -364,12 +465,22 @@ def forward( else: lck = len(cache_hidden[0]) - - cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck) - cos, sin = cos.to(query_states.device), sin.to(query_states.device) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids + lck - ) + if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): + cos, sin = self.rotary_emb(query_states, position_ids + lck) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + self.config.rope_scaling["mrope_section"], + ) + else: + cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) diff --git a/tests/test_qwenvl_loss_mask.py b/tests/test_qwenvl_loss_mask.py new file mode 100644 index 000000000..74523861f --- /dev/null +++ b/tests/test_qwenvl_loss_mask.py @@ -0,0 +1,64 @@ +from transformers import AutoProcessor + +from specforge.data.preprocessing import preprocess_vlm_conversations +from specforge.data.template import TEMPLATE_REGISTRY + +model_path = "Qwen/Qwen2.5-VL-7B-Instruct" +processor = AutoProcessor.from_pretrained(model_path) +# ANSI color codes +RED = "\033[91m" +GREEN = "\033[92m" +RESET = "\033[0m" + + +def test_preprocess_vlm_conversations(): + conversations = [ + {"role": "user", "content": "what is in the image?"}, + {"role": "assistant", "content": "This is an image of a cat."}, + ] + + examples = { + "id": ["example1"], + "image": [ + "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" + ], + "conversations": [conversations], + } + # examples = [examples] # Wrap in a list to match expected input format + chat_template = TEMPLATE_REGISTRY.get("qwen2-vl") + processed = preprocess_vlm_conversations(processor, examples, chat_template, 4096) + + input_ids = processed["input_ids"][0] + print(f"Loss mask sum: {processed['loss_mask'][0].sum()}") + loss_mask = processed["loss_mask"][0].squeeze(0).tolist() + input_ids = input_ids.squeeze(0) + current_mask = input_ids[0] + current_ids = [] + + for i in range(len(input_ids)): + # for i in range(input_ids.shape[-1]): + if current_mask == loss_mask[i]: + current_ids.append(input_ids[i]) + else: + decoded_text = processor.tokenizer.decode( + current_ids, skip_special_tokens=False + ) + if current_mask == 0: + print(f"{RED}{decoded_text}{RESET}", end="") + else: + print(f"{GREEN}{decoded_text}{RESET}", end="") + current_ids = [input_ids[i]] + current_mask = loss_mask[i] + + print( + f"{GREEN}{processor.tokenizer.decode(current_ids, skip_special_tokens=False)}{RESET}" + ) + + print() + print(f"input_ids shape: {processed['input_ids'][0].shape}") + print(f"loss_mask shape: {processed['loss_mask'][0].shape}") + # print(f"hidden_state shape: {processed['hidden_state'].shape}") + + +if __name__ == "__main__": + test_preprocess_vlm_conversations()