Skip to content

Commit

Permalink
feat: download files using multithreading TDE-822 (#580)
Browse files Browse the repository at this point in the history
* WIP

* DEBUG changes

* WIP

* fix: formatting

* fix: revert tmp dir

* feat: add error if missing files

* Chore tidy code

* fix: tidy log comments

* fix: move download command to fs_s3

* fix: remove fs_local reference

* fix: currently specific to tiffs

* fix: move exception to download function

* fix: make download function more generic

and separate out sidecar file listing

* fix: rename variables/logs/comments to match generalisation

* Update scripts/files/fs.py

Co-authored-by: paulfouquet <[email protected]>

* fix: return destination in write and remove unnecessary function

* fix: use os to manipulate paths 

and make find_sidecars multithreading as otherwise is slow

* fix: appease formatting gods

* fix: reduce concurrency for memory limit concerns

---------

Co-authored-by: paulfouquet <[email protected]>
  • Loading branch information
MDavidson17 and paulfouquet authored Aug 20, 2023
1 parent 110f4ec commit 8ab269d
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 42 deletions.
68 changes: 67 additions & 1 deletion scripts/files/fs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Optional

from linz_logger import get_log

from scripts.aws.aws_helper import is_s3
from scripts.files import fs_local, fs_s3


def write(destination: str, source: bytes) -> None:
def write(destination: str, source: bytes) -> str:
"""Write a file from its source to a destination path.
Args:
Expand All @@ -13,6 +19,7 @@ def write(destination: str, source: bytes) -> None:
fs_s3.write(destination, source)
else:
fs_local.write(destination, source)
return destination


def read(path: str) -> bytes:
Expand Down Expand Up @@ -42,3 +49,62 @@ def exists(path: str) -> bool:
if is_s3(path):
return fs_s3.exists(path)
return fs_local.exists(path)


def write_all(inputs: List[str], target: str, concurrency: Optional[int] = 4) -> List[str]:
"""Writes list of files to target destination using multithreading.
Args:
inputs: list of files to read
target: target folder to write to
Returns:
list of written file paths
"""
written_tiffs: List[str] = []
with ThreadPoolExecutor(max_workers=concurrency) as executor:
futuress = {
executor.submit(write, os.path.join(target, f"{os.path.basename(input)}"), read(input)): input for input in inputs
}
for future in as_completed(futuress):
if future.exception():
get_log().warn("Failed Read-Write", error=future.exception())
else:
written_tiffs.append(future.result())

if len(inputs) != len(written_tiffs):
get_log().error("Missing Files", count=len(inputs) - len(written_tiffs))
raise Exception("Not all source files were written")
return written_tiffs


def find_sidecars(inputs: List[str], extensions: List[str], concurrency: Optional[int] = 4) -> List[str]:
"""Searches for sidecar files.
A sidecar files is a file with the same name as the input file but with a different extension.
Args:
inputs: list of input files to search for extensions
extensions: the sidecar file extensions
Returns:
list of existing sidecar files
"""

def _validate_path(path: str) -> Optional[str]:
"""Helper inner function to re-return the path if it exists rather than a boolean."""
if exists(path):
return path
return None

sidecars: List[str] = []
with ThreadPoolExecutor(max_workers=concurrency) as executor:
for extension in extensions:
futuress = {executor.submit(_validate_path, f"{os.path.splitext(input)[0]}{extension}"): input for input in inputs}
for future in as_completed(futuress):
if future.exception():
get_log().warn("Find sidecar failed", error=future.exception())
else:
result = future.result()
if result:
sidecars.append(result)
return sidecars
46 changes: 5 additions & 41 deletions scripts/standardising.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from multiprocessing import Pool
from typing import List, Optional

import ulid
from linz_logger import get_log

from scripts.aws.aws_helper import is_s3
from scripts.cli.cli_helper import TileFiles
from scripts.files.file_tiff import FileTiff, FileTiffType
from scripts.files.fs import exists, read, write
from scripts.files.files_helper import is_tiff
from scripts.files.fs import exists, find_sidecars, read, write, write_all
from scripts.gdal.gdal_bands import get_gdal_band_offset
from scripts.gdal.gdal_helper import get_gdal_version, run_gdal
from scripts.gdal.gdal_preset import (
Expand Down Expand Up @@ -74,44 +74,6 @@ def run_standardising(
return standardized_tiffs


def download_tiffs(files: List[str], target: str) -> List[str]:
"""Download a tiff file and some of its sidecar files if they exist to the target dir.
Args:
files: links source filename to target tilename
target: target folder to write too
Returns:
linked downloaded filename to target tilename
Example:
```
>>> download_tiff_file(("s3://elevation/SN9457_CE16_10k_0502.tif", "CE16_5000_1003"), "/tmp/")
("/tmp/123456.tif", "CE16_5000_1003")
```
"""
downloaded_files: List[str] = []
for file in files:
target_file_path = os.path.join(target, str(ulid.ULID()))
input_file_path = target_file_path + ".tiff"
get_log().info("download_tiff", path=file, target_path=input_file_path)

write(input_file_path, read(file))
downloaded_files.append(input_file_path)

base_file_path = os.path.splitext(file)[0]
# Attempt to download sidecar files too
for ext in [".prj", ".tfw"]:
try:
write(target_file_path + ext, read(base_file_path + ext))
get_log().info("download_tiff_sidecar", path=base_file_path + ext, target_path=target_file_path + ext)

except: # pylint: disable-msg=bare-except
pass

return downloaded_files


def create_vrt(source_tiffs: List[str], target_path: str, add_alpha: bool = False) -> str:
"""Create a VRT from a list of tiffs files
Expand Down Expand Up @@ -168,8 +130,10 @@ def standardising(
# Download any needed file from S3 ["/foo/bar.tiff", "s3://foo"] => "/tmp/bar.tiff", "/tmp/foo.tiff"
with tempfile.TemporaryDirectory() as tmp_path:
standardized_working_path = os.path.join(tmp_path, standardized_file_name)
sidecars = find_sidecars(files.input, [".prj", ".tfw"])
source_files = write_all(files.input + sidecars, tmp_path)
source_tiffs = [file for file in source_files if is_tiff(file)]

source_tiffs = download_tiffs(files.input, tmp_path)
vrt_add_alpha = True

for file in source_tiffs:
Expand Down

0 comments on commit 8ab269d

Please sign in to comment.