Skip to content

Commit

Permalink
[Issue-3617] Enables FlyteFiles, FlyteDirectors, and StructuredDatase…
Browse files Browse the repository at this point in the history
…ts inputs in papermill plugin (#1612)
  • Loading branch information
peridotml authored May 11, 2023
1 parent e44b802 commit 26d1f29
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
record_outputs
"""

from .task import NotebookTask, record_outputs
from .task import NotebookTask, load_flytedirectory, load_flytefile, load_structureddataset, record_outputs
96 changes: 92 additions & 4 deletions plugins/flytekit-papermill/flytekitplugins/papermill/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,28 @@
import logging
import os
import sys
import tempfile
import typing
from typing import Any

import nbformat
import papermill as pm
from flyteidl.core.literals_pb2 import Literal as _pb2_Literal
from flyteidl.core.literals_pb2 import LiteralMap as _pb2_LiteralMap
from google.protobuf import text_format as _text_format
from nbconvert import HTMLExporter

from flytekit import FlyteContext, PythonInstanceTask
from flytekit import FlyteContext, PythonInstanceTask, StructuredDataset
from flytekit.configuration import SerializationSettings
from flytekit.core import utils
from flytekit.core.context_manager import ExecutionParameters
from flytekit.deck.deck import Deck
from flytekit.extend import Interface, TaskPlugins, TypeEngine
from flytekit.loggers import logger
from flytekit.models import task as task_models
from flytekit.models.literals import LiteralMap
from flytekit.types.file import HTMLPage, PythonNotebook
from flytekit.models.literals import Literal, LiteralMap
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile, HTMLPage, PythonNotebook

T = typing.TypeVar("T")

Expand All @@ -28,6 +32,8 @@ def _dummy_task_func():
return None


SAVE_AS_LITERAL = (FlyteFile, FlyteDirectory, StructuredDataset)

PAPERMILL_TASK_PREFIX = "pm.nb"


Expand Down Expand Up @@ -255,6 +261,10 @@ def execute(self, **kwargs) -> Any:
singleton
"""
logger.info(f"Hijacking the call for task-type {self.task_type}, to call notebook.")
for k, v in kwargs.items():
if isinstance(v, SAVE_AS_LITERAL):
kwargs[k] = save_python_val_to_file(v)

# Execute Notebook via Papermill.
pm.execute_notebook(self._notebook_path, self.output_notebook_path, parameters=kwargs, log_output=self._stream_logs) # type: ignore

Expand All @@ -265,6 +275,7 @@ def execute(self, **kwargs) -> Any:
if outputs:
m = outputs.literals
output_list = []

for k, type_v in self.python_interface.outputs.items():
if k == self._IMPLICIT_OP_NOTEBOOK:
output_list.append(self.output_notebook_path)
Expand All @@ -274,7 +285,7 @@ def execute(self, **kwargs) -> Any:
v = TypeEngine.to_python_value(ctx=FlyteContext.current_context(), lv=m[k], expected_python_type=type_v)
output_list.append(v)
else:
raise RuntimeError(f"Expected output {k} of type {v} not found in the notebook outputs")
raise TypeError(f"Expected output {k} of type {type_v} not found in the notebook outputs")

return tuple(output_list)

Expand Down Expand Up @@ -307,3 +318,80 @@ def record_outputs(**kwargs) -> str:
lit = TypeEngine.to_literal(ctx, python_type=type(v), python_val=v, expected=expected)
m[k] = lit
return LiteralMap(literals=m).to_flyte_idl()


def save_python_val_to_file(input: Any) -> str:
"""Save a python value to a local file as a Flyte literal.
Args:
input (Any): the python value
Returns:
str: the path to the file
"""
ctx = FlyteContext.current_context()
expected = TypeEngine.to_literal_type(type(input))
lit = TypeEngine.to_literal(ctx, python_type=type(input), python_val=input, expected=expected)

tmp_file = tempfile.mktemp(suffix="bin")
utils.write_proto_to_file(lit.to_flyte_idl(), tmp_file)
return tmp_file


def load_python_val_from_file(path: str, dtype: T) -> T:
"""Loads a python value from a Flyte literal saved to a local file.
If the path matches the type, it is returned as is. This enables
reusing the parameters cell for local development.
Args:
path (str): path to the file
dtype (T): the type of the literal
Returns:
T: the python value of the literal
"""
if isinstance(path, dtype):
return path

proto = utils.load_proto_from_file(_pb2_Literal, path)
lit = Literal.from_flyte_idl(proto)
ctx = FlyteContext.current_context()
python_value = TypeEngine.to_python_value(ctx, lit, dtype)
return python_value


def load_flytefile(path: str) -> T:
"""Loads a FlyteFile from a file.
Args:
path (str): path to the file
Returns:
T: the python value of the literal
"""
return load_python_val_from_file(path=path, dtype=FlyteFile)


def load_flytedirectory(path: str) -> T:
"""Loads a FlyteDirectory from a file.
Args:
path (str): path to the file
Returns:
T: the python value of the literal
"""
return load_python_val_from_file(path=path, dtype=FlyteDirectory)


def load_structureddataset(path: str) -> T:
"""Loads a StructuredDataset from a file.
Args:
path (str): path to the file
Returns:
T: the python value of the literal
"""
return load_python_val_from_file(path=path, dtype=StructuredDataset)
42 changes: 40 additions & 2 deletions plugins/flytekit-papermill/tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import datetime
import os
import tempfile

import pandas as pd
from flytekitplugins.papermill import NotebookTask
from flytekitplugins.pod import Pod
from kubernetes.client import V1Container, V1PodSpec

import flytekit
from flytekit import kwtypes
from flytekit import StructuredDataset, kwtypes, task
from flytekit.configuration import Image, ImageConfig
from flytekit.types.file import PythonNotebook
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile, PythonNotebook

from .testdata.datatype import X

Expand Down Expand Up @@ -134,3 +137,38 @@ def test_notebook_pod_task():
nb.get_command(serialization_settings)
== nb.get_k8s_pod(serialization_settings).pod_spec["containers"][0]["args"]
)


