diff --git a/smithy-codegen-core/src/main/java/software/amazon/smithy/codegen/core/TopologicalIndex.java b/smithy-codegen-core/src/main/java/software/amazon/smithy/codegen/core/TopologicalIndex.java new file mode 100644 index 00000000000..80b9b806091 --- /dev/null +++ b/smithy-codegen-core/src/main/java/software/amazon/smithy/codegen/core/TopologicalIndex.java @@ -0,0 +1,192 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.codegen.core; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.knowledge.KnowledgeIndex; +import software.amazon.smithy.model.knowledge.NeighborProviderIndex; +import software.amazon.smithy.model.loader.Prelude; +import software.amazon.smithy.model.neighbor.NeighborProvider; +import software.amazon.smithy.model.neighbor.Relationship; +import software.amazon.smithy.model.neighbor.RelationshipDirection; +import software.amazon.smithy.model.selector.PathFinder; +import software.amazon.smithy.model.shapes.Shape; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.model.shapes.SimpleShape; +import software.amazon.smithy.model.shapes.ToShapeId; +import software.amazon.smithy.utils.FunctionalUtils; + +/** + * Creates a reverse-topological ordering of shapes. + * + *

This kind of reverse topological ordering is useful for languages + * like C++ that need to define shapes before they can be referenced. + * Only non-recursive shapes are reverse-topologically ordered using + * {@link #getOrderedShapes()}. However, recursive shapes are queryable + * through {@link #getRecursiveShapes()}. When this returned {@code Set} is + * iterated, recursive shapes are ordered by their degree of recursion (the + * number of edges across all recursive closures), and then by shape ID + * when multiple shapes have the same degree of recursion. + * + *

The recursion closures of a shape can be queried using + * {@link #getRecursiveClosure(ToShapeId)}. This method returns a list of + * paths from the shape back to itself. This list can be useful for code + * generation to generate different code based on if a recursive path + * passes through particular types of shapes. + */ +public final class TopologicalIndex implements KnowledgeIndex { + + private final Set shapes = new LinkedHashSet<>(); + private final Map> recursiveShapes = new LinkedHashMap<>(); + + public TopologicalIndex(Model model) { + // A reverse-topological sort can't be performed on recursive shapes, + // so instead, recursive shapes are explored first and removed from + // the topological sort. + computeRecursiveShapes(model); + + // Next, the model is explored using a DFS so that targets of shapes + // are ordered before the shape itself. + NeighborProvider provider = NeighborProviderIndex.of(model).getProvider(); + model.shapes() + // Note that while we do not scan the prelude here, shapes from + // the prelude are pull into the ordered result if referenced. + .filter(FunctionalUtils.not(Prelude::isPreludeShape)) + .filter(shape -> !recursiveShapes.containsKey(shape)) + // Sort here to provide a deterministic result. + .sorted() + .forEach(shape -> visitShape(provider, shape)); + } + + private void computeRecursiveShapes(Model model) { + // PathFinder is used to find all paths from U -> U. + PathFinder finder = PathFinder.create(model); + + // The order of recursive shapes is first by the number of edges + // (the degree of recursion), and then alphabetically by shape ID. + Map>> edgesToShapePaths = new TreeMap<>(); + for (Shape shape : model.toSet()) { + if (!Prelude.isPreludeShape(shape) && !(shape instanceof SimpleShape)) { + // Find all paths from the shape back to itself. + List paths = finder.search(shape, shape); + if (!paths.isEmpty()) { + int edgeCount = 0; + for (PathFinder.Path path : paths) { + edgeCount += path.size(); + } + edgesToShapePaths.computeIfAbsent(edgeCount, s -> new TreeMap<>()) + .put(shape, Collections.unmodifiableList(paths)); + } + } + } + + for (Map.Entry>> entry : edgesToShapePaths.entrySet()) { + recursiveShapes.putAll(entry.getValue()); + } + } + + private void visitShape(NeighborProvider provider, Shape shape) { + // Visit members before visiting containers. Note that no 'visited' + // set is needed since only non-recursive shapes are traversed. + for (Relationship rel : provider.getNeighbors(shape)) { + if (rel.getRelationshipType().getDirection() == RelationshipDirection.DIRECTED) { + if (!rel.getNeighborShapeId().equals(shape.getId()) && rel.getNeighborShape().isPresent()) { + visitShape(provider, rel.getNeighborShape().get()); + } + } + } + + shapes.add(shape); + } + + /** + * Creates a new {@code TopologicalIndex}. + * + * @param model Model to create the index from. + * @return The created (or previously cached) {@code TopologicalIndex}. + */ + public static TopologicalIndex of(Model model) { + return model.getKnowledge(TopologicalIndex.class, TopologicalIndex::new); + } + + /** + * Gets all reverse-topologically ordered shapes, including members. + * + *

When the returned {@code Set} is iterated, shapes are returned in + * reverse-topological. Note that the returned set does not contain + * recursive shapes. + * + * @return Non-recursive shapes in a reverse-topological ordered {@code Set}. + */ + public Set getOrderedShapes() { + return Collections.unmodifiableSet(shapes); + } + + /** + * Gets all shapes that have edges that are part of a recursive closure, + * including container shapes (list/set/map/structure/union) and members. + * + *

When iterated, the returned {@code Set} is ordered from fewest number + * of edges to the most number of edges in the recursive closures, and then + * alphabetically by shape ID when there are multiple entries with + * the same number of edges. + * + * @return All shapes that are part of a recursive closure. + */ + public Set getRecursiveShapes() { + return Collections.unmodifiableSet(recursiveShapes.keySet()); + } + + /** + * Checks if the given shape has edges with recursive references. + * + * @param shape Shape to check. + * @return True if the shape has recursive edges. + */ + public boolean isRecursive(ToShapeId shape) { + return !getRecursiveClosure(shape).isEmpty(); + } + + /** + * Gets the recursive closure of a given shape represented as + * {@link PathFinder.Path} objects. + * + * @param shape Shape to get the recursive closures of. + * @return The closures of the shape, or an empty {@code List} if the shape is not recursive. + */ + public List getRecursiveClosure(ToShapeId shape) { + if (shape instanceof Shape) { + return recursiveShapes.getOrDefault(shape, Collections.emptyList()); + } + + // If given an ID, we need to scan the recursive shapes to look for a matching ID. + ShapeId id = shape.toShapeId(); + for (Map.Entry> entry : recursiveShapes.entrySet()) { + if (entry.getKey().getId().equals(id)) { + return entry.getValue(); + } + } + + return Collections.emptyList(); + } +} diff --git a/smithy-codegen-core/src/test/java/software/amazon/smithy/codegen/core/TopologicalIndexTest.java b/smithy-codegen-core/src/test/java/software/amazon/smithy/codegen/core/TopologicalIndexTest.java new file mode 100644 index 00000000000..127ee651150 --- /dev/null +++ b/smithy-codegen-core/src/test/java/software/amazon/smithy/codegen/core/TopologicalIndexTest.java @@ -0,0 +1,102 @@ +package software.amazon.smithy.codegen.core; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.shapes.Shape; +import software.amazon.smithy.model.shapes.ShapeId; + +public class TopologicalIndexTest { + + private static Model model; + + @BeforeAll + public static void before() { + model = Model.assembler() + .addImport(TopologicalIndexTest.class.getResource("topological-sort.smithy")) + .assemble() + .unwrap(); + } + + @AfterAll + public static void after() { + model = null; + } + + @Test + public void sortsTopologically() { + TopologicalIndex index = TopologicalIndex.of(model); + + List ordered = new ArrayList<>(); + for (Shape shape : index.getOrderedShapes()) { + ordered.add(shape.getId().toString()); + } + + List recursive = new ArrayList<>(); + for (Shape shape : index.getRecursiveShapes()) { + recursive.add(shape.getId().toString()); + } + + assertThat(ordered, contains( + "smithy.example#MyString", + "smithy.example#BamList$member", + "smithy.example#BamList", + "smithy.api#Integer", + "smithy.example#Bar$baz", + "smithy.example#Bar$bam", + "smithy.example#Bar", + "smithy.example#Foo$foo", + "smithy.example#Foo$bar", + "smithy.example#Foo")); + + assertThat(recursive, contains( + "smithy.example#Recursive$b", + "smithy.example#Recursive$a", + "smithy.example#RecursiveList", + "smithy.example#RecursiveList$member", + "smithy.example#Recursive")); + } + + @Test + public void checksIfShapeByIdIsRecursive() { + TopologicalIndex index = TopologicalIndex.of(model); + + assertThat(index.isRecursive(ShapeId.from("smithy.example#Recursive$b")), is(true)); + assertThat(index.isRecursive(ShapeId.from("smithy.example#MyString")), is(false)); + } + + @Test + public void checksIfShapeIsRecursive() { + TopologicalIndex index = TopologicalIndex.of(model); + + assertThat(index.isRecursive(model.expectShape(ShapeId.from("smithy.example#MyString"))), is(false)); + assertThat(index.isRecursive(model.expectShape(ShapeId.from("smithy.example#Recursive$b"))), is(true)); + } + + @Test + public void getsRecursiveClosureById() { + TopologicalIndex index = TopologicalIndex.of(model); + + assertThat(index.getRecursiveClosure(ShapeId.from("smithy.example#MyString")), empty()); + assertThat(index.getRecursiveClosure(ShapeId.from("smithy.example#Recursive$b")), not(empty())); + } + + @Test + public void getsRecursiveClosureByShape() { + TopologicalIndex index = TopologicalIndex.of(model); + + assertThat(index.getRecursiveClosure(model.expectShape(ShapeId.from("smithy.example#MyString"))), + empty()); + assertThat(index.getRecursiveClosure(model.expectShape(ShapeId.from("smithy.example#Recursive$b"))), + not(empty())); + } +} diff --git a/smithy-codegen-core/src/test/resources/software/amazon/smithy/codegen/core/topological-sort.smithy b/smithy-codegen-core/src/test/resources/software/amazon/smithy/codegen/core/topological-sort.smithy new file mode 100644 index 00000000000..c974cbe231c --- /dev/null +++ b/smithy-codegen-core/src/test/resources/software/amazon/smithy/codegen/core/topological-sort.smithy @@ -0,0 +1,26 @@ +namespace smithy.example + +string MyString + +structure Foo { + foo: MyString, + bar: Bar, +} + +structure Bar { + baz: Integer, + bam: BamList, +} + +list BamList { + member: MyString +} + +structure Recursive { + a: RecursiveList, + b: Recursive, +} + +list RecursiveList { + member: Recursive, +} diff --git a/smithy-model/src/main/java/software/amazon/smithy/model/selector/PathFinder.java b/smithy-model/src/main/java/software/amazon/smithy/model/selector/PathFinder.java index 28d7ace2111..6e5ff6a2b6a 100644 --- a/smithy-model/src/main/java/software/amazon/smithy/model/selector/PathFinder.java +++ b/smithy-model/src/main/java/software/amazon/smithy/model/selector/PathFinder.java @@ -18,6 +18,7 @@ import java.util.AbstractList; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashSet; import java.util.LinkedList; import java.util.List; @@ -101,12 +102,6 @@ public List search(ToShapeId startingShape, String targetSelector) { * @return Returns the list of matching paths. */ public List search(ToShapeId startingShape, Selector targetSelector) { - Shape shape = model.getShape(startingShape.toShapeId()).orElse(null); - - if (shape == null) { - return ListUtils.of(); - } - // Find all shapes that match the selector then work backwards from there. Set candidates = targetSelector.select(model); @@ -116,7 +111,30 @@ public List search(ToShapeId startingShape, Selector targetSelector) { } LOGGER.finest(() -> candidates.size() + " shapes matched the PathFinder selector of " + targetSelector); - return new Search(reverseProvider, shape, candidates).execute(); + return searchFromShapeToSet(startingShape, candidates); + } + + private List searchFromShapeToSet(ToShapeId startingShape, Set candidates) { + Shape shape = model.getShape(startingShape.toShapeId()).orElse(null); + if (shape == null || candidates.isEmpty()) { + return ListUtils.of(); + } else { + return new Search(reverseProvider, shape, candidates).execute(); + } + } + + /** + * Finds all of the possible paths from the {@code startingShape} to the + * the {@code targetShape}. + * + * @param startingShape Starting shape to find the paths from. + * @param targetShape The shape to try to find a path to. + * @return Returns the list of matching paths. + */ + public List search(ToShapeId startingShape, ToShapeId targetShape) { + return searchFromShapeToSet( + startingShape, + model.getShape(targetShape.toShapeId()).map(Collections::singleton).orElse(Collections.emptySet())); } /** @@ -183,7 +201,7 @@ private Optional createPathTo(ToShapeId operationId, String memberName, Re * An immutable {@code Relationship} path from a starting shape to an end shape. */ public static final class Path extends AbstractList { - private List relationships; + private final List relationships; public Path(List relationships) { if (relationships.isEmpty()) { @@ -239,7 +257,7 @@ public Shape getStartShape() { * starting shape. * * @return Returns the ending shape of the Path. - * @throws SourceException if the relationship is invalid. + * @throws SourceException if the last relationship is invalid. */ public Shape getEndShape() { Relationship last = relationships.get(relationships.size() - 1); diff --git a/smithy-model/src/test/java/software/amazon/smithy/model/selector/PathFinderTest.java b/smithy-model/src/test/java/software/amazon/smithy/model/selector/PathFinderTest.java index 677ada41247..4abe65c3c77 100644 --- a/smithy-model/src/test/java/software/amazon/smithy/model/selector/PathFinderTest.java +++ b/smithy-model/src/test/java/software/amazon/smithy/model/selector/PathFinderTest.java @@ -151,4 +151,22 @@ public void createsPathToInputAndOutputMember() { assertThat(output.get().toString(), equalTo("[id|smithy.example#Operation] -[output]-> [id|smithy.example#Output] -[member]-> [id|smithy.example#Output$foo] > [id|smithy.api#String]")); } + + @Test + public void createsPathToOtherShape() { + MemberShape recursiveMember = MemberShape.builder().id("a.b#Struct$a").target("a.b#Struct").build(); + StructureShape struct = StructureShape.builder() + .id("a.b#Struct") + .addMember(recursiveMember) + .build(); + Model model = Model.builder().addShapes(struct).build(); + PathFinder finder = PathFinder.create(model); + List paths = finder.search(struct.getId(), struct.getId()); + + assertThat(paths, hasSize(1)); + assertThat(paths.get(0), hasSize(2)); + assertThat(paths.get(0).getShapes(), hasSize(3)); + assertThat(paths.get(0).getStartShape(), equalTo(struct)); + assertThat(paths.get(0).getEndShape(), equalTo(struct)); + } }