Skip to content

Commit

Permalink
feat: add TaskOnKart.dump type
Browse files Browse the repository at this point in the history
  • Loading branch information
kitagry committed Apr 23, 2024
1 parent cc38660 commit 33e9f2d
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 61 deletions.
38 changes: 32 additions & 6 deletions gokart/build.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from functools import partial
from logging import getLogger
from typing import Any, Optional
from typing import Literal, Optional, TypeVar, cast, overload

import backoff
import luigi
Expand All @@ -11,6 +11,8 @@
from gokart.target import TargetOnKart
from gokart.task import TaskOnKart

T = TypeVar('T')


class LoggerConfig:
def __init__(self, level: int):
Expand Down Expand Up @@ -42,13 +44,13 @@ def __init__(self):
self.flag: bool = False


def _get_output(task: TaskOnKart) -> Any:
def _get_output(task: TaskOnKart[T]) -> T:
output = task.output()
# FIXME: currently, nested output is not supported
if isinstance(output, list) or isinstance(output, tuple):
return [t.load() for t in output if isinstance(t, TargetOnKart)]
return cast(T, [t.load() for t in output if isinstance(t, TargetOnKart)])
if isinstance(output, dict):
return {k: t.load() for k, t in output.items() if isinstance(t, TargetOnKart)}
return cast(T, {k: t.load() for k, t in output.items() if isinstance(t, TargetOnKart)})
if isinstance(output, TargetOnKart):
return output.load()
raise ValueError(f'output type is not supported: {type(output)}')
Expand All @@ -66,15 +68,39 @@ def _reset_register(keep={'gokart', 'luigi'}):
]


@overload
def build(
task: TaskOnKart[T],
return_value: Literal[True] = True,
reset_register: bool = True,
log_level: int = logging.ERROR,
task_lock_exception_max_tries: int = 10,
task_lock_exception_max_wait_seconds: int = 600,
**env_params,
) -> T: ...


@overload
def build(
task: TaskOnKart[T],
return_value: Literal[False],
reset_register: bool = True,
log_level: int = logging.ERROR,
task_lock_exception_max_tries: int = 10,
task_lock_exception_max_wait_seconds: int = 600,
**env_params,
) -> None: ...


def build(
task: TaskOnKart,
task: TaskOnKart[T],
return_value: bool = True,
reset_register: bool = True,
log_level: int = logging.ERROR,
task_lock_exception_max_tries: int = 10,
task_lock_exception_max_wait_seconds: int = 600,
**env_params,
) -> Optional[Any]:
) -> Optional[T]:
"""
Run gokart task for local interpreter.
Sharing the most of its parameters with luigi.build (see https://luigi.readthedocs.io/en/stable/api/luigi.html?highlight=build#luigi.build)
Expand Down
23 changes: 17 additions & 6 deletions gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import types
from importlib import import_module
from logging import getLogger
from typing import Any, Callable, Dict, List, Optional, Set, Union
from typing import Any, Callable, Dict, Generator, Generic, List, Optional, Set, TypeVar, Union, overload

import luigi
import pandas as pd
Expand All @@ -24,7 +24,10 @@
logger = getLogger(__name__)


class TaskOnKart(luigi.Task):
T = TypeVar('T')


class TaskOnKart(luigi.Task, Generic[T]):
"""
This is a wrapper class of luigi.Task.
Expand Down Expand Up @@ -269,7 +272,7 @@ def _load(targets):
return list(data.values())[0]
return data

def load_generator(self, target: Union[None, str, TargetOnKart] = None) -> Any:
def load_generator(self, target: Union[None, str, TargetOnKart] = None) -> Generator[Any, None, None]:
def _load(targets):
if isinstance(targets, list) or isinstance(targets, tuple):
for t in targets:
Expand Down Expand Up @@ -305,7 +308,15 @@ def _flatten_recursively(dfs):
data = data[list(required_columns)]
return data

def dump(self, obj, target: Union[None, str, TargetOnKart] = None) -> None:
@overload
def dump(self, obj: T, target: None = None) -> None:
...

@overload
def dump(self, obj: Any, target: Union[str, TargetOnKart]) -> None:
...

def dump(self, obj: Any, target: Union[None, str, TargetOnKart] = None) -> None:
PandasTypeConfigMap().check(obj, task_namespace=self.task_namespace)
if self.fail_on_empty_dump and isinstance(obj, pd.DataFrame):
assert not obj.empty
Expand All @@ -323,13 +334,13 @@ def get_own_code(self):
own_codes = self.get_code(self)
return ''.join(sorted(list(own_codes - gokart_codes)))

def make_unique_id(self):
def make_unique_id(self) -> str:
unique_id = self.task_unique_id or self._make_hash_id()
if self.cache_unique_id:
self.task_unique_id = unique_id
return unique_id

def _make_hash_id(self):
def _make_hash_id(self) -> str:
def _to_str_params(task):
if isinstance(task, TaskOnKart):
return str(task.make_unique_id()) if task.significant else None
Expand Down
75 changes: 32 additions & 43 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ fakeredis = "*"
mypy = "*"
types-redis = "*"
matplotlib = "*"
typing-extensions = "^4.11.0"

[tool.ruff]
# All the rules are listed on https://docs.astral.sh/ruff/rules/
Expand Down
Loading

0 comments on commit 33e9f2d

Please sign in to comment.