Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion airflow/providers/papermill/operators/papermill.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,30 @@ class PapermillOperator(BaseOperator):
:type output_nb: str
:param parameters: the notebook parameters to set
:type parameters: dict
:param kernel_name: (optional) name of kernel to execute the notebook against
(ignores kernel name in the notebook document metadata)
:type kernel_name: str
"""

supports_lineage = True

template_fields = ('input_nb', 'output_nb', 'parameters')
template_fields = ('input_nb', 'output_nb', 'parameters', 'kernel_name')

def __init__(
self,
*,
input_nb: Optional[str] = None,
output_nb: Optional[str] = None,
parameters: Optional[Dict] = None,
kernel_name: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)

self.input_nb = input_nb
self.output_nb = output_nb
self.parameters = parameters
self.kernel_name = kernel_name
if input_nb:
self.inlets.append(NoteBook(url=input_nb, parameters=self.parameters))
if output_nb:
Expand All @@ -79,4 +84,5 @@ def execute(self, context):
parameters=item.parameters,
progress_bar=False,
report_mode=True,
kernel_name=self.kernel_name,
)
13 changes: 11 additions & 2 deletions tests/providers/papermill/operators/test_papermill.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,28 @@ class TestPapermillOperator(unittest.TestCase):
def test_execute(self, mock_papermill):
in_nb = "/tmp/does_not_exist"
out_nb = "/tmp/will_not_exist"
kernel_name = "python3"
parameters = {"msg": "hello_world", "train": 1}

op = PapermillOperator(
input_nb=in_nb,
output_nb=out_nb,
parameters=parameters,
task_id="papermill_operator_test",
kernel_name=kernel_name,
dag=None,
)

op.pre_execute(context={}) # make sure to have the inlets
op.pre_execute(context={}) # Make sure to have the inlets
op.execute(context={})

mock_papermill.execute_notebook.assert_called_once_with(
in_nb, out_nb, parameters=parameters, progress_bar=False, report_mode=True
in_nb,
out_nb,
parameters=parameters,
kernel_name=kernel_name,
progress_bar=False,
report_mode=True,
)

def test_render_template(self):
Expand All @@ -56,6 +63,7 @@ def test_render_template(self):
input_nb="/tmp/{{ dag.dag_id }}.ipynb",
output_nb="/tmp/out-{{ dag.dag_id }}.ipynb",
parameters={"msgs": "dag id is {{ dag.dag_id }}!"},
kernel_name="python3",
dag=dag,
)

Expand All @@ -66,3 +74,4 @@ def test_render_template(self):
assert "/tmp/test_render_template.ipynb" == getattr(operator, 'input_nb')
assert '/tmp/out-test_render_template.ipynb' == getattr(operator, 'output_nb')
assert {"msgs": "dag id is test_render_template!"} == getattr(operator, 'parameters')
assert "python3" == getattr(operator, 'kernel_name')