Skip to content

Commit 0a0c74e

Browse files
authored
[XPU] Support PaddleOCR-VL model for XPU (#4529)
* [XPU] support PaddleOCR-VL in XPU * [XPU] fix PaddleOCR-VL pos_emb_type
1 parent 2a9ed72 commit 0a0c74e

File tree

2 files changed

+109
-27
lines changed

2 files changed

+109
-27
lines changed

fastdeploy/model_executor/layers/attention/xpu_attn_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ def __init__(
8383
self.rope_theta: float = (
8484
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
8585
)
86-
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
86+
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
87+
fd_config.model_config, "use_3d_rope", False
88+
)
8789
self.causal: bool = getattr(fd_config.model_config, "causal", True)
8890
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
8991
self.rank: int = fd_config.parallel_config.tensor_parallel_rank

fastdeploy/worker/xpu_model_runner.py

Lines changed: 106 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ def insert_tasks_v1(self, req_dicts: List[Request]):
420420
req_len = len(req_dicts)
421421
has_prefill_task = False
422422
has_decode_task = False
423+
multi_vision_inputs = {"images_lst": [], "grid_thw_lst": [], "vit_position_ids_lst": [], "cu_seqlens": [0]}
423424
rope_3d_position_ids = {
424425
"position_ids_idx": [],
425426
"position_ids_lst": [],
@@ -436,24 +437,39 @@ def insert_tasks_v1(self, req_dicts: List[Request]):
436437
if self.enable_mm:
437438
inputs = request.multimodal_inputs
438439
if request.with_image:
439-
vision_inputs = {}
440-
vision_inputs["input_ids"] = paddle.to_tensor(
441-
inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
442-
)
443-
vision_inputs["token_type_ids"] = paddle.to_tensor(
444-
inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
445-
)
446-
vision_inputs["image_type_ids"] = paddle.to_tensor(
447-
inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end],
448-
dtype=paddle.int64,
449-
)
450-
vision_inputs["images"] = paddle.to_tensor(
451-
inputs["images"][request.image_start : request.image_end], dtype="uint8"
452-
)
453-
vision_inputs["grid_thw"] = paddle.to_tensor(
454-
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
455-
)
456-
self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs)
440+
if envs.FD_ENABLE_MAX_PREFILL:
441+
multi_vision_inputs["images_lst"].append(
442+
paddle.to_tensor(inputs["images"][request.image_start : request.image_end])
443+
)
444+
multi_vision_inputs["grid_thw_lst"].extend(
445+
inputs["grid_thw"][request.num_image_start : request.num_image_end]
446+
)
447+
multi_vision_inputs["cu_seqlens"].extend(
448+
inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
449+
)
450+
multi_vision_inputs["vit_position_ids_lst"].extend(
451+
inputs["vit_position_ids"][request.num_image_start : request.num_image_end]
452+
)
453+
else:
454+
vision_inputs = {}
455+
vision_inputs["input_ids"] = paddle.to_tensor(
456+
inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
457+
)
458+
vision_inputs["token_type_ids"] = paddle.to_tensor(
459+
inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
460+
)
461+
vision_inputs["image_type_ids"] = paddle.to_tensor(
462+
inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end],
463+
dtype=paddle.int64,
464+
)
465+
vision_inputs["images"] = paddle.to_tensor(
466+
inputs["images"][request.image_start : request.image_end],
467+
dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16",
468+
)
469+
vision_inputs["grid_thw"] = paddle.to_tensor(
470+
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
471+
)
472+
self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs)
457473
else:
458474
self.share_inputs["image_features"] = None
459475

@@ -570,6 +586,9 @@ def insert_tasks_v1(self, req_dicts: List[Request]):
570586
else:
571587
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
572588

