Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add remote client that directly utilizes the flyteidl-rust Python bindings #2536

Closed
31 changes: 20 additions & 11 deletions flytekit/clients/helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import flyteidl_rust as flyteidl


def iterate_node_executions(
client,
workflow_execution_identifier=None,
Expand All @@ -22,13 +25,16 @@ def iterate_node_executions(
counter = 0
while True:
if workflow_execution_identifier is not None:
node_execs, next_token = client.list_node_executions(
workflow_execution_identifier=workflow_execution_identifier,
limit=num_to_fetch,
token=token,
filters=filters,
unique_parent_id=unique_parent_id,
node_execution_list = client.list_node_executions(
flyteidl.admin.NodeExecutionListRequest(
workflow_execution_id=workflow_execution_identifier,
limit=num_to_fetch,
token=token,
filters=filters or "",
unique_parent_id=unique_parent_id or "",
)
)
node_execs, next_token = node_execution_list.node_executions, node_execution_list.token
else:
node_execs, next_token = client.list_node_executions_for_task_paginated(
task_execution_identifier=task_execution_identifier,
Expand Down Expand Up @@ -61,12 +67,15 @@ def iterate_task_executions(client, node_execution_identifier, limit=None, filte
num_to_fetch = limit
counter = 0
while True:
task_execs, next_token = client.list_task_executions_paginated(
node_execution_identifier=node_execution_identifier,
limit=num_to_fetch,
token=token,
filters=filters,
task__execution_list = client.list_task_executions(
flyteidl.admin.TaskExecutionListRequest(
node_execution_id=node_execution_identifier,
limit=num_to_fetch,
token=token,
filters=filters or "",
)
)
task_execs, next_token = task__execution_list.task_executions, task__execution_list.token
for t in task_execs:
counter += 1
if limit is not None and counter > limit:
Expand Down
2 changes: 1 addition & 1 deletion flytekit/clis/sdk_in_container/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE
from flytekit.configuration import ImageConfig
from flytekit.configuration.plugin import get_plugin
from flytekit.remote.remote import FlyteRemote
from flytekit.remote.remote_rs import FlyteRemote

FLYTE_REMOTE_INSTANCE_KEY = "flyte_remote"

Expand Down
17 changes: 8 additions & 9 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dataclasses import dataclass, field, fields
from typing import cast, get_args

import flyteidl_rust as flyteidl
import rich_click as click
from dataclasses_json import DataClassJsonMixin
from rich.progress import Progress
Expand Down Expand Up @@ -37,10 +38,10 @@
from flytekit.loggers import logger
from flytekit.models import security
from flytekit.models.common import RawOutputDataConfig
from flytekit.models.interface import Parameter, Variable
from flytekit.models.types import SimpleType
from flytekit.remote import FlyteLaunchPlan, FlyteRemote, FlyteTask, FlyteWorkflow, remote_fs
from flytekit.models.interface import Parameter
from flytekit.remote import FlyteLaunchPlan, FlyteTask, FlyteWorkflow, remote_fs
from flytekit.remote.executions import FlyteWorkflowExecution
from flytekit.remote.remote_rs import FlyteRemote
from flytekit.tools import module_loader
from flytekit.tools.script_mode import _find_project_root, compress_scripts
from flytekit.tools.translator import Options
Expand Down Expand Up @@ -358,7 +359,7 @@ def to_click_option(
ctx: click.Context,
flyte_ctx: FlyteContext,
input_name: str,
literal_var: Variable,
literal_var: flyteidl.core.Variable,
python_type: typing.Type,
default_val: typing.Any,
required: bool,
Expand All @@ -380,7 +381,7 @@ def to_click_option(
default_val = False

description_extra = ""
if literal_var.type.simple == SimpleType.STRUCT:
if literal_var.type == flyteidl.core.SimpleType.Struct:
if default_val and not isinstance(default_val, ArtifactQuery):
if type(default_val) == dict or type(default_val) == list:
default_val = json.dumps(default_val)
Expand Down Expand Up @@ -508,7 +509,6 @@ def _run(*args, **kwargs):
Click command function that is used to execute a flyte workflow from the given entity in the file.
"""
# By the time we get to this function, all the loading has already happened

run_level_params: RunLevelParams = ctx.obj
logger.debug(f"Running {entity.name} with {kwargs} and run_level_params {run_level_params}")

Expand Down Expand Up @@ -565,7 +565,6 @@ def _run(*args, **kwargs):
module_name=run_level_params.computed_params.module,
copy_all=run_level_params.copy_all,
)

run_remote(
remote,
remote_entity,
Expand Down Expand Up @@ -612,7 +611,7 @@ def _fetch_entity(self, ctx: click.Context) -> typing.Union[FlyteLaunchPlan, Fly
def _get_params(
self,
ctx: click.Context,
inputs: typing.Dict[str, Variable],
inputs: typing.Dict[str, flyteidl.core.Variable],
native_inputs: typing.Dict[str, type],
fixed: typing.Optional[typing.Dict[str, Literal]] = None,
defaults: typing.Optional[typing.Dict[str, Parameter]] = None,
Expand Down Expand Up @@ -783,7 +782,7 @@ def _create_command(
# Add options for each of the workflow inputs
params = []
for input_name, input_type_val in loaded_entity.python_interface.inputs_with_defaults.items():
literal_var = loaded_entity.interface.inputs.get(input_name)
literal_var = loaded_entity.interface.inputs.variables.get(input_name)
python_type, default_val = input_type_val
required = type(None) not in get_args(python_type) and default_val is None
params.append(to_click_option(ctx, flyte_ctx, input_name, literal_var, python_type, default_val, required))
Expand Down
2 changes: 1 addition & 1 deletion flytekit/configuration/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from flytekit.configuration import Config, get_config_file
from flytekit.loggers import logger
from flytekit.remote import FlyteRemote
from flytekit.remote.remote_rs import FlyteRemote


@runtime_checkable
Expand Down
86 changes: 47 additions & 39 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import warnings
from abc import abstractmethod
from base64 import b64encode
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import (
Any,
Coroutine,
Expand All @@ -41,8 +41,7 @@
cast,
)

from flyteidl.core import artifact_id_pb2 as art_id
from flyteidl.core import tasks_pb2
import flyteidl_rust as flyteidl

from flytekit.configuration import LocalConfig, SerializationSettings
from flytekit.core.artifact_utils import (
Expand Down Expand Up @@ -72,11 +71,9 @@
from flytekit.core.utils import timeit
from flytekit.loggers import logger
from flytekit.models import dynamic_job as _dynamic_job
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
from flytekit.models import task as _task_model
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.documentation import Description, Documentation
from flytekit.models.interface import Variable
from flytekit.models.security import SecurityContext

Expand Down Expand Up @@ -134,6 +131,8 @@ class TaskMetadata(object):
retries: int = 0
timeout: Optional[Union[datetime.timedelta, int]] = None
pod_template_name: Optional[str] = None
generates_deck: Optional[bool] = False
tags: Optional[Dict[str, str]] = field(default_factory=dict) # type: ignore

def __post_init__(self):
if self.timeout:
Expand All @@ -151,28 +150,35 @@ def __post_init__(self):
)

@property
def retry_strategy(self) -> _literal_models.RetryStrategy:
return _literal_models.RetryStrategy(self.retries)
def retry_strategy(self) -> flyteidl.core.RetryStrategy:
return flyteidl.core.RetryStrategy(self.retries)

def to_taskmetadata_model(self) -> _task_model.TaskMetadata:
def to_taskmetadata_model(self) -> flyteidl.core.TaskMetadata:
"""
Converts to _task_model.TaskMetadata
Converts to flyteidl.TaskMetadata
"""
from flytekit import __version__

return _task_model.TaskMetadata(
return flyteidl.core.TaskMetadata(
discoverable=self.cache,
runtime=_task_model.RuntimeMetadata(
_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python"
),
timeout=self.timeout,
runtime=flyteidl.core.RuntimeMetadata(int(flyteidl.core.RuntimeType.FlyteSdk), __version__, "python"),
timeout=flyteidl.wkt.Duration(
seconds=(
self.timeout.days * 24 * 60 * 60 + self.timeout.seconds + self.timeout.microseconds // 1_000_000 # type: ignore
),
nanos=(self.timeout.microseconds % 1_000_000 * 1_000), # type: ignore
)
if not isinstance(self.timeout, int)
else flyteidl.wkt.Duration(seconds=int(self.timeout), nanos=0),
retries=self.retry_strategy,
interruptible=self.interruptible,
interruptible_value=self.interruptible,
discovery_version=self.cache_version,
deprecated_error_message=self.deprecated,
cache_serializable=self.cache_serialize,
pod_template_name=self.pod_template_name,
pod_template_name=self.pod_template_name or "",
cache_ignore_input_vars=self.cache_ignore_input_vars,
generates_deck=self.generates_deck,
tags=self.tags,
)


Expand All @@ -196,11 +202,11 @@ def __init__(
self,
task_type: str,
name: str,
interface: _interface_models.TypedInterface,
interface: flyteidl.core.TypedInterface,
metadata: Optional[TaskMetadata] = None,
task_type_version=0,
security_ctx: Optional[SecurityContext] = None,
docs: Optional[Documentation] = None,
docs: Optional[flyteidl.admin.DescriptionEntity] = None,
**kwargs,
):
self._task_type = task_type
Expand All @@ -214,7 +220,7 @@ def __init__(
FlyteEntities.entities.append(self)

@property
def interface(self) -> _interface_models.TypedInterface:
def interface(self) -> flyteidl.core.TypedInterface:
return self._interface

@property
Expand Down Expand Up @@ -242,7 +248,7 @@ def security_context(self) -> SecurityContext:
return self._security_ctx

@property
def docs(self) -> Documentation:
def docs(self) -> flyteidl.admin.DescriptionEntity:
return self._docs

def get_type_for_input_var(self, k: str, v: Any) -> type:
Expand Down Expand Up @@ -285,14 +291,14 @@ def local_execute(
kwargs = translate_inputs_to_literals(
ctx,
incoming_values=kwargs,
flyte_interface_types=self.interface.inputs,
flyte_interface_types=self.interface.inputs.variables,
native_types=self.get_input_types(), # type: ignore
)
except TypeTransformerFailedError as exc:
msg = f"Failed to convert inputs of task '{self.name}':\n {exc}"
logger.error(msg)
raise TypeError(msg) from exc
input_literal_map = _literal_models.LiteralMap(literals=kwargs)
input_literal_map = flyteidl.core.LiteralMap(literals=kwargs)

# if metadata.cache is set, check memoized version
local_config = LocalConfig.auto()
Expand Down Expand Up @@ -338,11 +344,10 @@ def local_execute(
return outputs_literal_map

outputs_literals = outputs_literal_map.literals

# TODO maybe this is the part that should be done for local execution, we pass the outputs to some special
# location, otherwise we dont really need to right? The higher level execute could just handle literalMap
# After running, we again have to wrap the outputs, if any, back into Promise objects
output_names = list(self.interface.outputs.keys()) # type: ignore
output_names = list(self.interface.outputs.variables.keys()) # type: ignore
if len(output_names) != len(outputs_literals):
# Length check, clean up exception
raise AssertionError(f"Length difference {len(output_names)} {len(outputs_literals)}")
Expand All @@ -360,7 +365,7 @@ def __call__(self, *args: object, **kwargs: object) -> Union[Tuple[Promise], Pro
def compile(self, ctx: FlyteContext, *args, **kwargs):
raise Exception("not implemented")

def get_container(self, settings: SerializationSettings) -> Optional[_task_model.Container]:
def get_container(self, settings: SerializationSettings) -> Optional[flyteidl.core.Container]:
"""
Returns the container definition (if any) that is used to run the task on hosted Flyte.
"""
Expand Down Expand Up @@ -391,7 +396,7 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]
"""
return None

def get_extended_resources(self, settings: SerializationSettings) -> Optional[tasks_pb2.ExtendedResources]:
def get_extended_resources(self, settings: SerializationSettings) -> Optional[flyteidl.core.ExtendedResources]:
"""
Returns the extended resources to allocate to the task on hosted Flyte.
"""
Expand All @@ -404,8 +409,8 @@ def local_execution_mode(self) -> ExecutionState.Mode:
def sandbox_execute(
self,
ctx: FlyteContext,
input_literal_map: _literal_models.LiteralMap,
) -> _literal_models.LiteralMap:
input_literal_map: flyteidl.core.LiteralMap,
) -> flyteidl.core.LiteralMap:
"""
Call dispatch_execute, in the context of a local sandbox execution. Not invoked during runtime.
"""
Expand All @@ -418,8 +423,8 @@ def sandbox_execute(
def dispatch_execute(
self,
ctx: FlyteContext,
input_literal_map: _literal_models.LiteralMap,
) -> _literal_models.LiteralMap:
input_literal_map: flyteidl.core.LiteralMap,
) -> flyteidl.core.LiteralMap:
"""
This method translates Flyte's Type system based input values and invokes the actual call to the executor
This method is also invoked during runtime.
Expand Down Expand Up @@ -508,17 +513,22 @@ def __init__(
self._disable_deck = True
if self._python_interface.docstring:
if self.docs is None:
self._docs = Documentation(
short_description=self._python_interface.docstring.short_description,
long_description=Description(value=self._python_interface.docstring.long_description),
self._docs = flyteidl.admin.DescriptionEntity(
short_description=self._python_interface.docstring.short_description
if self._python_interface.docstring.short_description
else "",
long_description=flyteidl.admin.Description(
content=self._python_interface.docstring.long_description, format=0, icon_link=""
),
tags=[],
)
else:
if self._python_interface.docstring.short_description:
cast(
Documentation, self._docs
flyteidl.admin.DescriptionEntity, self._docs
).short_description = self._python_interface.docstring.short_description
if self._python_interface.docstring.long_description:
cast(Documentation, self._docs).long_description = Description(
cast(flyteidl.admin.DescriptionEntity, self._docs).long_description = flyteidl.admin.Description(
value=self._python_interface.docstring.long_description
)

Expand Down Expand Up @@ -576,9 +586,7 @@ def compile(self, ctx: FlyteContext, *args, **kwargs) -> Optional[Union[Tuple[Pr
def _outputs_interface(self) -> Dict[Any, Variable]:
return self.interface.outputs # type: ignore

def _literal_map_to_python_input(
self, literal_map: _literal_models.LiteralMap, ctx: FlyteContext
) -> Dict[str, Any]:
def _literal_map_to_python_input(self, literal_map: flyteidl.core.LiteralMap, ctx: FlyteContext) -> Dict[str, Any]:
return TypeEngine.literal_map_to_kwargs(ctx, literal_map, self.python_interface.inputs)

def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteContext):
Expand Down Expand Up @@ -629,7 +637,7 @@ def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteConte
metadata[md_key] = md_val
logger.info(f"Adding {om.additional_items} additional metadata items {metadata} for {k}")
if om.dynamic_partitions or om.time_partition:
a = art_id.ArtifactID(
a = art_id.ArtifactID( # type: ignore
partitions=idl_partitions_from_dict(om.dynamic_partitions),
time_partition=idl_time_partition_from_datetime(
om.time_partition, om.artifact.time_partition_granularity
Expand Down
Loading
Loading