Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
01dd714
support qwen2_5_vl online
Lzhang-hub Jul 29, 2025
49c1754
delete nohup
Lzhang-hub Jul 29, 2025
449a7f3
add qwen2.5-vl eagle model
Lzhang-hub Jul 30, 2025
879f30e
add todo
Lzhang-hub Jul 30, 2025
a51ed5a
clean dev code
Lzhang-hub Jul 30, 2025
cae128d
support batch and fix position_ids bug
Lzhang-hub Jul 31, 2025
3de7f6b
add eval wandb metrics
Lzhang-hub Jul 31, 2025
d25c229
fix eval bug
Lzhang-hub Jul 31, 2025
8d5d8b0
fix eval dataloader bug
Lzhang-hub Aug 1, 2025
da54bf3
add comment
Lzhang-hub Aug 1, 2025
705a988
Merge branch 'main' into qwen-vl
Lzhang-hub Aug 1, 2025
8ab0c2d
merge main
Lzhang-hub Aug 1, 2025
bc6f27f
rename vlm online eagle3 model name
Lzhang-hub Aug 1, 2025
1fb216b
clean code
Lzhang-hub Aug 1, 2025
7873301
fix ttt input embeds bug
Lzhang-hub Aug 1, 2025
d0591b8
fix eval metrics bug
Lzhang-hub Aug 1, 2025
61f52f5
merge qwen-vl draft model to llama3
Lzhang-hub Aug 1, 2025
a6bb3c7
fix qwen vl train shell
Lzhang-hub Aug 4, 2025
5a26c21
add timeout config
Lzhang-hub Aug 4, 2025
c8195a6
qwenvl draft input without image embedding
Lzhang-hub Aug 5, 2025
07f3681
qwenvl draft input without image embedding
Lzhang-hub Aug 13, 2025
5fbcb71
Revert "qwenvl draft input without image embedding"
Lzhang-hub Aug 13, 2025
c604eb0
Merge branch 'main' into qwen-vl
Lzhang-hub Aug 13, 2025
d035fb6
Merge branch 'main' into qwen-vl
Lzhang-hub Aug 18, 2025
98fd7de
fix gitignore
Lzhang-hub Aug 18, 2025
44e6e9b
Merge branch 'main' into qwen-vl
Lzhang-hub Aug 22, 2025
0557314
fix wandb error
Lzhang-hub Aug 22, 2025
ec956e7
fix lint
Lzhang-hub Aug 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,7 @@ cache/
outputs/
wandb/
.idea
.vscode/

# macOS
.DS_Store

.vscode/
40 changes: 40 additions & 0 deletions configs/qwen2-5-vl-eagle3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"architectures": [
"LlamaForCausalLMEagle3"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 3584,
"initializer_range": 0.02,
"intermediate_size": 18944,
"max_position_embeddings": 8192,
"max_window_layers": 28,
"model_type": "llama",
"target_model_type": "qwen2_5_vl",
"num_attention_heads": 28,
"num_hidden_layers": 1,
"num_key_value_heads": 4,
"rms_norm_eps": 1e-06,
"pretraining_tp": 1,
"rope_scaling": {
"type": "mrope",
"mrope_section": [
16,
24,
24
]
},
"rope_theta": 1000000,
"sliding_window": 32768,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.51.0",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 152064,
"draft_vocab_size": 32000
}
28 changes: 28 additions & 0 deletions examples/run_qwen2_5_vl_eagle3_online.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/bin/bash

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)

# support tp1 train eagle3 for qwen2.5-vl-7b-instruct
NUM_GPUS=${1:-1}

torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3_online.py \
--target-model-path Qwen/Qwen2.5-VL-7B-Instruct \
--draft-model-config $ROOT_DIR/configs/qwen2-5-vl-eagle3.json \
--train-data-path $ROOT_DIR/cache/dataset/allava4v_train.jsonl \
--output-dir $ROOT_DIR/outputs/Qwen2.5-VL-7B-eagle3 \
--num-epochs 10 \
--batch-size 1 \
--learning-rate 1e-4 \
--max-length 8192 \
--dist-timeout 360 \
--chat-template qwen2-vl \
--cache-dir $ROOT_DIR/cache \
--embedding-key model.embed_tokens.weight \
--tp-size 1 \
--is-vlm \
--min-pixels 50176 \
--max-pixels 802816
123 changes: 105 additions & 18 deletions scripts/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def parse_args():
parser.add_argument(
"--dataset",
type=str,
choices=["ultrachat", "sharegpt", "opc"],
choices=["ultrachat", "sharegpt", "sharegpt4v", "allava4v", "opc"],
help="The demo dataset to quickly run the training for speculative decoding",
)
parser.add_argument(
Expand All @@ -48,6 +48,17 @@ def parse_args():
default=None,
help="The path to the custom dataset, if not specified, the default dataset will be loaded",
)
parser.add_argument(
"--sample-size",
type=int,
default=None,
help="The number of samples to process from the dataset, if not specified, all samples will be processed",
)
parser.add_argument(
"--split-eval",
action="store_true",
help="Whether to split the dataset into train and eval sets, default is False",
)
return parser.parse_args()


