Skip to content

Commit

Permalink
Improve test coverage of error cases
Browse files Browse the repository at this point in the history
  • Loading branch information
eguiraud committed Oct 29, 2023
1 parent fa225b2 commit 4e01912
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,26 @@

schemas = {
"scale": schemav2.Correction(
name="test scalar",
name="constant scale that still requires 1 input",
version=2,
inputs=[schemav2.Variable(name="x", type="real")],
output=schemav2.Variable(name="weight", type="real"),
data=1.234,
),
"scale-no-input": schemav2.Correction(
name="constant scale that requires no input",
version=2,
inputs=[],
output=schemav2.Variable(name="weight", type="real"),
data=1.234,
),
"scale-two-inputs": schemav2.Correction(
name="constant scale that requires two inputs",
version=2,
inputs=[schemav2.Variable(name="x", type="real"), schemav2.Variable(name="y", type="real")],
output=schemav2.Variable(name="weight", type="real"),
data=1.234,
),
"simple-uniform-binning": schemav2.Correction(
name="simple uniform binning",
version=2,
Expand Down Expand Up @@ -144,6 +158,19 @@ def test_evaluate_scale_nojax():
assert np.allclose(values, [1.234, 1.234])


def test_evaluate_scale_no_input():
cg = CorrectionWithGradient(schemas["scale-no-input"])
value = cg.evaluate()
value.item()
assert math.isclose(value.item(), 1.234)


def test_input_sizes_mismatch():
cg = CorrectionWithGradient(schemas["scale-two-inputs"])
with pytest.raises(ValueError, match="The shapes of all non-scalar inputs should match."):
cg.evaluate([1.0, 2.0], [3.0, 4.0, 5.0])


@pytest.mark.parametrize("jit", [False, True])
def test_evaluate_scale(jit):
cg = CorrectionWithGradient(schemas["scale"])
Expand Down

0 comments on commit 4e01912

Please sign in to comment.