Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: google-github-actions/auth@v3
with:
credentials_json: ${{ secrets.GCP_CREDENTIALS }}
- uses: astral-sh/setup-uv@v7
with:
enable-cache: true
Expand Down
7 changes: 6 additions & 1 deletion packages/lmi/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ requires-python = ">=3.11"
[project.optional-dependencies]
dev = [
"fhaviary[xml]",
"fhlmi[local,progress,typing,vcr]",
"fhlmi[image,local,progress,typing,vcr]",
"google-auth>=2", # Pin to keep recent
"google-cloud-storage>=3", # Pin to keep recent
"httpx-aiohttp",
"ipython>=8", # Pin to keep recent
"litellm>=1.71", # Lower pin for aiohttp transport adoption
Expand All @@ -59,6 +61,9 @@ dev = [
"refurb>=2", # Pin to keep recent
"typeguard",
]
image = [
"pillow>=10.3.0", # Pin for py.typed
]
local = [
"numpy",
"sentence-transformers",
Expand Down
46 changes: 36 additions & 10 deletions packages/lmi/src/lmi/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,37 @@
from lmi.cost_tracker import track_costs
from lmi.llms import PassThroughRouter
from lmi.rate_limiter import GLOBAL_LIMITER
from lmi.utils import get_litellm_retrying_config
from lmi.utils import get_litellm_retrying_config, is_encoded_image

URL_ENCODED_IMAGE_TOKEN_ESTIMATE = 85 # tokens


def estimate_tokens(
document: str
| list[str]
| list[litellm.ChatCompletionImageObject]
| list[litellm.types.llms.vertex_ai.PartType],
) -> float:
"""Estimate token count for rate limiting purposes."""
if isinstance(document, str): # Text or a data URL
return (
URL_ENCODED_IMAGE_TOKEN_ESTIMATE
if is_encoded_image(document)
else len(document) / CHARACTERS_PER_TOKEN_ASSUMPTION
)
# For multimodal content, estimate based on text parts and add fixed cost for images
token_count = 0.0
for part in document:
if isinstance(part, str): # Part of a batch of text or data URLs
token_count += estimate_tokens(part)
# Handle different multimodal formats
elif part.get("type") == "image_url": # OpenAI format
token_count += URL_ENCODED_IMAGE_TOKEN_ESTIMATE
elif ( # Gemini text format -- https://ai.google.dev/api#text-only-prompt
"text" in part
):
token_count += len(part["text"]) / CHARACTERS_PER_TOKEN_ASSUMPTION # type: ignore[typeddict-item]
return token_count


class EmbeddingModes(StrEnum):
Expand All @@ -39,7 +69,7 @@ def set_mode(self, mode: EmbeddingModes) -> None:

@abstractmethod
async def embed_documents(self, texts: list[str]) -> list[list[float]]:
pass
"""Embed a list of documents."""

async def embed_document(self, text: str) -> list[float]:
return (await self.embed_documents([text]))[0]
Expand Down Expand Up @@ -138,7 +168,7 @@ def _truncate_if_large(self, texts: list[str]) -> list[str]:
# heuristic about ratio of tokens to characters
conservative_char_token_ratio = 3
maybe_too_large = max_tokens * conservative_char_token_ratio
if any(len(t) > maybe_too_large for t in texts):
if any(len(t) > maybe_too_large for t in texts if not is_encoded_image(t)):
try:
enct = tiktoken.encoding_for_model("cl100k_base")
enc_batch = enct.encode_ordinary_batch(texts)
Expand All @@ -154,16 +184,12 @@ async def embed_documents(self, texts: list[str]) -> list[list[float]]:
N = len(texts)
embeddings = []
for i in range(0, N, batch_size):
await self.check_rate_limit(
sum(
len(t) / CHARACTERS_PER_TOKEN_ASSUMPTION
for t in texts[i : i + batch_size]
)
)
batch = texts[i : i + batch_size]
await self.check_rate_limit(sum(estimate_tokens(t) for t in batch))

response = await track_costs(self.router.aembedding)(
model=self.name,
input=texts[i : i + batch_size],
input=batch,
dimensions=self.ndim,
**self.config.get("kwargs", {}),
)
Expand Down
52 changes: 52 additions & 0 deletions packages/lmi/src/lmi/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import base64
import contextlib
import logging
import logging.config
Expand All @@ -15,7 +16,10 @@
tqdm = None # type: ignore[assignment,misc]

if TYPE_CHECKING:
from typing import IO

import vcr.request
from PIL._typing import StrOrBytesPath


def configure_llm_logs() -> None:
Expand Down Expand Up @@ -127,3 +131,51 @@ def update_litellm_max_callbacks(value: int = 1000) -> None:
SEE: https://github.com/BerriAI/litellm/issues/9792
"""
litellm.litellm_core_utils.logging_callback_manager.LoggingCallbackManager.MAX_CALLBACKS = value


def bytes_to_string(value: bytes) -> str:
"""Convert bytes to a base64-encoded string."""
# 1. Convert bytes to base64 bytes
# 2. Convert base64 bytes to base64 string,
# using UTF-8 since base64 produces ASCII characters
return base64.b64encode(value).decode("utf-8")


def string_to_bytes(value: str) -> bytes:
"""Convert a base64-encoded string to bytes."""
# 1. Convert base64 string to base64 bytes (the noqa comment is to make this clear)
# 2. Convert base64 bytes to original bytes
return base64.b64decode(value.encode("utf-8")) # noqa: FURB120


def validate_image(path: "StrOrBytesPath | IO[bytes]") -> None:
"""
Validate that the file at the given path is a valid image.

Raises:
OSError: If the image file is truncated.
""" # noqa: DOC502
try:
from PIL import Image
except ImportError as exc:
raise ImportError(
"Image validation requires the 'image' extra for 'pillow'. Please:"
" `pip install fhlmi[image]`."
) from exc

with Image.open(path) as img:
img.load()


def encode_image_as_url(image_type: str, image_data: bytes | str) -> str:
"""Convert image data to an RFC 2397 data URL format."""
if isinstance(image_data, bytes):
image_data = bytes_to_string(image_data)
return f"data:image/{image_type};base64,{image_data}"


def is_encoded_image(image: str) -> bool:
"""Check if the given image is a GCS URL or a RFC 2397 data URL."""
return image.startswith("gs://") or (
image.startswith("data:image/") and ";base64," in image
)
25 changes: 25 additions & 0 deletions packages/lmi/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytest
import vcr.stubs.httpx_stubs
from dotenv import load_dotenv
from google.cloud.storage import Client

from lmi.utils import (
ANTHROPIC_API_KEY_HEADER,
Expand All @@ -23,6 +24,7 @@

TESTS_DIR = Path(__file__).parent
CASSETTES_DIR = TESTS_DIR / "cassettes"
STUB_DATA_DIR = TESTS_DIR / "stub_data"


@pytest.fixture(autouse=True, scope="session")
Expand Down Expand Up @@ -84,6 +86,29 @@ def fixture_reset_log_levels(caplog) -> Iterator[None]:
logger.propagate = True


@pytest.fixture(name="png_image", scope="session")
def fixture_png_image() -> bytes:
with (STUB_DATA_DIR / "sf_districts.png").open("rb") as f:
return f.read()


TMP_LMI_TEST_GCS_BUCKET = "tmp-lmi-test"


@pytest.fixture(name="png_image_gcs", scope="session")
def fixture_png_image_gcs(png_image: bytes) -> str:
"""Get or create a temporary GCS bucket, upload test image, and return GCS URL."""
client = Client()
bucket = client.bucket(TMP_LMI_TEST_GCS_BUCKET)
if not bucket.exists(): # Get or create the bucket
bucket = client.create_bucket(bucket)
blob_name = "sf_districts.png"
blob = bucket.blob(blob_name)
if not blob.exists(client):
blob.upload_from_string(png_image, content_type="image/png")
return f"gs://{TMP_LMI_TEST_GCS_BUCKET}/{blob_name}"


class PreReadCompatibleAiohttpResponseStream(
httpx_aiohttp.transport.AiohttpResponseStream
):
Expand Down
Binary file added packages/lmi/tests/stub_data/sf_districts.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
84 changes: 83 additions & 1 deletion packages/lmi/tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import litellm
import pytest
import tiktoken
from litellm.caching import Cache, InMemoryCache
from pytest_subtests import SubTests

Expand All @@ -15,8 +16,34 @@
SentenceTransformerEmbeddingModel,
SparseEmbeddingModel,
embedding_model_factory,
estimate_tokens,
)
from lmi.utils import VCR_DEFAULT_MATCH_ON
from lmi.utils import VCR_DEFAULT_MATCH_ON, encode_image_as_url


def test_estimate_tokens(subtests: SubTests, png_image: bytes) -> None:
with subtests.test(msg="text only"):
text_only = "Hello world"
text_only_estimated_token_count = estimate_tokens(text_only)
assert text_only_estimated_token_count == 2.75, (
"Expected a reasonable token estimate"
)
text_only_actual_token_count = len(
tiktoken.get_encoding("cl100k_base").encode(text_only)
)
assert text_only_estimated_token_count == pytest.approx(
text_only_actual_token_count, abs=1
), "Estimation should be within one token of what tiktoken"

# Test multimodal (text + image)
with subtests.test(msg="multimodal"): # Text + image
multimodal = [
"What is in this image?",
encode_image_as_url(image_type="png", image_data=png_image),
]
assert estimate_tokens(multimodal) == 90.5, (
"Expected a reasonable token estimate"
)


class TestLiteLLMEmbeddingModel:
Expand Down Expand Up @@ -231,6 +258,61 @@ async def test_router_usage(
# Confirm use of the sentinel timeout in the Router's model_list or pass through
assert mock_aembedding.call_args.kwargs["timeout"] == self.SENTINEL_TIMEOUT

@pytest.mark.asyncio
async def test_multimodal_embedding(
self, subtests: SubTests, png_image_gcs: str
) -> None:
multimodal_model = LiteLLMEmbeddingModel(
name=f"{litellm.LlmProviders.VERTEX_AI.value}/multimodalembedding@001"
)

with subtests.test(msg="text or image only"):
embedding_text_only = await multimodal_model.embed_document("Some text")
assert len(embedding_text_only) == 1408
assert all(isinstance(x, float) for x in embedding_text_only)

embedding_image_only = await multimodal_model.embed_document(png_image_gcs)
assert len(embedding_image_only) == 1408
assert all(isinstance(x, float) for x in embedding_image_only)

assert embedding_image_only != embedding_text_only

with (
subtests.test(msg="denies two texts"),
pytest.raises(litellm.BadRequestError, match="one instance"),
):
# This is more of a confirmation/demonstration that Vertex AI denies any
# embedding request containing >1 text or >1 image
await multimodal_model.embed_documents(["A", "B"])

with subtests.test(msg="text and image mixing"):
(embedding_image_text,) = await multimodal_model.embed_documents([
"What is in this image?",
png_image_gcs,
])
assert len(embedding_image_text) == 1408
assert all(isinstance(x, float) for x in embedding_image_text)

(embedding_two_images,) = await multimodal_model.embed_documents([
png_image_gcs,
png_image_gcs,
])
assert len(embedding_two_images) == 1408
assert all(isinstance(x, float) for x in embedding_two_images)

assert embedding_image_text != embedding_two_images

with subtests.test(msg="batching"):
multimodal_model.config["batch_size"] = 1
embeddings = await multimodal_model.embed_documents([
"Some text",
png_image_gcs,
])
assert len(embeddings) == 2
for embedding in embeddings:
assert len(embedding) == 1408
assert all(isinstance(x, float) for x in embedding)


@pytest.mark.asyncio
async def test_sparse_embedding_model(subtests: SubTests):
Expand Down
25 changes: 25 additions & 0 deletions packages/lmi/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import base64

import pytest

from lmi.utils import bytes_to_string, string_to_bytes


@pytest.mark.parametrize(
"value",
[
pytest.param(b"Hello, World!", id="simple-text"),
pytest.param(b"", id="empty-bytes"),
pytest.param(bytes([0, 1, 2, 255, 128, 64]), id="binary-data"),
pytest.param(b"Test data for base64 encoding", id="base64-validation"),
pytest.param("Hello 世界 🌍".encode(), id="utf8-text"),
],
)
def test_str_bytes_conversions(value: bytes) -> None:
# Test round-trip conversion
encoded_string = bytes_to_string(value)
decoded_bytes = string_to_bytes(encoded_string)
assert decoded_bytes == value

# Validate that encoded string is valid base64
assert base64.b64decode(encoded_string) == value
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ ignore_missing_imports = true
module = [
"accelerate.*", # SEE: https://github.com/huggingface/accelerate/issues/2396
"dask_cuda",
"google.cloud.storage", # SEE: https://github.com/googleapis/python-storage/issues/393
"networkx", # SEE: https://github.com/networkx/networkx/issues/3988
"pydot",
"transformers.*", # SEE: https://github.com/huggingface/transformers/pull/18485
Expand Down
Loading