Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Also allow passing buffer instead of path for retrieve_file and store_file methods in SFTPHook #44247

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
17 changes: 11 additions & 6 deletions providers/src/airflow/providers/sftp/hooks/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import stat
import warnings
from fnmatch import fnmatch
from io import BytesIO
from typing import TYPE_CHECKING, Any, Callable, Sequence

import asyncssh
Expand Down Expand Up @@ -51,8 +52,6 @@ class SFTPHook(SSHHook):
- In contrast with FTPHook describe_directory only returns size, type and
modify. It doesn't return unix.owner, unix.mode, perm, unix.group and
unique.
- retrieve_file and store_file only take a local full path and not a
buffer.
- If no mode is passed to create_directory it will be created with 777
permissions.

Expand Down Expand Up @@ -247,11 +246,14 @@ def retrieve_file(self, remote_full_path: str, local_full_path: str, prefetch: b
at that location.

:param remote_full_path: full path to the remote file
:param local_full_path: full path to the local file
:param local_full_path: full path to the local file or a file-like buffer
:param prefetch: controls whether prefetch is performed (default: True)
"""
conn = self.get_conn()
conn.get(remote_full_path, local_full_path, prefetch=prefetch)
if isinstance(local_full_path, BytesIO):
conn.getfo(remote_full_path, local_full_path, prefetch=prefetch)
else:
conn.get(remote_full_path, local_full_path, prefetch=prefetch)

def store_file(self, remote_full_path: str, local_full_path: str, confirm: bool = True) -> None:
"""
Expand All @@ -261,10 +263,13 @@ def store_file(self, remote_full_path: str, local_full_path: str, confirm: bool
from that location.

:param remote_full_path: full path to the remote file
:param local_full_path: full path to the local file
:param local_full_path: full path to the local file or a file-like buffer
"""
conn = self.get_conn()
conn.put(local_full_path, remote_full_path, confirm=confirm)
if isinstance(local_full_path, BytesIO):
conn.putfo(local_full_path, remote_full_path, confirm=confirm)
else:
conn.put(local_full_path, remote_full_path, confirm=confirm)

def delete_file(self, path: str) -> None:
"""
Expand Down
25 changes: 23 additions & 2 deletions providers/tests/sftp/hooks/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import json
import os
import shutil
from io import StringIO
from io import BytesIO, StringIO
from unittest import mock
from unittest.mock import AsyncMock, patch

Expand Down Expand Up @@ -88,7 +88,10 @@ def setup_test_cases(self, tmp_path_factory):
file.write("Test file")
with open(os.path.join(temp_dir, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), "a") as file:
file.write("Test file")
os.mkfifo(os.path.join(temp_dir, TMP_DIR_FOR_TESTS, FIFO_FOR_TESTS))
try:
os.mkfifo(os.path.join(temp_dir, TMP_DIR_FOR_TESTS, FIFO_FOR_TESTS))
except AttributeError:
os.makedirs(os.path.join(temp_dir, TMP_DIR_FOR_TESTS, FIFO_FOR_TESTS))

self.temp_dir = str(temp_dir)

Expand Down Expand Up @@ -180,6 +183,24 @@ def test_store_retrieve_and_delete_file(self):
output = self.hook.list_directory(path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS))
assert output == [SUB_DIR, FIFO_FOR_TESTS]

def test_store_retrieve_and_delete_file_using_buffer(self):
file_contents = BytesIO(b"Test file")
self.hook.store_file(
remote_full_path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
local_full_path=file_contents,
)
output = self.hook.list_directory(path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS))
assert output == [SUB_DIR, FIFO_FOR_TESTS, TMP_FILE_FOR_TESTS]
retrieved_file_contents = BytesIO()
self.hook.retrieve_file(
remote_full_path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
local_full_path=retrieved_file_contents,
)
assert retrieved_file_contents.getvalue() == file_contents.getvalue()
self.hook.delete_file(path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
output = self.hook.list_directory(path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS))
assert output == [SUB_DIR, FIFO_FOR_TESTS]

def test_get_mod_time(self):
self.hook.store_file(
remote_full_path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
Expand Down
Loading