Skip to content

Commit

Permalink
Fix mypy errors (#1313)
Browse files Browse the repository at this point in the history
* wip

Signed-off-by: Kevin Su <[email protected]>

* Fix mypy errors

Signed-off-by: Kevin Su <[email protected]>

* Fix mypy errors

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* wip

Signed-off-by: Kevin Su <[email protected]>

* wip

Signed-off-by: Kevin Su <[email protected]>

* fix tests

Signed-off-by: Kevin Su <[email protected]>

* fix tests

Signed-off-by: Kevin Su <[email protected]>

* fix test

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* Update type

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* update dev-requirements.txt

Signed-off-by: Kevin Su <[email protected]>

* Address comment

Signed-off-by: Kevin Su <[email protected]>

* upgrade torch

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* lint

Signed-off-by: Kevin Su <[email protected]>

---------

Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Yee Hing Tong <[email protected]>
  • Loading branch information
2 people authored and eapolinario committed May 16, 2023
1 parent ded15a8 commit af49155
Show file tree
Hide file tree
Showing 55 changed files with 406 additions and 350 deletions.
11 changes: 6 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ fmt: ## Format code with black and isort

.PHONY: lint
lint: ## Run linters
mypy flytekit/core || true
mypy flytekit/types || true
mypy tests/flytekit/unit/core || true
# Exclude setup.py to fix error: Duplicate module named "setup"
mypy plugins --exclude setup.py || true
mypy flytekit/core
mypy flytekit/types
# allow-empty-bodies: Allow empty body in function.
# disable-error-code="annotation-unchecked": Remove the warning "By default the bodies of untyped functions are not checked".
# Mypy raises a warning because it cannot determine the type from the dataclass, despite we specified the type in the dataclass.
mypy --allow-empty-bodies --disable-error-code="annotation-unchecked" tests/flytekit/unit/core
pre-commit run --all-files

.PHONY: spellcheck
Expand Down
8 changes: 4 additions & 4 deletions flytekit/core/base_sql_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Any, Dict, Optional, Type, TypeVar
from typing import Any, Dict, Optional, Tuple, Type, TypeVar

from flytekit.core.base_task import PythonTask, TaskMetadata
from flytekit.core.interface import Interface
Expand All @@ -22,11 +22,11 @@ def __init__(
self,
name: str,
query_template: str,
task_config: Optional[T] = None,
task_type="sql_task",
inputs: Optional[Dict[str, Type]] = None,
inputs: Optional[Dict[str, Tuple[Type, Any]]] = None,
metadata: Optional[TaskMetadata] = None,
task_config: Optional[T] = None,
outputs: Dict[str, Type] = None,
outputs: Optional[Dict[str, Type]] = None,
**kwargs,
):
"""
Expand Down
48 changes: 30 additions & 18 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@
import datetime
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union
from typing import Any, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union, cast

from flytekit.configuration import SerializationSettings
from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager, FlyteEntities
from flytekit.core.context_manager import (
ExecutionParameters,
ExecutionState,
FlyteContext,
FlyteContextManager,
FlyteEntities,
)
from flytekit.core.interface import Interface, transform_interface_to_typed_interface
from flytekit.core.local_cache import LocalTaskCache
from flytekit.core.promise import (
Expand Down Expand Up @@ -156,7 +162,7 @@ def __init__(
self,
task_type: str,
name: str,
interface: Optional[_interface_models.TypedInterface] = None,
interface: _interface_models.TypedInterface,
metadata: Optional[TaskMetadata] = None,
task_type_version=0,
security_ctx: Optional[SecurityContext] = None,
Expand All @@ -174,7 +180,7 @@ def __init__(
FlyteEntities.entities.append(self)

@property
def interface(self) -> Optional[_interface_models.TypedInterface]:
def interface(self) -> _interface_models.TypedInterface:
return self._interface

@property
Expand Down Expand Up @@ -294,8 +300,8 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
vals = [Promise(var, outputs_literals[var]) for var in output_names]
return create_task_output(vals, self.python_interface)

def __call__(self, *args, **kwargs):
return flyte_entity_call_handler(self, *args, **kwargs)
def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]:
return flyte_entity_call_handler(self, *args, **kwargs) # type: ignore

def compile(self, ctx: FlyteContext, *args, **kwargs):
raise Exception("not implemented")
Expand Down Expand Up @@ -339,8 +345,8 @@ def sandbox_execute(
"""
Call dispatch_execute, in the context of a local sandbox execution. Not invoked during runtime.
"""
es = ctx.execution_state
b = es.user_space_params.with_task_sandbox()
es = cast(ExecutionState, ctx.execution_state)
b = cast(ExecutionParameters, es.user_space_params).with_task_sandbox()
ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build()
return self.dispatch_execute(ctx, input_literal_map)

Expand Down Expand Up @@ -389,7 +395,7 @@ def __init__(
self,
task_type: str,
name: str,
task_config: T,
task_config: Optional[T],
interface: Optional[Interface] = None,
environment: Optional[Dict[str, str]] = None,
disable_deck: bool = True,
Expand Down Expand Up @@ -426,9 +432,13 @@ def __init__(
)
else:
if self._python_interface.docstring.short_description:
self._docs.short_description = self._python_interface.docstring.short_description
cast(
Documentation, self._docs
).short_description = self._python_interface.docstring.short_description
if self._python_interface.docstring.long_description:
self._docs.long_description = Description(value=self._python_interface.docstring.long_description)
cast(Documentation, self._docs).long_description = Description(
value=self._python_interface.docstring.long_description
)

# TODO lets call this interface and the other as flyte_interface?
@property
Expand All @@ -439,25 +449,25 @@ def python_interface(self) -> Interface:
return self._python_interface

@property
def task_config(self) -> T:
def task_config(self) -> Optional[T]:
"""
Returns the user-specified task config which is used for plugin-specific handling of the task.
"""
return self._task_config

def get_type_for_input_var(self, k: str, v: Any) -> Optional[Type[Any]]:
def get_type_for_input_var(self, k: str, v: Any) -> Type[Any]:
"""
Returns the python type for an input variable by name.
"""
return self._python_interface.inputs[k]

def get_type_for_output_var(self, k: str, v: Any) -> Optional[Type[Any]]:
def get_type_for_output_var(self, k: str, v: Any) -> Type[Any]:
"""
Returns the python type for the specified output variable by name.
"""
return self._python_interface.outputs[k]

def get_input_types(self) -> Optional[Dict[str, type]]:
def get_input_types(self) -> Dict[str, type]:
"""
Returns the names and python types as a dictionary for the inputs of this task.
"""
Expand Down Expand Up @@ -503,7 +513,9 @@ def dispatch_execute(

# Create another execution context with the new user params, but let's keep the same working dir
with FlyteContextManager.with_context(
ctx.with_execution_state(ctx.execution_state.with_params(user_space_params=new_user_params))
ctx.with_execution_state(
cast(ExecutionState, ctx.execution_state).with_params(user_space_params=new_user_params)
)
# type: ignore
) as exec_ctx:
# TODO We could support default values here too - but not part of the plan right now
Expand Down Expand Up @@ -596,7 +608,7 @@ def dispatch_execute(
# After the execute has been successfully completed
return outputs_literal_map

def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
def pre_execute(self, user_params: Optional[ExecutionParameters]) -> Optional[ExecutionParameters]: # type: ignore
"""
This is the method that will be invoked directly before executing the task method and before all the inputs
are converted. One particular case where this is useful is if the context is to be modified for the user process
Expand All @@ -614,7 +626,7 @@ def execute(self, **kwargs) -> Any:
"""
pass

def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
def post_execute(self, user_params: Optional[ExecutionParameters], rval: Any) -> Any:
"""
Post execute is called after the execution has completed, with the user_params and can be used to clean-up,
or alter the outputs to match the intended tasks outputs. If not overridden, then this function is a No-op
Expand Down
4 changes: 2 additions & 2 deletions flytekit/core/class_based_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, *args, **kwargs):
def name(self) -> str:
return "ClassStorageTaskResolver"

