Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Simplify type annotations #990

Merged
merged 7 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions scripts/aws/aws_credential_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Optional


# pylint: disable=too-many-instance-attributes
Expand All @@ -21,15 +20,15 @@ class CredentialSource:
"""
Role arn to use
"""
externalId: Optional[str] = None
externalId: str | None = None
"""
Role external ID if it exists
"""
roleSessionDuration: Optional[int] = 1 * 60 * 60
roleSessionDuration: int | None = 1 * 60 * 60
"""
Max duration of the assumed session in seconds, default 1 hours
"""
flags: Optional[str] = None
flags: str | None = None
"""
flags that the role can use either "r" for read-only or "rw" for read-write
"""
10 changes: 5 additions & 5 deletions scripts/aws/aws_helper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from os import environ
from time import sleep
from typing import Any, Dict, List, NamedTuple, Optional
from typing import Any, NamedTuple
from urllib.parse import urlparse

from boto3 import Session
Expand All @@ -15,9 +15,9 @@

aws_profile = environ.get("AWS_PROFILE")
session = Session(profile_name=aws_profile)
sessions: Dict[str, Session] = {}
sessions: dict[str, Session] = {}

bucket_roles: List[CredentialSource] = []
bucket_roles: list[CredentialSource] = []

client_sts = session.client("sts")

Expand Down Expand Up @@ -67,7 +67,7 @@ def get_session(prefix: str) -> Session:
if current_session is not None:
return current_session

extra_args: Dict[str, Any] = {"DurationSeconds": cfg.roleSessionDuration}
extra_args: dict[str, Any] = {"DurationSeconds": cfg.roleSessionDuration}

if cfg.externalId:
extra_args["ExternalId"] = cfg.externalId
Expand Down Expand Up @@ -119,7 +119,7 @@ def get_session_credentials(prefix: str, retry_count: int = 3) -> ReadOnlyCreden
raise last_error


def _get_credential_config(prefix: str) -> Optional[CredentialSource]:
def _get_credential_config(prefix: str) -> CredentialSource | None:
"""Get the credential config (`bucket-config`) for the `prefix`.

Args:
Expand Down
16 changes: 8 additions & 8 deletions scripts/cli/cli_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
from datetime import datetime
from os import environ
from typing import List, NamedTuple, Optional
from typing import NamedTuple

from linz_logger import get_log

Expand All @@ -16,10 +16,10 @@ class InputParameterError(Exception):

class TileFiles(NamedTuple):
output: str
inputs: List[str]
inputs: list[str]


def get_tile_files(source: str) -> List[TileFiles]:
def get_tile_files(source: str) -> list[TileFiles]:
"""Transform a JSON string representing a list of input file paths and output tile name created
by `argo-tasks` (see examples) to a list of `TileFiles`

Expand All @@ -37,7 +37,7 @@ def get_tile_files(source: str) -> List[TileFiles]:
[TileFiles(output='CE16_5000_1001', inputs=['s3://bucket/SN9457_CE16_10k_0501.tif'])]
"""
try:
source_json: List[TileFiles] = json.loads(
source_json: list[TileFiles] = json.loads(
source, object_hook=lambda d: TileFiles(inputs=d["input"], output=d["output"])
)
except (json.decoder.JSONDecodeError, KeyError) as e:
Expand All @@ -47,7 +47,7 @@ def get_tile_files(source: str) -> List[TileFiles]:
return source_json


def load_input_files(path: str) -> List[TileFiles]:
def load_input_files(path: str) -> list[TileFiles]:
"""Load the TileFiles from a JSON input file containing a list of output and input files.
Args:
path: path to a JSON file listing output name and input files
Expand All @@ -58,7 +58,7 @@ def load_input_files(path: str) -> List[TileFiles]:
source = json.dumps(json.loads(read(path)))

try:
tile_files: List[TileFiles] = get_tile_files(source)
tile_files: list[TileFiles] = get_tile_files(source)
return tile_files
except InputParameterError as e:
get_log().error("An error occurred while getting tile_files", error=str(e))
Expand All @@ -77,7 +77,7 @@ def valid_date(s: str) -> datetime:
raise argparse.ArgumentTypeError(msg) from e


def parse_list(list_s: str, separator: Optional[str] = ";") -> List[str]:
def parse_list(list_s: str, separator: str | None = ";") -> list[str]:
"""Transform a string representing a list to a list of strings
example: "foo; bar; foo bar" -> ["foo", "bar", "foo bar"]

Expand All @@ -93,7 +93,7 @@ def parse_list(list_s: str, separator: Optional[str] = ";") -> List[str]:
return []


def coalesce_multi_single(multi_items: Optional[str], single_item: Optional[str]) -> List[str]:
def coalesce_multi_single(multi_items: str | None, single_item: str | None) -> list[str]:
"""Coalesce strings containing either semicolon delimited values or a single
value into a list. `single_item` is used only if `multi_items` is falsy.
If both are falsy, an empty list is returned.
Expand Down
4 changes: 1 addition & 3 deletions scripts/cli/tests/cli_helper_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

from pytest_subtests import SubTests

from scripts.cli.cli_helper import TileFiles, coalesce_multi_single, get_tile_files, parse_list
Expand All @@ -12,7 +10,7 @@ def test_get_tile_files(subtests: SubTests) -> None:
expected_output_filename_b = "tile_name2"
expected_input_filenames = ["file_a.tiff", "file_b.tiff"]

source: List[TileFiles] = get_tile_files(file_source)
source: list[TileFiles] = get_tile_files(file_source)
with subtests.test():
assert expected_output_filename == source[0].output

Expand Down
3 changes: 1 addition & 2 deletions scripts/collection_from_items.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse
import json
import os
from typing import List

import shapely.geometry
import shapely.ops
Expand Down Expand Up @@ -89,7 +88,7 @@ def main() -> None:
arguments = parser.parse_args()
uri = arguments.uri

providers: List[Provider] = []
providers: list[Provider] = []
for producer_name in coalesce_multi_single(arguments.producer_list, arguments.producer):
providers.append({"name": producer_name, "roles": [ProviderRole.PRODUCER]})
for licensor_name in coalesce_multi_single(arguments.licensor_list, arguments.licensor):
Expand Down
20 changes: 10 additions & 10 deletions scripts/files/file_tiff.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from decimal import Decimal
from enum import Enum
from typing import Annotated, Any, Dict, List, Optional
from typing import Annotated, Any
from urllib.parse import unquote

from scripts.gdal.gdal_helper import GDALExecutionException, gdal_info, run_gdal
Expand All @@ -28,8 +28,8 @@ class FileTiff:

