Skip to content

Conversation

@yuan-luo
Copy link
Collaborator

@yuan-luo yuan-luo commented Nov 2, 2025

Motivation

Refactor the VLM's entire load/preprocess/process path with fully async co-routine mechanism.

The solution has several issues to consider:

  1. Using ThreadPoolExecutor for CPU-side Python preprocessing (especially per-frame ops logic) is constrained by the GIL; the more threads it adds, the slower it gets.

  2. If we switch to ProcessPoolExecutor, we are passing a decord.VideoReader instance (vr) as an argument across processes. VideoReader isn’t picklable, so it blows up once it crosses the process boundary. The following logs proves:

[2025-11-05 19:25:52] INFO:     127.0.0.1:52534 - "POST /v1/chat/completions HTTP/1.1" 400 Bad Request
[2025-11-05 19:25:52] INFO:     127.0.0.1:52548 - "POST /v1/chat/completions HTTP/1.1" 400 Bad Request
[2025-11-05 19:25:52] INFO:     127.0.0.1:52556 - "POST /v1/chat/completions HTTP/1.1" 400 Bad Request
[2025-11-05 19:25:53] INFO:     127.0.0.1:52564 - "POST /v1/chat/completions HTTP/1.1" 400 Bad Request

{"object":"error","message":"ctypes objects containing pointers cannot be pickled","type":"BadRequestError","param":null,"code":400}
real    0m0.309s
user    0m0.001s
sys     0m0.003s
{"object":"error","message":"ctypes objects containing pointers cannot be pickled","type":"BadRequestError","param":null,"code":400}
real    0m0.269s
user    0m0.001s
sys     0m0.002s

The most robust and high performance solution is to construct the VideoReader inside the worker process. The parent process should pass only a serializable “video input descriptor” (a file path), and use spawn or forkserver to create child processes (to avoid undefined behavior from inheriting GPU contexts).

[Update Nov 5] Working on the approach to move the VideoReader construction in the parent process into each sub-process, just pass a serializable video_path. As it will change the whole processing logic, set ETA Nov 7.

Currently just modified qwen vl model, will broadcast to the rest of the vl models in case the solution is ready.

Modifications

The core modifications are:

  1. Keep current TokenizerManager / AsyncMMDataProcessor interface unchanged
  2. Make sure load_mm_data_async() the VIDEO data is processed through BaseMultimodalProcessor.submit_data_loading_tasks_async()
  3. The return value of the submit_data_loading_task_async would be a VideoInput list, then preprocess_video() will put VideoInput into process pool to do decode

Accuracy Tests

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yuan-luo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the multimodal data preprocessing pipeline, particularly for VLM (Vision-Language Model) video handling, by integrating asynchronous programming patterns. The primary goal is to enhance performance and system responsiveness by delegating blocking I/O and computationally intensive video processing tasks to dedicated thread pools. This ensures that the main application flow remains unblocked, leading to a more efficient and scalable data processing architecture.

Highlights

  • Asynchronous Multimodal Data Loading: Introduced submit_data_loading_tasks_async and load_mm_data_async methods in base_processor.py to enable asynchronous loading of image, video, and audio data, improving responsiveness by offloading I/O operations.
  • Dedicated Video Processing Thread Pool: A new video_executor (ThreadPoolExecutor) has been added to base_processor.py to handle video preprocessing tasks in a dedicated thread pool, preventing these potentially CPU-intensive operations from blocking the main event loop.
  • Refactored Video Preprocessing Logic: The preprocess_video function in qwen_vl.py was refactored into an asynchronous instance method self.preprocess_video and a static method preprocess_video_task within the QwenVLImageProcessor class. This allows video preprocessing to be executed asynchronously via the new video_executor.
  • Integration of Async Workflow: The process_mm_data_async method in qwen_vl.py was updated to utilize the new asynchronous data loading and video preprocessing methods, ensuring the entire multimodal processing pipeline operates asynchronously.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@yuan-luo yuan-luo changed the title Refactor async vl preprocess v2 [WIP] Refactor async vl preprocess v2 Nov 2, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully refactors the multimodal data preprocessing pipeline to be fully asynchronous, which is a great improvement for performance in an async environment. The changes are well-structured, introducing async versions of data loading methods and correctly utilizing asyncio features. The video-specific preprocessing in qwen_vl.py is also nicely refactored to leverage this new async infrastructure. My feedback includes suggestions for using modern asyncio APIs, improving performance by selecting the appropriate executor for CPU-bound tasks, and enhancing code clarity and robustness.

