Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions airflow/providers/cncf/kubernetes/hooks/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def __init__(
self.disable_verify_ssl = disable_verify_ssl
self.disable_tcp_keepalive = disable_tcp_keepalive

self._is_in_cluster: Optional[bool] = None

# these params used for transition in KPO to K8s hook
# for a deprecation period we will continue to consider k8s settings from airflow.cfg
self._deprecated_core_disable_tcp_keepalive: Optional[bool] = None
Expand Down Expand Up @@ -232,11 +234,13 @@ def get_conn(self) -> Any:

if in_cluster:
self.log.debug("loading kube_config from: in_cluster configuration")
self._is_in_cluster = True
config.load_incluster_config()
return client.ApiClient()

if kubeconfig_path is not None:
self.log.debug("loading kube_config from: %s", kubeconfig_path)
self._is_in_cluster = False
config.load_kube_config(
config_file=kubeconfig_path,
client_configuration=self.client_configuration,
Expand All @@ -249,6 +253,7 @@ def get_conn(self) -> Any:
self.log.debug("loading kube_config from: connection kube_config")
temp_config.write(kubeconfig.encode())
temp_config.flush()
self._is_in_cluster = False
config.load_kube_config(
config_file=temp_config.name,
client_configuration=self.client_configuration,
Expand All @@ -265,14 +270,24 @@ def _get_default_client(self, *, cluster_context=None):
# in the default location
try:
config.load_incluster_config(client_configuration=self.client_configuration)
self._is_in_cluster = True
except ConfigException:
self.log.debug("loading kube_config from: default file")
self._is_in_cluster = False
config.load_kube_config(
client_configuration=self.client_configuration,
context=cluster_context,
)
return client.ApiClient()

@property
def is_in_cluster(self):
"""Expose whether the hook is configured with ``load_incluster_config`` or not"""
if self._is_in_cluster is not None:
return self._is_in_cluster
self.api_client # so we can determine if we are in_cluster or not
return self._is_in_cluster

@cached_property
def api_client(self) -> Any:
"""Cached Kubernetes API client"""
Expand Down
9 changes: 7 additions & 2 deletions airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,11 @@ def pod_manager(self) -> PodManager:
return PodManager(kube_client=self.client)

def get_hook(self):
warnings.warn("get_hook is deprecated. Please use hook instead.", DeprecationWarning, stacklevel=2)
return self.hook

@cached_property
def hook(self) -> KubernetesHook:
hook = KubernetesHook(
conn_id=self.kubernetes_conn_id,
in_cluster=self.in_cluster,
Expand All @@ -341,8 +346,7 @@ def get_hook(self):

@cached_property
def client(self) -> CoreV1Api:
hook = self.get_hook()
return hook.core_v1_client
return self.hook.core_v1_client

def find_pod(self, namespace, context, *, exclude_checked=True) -> Optional[k8s.V1Pod]:
"""Returns an already-running pod for this task instance if one exists."""
Expand Down Expand Up @@ -568,6 +572,7 @@ def build_pod_request_obj(self, context=None):
pod.metadata.labels.update(
{
'airflow_version': airflow_version.replace('+', '-'),
'airflow_kpo_in_cluster': str(self.hook.is_in_cluster),
}
)
pod_mutation_hook(pod)
Expand Down
15 changes: 11 additions & 4 deletions kubernetes_tests/test_kubernetes_pod_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def setUp(self):
'foo': 'bar',
'kubernetes_pod_operator': 'True',
'airflow_version': airflow_version.replace('+', '-'),
'airflow_kpo_in_cluster': 'False',
'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b',
'dag_id': 'dag',
'task_id': ANY,
Expand Down Expand Up @@ -734,6 +735,7 @@ def test_pod_template_file_with_overrides_system(self):
'fizz': 'buzz',
'foo': 'bar',
'airflow_version': mock.ANY,
'airflow_kpo_in_cluster': 'False',
'dag_id': 'dag',
'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b',
'kubernetes_pod_operator': 'True',
Expand Down Expand Up @@ -773,6 +775,7 @@ def test_pod_template_file_with_full_pod_spec(self):
'fizz': 'buzz',
'foo': 'bar',
'airflow_version': mock.ANY,
'airflow_kpo_in_cluster': 'False',
'dag_id': 'dag',
'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b',
'kubernetes_pod_operator': 'True',
Expand Down Expand Up @@ -815,6 +818,7 @@ def test_full_pod_spec(self):
'fizz': 'buzz',
'foo': 'bar',
'airflow_version': mock.ANY,
'airflow_kpo_in_cluster': 'False',
'dag_id': 'dag',
'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b',
'kubernetes_pod_operator': 'True',
Expand Down Expand Up @@ -882,9 +886,10 @@ def test_init_container(self):
@mock.patch(f"{POD_MANAGER_CLASS}.extract_xcom")
@mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion")
@mock.patch(f"{POD_MANAGER_CLASS}.create_pod", new=MagicMock)
@mock.patch(HOOK_CLASS, new=MagicMock)
def test_pod_template_file(self, await_pod_completion_mock, extract_xcom_mock):
@mock.patch(HOOK_CLASS)
def test_pod_template_file(self, hook_mock, await_pod_completion_mock, extract_xcom_mock):
# todo: This isn't really a system test
hook_mock.return_value.is_in_cluster = False
extract_xcom_mock.return_value = '{}'
path = sys.path[0] + '/tests/kubernetes/pod.yaml'
k = KubernetesPodOperator(
Expand Down Expand Up @@ -920,6 +925,7 @@ def test_pod_template_file(self, await_pod_completion_mock, extract_xcom_mock):
'metadata': {
'annotations': {},
'labels': {
'airflow_kpo_in_cluster': 'False',
'dag_id': 'dag',
'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b',
'kubernetes_pod_operator': 'True',
Expand Down Expand Up @@ -968,13 +974,14 @@ def test_pod_template_file(self, await_pod_completion_mock, extract_xcom_mock):

@mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion")
@mock.patch(f"{POD_MANAGER_CLASS}.create_pod", new=MagicMock)
@mock.patch(HOOK_CLASS, new=MagicMock)
def test_pod_priority_class_name(self, await_pod_completion_mock):
@mock.patch(HOOK_CLASS)
def test_pod_priority_class_name(self, hook_mock, await_pod_completion_mock):
"""
Test ability to assign priorityClassName to pod

todo: This isn't really a system test
"""
hook_mock.return_value.is_in_cluster = False

priority_class_name = "medium-test"
k = KubernetesPodOperator(
Expand Down
2 changes: 2 additions & 0 deletions kubernetes_tests/test_kubernetes_pod_operator_backcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def setUp(self):
'foo': 'bar',
'kubernetes_pod_operator': 'True',
'airflow_version': airflow_version.replace('+', '-'),
'airflow_kpo_in_cluster': 'False',
'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b',
'dag_id': 'dag',
'task_id': 'task',
Expand Down Expand Up @@ -571,6 +572,7 @@ def test_pod_template_file_with_overrides_system(self):
'fizz': 'buzz',
'foo': 'bar',
'airflow_version': mock.ANY,
'airflow_kpo_in_cluster': 'False',
'dag_id': 'dag',
'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b',
'kubernetes_pod_operator': 'True',
Expand Down
7 changes: 7 additions & 0 deletions tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def test_in_cluster_connection(
else:
mock_get_default_client.assert_called()
assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)
if mock_get_default_client.called:
# get_default_client sets it, but it's mocked
assert kubernetes_hook.is_in_cluster is None
else:
assert kubernetes_hook.is_in_cluster is in_cluster_called

@pytest.mark.parametrize('in_cluster_fails', [True, False])
@patch("kubernetes.config.kube_config.KubeConfigLoader")
Expand All @@ -130,10 +135,12 @@ def test_get_default_client(
mock_incluster.assert_called_once()
mock_merger.assert_called_once_with(KUBE_CONFIG_PATH)
mock_loader.assert_called_once()
assert kubernetes_hook.is_in_cluster is False
else:
mock_incluster.assert_called_once()
mock_merger.assert_not_called()
mock_loader.assert_not_called()
assert kubernetes_hook.is_in_cluster is True
assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)

@pytest.mark.parametrize(
Expand Down
20 changes: 14 additions & 6 deletions tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def run_pod(self, operator: KubernetesPodOperator, map_index: int = -1) -> k8s.V
remote_pod_mock = MagicMock()
remote_pod_mock.status.phase = 'Succeeded'
self.await_pod_mock.return_value = remote_pod_mock
if not isinstance(self.hook_mock.return_value.is_in_cluster, bool):
self.hook_mock.return_value.is_in_cluster = True
operator.execute(context=context)
return self.await_start_mock.call_args[1]['pod']

Expand Down Expand Up @@ -170,15 +172,17 @@ def test_envs_from_configmaps(
pod = self.run_pod(k)
assert pod.spec.containers[0].env_from == env_from

def test_labels(self):
@pytest.mark.parametrize(("in_cluster",), ([True], [False]))
def test_labels(self, in_cluster):
self.hook_mock.return_value.is_in_cluster = in_cluster
k = KubernetesPodOperator(
namespace="default",
image="ubuntu:16.04",
cmds=["bash", "-cx"],
labels={"foo": "bar"},
name="test",
task_id="task",
in_cluster=False,
in_cluster=in_cluster,
do_xcom_push=False,
)
pod = self.run_pod(k)
Expand All @@ -190,6 +194,7 @@ def test_labels(self):
"try_number": "1",
"airflow_version": mock.ANY,
"run_id": "test",
"airflow_kpo_in_cluster": str(in_cluster),
}

def test_labels_mapped(self):
Expand All @@ -209,6 +214,7 @@ def test_labels_mapped(self):
"airflow_version": mock.ANY,
"run_id": "test",
"map_index": "10",
"airflow_kpo_in_cluster": "True",
}

def test_find_pod_labels(self):
Expand Down Expand Up @@ -391,6 +397,7 @@ def test_full_pod_spec(self, randomize_name, pod_spec):
"task_id": "task",
"try_number": "1",
"airflow_version": mock.ANY,
"airflow_kpo_in_cluster": "True",
"run_id": "test",
}

Expand Down Expand Up @@ -429,6 +436,7 @@ def test_full_pod_spec_kwargs(self, randomize_name, pod_spec):
"task_id": "task",
"try_number": "1",
"airflow_version": mock.ANY,
"airflow_kpo_in_cluster": "True",
"run_id": "test",
}

Expand Down Expand Up @@ -499,6 +507,7 @@ def test_pod_template_file(self, randomize_name, pod_template_file):
"task_id": "task",
"try_number": "1",
"airflow_version": mock.ANY,
"airflow_kpo_in_cluster": "True",
"run_id": "test",
}
assert pod.metadata.namespace == "mynamespace"
Expand Down Expand Up @@ -568,6 +577,7 @@ def test_pod_template_file_kwargs_override(self, randomize_name, pod_template_fi
"task_id": "task",
"try_number": "1",
"airflow_version": mock.ANY,
"airflow_kpo_in_cluster": "True",
"run_id": "test",
}

Expand Down Expand Up @@ -877,13 +887,11 @@ def test_patch_core_settings(self, key, value, attr, patched_value):
# the hook attr should be None
op = KubernetesPodOperator(task_id='abc', name='hi')
self.hook_patch.stop()
hook = op.get_hook()
assert getattr(hook, attr) is None
assert getattr(op.hook, attr) is None
# now check behavior with a non-default value
with conf_vars({('kubernetes', key): value}):
op = KubernetesPodOperator(task_id='abc', name='hi')
hook = op.get_hook()
assert getattr(hook, attr) == patched_value
assert getattr(op.hook, attr) == patched_value


def test__suppress():
Expand Down