Skip to content

Commit 2a662e7

Browse files
authored
[DEVX-210]: YAML based workflow creation (#175)
1 parent 4356d24 commit 2a662e7

25 files changed

+811
-31
lines changed

README.md

+16
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,23 @@ all_workflow = app.list_workflow()
195195
# List all workflow in community filtered by description
196196
all_face_community_workflows = App().list_workflows(filter_by={"query": "face"}, only_in_app=False) # Get all face related workflows
197197
```
198+
#### Workflow Create
199+
Create a new workflow specified by a yaml config file.
200+
```python
201+
# Note: CLARIFAI_PAT must be set as env variable.
202+
from clarifai.client.app import App
203+
app = App(app_id="app_id", user_id="user_id")
204+
workflow = app.create_workflow(config_filepath="config.yml")
205+
```
198206

207+
#### Workflow Export
208+
Export an existing workflow from Clarifai as a local yaml file.
209+
```python
210+
# Note: CLARIFAI_PAT must be set as env variable.
211+
from clarifai.client.workflow import Workflow
212+
workflow = Workflow("https://clarifai.com/clarifai/main/workflows/Demographics")
213+
workflow.export('demographics_workflow.yml')
214+
```
199215

200216
## More Examples
201217
See many more code examples in this [repo](https://github.com/Clarifai/examples).

clarifai/client/app.py

+92-11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import os
2+
import uuid
13
from typing import Any, Dict, List
24

5+
import yaml
36
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
47
from clarifai_grpc.grpc.api.status import status_code_pb2
58
from google.protobuf.json_format import MessageToDict
@@ -13,7 +16,9 @@
1316
from clarifai.client.workflow import Workflow
1417
from clarifai.errors import UserError
1518
from clarifai.urls.helper import ClarifaiUrlHelper
16-
from clarifai.utils.logging import get_logger
19+
from clarifai.utils.logging import display_workflow_tree, get_logger
20+
from clarifai.workflows.utils import get_yaml_output_info_proto, is_same_yaml_model
21+
from clarifai.workflows.validate import validate
1722

1823

1924
class App(Lister, BaseClient):
@@ -236,30 +241,98 @@ def create_model(self, model_id: str, **kwargs) -> Model:
236241

237242
return Model(model_id=model_id, **kwargs)
238243

239-
def create_workflow(self, workflow_id: str, **kwargs) -> Workflow:
244+
def create_workflow(self,
245+
config_filepath: str,
246+
generate_new_id: bool = False,
247+
display: bool = True) -> Workflow:
240248
"""Creates a workflow for the app.
241249
242250
Args:
243-
workflow_id (str): The workflow ID for the workflow to create.
244-
**kwargs: Additional keyword arguments to be passed to the workflow.
251+
config_filepath (str): The path to the yaml workflow config file.
252+
generate_new_id (bool): If True, generate a new workflow ID.
253+
display (bool): If True, display the workflow nodes tree.
245254
246255
Returns:
247-
Workflow: A Workflow object for the specified workflow ID.
256+
Workflow: A Workflow object for the specified workflow config.
248257
249258
Example:
250259
>>> from clarifai.client.app import App
251260
>>> app = App(app_id="app_id", user_id="user_id")
252-
>>> workflow = app.create_workflow(workflow_id="workflow_id")
261+
>>> workflow = app.create_workflow(config_filepath="config.yml")
253262
"""
263+
if not os.path.exists(config_filepath):
264+
raise UserError(f"Workflow config file not found at {config_filepath}")
265+
266+
with open(config_filepath, 'r') as file:
267+
data = yaml.safe_load(file)
268+
269+
data = validate(data)
270+
workflow = data['workflow']
271+
272+
# Get all model objects from the workflow nodes.
273+
all_models = []
274+
for node in workflow['nodes']:
275+
output_info = get_yaml_output_info_proto(node['model'].get('output_info', None))
276+
try:
277+
model = self.model(
278+
node['model']['model_id'],
279+
node['model'].get('model_version_id', ""),
280+
user_id=node['model'].get('user_id', ""),
281+
app_id=node['model'].get('app_id', ""))
282+
except Exception as e:
283+
if "Model does not exist" in str(e):
284+
model = self.create_model(
285+
**{k: v
286+
for k, v in node['model'].items() if k != 'output_info'})
287+
model_version = model.create_model_version(
288+
model_id=node['model']['model_id'], output_info=output_info)
289+
all_models.append(model_version.model_info)
290+
continue
291+
292+
# If the model version ID is specified, or if the yaml model is the same as the one in the api
293+
if node["model"].get("model_version_id", "") or is_same_yaml_model(
294+
model.model_info, node["model"]):
295+
all_models.append(model.model_info)
296+
else: # Create a new model version
297+
model = model.create_model_version(
298+
model_id=node['model']['model_id'], output_info=output_info)
299+
all_models.append(model.model_info)
300+
301+
# Convert nodes to resources_pb2.WorkflowNodes.
302+
nodes = []
303+
for i, yml_node in enumerate(workflow['nodes']):
304+
node = resources_pb2.WorkflowNode(
305+
id=yml_node['id'],
306+
model=all_models[i],
307+
)
308+
# Add node inputs if they exist, i.e. if these nodes do not connect directly to the input.
309+
if yml_node.get("node_inputs"):
310+
for ni in yml_node.get("node_inputs"):
311+
node.node_inputs.append(resources_pb2.NodeInput(node_id=ni['node_id']))
312+
nodes.append(node)
313+
314+
workflow_id = workflow['id']
315+
if generate_new_id:
316+
workflow_id = str(uuid.uuid4())
317+
318+
# Create the workflow.
254319
request = service_pb2.PostWorkflowsRequest(
255-
user_app_id=self.user_app_id, workflows=[resources_pb2.Workflow(id=workflow_id, **kwargs)])
320+
user_app_id=self.user_app_id,
321+
workflows=[resources_pb2.Workflow(id=workflow_id, nodes=nodes)])
322+
256323
response = self._grpc_request(self.STUB.PostWorkflows, request)
257324
if response.status.code != status_code_pb2.SUCCESS:
258325
raise Exception(response.status)
259326
self.logger.info("\nWorkflow created\n%s", response.status)
260-
kwargs.update({'app_id': self.id, 'user_id': self.user_id})
261327

262-
return Workflow(workflow_id=workflow_id, **kwargs)
328+
dict_response = MessageToDict(response, preserving_proto_field_name=True)
329+
# Display the workflow nodes tree.
330+
if display:
331+
display_workflow_tree(dict_response["workflows"][0]["nodes"])
332+
kwargs = self.process_response_keys(dict_response[list(dict_response.keys())[1]][0],
333+
"workflow")
334+
335+
return Workflow(**kwargs)
263336

264337
def create_module(self, module_id: str, description: str, **kwargs) -> Module:
265338
"""Creates a module for the app.
@@ -329,8 +402,16 @@ def model(self, model_id: str, model_version_id: str = "", **kwargs) -> Model:
329402
>>> app = App(app_id="app_id", user_id="user_id")
330403
>>> model_v1 = app.model(model_id="model_id", model_version_id="model_version_id")
331404
"""
332-
request = service_pb2.GetModelRequest(
333-
user_app_id=self.user_app_id, model_id=model_id, version_id=model_version_id)
405+
# Change user_app_id based on whether user_id or app_id is specified.
406+
if kwargs.get("user_id") or kwargs.get("app_id"):
407+
request = service_pb2.GetModelRequest(
408+
user_app_id=self.auth_helper.get_user_app_id_proto(
409+
kwargs.get("user_id"), kwargs.get("app_id")),
410+
model_id=model_id,
411+
version_id=model_version_id)
412+
else:
413+
request = service_pb2.GetModelRequest(
414+
user_app_id=self.user_app_id, model_id=model_id, version_id=model_version_id)
334415
response = self._grpc_request(self.STUB.GetModel, request)
335416

336417
if response.status.code != status_code_pb2.SUCCESS:

clarifai/client/base.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from datetime import datetime
33
from typing import Any, Callable
44

5-
from google.protobuf.json_format import MessageToDict # noqa
5+
from google.protobuf import struct_pb2
66
from google.protobuf.timestamp_pb2 import Timestamp
77
from google.protobuf.wrappers_pb2 import BoolValue
88

@@ -71,7 +71,10 @@ def convert_string_to_timestamp(self, date_str) -> Timestamp:
7171
try:
7272
datetime_obj = datetime.strptime(date_str, '%Y-%m-%dT%H:%M:%S.%fZ')
7373
except ValueError:
74-
datetime_obj = datetime.strptime(date_str, '%Y-%m-%dT%H:%M:%SZ')
74+
try:
75+
datetime_obj = datetime.strptime(date_str, '%Y-%m-%dT%H:%M:%SZ')
76+
except ValueError:
77+
return Timestamp()
7578

7679
# Convert the datetime object to a Timestamp object
7780
timestamp_obj = Timestamp()
@@ -99,8 +102,12 @@ def convert_recursive(item):
99102
value = self.convert_string_to_timestamp(value)
100103
elif key in ['workflow_recommended']:
101104
value = BoolValue(value=True)
102-
elif key in ['metadata', 'fields_map', 'params']:
103-
continue # TODO Fix "app_duplication",proto struct
105+
elif key in ['fields_map', 'params']:
106+
value_s = struct_pb2.Struct()
107+
value_s.update(value)
108+
value = value_s
109+
elif key in ['metadata']:
110+
continue # TODO Fix "app_duplication"
104111
new_item[key] = convert_recursive(value)
105112
return new_item
106113
elif isinstance(item, list):

clarifai/client/model.py

+68
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
66
from clarifai_grpc.grpc.api.resources_pb2 import Input
77
from clarifai_grpc.grpc.api.status import status_code_pb2
8+
from google.protobuf.json_format import MessageToDict
89

910
from clarifai.client.base import BaseClient
1011
from clarifai.client.lister import Lister
@@ -52,6 +53,73 @@ def __init__(self,
5253
BaseClient.__init__(self, user_id=self.user_id, app_id=self.app_id)
5354
Lister.__init__(self)
5455

56+
def create_model_version(self, model_id: str, **kwargs) -> 'Model':
57+
"""Creates a model version for the Model.
58+
59+
Args:
60+
model_id (str): The model ID for the model to create.
61+
**kwargs: Additional keyword arguments to be passed to Model Version.
62+
- description (str): The description of the model version.
63+
- concepts (list[Concept]): The concepts to associate with the model version.
64+
- output_info (resources_pb2.OutputInfo(): The output info to associate with the model version.
65+
66+
Returns:
67+
Model: A Model object for the specified model ID.
68+
69+
Example:
70+
>>> from clarifai.client.model import Model
71+
>>> model = Model("model_url")
72+
or
73+
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
74+
>>> model_version = model.create_model_version(model_id='model_id', description='model_version_description')
75+
"""
76+
request = service_pb2.PostModelVersionsRequest(
77+
user_app_id=self.user_app_id,
78+
model_id=model_id,
79+
model_versions=[resources_pb2.ModelVersion(**kwargs)])
80+
81+
response = self._grpc_request(self.STUB.PostModelVersions, request)
82+
if response.status.code != status_code_pb2.SUCCESS:
83+
raise Exception(response.status)
84+
self.logger.info("\nModel Version created\n%s", response.status)
85+
86+
kwargs.update({'app_id': self.app_id, 'user_id': self.user_id})
87+
dict_response = MessageToDict(response, preserving_proto_field_name=True)
88+
kwargs = self.process_response_keys(dict_response['model'], 'model')
89+
90+
return Model(**kwargs)
91+
92+
def list_versions(self) -> List['Model']:
93+
"""Lists all the versions for the model.
94+
95+
Returns:
96+
List[Model]: A list of Model objects for the versions of the model.
97+
98+
Example:
99+
>>> from clarifai.client.model import Model
100+
>>> model = Model("model_url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
101+
or
102+
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
103+
>>> all_model_versions = model.list_versions()
104+
"""
105+
request_data = dict(
106+
user_app_id=self.user_app_id,
107+
model_id=self.id,
108+
per_page=self.default_page_size,
109+
)
110+
all_model_versions_info = list(
111+
self.list_all_pages_generator(self.STUB.ListModelVersions,
112+
service_pb2.ListModelVersionsRequest, request_data))
113+
114+
for model_version_info in all_model_versions_info:
115+
model_version_info['id'] = model_version_info['model_version_id']
116+
del model_version_info['model_version_id']
117+
118+
return [
119+
Model(model_id=self.id, **dict(self.kwargs, model_version=model_version_info))
120+
for model_version_info in all_model_versions_info
121+
]
122+
55123
def predict(self, inputs: List[Input]):
56124
"""Predicts the model based on the given inputs.
57125

clarifai/client/workflow.py

+23
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from clarifai.errors import UserError
1111
from clarifai.urls.helper import ClarifaiUrlHelper
1212
from clarifai.utils.logging import get_logger
13+
from clarifai.workflows.export import Exporter
1314

1415

1516
class Workflow(Lister, BaseClient):
@@ -182,6 +183,28 @@ def list_versions(self) -> List['Workflow']:
182183
for workflow_version_info in all_workflow_versions_info
183184
]
184185

186+
def export(self, out_path: str):
187+
"""Exports the workflow to a yaml file.
188+
189+
Args:
190+
out_path (str): The path to save the yaml file to.
191+
192+
Example:
193+
>>> from clarifai.client.workflow import Workflow
194+
>>> workflow = Workflow("https://clarifai.com/clarifai/main/workflows/Demographics")
195+
>>> workflow.export('out_path')
196+
"""
197+
request = service_pb2.GetWorkflowRequest(user_app_id=self.user_app_id, workflow_id=self.id)
198+
response = self._grpc_request(self.STUB.GetWorkflow, request)
199+
if response.status.code != status_code_pb2.SUCCESS:
200+
raise Exception(f"Workflow Export failed with response {response.status!r}")
201+
202+
with Exporter(response) as e:
203+
e.parse()
204+
e.export(out_path)
205+
206+
self.logger.info(f"Exported workflow to {out_path}")
207+
185208
def __getattr__(self, name):
186209
return getattr(self.workflow_info, name)
187210

clarifai/utils/logging.py

+53-3
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,66 @@
11
import logging
2-
from typing import Optional
2+
from collections import defaultdict
3+
from typing import Dict, List, Optional
34

5+
from rich import print as rprint
46
from rich.logging import RichHandler
57
from rich.table import Table
68
from rich.traceback import install
9+
from rich.tree import Tree
710

811
install()
912

1013

11-
def table_from_dict(data, column_names, title="") -> Table:
14+
def display_workflow_tree(nodes_data: List[Dict]) -> None:
15+
"""Displays a tree of the workflow nodes."""
16+
# Create a mapping of node_id to the list of node_ids that are connected to it.
17+
node_adj_mapping = defaultdict(list)
18+
# Create a mapping of node_id to the node data info.
19+
nodes_data_dict = {}
20+
for node in nodes_data:
21+
nodes_data_dict[node["id"]] = node
22+
if node.get("node_inputs", "") == "":
23+
node_adj_mapping["Input"].append(node["id"])
24+
else:
25+
for node_input in node["node_inputs"]:
26+
node_adj_mapping[node_input["node_id"]].append(node["id"])
27+
28+
# Get all leaf nodes.
29+
leaf_node_ids = set()
30+
for node_id in list(nodes_data_dict.keys()):
31+
if node_adj_mapping.get(node_id, "") == "":
32+
leaf_node_ids.add(node_id)
33+
34+
def build_node_tree(node_id="Input"):
35+
"""Recursively builds a rich tree of the workflow nodes."""
36+
# Set the style of the current node.
37+
style_str = "green" if node_id in leaf_node_ids else "white"
38+
39+
# Create a Tree object for the current node.
40+
if node_id != "Input":
41+
node_table = table_from_dict(
42+
[nodes_data_dict[node_id]["model"]],
43+
column_names=["id", "model_type_id", "app_id", "user_id"],
44+
title="Node: " + node_id)
45+
46+
tree = Tree(node_table, style=style_str, guide_style="underline2 white")
47+
else:
48+
tree = Tree(f"[green] {node_id}", style=style_str, guide_style="underline2 white")
49+
50+
# Recursively add the child nodes of the current node to the tree.
51+
for child in node_adj_mapping.get(node_id, []):
52+
tree.add(build_node_tree(child))
53+
54+
# Return the tree.
55+
return tree
56+
57+
tree = build_node_tree("Input")
58+
rprint(tree)
59+
60+
61+
def table_from_dict(data: List[Dict], column_names: List[str], title: str = "") -> Table:
1262
"""Use this function for printing tables from a list of dicts."""
13-
table = Table(title=title, show_header=True, header_style="bold blue")
63+
table = Table(title=title, show_lines=False, show_header=True, header_style="blue")
1464
for column_name in column_names:
1565
table.add_column(column_name)
1666
for row in data:

clarifai/workflows/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)