Skip to content

Fix caption_with_cogvlm.py for cogvlm2 + textfile strategy #936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions toolkit/captioning/caption_with_cogvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ def process_directory(
counter += 1

image.save(new_filepath)
else:
new_filepath = full_filepath
if args.target_backend_id:
upload_to_s3(s3_client, bucket_name, image, new_filename)

Expand Down Expand Up @@ -342,15 +344,23 @@ def main():
if args.output_dir and not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
logger.info("Loading CogVLM model. This should only occur once.")
from transformers import AutoModelForCausalLM, LlamaTokenizer
from transformers import AutoModelForCausalLM, LlamaTokenizer, AutoTokenizer

tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
logger.info(f"Loading CogVLM in {args.precision} precision.")
if "cogvlm2" in args.model_path and torch.backends.mps.is_available():
logger.warning(
"Can not run CogVLM 2 on MPS because Triton is unavailable. Falling back to CogVLM 1.1"
)
args.model_path = "THUDM/cogvlm-chat-hf"
elif "cogvlm2" in args.model_path:
import sysconfig

print(sysconfig.get_paths()["include"])
tokenizer = AutoTokenizer.from_pretrained(
args.model_path, trust_remote_code=True
)
else:
tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")

model = AutoModelForCausalLM.from_pretrained(
args.model_path,
torch_dtype=torch_dtype,
Expand Down
Loading