Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
checkpoints
test

# PyInstaller
# Usually these files are written by a python script from a template
Expand Down
22 changes: 21 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,27 @@ We use LLaVA codebase to train FastVLM variants. In order to train or finetune y
please follow instructions provided in [LLaVA](https://github.com/haotian-liu/LLaVA) codebase.
We provide instructions for running inference with our models.


### Setup
#### 🔹 Option 1: Using `uv` (fast Python package manager)

```bash
# 1. Install uv if not already installed
pip3 install uv

# 2. Create a virtual environment
uv venv .venv

# 3. Activate the virtual environment
source .venv/bin/activate # On macOS/Linux
# .venv\\Scripts\\activate # On Windows (use this instead)

# 4. Sync and install dependencies
uv sync
```

#### 🔸 Option 2: Using conda

```bash
conda create -n fastvlm python=3.10
conda activate fastvlm
Expand Down Expand Up @@ -56,7 +76,7 @@ bash get_models.sh # Files will be downloaded to `checkpoints` directory.
To run inference of PyTorch checkpoint, follow the instruction below
```bash
python predict.py --model-path /path/to/checkpoint-dir \
--image-file /path/to/image.png \
--image-file test/image.png \
--prompt "Describe the image."
```

Expand Down
Empty file modified get_models.sh
100644 → 100755
Empty file.
27 changes: 21 additions & 6 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
from PIL import Image
from loguru import logger

from llava.utils import disable_torch_init
from llava.conversation import conv_templates
Expand All @@ -15,9 +16,22 @@
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN


if torch.backends.mps.is_available():
DEVICE = torch.device("mps")
logger.info("Using MPS device")
elif torch.cuda.is_available():
DEVICE = torch.device("cuda")
logger.info("Using CUDA device")
else:
DEVICE = torch.device("cpu")
logger.info("Using CPU device")


def predict(args):
# Remove generation config from model folder
# to read generation parameters from args
logger.info(f"Starting prediction with model_path={args.model_path}")
logger.info(f"Prompt: {args.prompt}")

# Remove generation config from model folder to read generation parameters from args
model_path = os.path.expanduser(args.model_path)
generation_config = None
if os.path.exists(os.path.join(model_path, 'generation_config.json')):
Expand All @@ -28,7 +42,8 @@ def predict(args):
# Load model
disable_torch_init()
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name, device="mps")
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name, device=DEVICE)
logger.info("Model loaded successfully")

# Construct prompt
qs = args.prompt
Expand All @@ -45,7 +60,7 @@ def predict(args):
model.generation_config.pad_token_id = tokenizer.pad_token_id

# Tokenize prompt
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(torch.device("mps"))
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(DEVICE)

# Load and preprocess image
image = Image.open(args.image_file).convert('RGB')
Expand All @@ -65,7 +80,7 @@ def predict(args):
use_cache=True)

outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
print(outputs)
logger.success(f"Inference completed. Output: {outputs}")

# Restore generation config
if generation_config is not None:
Expand All @@ -84,4 +99,4 @@ def predict(args):
parser.add_argument("--num_beams", type=int, default=1)
args = parser.parse_args()

predict(args)
predict(args)
31 changes: 23 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,34 @@ name = "llava"
version = "1.2.2.post1"
description = "Towards GPT-4 like large language and visual assistant."
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.10"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
]
dependencies = [
"torch==2.6.0", "torchvision==0.21.0",
"transformers==4.48.3", "tokenizers==0.21.0", "sentencepiece==0.1.99", "shortuuid",
"accelerate==1.6.0", "peft>=0.10.0,<0.14.0", "bitsandbytes",
"pydantic", "markdown2[all]", "numpy==1.26.4", "scikit-learn==1.2.2",
"gradio==5.11.0", "requests", "uvicorn", "fastapi",
"einops==0.6.1", "einops-exts==0.0.4", "timm==1.0.15",
"coremltools==8.2"
"torch==2.6.0",
"torchvision==0.21.0",
"transformers==4.48.3",
"tokenizers==0.21.0",
"sentencepiece==0.1.99",
"shortuuid",
"accelerate==1.6.0",
"peft>=0.10.0,<0.14.0",
"bitsandbytes",
"pydantic",
"markdown2[all]",
"numpy==1.26.4",
"scikit-learn==1.2.2",
"gradio==5.11.0",
"requests",
"uvicorn",
"fastapi",
"einops==0.6.1",
"einops-exts==0.0.4",
"timm==1.0.15",
"coremltools==8.2",
"loguru>=0.7.3",
]

[project.optional-dependencies]
Expand Down
Loading