Skip to content

Commit

Permalink
Populate remaining tests for schemas. (#29)
Browse files Browse the repository at this point in the history
* added tests for remaining components in schema files.

* removed unused functions in tests/schema/conftest.py

* fixed changes for BinaryFileHeader and TraceTextHeader to StructureDataTypeDescriptor

* removed unnecessary roundtrip json to model from tests
  • Loading branch information
ta-hill authored Mar 5, 2024
1 parent 8ae29ad commit 4062c68
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 51 deletions.
21 changes: 21 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

from segy.schema import Endianness
from segy.schema import ScalarType
from segy.schema import TextHeaderDescriptor
from segy.schema import TextHeaderEncoding
from segy.schema import TraceDataDescriptor
from segy.schema import TraceDescriptor
from segy.schema.data_type import StructuredDataTypeDescriptor
Expand Down Expand Up @@ -228,6 +230,25 @@ def _make_binary_header_descriptor(
return _make_binary_header_descriptor


@pytest.fixture(scope="module")
def make_text_header_descriptor() -> Callable[..., TextHeaderDescriptor]:
"""Fixture wrapper around helper function for creating BinaryHeaderDescriptor."""

def _make_text_header_descriptor(
rows: int = 40,
cols: int = 80,
encoding: TextHeaderEncoding = TextHeaderEncoding.EBCDIC,
format: ScalarType = ScalarType.UINT8, # noqa: A002
offset: int | None = None,
) -> TextHeaderDescriptor:
"""Helper function for creating text header descriptor objects."""
return TextHeaderDescriptor(
rows=rows, cols=cols, encoding=encoding, format=format, offset=offset
)

return _make_text_header_descriptor


def generate_unique_names(count: int) -> list[str]:
"""Helper function to create random unique names as placeholders during testing."""
names: set[str] = set()
Expand Down
95 changes: 60 additions & 35 deletions tests/schema/conftest.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Tests for the compontent classes in the schema directory."""

import itertools
import string
from collections.abc import Callable
from collections.abc import Generator
from typing import Any

import numpy as np
import pytest

from segy.schema import ScalarType
from segy.schema import TextHeaderDescriptor
from segy.schema import TraceDataDescriptor
from segy.schema.data_type import DataTypeDescriptor
from segy.schema.data_type import Endianness
Expand All @@ -34,14 +35,31 @@
",".join([*["i4"] * 3, *["i2"] * 4, *["i4"] * 5, *["i2"] * 6]),
]

TEXT_HEADER_DESCRIPTORS_PARAMS = [
{"rows": 40, "cols": 80, "encoding": "ascii", "format": "uint8", "offset": 0},
{"rows": 40, "cols": 80, "encoding": "ebcdic", "format": "uint8", "offset": 0},
]

@pytest.fixture(
params=[
(dt_string, None, dt_endian)
for dt_string in BINARY_HEADER_TEST_DTYPE_STRINGS
for dt_endian in DTYPE_ENDIANNESS
]
)

BINARY_HEADER_DESCRIPTORS_PARAMS = [
(dt_string, None, None, dt_endian)
for dt_string in BINARY_HEADER_TEST_DTYPE_STRINGS
for dt_endian in DTYPE_ENDIANNESS
]

TRACE_HEADER_DESCRIPTORS_PARAMS = [
(dt_string, dt_endian)
for dt_string in TRACE_HEADER_TEST_DTYPE_STRINGS
for dt_endian in DTYPE_ENDIANNESS
]

TRACE_DATA_DESCRIPTORS_PARAMS = [
(p1, p2, DTYPE_DESCRIPTIONS[1], 100)
for p1, p2 in zip(DTYPE_FORMATS, itertools.cycle(DTYPE_ENDIANNESS))
]


@pytest.fixture(params=BINARY_HEADER_DESCRIPTORS_PARAMS)
def binary_header_descriptors(
request: pytest.FixtureRequest,
make_binary_header_descriptor: Callable[..., StructuredDataTypeDescriptor],
Expand All @@ -56,17 +74,14 @@ def binary_header_descriptors(
Structured data type descriptor object for binary header
"""
return make_binary_header_descriptor(
dt_string=request.param[0], names=request.param[1], endianness=request.param[2]
dt_string=request.param[0],
names=request.param[1],
offsets=request.param[2],
endianness=request.param[3],
)


@pytest.fixture(
params=[
(dt_string, dt_endian)
for dt_string in TRACE_HEADER_TEST_DTYPE_STRINGS
for dt_endian in DTYPE_ENDIANNESS
]
)
@pytest.fixture(params=TRACE_HEADER_DESCRIPTORS_PARAMS)
def trace_header_descriptors(
request: pytest.FixtureRequest,
make_trace_header_descriptor: Callable[..., StructuredDataTypeDescriptor],
Expand All @@ -86,18 +101,6 @@ def trace_header_descriptors(
)


def generate_unique_names(count: int) -> list[str]:
"""Helper function to create random unique names as placeholders during testing."""
names: set[str] = set()
rng = np.random.default_rng()
while len(names) < count:
name_length = rng.integers(5, 10) # noqa: S311
letters = rng.choice(list(string.ascii_uppercase), size=name_length) # noqa: S311
name = "".join(letters)
names.add(name)
return list(names)


@pytest.fixture(
params=[
(p1, p2, DTYPE_DESCRIPTIONS[1])
Expand All @@ -113,12 +116,7 @@ def data_types(request: pytest.FixtureRequest) -> DataTypeDescriptor:
)


@pytest.fixture(
params=[
(p1, p2, DTYPE_DESCRIPTIONS[1], 100)
for p1, p2 in zip(DTYPE_FORMATS, itertools.cycle(DTYPE_ENDIANNESS))
]
)
@pytest.fixture(params=TRACE_DATA_DESCRIPTORS_PARAMS)
def trace_data_descriptors(
request: pytest.FixtureRequest,
make_trace_data_descriptor: Callable[..., TraceDataDescriptor],
Expand Down Expand Up @@ -188,3 +186,30 @@ def text_header_samples(
) -> str:
"""Fixture that generates fixed size text header test data from strings."""
return format_str_to_text_header(request.param)


@pytest.fixture()
def custom_segy_file_descriptors(
request: pytest.FixtureRequest,
make_text_header_descriptor: Callable[..., TextHeaderDescriptor],
make_binary_header_descriptor: Callable[..., StructuredDataTypeDescriptor],
make_trace_header_descriptor: Callable[..., StructuredDataTypeDescriptor],
make_trace_data_descriptor: Callable[..., TraceDataDescriptor],
) -> Generator[dict[str, Any], None, None]:
"""Helper fixture to return a requested number of custom segy file descriptor params."""
num_file_descriptors = getattr(request, "params", 1)
for i in range(num_file_descriptors):
yield {
"text_header_descriptor": make_text_header_descriptor(
*TEXT_HEADER_DESCRIPTORS_PARAMS[i].values()
),
"binary_header_descriptor": make_binary_header_descriptor(
*BINARY_HEADER_DESCRIPTORS_PARAMS[i]
),
"trace_header_descriptor": make_trace_header_descriptor(
*TRACE_HEADER_DESCRIPTORS_PARAMS[i]
),
"trace_data_descriptor": make_trace_data_descriptor(
*TRACE_DATA_DESCRIPTORS_PARAMS[i]
),
}
73 changes: 71 additions & 2 deletions tests/schema/test_data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import string
from typing import TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -121,17 +122,28 @@ def build_sdt_fields(
8,
12,
),
(
build_sdt_fields(
("ibm32", Endianness.BIG, "varA", 0),
("float32", Endianness.BIG, "varB", 4),
),
None,
12,
),
],
)
def test_structured_data_type_descriptor(
fields: tuple[StructuredFieldDescriptor, ...], item_size: int, offset: int
fields: tuple[StructuredFieldDescriptor, ...], item_size: int | None, offset: int
) -> None:
"""This tests for creatin a StructuredDataTypeDescriptor for different component data types."""
new_sdtd = StructuredDataTypeDescriptor(
fields=list(fields), item_size=item_size, offset=offset
)
assert new_sdtd.dtype.names == tuple([f.name for f in fields])
assert new_sdtd.item_size == new_sdtd.dtype.itemsize
if item_size is not None:
assert new_sdtd.item_size == new_sdtd.dtype.itemsize
else:
assert new_sdtd.item_size == item_size


def test_trace_data_descriptors(trace_data_descriptors: TraceDataDescriptor) -> None:
Expand All @@ -147,3 +159,60 @@ def test_trace_data_descriptors(trace_data_descriptors: TraceDataDescriptor) ->
expected = np.dtype(f"({samples},){endianness}{format_}")

assert trace_data_descriptors.dtype == expected


@pytest.mark.parametrize(
("json_string", "expected"),
[
(
"""
{"description": "dummy description",
"fields": [{"description": "description of field_one",
"format": "int32","endianness": "big",
"name": "field_one", "offset": 0} ,
{"description": "description of field_two",
"format": "ibm32", "endianness": "big",
"name": "field_two", "offset": 4}],
"itemSize": 8,"offset": 200}
""",
StructuredDataTypeDescriptor(
description="dummy description",
fields=[
StructuredFieldDescriptor(
description="description of field_one",
format=ScalarType.INT32,
endianness=Endianness.BIG,
name="field_one",
offset=0,
),
StructuredFieldDescriptor(
description="description of field_two",
format=ScalarType.IBM32,
endianness=Endianness.BIG,
name="field_two",
offset=4,
),
],
item_size=8,
offset=200,
),
)
],
)
def test_validate_json_structured_data_type_descriptor(
json_string: str, expected: StructuredDataTypeDescriptor
) -> None:
"""Test for validating recreating a StrucrutedDataTypeDescriptor from a JSON string."""
validated_json = StructuredDataTypeDescriptor.model_validate_json(json_string)
assert validated_json.description == expected.description
assert validated_json.fields == expected.fields
assert validated_json.item_size == expected.item_size
assert validated_json.offset == expected.offset
assert _compare_json_strings(json_string, expected.model_dump_json())


def _compare_json_strings(s1: str, s2: str) -> bool:
"""Helper function for clearing whitespace to compare json strings."""
remove = string.whitespace
mapping = {ord(c): None for c in remove}
return s1.translate(mapping) == s2.translate(mapping)
32 changes: 18 additions & 14 deletions tests/schema/test_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def test_full_text_headers(
len(split_lines),
len(split_lines[0]),
)
# assertions for exception cases
with pytest.raises(ValueError, match="Text length must be equal to rows x cols."):
new_text_head_desc._encode(text_header_samples[:-10])
with pytest.raises(
ValueError, match="rows x cols must be equal wrapped text length."
):
new_text_head_desc._wrap(text_header_samples[:-10])
assert "\n" not in new_text_head_desc._unwrap(text_header_samples)


def test_binary_header_descriptors(
Expand Down Expand Up @@ -89,27 +97,23 @@ def void_buffer(buff_size: int) -> npt.NDArray[np.void]:
Prefills with random bytes.
"""
rng = np.random.default_rng()
new_void_buffer = None
if isinstance(buff_size, int):
new_void_buffer = np.frombuffer(rng.bytes(buff_size), dtype=np.void(buff_size))
return new_void_buffer
return np.frombuffer(rng.bytes(buff_size), dtype=np.void(buff_size))


def get_dt_info(
dt: np.dtype[Any],
atrnames: list[str] | None = None,
) -> dict[str, Any]:
"""Helper function to get info about a numpy dtype."""
if atrnames is None:
atrnames = [
"descr",
"str",
"fields",
"itemsize",
"byteorder",
"shape",
"names",
]
atrnames = [
"descr",
"str",
"fields",
"itemsize",
"byteorder",
"shape",
"names",
]
dt_info = dict(zip(atrnames, operator.attrgetter(*atrnames)(dt)))
dt_info["offsets"] = [f[-1] for f in dt_info["fields"].values()]
dt_info["combo_str"] = ",".join([f[1] for f in dt_info["descr"]])
Expand Down
46 changes: 46 additions & 0 deletions tests/schema/test_segy_descriptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""tests for SegyDescriptor components."""
from typing import Any

import pytest

from segy.standards import SegyStandard
from segy.standards.registry import get_spec


@pytest.mark.parametrize(
("rev_number", "custom_segy_file_descriptors"),
[(0.0, 1), (1.0, 1)],
indirect=["custom_segy_file_descriptors"],
)
def test_custom_segy_descriptor(
rev_number: float, custom_segy_file_descriptors: dict[str, Any]
) -> None:
"""Test for creating customized SegyDescriptor."""
rev_spec = get_spec(SegyStandard(rev_number))
custom_spec = rev_spec.customize(
text_header_spec=custom_segy_file_descriptors["text_header_descriptor"],
binary_header_fields=custom_segy_file_descriptors[
"binary_header_descriptor"
].fields,
extended_text_spec=None,
trace_header_fields=custom_segy_file_descriptors[
"trace_header_descriptor"
].fields,
trace_data_spec=custom_segy_file_descriptors["trace_data_descriptor"],
)
assert (
custom_spec.text_file_header
== custom_segy_file_descriptors["text_header_descriptor"]
)
assert (
custom_spec.binary_file_header.fields
== custom_segy_file_descriptors["binary_header_descriptor"].fields
)
assert (
custom_spec.trace.header_descriptor.fields
== custom_segy_file_descriptors["trace_header_descriptor"].fields
)
assert (
custom_spec.trace.data_descriptor
== custom_segy_file_descriptors["trace_data_descriptor"]
)

0 comments on commit 4062c68

Please sign in to comment.