589+
if len(multi_vision_inputs["images_lst"]) > 0:
590+
self.share_inputs["image_features"] = self.extract_vision_features(multi_vision_inputs)
591+
573592
if len(rope_3d_position_ids["position_ids_idx"]) > 0:
574593
packed_position_ids = paddle.to_tensor(
575594
np.concatenate(rope_3d_position_ids["position_ids_lst"]), dtype="int64"
@@ -826,14 +845,24 @@ def _init_share_inputs(self, max_num_seqs: int):
826845

827846
if self.enable_mm:
828847
head_dim = self.model_config.head_dim
848+
if "paddleocr" in self.model_config.model_type: # neox style = True
849+
rope_head_dim = head_dim
850+
else: # neox style = False
851+
rope_head_dim = head_dim // 2
852+
853+
if head_dim == self.model_config.head_dim:
854+
self.share_inputs["pos_emb_type"] = "NORMAL"
855+
else:
856+
self.share_inputs["pos_emb_type"] = "HALF_HEAD_DIM"
857+
829858
self.share_inputs["rope_emb"] = paddle.full(
830859
shape=[
831860
max_num_seqs,
832861
2,
833862
1,
834863
self.model_config.max_model_len,
835864
1,
836-
head_dim // 2,
865+
rope_head_dim,
837866
],
838867
fill_value=0,
839868
dtype="float32",
@@ -866,8 +895,8 @@ def _prepare_inputs(self, is_dummy_run=False) -> None:
866895
# Update bad tokens len
867896
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
868897

869-
if self.enable_mm: # pos_emb_type is different in EB and VL
870-
self.forward_meta.pos_emb_type = "HALF_HEAD_DIM"
898+
if self.enable_mm:
899+
self.forward_meta.pos_emb_type = self.share_inputs["pos_emb_type"]
871900
self.forward_meta.attn_backend = self.attn_backends[0]
872901
self.initialize_attention_backend()
873902

@@ -1338,12 +1367,10 @@ def _preprocess_mm_task(self, one: dict) -> None:
13381367
)
13391368
return result
13401369

1341-
@paddle.no_grad()
1342-
def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
1343-
"""extract_vision_features"""
1370+
def extract_vision_features_ernie(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
13441371
assert inputs["images"] is not None
13451372
grid_thw = inputs["grid_thw"]
1346-
1373+
# ernie-vl has images norm
13471374
images = inputs["images"].cast("float32")
13481375
images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor
13491376
images = images / self.image_preprocess.image_std_tensor
@@ -1353,7 +1380,6 @@ def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
13531380
token_type_ids_w_video = token_type_ids
13541381
input_ids = inputs["input_ids"]
13551382
# convert to img patch id
1356-
# TODO(lulinjun): may need to check model_config and model_cfg
13571383
image_mask = input_ids == self.model_config.im_patch_id
13581384
image_type_ids = inputs["image_type_ids"]
13591385
with paddle.amp.auto_cast(
@@ -1369,6 +1395,7 @@ def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
13691395
image_features = image_features.reshape([-1, C * self.model_config.spatial_conv_size**2])
13701396
image_features = ScatterOp.apply(image_features, axis=-1) # mp 切 Fea
13711397
image_features = image_features.reshape([S, -1])
1398+
# ernie-vl has resampler_model
13721399
image_features = self.model.resampler_model(
13731400
image_features,
13741401
image_mask,
@@ -1378,6 +1405,59 @@ def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
13781405
)
13791406
return image_features
13801407

1408+
def extract_vision_features_paddleocr(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
1409+
if envs.FD_ENABLE_MAX_PREFILL:
1410+
inputs["vit_position_ids_lst"] = np.concatenate(inputs["vit_position_ids_lst"])
1411+
images = paddle.concat(inputs["images_lst"]).cast("bfloat16")
1412+
grid_thw = paddle.to_tensor(inputs["grid_thw_lst"], dtype="int64")
1413+
position_ids = paddle.to_tensor(inputs["vit_position_ids_lst"], dtype="int64")
1414+
cu_seqlens = paddle.cumsum(paddle.to_tensor(inputs["cu_seqlens"])).cast("int32")
1415+
else:
1416+
assert inputs["images"] is not None
1417+
grid_thw = inputs["grid_thw"]
1418+
images = inputs["images"]
1419+
1420+
position_ids = []
1421+
cu_seqlens = [0]
1422+
for idx, thw in enumerate(grid_thw):
1423+
numel = np.prod(np.array(thw))
1424+
position_ids.append(paddle.arange(numel) % np.prod(thw[1:]))
1425+
cu_seqlens.append(cu_seqlens[-1] + numel)
1426+
1427+
position_ids = paddle.concat(position_ids, axis=0).to(images.place)
1428+
cu_seqlens = paddle.to_tensor(cu_seqlens, dtype=paddle.int32).to(images.place)
1429+
1430+
with paddle.amp.auto_cast(
1431+
True,
1432+
custom_black_list=self.amp_black,
1433+
custom_white_list=self.amp_white,
1434+
level="O2",
1435+
dtype=self.model_config.dtype,
1436+
):
1437+
image_features = self.model.visual(
1438+
pixel_values=images,
1439+
image_grid_thw=grid_thw,
1440+
position_ids=position_ids,
1441+
interpolate_pos_encoding=True,
1442+
cu_seqlens=cu_seqlens,
1443+
use_rope=True,
1444+
window_size=-1,
1445+
)
1446+
image_features = self.model.projector(image_features, grid_thw)
1447+
image_features = paddle.concat(image_features, axis=0)
1448+
1449+
return image_features
1450+
1451+
@paddle.no_grad()
1452+
def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
1453+
"""extract_vision_features"""
1454+
if "ernie" in self.model_config.model_type:
1455+
return self.extract_vision_features_ernie(inputs)
1456+
elif "paddleocr" in self.model_config.model_type:
1457+
return self.extract_vision_features_paddleocr(inputs)
1458+
else:
1459+
raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported")
1460+
13811461
@paddle.no_grad()
13821462
def prepare_rope3d(
13831463
self, position_ids: paddle.Tensor, max_len_lst: list[int], cumsum_seqlens: list[int]

0 commit comments

Comments
 (0)