diff --git a/airflow/providers/papermill/operators/papermill.py b/airflow/providers/papermill/operators/papermill.py index ebbda6620ffdf..7c2e8d383d1b6 100644 --- a/airflow/providers/papermill/operators/papermill.py +++ b/airflow/providers/papermill/operators/papermill.py @@ -44,11 +44,14 @@ class PapermillOperator(BaseOperator): :type output_nb: str :param parameters: the notebook parameters to set :type parameters: dict + :param kernel_name: (optional) name of kernel to execute the notebook against + (ignores kernel name in the notebook document metadata) + :type kernel_name: str """ supports_lineage = True - template_fields = ('input_nb', 'output_nb', 'parameters') + template_fields = ('input_nb', 'output_nb', 'parameters', 'kernel_name') def __init__( self, @@ -56,6 +59,7 @@ def __init__( input_nb: Optional[str] = None, output_nb: Optional[str] = None, parameters: Optional[Dict] = None, + kernel_name: Optional[str] = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -63,6 +67,7 @@ def __init__( self.input_nb = input_nb self.output_nb = output_nb self.parameters = parameters + self.kernel_name = kernel_name if input_nb: self.inlets.append(NoteBook(url=input_nb, parameters=self.parameters)) if output_nb: @@ -79,4 +84,5 @@ def execute(self, context): parameters=item.parameters, progress_bar=False, report_mode=True, + kernel_name=self.kernel_name, ) diff --git a/tests/providers/papermill/operators/test_papermill.py b/tests/providers/papermill/operators/test_papermill.py index 71859326d7d81..24d44e386a645 100644 --- a/tests/providers/papermill/operators/test_papermill.py +++ b/tests/providers/papermill/operators/test_papermill.py @@ -30,6 +30,7 @@ class TestPapermillOperator(unittest.TestCase): def test_execute(self, mock_papermill): in_nb = "/tmp/does_not_exist" out_nb = "/tmp/will_not_exist" + kernel_name = "python3" parameters = {"msg": "hello_world", "train": 1} op = PapermillOperator( @@ -37,14 +38,20 @@ def test_execute(self, mock_papermill): output_nb=out_nb, parameters=parameters, task_id="papermill_operator_test", + kernel_name=kernel_name, dag=None, ) - op.pre_execute(context={}) # make sure to have the inlets + op.pre_execute(context={}) # Make sure to have the inlets op.execute(context={}) mock_papermill.execute_notebook.assert_called_once_with( - in_nb, out_nb, parameters=parameters, progress_bar=False, report_mode=True + in_nb, + out_nb, + parameters=parameters, + kernel_name=kernel_name, + progress_bar=False, + report_mode=True, ) def test_render_template(self): @@ -56,6 +63,7 @@ def test_render_template(self): input_nb="/tmp/{{ dag.dag_id }}.ipynb", output_nb="/tmp/out-{{ dag.dag_id }}.ipynb", parameters={"msgs": "dag id is {{ dag.dag_id }}!"}, + kernel_name="python3", dag=dag, ) @@ -66,3 +74,4 @@ def test_render_template(self): assert "/tmp/test_render_template.ipynb" == getattr(operator, 'input_nb') assert '/tmp/out-test_render_template.ipynb' == getattr(operator, 'output_nb') assert {"msgs": "dag id is test_render_template!"} == getattr(operator, 'parameters') + assert "python3" == getattr(operator, 'kernel_name')