Skip to content

Commit

Permalink
To improve the code's compatibility with different PyTorch versions. (#…
Browse files Browse the repository at this point in the history
…361)

Co-authored-by: wangzaijun <[email protected]>
  • Loading branch information
hiworldwzj and wangzaijun authored Mar 18, 2024
1 parent 8debc65 commit 6d67fbb
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion lightllm/models/internlm_xcomposer/internlm_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import List, Union
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers import CLIPVisionModel


class InternVisionModel:
Expand Down Expand Up @@ -57,6 +56,7 @@ def load_model(self, weight_dir):
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
from transformers import CLIPVisionModel
self.vision_tower = CLIPVisionModel.from_pretrained(vision_path)
self.vision_tower.requires_grad_(False)
self.resize_pos(config, vision_path)
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/llava/llava_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
from PIL import Image
from typing import List, Union
from transformers import CLIPVisionModel, CLIPImageProcessor


class LlavaVisionModel:
Expand All @@ -26,6 +25,7 @@ def load_model(self, weight_dir):
if vision_path.startswith("./"):
vision_path = os.path.join(weight_dir, vision_path)

from transformers import CLIPVisionModel, CLIPImageProcessor
self.image_processor = CLIPImageProcessor.from_pretrained(vision_path)
self.vision_tower = CLIPVisionModel.from_pretrained(vision_path).half()
self.vision_tower.requires_grad_(False)
Expand Down

0 comments on commit 6d67fbb

Please sign in to comment.