Skip to content

Commit 623155d

Browse files
committed
tutorial details
1 parent 98a0e3a commit 623155d

File tree

3 files changed

+108
-18
lines changed

3 files changed

+108
-18
lines changed

daft/ai/openai/protocols/text_embedder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __post_init__(self) -> None:
6767
)
6868

6969
def get_provider(self) -> str:
70-
return "openai"
70+
return self.provider_name
7171

7272
def get_model(self) -> str:
7373
return self.model_name

daft/functions/ai/__init__.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,14 @@ def embed_text(
7979
# load a TextEmbedderDescriptor from the resolved provider
8080
text_embedder = _resolve_provider(provider, "sentence_transformers").get_text_embedder(model, **options)
8181

82-
# implemented as class-based UDF for now
83-
expr = udf(return_dtype=text_embedder.get_dimensions().as_dtype(), concurrency=1, use_process=False)(
84-
_TextEmbedderExpression
82+
# implemented as a class-based udf for now
83+
expr_callable = udf(
84+
return_dtype=text_embedder.get_dimensions().as_dtype(),
85+
concurrency=1,
86+
use_process=False,
8587
)
88+
89+
expr = expr_callable(_TextEmbedderExpression)
8690
expr = expr.with_init_args(text_embedder)
8791
return expr(text)
8892

@@ -112,8 +116,14 @@ def embed_image(
112116
from daft.ai.protocols import ImageEmbedder
113117

114118
image_embedder = _resolve_provider(provider, "transformers").get_image_embedder(model, **options)
115-
expr = udf(return_dtype=image_embedder.get_dimensions().as_dtype(), concurrency=1, use_process=False)(
116-
_ImageEmbedderExpression
119+
120+
# implemented as a class-based udf for now
121+
expr_udf = udf(
122+
return_dtype=image_embedder.get_dimensions().as_dtype(),
123+
concurrency=1,
124+
use_process=False,
117125
)
126+
127+
expr = expr_udf(_ImageEmbedderExpression)
118128
expr = expr.with_init_args(image_embedder)
119129
return expr(image)

docs/models/index.md

Lines changed: 92 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class TextClassifierDescriptor(Descriptor[TextClassifier]):
3030
"""Descriptor for a TextClassifier implementation."""
3131
```
3232

33-
### Step 3. Add to the Provider Interface
33+
### Step 2. Add to the Provider Interface
3434

3535
You must update the Provider interface with a new method to create your descriptor. This should have
3636
a default implementation which simply raises; this makes it so that you need not update all existing providers.
@@ -39,23 +39,103 @@ a default implementation which simply raises; this makes it so that you need not
3939
# daft.ai.provider
4040
class Provider(ABC):
4141

42-
# ... existing code
42+
# ... existing code
4343

44-
def get_text_classifier(self, model: str | None = None, **options: Any) -> TextClassifierDescriptor:
45-
"""Returns a TextClassifierDescriptor for this provider."""
46-
raise not_implemented_err(self, method="classify_text")
44+
def get_text_classifier(self, model: str | None = None, **options: Any) -> TextClassifierDescriptor:
45+
"""Returns a TextClassifierDescriptor for this provider."""
46+
raise not_implemented_err(self, method="classify_text")
4747
```
4848

49-
### Step 4. Define the Function.
49+
### Step 3. Define the Function.
5050

5151
In `daft.functions.ai` you can add the function, and then re-export it in `daft.functions.__init__.py`.
52+
The implementation is responsible for resolving the provider from the given arguments, then you
53+
will call the appropriate provider method to get the relevant descriptor
5254

5355
```python
5456
def classify_text(
55-
text: Expression,
56-
labels: LabelLike | list[LabelLike],
57-
*,
58-
provider: str | Provider | None = None,
59-
model: str | None = None,
60-
) -> Expression: ...
57+
text: Expression,
58+
labels: LabelLike | list[LabelLike],
59+
*,
60+
provider: str | Provider | None = None,
61+
model: str | None = None,
62+
) -> Expression:
63+
from daft.ai._expressions import _TextClassifierExpression
64+
from daft.ai.protocols import TextClassifier
65+
66+
# Load a TextClassifierDescriptor from the resolved provider
67+
text_classifier = _resolve_provider(provider, "sentence_transformers").get_text_classifier(model, **options)
68+
69+
# Implement the expression here!
70+
71+
# This shows creating a class-based udf which holds state
72+
expr_udf = udf(
73+
return_dtype=get_type_from_labels(labels),
74+
concurrency=1,
75+
use_process=False,
76+
)
77+
78+
# We invoke the UDF with a class callable to create an Expression
79+
expr = expr_udf(_TextClassifierExpression) # <-- see step 4!
80+
expr = expr.with_init_args(text_classifier)
81+
82+
# Now pass the input arguments to the expression!
83+
return expr(text, labels)
84+
85+
86+
class _TextClassifierExpression:
87+
"""Function expression implementation for a TextClassifier protocol."""
88+
89+
text_classifier: TextClassifier
90+
91+
def __init__(self, text_classifier: TextClassifierDescriptor):
92+
# !! IMPORTANT: instantiate from the descriptor in __init__ !!
93+
self.text_classifier = text_classifier.instantiate()
94+
95+
def __call__(self, text_series: Series, labels: list[Label]) -> list[Embedding]:
96+
text = text_series.to_pylist()
97+
return self.text_classifier.classify_text(text, labels) if text else []
98+
```
99+
100+
## Step 4. Implement the Protocol for some Provider.
101+
102+
Here is a simplified example implementation of embed_text for OpenAI. This should give you
103+
and idea of where you actual logic should live, and the previous steps are to properly
104+
hook your new expression into the provider/model system.
105+
106+
```python
107+
dataclass
108+
class OpenAITextEmbedderDescriptor(TextEmbedderDescriptor):
109+
model: str # store some metadata
110+
111+
# We can use the stored metadata to instantiate the protocol implementation
112+
def instantiate(self) -> TextEmbedder:
113+
return OpenAITextEmbedder(client=OpenAI(), model=self.model)
114+
115+
@dataclass
116+
class OpenAITextEmbedder(TextEmbedder):
117+
client: OpenAI
118+
model: str
119+
120+
# This is a a imple version using the batch API. The full implementation
121+
# is uses dynamic batching and has error handling mechanisms.
122+
def embed_text(self, text: list[str]) -> list[Embedding]:
123+
response = self.client.embeddings.create(
124+
input=text,
125+
model=self.model,
126+
encoding_format="float",
127+
)
128+
return [np.array(embedding.embedding) for embedding in response.data]
129+
```
130+
131+
## Step 5. Expression Usage
132+
133+
You can now use this like any other expression.
134+
135+
```python
136+
import daft
137+
138+
df = daft.read_parquet("/path/to/file.parquet") # assuming has some column 'text'
139+
df = df.with_column("embedding", embed_text(df["text"], provider="openai")) # <- set provider to 'openai'
140+
df.show()
61141
```

0 commit comments

Comments
 (0)