Comment on lines 173 to 175
self.video_executor = concurrent.futures.ThreadPoolExecutor(
max_workers=int(os.environ.get("SGLANG_VIDEO_WORKERS", 8))
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new video_executor is a ThreadPoolExecutor. Video preprocessing tasks, like those in qwen_vl.py::preprocess_video_task, often involve CPU-bound operations (e.g., decoding, resizing frames) that can be limited by Python's Global Interpreter Lock (GIL). For true parallelism with CPU-bound tasks, a ProcessPoolExecutor is typically more effective. Consider using the existing self.cpu_executor or creating a new ProcessPoolExecutor for video tasks to avoid potential performance bottlenecks. If ThreadPoolExecutor was chosen because the tasks are I/O-bound or release the GIL, a brief comment explaining this would be helpful for future maintainers.

"Mismatch between image tokens and estimated frame counts."
)

loop = asyncio.get_running_loop()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The call to asyncio.get_running_loop() is inside a for loop. Since the event loop does not change during the execution of this method, this call can be moved outside the loop (e.g., before line 352) for a minor performance improvement and cleaner code.

Comment on lines +504 to +534
else:
prompt = prompt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This else block is redundant as it assigns the prompt variable to itself. It can be removed to improve code clarity.

image_factor: int = IMAGE_FACTOR,
) -> torch.Tensor:
if self.video_executor is not None:
loop = asyncio.get_event_loop()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

asyncio.get_event_loop() is deprecated since Python 3.10 and its behavior can be surprising (e.g., creating a new loop if one is not running). It's recommended to use asyncio.get_running_loop() which is safer as it raises a RuntimeError if no loop is running. This also makes it consistent with the usage in base_processor.py.

Suggested change
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()

@yuan-luo yuan-luo changed the title [WIP] Refactor async vl preprocess v2 Refactor async vl preprocess v2 Nov 2, 2025
@yuan-luo yuan-luo force-pushed the refactor_async_vl_preprocess_v2 branch from 123ec2d to 8c2234f Compare November 3, 2025 02:19
@yuan-luo yuan-luo added Multi-modal multi-modal language model performance vlm labels Nov 5, 2025
@yuan-luo
Copy link
Collaborator Author

yuan-luo commented Nov 5, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the video preprocessing pipeline to use asynchronous coroutines, aiming to resolve performance bottlenecks associated with the previous multiprocessing implementation. The changes introduce async versions of data loading and processing methods, currently targeting the Qwen VL model. The approach of using run_in_executor for blocking tasks is sound. My review focuses on improving maintainability by reducing code duplication in the new async methods and simplifying some of the logic. I've also suggested using a more modern and safer asyncio API for better consistency and robustness.

Comment on lines +479 to +518
async def load_mm_data_async(
self,
prompt: str,
multimodal_tokens: MultimodalSpecialTokens,
image_data: Optional[list] = None,
video_data: Optional[list] = None,
audio_data: Optional[list] = None,
return_text: Optional[bool] = True,
discard_alpha_channel: bool = True,
audio_sample_rate: Optional[int] = None,
) -> BaseMultiModalProcessorOutput:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication between this new load_mm_data_async method and the existing load_mm_data method. The logic for processing the loaded data (from line 532 onwards) is nearly identical, with the main difference being await next(futures_iter) versus next(futures_iter).result().

To improve maintainability and reduce redundancy, consider refactoring the result-processing logic into a separate, private helper method. This helper could take the loaded data as an argument.

For example:

def _process_loaded_mm_data(self, text_parts, multimodal_tokens_pattern, task_info_iter, loaded_data):
    # ... common result processing logic ...

async def load_mm_data_async(self, ...):
    # ... submission logic ...
    results = await asyncio.gather(*futures)
    return self._process_loaded_mm_data(text_parts, multimodal_tokens, task_info_iter, results)

def load_mm_data(self, ...):
    # ... submission logic ...
    results = [f.result() for f in futures]
    return self._process_loaded_mm_data(text_parts, multimodal_tokens, task_info_iter, results)

This would make the code cleaner and easier to maintain, especially before extending this pattern to other models.

