Skip to content

Commit

Permalink
Add Florence 2 model series by Microsoft (#42)
Browse files Browse the repository at this point in the history
* initial implementation

* simplify

* add other florence 2

* add florence 2 test

* Bump version: 0.1.1 → 0.1.2

* update readme

* update quickstart
  • Loading branch information
dnth authored Oct 29, 2024
1 parent 7a28291 commit ae9964b
Show file tree
Hide file tree
Showing 8 changed files with 507 additions and 7 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,13 @@ pip install -e .
<td><pre lang="python"><code>xinfer.create_model("fancyfeast/llama-joycaption-alpha-two-hf-llava")</code></pre></td>
</tr>
<tr>
<td><a href="https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct">Llama-3.2 Vision</a></td>
<td><a href="https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct">Llama-3.2 Vision Series</a></td>
<td><pre lang="python"><code>xinfer.create_model("meta-llama/Llama-3.2-11B-Vision-Instruct")</code></pre></td>
</tr>
<tr>
<td><a href="https://huggingface.co/microsoft/Florence-2-base-ft">Florence-2 Series</a></td>
<td><pre lang="python"><code>xinfer.create_model("microsoft/Florence-2-base-ft")</code></pre></td>
</tr>
</tbody>
</table>
</body>
Expand Down
372 changes: 372 additions & 0 deletions nbs/florence-2.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions nbs/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"It's recommended to restart the kernel once all the dependencies are installed. Uncomment the following line to restart the kernel."
"It's recommended to restart the kernel once all the dependencies are installed."
]
},
{
Expand All @@ -119,8 +119,8 @@
"metadata": {},
"outputs": [],
"source": [
"# from IPython import get_ipython\n",
"# get_ipython().kernel.do_shutdown(restart=True)"
"from IPython import get_ipython\n",
"get_ipython().kernel.do_shutdown(restart=True)"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "xinfer"
version = "0.1.1"
version = "0.1.2"
dynamic = [
"dependencies",
]
Expand Down Expand Up @@ -48,7 +48,7 @@ universal = true


[tool.bumpversion]
current_version = "0.1.1"
current_version = "0.1.2"
commit = true
tag = true

Expand Down
40 changes: 40 additions & 0 deletions tests/smoke/test_florence2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from pathlib import Path

import pytest
import torch

import xinfer


@pytest.fixture
def model():
return xinfer.create_model(
"microsoft/Florence-2-base-ft", device="cpu", dtype="float32"
)


@pytest.fixture
def test_image():
return str(Path(__file__).parent.parent / "test_data" / "test_image_1.jpg")


def test_florence2_initialization(model):
assert model.model_id == "microsoft/Florence-2-base-ft"
assert model.device == "cpu"
assert model.dtype == torch.float32


def test_florence2_inference(model, test_image):
prompt = "<CAPTION>"
result = model.infer(test_image, prompt)

assert isinstance(result, str)
assert len(result) > 0


def test_florence2_batch_inference(model, test_image):
prompt = "<CAPTION>"
result = model.infer_batch([test_image, test_image], [prompt, prompt])

assert isinstance(result, list)
assert len(result) == 2
2 changes: 1 addition & 1 deletion xinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__author__ = """Dickson Neoh"""
__email__ = "[email protected]"
__version__ = "0.1.1"
__version__ = "0.1.2"

from .core import create_model, list_models
from .model_registry import ModelInputOutput, register_model
Expand Down
1 change: 1 addition & 0 deletions xinfer/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .blip2 import BLIP2
from .florence2 import Florence2
from .joycaption import JoyCaption
from .llama32 import Llama32Vision, Llama32VisionInstruct
from .moondream import Moondream
Expand Down
83 changes: 83 additions & 0 deletions xinfer/transformers/florence2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
from transformers import AutoModelForCausalLM, AutoProcessor

from ..model_registry import ModelInputOutput, register_model
from ..models import BaseModel, track_inference


@register_model(
"microsoft/Florence-2-large", "transformers", ModelInputOutput.IMAGE_TEXT_TO_TEXT
)
@register_model(
"microsoft/Florence-2-base",
"transformers",
ModelInputOutput.IMAGE_TEXT_TO_TEXT,
)
@register_model(
"microsoft/Florence-2-large-ft",
"transformers",
ModelInputOutput.IMAGE_TEXT_TO_TEXT,
)
@register_model(
"microsoft/Florence-2-base-ft",
"transformers",
ModelInputOutput.IMAGE_TEXT_TO_TEXT,
)
class Florence2(BaseModel):
def __init__(
self,
model_id: str,
device: str = "cpu",
dtype: str = "float32",
):
super().__init__(model_id, device, dtype)
self.load_model()

def load_model(self):
self.model = AutoModelForCausalLM.from_pretrained(
self.model_id, trust_remote_code=True
).to(self.device, self.dtype)
self.model.eval()
self.model = torch.compile(self.model, mode="max-autotune")
self.processor = AutoProcessor.from_pretrained(
self.model_id, trust_remote_code=True
)

@track_inference
def infer(self, image: str, prompt: str = None, **generate_kwargs) -> str:
output = self.infer_batch([image], [prompt], **generate_kwargs)
return output[0]

@track_inference
def infer_batch(
self, images: list[str], prompts: list[str] = None, **generate_kwargs
) -> list[str]:
images = self.parse_images(images)
inputs = self.processor(text=prompts, images=images, return_tensors="pt").to(
self.device, self.dtype
)

if "max_new_tokens" not in generate_kwargs:
generate_kwargs["max_new_tokens"] = 1024
if "num_beams" not in generate_kwargs:
generate_kwargs["num_beams"] = 3

with torch.inference_mode():
generated_ids = self.model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
**generate_kwargs,
)

generated_text = self.processor.batch_decode(
generated_ids, skip_special_tokens=False
)

parsed_answers = [
self.processor.post_process_generation(
text, task=prompt, image_size=(img.width, img.height)
).get(prompt)
for text, prompt, img in zip(generated_text, prompts, images)
]

return parsed_answers

0 comments on commit ae9964b

Please sign in to comment.