diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index d38eb5268..e8697f2b6 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -63,7 +63,7 @@ jobs: - name: Install shell: bash - run: cd $GITHUB_WORKSPACE && pip install -e . --break-system-packages + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages - name: Execute shell: bash @@ -107,7 +107,7 @@ jobs: - name: Install shell: bash - run: cd $GITHUB_WORKSPACE && pip install -e . --break-system-packages + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages - name: Execute shell: bash @@ -151,7 +151,7 @@ jobs: - name: Install shell: bash - run: cd $GITHUB_WORKSPACE && pip install -e . --break-system-packages + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages - name: Execute shell: bash @@ -195,7 +195,7 @@ jobs: - name: Install shell: bash - run: cd $GITHUB_WORKSPACE && pip install -e . --break-system-packages + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages - name: Execute shell: bash diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 0ed66ffae..712a0bba4 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -94,7 +94,7 @@ jobs: - name: Install shell: bash - run: cd $GITHUB_WORKSPACE && pip install -e . --break-system-packages + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages - name: Execute shell: bash diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index d2aa0db22..3ad953f3d 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -20,6 +20,7 @@ from slime.utils.metric_checker import MetricChecker from slime.utils.metric_utils import compute_pass_rate, compute_rollout_step, compute_statistics, dict_add_prefix from slime.utils.misc import load_function +from slime.utils.processing_utils import load_processor from slime.utils.ray_utils import Box from slime.utils.seqlen_balancing import get_seqlen_balanced_partitions from slime.utils.tracking_utils import init_tracking @@ -77,6 +78,7 @@ def __init__(self, args, pg): self._metric_checker = MetricChecker.maybe_create(args) if self.args.use_fault_tolerance: self._health_monitor = RolloutHealthMonitor(self, args) + self.processor = None def dispose(self): if self._metric_checker is not None: @@ -275,7 +277,17 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl train_data["metadata"] = [sample.train_metadata for sample in samples] if samples[0].multimodal_inputs is not None: - train_data["multimodal_inputs"] = [sample.multimodal_inputs for sample in samples] + if self.processor is None: + self.processor = load_processor(self.args.hf_checkpoint, trust_remote_code=True) + train_data["multimodal_inputs"] = [] + for sample in samples: + # Get input IDs with full prompt (text + multimodal) + processor_output = self.processor(text=sample.prompt, **sample.multimodal_inputs) + + # Extract multimodal tokens (exclude text-related tokens) + train_data["multimodal_inputs"].append( + {k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"]} + ) if "teacher_log_probs" in samples[0].__dict__: train_data["teacher_log_probs"] = [sample.teacher_log_probs for sample in samples] diff --git a/slime/rollout/sft_rollout.py b/slime/rollout/sft_rollout.py index e66782fac..6b914a964 100644 --- a/slime/rollout/sft_rollout.py +++ b/slime/rollout/sft_rollout.py @@ -1,7 +1,7 @@ import logging from slime.utils.mask_utils import MultiTurnLossMaskGenerator -from slime.utils.processing_utils import load_processor, load_tokenizer, prepare_model_inputs +from slime.utils.processing_utils import load_processor, load_tokenizer __all__ = ["generate_rollout"] @@ -46,18 +46,7 @@ def generate_rollout(args, rollout_id, data_buffer, evaluation=False): messages = sample.prompt tools = sample.metadata.get("tools", None) - input_ids, extra_info = prepare_model_inputs( - messages, TOKENIZER, PROCESSOR, sample.metadata, args.apply_chat_template, args.apply_chat_template_kwargs - ) - - has_multimodal = bool(extra_info.get("images") or extra_info.get("videos")) - if has_multimodal: - sample.multimodal_inputs = extra_info["multimodal_inputs"] - token_ids, loss_mask = MASK_GENERATOR.get_loss_mask_with_multimodal_alignment( - messages, input_ids, tools=tools - ) - else: - token_ids, loss_mask = MASK_GENERATOR.get_loss_mask(messages, tools=tools) + token_ids, loss_mask = MASK_GENERATOR.get_loss_mask(messages, tools=tools) response_length = MASK_GENERATOR.get_response_lengths([loss_mask])[0] diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index 3988eeb8e..e063bce4f 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -20,12 +20,7 @@ from slime.utils.http_utils import get, post from slime.utils.mask_utils import get_response_lengths from slime.utils.misc import SingletonMeta, load_function -from slime.utils.processing_utils import ( - encode_image_for_rollout_engine, - load_processor, - load_tokenizer, - prepare_model_inputs, -) +from slime.utils.processing_utils import encode_image_for_rollout_engine, load_processor, load_tokenizer from slime.utils.types import Sample from .rm_hub import async_rm, batched_async_rm @@ -90,6 +85,9 @@ def submit_generate_tasks(self, samples: list[list[Sample]]) -> None: async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, Any]) -> Sample: """Generate using traditional SGLang router with token-based workflow""" + if args.ci_test: + assert isinstance(sample.prompt, str) + state = GenerateState(args) url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" @@ -97,19 +95,11 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A sample.status == Sample.Status.PENDING or sample.status == Sample.Status.ABORTED ), f"Sample status is {sample.status}" - prompt_ids, extra_info = prepare_model_inputs( - sample.prompt, - state.tokenizer, - state.processor, - sample.metadata, - args.apply_chat_template, - args.apply_chat_template_kwargs, - ) - - sample.prompt = extra_info.get("formatted_prompt", sample.prompt) - image_data = extra_info.get("images", []) - video_data = extra_info.get("videos", []) - multimodal_inputs = extra_info.get("multimodal_inputs", None) + if state.processor: + processor_output = state.processor(text=sample.prompt, **sample.multimodal_inputs) + prompt_ids = processor_output["input_ids"][0] + else: + prompt_ids = state.tokenizer.encode(sample.prompt, add_special_tokens=False) if len(sample.response) > 0: sampling_params["max_new_tokens"] -= len(sample.tokens) - len(prompt_ids) @@ -130,12 +120,9 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A if args.use_rollout_routing_replay: payload["return_routed_experts"] = True - if image_data: + if sample.multimodal_inputs and sample.multimodal_inputs["images"]: + image_data = sample.multimodal_inputs["images"] payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data] - sample.multimodal_inputs = multimodal_inputs - - if video_data: - raise NotImplementedError("Video data is not supported yet") # Use existing tokens for multi-turn or tokenize the new prompt if len(sample.response) > 0: diff --git a/slime/utils/data.py b/slime/utils/data.py index 45ee103fa..ddb9cf459 100644 --- a/slime/utils/data.py +++ b/slime/utils/data.py @@ -49,17 +49,16 @@ def _parse_generalized_path(s: str): return s, None -def _should_skip_prompt( - prompt, tokenizer, processor, metadata, max_length, apply_chat_template, apply_chat_template_kwargs -): +def _should_skip_prompt(formatted_prompt: str, tokenizer, processor, max_length, multimodal_inputs=None): if max_length is None: return False - from slime.utils.processing_utils import prepare_model_inputs + if processor: + processor_output = processor(text=formatted_prompt, **multimodal_inputs) + input_ids = processor_output["input_ids"][0] + else: + input_ids = tokenizer.encode(formatted_prompt, add_special_tokens=False) - input_ids, _ = prepare_model_inputs( - prompt, tokenizer, processor, metadata, apply_chat_template, apply_chat_template_kwargs - ) return len(input_ids) > max_length @@ -140,6 +139,7 @@ def __init__( prompt = _build_messages(data, prompt_key, as_conversation, multimodal_keys) metadata = data.get(metadata_key) or {} + tools = None if tool_key is not None and tool_key in data: tools = data[tool_key] if isinstance(tools, str): @@ -149,17 +149,37 @@ def __init__( assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead" metadata["tools"] = tools + if apply_chat_template: + formatted_prompt = tokenizer.apply_chat_template( + prompt, + tools=tools, + tokenize=False, + add_generation_prompt=True, + **(apply_chat_template_kwargs or {}), + ) + else: + formatted_prompt = prompt + + if processor: + # temporary solution, will write image utils for slime later + from qwen_vl_utils import process_vision_info + + assert isinstance(prompt, list) + images, videos = process_vision_info(prompt) + multimodal_inputs = {"images": images, "videos": videos} + else: + multimodal_inputs = None + # TODO: this is slow. - if _should_skip_prompt( - prompt, tokenizer, processor, metadata, max_length, apply_chat_template, apply_chat_template_kwargs - ): + if _should_skip_prompt(formatted_prompt, tokenizer, processor, max_length, multimodal_inputs): continue self.origin_samples.append( Sample( - prompt=prompt, + prompt=formatted_prompt, label=data[label_key] if label_key is not None else None, metadata=metadata, + multimodal_inputs=multimodal_inputs, ) ) diff --git a/slime/utils/processing_utils.py b/slime/utils/processing_utils.py index 9863c90d4..ab16acfea 100644 --- a/slime/utils/processing_utils.py +++ b/slime/utils/processing_utils.py @@ -2,7 +2,6 @@ import io import logging -import numpy as np from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizerBase, ProcessorMixin logger = logging.getLogger(__name__) @@ -26,62 +25,6 @@ def load_processor(name_or_path: str, **kwargs): return proc -def prepare_model_inputs( - prompt, tokenizer, processor=None, metadata=None, apply_chat_template=False, apply_chat_template_kwargs=None -): - """Prepare all inputs for model inference. - - Returns: - tuple: (input_ids, extra_info) - - input_ids: Token IDs for the prompt - - extra_info: Dict with 'images', 'videos', 'multimodal_inputs' (or empty dict) - """ - tools = metadata.get("tools") if metadata else None - if isinstance(prompt, (list, np.ndarray)): - assert ( - apply_chat_template - ), f"apply_chat_template must be True when prompt is a list or numpy array, current prompt is {prompt}" - formatted_prompt = tokenizer.apply_chat_template( - prompt, - tools=tools, - tokenize=False, - add_generation_prompt=True, - **(apply_chat_template_kwargs or {}), - ) - elif isinstance(prompt, str): - assert ( - not apply_chat_template - ), f"apply_chat_template must be False when prompt is a string, current prompt is {prompt}" - formatted_prompt = prompt - else: - raise ValueError(f"Invalid prompt type: {type(prompt)}, current prompt is {prompt}") - - if not processor: - input_ids = tokenizer.encode(formatted_prompt, add_special_tokens=False) - return input_ids, {"formatted_prompt": formatted_prompt} - else: - # temporary solution, will write image utils for slime later - from qwen_vl_utils import process_vision_info - - images, videos = process_vision_info(prompt) - - # Get input IDs with full prompt (text + multimodal) - processor_output = processor(text=formatted_prompt, images=images, videos=videos) - input_ids = processor_output["input_ids"][0] - - # Extract multimodal tokens (exclude text-related tokens) - multimodal_inputs = {k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"]} - - extra_info = { - "formatted_prompt": formatted_prompt, - "images": images, - "videos": videos, - "multimodal_inputs": multimodal_inputs, - } - - return input_ids, extra_info - - def encode_image_for_rollout_engine(image) -> str: """Load an image from path, ensure RGB, encode as JPEG base64 string.""" buffer = io.BytesIO()