Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
99 commits
Select commit Hold shift + click to select a range
f07f575
Support multi-modal datasets (preprocessing part)
DarkLight1337 Apr 30, 2026
76a7a2d
Doc
DarkLight1337 Apr 30, 2026
b8f3f4a
Doc
DarkLight1337 Apr 30, 2026
8688e1b
Iterate
DarkLight1337 May 1, 2026
ddabe1c
Fix
DarkLight1337 May 1, 2026
c9f845e
Simplify
DarkLight1337 May 1, 2026
7b44b1a
Clean
DarkLight1337 May 1, 2026
ce2a072
Use ShareGPT4V to avoid outdated version of datasets
DarkLight1337 May 1, 2026
a81744a
Improve UX
DarkLight1337 May 1, 2026
5f9a87c
Again
DarkLight1337 May 1, 2026
e827b77
Improve
DarkLight1337 May 1, 2026
4f0f9b8
Fix
DarkLight1337 May 1, 2026
3101671
Fixes
DarkLight1337 May 1, 2026
c53fb6f
Update and cover more edge cases
DarkLight1337 May 1, 2026
df2c542
Add trust remote code
DarkLight1337 May 1, 2026
a0da622
Update CLI reference
DarkLight1337 May 1, 2026
2dfa0f5
Format
DarkLight1337 May 1, 2026
cf2631c
More cleanup and tests
DarkLight1337 May 1, 2026
3abacf8
Reduce diff
DarkLight1337 May 1, 2026
f2ed9c3
Fix mypy
DarkLight1337 May 1, 2026
8eab139
Address AI comments
DarkLight1337 May 1, 2026
5088837
Add torchaudio and torchvision to dependencies
DarkLight1337 May 1, 2026
a37e900
Handle transformers v4
DarkLight1337 May 1, 2026
1713b22
Typo
DarkLight1337 May 1, 2026
7e66a57
Clean up
DarkLight1337 May 2, 2026
84d0969
Fix
DarkLight1337 May 2, 2026
f0151b3
Clean
DarkLight1337 May 2, 2026
83fc95a
Improve UX
DarkLight1337 May 2, 2026
b9bdc6c
Fix
DarkLight1337 May 2, 2026
97086d6
Add type annotations
DarkLight1337 May 2, 2026
d9f6efa
Improve error message
DarkLight1337 May 2, 2026
679d8c7
Avoid name clash
DarkLight1337 May 2, 2026
720ccb4
Rename
DarkLight1337 May 2, 2026
de2e38a
Reword
DarkLight1337 May 2, 2026
95c6ec6
Fix
DarkLight1337 May 2, 2026
f949b7c
Fix
DarkLight1337 May 2, 2026
8d0c01d
Fix vllm messages not actually being passed
DarkLight1337 May 2, 2026
20b365f
Fix whitespacec
DarkLight1337 May 2, 2026
a3d7a35
Cleanup
DarkLight1337 May 2, 2026
b163817
Simplify
DarkLight1337 May 2, 2026
347471e
Fix
DarkLight1337 May 2, 2026
7af08bb
Simplify
DarkLight1337 May 2, 2026
68897d9
Move to global
DarkLight1337 May 2, 2026
99cac6b
Fix response parsing
DarkLight1337 May 2, 2026
9199bbb
Fix missing special tokens
DarkLight1337 May 2, 2026
fc66e08
Update
DarkLight1337 May 2, 2026
2a01ac0
Fix chat
DarkLight1337 May 2, 2026
df0b721
Parameterize `run_prepare_data`
DarkLight1337 May 2, 2026
f1c9a5c
Add e2e test
DarkLight1337 May 2, 2026
46e6ab3
Enforce eager to reduce startup time
DarkLight1337 May 2, 2026
4a53d60
Update
DarkLight1337 May 2, 2026
bf2a5c7
Avoid argument bloat
DarkLight1337 May 2, 2026
8a4ac6d
Use enforce eager by default
DarkLight1337 May 2, 2026
999be1f
Typo
DarkLight1337 May 2, 2026
204687d
Up
DarkLight1337 May 2, 2026
c31cb45
Typo
DarkLight1337 May 2, 2026
4a037d6
Don't rely on sampling
DarkLight1337 May 2, 2026
a1fa7b1
Reduce diff
DarkLight1337 May 2, 2026
d9c5ca4
Fix symlink
DarkLight1337 May 2, 2026
09a9bb7
Use the real dataset for nightly
DarkLight1337 May 2, 2026
54f5bf7
Use MM prompts for testing
DarkLight1337 May 2, 2026
f8fa8f9
Fix
DarkLight1337 May 2, 2026
6cbd2c8
Fix
DarkLight1337 May 2, 2026
feb939e
Fix
DarkLight1337 May 4, 2026
04e89e3
Quality
DarkLight1337 May 4, 2026
f8e9bd4
Fix evaluation
DarkLight1337 May 4, 2026
2a338ba
Revert unnecsesary changes
DarkLight1337 May 4, 2026
427bdab
Up
DarkLight1337 May 4, 2026
9799d26
Update thresholds
DarkLight1337 May 5, 2026
d5b5b43
Change the key
DarkLight1337 May 5, 2026
cf0bd0d
Typo
DarkLight1337 May 5, 2026
2ac74ab
Avoid misleading warning for Qwen3.5
DarkLight1337 May 7, 2026
ab936fc
Avoid CPU contention for MM processing
DarkLight1337 May 7, 2026
86639c4
Merge branch 'main' into mm-dataset
DarkLight1337 May 7, 2026
9348a58
Format
DarkLight1337 May 7, 2026
2ff9bbb
mv
DarkLight1337 May 7, 2026
f72d76b
Reduce scope
DarkLight1337 May 7, 2026
8c8c242
Improve messaging
DarkLight1337 May 15, 2026
6b64384
Fix typo
DarkLight1337 May 15, 2026
434ce83
Fix
DarkLight1337 May 15, 2026
3116024
Fix
DarkLight1337 May 15, 2026
a8ce7fb
Fix
DarkLight1337 May 15, 2026
dd4c628
Improve
DarkLight1337 May 15, 2026
b9017fd
Improve messaging
DarkLight1337 May 18, 2026
5d60c44
Merge branch 'main' into mm-dataset
DarkLight1337 May 18, 2026
94ab458
mypy
DarkLight1337 May 18, 2026
df17d4b
Merge branch 'main' into mm-dataset
DarkLight1337 May 21, 2026
3272829
Fix
DarkLight1337 May 21, 2026
07f7d7c
Fix 2
DarkLight1337 May 21, 2026
8a10f1b
Use load_processor everywhere
DarkLight1337 May 22, 2026
7f0c317
Use graph by default for server
DarkLight1337 May 22, 2026
be339a8
Change default
DarkLight1337 May 22, 2026
b6336ed
Fix llm args
DarkLight1337 May 22, 2026
1d31818
Be more robust
DarkLight1337 May 22, 2026
5362c7a
Fix whitespace
DarkLight1337 May 22, 2026
e509c6d
Doc
DarkLight1337 May 22, 2026
2d39e73
Format
DarkLight1337 May 22, 2026
ee95bd6
mypy
DarkLight1337 May 22, 2026
f86b93e
Merge branch 'main' into mm-dataset
shanjiaz May 26, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/cli/prepare_data.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ python scripts/prepare_data.py \

Example: `meta-llama/Llama-3.1-8B-Instruct`

- **`--trust-remote-code`** (flag) Allow executing code from HF Hub when loading the target model's processor.

### Data Arguments

- **`--data`** (str, required, repeatable) Path to training data. Can be a HuggingFace dataset name or local path. Use multiple times to specify multiple datasets.
Expand Down
2 changes: 2 additions & 0 deletions docs/cli/train.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ torchrun --standalone --nproc_per_node=4 scripts/train.py \

- **`--verifier-name-or-path`** (str, required) HuggingFace model ID or local path for the verifier/target model.

- **`--trust-remote-code`** (flag) Allow executing code from HF Hub when loading the verifier's tokenizer.

- **`--speculator-type`** (str, default: `"eagle3"`) Type of speculator model to train. Options: `eagle3`, `dflash`

- **`--from-pretrained`** (str, default: `""`) Path to a pretrained draft model to finetune.
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ dependencies = [
"safetensors",
"setuptools",
"torch>=2.9.0,<=2.11.0",
"torchaudio",
"torchvision",
"tqdm>=4.66.3,<=4.67.3",
"transformers>=4.56.1,<5.9.0",
"typer>=0.12.0",
Expand Down Expand Up @@ -249,6 +251,7 @@ select = [
"PTH", # os.path is acceptable in scripts
"T201", # print statements are acceptable in scripts
"SLF001", # allow private member access for model configuration
"PLR0915", # allow long parse_args functions
]

"examples/**/*.py" = [
Expand Down
28 changes: 16 additions & 12 deletions scripts/data_generation_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
DEFAULT_REQUEST_TIMEOUT,
generate_hidden_states_async,
)
from speculators.train.data import build_client_item
from speculators.train.logger import setup_root_logger

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -66,8 +67,8 @@ def parse_args():
type=str,
default=None,
help=(
"HuggingFace model ID or local path for target model (default auto select)."
"For verification purposes only."
"HuggingFace model ID or local path for target model "
"(default auto select). For verification purposes only."
),
)
parser.add_argument(
Expand Down Expand Up @@ -113,16 +114,16 @@ def parse_args():
type=int,
default=32,
help=(
"Number of active vLLM requests at a time."
"Number of active vLLM requests at a time. "
"Note: number of async workers set to 2*concurrency"
),
)
parser.add_argument(
"--validate-outputs",
action="store_true",
help=(
"Load generated safetensor files and check output token ids match prompt"
" tokens and hidden states seq_len matches num tokens"
"Load generated safetensor files and check output token ids match "
"prompt tokens and hidden states seq_len matches num tokens"
),
)
parser.add_argument(
Expand Down Expand Up @@ -276,16 +277,14 @@ async def worker(
queue.task_done()
continue

input_ids = item["input_ids"].tolist()

target_hidden_states_path = hidden_states_output_dir / f"hs_{idx}.safetensors"

try:
async with vllm_semaphore: # Limit number of active generate calls
hidden_states_path = await generate_hidden_states_async(
client,
model,
input_ids,
item,
timeout=request_timeout,
max_retries=max_retries,
)
Expand All @@ -295,7 +294,9 @@ async def worker(
)
if validate_outputs:
await asyncio.to_thread(
check_safetensors_file, target_hidden_states_path, input_ids
check_safetensors_file,
target_hidden_states_path,
item["input_ids"],
)
except Exception as e:
if fail_on_error:
Expand Down Expand Up @@ -325,12 +326,15 @@ async def _feed_queue(to_process, dataset, queue, cancel_event):
for i in to_process:
if cancel_event.is_set():
break
item = dataset[i]

dataset_item = dataset[i]
client_item = build_client_item(dataset_item) | {"idx": i}

# Check cancel_event while waiting for queue space to avoid
# deadlocking when all workers have died.
while not cancel_event.is_set():
try:
queue.put_nowait({"idx": i, "input_ids": item["input_ids"]})
queue.put_nowait(client_item)
break
except asyncio.QueueFull:
await asyncio.sleep(0.1)
Expand Down Expand Up @@ -397,7 +401,7 @@ async def generate_and_save_hidden_states(args, dataset):
if args.model and args.model != model_id:
raise ValueError(
f"An explicit model name was passed ({args.model}) which doesn't match"
"found model_id {model_id}."
f" found model_id {model_id}."
"Please make sure --endpoint is set to the correct vllm instance."
)

Expand Down
11 changes: 10 additions & 1 deletion scripts/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ def parse_args():
required=True,
help="HuggingFace model ID or local path for target model",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help=(
"Allow executing code from HF Hub when loading the target model's "
"processor."
),
)

# Data arguments
parser.add_argument(
Expand All @@ -75,7 +83,7 @@ def parse_args():
type=str,
default=None,
help=(
"Path to save token frequency distribution"
"Path to save token frequency distribution "
"(default: args.output / 'token_freq.pt')"
),
)
Expand Down Expand Up @@ -177,6 +185,7 @@ def main():
assistant_pattern=args.assistant_pattern,
turn_dropout=args.turn_dropout,
minimum_valid_tokens=args.minimum_valid_tokens,
trust_remote_code=args.trust_remote_code,
)

