Skip to content

Commit

Permalink
Add methods to get shapes by type and trait
Browse files Browse the repository at this point in the history
Model now contains methods for every shape type to get a Set of shapes
or that type or a set of shapes of a type that have a specific trait.
These methods are simpler to use than `shapes(Class)` and
`toSet(Class)`, and they hopefully encourage the use of the caches that
Model uses. These methods also eliminated a lot of boilerplat that
getShapesWithTrait previously required when only shapes of a certain
type were needed.
  • Loading branch information
mtdowling committed May 18, 2021
1 parent c52551b commit be0cc88
Show file tree
Hide file tree
Showing 33 changed files with 702 additions and 202 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,8 @@ public final class ArnIndex implements KnowledgeIndex {

public ArnIndex(Model model) {
// Pre-compute the ARN services.
for (Shape shape : model.getShapesWithTrait(ServiceTrait.class)) {
shape.asServiceShape().ifPresent(service -> {
arnServices.put(service.getId(), service.expectTrait(ServiceTrait.class).getArnNamespace());
});
for (Shape service : model.getServiceShapesWithTrait(ServiceTrait.class)) {
arnServices.put(service.getId(), service.expectTrait(ServiceTrait.class).getArnNamespace());
}

// Pre-compute all of the ArnTemplates in a service shape.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ public final class ArnTemplateValidator extends AbstractValidator {
public List<ValidationEvent> validate(Model model) {
ArnIndex arnIndex = ArnIndex.of(model);
List<ValidationEvent> events = new ArrayList<>();
for (Shape shape : model.getShapesWithTrait(ServiceTrait.class)) {
shape.asServiceShape().ifPresent(service -> events.addAll(validateService(model, arnIndex, service)));
for (ServiceShape service : model.getServiceShapesWithTrait(ServiceTrait.class)) {
events.addAll(validateService(model, arnIndex, service));
}
return events;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.util.Optional;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.validation.AbstractValidator;
import software.amazon.smithy.model.validation.ValidationEvent;
import software.amazon.smithy.utils.MapUtils;
Expand All @@ -44,10 +43,8 @@ public final class EventSourceValidator extends AbstractValidator {
@Override
public List<ValidationEvent> validate(Model model) {
List<ValidationEvent> events = new ArrayList<>();
for (Shape shape : model.getShapesWithTrait(ServiceTrait.class)) {
shape.asServiceShape()
.flatMap(service -> validateService(service, service.expectTrait(ServiceTrait.class)))
.ifPresent(events::add);
for (ServiceShape service : model.getServiceShapesWithTrait(ServiceTrait.class)) {
validateService(service, service.expectTrait(ServiceTrait.class)).ifPresent(events::add);
}
return events;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import java.util.regex.Pattern;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.validation.AbstractValidator;
import software.amazon.smithy.model.validation.ValidationEvent;
import software.amazon.smithy.model.validation.ValidationUtils;
Expand Down Expand Up @@ -71,10 +70,8 @@ public final class SdkServiceIdValidator extends AbstractValidator {
@Override
public List<ValidationEvent> validate(Model model) {
List<ValidationEvent> events = new ArrayList<>();
for (Shape shape : model.getShapesWithTrait(ServiceTrait.class)) {
shape.asServiceShape()
.flatMap(service -> validateService(service, service.expectTrait(ServiceTrait.class)))
.ifPresent(events::add);
for (ServiceShape service : model.getServiceShapesWithTrait(ServiceTrait.class)) {
validateService(service, service.expectTrait(ServiceTrait.class)).ifPresent(events::add);
}
return events;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,13 @@ public Model onRemove(ModelTransformer transformer, Collection<Shape> shapes, Mo

private Set<Shape> getServicesToUpdate(Model model, Set<ShapeId> removedOperations, Set<ShapeId> removedErrors) {
Set<Shape> result = new HashSet<>();
for (Shape shape : model.getShapesWithTrait(ClientEndpointDiscoveryTrait.class)) {
shape.asServiceShape().ifPresent(service -> {
ClientEndpointDiscoveryTrait trait = service.expectTrait(ClientEndpointDiscoveryTrait.class);
if (removedOperations.contains(trait.getOperation()) || removedErrors.contains(trait.getError())) {
ServiceShape.Builder builder = service.toBuilder();
builder.removeTrait(ClientEndpointDiscoveryTrait.ID);
result.add(builder.build());
}
});
for (ServiceShape service : model.getServiceShapesWithTrait(ClientEndpointDiscoveryTrait.class)) {
ClientEndpointDiscoveryTrait trait = service.expectTrait(ClientEndpointDiscoveryTrait.class);
if (removedOperations.contains(trait.getOperation()) || removedErrors.contains(trait.getError())) {
ServiceShape.Builder builder = service.toBuilder();
builder.removeTrait(ClientEndpointDiscoveryTrait.ID);
result.add(builder.build());
}
}
return result;
}
Expand All @@ -92,15 +90,13 @@ private Set<Shape> getOperationsToUpdate(

// Get all endpoint discovery operations
Set<Shape> result = new HashSet<>();
for (Shape shape : model.getShapesWithTrait(ClientDiscoveredEndpointTrait.class)) {
shape.asOperationShape().ifPresent(operation -> {
ClientDiscoveredEndpointTrait trait = operation.expectTrait(ClientDiscoveredEndpointTrait.class);
// Only get the ones where discovery is optional, as it is safe to remove in that case.
// Only get the ones that aren't still bound to a service that requires endpoint discovery.
if (!trait.isRequired() && !stillBoundOperations.contains(operation.getId())) {
result.add(operation.toBuilder().removeTrait(ClientDiscoveredEndpointTrait.ID).build());
}
});
for (OperationShape operation : model.getOperationShapesWithTrait(ClientDiscoveredEndpointTrait.class)) {
ClientDiscoveredEndpointTrait trait = operation.expectTrait(ClientDiscoveredEndpointTrait.class);
// Only get the ones where discovery is optional, as it is safe to remove in that case.
// Only get the ones that aren't still bound to a service that requires endpoint discovery.
if (!trait.isRequired() && !stillBoundOperations.contains(operation.getId())) {
result.add(operation.toBuilder().removeTrait(ClientDiscoveredEndpointTrait.ID).build());
}
}
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,23 @@ public ClientEndpointDiscoveryIndex(Model model) {
TopDownIndex topDownIndex = TopDownIndex.of(model);
OperationIndex opIndex = OperationIndex.of(model);

for (Shape shape : model.getShapesWithTrait(ClientEndpointDiscoveryTrait.class)) {
shape.asServiceShape().ifPresent(service -> {
ClientEndpointDiscoveryTrait trait = service.expectTrait(ClientEndpointDiscoveryTrait.class);
ShapeId endpointOperationId = trait.getOperation();
ShapeId endpointErrorId = trait.getError();
for (ServiceShape service : model.getServiceShapesWithTrait(ClientEndpointDiscoveryTrait.class)) {
ClientEndpointDiscoveryTrait trait = service.expectTrait(ClientEndpointDiscoveryTrait.class);
ShapeId endpointOperationId = trait.getOperation();
ShapeId endpointErrorId = trait.getError();

Optional<OperationShape> endpointOperation = model.getShape(endpointOperationId)
.flatMap(Shape::asOperationShape);
Optional<StructureShape> endpointError = model.getShape(endpointErrorId)
.flatMap(Shape::asStructureShape);
Optional<OperationShape> endpointOperation = model.getShape(endpointOperationId)
.flatMap(Shape::asOperationShape);
Optional<StructureShape> endpointError = model.getShape(endpointErrorId)
.flatMap(Shape::asStructureShape);

if (endpointOperation.isPresent() && endpointError.isPresent()) {
Map<ShapeId, ClientEndpointDiscoveryInfo> serviceInfo = getOperations(
service, endpointOperation.get(), endpointError.get(), topDownIndex, opIndex);
if (!serviceInfo.isEmpty()) {
endpointDiscoveryInfo.put(service.getId(), serviceInfo);
}
if (endpointOperation.isPresent() && endpointError.isPresent()) {
Map<ShapeId, ClientEndpointDiscoveryInfo> serviceInfo = getOperations(
service, endpointOperation.get(), endpointError.get(), topDownIndex, opIndex);
if (!serviceInfo.isEmpty()) {
endpointDiscoveryInfo.put(service.getId(), serviceInfo);
}
});
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,8 @@ public List<ValidationEvent> validate(Model model) {
OperationIndex opIndex = OperationIndex.of(model);

Map<ServiceShape, ClientEndpointDiscoveryTrait> endpointDiscoveryServices = new HashMap<>();
for (Shape shape : model.getShapesWithTrait(ClientEndpointDiscoveryTrait.class)) {
shape.asServiceShape().ifPresent(service -> {
endpointDiscoveryServices.put(service, service.expectTrait(ClientEndpointDiscoveryTrait.class));
});
for (ServiceShape service : model.getServiceShapesWithTrait(ClientEndpointDiscoveryTrait.class)) {
endpointDiscoveryServices.put(service, service.expectTrait(ClientEndpointDiscoveryTrait.class));
}

List<ValidationEvent> validationEvents = endpointDiscoveryServices.values().stream()
Expand Down
Loading

0 comments on commit be0cc88

Please sign in to comment.