Comment on lines 239 to 291
ele = {}
total_frames, video_fps = len(vr), vr.get_avg_fps()
nframes = smart_nframes({}, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
video = vr.get_batch(idx).asnumpy()
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
nframes, _, height, width = video.shape
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
max_pixels = max(
min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR),
int(min_pixels * 1.05),
)
max_pixels_supposed = ele.get("max_pixels", max_pixels)
if max_pixels_supposed > max_pixels:
logger.warning(
f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}]."
)
max_pixels = min(max_pixels_supposed, max_pixels)
if "resized_height" in ele and "resized_width" in ele:
resized_height, resized_width = smart_resize(
ele["resized_height"],
ele["resized_width"],
factor=image_factor,
)
else:
resized_height, resized_width = smart_resize(
height,
width,
factor=image_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ele dictionary is initialized as empty and never modified, making the subsequent calls to ele.get(...) and checks like if "resized_height" in ele effectively dead code. This makes the logic unnecessarily complex and potentially confusing for future maintenance.

Consider removing the ele dictionary and simplifying the code to use the default values directly.

Suggested change
ele = {}
total_frames, video_fps = len(vr), vr.get_avg_fps()
nframes = smart_nframes({}, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
video = vr.get_batch(idx).asnumpy()
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
nframes, _, height, width = video.shape
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
max_pixels = max(
min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR),
int(min_pixels * 1.05),
)
max_pixels_supposed = ele.get("max_pixels", max_pixels)
if max_pixels_supposed > max_pixels:
logger.warning(
f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}]."
)
max_pixels = min(max_pixels_supposed, max_pixels)
if "resized_height" in ele and "resized_width" in ele:
resized_height, resized_width = smart_resize(
ele["resized_height"],
ele["resized_width"],
factor=image_factor,
)
else:
resized_height, resized_width = smart_resize(
height,
width,
factor=image_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
total_frames, video_fps = len(vr), vr.get_avg_fps()
nframes = smart_nframes({}, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
video = vr.get_batch(idx).asnumpy()
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
nframes, _, height, width = video.shape
min_pixels = VIDEO_MIN_PIXELS
total_pixels = VIDEO_TOTAL_PIXELS
max_pixels = max(
min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR),
int(min_pixels * 1.05),
)
resized_height, resized_width = smart_resize(
height,
width,
factor=image_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)

@yuan-luo yuan-luo changed the title Refactor async vl preprocess v2 [WIP] Refactor vl video path to full async mode Nov 5, 2025
@yuan-luo yuan-luo force-pushed the refactor_async_vl_preprocess_v2 branch from e3296fe to 653ca1a Compare November 5, 2025 13:38
@yuan-luo
Copy link
Collaborator Author

yuan-luo commented Nov 6, 2025

There are some child processes terminated abruptly, the the process pool is not usable. Investigating.

[2025-11-06 10:37:24] Error in request: A child process terminated abruptly, the process pool is not usable anymore
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/entrypoints/openai/serving_base.py", line 105, in handle_request
    return await self._handle_non_streaming_request(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/entrypoints/openai/serving_chat.py", line 694, in _handle_non_streaming_request
    ret = await self.tokenizer_manager.generate_request(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/tokenizer_manager.py", line 419, in generate_request
    tokenized_obj = await self._tokenize_one_request(obj)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/tokenizer_manager.py", line 607, in _tokenize_one_request
    mm_inputs: Dict = await self.mm_data_processor.process(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/async_mm_data_processor.py", line 99, in process
    return await asyncio.wait_for(_invoke(), timeout=self.timeout_s)
  File "/opt/conda/lib/python3.10/asyncio/tasks.py", line 445, in wait_for
    return fut.result()
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/async_mm_data_processor.py", line 70, in _invoke
    return await self._proc_async(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/multimodal/processors/qwen_vl.py", line 321, in process_mm_data_async
    video_results = await asyncio.gather(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/multimodal/processors/qwen_vl.py", line 221, in preprocess_video
    return await loop.run_in_executor(
  File "uvloop/loop.pyx", line 2747, in uvloop.loop.Loop.run_in_executor
  File "/opt/conda/lib/python3.10/concurrent/futures/process.py", line 720, in submit
    raise BrokenProcessPool(self._broken)
concurrent.futures.process.BrokenProcessPool: A child process terminated abruptly, the process pool is not usable anymore
[2025-11-06 10:37:24] INFO:     127.0.0.1:43630 - "POST /v1/chat/completions HTTP/1.1" 500 Internal Server Error

@yuan-luo
Copy link
Collaborator Author

yuan-luo commented Nov 6, 2025

The error is because can't import decord in parent process, move it into sub-process. After the fix, encounter new error.

Process SpawnProcess-10:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.10/concurrent/futures/process.py", line 240, in _process_worker
    call_item = call_queue.get(block=True)
  File "/opt/conda/lib/python3.10/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/multimodal/processors/qwen_vl.py", line 14, in <module>
    from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/models/qwen2_5_vl.py", line 43, in <module>
    from sglang.srt.layers.attention.vision import VisionAttention
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/attention/vision.py", line 41, in <module>
    from sglang.srt.layers.linear import (
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/linear.py", line 33, in <module>
    from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/quantization/__init__.py", line 19, in <module>
    from sglang.srt.layers.quantization.auto_round import AutoRoundConfig
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/quantization/auto_round.py", line 17, in <module>
    from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
ImportError: cannot import name 'LinearBase' from partially initialized module 'sglang.srt.layers.linear' (most likely due to a circular import) (/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/linear.py)
[2025-11-06 11:21:51] Error in request: A process in the process pool was terminated abruptly while the future was running or pending.
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/entrypoints/openai/serving_base.py", line 105, in handle_request
    return await self._handle_non_streaming_request(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/entrypoints/openai/serving_chat.py", line 694, in _handle_non_streaming_request
    ret = await self.tokenizer_manager.generate_request(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/tokenizer_manager.py", line 419, in generate_request
    tokenized_obj = await self._tokenize_one_request(obj)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/tokenizer_manager.py", line 607, in _tokenize_one_request
    mm_inputs: Dict = await self.mm_data_processor.process(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/async_mm_data_processor.py", line 99, in process
    return await asyncio.wait_for(_invoke(), timeout=self.timeout_s)
  File "/opt/conda/lib/python3.10/asyncio/tasks.py", line 445, in wait_for
    return fut.result()
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/async_mm_data_processor.py", line 70, in _invoke
    return await self._proc_async(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/multimodal/processors/qwen_vl.py", line 334, in process_mm_data_async
    video_results = await asyncio.gather(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/multimodal/processors/qwen_vl.py", line 221, in preprocess_video
    return await loop.run_in_executor(
concurrent.futures.process.BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.
[2025-11-06 11:21:51] INFO:     127.0.0.1:45588 - "POST /v1/chat/completions HTTP/1.1" 500 Internal Server Error

After investigating, the reason is stack import dependency in qwen_vl.py. The subprocess (worker) import a series of model layer, in linear and quantization have recursive import, which makes the worker process broken.
Refactoring the code.

@yuan-luo yuan-luo force-pushed the refactor_async_vl_preprocess_v2 branch from 653ca1a to 3b55810 Compare November 7, 2025 04:05
@yuan-luo yuan-luo marked this pull request as draft November 7, 2025 05:48
@yuan-luo
Copy link
Collaborator Author

yuan-luo commented Nov 7, 2025

The error prints in subprocess terminate in

ImportError: cannot import name 'LinearBase' from partially initialized module 'sglang.srt.layers.linear' (most likely due to a circular import) (/usr/local/lib/python3.10/dist-packages/sglang/srt/layers/linear.py)

The full log backtrace is:

[2025-11-06 22:35:36] INFO:     127.0.0.1:60680 - "POST /generate HTTP/1.1" 200 OK
[2025-11-06 22:35:36] The server is fired up and ready to roll!
[2025-11-06 22:36:16] INFO utils.py:148: Note: detected 248 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-06 22:36:16] INFO utils.py:151: Note: NumExpr detected 248 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-06 22:36:16] INFO utils.py:164: NumExpr defaulting to 16 threads.
Process SpawnProcess-3:
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.10/concurrent/futures/process.py", line 240, in _process_worker
    call_item = call_queue.get(block=True)
  File "/usr/lib/python3.10/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
  File "/usr/local/lib/python3.10/dist-packages/sglang/srt/multimodal/processors/qwen_vl.py", line 16, in <module>
    from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
  File "/usr/local/lib/python3.10/dist-packages/sglang/srt/models/qwen2_5_vl.py", line 43, in <module>
    from sglang.srt.layers.attention.vision import VisionAttention
  File "/usr/local/lib/python3.10/dist-packages/sglang/srt/layers/attention/vision.py", line 44, in <module>
    from sglang.srt.layers.linear import (
  File "/usr/local/lib/python3.10/dist-packages/sglang/srt/layers/linear.py", line 34, in <module>
    from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
  File "/usr/local/lib/python3.10/dist-packages/sglang/srt/layers/quantization/__init__.py", line 19, in <module>
    from sglang.srt.layers.quantization.auto_round import AutoRoundConfig
  File "/usr/local/lib/python3.10/dist-packages/sglang/srt/layers/quantization/auto_round.py", line 17, in <module>
    from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
ImportError: cannot import name 'LinearBase' from partially initialized module 'sglang.srt.layers.linear' (most likely due to a circular import) (/usr/local/lib/python3.10/dist-packages/sglang/srt/layers/linear.py)
[2025-11-06 22:36:20] Error in request: A process in the process pool was terminated abruptly while the future was running or pending.

@yuan-luo
Copy link
Collaborator Author

yuan-luo commented Nov 7, 2025

The root cause of the error is that when a child process deserializes (pickles) a callable object which is submitted to ProcessPoolExecutor, it must import the module where that callable object is defined. We put the callable object under qwen_vl. Importing that module immediately pulls in various Qwen models → which triggers layers.linear → which then imports quantization → and auto_round.py in turn imports layers.linear back, creating a circular dependency.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants