Skip to content

Commit

Permalink
static type analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
Altay Sansal committed Mar 1, 2024
1 parent 85f9fb4 commit 6c9ea41
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 27 deletions.
2 changes: 1 addition & 1 deletion src/segy/standards/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
register_spec(SegyStandard.REV1, rev1_segy)


__all__ = ["rev0_segy", "rev1_segy"]
__all__ = ["rev0_segy", "rev1_segy", "SegyStandard"]
15 changes: 9 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def _make_header_field_descriptor(
dt_string: str = "i2",
names: list[str] | None = None,
offsets: list[int] | None = None,
endianness: str = "big",
) -> dict:
endianness: Endianness = Endianness.BIG,
) -> dict[str, list[HeaderFieldDescriptor] | int]:
"""Convenience function for creating parameters needed for descriptors.
Args:
Expand Down Expand Up @@ -92,7 +92,7 @@ def _make_trace_header_descriptor(
dt_string: str = "i2",
names: list[str] | None = None,
offsets: list[int] | None = None,
endianness: str = "big",
endianness: str = Endianness.BIG,
) -> TraceHeaderDescriptor:
"""Convenience function for creating TraceHeaderDescriptors.
Expand Down Expand Up @@ -124,7 +124,7 @@ def make_trace_data_descriptor() -> Callable:
def _make_trace_data_descriptor(
format: ScalarType = ScalarType.IBM32, # noqa: A002
endianness: Endianness = Endianness.BIG,
description: str = None,
description: str | None = None,
samples: int = 10,
) -> TraceDataDescriptor:
"""Convenience function for creating TraceDataDescriptors.
Expand Down Expand Up @@ -154,7 +154,10 @@ def make_trace_descriptor(
) -> Callable:
"""Fixture wrapper for helper function to create TraceDescriptors."""

def _make_trace_descriptor(head_params: dict, data_params: dict) -> TraceDescriptor:
def _make_trace_descriptor(
head_params: dict[str, str | list[str] | Endianness],
data_params: dict[str, str | int | Endianness],
) -> TraceDescriptor:
"""Convenience function for creating TraceDescriptor object.
Args:
Expand All @@ -181,7 +184,7 @@ def _make_binary_header_descriptor(
dt_string: str = "i2",
names: list[str] | None = None,
offsets: list[int] | None = None,
endianness: str = "big",
endianness: str = Endianness.BIG,
) -> BinaryHeaderDescriptor:
"""Helper function for creating BinaryHeaderDescriptor objects.
Expand Down
2 changes: 1 addition & 1 deletion tests/main/test_ibm_ieee.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_single_ibm_to_ieee(self, ieee: float, ibm: int) -> None:


@pytest.mark.parametrize("shape", [(1,), (10,), (20, 20), (150, 150)])
def test_ieee_to_ibm_roundtrip(shape: tuple) -> None:
def test_ieee_to_ibm_roundtrip(shape: tuple[int, ...]) -> None:
"""Convert values from IEEE to IBM and back to IEEE."""
rng = np.random.default_rng()
expected_ieee = rng.normal(size=shape).astype("float32")
Expand Down
15 changes: 10 additions & 5 deletions tests/main/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from segy.indexing import bounds_check
from segy.indexing import merge_cat_file
from segy.indexing import trace_ibm2ieee_inplace
from segy.schema import Endianness

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -80,13 +81,17 @@ def test_merge_cat_file(
("header_params", "data_params", "float_vals"),
[
(
{"dt_string": "i2,i4", "names": ["a", "b"], "endianness": "little"},
{"format": "uint32", "endianness": "little", "samples": 3},
{
"dt_string": "i2,i4",
"names": ["a", "b"],
"endianness": Endianness.LITTLE,
},
{"format": "uint32", "endianness": Endianness.LITTLE, "samples": 3},
[0.0, 0.1, 3.141593],
),
(
{"dt_string": "i2,i4", "names": ["a", "b"], "endianness": "big"},
{"format": "uint32", "endianness": "big", "samples": 3},
{"dt_string": "i2,i4", "names": ["a", "b"], "endianness": Endianness.BIG},
{"format": "uint32", "endianness": Endianness.BIG, "samples": 3},
[1.01, -2.01, 33.11],
),
],
Expand All @@ -102,7 +107,7 @@ def test_trace_ibm2ieee_inplace(
samp_trace = np.zeros(1, dtype=trace_descr.dtype)
ieee_floats = np.array(float_vals, dtype="<f4")
ibm_floats = ieee2ibm(ieee_floats)
if trace_descr.data_descriptor.endianness == "big":
if trace_descr.data_descriptor.endianness == Endianness.BIG:
# emulating how trace indexer swaps byte order
# of big endian data types
samp_trace = samp_trace.byteswap(inplace=True).newbyteorder()
Expand Down
5 changes: 3 additions & 2 deletions tests/schema/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from segy.schema import TraceDataDescriptor
from segy.schema import TraceHeaderDescriptor
from segy.schema.data_type import DataTypeDescriptor
from segy.schema.data_type import Endianness

# Constants defined for ScalarType
DTYPE_FORMATS = [s.value for s in ScalarType]

DTYPE_ENDIANNESS = ["little", "big"]
DTYPE_ENDIANNESS = [Endianness.LITTLE, Endianness.BIG]

# For cases where a description is supplied or not
DTYPE_DESCRIPTIONS = [None, "this is a data type description"]
Expand Down Expand Up @@ -86,7 +87,7 @@ def trace_header_descriptors(

def generate_unique_names(count: int) -> list[str]:
"""Helper function to create random unique names as placeholders during testing."""
names: set = set()
names: set[str] = set()
rng = np.random.default_rng()
while len(names) < count:
name_length = rng.integers(5, 10) # noqa: S311
Expand Down
20 changes: 11 additions & 9 deletions tests/schema/test_data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any

import numpy as np
import pytest
Expand Down Expand Up @@ -37,7 +36,7 @@
("float16", "f2"),
],
)
@pytest.mark.parametrize("endianness", ["big", "little"])
@pytest.mark.parametrize("endianness", [Endianness.BIG, Endianness.LITTLE])
@pytest.mark.parametrize("description", [None, "this is a data type description"])
def test_data_type_descriptor(
format_: tuple[str, str], endianness: str, description: str | None
Expand Down Expand Up @@ -95,34 +94,37 @@ def build_sfd_helper(
)


def build_sdt_fields(*params: dict[str, Any]) -> list[StructuredFieldDescriptor]:
def build_sdt_fields(
*params: tuple[str, str, str, int],
) -> tuple[StructuredFieldDescriptor, ...]:
"""Convenience for creating a list of StructuredFieldDescriptors."""
return [build_sfd_helper(*p) for p in params]
return tuple(build_sfd_helper(*p) for p in params)


@pytest.mark.parametrize(
("fields", "item_size", "offset"),
[
(
build_sdt_fields(
("int32", "little", "varA", 2),
("int16", "little", "varB", 0),
("int32", "little", "varC", 6),
("int32", Endianness.LITTLE, "varA", 2),
("int16", Endianness.LITTLE, "varB", 0),
("int32", Endianness.LITTLE, "varC", 6),
),
10,
0,
),
(
build_sdt_fields(
("float32", "big", "varA", 0), ("float32", "big", "varB", 4)
("float32", Endianness.BIG, "varA", 0),
("float32", Endianness.BIG, "varB", 4),
),
8,
12,
),
],
)
def test_structured_data_type_descriptor(
fields: list[StructuredFieldDescriptor], item_size: int, offset: int
fields: tuple[StructuredFieldDescriptor, ...], item_size: int, offset: int
) -> None:
"""This tests for creatin a StructuredDataTypeDescriptor for different component data types."""
new_sdtd = StructuredDataTypeDescriptor(
Expand Down
2 changes: 1 addition & 1 deletion tests/schema/test_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def void_buffer(buff_size: int) -> np.ndarray:


def get_dt_info(
dt: np.dtype,
dt: np.dtype[Any],
atrnames: list[str] | None = None,
) -> dict:
"""Helper function to get info about a numpy dtype."""
Expand Down
4 changes: 2 additions & 2 deletions tests/standards/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_get_spec(standard_enum: SegyStandard, base_spec: SegyDescriptor) -> Non
def test_get_nonexistent_spec_error() -> None:
"""Test missing / non-existent SegyStandard from registry."""
with pytest.raises(NotImplementedError):
get_spec("non_existent")
get_spec("non_existent") # type: ignore


def test_register_custom_descriptor() -> None:
Expand All @@ -37,4 +37,4 @@ def test_register_nondescriptor_error() -> None:
"""Test if not providing a descriptor to registration."""
msg = "spec_cls must be a subclass of SegyDescriptor."
with pytest.raises(ValueError, match=msg):
register_spec(SegyStandard.CUSTOM, "not_a_descriptor")
register_spec(SegyStandard.CUSTOM, "not_a_descriptor") # type: ignore

0 comments on commit 6c9ea41

Please sign in to comment.