diff --git a/smithy-model/src/main/java/software/amazon/smithy/model/Model.java b/smithy-model/src/main/java/software/amazon/smithy/model/Model.java index 498a9c36e49..dac1702ead9 100644 --- a/smithy-model/src/main/java/software/amazon/smithy/model/Model.java +++ b/smithy-model/src/main/java/software/amazon/smithy/model/Model.java @@ -202,6 +202,27 @@ public Shape expectShape(ShapeId id) { "Shape not found in model: " + id, SourceLocation.NONE)); } + /** + * Attempts to retrieve a {@link Shape} by {@link ShapeId} and + * throws if not found or if the shape is not of the expected type. + * + * @param id Shape to retrieve by ID. + * @param type Shape type to expect and convert to. + * @return Returns the shape. + * @throws ExpectationNotMetException if the shape is not found or is not the expected type. + */ + @SuppressWarnings("unchecked") + public T expectShape(ShapeId id, Class type) { + Shape shape = expectShape(id); + if (type.isInstance(shape)) { + return (T) shape; + } + + throw new ExpectationNotMetException(String.format( + "Expected shape `%s` to be an instance of `%s`, but found `%s`", + id, type.getSimpleName(), shape.getType()), shape); + } + /** * Gets a stream of {@link Shape}s in the index. * diff --git a/smithy-model/src/main/java/software/amazon/smithy/model/shapes/Shape.java b/smithy-model/src/main/java/software/amazon/smithy/model/shapes/Shape.java index fb329672128..78a980c176d 100644 --- a/smithy-model/src/main/java/software/amazon/smithy/model/shapes/Shape.java +++ b/smithy-model/src/main/java/software/amazon/smithy/model/shapes/Shape.java @@ -24,6 +24,7 @@ import software.amazon.smithy.model.Model; import software.amazon.smithy.model.SourceException; import software.amazon.smithy.model.SourceLocation; +import software.amazon.smithy.model.node.ExpectationNotMetException; import software.amazon.smithy.model.traits.TagsTrait; import software.amazon.smithy.model.traits.Trait; import software.amazon.smithy.utils.MapUtils; @@ -207,6 +208,19 @@ public final Optional getTrait(Class traitClass) { .map(trait -> (T) trait); } + /** + * Gets specific {@link Trait} by class from the shape or throws if not found. + * + * @param traitClass Trait class to retrieve. + * @param The instance of the trait to retrieve. + * @return Returns the matching trait. + * @throws ExpectationNotMetException if the trait cannot be found. + */ + public final T expectTrait(Class traitClass) { + return getTrait(traitClass).orElseThrow(() -> new ExpectationNotMetException(String.format( + "Expected shape `%s` to have a trait `%s`", getId(), traitClass.getCanonicalName()), this)); + } + /** * Gets all of the traits attached to the shape. * diff --git a/smithy-model/src/test/java/software/amazon/smithy/model/ModelTest.java b/smithy-model/src/test/java/software/amazon/smithy/model/ModelTest.java index 6377907d8a7..d5a9114fac0 100644 --- a/smithy-model/src/test/java/software/amazon/smithy/model/ModelTest.java +++ b/smithy-model/src/test/java/software/amazon/smithy/model/ModelTest.java @@ -21,8 +21,12 @@ import static org.hamcrest.Matchers.not; import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import software.amazon.smithy.model.node.ExpectationNotMetException; import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.model.shapes.IntegerShape; +import software.amazon.smithy.model.shapes.ShapeId; import software.amazon.smithy.model.shapes.StringShape; import software.amazon.smithy.model.traits.TraitDefinition; @@ -56,4 +60,22 @@ public void modelEquality() { assertThat(modelA, not(equalTo(modelB))); assertThat(modelA, not(equalTo(null))); } + + @Test + public void successfullyExpectsShapesOfType() { + StringShape shape = StringShape.builder().id("ns.foo#A").build(); + Model model = Model.builder().addShape(shape).build(); + + assertThat(model.expectShape(ShapeId.from("ns.foo#A"), StringShape.class), equalTo(shape)); + } + + @Test + public void throwsIfShapeNotOfRightType() { + StringShape shape = StringShape.builder().id("ns.foo#A").build(); + Model model = Model.builder().addShape(shape).build(); + + Assertions.assertThrows(ExpectationNotMetException.class, () -> { + model.expectShape(ShapeId.from("ns.foo#A"), IntegerShape.class); + }); + } } diff --git a/smithy-model/src/test/java/software/amazon/smithy/model/shapes/ShapeTest.java b/smithy-model/src/test/java/software/amazon/smithy/model/shapes/ShapeTest.java index 8d47561874f..5c37bb73dfe 100644 --- a/smithy-model/src/test/java/software/amazon/smithy/model/shapes/ShapeTest.java +++ b/smithy-model/src/test/java/software/amazon/smithy/model/shapes/ShapeTest.java @@ -31,7 +31,9 @@ import org.junit.jupiter.api.Test; import software.amazon.smithy.model.Model; import software.amazon.smithy.model.SourceLocation; +import software.amazon.smithy.model.node.ExpectationNotMetException; import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.model.traits.DeprecatedTrait; import software.amazon.smithy.model.traits.DocumentationTrait; import software.amazon.smithy.model.traits.Trait; @@ -126,6 +128,7 @@ public void hasTraits() { assertTrue(shape.getTrait(MyTrait.class).isPresent()); assertTrue(shape.getMemberTrait(model, MyTrait.class).isPresent()); + assertEquals(shape.getTrait(MyTrait.class).get(), shape.expectTrait(MyTrait.class)); assertTrue(shape.findTrait("foo.baz#foo").isPresent()); assertTrue(shape.findMemberTrait(model, "foo.baz#foo").isPresent()); @@ -150,6 +153,13 @@ public void hasTraits() { assertThat(traits, hasItem(documentationTrait)); } + @Test + public void throwsWhenTraitNotFound() { + Shape string = StringShape.builder().id("com.foo#example").build(); + + Assertions.assertThrows(ExpectationNotMetException.class, () -> string.expectTrait(DeprecatedTrait.class)); + } + @Test public void traitsMustNotBeNull() { Assertions.assertThrows(IllegalArgumentException.class, () -> {