@@ -30,7 +30,7 @@ class TextClassifierDescriptor(Descriptor[TextClassifier]):
30
30
""" Descriptor for a TextClassifier implementation."""
31
31
```
32
32
33
- ### Step 3 . Add to the Provider Interface
33
+ ### Step 2 . Add to the Provider Interface
34
34
35
35
You must update the Provider interface with a new method to create your descriptor. This should have
36
36
a default implementation which simply raises; this makes it so that you need not update all existing providers.
@@ -39,23 +39,103 @@ a default implementation which simply raises; this makes it so that you need not
39
39
# daft.ai.provider
40
40
class Provider (ABC ):
41
41
42
- # ... existing code
42
+ # ... existing code
43
43
44
- def get_text_classifier (self , model : str | None = None , ** options : Any) -> TextClassifierDescriptor:
45
- """ Returns a TextClassifierDescriptor for this provider."""
46
- raise not_implemented_err(self , method = " classify_text" )
44
+ def get_text_classifier (self , model : str | None = None , ** options : Any) -> TextClassifierDescriptor:
45
+ """ Returns a TextClassifierDescriptor for this provider."""
46
+ raise not_implemented_err(self , method = " classify_text" )
47
47
```
48
48
49
- ### Step 4 . Define the Function.
49
+ ### Step 3 . Define the Function.
50
50
51
51
In ` daft.functions.ai ` you can add the function, and then re-export it in ` daft.functions.__init__.py ` .
52
+ The implementation is responsible for resolving the provider from the given arguments, then you
53
+ will call the appropriate provider method to get the relevant descriptor
52
54
53
55
``` python
54
56
def classify_text (
55
- text : Expression,
56
- labels : LabelLike | list[LabelLike],
57
- * ,
58
- provider : str | Provider | None = None ,
59
- model : str | None = None ,
60
- ) -> Expression: ...
57
+ text : Expression,
58
+ labels : LabelLike | list[LabelLike],
59
+ * ,
60
+ provider : str | Provider | None = None ,
61
+ model : str | None = None ,
62
+ ) -> Expression:
63
+ from daft.ai._expressions import _TextClassifierExpression
64
+ from daft.ai.protocols import TextClassifier
65
+
66
+ # Load a TextClassifierDescriptor from the resolved provider
67
+ text_classifier = _resolve_provider(provider, " sentence_transformers" ).get_text_classifier(model, ** options)
68
+
69
+ # Implement the expression here!
70
+
71
+ # This shows creating a class-based udf which holds state
72
+ expr_udf = udf(
73
+ return_dtype = get_type_from_labels(labels),
74
+ concurrency = 1 ,
75
+ use_process = False ,
76
+ )
77
+
78
+ # We invoke the UDF with a class callable to create an Expression
79
+ expr = expr_udf(_TextClassifierExpression) # <-- see step 4!
80
+ expr = expr.with_init_args(text_classifier)
81
+
82
+ # Now pass the input arguments to the expression!
83
+ return expr(text, labels)
84
+
85
+
86
+ class _TextClassifierExpression :
87
+ """ Function expression implementation for a TextClassifier protocol."""
88
+
89
+ text_classifier: TextClassifier
90
+
91
+ def __init__ (self , text_classifier : TextClassifierDescriptor):
92
+ # !! IMPORTANT: instantiate from the descriptor in __init__ !!
93
+ self .text_classifier = text_classifier.instantiate()
94
+
95
+ def __call__ (self , text_series : Series, labels : list[Label]) -> list[Embedding]:
96
+ text = text_series.to_pylist()
97
+ return self .text_classifier.classify_text(text, labels) if text else []
98
+ ```
99
+
100
+ ## Step 4. Implement the Protocol for some Provider.
101
+
102
+ Here is a simplified example implementation of embed_text for OpenAI. This should give you
103
+ and idea of where you actual logic should live, and the previous steps are to properly
104
+ hook your new expression into the provider/model system.
105
+
106
+ ``` python
107
+ dataclass
108
+ class OpenAITextEmbedderDescriptor (TextEmbedderDescriptor ):
109
+ model: str # store some metadata
110
+
111
+ # We can use the stored metadata to instantiate the protocol implementation
112
+ def instantiate (self ) -> TextEmbedder:
113
+ return OpenAITextEmbedder(client = OpenAI(), model = self .model)
114
+
115
+ @dataclass
116
+ class OpenAITextEmbedder (TextEmbedder ):
117
+ client: OpenAI
118
+ model: str
119
+
120
+ # This is a a imple version using the batch API. The full implementation
121
+ # is uses dynamic batching and has error handling mechanisms.
122
+ def embed_text (self , text : list[str ]) -> list[Embedding]:
123
+ response = self .client.embeddings.create(
124
+ input = text,
125
+ model = self .model,
126
+ encoding_format = " float" ,
127
+ )
128
+ return [np.array(embedding.embedding) for embedding in response.data]
129
+ ```
130
+
131
+ ## Step 5. Expression Usage
132
+
133
+ You can now use this like any other expression.
134
+
135
+ ``` python
136
+ import daft
137
+
138
+ df = daft.read_parquet(" /path/to/file.parquet" ) # assuming has some column 'text'
139
+ df = df.with_column(" embedding" , embed_text(df[" text" ], provider = " openai" )) # <- set provider to 'openai'
140
+ df.show()
61
141
```
0 commit comments