From 50747ad9c168537b8c68249e8254f556736a4b71 Mon Sep 17 00:00:00 2001 From: John Green Date: Wed, 25 May 2022 16:51:03 +0200 Subject: [PATCH 1/2] Add support to specify language name in PapermillOperator --- airflow/providers/papermill/operators/papermill.py | 5 ++++- tests/providers/papermill/operators/test_papermill.py | 5 +++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/airflow/providers/papermill/operators/papermill.py b/airflow/providers/papermill/operators/papermill.py index 36e8539b1e9c1..b79ac226c120f 100644 --- a/airflow/providers/papermill/operators/papermill.py +++ b/airflow/providers/papermill/operators/papermill.py @@ -50,7 +50,7 @@ class PapermillOperator(BaseOperator): supports_lineage = True - template_fields: Sequence[str] = ('input_nb', 'output_nb', 'parameters', 'kernel_name') + template_fields: Sequence[str] = ('input_nb', 'output_nb', 'parameters', 'kernel_name', 'language_name') def __init__( self, @@ -59,6 +59,7 @@ def __init__( output_nb: Optional[str] = None, parameters: Optional[Dict] = None, kernel_name: Optional[str] = None, + language_name: Optional[str] = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -67,6 +68,7 @@ def __init__( self.output_nb = output_nb self.parameters = parameters self.kernel_name = kernel_name + self.language_name = language_name if input_nb: self.inlets.append(NoteBook(url=input_nb, parameters=self.parameters)) if output_nb: @@ -84,4 +86,5 @@ def execute(self, context: 'Context'): progress_bar=False, report_mode=True, kernel_name=self.kernel_name, + language_name=self.language_name, ) diff --git a/tests/providers/papermill/operators/test_papermill.py b/tests/providers/papermill/operators/test_papermill.py index 24d44e386a645..fc18aa82f5cd4 100644 --- a/tests/providers/papermill/operators/test_papermill.py +++ b/tests/providers/papermill/operators/test_papermill.py @@ -31,6 +31,7 @@ def test_execute(self, mock_papermill): in_nb = "/tmp/does_not_exist" out_nb = "/tmp/will_not_exist" kernel_name = "python3" + language_name = "python" parameters = {"msg": "hello_world", "train": 1} op = PapermillOperator( @@ -39,6 +40,7 @@ def test_execute(self, mock_papermill): parameters=parameters, task_id="papermill_operator_test", kernel_name=kernel_name, + language_name=language_name, dag=None, ) @@ -50,6 +52,7 @@ def test_execute(self, mock_papermill): out_nb, parameters=parameters, kernel_name=kernel_name, + language_name=language_name, progress_bar=False, report_mode=True, ) @@ -64,6 +67,7 @@ def test_render_template(self): output_nb="/tmp/out-{{ dag.dag_id }}.ipynb", parameters={"msgs": "dag id is {{ dag.dag_id }}!"}, kernel_name="python3", + language_name="python", dag=dag, ) @@ -75,3 +79,4 @@ def test_render_template(self): assert '/tmp/out-test_render_template.ipynb' == getattr(operator, 'output_nb') assert {"msgs": "dag id is test_render_template!"} == getattr(operator, 'parameters') assert "python3" == getattr(operator, 'kernel_name') + assert "python" == getattr(operator, 'language_name') From 643465d8feafdfa99ec5edbdf831fe27b6d37846 Mon Sep 17 00:00:00 2001 From: John Green Date: Wed, 25 May 2022 17:27:56 +0200 Subject: [PATCH 2/2] Replace getattr() with simple attribute access --- tests/providers/papermill/operators/test_papermill.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/providers/papermill/operators/test_papermill.py b/tests/providers/papermill/operators/test_papermill.py index fc18aa82f5cd4..e635524d01782 100644 --- a/tests/providers/papermill/operators/test_papermill.py +++ b/tests/providers/papermill/operators/test_papermill.py @@ -75,8 +75,8 @@ def test_render_template(self): ti.dag_run = DagRun(execution_date=DEFAULT_DATE) ti.render_templates() - assert "/tmp/test_render_template.ipynb" == getattr(operator, 'input_nb') - assert '/tmp/out-test_render_template.ipynb' == getattr(operator, 'output_nb') - assert {"msgs": "dag id is test_render_template!"} == getattr(operator, 'parameters') - assert "python3" == getattr(operator, 'kernel_name') - assert "python" == getattr(operator, 'language_name') + assert "/tmp/test_render_template.ipynb" == operator.input_nb + assert '/tmp/out-test_render_template.ipynb' == operator.output_nb + assert {"msgs": "dag id is test_render_template!"} == operator.parameters + assert "python3" == operator.kernel_name + assert "python" == operator.language_name