Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
fbd6c7c
0815-temp
SangChengC Aug 15, 2025
4acd7e7
0815-add-visual-only
SangChengC Aug 15, 2025
68dd163
0820-add-llm-only
SangChengC Aug 20, 2025
9aaf63b
0820
SangChengC Aug 20, 2025
66d0c10
0820-del-metric
SangChengC Aug 20, 2025
1f46fd2
add redis server for vit/llm disaggaggregation
shihaobai Aug 26, 2025
3561a17
remove unused code of http manager
shihaobai Aug 26, 2025
0ef48cb
[0826]modify visual server
SangChengC Aug 26, 2025
d4de040
merge main
shihaobai Aug 26, 2025
27ef8f3
add vit mananger for vit-llm disaggr
shihaobai Aug 27, 2025
70bc956
[0827]temp
SangChengC Aug 27, 2025
ded28b7
update visual server mananger
shihaobai Aug 27, 2025
3a89cf0
add visual start
shihaobai Aug 27, 2025
630d3ee
rename
shihaobai Aug 28, 2025
4407040
add vit register loop
shihaobai Aug 28, 2025
a566580
[0828]temp
SangChengC Aug 28, 2025
1ae9cd3
[0828]temp
SangChengC Aug 28, 2025
c99bb46
fix vit manager
shihaobai Aug 28, 2025
81cbc03
merge
shihaobai Aug 28, 2025
00b3b53
fix llm remote vit init
shihaobai Aug 28, 2025
62f80c4
[0828]temp
SangChengC Aug 28, 2025
b673a36
fix vit transfer
shihaobai Aug 28, 2025
2eaa709
Merge branch 'visual_only2' of https://github.com/ModelTC/lightllm in…
shihaobai Aug 28, 2025
67a3c38
fix connection bug
shihaobai Aug 28, 2025
676215e
add wait for embed for llm
shihaobai Aug 28, 2025
33923b9
[0828]fix vit embed
SangChengC Aug 29, 2025
fcac8e5
[0828]temp
SangChengC Aug 29, 2025
06f7817
[0828]temp
SangChengC Aug 29, 2025
daf1318
[0829]add free_afs
SangChengC Aug 29, 2025
1c16903
[support]add get_image_embedding
SangChengC Sep 3, 2025
c1d98eb
0903
SangChengC Sep 3, 2025
cffa0a0
0909
SangChengC Sep 9, 2025
0a296a1
0911
SangChengC Sep 11, 2025
6b95156
0911-add-other-multimodal's vit dispatch
SangChengC Sep 11, 2025
4853561
[fix]0915-fix-rpyc-cost
SangChengC Sep 16, 2025
ffe2f6b
[fix]fix redis
SangChengC Sep 19, 2025
e723c40
[fix]clean redis before start
SangChengC Sep 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions lightllm/models/gemma3/gemma3_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@


class Gemma3VisionModel:
def __init__(self):
def __init__(self, kvargs):
self.remote_vit = kvargs.get("remote_vit", False)
pass

def load_model(self, weight_dir):
Expand Down Expand Up @@ -122,7 +123,10 @@ def encode(self, images: List[ImageItem]):
for i, img in enumerate(images):
if isinstance(img, ImageItem):
uuids.append(img.uuid)
image_data = read_shm(get_shm_name_data(img.uuid))
if self.remote_vit:
image_data = img._preload_data
else:
image_data = read_shm(get_shm_name_data(img.uuid))
image_data = Image.open(BytesIO(image_data))
t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"]
img_tensors.append(t)
Expand Down
13 changes: 7 additions & 6 deletions lightllm/models/internvl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,22 @@ def init_imageitem_extral_params(
img.extra_params["image_patch_max_num"] = 6
elif num_images > 6:
img.extra_params["image_patch_max_num"] = 0
img.patch_num = self.get_image_patch(img)
return

def init_audioitem_extral_params(
self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams
):
return

def get_image_token_length(self, img: ImageItem):
return (
self.get_image_patch_func(
img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True
)
* self.image_length
def get_image_patch(self, img: ImageItem):
return self.get_image_patch_func(
img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True
)

def get_image_token_length(self, img: ImageItem):
return self.get_image_patch(img) * self.image_length

def get_audio_token_length(self, audio: AudioItem):
L = audio.audio_length
L = L if L <= 480000 else 480000 # max_length < 30s
Expand Down
8 changes: 6 additions & 2 deletions lightllm/models/llava/llava_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@


class LlavaVisionModel:
def __init__(self):
def __init__(self, kvargs):
self.remote_vit = kvargs.get("remote_vit", False)
pass

def load_model(self, weight_dir):
Expand Down Expand Up @@ -133,7 +134,10 @@ def encode(self, images: List[ImageItem]):
for i, img in enumerate(images):
if isinstance(img, ImageItem):
uuids.append(img.uuid)
image_data = read_shm(get_shm_name_data(img.uuid))
if self.remote_vit:
image_data = img._preload_data
else:
image_data = read_shm(get_shm_name_data(img.uuid))
image_data = Image.open(BytesIO(image_data)).convert("RGB")
t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"]
img_tensors.append(t)
Expand Down
6 changes: 5 additions & 1 deletion lightllm/models/qwen2_5_vl/qwen2_5_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def __init__(
self.window_size = window_size
self.fullatt_block_indexes = fullatt_block_indexes
self.out_hidden_size = out_hidden_size
self.remote_vit = kvargs.get("remote_vit", False)

self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size

Expand Down Expand Up @@ -381,7 +382,10 @@ def encode(self, images: List[ImageItem]):
for i, img in enumerate(images):
if isinstance(img, ImageItem):
uuids.append(img.uuid)
image_data = read_shm(get_shm_name_data(img.uuid))
if self.remote_vit:
image_data = img._preload_data
else:
image_data = read_shm(get_shm_name_data(img.uuid))
image_data = Image.open(BytesIO(image_data))
image_data = resize_image(image_data)
pixel_values, image_grid_thw = self.processor.preprocess(image_data)
Expand Down
6 changes: 5 additions & 1 deletion lightllm/models/qwen2_vl/qwen2_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def __init__(
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
self.remote_vit = kvargs.get("remote_vit", False)

self.patch_embed = PatchEmbed(
patch_size=self.patch_size,
Expand Down Expand Up @@ -309,7 +310,10 @@ def encode(self, images: List[ImageItem]):
for i, img in enumerate(images):
if isinstance(img, ImageItem):
uuids.append(img.uuid)
image_data = read_shm(get_shm_name_data(img.uuid))
if self.remote_vit:
image_data = img._preload_data
else:
image_data = read_shm(get_shm_name_data(img.uuid))
image_data = Image.open(BytesIO(image_data))
image_data = resize_image(image_data)
pixel_values, image_grid_thw = self.processor.preprocess(image_data)
Expand Down
18 changes: 14 additions & 4 deletions lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import rpyc
import socket
import torch
import torch.distributed as dist

Expand All @@ -6,9 +8,10 @@

from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
from lightllm.utils.infer_utils import mark_cost_time
from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed
from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed, read_afs
from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb
from lightllm.distributed.communication_op import all_reduce
from lightllm.utils.envs_utils import get_env_start_args


"""
Expand All @@ -29,6 +32,9 @@
class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer):
def __init__(self, network_config, mode):
super().__init__(network_config, mode)
self.args = get_env_start_args()
self.cache_client = rpyc.connect("localhost", self.args.cache_port, config={"allow_pickle": True})
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
return

def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight):
Expand All @@ -50,9 +56,13 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
# skip the same image
if img["token_id"] in img_start_token_ids or img["_prefill_"] is False:
continue
# pull the img_embeds by uid from shm
data = read_shm(get_shm_name_embed(img["uuid"]))
img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1))
# pull the img_embeds by uid from shm or afs
if self.args.enable_remote_vit:
embed = read_afs(get_shm_name_embed(img["uuid"]), self.args.image_embed_dir)
else:
embed = read_shm(get_shm_name_embed(img["uuid"]))
self.cache_client.root.release([img["uuid"]])
img_weight.append(bytes2tensor(embed).cuda().reshape(img["token_num"], -1))
img_start_token_ids.append(img["token_id"])
img_token_lens.append(img["token_num"])
img_start_locs.append(img_start_loc)
Expand Down
7 changes: 6 additions & 1 deletion lightllm/models/qwen_vl/qwen_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
class QWenVisionTransformer(nn.Module):
def __init__(
self,
kvargs,
image_size: int,
patch_size: int,
width: int,
Expand All @@ -344,6 +345,7 @@ def __init__(
**kwargs,
):
super().__init__()
self.remote_vit = kvargs.get("remote_vit", False)
image_height, image_width = self.image_size = (image_size, image_size)
patch_height, patch_width = self.patch_size = (patch_size, patch_size)
self.grid_size = (image_height // patch_height, image_width // patch_width)
Expand Down Expand Up @@ -422,7 +424,10 @@ def encode(self, image_uuids: List):
for i, item in enumerate(image_uuids):
if isinstance(item, int):
uuids.append(item)
image_data = read_shm(get_shm_name_data(item))
if self.remote_vit:
image_data = item._preload_data
else:
image_data = read_shm(get_shm_name_data(item.uuid))
image_data = Image.open(BytesIO(image_data)).convert("RGB")
t = self.image_transform(image_data)
img_tensors.append(t)
Expand Down
7 changes: 6 additions & 1 deletion lightllm/models/tarsier2/tarsier2_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def forward(self, image_features, input_embeddings):
class TarsierVisionTransformerPretrainedModel(nn.Module):
def __init__(
self,
kvargs,
vision_config=None,
text_config=None,
ignore_index=-100,
Expand All @@ -165,6 +166,7 @@ def __init__(
**kwargs,
):
super().__init__()
self.remote_vit = kvargs.get("remote_vit", False)
self.vision_tower = Qwen2VisionTransformerPretrainedModel(**vision_config)

if projection_head == "Pixel_Shuffle":
Expand Down Expand Up @@ -251,7 +253,10 @@ def encode(self, images: List[ImageItem]):
for i, img in enumerate(images):
if isinstance(img, ImageItem):
uuids.append(img.uuid)
image_data = read_shm(get_shm_name_data(img.uuid))
if self.remote_vit:
image_data = img._preload_data
else:
image_data = read_shm(get_shm_name_data(img.uuid))
image_data = Image.open(BytesIO(image_data))
image_data = resize_image(image_data)
pixel_values, image_grid_thw = self.processor.preprocess(image=image_data)
Expand Down
2 changes: 2 additions & 0 deletions lightllm/models/vit/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import time
import json
import torch
from lightllm.models.vit.layer_infer.pre_layer_infer import ViTPreLayerInfer
Expand Down Expand Up @@ -47,6 +48,7 @@ def __init__(self, kvargs):
self.quant_cfg_path = kvargs.get("quant_cfg", None)
self.load_image_func = get_load_image_func(self.weight_dir_)
self.max_batch_size = kvargs.get("max_batch_size", 1)
self.remote_vit = kvargs.get("remote_vit", False)

self._init_datatype()
self._init_config()
Expand Down
39 changes: 37 additions & 2 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--run_mode",
type=str,
choices=["normal", "prefill", "decode", "pd_master", "config_server"],
choices=["normal", "prefill", "decode", "pd_master", "config_server", "visual"],
default="normal",
help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode,
config_server is for pd split mode used to register pd_master node, and get pd_master node list,
Expand Down Expand Up @@ -353,7 +353,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
"--visual_nccl_ports",
nargs="+",
type=int,
default=[29500],
default=None,
help="List of NCCL ports to build a distributed environment for Vit, e.g., 29500 29501 29502",
)
parser.add_argument(
Expand Down Expand Up @@ -505,4 +505,39 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=0.03,
help="""The interval of the schedule time, default is 30ms.""",
)
parser.add_argument(
"--image_embed_dir",
type=str,
default=None,
help="path for vit embed",
)
parser.add_argument(
"--enable_remote_vit",
action="store_true",
help="Whether to enable remote vit for multimodal service.",
)
parser.add_argument(
"--remote_vit_port",
type=int,
default=12346,
help="The port number for the remote vit service.",
)
# redis for vit llm disaggregation
parser.add_argument(
"--redis_port",
type=int,
default=6379,
help="The port number for the redis service in config_server mode.",
)
parser.add_argument(
"--redis_evict_fraction",
type=float,
default=0.3,
help="The evict fraction for the redis service in config_server mode.",
)
parser.add_argument(
"--start_redis",
action="store_true",
help="Whether to start the redis service in config_server mode.",
)
return parser
24 changes: 22 additions & 2 deletions lightllm/server/api_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@
from lightllm.server.core.objs.sampling_params import SamplingParams
from .multimodal_params import MultimodalParams
from .httpserver.manager import HttpServerManager
from .visualserver.manager import VisualManager
from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster
from .api_lightllm import lightllm_get_score

# from .visualserver.manager import VisualManager
from .api_lightllm import lightllm_get_score, lightllm_get_image_embedding
from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size
from lightllm.utils.log_utils import init_logger
from lightllm.utils.error_utils import ServerBusyError
Expand Down Expand Up @@ -69,6 +72,7 @@ class G_Objs:
g_generate_func: Callable = None
g_generate_stream_func: Callable = None
httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster] = None
visual_manager: VisualManager = None
shared_token_load: TokenLoad = None

def set_args(self, args):
Expand All @@ -89,6 +93,8 @@ def set_args(self, args):
args,
metric_port=args.metric_port,
)
elif args.run_mode == "visual":
self.metric_client = MetricClient(args.metric_port)
else:
init_tokenizer(args) # for openai api
SamplingParams.load_generation_cfg(args.model_dir)
Expand Down Expand Up @@ -139,7 +145,7 @@ def get_model_name():
@app.get("/health", summary="Check server health")
@app.head("/health", summary="Check server health")
async def healthcheck(request: Request):
if g_objs.args.run_mode == "pd_master":
if g_objs.args.run_mode in ["pd_master", "visual"]:
return JSONResponse({"message": "Ok"}, status_code=200)

if os.environ.get("DEBUG_HEALTHCHECK_RETURN_FAIL") == "true":
Expand Down Expand Up @@ -209,6 +215,18 @@ async def get_score(request: Request) -> Response:
return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))


@app.post("/get_image_embedding")
async def get_image_embed(request: Request) -> Response:
try:
return await lightllm_get_image_embedding(request, g_objs.httpserver_manager)
except ServerBusyError as e:
logger.error("%s", str(e), exc_info=True)
return create_error_response(HTTPStatus.SERVICE_UNAVAILABLE, str(e))
except Exception as e:
logger.error("An error occurred: %s", str(e), exc_info=True)
return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))


@app.post("/")
async def compat_generate(request: Request) -> Response:
request_dict = await request.json()
Expand Down Expand Up @@ -334,6 +352,8 @@ async def startup_event():
logger.info("server start up")
loop = asyncio.get_event_loop()
g_objs.set_args(get_env_start_args())
if g_objs.httpserver_manager is None:
return
loop.create_task(g_objs.httpserver_manager.handle_loop())
logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}")
return
18 changes: 18 additions & 0 deletions lightllm/server/api_lightllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from lightllm.server.core.objs.sampling_params import SamplingParams
from .multimodal_params import MultimodalParams
from .httpserver.manager import HttpServerManager
from .visualserver.manager import VisualManager
from fastapi.responses import JSONResponse
import ujson as json


Expand Down Expand Up @@ -150,3 +152,19 @@ async def stream_results() -> AsyncGenerator[bytes, None]:

background_tasks = BackgroundTasks()
return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks)


async def lightllm_get_image_embedding(request: Request, httpserver_manager: HttpServerManager) -> Response:
request_dict = await request.json()
# request_dict: {'parameters': {'max_new_tokens': 128},
# 'multimodal_params': {'images': [{'type': 'base64', 'data': 'base64'}]}}
sample_params_dict = request_dict["parameters"]
sampling_params = SamplingParams()
sampling_params.init(tokenizer=None, **sample_params_dict)
sampling_params.verify()
multimodal_params_dict = request_dict.get("multimodal_params", {})
multimodal_params = MultimodalParams(**multimodal_params_dict)

await httpserver_manager.get_image_embeding(sampling_params, multimodal_params, request=request)

return JSONResponse({"message": "OK"}, status_code=200)
4 changes: 3 additions & 1 deletion lightllm/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess
parser = make_argument_parser()
args = parser.parse_args()
from .api_start import pd_master_start, normal_or_p_d_start, config_server_start
from .api_start import pd_master_start, normal_or_p_d_start, visual_start, config_server_start

if args.run_mode == "pd_master":
pd_master_start(args)
elif args.run_mode == "config_server":
config_server_start(args)
elif args.run_mode == "visual":
visual_start(args)
else:
normal_or_p_d_start(args)
Loading