log.info("Done preparing data")
Expand Down
6 changes: 6 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def main(args: argparse.Namespace):
args.verifier_name_or_path,
transformer_layer_config.vocab_size,
args.mask_token_id,
trust_remote_code=args.trust_remote_code,
)

registry = SpeculatorModel.registry
Expand Down Expand Up @@ -398,6 +399,11 @@ def _checkpoint_freq(value: str) -> float:
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--verifier-name-or-path", type=str, required=True)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Allow executing code from HF Hub when loading the verifier's tokenizer.",
)
parser.add_argument(
"--speculator-type",
type=str,
Expand Down
71 changes: 68 additions & 3 deletions src/speculators/data_generation/configs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Configuration registries for data generation pipeline."""

import os
from collections.abc import Callable
from dataclasses import dataclass

Expand All @@ -9,14 +10,15 @@
]


@dataclass
@dataclass(kw_only=True)
class DatasetConfig:
"""Configuration for loading a dataset"""

name: str
hf_path: str
split: str
subset: str | None = None
split: str
filter_fn: Callable[[dict], bool] | None = None
normalize_fn: Callable[[dict], dict] | None = None


Expand All @@ -35,6 +37,60 @@ def _normalize_gsm8k(example: dict) -> dict:
}


def get_coco_dir():
return os.getenv("COCO_DIR") or "coco/"


def _parse_sharegpt4v_part(part: str, image_path: str):
if part == "<image>":
return {"type": "image", "path": image_path}

return {"type": "text", "text": part}


def _parse_sharegpt4v_user_content(content: str, image_path: str):
return [_parse_sharegpt4v_part(part, image_path) for part in content.split("\n")]


def _parse_sharegpt4v_assistant_content(content: str):
return [{"type": "text", "text": content}]


def _filter_sharegpt4v_coco(example: dict) -> bool:
return example["image"].startswith("coco/")


def _normalize_sharegpt4v_coco(example: dict) -> dict:
coco_dir = get_coco_dir()
image_path = os.path.join(coco_dir, example["image"].removeprefix("coco/"))

if not os.path.exists(image_path):
state_str = "set to" if os.getenv("COCO_DIR") else "default"

raise ValueError(
f"No image found at <{image_path}>. "
f"Please download COCO 2017 Train Images from "
f"<http://images.cocodataset.org/zips/train2017.zip> and place the "
f"extracted folder under `COCO_DIR` ({state_str}: `{coco_dir}`)."
)

messages = [
(
turn
| {
"value": (
_parse_sharegpt4v_user_content(turn["value"], image_path)
if turn["from"] in ("human", "user")
else _parse_sharegpt4v_assistant_content(turn["value"])
)
}
)
for turn in example["conversations"]
]

return {"conversations": messages}


DATASET_CONFIGS: dict[str, DatasetConfig] = {
"sharegpt": DatasetConfig(
name="sharegpt",
Expand All @@ -50,8 +106,17 @@ def _normalize_gsm8k(example: dict) -> dict:
"gsm8k": DatasetConfig(
name="gsm8k",
hf_path="openai/gsm8k",
split="train",
subset="main",
split="train",
normalize_fn=_normalize_gsm8k,
),
# NOTE: You need to serve vLLM with `--allowed-local-media-path /path/to/coco`
"sharegpt4v_coco": DatasetConfig(
name="sharegpt4v_coco",
hf_path="Lin-Chen/ShareGPT4V",
subset="ShareGPT4V",
split="train",
filter_fn=_filter_sharegpt4v_coco,
normalize_fn=_normalize_sharegpt4v_coco,
),
}
Loading
Loading