|
| 1 | +import os |
| 2 | +import uuid |
1 | 3 | from typing import Any, Dict, List
|
2 | 4 |
|
| 5 | +import yaml |
3 | 6 | from clarifai_grpc.grpc.api import resources_pb2, service_pb2
|
4 | 7 | from clarifai_grpc.grpc.api.status import status_code_pb2
|
5 | 8 | from google.protobuf.json_format import MessageToDict
|
|
13 | 16 | from clarifai.client.workflow import Workflow
|
14 | 17 | from clarifai.errors import UserError
|
15 | 18 | from clarifai.urls.helper import ClarifaiUrlHelper
|
16 |
| -from clarifai.utils.logging import get_logger |
| 19 | +from clarifai.utils.logging import display_workflow_tree, get_logger |
| 20 | +from clarifai.workflows.utils import get_yaml_output_info_proto, is_same_yaml_model |
| 21 | +from clarifai.workflows.validate import validate |
17 | 22 |
|
18 | 23 |
|
19 | 24 | class App(Lister, BaseClient):
|
@@ -236,30 +241,98 @@ def create_model(self, model_id: str, **kwargs) -> Model:
|
236 | 241 |
|
237 | 242 | return Model(model_id=model_id, **kwargs)
|
238 | 243 |
|
239 |
| - def create_workflow(self, workflow_id: str, **kwargs) -> Workflow: |
| 244 | + def create_workflow(self, |
| 245 | + config_filepath: str, |
| 246 | + generate_new_id: bool = False, |
| 247 | + display: bool = True) -> Workflow: |
240 | 248 | """Creates a workflow for the app.
|
241 | 249 |
|
242 | 250 | Args:
|
243 |
| - workflow_id (str): The workflow ID for the workflow to create. |
244 |
| - **kwargs: Additional keyword arguments to be passed to the workflow. |
| 251 | + config_filepath (str): The path to the yaml workflow config file. |
| 252 | + generate_new_id (bool): If True, generate a new workflow ID. |
| 253 | + display (bool): If True, display the workflow nodes tree. |
245 | 254 |
|
246 | 255 | Returns:
|
247 |
| - Workflow: A Workflow object for the specified workflow ID. |
| 256 | + Workflow: A Workflow object for the specified workflow config. |
248 | 257 |
|
249 | 258 | Example:
|
250 | 259 | >>> from clarifai.client.app import App
|
251 | 260 | >>> app = App(app_id="app_id", user_id="user_id")
|
252 |
| - >>> workflow = app.create_workflow(workflow_id="workflow_id") |
| 261 | + >>> workflow = app.create_workflow(config_filepath="config.yml") |
253 | 262 | """
|
| 263 | + if not os.path.exists(config_filepath): |
| 264 | + raise UserError(f"Workflow config file not found at {config_filepath}") |
| 265 | + |
| 266 | + with open(config_filepath, 'r') as file: |
| 267 | + data = yaml.safe_load(file) |
| 268 | + |
| 269 | + data = validate(data) |
| 270 | + workflow = data['workflow'] |
| 271 | + |
| 272 | + # Get all model objects from the workflow nodes. |
| 273 | + all_models = [] |
| 274 | + for node in workflow['nodes']: |
| 275 | + output_info = get_yaml_output_info_proto(node['model'].get('output_info', None)) |
| 276 | + try: |
| 277 | + model = self.model( |
| 278 | + node['model']['model_id'], |
| 279 | + node['model'].get('model_version_id', ""), |
| 280 | + user_id=node['model'].get('user_id', ""), |
| 281 | + app_id=node['model'].get('app_id', "")) |
| 282 | + except Exception as e: |
| 283 | + if "Model does not exist" in str(e): |
| 284 | + model = self.create_model( |
| 285 | + **{k: v |
| 286 | + for k, v in node['model'].items() if k != 'output_info'}) |
| 287 | + model_version = model.create_model_version( |
| 288 | + model_id=node['model']['model_id'], output_info=output_info) |
| 289 | + all_models.append(model_version.model_info) |
| 290 | + continue |
| 291 | + |
| 292 | + # If the model version ID is specified, or if the yaml model is the same as the one in the api |
| 293 | + if node["model"].get("model_version_id", "") or is_same_yaml_model( |
| 294 | + model.model_info, node["model"]): |
| 295 | + all_models.append(model.model_info) |
| 296 | + else: # Create a new model version |
| 297 | + model = model.create_model_version( |
| 298 | + model_id=node['model']['model_id'], output_info=output_info) |
| 299 | + all_models.append(model.model_info) |
| 300 | + |
| 301 | + # Convert nodes to resources_pb2.WorkflowNodes. |
| 302 | + nodes = [] |
| 303 | + for i, yml_node in enumerate(workflow['nodes']): |
| 304 | + node = resources_pb2.WorkflowNode( |
| 305 | + id=yml_node['id'], |
| 306 | + model=all_models[i], |
| 307 | + ) |
| 308 | + # Add node inputs if they exist, i.e. if these nodes do not connect directly to the input. |
| 309 | + if yml_node.get("node_inputs"): |
| 310 | + for ni in yml_node.get("node_inputs"): |
| 311 | + node.node_inputs.append(resources_pb2.NodeInput(node_id=ni['node_id'])) |
| 312 | + nodes.append(node) |
| 313 | + |
| 314 | + workflow_id = workflow['id'] |
| 315 | + if generate_new_id: |
| 316 | + workflow_id = str(uuid.uuid4()) |
| 317 | + |
| 318 | + # Create the workflow. |
254 | 319 | request = service_pb2.PostWorkflowsRequest(
|
255 |
| - user_app_id=self.user_app_id, workflows=[resources_pb2.Workflow(id=workflow_id, **kwargs)]) |
| 320 | + user_app_id=self.user_app_id, |
| 321 | + workflows=[resources_pb2.Workflow(id=workflow_id, nodes=nodes)]) |
| 322 | + |
256 | 323 | response = self._grpc_request(self.STUB.PostWorkflows, request)
|
257 | 324 | if response.status.code != status_code_pb2.SUCCESS:
|
258 | 325 | raise Exception(response.status)
|
259 | 326 | self.logger.info("\nWorkflow created\n%s", response.status)
|
260 |
| - kwargs.update({'app_id': self.id, 'user_id': self.user_id}) |
261 | 327 |
|
262 |
| - return Workflow(workflow_id=workflow_id, **kwargs) |
| 328 | + dict_response = MessageToDict(response, preserving_proto_field_name=True) |
| 329 | + # Display the workflow nodes tree. |
| 330 | + if display: |
| 331 | + display_workflow_tree(dict_response["workflows"][0]["nodes"]) |
| 332 | + kwargs = self.process_response_keys(dict_response[list(dict_response.keys())[1]][0], |
| 333 | + "workflow") |
| 334 | + |
| 335 | + return Workflow(**kwargs) |
263 | 336 |
|
264 | 337 | def create_module(self, module_id: str, description: str, **kwargs) -> Module:
|
265 | 338 | """Creates a module for the app.
|
@@ -329,8 +402,16 @@ def model(self, model_id: str, model_version_id: str = "", **kwargs) -> Model:
|
329 | 402 | >>> app = App(app_id="app_id", user_id="user_id")
|
330 | 403 | >>> model_v1 = app.model(model_id="model_id", model_version_id="model_version_id")
|
331 | 404 | """
|
332 |
| - request = service_pb2.GetModelRequest( |
333 |
| - user_app_id=self.user_app_id, model_id=model_id, version_id=model_version_id) |
| 405 | + # Change user_app_id based on whether user_id or app_id is specified. |
| 406 | + if kwargs.get("user_id") or kwargs.get("app_id"): |
| 407 | + request = service_pb2.GetModelRequest( |
| 408 | + user_app_id=self.auth_helper.get_user_app_id_proto( |
| 409 | + kwargs.get("user_id"), kwargs.get("app_id")), |
| 410 | + model_id=model_id, |
| 411 | + version_id=model_version_id) |
| 412 | + else: |
| 413 | + request = service_pb2.GetModelRequest( |
| 414 | + user_app_id=self.user_app_id, model_id=model_id, version_id=model_version_id) |
334 | 415 | response = self._grpc_request(self.STUB.GetModel, request)
|
335 | 416 |
|
336 | 417 | if response.status.code != status_code_pb2.SUCCESS:
|
|
0 commit comments