def test_flyte_types():
@task
def create_file() -> FlyteFile:
tmp_file = tempfile.mktemp()
with open(tmp_file, "w") as f:
f.write("abc")
return FlyteFile(path=tmp_file)

@task
def create_dir() -> FlyteDirectory:
tmp_dir = tempfile.mkdtemp()
with open(os.path.join(tmp_dir, "file.txt"), "w") as f:
f.write("abc")
return FlyteDirectory(path=tmp_dir)

@task
def create_sd() -> StructuredDataset:
df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
return StructuredDataset(dataframe=df)

ff = create_file()
fd = create_dir()
sd = create_sd()

nb_name = "nb-types"
nb_types = NotebookTask(
name="test",
notebook_path=_get_nb_path(nb_name, abs=False),
inputs=kwtypes(ff=FlyteFile, fd=FlyteDirectory, sd=StructuredDataset),
outputs=kwtypes(success=bool),
)
success, out, render = nb_types.execute(ff=ff, fd=fd, sd=sd)
assert success is True, "Notebook execution failed"
7 changes: 3 additions & 4 deletions plugins/flytekit-papermill/tests/testdata/nb-simple.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
"outputs": [],
"source": [
"from flytekitplugins.papermill import record_outputs\n",
"\n",
"record_outputs(square=out)"
]
},
Expand All @@ -49,7 +48,7 @@
"metadata": {
"celltoolbar": "Tags",
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -63,9 +62,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
100 changes: 100 additions & 0 deletions plugins/flytekit-papermill/tests/testdata/nb-types.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"tags": [
"parameters"
]
},
"outputs": [],
"source": [
"ff = None\n",
"fd = None\n",
"sd = None"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pandas as pd\n",
"from flytekitplugins.papermill import (\n",
" load_flytefile, load_flytedirectory, load_structureddataset,\n",
" record_outputs\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ff = load_flytefile(ff)\n",
"fd = load_flytedirectory(fd)\n",
"sd = load_structureddataset(sd)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# read file\n",
"with open(ff.download(), 'r') as f:\n",
" text = f.read()\n",
" assert text == \"abc\", \"Text does not match\"\n",
"\n",
"# check file inside directory\n",
"with open(os.path.join(fd.download(),\"file.txt\"), 'r') as f:\n",
" text = f.read()\n",
" assert text == \"abc\", \"Text does not match\"\n",
"\n",
"# check dataset\n",
"df = sd.open(pd.DataFrame).all()\n",
"expected = pd.DataFrame({\"a\": [1, 2], \"b\": [3, 4]})\n",
"assert df.equals(expected), \"Dataframes do not match\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"outputs"
]
},
"outputs": [],
"source": [
"record_outputs(success=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

0 comments on commit 26d1f29

Please sign in to comment.