def get_all_tasks(self) -> List[PythonAutoContainerTask]:
def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type:ignore
return self.mapping

def add(self, t: PythonAutoContainerTask):
Expand All @@ -33,7 +33,7 @@ def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask:
idx = int(loader_args[0])
return self.mapping[idx]

def loader_args(self, settings: SerializationSettings, t: PythonAutoContainerTask) -> List[str]:
def loader_args(self, settings: SerializationSettings, t: PythonAutoContainerTask) -> List[str]: # type: ignore
"""
This is responsible for turning an instance of a task into args that the load_task function can reconstitute.
"""
Expand Down
6 changes: 3 additions & 3 deletions flytekit/core/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def end_branch(self) -> Optional[Union[Condition, Promise, Tuple[Promise], VoidP
return self._compute_outputs(n)
return self._condition

def if_(self, expr: bool) -> Case:
def if_(self, expr: Union[ComparisonExpression, ConjunctionExpression]) -> Case:
return self._condition._if(expr)

def compute_output_vars(self) -> typing.Optional[typing.List[str]]:
Expand Down Expand Up @@ -360,7 +360,7 @@ def create_branch_node_promise_var(node_id: str, var: str) -> str:
return f"{node_id}.{var}"


def merge_promises(*args: Promise) -> typing.List[Promise]:
def merge_promises(*args: Optional[Promise]) -> typing.List[Promise]:
node_vars: typing.Set[typing.Tuple[str, str]] = set()
merged_promises: typing.List[Promise] = []
for p in args:
Expand Down Expand Up @@ -414,7 +414,7 @@ def transform_to_boolexpr(


def to_case_block(c: Case) -> Tuple[Union[_core_wf.IfBlock], typing.List[Promise]]:
expr, promises = transform_to_boolexpr(c.expr)
expr, promises = transform_to_boolexpr(cast(Union[ComparisonExpression, ConjunctionExpression], c.expr))
n = c.output_promise.ref.node # type: ignore
return _core_wf.IfBlock(condition=expr, then_node=n), promises

Expand Down
14 changes: 7 additions & 7 deletions flytekit/core/container_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Optional, Tuple, Type

from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask, TaskMetadata
Expand Down Expand Up @@ -38,16 +38,16 @@ def __init__(
name: str,
image: str,
command: List[str],
inputs: Optional[Dict[str, Type]] = None,
inputs: Optional[Dict[str, Tuple[Type, Any]]] = None,
metadata: Optional[TaskMetadata] = None,
arguments: List[str] = None,
outputs: Dict[str, Type] = None,
arguments: Optional[List[str]] = None,
outputs: Optional[Dict[str, Type]] = None,
requests: Optional[Resources] = None,
limits: Optional[Resources] = None,
input_data_dir: str = None,
output_data_dir: str = None,
input_data_dir: Optional[str] = None,
output_data_dir: Optional[str] = None,
metadata_format: MetadataFormat = MetadataFormat.JSON,
io_strategy: IOStrategy = None,
io_strategy: Optional[IOStrategy] = None,
secret_requests: Optional[List[Secret]] = None,
pod_template: Optional["PodTemplate"] = None,
pod_template_name: Optional[str] = None,
Expand Down
20 changes: 10 additions & 10 deletions flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
flyte_context_Var: ContextVar[typing.List[FlyteContext]] = ContextVar("", default=[])

if typing.TYPE_CHECKING:
from flytekit.core.base_task import TaskResolverMixin
from flytekit.core.base_task import Task, TaskResolverMixin


# Identifier fields use placeholders for registration-time substitution.
Expand Down Expand Up @@ -108,7 +108,7 @@ def add_attr(self, key: str, v: typing.Any) -> ExecutionParameters.Builder:

def build(self) -> ExecutionParameters:
if not isinstance(self.working_dir, utils.AutoDeletingTempDir):
pathlib.Path(self.working_dir).mkdir(parents=True, exist_ok=True)
pathlib.Path(typing.cast(str, self.working_dir)).mkdir(parents=True, exist_ok=True)
return ExecutionParameters(
execution_date=self.execution_date,
stats=self.stats,
Expand All @@ -123,14 +123,14 @@ def build(self) -> ExecutionParameters:
)

@staticmethod
def new_builder(current: ExecutionParameters = None) -> Builder:
def new_builder(current: Optional[ExecutionParameters] = None) -> Builder:
return ExecutionParameters.Builder(current=current)

def with_task_sandbox(self) -> Builder:
prefix = self.working_directory
if isinstance(self.working_directory, utils.AutoDeletingTempDir):
prefix = self.working_directory.name
task_sandbox_dir = tempfile.mkdtemp(prefix=prefix)
task_sandbox_dir = tempfile.mkdtemp(prefix=prefix) # type: ignore
p = pathlib.Path(task_sandbox_dir)
cp_dir = p.joinpath("__cp")
cp_dir.mkdir(exist_ok=True)
Expand Down Expand Up @@ -299,7 +299,7 @@ def get(self, key: str) -> typing.Any:
"""
Returns task specific context if present else raise an error. The returned context will match the key
"""
return self.__getattr__(attr_name=key)
return self.__getattr__(attr_name=key) # type: ignore


class SecretsManager(object):
Expand Down Expand Up @@ -480,14 +480,14 @@ class Mode(Enum):
LOCAL_TASK_EXECUTION = 3

mode: Optional[ExecutionState.Mode]
working_dir: os.PathLike
working_dir: Union[os.PathLike, str]
engine_dir: Optional[Union[os.PathLike, str]]
branch_eval_mode: Optional[BranchEvalMode]
user_space_params: Optional[ExecutionParameters]

def __init__(
self,
working_dir: os.PathLike,
working_dir: Union[os.PathLike, str],
mode: Optional[ExecutionState.Mode] = None,
engine_dir: Optional[Union[os.PathLike, str]] = None,
branch_eval_mode: Optional[BranchEvalMode] = None,
Expand Down Expand Up @@ -620,7 +620,7 @@ def new_execution_state(self, working_dir: Optional[os.PathLike] = None) -> Exec
return ExecutionState(working_dir=working_dir, user_space_params=self.user_space_params)

@staticmethod
def current_context() -> Optional[FlyteContext]:
def current_context() -> FlyteContext:
"""
This method exists only to maintain backwards compatibility. Please use
``FlyteContextManager.current_context()`` instead.
Expand Down Expand Up @@ -652,7 +652,7 @@ def get_deck(self) -> typing.Union[str, "IPython.core.display.HTML"]: # type:ig
"""
from flytekit.deck.deck import _get_deck

return _get_deck(self.execution_state.user_space_params)
return _get_deck(typing.cast(ExecutionState, self.execution_state).user_space_params)

@dataclass
class Builder(object):
Expand Down Expand Up @@ -865,7 +865,7 @@ class FlyteEntities(object):
registration process
"""

entities = []
entities: List[Union["LaunchPlan", Task, "WorkflowBase"]] = [] # type: ignore


FlyteContextManager.initialize()
2 changes: 1 addition & 1 deletion flytekit/core/docstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class Docstring(object):
def __init__(self, docstring: str = None, callable_: Callable = None):
def __init__(self, docstring: Optional[str] = None, callable_: Optional[Callable] = None):
if docstring is not None:
self._parsed_docstring = parse(docstring)
else:
Expand Down
5 changes: 3 additions & 2 deletions flytekit/core/gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
)
else:
# We don't know how to find the python interface here, approve() sets it below, See the code.
self._python_interface = None
self._python_interface = None # type: ignore

@property
def name(self) -> str:
Expand Down Expand Up @@ -105,7 +105,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
return p

# Assume this is an approval operation since that's the only remaining option.
msg = f"Pausing execution for {self.name}, literal value is:\n{self._upstream_item.val}\nContinue?"
msg = f"Pausing execution for {self.name}, literal value is:\n{typing.cast(Promise, self._upstream_item).val}\nContinue?"
proceed = click.confirm(msg, default=True)
if proceed:
# We need to return a promise here, and a promise is what should've been passed in by the call in approve()
Expand Down Expand Up @@ -167,6 +167,7 @@ def approve(upstream_item: Union[Tuple[Promise], Promise, VoidPromise], name: st
raise ValueError("You can't use approval on a task that doesn't return anything.")

ctx = FlyteContextManager.current_context()
upstream_item = typing.cast(Promise, upstream_item)
if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
if not upstream_item.ref.node.flyte_entity.python_interface:
raise ValueError(
Expand Down
Loading

0 comments on commit af49155

Please sign in to comment.