Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ jobs:

- name: Install
shell: bash
run: cd $GITHUB_WORKSPACE && pip install -e . --break-system-packages
run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages

- name: Execute
shell: bash
Expand Down Expand Up @@ -107,7 +107,7 @@ jobs:

- name: Install
shell: bash
run: cd $GITHUB_WORKSPACE && pip install -e . --break-system-packages
run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages

- name: Execute
shell: bash
Expand Down Expand Up @@ -151,7 +151,7 @@ jobs:

- name: Install
shell: bash
run: cd $GITHUB_WORKSPACE && pip install -e . --break-system-packages
run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages

- name: Execute
shell: bash
Expand Down Expand Up @@ -195,7 +195,7 @@ jobs:

- name: Install
shell: bash
run: cd $GITHUB_WORKSPACE && pip install -e . --break-system-packages
run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages

- name: Execute
shell: bash
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pr-test.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ jobs:

- name: Install
shell: bash
run: cd $GITHUB_WORKSPACE && pip install -e . --break-system-packages
run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages

- name: Execute
shell: bash
Expand Down
14 changes: 13 additions & 1 deletion slime/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from slime.utils.metric_checker import MetricChecker
from slime.utils.metric_utils import compute_pass_rate, compute_rollout_step, compute_statistics, dict_add_prefix
from slime.utils.misc import load_function
from slime.utils.processing_utils import load_processor
from slime.utils.ray_utils import Box
from slime.utils.seqlen_balancing import get_seqlen_balanced_partitions
from slime.utils.tracking_utils import init_tracking
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(self, args, pg):
self._metric_checker = MetricChecker.maybe_create(args)
if self.args.use_fault_tolerance:
self._health_monitor = RolloutHealthMonitor(self, args)
self.processor = None

def dispose(self):
if self._metric_checker is not None:
Expand Down Expand Up @@ -275,7 +277,17 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl
train_data["metadata"] = [sample.train_metadata for sample in samples]

if samples[0].multimodal_inputs is not None:
train_data["multimodal_inputs"] = [sample.multimodal_inputs for sample in samples]
if self.processor is None:
self.processor = load_processor(self.args.hf_checkpoint, trust_remote_code=True)
train_data["multimodal_inputs"] = []
for sample in samples:
# Get input IDs with full prompt (text + multimodal)
processor_output = self.processor(text=sample.prompt, **sample.multimodal_inputs)

# Extract multimodal tokens (exclude text-related tokens)
train_data["multimodal_inputs"].append(
{k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"]}
)

if "teacher_log_probs" in samples[0].__dict__:
train_data["teacher_log_probs"] = [sample.teacher_log_probs for sample in samples]
Expand Down
15 changes: 2 additions & 13 deletions slime/rollout/sft_rollout.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from slime.utils.mask_utils import MultiTurnLossMaskGenerator
from slime.utils.processing_utils import load_processor, load_tokenizer, prepare_model_inputs
from slime.utils.processing_utils import load_processor, load_tokenizer

__all__ = ["generate_rollout"]

Expand Down Expand Up @@ -46,18 +46,7 @@ def generate_rollout(args, rollout_id, data_buffer, evaluation=False):
messages = sample.prompt
tools = sample.metadata.get("tools", None)

input_ids, extra_info = prepare_model_inputs(
messages, TOKENIZER, PROCESSOR, sample.metadata, args.apply_chat_template, args.apply_chat_template_kwargs
)

has_multimodal = bool(extra_info.get("images") or extra_info.get("videos"))
if has_multimodal:
sample.multimodal_inputs = extra_info["multimodal_inputs"]
token_ids, loss_mask = MASK_GENERATOR.get_loss_mask_with_multimodal_alignment(
messages, input_ids, tools=tools
)
else:
token_ids, loss_mask = MASK_GENERATOR.get_loss_mask(messages, tools=tools)
token_ids, loss_mask = MASK_GENERATOR.get_loss_mask(messages, tools=tools)

response_length = MASK_GENERATOR.get_response_lengths([loss_mask])[0]

Expand Down
35 changes: 11 additions & 24 deletions slime/rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,7 @@
from slime.utils.http_utils import get, post
from slime.utils.mask_utils import get_response_lengths
from slime.utils.misc import SingletonMeta, load_function
from slime.utils.processing_utils import (
encode_image_for_rollout_engine,
load_processor,
load_tokenizer,
prepare_model_inputs,
)
from slime.utils.processing_utils import encode_image_for_rollout_engine, load_processor, load_tokenizer
from slime.utils.types import Sample

from .rm_hub import async_rm, batched_async_rm
Expand Down Expand Up @@ -90,26 +85,21 @@ def submit_generate_tasks(self, samples: list[list[Sample]]) -> None:

async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, Any]) -> Sample:
"""Generate using traditional SGLang router with token-based workflow"""
if args.ci_test:
assert isinstance(sample.prompt, str)

state = GenerateState(args)
url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"

assert (
sample.status == Sample.Status.PENDING or sample.status == Sample.Status.ABORTED
), f"Sample status is {sample.status}"

prompt_ids, extra_info = prepare_model_inputs(
sample.prompt,
state.tokenizer,
state.processor,
sample.metadata,
args.apply_chat_template,
args.apply_chat_template_kwargs,
)

sample.prompt = extra_info.get("formatted_prompt", sample.prompt)
image_data = extra_info.get("images", [])
video_data = extra_info.get("videos", [])
multimodal_inputs = extra_info.get("multimodal_inputs", None)
if state.processor:
processor_output = state.processor(text=sample.prompt, **sample.multimodal_inputs)
prompt_ids = processor_output["input_ids"][0]
else:
prompt_ids = state.tokenizer.encode(sample.prompt, add_special_tokens=False)