def __init__(
self,
paths: List[str],
preset: Optional[str] = None,
paths: list[str],
preset: str | None = None,
) -> None:
paths_original = []
for p in paths:
Expand All @@ -40,9 +40,9 @@ def __init__(

self._paths_original = paths_original
self._path_standardised = ""
self._errors: List[Dict[str, Any]] = []
self._gdalinfo: Optional[GdalInfo] = None
self._srs: Optional[bytes] = None
self._errors: list[dict[str, Any]] = []
self._gdalinfo: GdalInfo | None = None
self._srs: bytes | None = None
if preset == "dem_lerc":
self._tiff_type = FileTiffType.DEM
else:
Expand Down Expand Up @@ -112,7 +112,7 @@ def set_path_standardised(self, path: str) -> None:
"""
self._path_standardised = path

def get_gdalinfo(self, path: Optional[str] = None) -> Optional[GdalInfo]:
def get_gdalinfo(self, path: str | None = None) -> GdalInfo | None:
"""Get the `gdalinfo` output for the file.
Run gdalinfo if not already ran or if different path is specified.
`path` is useful to specify a local file to avoid downloading from external source.
Expand Down Expand Up @@ -141,15 +141,15 @@ def get_gdalinfo(self, path: Optional[str] = None) -> Optional[GdalInfo]:
self.add_error(error_type=FileTiffErrorType.GDAL_INFO, error_message=f"error(s): {str(e)}")
return self._gdalinfo

def get_errors(self) -> List[Dict[str, Any]]:
def get_errors(self) -> list[dict[str, Any]]:
"""Get the Non Visual QA errors.

Returns:
a list of errors
"""
return self._errors

def get_paths_original(self) -> List[str]:
def get_paths_original(self) -> list[str]:
"""Get the path(es) of the original (non standardised) file.
It can be a list of path if the standardised file is a retiled image.

Expand All @@ -175,7 +175,7 @@ def get_tiff_type(self) -> FileTiffType:
return self._tiff_type

def add_error(
self, error_type: FileTiffErrorType, error_message: str, custom_fields: Optional[Dict[str, str]] = None
self, error_type: FileTiffErrorType, error_message: str, custom_fields: dict[str, str] | None = None
) -> None:
"""Add an error in Non Visual QA errors list.

Expand Down
16 changes: 7 additions & 9 deletions scripts/files/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING

from boto3 import resource
from linz_logger import get_log
Expand All @@ -17,7 +17,7 @@
S3Client = dict


def write(destination: str, source: bytes, content_type: Optional[str] = None) -> str:
def write(destination: str, source: bytes, content_type: str | None = None) -> str:
"""Write a file from its source to a destination path.

Args:
Expand Down Expand Up @@ -87,16 +87,14 @@ def exists(path: str) -> bool:
return fs_local.exists(path)


def modified(path: str, s3_client: Optional[S3Client] = None) -> datetime:
def modified(path: str, s3_client: S3Client | None = None) -> datetime:
"""Get modified datetime for S3 URL or local path"""
if is_s3(path):
return fs_s3.modified(fs_s3.bucket_name_from_path(path), fs_s3.prefix_from_path(path), s3_client)
return fs_local.modified(Path(path))


def write_all(
inputs: List[str], target: str, concurrency: Optional[int] = 4, generate_name: Optional[bool] = True
) -> List[str]:
def write_all(inputs: list[str], target: str, concurrency: int | None = 4, generate_name: bool | None = True) -> list[str]:
"""Writes list of files to target destination using multithreading.
Args:
inputs: list of files to read
Expand All @@ -107,7 +105,7 @@ def write_all(
Returns:
list of written file paths
"""
written_tiffs: List[str] = []
written_tiffs: list[str] = []
with ThreadPoolExecutor(max_workers=concurrency) as executor:
futuress = {write_file(executor, input_, target, generate_name): input_ for input_ in inputs}
for future in as_completed(futuress):
Expand All @@ -123,7 +121,7 @@ def write_all(
return written_tiffs


def write_sidecars(inputs: List[str], target: str, concurrency: Optional[int] = 4) -> None:
def write_sidecars(inputs: list[str], target: str, concurrency: int | None = 4) -> None:
"""Writes list of files (if found) to target destination using multithreading.
The copy of the files have a generated file name (@see `write_file`)

Expand All @@ -142,7 +140,7 @@ def write_sidecars(inputs: List[str], target: str, concurrency: Optional[int] =
get_log().info("wrote_sidecar_file", path=future.result())


def write_file(executor: ThreadPoolExecutor, input_: str, target: str, generate_name: Optional[bool] = True) -> Future[str]:
def write_file(executor: ThreadPoolExecutor, input_: str, target: str, generate_name: bool | None = True) -> Future[str]:
"""Read a file from a path and write it to a target path.
Args:
executor: A ThreadPoolExecutor instance.
Expand Down
13 changes: 7 additions & 6 deletions scripts/files/fs_s3.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from collections.abc import Generator
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import TYPE_CHECKING, Any, Generator, List, Optional, Union
from typing import TYPE_CHECKING, Any

from boto3 import client, resource
from botocore.exceptions import ClientError
Expand All @@ -18,7 +19,7 @@
S3Client = GetObjectOutputTypeDef = dict


def write(destination: str, source: bytes, content_type: Optional[str] = None) -> None:
def write(destination: str, source: bytes, content_type: str | None = None) -> None:
"""Write a source (bytes) in a AWS s3 destination (path in a bucket).

Args:
Expand Down Expand Up @@ -172,7 +173,7 @@ def prefix_from_path(path: str) -> str:
return path.replace(f"s3://{bucket_name}/", "")


def list_files_in_uri(uri: str, suffixes: List[str], s3_client: Optional[S3Client]) -> List[str]:
def list_files_in_uri(uri: str, suffixes: list[str], s3_client: S3Client | None) -> list[str]:
"""Get a list of file paths from a s3 path based on their suffixes

Args:
Expand Down Expand Up @@ -215,8 +216,8 @@ def _get_object(bucket: str, file_name: str, s3_client: S3Client) -> GetObjectOu


def get_object_parallel_multithreading(
bucket: str, files_to_read: List[str], s3_client: Optional[S3Client], concurrency: int
) -> Generator[Any, Union[Any, BaseException], None]:
bucket: str, files_to_read: list[str], s3_client: S3Client | None, concurrency: int
) -> Generator[Any, Any | BaseException, None]:
"""Get s3 objects in parallel

Args:
Expand All @@ -242,6 +243,6 @@ def get_object_parallel_multithreading(
yield key, exception


def modified(bucket_name: str, key: str, s3_client: Optional[S3Client]) -> datetime:
def modified(bucket_name: str, key: str, s3_client: S3Client | None) -> datetime:
s3_client = s3_client or client("s3")
return _get_object(bucket_name, key, s3_client)["LastModified"]
2 changes: 1 addition & 1 deletion scripts/files/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Generator
from shutil import rmtree
from tempfile import mkdtemp
from typing import Generator

import pytest

Expand Down
10 changes: 4 additions & 6 deletions scripts/gdal/gdal_bands.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from typing import List, Optional

from linz_logger import get_log

from scripts.gdal.gdal_helper import gdal_info
from scripts.gdal.gdalinfo import GdalInfo, GdalInfoBand


def find_band(bands: List[GdalInfoBand], color: str) -> Optional[GdalInfoBand]:
def find_band(bands: list[GdalInfoBand], color: str) -> GdalInfoBand | None:
"""Look for a specific colorInterperation inside of a `gdalinfo` band output.

Args:
Expand All @@ -23,7 +21,7 @@ def find_band(bands: List[GdalInfoBand], color: str) -> Optional[GdalInfoBand]:


# pylint: disable-msg=too-many-return-statements
def get_gdal_band_offset(file: str, info: Optional[GdalInfo] = None, preset: Optional[str] = None) -> List[str]:
def get_gdal_band_offset(file: str, info: GdalInfo | None = None, preset: str | None = None) -> list[str]:
"""Get the banding parameters for a `gdal_translate` command.

Args:
Expand All @@ -39,7 +37,7 @@ def get_gdal_band_offset(file: str, info: Optional[GdalInfo] = None, preset: Opt

bands = info["bands"]

band_alpha_arg: List[str] = []
band_alpha_arg: list[str] = []
if band_alpha := find_band(bands, "Alpha"):
band_alpha_arg.extend(["-b", str(band_alpha["band"])])

Expand Down Expand Up @@ -87,7 +85,7 @@ def get_gdal_band_offset(file: str, info: Optional[GdalInfo] = None, preset: Opt
return ["-b", str(band_red["band"]), "-b", str(band_green["band"]), "-b", str(band_blue["band"])] + band_alpha_arg


def get_gdal_band_type(file: str, info: Optional[GdalInfo] = None) -> str:
def get_gdal_band_type(file: str, info: GdalInfo | None = None) -> str:
"""Get the band type of the first band.

Args:
Expand Down
Loading
Loading