Expand Down Expand Up @@ -101,12 +112,83 @@ def process_sharegpt_row(row: Dict) -> Tuple[Dict, int]:
return row, skipped_count


def process_sharegpt4v_row(row) -> Dict:
"""
sharegpt4v dataset schema:
{
"id": str,
"image": str, # path to the image
"conversations": [
{
"from": <human|gpt>,
"value": <message>,
},
...
]
}
"""
conversations = row["conversations"]
image = f'FreedomIntelligence/ALLaVA-4V/{row["image"]}'
if not os.path.exists(image):
print(f"Image path {image} does not exist, skipping this sample.")
return None, None
formatted_conversations = []
skipped_count = 0
for message in conversations:
if message["from"] not in ROLE_MAPPING:
skipped_count += 1
continue
new_role = ROLE_MAPPING[message["from"]]
if new_role == "user":
text_content = message["value"].replace("<image>\n", "")
content = text_content
else:
content = message["value"]
formatted_conversations.append({"role": new_role, "content": content})

row = {"id": row["id"], "image": image, "conversations": formatted_conversations}
return row, skipped_count


def load_dataset_from_path(data_path: Path):
suffix = data_path.suffix.split(".")[1]
ds = load_dataset(suffix, data_files=str(data_path), split="train")
return ds


def process_and_save_ds(train_ds, test_ds, output_path, proc_fn, dataset_name):
train_output_jsonl_path = output_path.joinpath(f"{dataset_name}_train.jsonl")
if train_output_jsonl_path.exists():
print(
f"The dataset {dataset_name} has already been processed and saved in {train_output_jsonl_path}, skipping..."
)
return

total_skipped_count = 0
with open(train_output_jsonl_path, "w") as f:
for item in tqdm(train_ds, desc=f"Processing {dataset_name} dataset"):
row, skipped_count = proc_fn(item)
if row is None:
continue
total_skipped_count += skipped_count
f.write(json.dumps(row) + "\n")

if test_ds is not None:
test_output_jsonl_path = output_path.joinpath(f"{dataset_name}_test.jsonl")
with open(test_output_jsonl_path, "w") as f:
for item in tqdm(test_ds, desc=f"Processing {dataset_name} test dataset"):
row, skipped_count = proc_fn(item)
if row is None:
continue
total_skipped_count += skipped_count
f.write(json.dumps(row) + "\n")

if total_skipped_count > 0:
print(
f"Skipped {total_skipped_count}/{len(train_ds)+len(test_ds)} messages for {dataset_name}"
)


import hashlib


Expand Down Expand Up @@ -135,6 +217,14 @@ def main():
print("Loading dataset from custom data path: ", args.data_path)
ds = load_dataset_from_path(Path(args.data_path))
proc_fn = process_sharegpt_row
elif args.dataset == "sharegpt4v":
ds = load_dataset("Lin-Chen/ShareGPT4V")["train"]
proc_fn = process_sharegpt4v_row
elif args.dataset == "allava4v":
ds = load_dataset("FreedomIntelligence/ALLaVA-4V", name="allava_laion")[
"instruct"
]
proc_fn = process_sharegpt4v_row
elif args.dataset == "opc":
ds = load_dataset(
"OpenCoder-LLM/opc-sft-stage1", "largescale_diverse_instruct"
Expand All @@ -145,30 +235,27 @@ def main():
"This script only supports ultrachat_200k and sharegpt datasets for demo purpose, if you wish to use other datasets, please modify this script."
)

# filter and split dataset
if args.sample_size is not None and args.sample_size < len(ds):
ds = ds.select(range(args.sample_size))
print(f"Processing {args.sample_size} samples from the dataset {args.dataset}")
if args.split_eval:
ds = ds.train_test_split(test_size=0.05)
train_ds = ds["train"]
test_ds = ds["test"]
else:
train_ds = ds
test_ds = None

if args.output_path is None:
root_path = Path(__file__).parent.parent
output_path = root_path.joinpath("cache", "dataset")
output_path.mkdir(parents=True, exist_ok=True)
else:
output_path = Path(args.output_path)
output_path.mkdir(parents=True, exist_ok=True)

output_jsonl_path = output_path.joinpath(f"{args.dataset}.jsonl")

if output_jsonl_path.exists():
print(
f"The dataset {args.dataset} has already been processed and saved in {output_jsonl_path}, skipping..."
)
return

total_skipped_count = 0
with open(output_jsonl_path, "w") as f:
for item in tqdm(ds, desc=f"Processing {args.dataset} dataset"):
row, skipped_count = proc_fn(item)
total_skipped_count += skipped_count
f.write(json.dumps(row) + "\n")

if total_skipped_count > 0:
print(f"Skipped {total_skipped_count}/{len(ds)} messages for {args.dataset}")
process_and_save_ds(train_ds, test_ds, output_path, proc_fn, args.dataset)


if __name__ == "__main__":
Expand Down
Loading