From 36f1bf27a80f5877b88a915749341664cef4a1ab Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?R=C3=A9mi=20Louf?= <remilouf@gmail.com>
Date: Thu, 28 Nov 2024 23:15:15 +0100
Subject: [PATCH] Fix library top-level imports (#1296)

Users are currently running into install issues. After a clean install
of `outlines` they get an error message that asks for `transformers` to
be installed. This should not be the case, as the library is not
required for every integration. In this PR we remove `transformers` and
`datasets` top-level imports, and add per-integration optional
dependencies.

## TODO
- [x] Test `import outlines` from clean install
- [x] Test installing outlines with vLLM optional dependencies
- [x] Test installing outlines with MLX optional dependencies
- [x] Test installing outlines with transformers optional dependencies
- [x] Test installing outlines with llama-cpp optional dependencies
- [x] Test installing outlines with exllamav2 optional dependencies
- [x] Test installing outlines with openai optional dependencies
- [x] Update the documentation

Supersedes #1295.  Fixes #1263.
---
 docs/reference/models/llamacpp.md                      | 6 +++++-
 docs/reference/models/mlxlm.md                         | 6 +++++-
 docs/reference/models/openai.md                        | 6 +++++-
 docs/reference/models/transformers.md                  | 4 ++--
 docs/reference/models/vllm.md                          | 6 +++++-
 outlines/models/transformers.py                        | 4 ++--
 outlines/models/vllm.py                                | 7 ++++---
 pyproject.toml                                         | 9 ++++++++-
 tests/generate/test_integration_transformers_vision.py | 2 +-
 9 files changed, 37 insertions(+), 13 deletions(-)

diff --git a/docs/reference/models/llamacpp.md b/docs/reference/models/llamacpp.md
index 24b0fdc97..51b62eca8 100644
--- a/docs/reference/models/llamacpp.md
+++ b/docs/reference/models/llamacpp.md
@@ -4,7 +4,11 @@ Outlines provides an integration with [Llama.cpp](https://github.com/ggerganov/l
 
 !!! Note "Installation"
 
-    You need to install the `llama-cpp-python` library to use the llama.cpp integration. See the [installation section](#installation) for instructions to install `llama-cpp-python` with CUDA, Metal, ROCm and other backends.
+    You need to install the `llama-cpp-python` library to use the llama.cpp integration. See the [installation section](#installation) for instructions to install `llama-cpp-python` with CUDA, Metal, ROCm and other backends. To get started quickly you can also run:
+
+    ```bash
+    pip install "outlines[llamacpp]"
+    ```
 
 ## Load the model
 
diff --git a/docs/reference/models/mlxlm.md b/docs/reference/models/mlxlm.md
index cf7bb7443..d435b9c1f 100644
--- a/docs/reference/models/mlxlm.md
+++ b/docs/reference/models/mlxlm.md
@@ -4,7 +4,11 @@ Outlines provides an integration with [mlx-lm](https://github.com/ml-explore/mlx
 
 !!! Note "Installation"
 
-    You need to install the `mlx` and `mlx-lm` libraries on a device which [supports Metal](https://support.apple.com/en-us/102894) to use the mlx-lm integration.
+    You need to install the `mlx` and `mlx-lm` libraries on a device which [supports Metal](https://support.apple.com/en-us/102894) to use the mlx-lm integration. To get started quickly you can also run:
+
+    ```bash
+    pip install "outlines[mlxlm]"
+    ```
 
 
 ## Load the model
diff --git a/docs/reference/models/openai.md b/docs/reference/models/openai.md
index 5c737c916..638107568 100644
--- a/docs/reference/models/openai.md
+++ b/docs/reference/models/openai.md
@@ -2,7 +2,11 @@
 
 !!! Installation
 
-    You need to install the `openai` library to be able to use the OpenAI API in Outlines.
+    You need to install the `openai` library to be able to use the OpenAI API in Outlines. Or alternatively:
+
+    ```bash
+    pip install "outlines[openai]"
+    ```
 
 ## OpenAI models
 
diff --git a/docs/reference/models/transformers.md b/docs/reference/models/transformers.md
index 2a13e28ec..f4c319540 100644
--- a/docs/reference/models/transformers.md
+++ b/docs/reference/models/transformers.md
@@ -3,10 +3,10 @@
 
 !!! Installation
 
-    You need to install the `transformer`, `datasets` and `torch` libraries to be able to use these models in Outlines:
+    You need to install the `transformer`, `datasets` and `torch` libraries to be able to use these models in Outlines, or alternatively:
 
     ```bash
-    pip install torch transformers datasets
+    pip install "outlines[transformers]"
     ```
 
 
diff --git a/docs/reference/models/vllm.md b/docs/reference/models/vllm.md
index fb1c830fa..8789b588e 100644
--- a/docs/reference/models/vllm.md
+++ b/docs/reference/models/vllm.md
@@ -3,7 +3,11 @@
 
 !!! Note "Installation"
 
-    You need to install the `vllm` library to use the vLLM integration. See the [installation section](#installation) for instructions to install vLLM for CPU or ROCm.
+    You need to install the `vllm` library to use the vLLM integration. See the [installation section](#installation) for instructions to install vLLM for CPU or ROCm. To get started you can also run:
+
+    ```bash
+    pip install "outlines[vllm]"
+    ```
 
 ## Load the model
 
diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py
index 7ecc9013f..444492500 100644
--- a/outlines/models/transformers.py
+++ b/outlines/models/transformers.py
@@ -2,8 +2,6 @@
 import inspect
 from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union
 
-from datasets.fingerprint import Hasher
-
 from outlines.generate.api import GenerationParameters, SamplingParameters
 from outlines.models.tokenizer import Tokenizer
 
@@ -116,6 +114,8 @@ def __eq__(self, other):
         return NotImplemented
 
     def __hash__(self):
+        from datasets.fingerprint import Hasher
+
         return hash(Hasher.hash(self.tokenizer))
 
     def __getstate__(self):
diff --git a/outlines/models/vllm.py b/outlines/models/vllm.py
index d1f97bde2..778c27c6f 100644
--- a/outlines/models/vllm.py
+++ b/outlines/models/vllm.py
@@ -1,11 +1,10 @@
 import dataclasses
 from typing import TYPE_CHECKING, List, Optional, Union
 
-from transformers import SPIECE_UNDERLINE, PreTrainedTokenizerBase
-
 from outlines.generate.api import GenerationParameters, SamplingParameters
 
 if TYPE_CHECKING:
+    from transformers import PreTrainedTokenizerBase
     from vllm import LLM
     from vllm.sampling_params import SamplingParams
 
@@ -188,7 +187,7 @@ def vllm(model_name: str, **vllm_model_params):
     return VLLM(model)
 
 
-def adapt_tokenizer(tokenizer: PreTrainedTokenizerBase) -> PreTrainedTokenizerBase:
+def adapt_tokenizer(tokenizer: "PreTrainedTokenizerBase") -> "PreTrainedTokenizerBase":
     """Adapt a tokenizer to use to compile the FSM.
 
     The API of Outlines tokenizers is slightly different to that of `transformers`. In
@@ -205,6 +204,8 @@ def adapt_tokenizer(tokenizer: PreTrainedTokenizerBase) -> PreTrainedTokenizerBa
     PreTrainedTokenizerBase
         The adapted tokenizer.
     """
+    from transformers import SPIECE_UNDERLINE
+
     tokenizer.vocabulary = tokenizer.get_vocab()
     tokenizer.special_tokens = set(tokenizer.all_special_tokens)
 
diff --git a/pyproject.toml b/pyproject.toml
index 294fbe4b8..5b005cfbd 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -36,7 +36,6 @@ dependencies = [
    "jsonschema",
    "requests",
    "tqdm",
-   "datasets",
    "typing_extensions",
    "pycountry",
    "airportsdata",
@@ -46,6 +45,12 @@ dependencies = [
 dynamic = ["version"]
 
 [project.optional-dependencies]
+vllm = ["vllm", "transformers", "numpy2"]
+transformers = ["transformers", "accelerate", "datasets", "numpy<2"]
+mlxlm = ["mlx-lm", "datasets"]
+openai = ["openai"]
+llamacpp = ["llama-cpp-python", "transformers", "datasets", "numpy<2"]
+exllamav2 = ["exllamav2"]
 test = [
     "pre-commit",
     "pytest",
@@ -61,10 +66,12 @@ test = [
     "mlx-lm>=0.19.2; platform_machine == 'arm64' and sys_platform == 'darwin'",
     "huggingface_hub",
     "openai>=1.0.0",
+    "datasets",
     "vllm; sys_platform != 'darwin'",
     "transformers",
     "pillow",
     "exllamav2",
+    "jax"
 ]
 serve = [
     "vllm>=0.3.0",
diff --git a/tests/generate/test_integration_transformers_vision.py b/tests/generate/test_integration_transformers_vision.py
index 28b516c57..ee4f84c06 100644
--- a/tests/generate/test_integration_transformers_vision.py
+++ b/tests/generate/test_integration_transformers_vision.py
@@ -23,7 +23,7 @@ def img_from_url(url):
 @pytest.fixture(scope="session")
 def model(tmp_path_factory):
     return transformers_vision(
-        "trl-internal-testing/tiny-random-LlavaForConditionalGeneration",
+        "trl-internal-testing/tiny-LlavaForConditionalGeneration",
         model_class=LlavaForConditionalGeneration,
         device="cpu",
     )