if len(sample.response) > 0:
sampling_params["max_new_tokens"] -= len(sample.tokens) - len(prompt_ids)
Expand All @@ -130,12 +120,9 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A
if args.use_rollout_routing_replay:
payload["return_routed_experts"] = True

if image_data:
if sample.multimodal_inputs and sample.multimodal_inputs["images"]:
image_data = sample.multimodal_inputs["images"]
payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data]
sample.multimodal_inputs = multimodal_inputs

if video_data:
raise NotImplementedError("Video data is not supported yet")

# Use existing tokens for multi-turn or tokenize the new prompt
if len(sample.response) > 0:
Expand Down
42 changes: 31 additions & 11 deletions slime/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,16 @@ def _parse_generalized_path(s: str):
return s, None


def _should_skip_prompt(
prompt, tokenizer, processor, metadata, max_length, apply_chat_template, apply_chat_template_kwargs
):
def _should_skip_prompt(formatted_prompt: str, tokenizer, processor, max_length, multimodal_inputs=None):
if max_length is None:
return False

from slime.utils.processing_utils import prepare_model_inputs
if processor:
processor_output = processor(text=formatted_prompt, **multimodal_inputs)
input_ids = processor_output["input_ids"][0]
else:
input_ids = tokenizer.encode(formatted_prompt, add_special_tokens=False)

input_ids, _ = prepare_model_inputs(
prompt, tokenizer, processor, metadata, apply_chat_template, apply_chat_template_kwargs
)
return len(input_ids) > max_length


Expand Down Expand Up @@ -140,6 +139,7 @@ def __init__(
prompt = _build_messages(data, prompt_key, as_conversation, multimodal_keys)

metadata = data.get(metadata_key) or {}
tools = None
if tool_key is not None and tool_key in data:
tools = data[tool_key]
if isinstance(tools, str):
Expand All @@ -149,17 +149,37 @@ def __init__(
assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead"
metadata["tools"] = tools

if apply_chat_template:
formatted_prompt = tokenizer.apply_chat_template(
prompt,
tools=tools,
tokenize=False,
add_generation_prompt=True,
**(apply_chat_template_kwargs or {}),
)
else:
formatted_prompt = prompt

if processor:
# temporary solution, will write image utils for slime later
from qwen_vl_utils import process_vision_info

assert isinstance(prompt, list)
images, videos = process_vision_info(prompt)
multimodal_inputs = {"images": images, "videos": videos}
else:
multimodal_inputs = None

# TODO: this is slow.
if _should_skip_prompt(
prompt, tokenizer, processor, metadata, max_length, apply_chat_template, apply_chat_template_kwargs
):
if _should_skip_prompt(formatted_prompt, tokenizer, processor, max_length, multimodal_inputs):
continue

self.origin_samples.append(
Sample(
prompt=prompt,
prompt=formatted_prompt,
label=data[label_key] if label_key is not None else None,
metadata=metadata,
multimodal_inputs=multimodal_inputs,
)
)

Expand Down
57 changes: 0 additions & 57 deletions slime/utils/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import io
import logging

import numpy as np
from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizerBase, ProcessorMixin

logger = logging.getLogger(__name__)
Expand All @@ -26,62 +25,6 @@ def load_processor(name_or_path: str, **kwargs):
return proc


def prepare_model_inputs(
prompt, tokenizer, processor=None, metadata=None, apply_chat_template=False, apply_chat_template_kwargs=None
):
"""Prepare all inputs for model inference.

Returns:
tuple: (input_ids, extra_info)
- input_ids: Token IDs for the prompt
- extra_info: Dict with 'images', 'videos', 'multimodal_inputs' (or empty dict)
"""
tools = metadata.get("tools") if metadata else None
if isinstance(prompt, (list, np.ndarray)):
assert (
apply_chat_template
), f"apply_chat_template must be True when prompt is a list or numpy array, current prompt is {prompt}"
formatted_prompt = tokenizer.apply_chat_template(
prompt,
tools=tools,
tokenize=False,
add_generation_prompt=True,
**(apply_chat_template_kwargs or {}),
)
elif isinstance(prompt, str):
assert (
not apply_chat_template
), f"apply_chat_template must be False when prompt is a string, current prompt is {prompt}"
formatted_prompt = prompt
else:
raise ValueError(f"Invalid prompt type: {type(prompt)}, current prompt is {prompt}")

if not processor:
input_ids = tokenizer.encode(formatted_prompt, add_special_tokens=False)
return input_ids, {"formatted_prompt": formatted_prompt}
else:
# temporary solution, will write image utils for slime later
from qwen_vl_utils import process_vision_info

images, videos = process_vision_info(prompt)

# Get input IDs with full prompt (text + multimodal)
processor_output = processor(text=formatted_prompt, images=images, videos=videos)
input_ids = processor_output["input_ids"][0]

# Extract multimodal tokens (exclude text-related tokens)
multimodal_inputs = {k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"]}

extra_info = {
"formatted_prompt": formatted_prompt,
"images": images,
"videos": videos,
"multimodal_inputs": multimodal_inputs,
}

return input_ids, extra_info


def encode_image_for_rollout_engine(image) -> str:
"""Load an image from path, ensure RGB, encode as JPEG base64 string."""
buffer = io.BytesIO()
Expand Down
Loading