diff --git a/python/src/iceberg/io/base.py b/python/src/iceberg/io/base.py index 705c8b2031de..8bb905a86d57 100644 --- a/python/src/iceberg/io/base.py +++ b/python/src/iceberg/io/base.py @@ -26,6 +26,54 @@ from abc import ABC, abstractmethod from typing import Union +try: + from typing import Protocol, runtime_checkable +except ImportError: # pragma: no cover + from typing_extensions import Protocol # type: ignore + from typing_extensions import runtime_checkable + + +@runtime_checkable +class InputStream(Protocol): # pragma: no cover + """A protocol for the file-like object returned by InputFile.open(...) + + This outlines the minimally required methods for a seekable input stream returned from an InputFile + implementation's `open(...)` method. These methods are a subset of IOBase/RawIOBase. + """ + + def read(self, size: int) -> bytes: + ... + + def seek(self, offset: int, whence: int) -> None: + ... + + def tell(self) -> int: + ... + + def closed(self) -> bool: + ... + + def close(self) -> None: + ... + + +@runtime_checkable +class OutputStream(Protocol): # pragma: no cover + """A protocol for the file-like object returned by OutputFile.create(...) + + This outlines the minimally required methods for a writable output stream returned from an OutputFile + implementation's `create(...)` method. These methods are a subset of IOBase/RawIOBase. + """ + + def write(self, b: bytes) -> None: + ... + + def closed(self) -> bool: + ... + + def close(self) -> None: + ... + class InputFile(ABC): """A base class for InputFile implementations""" @@ -48,8 +96,11 @@ def exists(self) -> bool: """Checks whether the file exists""" @abstractmethod - def open(self): - """This method should return an instance of an seekable input stream.""" + def open(self) -> InputStream: + """This method should return an object that matches the InputStream protocol + + If a file does not exist at `self.location`, this should raise a FileNotFoundError. + """ class OutputFile(ABC): @@ -77,8 +128,8 @@ def to_input_file(self) -> InputFile: """Returns an InputFile for the location of this output file""" @abstractmethod - def create(self, overwrite: bool = False): - """This method should return a file-like object. + def create(self, overwrite: bool = False) -> OutputStream: + """This method should return an object that matches the OutputStream protocol. Args: overwrite(bool): If the file already exists at `self.location` @@ -87,6 +138,8 @@ def create(self, overwrite: bool = False): class FileIO(ABC): + """A base class for FileIO implementations""" + @abstractmethod def new_input(self, location: str) -> InputFile: """Get an InputFile instance to read bytes from the file at the given location""" diff --git a/python/tests/io/test_base.py b/python/tests/io/test_base.py index bf67bc62ee75..6a7773cb6786 100644 --- a/python/tests/io/test_base.py +++ b/python/tests/io/test_base.py @@ -22,7 +22,7 @@ import pytest -from iceberg.io.base import FileIO, InputFile, OutputFile +from iceberg.io.base import FileIO, InputFile, InputStream, OutputFile, OutputStream class LocalInputFile(InputFile): @@ -49,8 +49,11 @@ def __len__(self): def exists(self): return os.path.exists(self.parsed_location.path) - def open(self): - return open(self.parsed_location.path, "rb") + def open(self) -> InputStream: + input_file = open(self.parsed_location.path, "rb") + if not isinstance(input_file, InputStream): + raise TypeError("""Object returned from LocalInputFile.open does not match the OutputStream protocol.""") + return input_file class LocalOutputFile(OutputFile): @@ -80,8 +83,11 @@ def exists(self): def to_input_file(self): return LocalInputFile(location=self.location) - def create(self, overwrite: bool = False) -> None: - return open(self.parsed_location.path, "wb" if overwrite else "xb") + def create(self, overwrite: bool = False) -> OutputStream: + output_file = open(self.parsed_location.path, "wb" if overwrite else "xb") + if not isinstance(output_file, OutputStream): + raise TypeError("""Object returned from LocalOutputFile.create does not match the OutputStream protocol.""") + return output_file class LocalFileIO(FileIO):